epgsql_cmd_connect.erl 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. %%% @doc Connects to the server and performs all the necessary handshakes.
  2. %%%
  3. %%% Special kind of command - it's exclusive: no other commands can run until
  4. %%% this one finishes.
  5. %%% It also uses some "private" epgsql_sock's APIs
  6. %%%
  7. -module(epgsql_cmd_connect).
  8. -behaviour(epgsql_command).
  9. -export([hide_password/1, opts_hide_password/1, open_socket/2]).
  10. -export([init/1, execute/2, handle_message/4]).
  11. -export_type([response/0, connect_error/0]).
  12. -type response() :: connected
  13. | {error, connect_error()}.
  14. -type connect_error() ::
  15. invalid_authorization_specification
  16. | invalid_password
  17. | {unsupported_auth_method,
  18. kerberosV5 | crypt | scm | gss | sspi | {unknown, integer()} | {sasl, [binary()]}}
  19. | {sasl_server_final, any()}
  20. | epgsql:query_error().
  21. -include("epgsql.hrl").
  22. -include("protocol.hrl").
  23. -type auth_fun() :: fun((init | binary(), _, _) ->
  24. {send, byte(), iodata(), any()}
  25. | ok
  26. | {error, any()}
  27. | unknown).
  28. -record(connect,
  29. {opts :: map(),
  30. auth_fun :: auth_fun() | undefined,
  31. auth_state :: any() | undefined,
  32. auth_send :: {integer(), iodata()} | undefined,
  33. stage = connect :: connect | maybe_auth | auth | initialization}).
  34. -define(SCRAM_AUTH_METHOD, <<"SCRAM-SHA-256">>).
  35. -define(AUTH_OK, 0).
  36. -define(AUTH_CLEARTEXT, 3).
  37. -define(AUTH_MD5, 5).
  38. -define(AUTH_SASL, 10).
  39. -define(AUTH_SASL_CONTINUE, 11).
  40. -define(AUTH_SASL_FINAL, 12).
  41. init(#{host := _, username := _} = Opts) ->
  42. #connect{opts = Opts}.
  43. execute(PgSock, #connect{opts = #{username := Username} = Opts, stage = connect} = State) ->
  44. SockOpts = prepare_tcp_opts(maps:get(tcp_opts, Opts, [])),
  45. FilteredOpts = filter_sensitive_info(Opts),
  46. PgSock1 = epgsql_sock:set_attr(connect_opts, FilteredOpts, PgSock),
  47. case open_socket(SockOpts, Opts) of
  48. {ok, Mode, Sock} ->
  49. PgSock2 = epgsql_sock:set_net_socket(Mode, Sock, PgSock1),
  50. Opts2 = ["user", 0, Username, 0],
  51. Opts3 = case maps:find(database, Opts) of
  52. error -> Opts2;
  53. {ok, Database} -> [Opts2 | ["database", 0, Database, 0]]
  54. end,
  55. {Opts4, PgSock3} =
  56. case Opts of
  57. #{replication := Replication} ->
  58. {[Opts3 | ["replication", 0, Replication, 0]],
  59. epgsql_sock:init_replication_state(PgSock2)};
  60. _ -> {Opts3, PgSock2}
  61. end,
  62. Opts5 = case Opts of
  63. #{application_name := ApplicationName} ->
  64. [Opts4 | ["application_name", 0, ApplicationName, 0]];
  65. _ ->
  66. Opts4
  67. end,
  68. ok = epgsql_sock:send(PgSock3, [<<196608:?int32>>, Opts5, 0]),
  69. PgSock4 = case Opts of
  70. #{async := Async} ->
  71. epgsql_sock:set_attr(async, Async, PgSock3);
  72. _ -> PgSock3
  73. end,
  74. {ok, PgSock4, State#connect{stage = maybe_auth}};
  75. {error, Reason} = Error ->
  76. {stop, Reason, Error, PgSock}
  77. end;
  78. execute(PgSock, #connect{stage = auth, auth_send = {PacketType, Data}} = St) ->
  79. {send, PacketType, Data, PgSock, St#connect{auth_send = undefined}}.
  80. -spec open_socket([{atom(), any()}], epgsql:connect_opts()) ->
  81. {ok , gen_tcp | ssl, port() | ssl:sslsocket()} | {error, any()}.
  82. open_socket(SockOpts, #{host := Host} = ConnectOpts) ->
  83. Timeout = maps:get(timeout, ConnectOpts, 5000),
  84. Deadline = deadline(Timeout),
  85. Port = maps:get(port, ConnectOpts, 5432),
  86. case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
  87. {ok, Sock} ->
  88. client_handshake(Sock, ConnectOpts, Deadline);
  89. {error, _Reason} = Error ->
  90. Error
  91. end.
  92. client_handshake(Sock, ConnectOpts, Deadline) ->
  93. case maps:is_key(tcp_opts, ConnectOpts) of
  94. false ->
  95. %% Increase the buffer size. Following the recommendation in the inet man page:
  96. %%
  97. %% It is recommended to have val(buffer) >=
  98. %% max(val(sndbuf),val(recbuf)).
  99. {ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
  100. inet:getopts(Sock, [recbuf, sndbuf]),
  101. inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]);
  102. true ->
  103. %% All TCP options are provided by the user
  104. noop
  105. end,
  106. maybe_ssl(Sock, maps:get(ssl, ConnectOpts, false), ConnectOpts, Deadline).
  107. maybe_ssl(Sock, false, _ConnectOpts, _Deadline) ->
  108. {ok, gen_tcp, Sock};
  109. maybe_ssl(Sock, Flag, ConnectOpts, Deadline) ->
  110. ok = gen_tcp:send(Sock, <<8:?int32, 80877103:?int32>>),
  111. Timeout0 = timeout(Deadline),
  112. case gen_tcp:recv(Sock, 1, Timeout0) of
  113. {ok, <<$S>>} ->
  114. SslOpts = maps:get(ssl_opts, ConnectOpts, []),
  115. Timeout = timeout(Deadline),
  116. case ssl:connect(Sock, SslOpts, Timeout) of
  117. {ok, Sock2} ->
  118. {ok, ssl, Sock2};
  119. {error, Reason} ->
  120. Err = {ssl_negotiation_failed, Reason},
  121. {error, Err}
  122. end;
  123. {ok, <<$N>>} ->
  124. case Flag of
  125. true ->
  126. {ok, gen_tcp, Sock};
  127. required ->
  128. {error, ssl_not_available}
  129. end;
  130. {error, Reason} ->
  131. {error, Reason}
  132. end.
  133. %% @doc Replace `password' in Opts map with obfuscated one
  134. opts_hide_password(#{password := Password} = Opts) ->
  135. HiddenPassword = hide_password(Password),
  136. Opts#{password => HiddenPassword};
  137. opts_hide_password(Opts) -> Opts.
  138. %% @doc password and username are sensitive data that should not be stored in a
  139. %% permanent state that might crash during code upgrade
  140. filter_sensitive_info(Opts0) ->
  141. maps:without([password, username], Opts0).
  142. %% @doc this function wraps plaintext password to a lambda function, so, if
  143. %% epgsql_sock process crashes when executing `connect' command, password will
  144. %% not appear in a crash log
  145. -spec hide_password(iodata()) -> fun( () -> iodata() ).
  146. hide_password(Password) when is_list(Password);
  147. is_binary(Password) ->
  148. fun() ->
  149. Password
  150. end;
  151. hide_password(PasswordFun) when is_function(PasswordFun, 0) ->
  152. PasswordFun.
  153. %% Auth sub-protocol
  154. auth_init(<<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->
  155. auth_init(fun auth_cleartext/3, undefined, Sock, St);
  156. auth_init(<<?AUTH_MD5:?int32, Salt:4/binary>>, Sock, St) ->
  157. auth_init(fun auth_md5/3, Salt, Sock, St);
  158. auth_init(<<?AUTH_SASL:?int32, MethodsB/binary>>, Sock, St) ->
  159. Methods = epgsql_wire:decode_strings(MethodsB),
  160. case lists:member(?SCRAM_AUTH_METHOD, Methods) of
  161. true ->
  162. auth_init(fun auth_scram/3, undefined, Sock, St);
  163. false ->
  164. {stop, normal, {error, {unsupported_auth_method,
  165. {sasl, lists:delete(<<>>, Methods)}}}}
  166. end;
  167. auth_init(<<M:?int32, _/binary>>, Sock, _St) ->
  168. Method = case M of
  169. 2 -> kerberosV5;
  170. 4 -> crypt;
  171. 6 -> scm;
  172. 7 -> gss;
  173. 8 -> sspi;
  174. _ -> {unknown, M}
  175. end,
  176. {stop, normal, {error, {unsupported_auth_method, Method}}, Sock}.
  177. auth_init(Fun, InitState, PgSock, St) ->
  178. auth_handle(init, PgSock, St#connect{auth_fun = Fun, auth_state = InitState,
  179. stage = auth}).
  180. auth_handle(Data, PgSock, #connect{auth_fun = Fun, auth_state = AuthSt} = St) ->
  181. case Fun(Data, AuthSt, St) of
  182. {send, SendPacketId, SendData, AuthSt1} ->
  183. {requeue, PgSock, St#connect{auth_state = AuthSt1,
  184. auth_send = {SendPacketId, SendData}}};
  185. ok -> {noaction, PgSock, St};
  186. {error, Reason} ->
  187. {stop, normal, {error, Reason}};
  188. unknown -> unknown
  189. end.
  190. %% AuthenticationCleartextPassword
  191. auth_cleartext(init, _AuthState, #connect{opts = Opts}) ->
  192. Password = get_password(Opts),
  193. {send, ?PASSWORD, [Password, 0], undefined};
  194. auth_cleartext(_, _, _) -> unknown.
  195. %% AuthenticationMD5Password
  196. auth_md5(init, Salt, #connect{opts = Opts}) ->
  197. User = maps:get(username, Opts),
  198. Password = get_password(Opts),
  199. Digest1 = hex(erlang:md5([Password, User])),
  200. Str = ["md5", hex(erlang:md5([Digest1, Salt])), 0],
  201. {send, ?PASSWORD, Str, undefined};
  202. auth_md5(_, _, _) -> unknown.
  203. %% AuthenticationSASL
  204. auth_scram(init, undefined, #connect{opts = Opts}) ->
  205. User = maps:get(username, Opts),
  206. Nonce = epgsql_scram:get_nonce(16),
  207. ClientFirst = epgsql_scram:get_client_first(User, Nonce),
  208. SaslInitialResponse = [?SCRAM_AUTH_METHOD, 0, <<(iolist_size(ClientFirst)):?int32>>, ClientFirst],
  209. {send, ?SASL_ANY_RESPONSE, SaslInitialResponse, {auth_request, Nonce}};
  210. auth_scram(<<?AUTH_SASL_CONTINUE:?int32, ServerFirst/binary>>, {auth_request, Nonce}, #connect{opts = Opts}) ->
  211. User = maps:get(username, Opts),
  212. Password = get_password(Opts),
  213. ServerFirstParts = epgsql_scram:parse_server_first(ServerFirst, Nonce),
  214. {ClientFinalMessage, ServerProof} = epgsql_scram:get_client_final(ServerFirstParts, Nonce, User, Password),
  215. {send, ?SASL_ANY_RESPONSE, ClientFinalMessage, {server_final, ServerProof}};
  216. auth_scram(<<?AUTH_SASL_FINAL:?int32, ServerFinalMsg/binary>>, {server_final, ServerProof}, _Conn) ->
  217. case epgsql_scram:parse_server_final(ServerFinalMsg) of
  218. {ok, ServerProof} -> ok;
  219. Other -> {error, {sasl_server_final, Other}}
  220. end;
  221. auth_scram(_, _, _) ->
  222. unknown.
  223. %% --- Auth ---
  224. %% AuthenticationOk
  225. handle_message(?AUTHENTICATION_REQUEST, <<?AUTH_OK:?int32>>, Sock, State) ->
  226. {noaction, Sock, State#connect{stage = initialization,
  227. auth_fun = undefined,
  228. auth_state = undefined,
  229. auth_send = undefined}};
  230. handle_message(?AUTHENTICATION_REQUEST, Message, Sock, #connect{stage = Stage} = St) when Stage =/= auth ->
  231. auth_init(Message, Sock, St);
  232. handle_message(?AUTHENTICATION_REQUEST, Packet, Sock, #connect{stage = auth} = St) ->
  233. auth_handle(Packet, Sock, St);
  234. %% --- Initialization ---
  235. %% BackendKeyData
  236. handle_message(?CANCELLATION_KEY, <<Pid:?int32, Key:?int32>>, Sock, _State) ->
  237. {noaction, epgsql_sock:set_attr(backend, {Pid, Key}, Sock)};
  238. %% ReadyForQuery
  239. handle_message(?READY_FOR_QUERY, _, Sock, #connect{opts = Opts}) ->
  240. CodecOpts = maps:with([nulls], Opts),
  241. Codec = epgsql_binary:new_codec(Sock, CodecOpts),
  242. Sock1 = epgsql_sock:set_attr(codec, Codec, Sock),
  243. {finish, connected, connected, Sock1};
  244. %% ErrorResponse
  245. handle_message(?ERROR, #error{code = Code} = Err, Sock, #connect{stage = Stage} = _State) ->
  246. IsAuthStage = (Stage == auth) orelse (Stage == maybe_auth),
  247. Why = case Code of
  248. <<"28000">> when IsAuthStage ->
  249. invalid_authorization_specification;
  250. <<"28P01">> when IsAuthStage ->
  251. invalid_password;
  252. _ ->
  253. Err
  254. end,
  255. {stop, normal, {error, Why}, Sock};
  256. handle_message(_, _, _, _) ->
  257. unknown.
  258. prepare_tcp_opts([]) ->
  259. [{active, false}, {packet, raw}, {mode, binary}, {nodelay, true}, {keepalive, true}];
  260. prepare_tcp_opts(Opts0) ->
  261. case lists:filter(fun(binary) -> true;
  262. (list) -> true;
  263. ({mode, _}) -> true;
  264. ({packet, _}) -> true;
  265. ({packet_size, _}) -> true;
  266. ({header, _}) -> true;
  267. ({active, _}) -> true;
  268. (_) -> false
  269. end, Opts0) of
  270. [] ->
  271. [{active, false}, {packet, raw}, {mode, binary} | Opts0];
  272. Forbidden ->
  273. error({forbidden_tcp_opts, Forbidden})
  274. end.
  275. get_password(Opts) ->
  276. PasswordFun = maps:get(password, Opts),
  277. PasswordFun().
  278. hex(Bin) ->
  279. HChar = fun(N) when N < 10 -> $0 + N;
  280. (N) when N < 16 -> $W + N
  281. end,
  282. <<<<(HChar(H)), (HChar(L))>> || <<H:4, L:4>> <= Bin>>.
  283. deadline(Timeout) ->
  284. erlang:monotonic_time(milli_seconds) + Timeout.
  285. timeout(Deadline) ->
  286. erlang:max(0, Deadline - erlang:monotonic_time(milli_seconds)).