|
@@ -6,7 +6,7 @@
|
|
|
%%%
|
|
|
-module(epgsql_cmd_connect).
|
|
|
-behaviour(epgsql_command).
|
|
|
--export([hide_password/1, opts_hide_password/1]).
|
|
|
+-export([hide_password/1, opts_hide_password/1, open_socket/2]).
|
|
|
-export([init/1, execute/2, handle_message/4]).
|
|
|
-export_type([response/0, connect_error/0]).
|
|
|
|
|
@@ -47,57 +47,99 @@
|
|
|
init(#{host := _, username := _} = Opts) ->
|
|
|
#connect{opts = Opts}.
|
|
|
|
|
|
-execute(PgSock, #connect{opts = #{host := Host} = Opts, stage = connect} = State) ->
|
|
|
- Timeout = maps:get(timeout, Opts, 5000),
|
|
|
- Deadline = deadline(Timeout),
|
|
|
- Port = maps:get(port, Opts, 5432),
|
|
|
- SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}, {keepalive, true}],
|
|
|
- case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
|
|
|
- {ok, Sock} ->
|
|
|
- client_handshake(Sock, PgSock, State, Deadline);
|
|
|
- {error, Reason} = Error ->
|
|
|
- {stop, Reason, Error, PgSock}
|
|
|
- end;
|
|
|
-execute(PgSock, #connect{stage = auth, auth_send = {PacketId, Data}} = St) ->
|
|
|
- ok = epgsql_sock:send(PgSock, PacketId, Data),
|
|
|
- {ok, PgSock, St#connect{auth_send = undefined}}.
|
|
|
-
|
|
|
-client_handshake(Sock, PgSock, #connect{opts = #{username := Username} = Opts} = State, Deadline) ->
|
|
|
- %% Increase the buffer size. Following the recommendation in the inet man page:
|
|
|
- %%
|
|
|
- %% It is recommended to have val(buffer) >=
|
|
|
- %% max(val(sndbuf),val(recbuf)).
|
|
|
-
|
|
|
- {ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
|
|
|
- inet:getopts(Sock, [recbuf, sndbuf]),
|
|
|
- inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]),
|
|
|
-
|
|
|
- case maybe_ssl(Sock, maps:get(ssl, Opts, false), Opts, PgSock, Deadline) of
|
|
|
- {error, Reason} ->
|
|
|
- {stop, Reason, {error, Reason}, PgSock};
|
|
|
- PgSock1 ->
|
|
|
+execute(PgSock, #connect{opts = #{username := Username} = Opts, stage = connect} = State) ->
|
|
|
+ SockOpts = prepare_tcp_opts(maps:get(tcp_opts, Opts, [])),
|
|
|
+ FilteredOpts = filter_sensitive_info(Opts),
|
|
|
+ PgSock1 = epgsql_sock:set_attr(connect_opts, FilteredOpts, PgSock),
|
|
|
+ case open_socket(SockOpts, Opts) of
|
|
|
+ {ok, Mode, Sock} ->
|
|
|
+ PgSock2 = epgsql_sock:set_net_socket(Mode, Sock, PgSock1),
|
|
|
Opts2 = ["user", 0, Username, 0],
|
|
|
Opts3 = case maps:find(database, Opts) of
|
|
|
error -> Opts2;
|
|
|
{ok, Database} -> [Opts2 | ["database", 0, Database, 0]]
|
|
|
end,
|
|
|
+ {Opts4, PgSock3} =
|
|
|
+ case Opts of
|
|
|
+ #{replication := Replication} ->
|
|
|
+ {[Opts3 | ["replication", 0, Replication, 0]],
|
|
|
+ epgsql_sock:init_replication_state(PgSock2)};
|
|
|
+ _ -> {Opts3, PgSock2}
|
|
|
+ end,
|
|
|
+ Opts5 = case Opts of
|
|
|
+ #{application_name := ApplicationName} ->
|
|
|
+ [Opts4 | ["application_name", 0, ApplicationName, 0]];
|
|
|
+ _ ->
|
|
|
+ Opts4
|
|
|
+ end,
|
|
|
+ ok = epgsql_sock:send(PgSock3, [<<196608:?int32>>, Opts5, 0]),
|
|
|
+ PgSock4 = case Opts of
|
|
|
+ #{async := Async} ->
|
|
|
+ epgsql_sock:set_attr(async, Async, PgSock3);
|
|
|
+ _ -> PgSock3
|
|
|
+ end,
|
|
|
+ {ok, PgSock4, State#connect{stage = maybe_auth}};
|
|
|
+ {error, Reason} = Error ->
|
|
|
+ {stop, Reason, Error, PgSock}
|
|
|
+ end;
|
|
|
+execute(PgSock, #connect{stage = auth, auth_send = {PacketType, Data}} = St) ->
|
|
|
+ {send, PacketType, Data, PgSock, St#connect{auth_send = undefined}}.
|
|
|
|
|
|
- {Opts4, PgSock2} =
|
|
|
- case Opts of
|
|
|
- #{replication := Replication} ->
|
|
|
- {[Opts3 | ["replication", 0, Replication, 0]],
|
|
|
- epgsql_sock:init_replication_state(PgSock1)};
|
|
|
- _ -> {Opts3, PgSock1}
|
|
|
- end,
|
|
|
- ok = epgsql_sock:send(PgSock2, [<<196608:?int32>>, Opts4, 0]),
|
|
|
- PgSock3 = case Opts of
|
|
|
- #{async := Async} ->
|
|
|
- epgsql_sock:set_attr(async, Async, PgSock2);
|
|
|
- _ -> PgSock2
|
|
|
- end,
|
|
|
- {ok, PgSock3, State#connect{stage = maybe_auth}}
|
|
|
+-spec open_socket([{atom(), any()}], epgsql:connect_opts()) ->
|
|
|
+ {ok , gen_tcp | ssl, port() | ssl:sslsocket()} | {error, any()}.
|
|
|
+open_socket(SockOpts, #{host := Host} = ConnectOpts) ->
|
|
|
+ Timeout = maps:get(timeout, ConnectOpts, 5000),
|
|
|
+ Deadline = deadline(Timeout),
|
|
|
+ Port = maps:get(port, ConnectOpts, 5432),
|
|
|
+ case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
|
|
|
+ {ok, Sock} ->
|
|
|
+ client_handshake(Sock, ConnectOpts, Deadline);
|
|
|
+ {error, _Reason} = Error ->
|
|
|
+ Error
|
|
|
end.
|
|
|
|
|
|
+client_handshake(Sock, ConnectOpts, Deadline) ->
|
|
|
+ case maps:is_key(tcp_opts, ConnectOpts) of
|
|
|
+ false ->
|
|
|
+ %% Increase the buffer size. Following the recommendation in the inet man page:
|
|
|
+ %%
|
|
|
+ %% It is recommended to have val(buffer) >=
|
|
|
+ %% max(val(sndbuf),val(recbuf)).
|
|
|
+ {ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
|
|
|
+ inet:getopts(Sock, [recbuf, sndbuf]),
|
|
|
+ inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]);
|
|
|
+ true ->
|
|
|
+ %% All TCP options are provided by the user
|
|
|
+ noop
|
|
|
+ end,
|
|
|
+ maybe_ssl(Sock, maps:get(ssl, ConnectOpts, false), ConnectOpts, Deadline).
|
|
|
+
|
|
|
+maybe_ssl(Sock, false, _ConnectOpts, _Deadline) ->
|
|
|
+ {ok, gen_tcp, Sock};
|
|
|
+maybe_ssl(Sock, Flag, ConnectOpts, Deadline) ->
|
|
|
+ ok = gen_tcp:send(Sock, <<8:?int32, 80877103:?int32>>),
|
|
|
+ Timeout0 = timeout(Deadline),
|
|
|
+ case gen_tcp:recv(Sock, 1, Timeout0) of
|
|
|
+ {ok, <<$S>>} ->
|
|
|
+ SslOpts = maps:get(ssl_opts, ConnectOpts, []),
|
|
|
+ Timeout = timeout(Deadline),
|
|
|
+ case ssl:connect(Sock, SslOpts, Timeout) of
|
|
|
+ {ok, Sock2} ->
|
|
|
+ {ok, ssl, Sock2};
|
|
|
+ {error, Reason} ->
|
|
|
+ Err = {ssl_negotiation_failed, Reason},
|
|
|
+ {error, Err}
|
|
|
+ end;
|
|
|
+ {ok, <<$N>>} ->
|
|
|
+ case Flag of
|
|
|
+ true ->
|
|
|
+ {ok, gen_tcp, Sock};
|
|
|
+ required ->
|
|
|
+ {error, ssl_not_available}
|
|
|
+ end;
|
|
|
+ {error, Reason} ->
|
|
|
+ {error, Reason}
|
|
|
+ end.
|
|
|
|
|
|
%% @doc Replace `password' in Opts map with obfuscated one
|
|
|
opts_hide_password(#{password := Password} = Opts) ->
|
|
@@ -105,6 +147,10 @@ opts_hide_password(#{password := Password} = Opts) ->
|
|
|
Opts#{password => HiddenPassword};
|
|
|
opts_hide_password(Opts) -> Opts.
|
|
|
|
|
|
+%% @doc password and username are sensitive data that should not be stored in a
|
|
|
+%% permanent state that might crash during code upgrade
|
|
|
+filter_sensitive_info(Opts0) ->
|
|
|
+ maps:without([password, username], Opts0).
|
|
|
|
|
|
%% @doc this function wraps plaintext password to a lambda function, so, if
|
|
|
%% epgsql_sock process crashes when executing `connect' command, password will
|
|
@@ -118,34 +164,6 @@ hide_password(Password) when is_list(Password);
|
|
|
hide_password(PasswordFun) when is_function(PasswordFun, 0) ->
|
|
|
PasswordFun.
|
|
|
|
|
|
-
|
|
|
-maybe_ssl(S, false, _, PgSock, _Deadline) ->
|
|
|
- epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
|
|
|
-maybe_ssl(S, Flag, Opts, PgSock, Deadline) ->
|
|
|
- ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
|
|
|
- Timeout0 = timeout(Deadline),
|
|
|
- case gen_tcp:recv(S, 1, Timeout0) of
|
|
|
- {ok, <<$S>>} ->
|
|
|
- SslOpts = maps:get(ssl_opts, Opts, []),
|
|
|
- Timeout = timeout(Deadline),
|
|
|
- case ssl:connect(S, SslOpts, Timeout) of
|
|
|
- {ok, S2} ->
|
|
|
- epgsql_sock:set_net_socket(ssl, S2, PgSock);
|
|
|
- {error, Reason} ->
|
|
|
- Err = {ssl_negotiation_failed, Reason},
|
|
|
- {error, Err}
|
|
|
- end;
|
|
|
- {ok, <<$N>>} ->
|
|
|
- case Flag of
|
|
|
- true ->
|
|
|
- epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
|
|
|
- required ->
|
|
|
- {error, ssl_not_available}
|
|
|
- end;
|
|
|
- {error, Reason} ->
|
|
|
- {error, Reason}
|
|
|
- end.
|
|
|
-
|
|
|
%% Auth sub-protocol
|
|
|
|
|
|
auth_init(<<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->
|
|
@@ -268,6 +286,24 @@ handle_message(?ERROR, #error{code = Code} = Err, Sock, #connect{stage = Stage}
|
|
|
handle_message(_, _, _, _) ->
|
|
|
unknown.
|
|
|
|
|
|
+prepare_tcp_opts([]) ->
|
|
|
+ [{active, false}, {packet, raw}, {mode, binary}, {nodelay, true}, {keepalive, true}];
|
|
|
+prepare_tcp_opts(Opts0) ->
|
|
|
+ case lists:filter(fun(binary) -> true;
|
|
|
+ (list) -> true;
|
|
|
+ ({mode, _}) -> true;
|
|
|
+ ({packet, _}) -> true;
|
|
|
+ ({packet_size, _}) -> true;
|
|
|
+ ({header, _}) -> true;
|
|
|
+ ({active, _}) -> true;
|
|
|
+ (_) -> false
|
|
|
+ end, Opts0) of
|
|
|
+ [] ->
|
|
|
+ [{active, false}, {packet, raw}, {mode, binary} | Opts0];
|
|
|
+ Forbidden ->
|
|
|
+ error({forbidden_tcp_opts, Forbidden})
|
|
|
+ end.
|
|
|
+
|
|
|
|
|
|
get_password(Opts) ->
|
|
|
PasswordFun = maps:get(password, Opts),
|
|
@@ -284,4 +320,4 @@ deadline(Timeout) ->
|
|
|
erlang:monotonic_time(milli_seconds) + Timeout.
|
|
|
|
|
|
timeout(Deadline) ->
|
|
|
- erlang:max(0, Deadline - erlang:monotonic_time(milli_seconds)).
|
|
|
+ erlang:max(0, Deadline - erlang:monotonic_time(milli_seconds)).
|