epgsql_cmd_connect.erl 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. %%% Special kind of command - it's exclusive: no other commands can run until
  2. %%% this one finishes.
  3. %%% It also uses some 'private' epgsql_sock's APIs
  4. %%%
  5. -module(epgsql_cmd_connect).
  6. -behaviour(epgsql_command).
  7. -export([hide_password/1, opts_hide_password/1]).
  8. -export([init/1, execute/2, handle_message/4]).
  9. -export_type([response/0, connect_error/0]).
  10. -type response() :: connected
  11. | {error, connect_error()}.
  12. -type connect_error() ::
  13. invalid_authorization_specification
  14. | invalid_password
  15. | {unsupported_auth_method,
  16. kerberosV5 | crypt | scm | gss | sspi | {unknown, integer()} | {sasl, [binary()]}}
  17. | {sasl_server_final, any()}
  18. | epgsql:query_error().
  19. -include("epgsql.hrl").
  20. -include("protocol.hrl").
  21. -type auth_fun() :: fun((init | binary(), _, _) ->
  22. {send, byte(), iodata(), any()}
  23. | ok
  24. | {error, any()}
  25. | unknown).
  26. -record(connect,
  27. {opts :: map(),
  28. auth_fun :: auth_fun() | undefined,
  29. auth_state :: any() | undefined,
  30. auth_send :: {integer(), iodata()} | undefined,
  31. stage = connect :: connect | maybe_auth | auth | initialization}).
  32. -define(SCRAM_AUTH_METHOD, <<"SCRAM-SHA-256">>).
  33. -define(AUTH_OK, 0).
  34. -define(AUTH_CLEARTEXT, 3).
  35. -define(AUTH_MD5, 5).
  36. -define(AUTH_SASL, 10).
  37. -define(AUTH_SASL_CONTINUE, 11).
  38. -define(AUTH_SASL_FINAL, 12).
  39. init(#{host := _, username := _} = Opts) ->
  40. #connect{opts = Opts}.
  41. execute(PgSock, #connect{opts = #{host := Host} = Opts, stage = connect} = State) ->
  42. Timeout = maps:get(timeout, Opts, 5000),
  43. Port = maps:get(port, Opts, 5432),
  44. SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}, {keepalive, true}],
  45. case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
  46. {ok, Sock} ->
  47. client_handshake(Sock, PgSock, State);
  48. {error, Reason} = Error ->
  49. {stop, Reason, Error, PgSock}
  50. end;
  51. execute(PgSock, #connect{stage = auth, auth_send = {PacketId, Data}} = St) ->
  52. epgsql_sock:send(PgSock, PacketId, Data),
  53. {ok, PgSock, St#connect{auth_send = undefined}}.
  54. client_handshake(Sock, PgSock, #connect{opts = #{username := Username} = Opts} = State) ->
  55. %% Increase the buffer size. Following the recommendation in the inet man page:
  56. %%
  57. %% It is recommended to have val(buffer) >=
  58. %% max(val(sndbuf),val(recbuf)).
  59. {ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
  60. inet:getopts(Sock, [recbuf, sndbuf]),
  61. inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]),
  62. case maybe_ssl(Sock, maps:get(ssl, Opts, false), Opts, PgSock) of
  63. {error, Reason} ->
  64. {stop, Reason, {error, Reason}, PgSock};
  65. PgSock1 ->
  66. Opts2 = ["user", 0, Username, 0],
  67. Opts3 = case maps:find(database, Opts) of
  68. error -> Opts2;
  69. {ok, Database} -> [Opts2 | ["database", 0, Database, 0]]
  70. end,
  71. {Opts4, PgSock2} =
  72. case Opts of
  73. #{replication := Replication} ->
  74. {[Opts3 | ["replication", 0, Replication, 0]],
  75. epgsql_sock:init_replication_state(PgSock1)};
  76. _ -> {Opts3, PgSock1}
  77. end,
  78. epgsql_sock:send(PgSock2, [<<196608:?int32>>, Opts4, 0]),
  79. PgSock3 = case Opts of
  80. #{async := Async} ->
  81. epgsql_sock:set_attr(async, Async, PgSock2);
  82. _ -> PgSock2
  83. end,
  84. {ok, PgSock3, State#connect{stage = maybe_auth}}
  85. end.
  86. %% @doc Replace `password' in Opts map with obfuscated one
  87. opts_hide_password(#{password := Password} = Opts) ->
  88. HiddenPassword = hide_password(Password),
  89. Opts#{password => HiddenPassword};
  90. opts_hide_password(Opts) -> Opts.
  91. %% @doc this function wraps plaintext password to a lambda function, so, if
  92. %% epgsql_sock process crashes when executing `connect` command, password will
  93. %% not appear in a crash log
  94. -spec hide_password(iodata()) -> fun( () -> iodata() ).
  95. hide_password(Password) when is_list(Password);
  96. is_binary(Password) ->
  97. fun() ->
  98. Password
  99. end;
  100. hide_password(PasswordFun) when is_function(PasswordFun, 0) ->
  101. PasswordFun.
  102. maybe_ssl(S, false, _, PgSock) ->
  103. epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
  104. maybe_ssl(S, Flag, Opts, PgSock) ->
  105. ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
  106. Timeout = maps:get(timeout, Opts, 5000),
  107. {ok, <<Code>>} = gen_tcp:recv(S, 1, Timeout),
  108. case Code of
  109. $S ->
  110. SslOpts = maps:get(ssl_opts, Opts, []),
  111. case ssl:connect(S, SslOpts, Timeout) of
  112. {ok, S2} ->
  113. epgsql_sock:set_net_socket(ssl, S2, PgSock);
  114. {error, Reason} ->
  115. Err = {ssl_negotiation_failed, Reason},
  116. {error, Err}
  117. end;
  118. $N ->
  119. case Flag of
  120. true ->
  121. epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
  122. required ->
  123. {error, ssl_not_available}
  124. end
  125. end.
  126. %% Auth sub-protocol
  127. auth_init(<<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->
  128. auth_init(fun auth_cleartext/3, undefined, Sock, St);
  129. auth_init(<<?AUTH_MD5:?int32, Salt:4/binary>>, Sock, St) ->
  130. auth_init(fun auth_md5/3, Salt, Sock, St);
  131. auth_init(<<?AUTH_SASL:?int32, MethodsB/binary>>, Sock, St) ->
  132. Methods = epgsql_wire:decode_strings(MethodsB),
  133. case lists:member(?SCRAM_AUTH_METHOD, Methods) of
  134. true ->
  135. auth_init(fun auth_scram/3, undefined, Sock, St);
  136. false ->
  137. {stop, normal, {error, {unsupported_auth_method,
  138. {sasl, lists:delete(<<>>, Methods)}}}}
  139. end;
  140. auth_init(<<M:?int32, _/binary>>, Sock, _St) ->
  141. Method = case M of
  142. 2 -> kerberosV5;
  143. 4 -> crypt;
  144. 6 -> scm;
  145. 7 -> gss;
  146. 8 -> sspi;
  147. _ -> {unknown, M}
  148. end,
  149. {stop, normal, {error, {unsupported_auth_method, Method}}, Sock}.
  150. auth_init(Fun, InitState, PgSock, St) ->
  151. auth_handle(init, PgSock, St#connect{auth_fun = Fun, auth_state = InitState,
  152. stage = auth}).
  153. auth_handle(Data, PgSock, #connect{auth_fun = Fun, auth_state = AuthSt} = St) ->
  154. case Fun(Data, AuthSt, St) of
  155. {send, SendPacketId, SendData, AuthSt1} ->
  156. {requeue, PgSock, St#connect{auth_state = AuthSt1,
  157. auth_send = {SendPacketId, SendData}}};
  158. ok -> {noaction, PgSock, St};
  159. {error, Reason} ->
  160. {stop, normal, {error, Reason}};
  161. unknown -> unknown
  162. end.
  163. %% AuthenticationCleartextPassword
  164. auth_cleartext(init, _AuthState, #connect{opts = Opts}) ->
  165. Password = get_password(Opts),
  166. {send, ?PASSWORD, [Password, 0], undefined};
  167. auth_cleartext(_, _, _) -> unknown.
  168. %% AuthenticationMD5Password
  169. auth_md5(init, Salt, #connect{opts = Opts}) ->
  170. User = maps:get(username, Opts),
  171. Password = get_password(Opts),
  172. Digest1 = hex(erlang:md5([Password, User])),
  173. Str = ["md5", hex(erlang:md5([Digest1, Salt])), 0],
  174. {send, ?PASSWORD, Str, undefined};
  175. auth_md5(_, _, _) -> unknown.
  176. %% AuthenticationSASL
  177. auth_scram(init, undefined, #connect{opts = Opts}) ->
  178. User = maps:get(username, Opts),
  179. Nonce = epgsql_scram:get_nonce(16),
  180. ClientFirst = epgsql_scram:get_client_first(User, Nonce),
  181. SaslInitialResponse = [?SCRAM_AUTH_METHOD, 0, <<(iolist_size(ClientFirst)):?int32>>, ClientFirst],
  182. {send, ?SASL_ANY_RESPONSE, SaslInitialResponse, {auth_request, Nonce}};
  183. auth_scram(<<?AUTH_SASL_CONTINUE:?int32, ServerFirst/binary>>, {auth_request, Nonce}, #connect{opts = Opts}) ->
  184. User = maps:get(username, Opts),
  185. Password = get_password(Opts),
  186. ServerFirstParts = epgsql_scram:parse_server_first(ServerFirst, Nonce),
  187. {ClientFinalMessage, ServerProof} = epgsql_scram:get_client_final(ServerFirstParts, Nonce, User, Password),
  188. {send, ?SASL_ANY_RESPONSE, ClientFinalMessage, {server_final, ServerProof}};
  189. auth_scram(<<?AUTH_SASL_FINAL:?int32, ServerFinalMsg/binary>>, {server_final, ServerProof}, _Conn) ->
  190. case epgsql_scram:parse_server_final(ServerFinalMsg) of
  191. {ok, ServerProof} -> ok;
  192. Other -> {error, {sasl_server_final, Other}}
  193. end;
  194. auth_scram(_, _, _) ->
  195. unknown.
  196. %% --- Auth ---
  197. %% AuthenticationOk
  198. handle_message(?AUTHENTICATION_REQUEST, <<?AUTH_OK:?int32>>, Sock, State) ->
  199. {noaction, Sock, State#connect{stage = initialization,
  200. auth_fun = undefined,
  201. auth_state = undefined,
  202. auth_send = undefined}};
  203. handle_message(?AUTHENTICATION_REQUEST, Message, Sock, #connect{stage = Stage} = St) when Stage =/= auth ->
  204. auth_init(Message, Sock, St);
  205. handle_message(?AUTHENTICATION_REQUEST, Packet, Sock, #connect{stage = auth} = St) ->
  206. auth_handle(Packet, Sock, St);
  207. %% --- Initialization ---
  208. %% BackendKeyData
  209. handle_message(?CANCELLATION_KEY, <<Pid:?int32, Key:?int32>>, Sock, _State) ->
  210. {noaction, epgsql_sock:set_attr(backend, {Pid, Key}, Sock)};
  211. %% ReadyForQuery
  212. handle_message(?READY_FOR_QUERY, _, Sock, #connect{opts = Opts}) ->
  213. CodecOpts = maps:with([nulls], Opts),
  214. Codec = epgsql_binary:new_codec(Sock, CodecOpts),
  215. Sock1 = epgsql_sock:set_attr(codec, Codec, Sock),
  216. {finish, connected, connected, Sock1};
  217. %% ErrorResponse
  218. handle_message(?ERROR, #error{code = Code} = Err, Sock, #connect{stage = Stage} = _State) ->
  219. IsAuthStage = (Stage == auth) orelse (Stage == maybe_auth),
  220. Why = case Code of
  221. <<"28000">> when IsAuthStage ->
  222. invalid_authorization_specification;
  223. <<"28P01">> when IsAuthStage ->
  224. invalid_password;
  225. _ ->
  226. Err
  227. end,
  228. {stop, normal, {error, Why}, Sock};
  229. handle_message(_, _, _, _) ->
  230. unknown.
  231. get_password(Opts) ->
  232. PasswordFun = maps:get(password, Opts),
  233. PasswordFun().
  234. hex(Bin) ->
  235. HChar = fun(N) when N < 10 -> $0 + N;
  236. (N) when N < 16 -> $W + N
  237. end,
  238. <<<<(HChar(H)), (HChar(L))>> || <<H:4, L:4>> <= Bin>>.