epgsql_cmd_connect.erl 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. %%% Special kind of command - it's exclusive: no other commands can run until
  2. %%% this one finishes.
  3. %%% It also uses some 'private' epgsql_sock's APIs
  4. %%%
  5. -module(epgsql_cmd_connect).
  6. -behaviour(epgsql_command).
  7. -export([init/1, execute/2, handle_message/4]).
  8. -export_type([response/0, connect_error/0]).
  9. -type response() :: connected
  10. | {error, connect_error()}.
  11. -type connect_error() ::
  12. invalid_authorization_specification
  13. | invalid_password
  14. | {unsupported_auth_method,
  15. kerberosV5 | crypt | scm | gss | sspi | {unknown, integer()} | {sasl, [binary()]}}
  16. | {sasl_server_final, any()}
  17. | epgsql:query_error().
  18. -include("epgsql.hrl").
  19. -include("protocol.hrl").
  20. -type auth_fun() :: fun((init | binary(), _, _) ->
  21. {send, byte(), iodata(), any()}
  22. | ok
  23. | {error, any()}
  24. | unknown).
  25. -record(connect,
  26. {opts :: list(),
  27. auth_fun :: auth_fun() | undefined,
  28. auth_state :: any() | undefined,
  29. auth_send :: {integer(), iodata()} | undefined,
  30. stage = connect :: connect | maybe_auth | auth | initialization}).
  31. -define(SCRAM_AUTH_METHOD, <<"SCRAM-SHA-256">>).
  32. -define(AUTH_OK, 0).
  33. -define(AUTH_CLEARTEXT, 3).
  34. -define(AUTH_MD5, 5).
  35. -define(AUTH_SASL, 10).
  36. -define(AUTH_SASL_CONTINUE, 11).
  37. -define(AUTH_SASL_FINAL, 12).
  38. init({Host, Username, Password, Opts}) ->
  39. Opts1 = [{host, Host},
  40. {username, Username},
  41. {password, Password}
  42. | Opts],
  43. #connect{opts = Opts1}.
  44. execute(PgSock, #connect{opts = Opts, stage = connect} = State) ->
  45. Host = get_val(host, Opts),
  46. Username = get_val(username, Opts),
  47. %% _ = get_val(password, Opts),
  48. Timeout = proplists:get_value(timeout, Opts, 5000),
  49. Port = proplists:get_value(port, Opts, 5432),
  50. SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}, {keepalive, true}],
  51. case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
  52. {ok, Sock} ->
  53. %% Increase the buffer size. Following the recommendation in the inet man page:
  54. %%
  55. %% It is recommended to have val(buffer) >=
  56. %% max(val(sndbuf),val(recbuf)).
  57. {ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
  58. inet:getopts(Sock, [recbuf, sndbuf]),
  59. inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]),
  60. PgSock1 = maybe_ssl(Sock, proplists:get_value(ssl, Opts, false), Opts, PgSock),
  61. Opts2 = ["user", 0, Username, 0],
  62. Opts3 = case proplists:get_value(database, Opts, undefined) of
  63. undefined -> Opts2;
  64. Database -> [Opts2 | ["database", 0, Database, 0]]
  65. end,
  66. Replication = proplists:get_value(replication, Opts, undefined),
  67. Opts4 = case Replication of
  68. undefined -> Opts3;
  69. Replication ->
  70. [Opts3 | ["replication", 0, Replication, 0]]
  71. end,
  72. PgSock2 = case Replication of
  73. undefined -> PgSock1;
  74. _ -> epgsql_sock:init_replication_state(PgSock1)
  75. end,
  76. epgsql_sock:send(PgSock2, [<<196608:?int32>>, Opts4, 0]),
  77. PgSock3 = case proplists:get_value(async, Opts, undefined) of
  78. undefined -> PgSock2;
  79. Async -> epgsql_sock:set_attr(async, Async, PgSock2)
  80. end,
  81. {ok, PgSock3, State#connect{stage = maybe_auth}};
  82. {error, Reason} = Error ->
  83. {stop, Reason, Error, PgSock}
  84. end;
  85. execute(PgSock, #connect{stage = auth, auth_send = {PacketId, Data}} = St) ->
  86. epgsql_sock:send(PgSock, PacketId, Data),
  87. {ok, PgSock, St#connect{auth_send = undefined}}.
  88. maybe_ssl(S, false, _, PgSock) ->
  89. epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
  90. maybe_ssl(S, Flag, Opts, PgSock) ->
  91. ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
  92. Timeout = proplists:get_value(timeout, Opts, 5000),
  93. {ok, <<Code>>} = gen_tcp:recv(S, 1, Timeout),
  94. case Code of
  95. $S ->
  96. SslOpts = proplists:get_value(ssl_opts, Opts, []),
  97. case ssl:connect(S, SslOpts, Timeout) of
  98. {ok, S2} ->
  99. epgsql_sock:set_net_socket(ssl, S2, PgSock);
  100. {error, Reason} ->
  101. exit({ssl_negotiation_failed, Reason})
  102. end;
  103. $N ->
  104. case Flag of
  105. true ->
  106. epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
  107. required ->
  108. exit(ssl_not_available)
  109. end
  110. end.
  111. %% Auth sub-protocol
  112. auth_init(<<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->
  113. auth_init(fun auth_cleartext/3, undefined, Sock, St);
  114. auth_init(<<?AUTH_MD5:?int32, Salt:4/binary>>, Sock, St) ->
  115. auth_init(fun auth_md5/3, Salt, Sock, St);
  116. auth_init(<<?AUTH_SASL:?int32, MethodsB/binary>>, Sock, St) ->
  117. Methods = epgsql_wire:decode_strings(MethodsB),
  118. case lists:member(?SCRAM_AUTH_METHOD, Methods) of
  119. true ->
  120. auth_init(fun auth_scram/3, undefined, Sock, St);
  121. false ->
  122. {stop, normal, {error, {unsupported_auth_method,
  123. {sasl, lists:delete(<<>>, Methods)}}}}
  124. end;
  125. auth_init(<<M:?int32, _/binary>>, Sock, _St) ->
  126. Method = case M of
  127. 2 -> kerberosV5;
  128. 4 -> crypt;
  129. 6 -> scm;
  130. 7 -> gss;
  131. 8 -> sspi;
  132. _ -> {unknown, M}
  133. end,
  134. {stop, normal, {error, {unsupported_auth_method, Method}}, Sock}.
  135. auth_init(Fun, InitState, PgSock, St) ->
  136. auth_handle(init, PgSock, St#connect{auth_fun = Fun, auth_state = InitState,
  137. stage = auth}).
  138. auth_handle(Data, PgSock, #connect{auth_fun = Fun, auth_state = AuthSt} = St) ->
  139. case Fun(Data, AuthSt, St) of
  140. {send, SendPacketId, SendData, AuthSt1} ->
  141. {requeue, PgSock, St#connect{auth_state = AuthSt1,
  142. auth_send = {SendPacketId, SendData}}};
  143. ok -> {noaction, PgSock, St};
  144. {error, Reason} ->
  145. {stop, normal, {error, Reason}};
  146. unknown -> unknown
  147. end.
  148. %% AuthenticationCleartextPassword
  149. auth_cleartext(init, _AuthState, #connect{opts = Opts}) ->
  150. Password = get_val(password, Opts),
  151. {send, ?PASSWORD, [Password, 0], undefined};
  152. auth_cleartext(_, _, _) -> unknown.
  153. %% AuthenticationMD5Password
  154. auth_md5(init, Salt, #connect{opts = Opts}) ->
  155. User = get_val(username, Opts),
  156. Password = get_val(password, Opts),
  157. Digest1 = hex(erlang:md5([Password, User])),
  158. Str = ["md5", hex(erlang:md5([Digest1, Salt])), 0],
  159. {send, ?PASSWORD, Str, undefined};
  160. auth_md5(_, _, _) -> unknown.
  161. %% AuthenticationSASL
  162. auth_scram(init, undefined, #connect{opts = Opts}) ->
  163. User = get_val(username, Opts),
  164. Nonce = epgsql_scram:get_nonce(16),
  165. ClientFirst = epgsql_scram:get_client_first(User, Nonce),
  166. SaslInitialResponse = [?SCRAM_AUTH_METHOD, 0, <<(iolist_size(ClientFirst)):?int32>>, ClientFirst],
  167. {send, ?SASL_ANY_RESPONSE, SaslInitialResponse, {auth_request, Nonce}};
  168. auth_scram(<<?AUTH_SASL_CONTINUE:?int32, ServerFirst/binary>>, {auth_request, Nonce}, #connect{opts = Opts}) ->
  169. User = get_val(username, Opts),
  170. Password = get_val(password, Opts),
  171. ServerFirstParts = epgsql_scram:parse_server_first(ServerFirst, Nonce),
  172. {ClientFinalMessage, ServerProof} = epgsql_scram:get_client_final(ServerFirstParts, Nonce, User, Password),
  173. {send, ?SASL_ANY_RESPONSE, ClientFinalMessage, {server_final, ServerProof}};
  174. auth_scram(<<?AUTH_SASL_FINAL:?int32, ServerFinalMsg/binary>>, {server_final, ServerProof}, _Conn) ->
  175. case epgsql_scram:parse_server_final(ServerFinalMsg) of
  176. {ok, ServerProof} -> ok;
  177. Other -> {error, {sasl_server_final, Other}}
  178. end;
  179. auth_scram(_, _, _) ->
  180. unknown.
  181. %% --- Auth ---
  182. %% AuthenticationOk
  183. handle_message(?AUTHENTICATION_REQUEST, <<?AUTH_OK:?int32>>, Sock, State) ->
  184. {noaction, Sock, State#connect{stage = initialization,
  185. auth_fun = undefined,
  186. auth_state = undefned,
  187. auth_send = undefined}};
  188. handle_message(?AUTHENTICATION_REQUEST, Message, Sock, #connect{stage = Stage} = St) when Stage =/= auth ->
  189. auth_init(Message, Sock, St);
  190. handle_message(?AUTHENTICATION_REQUEST, Packet, Sock, #connect{stage = auth} = St) ->
  191. auth_handle(Packet, Sock, St);
  192. %% --- Initialization ---
  193. %% BackendKeyData
  194. handle_message(?CANCELLATION_KEY, <<Pid:?int32, Key:?int32>>, Sock, _State) ->
  195. {noaction, epgsql_sock:set_attr(backend, {Pid, Key}, Sock)};
  196. %% ReadyForQuery
  197. handle_message(?READY_FOR_QUERY, _, Sock, _State) ->
  198. Codec = epgsql_binary:new_codec(Sock, []),
  199. Sock1 = epgsql_sock:set_attr(codec, Codec, Sock),
  200. {finish, connected, connected, Sock1};
  201. %% ErrorResponse
  202. handle_message(?ERROR, Err, Sock, #connect{stage = Stage} = _State) when Stage == auth;
  203. Stage == maybe_auth ->
  204. Why = case Err#error.code of
  205. <<"28000">> -> invalid_authorization_specification;
  206. <<"28P01">> -> invalid_password;
  207. Any -> Any
  208. end,
  209. {stop, normal, {error, Why}, Sock};
  210. handle_message(_, _, _, _) ->
  211. unknown.
  212. get_val(Key, Proplist) ->
  213. Val = proplists:get_value(Key, Proplist),
  214. (Val =/= undefined) orelse error({required_option, Key}),
  215. Val.
  216. hex(Bin) ->
  217. HChar = fun(N) when N < 10 -> $0 + N;
  218. (N) when N < 16 -> $W + N
  219. end,
  220. <<<<(HChar(H)), (HChar(L))>> || <<H:4, L:4>> <= Bin>>.