Browse Source

Refactor epgsql_cmd_connect to allow other parts of the source code to make use of the newly introduced fun, open_socket

Enid Gjoleka 5 years ago
parent
commit
dac4653b21
2 changed files with 76 additions and 71 deletions
  1. 68 67
      src/commands/epgsql_cmd_connect.erl
  2. 8 4
      src/epgsql_sock.erl

+ 68 - 67
src/commands/epgsql_cmd_connect.erl

@@ -6,7 +6,7 @@
 %%%
 -module(epgsql_cmd_connect).
 -behaviour(epgsql_command).
--export([hide_password/1, opts_hide_password/1]).
+-export([hide_password/1, opts_hide_password/1, open_socket/2]).
 -export([init/1, execute/2, handle_message/4]).
 -export_type([response/0, connect_error/0]).
 
@@ -47,14 +47,37 @@
 init(#{host := _, username := _} = Opts) ->
     #connect{opts = Opts}.
 
-execute(PgSock, #connect{opts = #{host := Host} = Opts, stage = connect} = State) ->
-    Timeout = maps:get(timeout, Opts, 5000),
-    Deadline = deadline(Timeout),
-    Port = maps:get(port, Opts, 5432),
+execute(PgSock, #connect{opts = #{username := Username} = Opts, stage = connect} = State) ->
     SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}, {keepalive, true}],
-    case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
-        {ok, Sock} ->
-            client_handshake(Sock, PgSock, State, Deadline);
+    epgsql_sock:set_attr(connect_opts, Opts, PgSock),
+    case open_socket(SockOpts, Opts) of
+        {ok, Mode, Sock} ->
+            PgSock1 = epgsql_sock:set_net_socket(Mode, Sock, PgSock),
+            Opts2 = ["user", 0, Username, 0],
+            Opts3 = case maps:find(database, Opts) of
+                        error -> Opts2;
+                        {ok, Database}  -> [Opts2 | ["database", 0, Database, 0]]
+                    end,
+           {Opts4, PgSock2} =
+               case Opts of
+                   #{replication := Replication}  ->
+                       {[Opts3 | ["replication", 0, Replication, 0]],
+                        epgsql_sock:init_replication_state(PgSock1)};
+                   _ -> {Opts3, PgSock1}
+               end,
+            Opts5 = case Opts of
+                        #{application_name := ApplicationName}  ->
+                            [Opts4 | ["application_name", 0, ApplicationName, 0]];
+                        _ ->
+                            Opts4
+                    end,
+           ok = epgsql_sock:send(PgSock2, [<<196608:?int32>>, Opts5, 0]),
+           PgSock3 = case Opts of
+                         #{async := Async} ->
+                             epgsql_sock:set_attr(async, Async, PgSock2);
+                         _ -> PgSock2
+                     end,
+           {ok, PgSock3, State#connect{stage = maybe_auth}};
         {error, Reason} = Error ->
             {stop, Reason, Error, PgSock}
     end;
@@ -62,7 +85,18 @@ execute(PgSock, #connect{stage = auth, auth_send = {PacketId, Data}} = St) ->
     ok = epgsql_sock:send(PgSock, PacketId, Data),
     {ok, PgSock, St#connect{auth_send = undefined}}.
 
-client_handshake(Sock, PgSock, #connect{opts = #{username := Username} = Opts} = State, Deadline) ->
+open_socket(SockOpts, #{host := Host} = ConnectOpts) ->
+    Timeout = maps:get(timeout, ConnectOpts, 5000),
+    Deadline = deadline(Timeout),
+    Port = maps:get(port, ConnectOpts, 5432),
+    case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
+       {ok, Sock} ->
+           client_handshake(Sock, ConnectOpts, Deadline);
+       {error, _Reason} = Error ->
+           Error
+    end.
+  
+client_handshake(Sock, ConnectOpts, Deadline) ->
     %% Increase the buffer size.  Following the recommendation in the inet man page:
     %%
     %%    It is recommended to have val(buffer) >=
@@ -71,40 +105,35 @@ client_handshake(Sock, PgSock, #connect{opts = #{username := Username} = Opts} =
     {ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
         inet:getopts(Sock, [recbuf, sndbuf]),
     inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]),
+    maybe_ssl(Sock, maps:get(ssl, ConnectOpts, false), ConnectOpts, Deadline).
 
-    case maybe_ssl(Sock, maps:get(ssl, Opts, false), Opts, PgSock, Deadline) of
+maybe_ssl(Sock, false, _ConnectOpts, _Deadline) ->
+    {ok, gen_tcp, Sock};
+maybe_ssl(Sock, Flag, ConnectOpts, Deadline) ->
+    ok = gen_tcp:send(Sock, <<8:?int32, 80877103:?int32>>),
+    Timeout0 = timeout(Deadline),
+    case gen_tcp:recv(Sock, 1, Timeout0) of
+        {ok, <<$S>>}  ->
+            SslOpts = maps:get(ssl_opts, ConnectOpts, []),
+            Timeout = timeout(Deadline),
+            case ssl:connect(Sock, SslOpts, Timeout) of
+                {ok, Sock2} ->
+                    {ok, ssl, Sock2};
+                {error, Reason} ->
+                    Err = {ssl_negotiation_failed, Reason},
+                    {error, Err}
+            end;
+        {ok, <<$N>>} ->
+            case Flag of
+                true ->
+                   {ok, gen_tcp, Sock};
+                required ->
+                    {error, ssl_not_available}
+            end;
         {error, Reason} ->
-            {stop, Reason, {error, Reason}, PgSock};
-        PgSock1 ->
-            Opts2 = ["user", 0, Username, 0],
-            Opts3 = case maps:find(database, Opts) of
-                        error -> Opts2;
-                        {ok, Database}  -> [Opts2 | ["database", 0, Database, 0]]
-                    end,
-
-            {Opts4, PgSock2} =
-                case Opts of
-                    #{replication := Replication}  ->
-                        {[Opts3 | ["replication", 0, Replication, 0]],
-                         epgsql_sock:init_replication_state(PgSock1)};
-                    _ -> {Opts3, PgSock1}
-                end,
-            Opts5 = case Opts of
-                        #{application_name := ApplicationName}  ->
-                            [Opts3 | ["application_name", 0, ApplicationName, 0]];
-                        _ ->
-                            Opts4
-                    end,
-            ok = epgsql_sock:send(PgSock2, [<<196608:?int32>>, Opts5, 0]),
-            PgSock3 = case Opts of
-                          #{async := Async} ->
-                              epgsql_sock:set_attr(async, Async, PgSock2);
-                          _ -> PgSock2
-                      end,
-            {ok, PgSock3, State#connect{stage = maybe_auth}}
+            {error, Reason}
     end.
 
-
 %% @doc Replace `password' in Opts map with obfuscated one
 opts_hide_password(#{password := Password} = Opts) ->
     HiddenPassword = hide_password(Password),
@@ -124,34 +153,6 @@ hide_password(Password) when is_list(Password);
 hide_password(PasswordFun) when is_function(PasswordFun, 0) ->
     PasswordFun.
 
-
-maybe_ssl(S, false, _, PgSock, _Deadline) ->
-    epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
-maybe_ssl(S, Flag, Opts, PgSock, Deadline) ->
-    ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
-    Timeout0 = timeout(Deadline),
-    case gen_tcp:recv(S, 1, Timeout0) of
-        {ok, <<$S>>}  ->
-            SslOpts = maps:get(ssl_opts, Opts, []),
-            Timeout = timeout(Deadline),
-            case ssl:connect(S, SslOpts, Timeout) of
-                {ok, S2}        ->
-                    epgsql_sock:set_net_socket(ssl, S2, PgSock);
-                {error, Reason} ->
-                    Err = {ssl_negotiation_failed, Reason},
-                    {error, Err}
-            end;
-        {ok, <<$N>>} ->
-            case Flag of
-                true ->
-                    epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
-                required ->
-                    {error, ssl_not_available}
-            end;
-        {error, Reason} ->
-            {error, Reason}
-    end.
-
 %% Auth sub-protocol
 
 auth_init(<<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->

+ 8 - 4
src/epgsql_sock.erl

@@ -92,7 +92,8 @@
                 sync_required :: boolean() | undefined,
                 txstatus :: byte() | undefined,  % $I | $T | $E,
                 complete_status :: atom() | {atom(), integer()} | undefined,
-                repl :: repl_state() | undefined}).
+                repl :: repl_state() | undefined,
+                connect_opts :: epgsql:connect_opts()}).
 
 -opaque pg_sock() :: #state{}.
 
@@ -158,7 +159,9 @@ set_attr(codec, Codec, State) ->
 set_attr(sync_required, Value, State) ->
     State#state{sync_required = Value};
 set_attr(replication_state, Value, State) ->
-    State#state{repl = Value}.
+    State#state{repl = Value};
+set_attr(connect_opts, ConnectOpts, State) ->
+    State#state{connect_opts = ConnectOpts}.
 
 %% XXX: be careful!
 -spec set_packet_handler(atom(), pg_sock()) -> pg_sock().
@@ -232,6 +235,7 @@ handle_cast(cancel, State = #state{backend = {Pid, Key},
                          end,
     SockOpts = [{active, false}, {packet, raw}, binary],
     %% TODO timeout
+    %% TODO DO NOT use gen_tcp
     {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
     Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
     ok = gen_tcp:send(Sock, Msg),
@@ -372,8 +376,8 @@ send(#state{mod = Mod, sock = Sock}, Type, Data) ->
 -spec send_multi(pg_sock(), [{byte(), iodata()}]) -> ok | {error, any()}.
 send_multi(#state{mod = Mod, sock = Sock}, List) ->
     do_send(Mod, Sock, lists:map(fun({Type, Data}) ->
-        epgsql_wire:encode_command(Type, Data)
-    end, List)).
+                                    epgsql_wire:encode_command(Type, Data)
+                                 end, List)).
 
 do_send(gen_tcp, Sock, Bin) ->
     %% Why not gen_tcp:send/2?