epgsql_cmd_connect.erl 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. %%% Special kind of command - it's exclusive: no other commands can run until
  2. %%% this one finishes.
  3. %%% It also uses some 'private' epgsql_sock's APIs
  4. %%%
  5. -module(epgsql_cmd_connect).
  6. -behaviour(epgsql_command).
  7. -export([init/1, execute/2, handle_message/4]).
  8. -export_type([response/0, connect_error/0]).
  9. -type response() :: connected
  10. | {error, connect_error()}.
  11. -type connect_error() ::
  12. invalid_authorization_specification
  13. | invalid_password
  14. | {unsupported_auth_method, integer()}
  15. | epgsql:query_error().
  16. -include("epgsql.hrl").
  17. -include("protocol.hrl").
  18. -record(connect,
  19. {opts :: list(),
  20. auth_method,
  21. stage = connect :: connect | auth | initialization}).
  22. init({Host, Username, Password, Opts}) ->
  23. Opts1 = [{host, Host},
  24. {username, Username},
  25. {password, Password}
  26. | Opts],
  27. #connect{opts = Opts1}.
  28. execute(PgSock, #connect{opts = Opts, stage = connect} = State) ->
  29. Host = get_val(host, Opts),
  30. Username = get_val(username, Opts),
  31. %% _ = get_val(password, Opts),
  32. Timeout = proplists:get_value(timeout, Opts, 5000),
  33. Port = proplists:get_value(port, Opts, 5432),
  34. SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}, {keepalive, true}],
  35. case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
  36. {ok, Sock} ->
  37. %% Increase the buffer size. Following the recommendation in the inet man page:
  38. %%
  39. %% It is recommended to have val(buffer) >=
  40. %% max(val(sndbuf),val(recbuf)).
  41. {ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
  42. inet:getopts(Sock, [recbuf, sndbuf]),
  43. inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]),
  44. PgSock1 = maybe_ssl(Sock, proplists:get_value(ssl, Opts, false), Opts, PgSock),
  45. Opts2 = ["user", 0, Username, 0],
  46. Opts3 = case proplists:get_value(database, Opts, undefined) of
  47. undefined -> Opts2;
  48. Database -> [Opts2 | ["database", 0, Database, 0]]
  49. end,
  50. Replication = proplists:get_value(replication, Opts, undefined),
  51. Opts4 = case Replication of
  52. undefined -> Opts3;
  53. Replication ->
  54. [Opts3 | ["replication", 0, Replication, 0]]
  55. end,
  56. PgSock2 = case Replication of
  57. undefined -> PgSock1;
  58. _ -> epgsql_sock:init_replication_state(PgSock1)
  59. end,
  60. epgsql_sock:send(PgSock2, [<<196608:?int32>>, Opts4, 0]),
  61. PgSock3 = case proplists:get_value(async, Opts, undefined) of
  62. undefined -> PgSock2;
  63. Async -> epgsql_sock:set_attr(async, Async, PgSock2)
  64. end,
  65. {ok, PgSock3, State#connect{stage = auth}};
  66. {error, Reason} = Error ->
  67. {stop, Reason, Error, PgSock}
  68. end;
  69. execute(PgSock, #connect{stage = auth, auth_method = cleartext, opts = Opts} = St) ->
  70. Password = get_val(password, Opts),
  71. epgsql_sock:send(PgSock, ?PASSWORD, [Password, 0]),
  72. {ok, PgSock, St};
  73. execute(PgSock, #connect{stage = auth, auth_method = {md5, Salt}, opts = Opts} = St) ->
  74. User = get_val(username, Opts),
  75. Password = get_val(password, Opts),
  76. Digest1 = hex(erlang:md5([Password, User])),
  77. Str = ["md5", hex(erlang:md5([Digest1, Salt])), 0],
  78. epgsql_sock:send(PgSock, ?PASSWORD, Str),
  79. {ok, PgSock, St}.
  80. maybe_ssl(S, false, _, PgSock) ->
  81. epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
  82. maybe_ssl(S, Flag, Opts, PgSock) ->
  83. ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
  84. Timeout = proplists:get_value(timeout, Opts, 5000),
  85. {ok, <<Code>>} = gen_tcp:recv(S, 1, Timeout),
  86. case Code of
  87. $S ->
  88. SslOpts = proplists:get_value(ssl_opts, Opts, []),
  89. case ssl:connect(S, SslOpts, Timeout) of
  90. {ok, S2} ->
  91. epgsql_sock:set_net_socket(ssl, S2, PgSock);
  92. {error, Reason} ->
  93. exit({ssl_negotiation_failed, Reason})
  94. end;
  95. $N ->
  96. case Flag of
  97. true ->
  98. epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
  99. required ->
  100. exit(ssl_not_available)
  101. end
  102. end.
  103. %% --- Auth ---
  104. %% AuthenticationOk
  105. handle_message(?AUTHENTICATION_REQUEST, <<0:?int32>>, Sock, State) ->
  106. {noaction, Sock, State#connect{stage = initialization}};
  107. %% AuthenticationCleartextPassword
  108. handle_message(?AUTHENTICATION_REQUEST, <<3:?int32>>, Sock, St) ->
  109. {requeue, Sock, St#connect{stage = auth, auth_method = cleartext}};
  110. %% AuthenticationMD5Password
  111. handle_message(?AUTHENTICATION_REQUEST, <<5:?int32, Salt:4/binary>>, Sock, St) ->
  112. {requeue, Sock, St#connect{stage = auth, auth_method = {md5, Salt}}};
  113. handle_message(?AUTHENTICATION_REQUEST, <<M:?int32, _/binary>>, Sock, _State) ->
  114. Method = case M of
  115. 2 -> kerberosV5;
  116. 4 -> crypt;
  117. 6 -> scm;
  118. 7 -> gss;
  119. 8 -> sspi;
  120. _ -> unknown
  121. end,
  122. {stop, normal, {error, {unsupported_auth_method, Method}}, Sock};
  123. %% --- Initialization ---
  124. %% BackendKeyData
  125. handle_message(?CANCELLATION_KEY, <<Pid:?int32, Key:?int32>>, Sock, _State) ->
  126. {noaction, epgsql_sock:set_attr(backend, {Pid, Key}, Sock)};
  127. %% ReadyForQuery
  128. handle_message(?READY_FOR_QUERY, _, Sock, _State) ->
  129. Codec = epgsql_binary:new_codec(Sock, []),
  130. Sock1 = epgsql_sock:set_attr(codec, Codec, Sock),
  131. {finish, connected, connected, Sock1};
  132. %% ErrorResponse
  133. handle_message(?ERROR, Err, Sock, #connect{stage = auth} = _State) ->
  134. Why = case Err#error.code of
  135. <<"28000">> -> invalid_authorization_specification;
  136. <<"28P01">> -> invalid_password;
  137. Any -> Any
  138. end,
  139. {stop, normal, {error, Why}, Sock};
  140. handle_message(_, _, _, _) ->
  141. unknown.
  142. get_val(Key, Proplist) ->
  143. Val = proplists:get_value(Key, Proplist),
  144. (Val =/= undefined) orelse error({required_option, Key}),
  145. Val.
  146. hex(Bin) ->
  147. HChar = fun(N) when N < 10 -> $0 + N;
  148. (N) when N < 16 -> $W + N
  149. end,
  150. <<<<(HChar(H)), (HChar(L))>> || <<H:4, L:4>> <= Bin>>.