Browse Source

Merge pull request #227 from enidgjoleka/handle-tls-connection-during-cancellation

Handle tls connection during cancellation
Sergey Prokhorov 5 years ago
parent
commit
f0d07e3fff
4 changed files with 156 additions and 81 deletions
  1. 1 1
      README.md
  2. 75 67
      src/commands/epgsql_cmd_connect.erl
  3. 16 13
      src/epgsql_sock.erl
  4. 64 0
      test/epgsql_SUITE.erl

+ 1 - 1
README.md

@@ -451,7 +451,7 @@ epgsql:cancel(connection()) -> ok.
 
 PostgreSQL protocol supports [cancellation](https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.9)
 of currently executing command. `cancel/1` sends a cancellation request via the
-new temporary TCP connection asynchronously, it doesn't await for the command to
+new temporary TCP/TLS_over_TCP connection asynchronously, it doesn't await for the command to
 be cancelled. Instead, client should expect to get
 `{error, #error{code = <<"57014">>, codename = query_canceled}}` back from
 the command that was cancelled. However, normal response can still be received as well.

+ 75 - 67
src/commands/epgsql_cmd_connect.erl

@@ -6,7 +6,7 @@
 %%%
 -module(epgsql_cmd_connect).
 -behaviour(epgsql_command).
--export([hide_password/1, opts_hide_password/1]).
+-export([hide_password/1, opts_hide_password/1, open_socket/2]).
 -export([init/1, execute/2, handle_message/4]).
 -export_type([response/0, connect_error/0]).
 
@@ -47,14 +47,38 @@
 init(#{host := _, username := _} = Opts) ->
     #connect{opts = Opts}.
 
-execute(PgSock, #connect{opts = #{host := Host} = Opts, stage = connect} = State) ->
-    Timeout = maps:get(timeout, Opts, 5000),
-    Deadline = deadline(Timeout),
-    Port = maps:get(port, Opts, 5432),
+execute(PgSock, #connect{opts = #{username := Username} = Opts, stage = connect} = State) ->
     SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}, {keepalive, true}],
-    case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
-        {ok, Sock} ->
-            client_handshake(Sock, PgSock, State, Deadline);
+    FilteredOpts = filter_sensitive_info(Opts),
+    PgSock1 = epgsql_sock:set_attr(connect_opts, FilteredOpts, PgSock),
+    case open_socket(SockOpts, Opts) of
+        {ok, Mode, Sock} ->
+            PgSock2 = epgsql_sock:set_net_socket(Mode, Sock, PgSock1),
+            Opts2 = ["user", 0, Username, 0],
+            Opts3 = case maps:find(database, Opts) of
+                        error -> Opts2;
+                        {ok, Database}  -> [Opts2 | ["database", 0, Database, 0]]
+                    end,
+           {Opts4, PgSock3} =
+               case Opts of
+                   #{replication := Replication}  ->
+                       {[Opts3 | ["replication", 0, Replication, 0]],
+                        epgsql_sock:init_replication_state(PgSock2)};
+                   _ -> {Opts3, PgSock2}
+               end,
+            Opts5 = case Opts of
+                        #{application_name := ApplicationName}  ->
+                            [Opts4 | ["application_name", 0, ApplicationName, 0]];
+                        _ ->
+                            Opts4
+                    end,
+           ok = epgsql_sock:send(PgSock3, [<<196608:?int32>>, Opts5, 0]),
+           PgSock4 = case Opts of
+                         #{async := Async} ->
+                             epgsql_sock:set_attr(async, Async, PgSock3);
+                         _ -> PgSock3
+                     end,
+           {ok, PgSock4, State#connect{stage = maybe_auth}};
         {error, Reason} = Error ->
             {stop, Reason, Error, PgSock}
     end;
@@ -62,7 +86,20 @@ execute(PgSock, #connect{stage = auth, auth_send = {PacketId, Data}} = St) ->
     ok = epgsql_sock:send(PgSock, PacketId, Data),
     {ok, PgSock, St#connect{auth_send = undefined}}.
 
-client_handshake(Sock, PgSock, #connect{opts = #{username := Username} = Opts} = State, Deadline) ->
+-spec open_socket([{atom(), any()}], epgsql:connect_opts()) ->
+    {ok , gen_tcp | ssl, port() | ssl:sslsocket()} | {error, any()}.
+open_socket(SockOpts, #{host := Host} = ConnectOpts) ->
+    Timeout = maps:get(timeout, ConnectOpts, 5000),
+    Deadline = deadline(Timeout),
+    Port = maps:get(port, ConnectOpts, 5432),
+    case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
+       {ok, Sock} ->
+           client_handshake(Sock, ConnectOpts, Deadline);
+       {error, _Reason} = Error ->
+           Error
+    end.
+
+client_handshake(Sock, ConnectOpts, Deadline) ->
     %% Increase the buffer size.  Following the recommendation in the inet man page:
     %%
     %%    It is recommended to have val(buffer) >=
@@ -71,46 +108,45 @@ client_handshake(Sock, PgSock, #connect{opts = #{username := Username} = Opts} =
     {ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
         inet:getopts(Sock, [recbuf, sndbuf]),
     inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]),
+    maybe_ssl(Sock, maps:get(ssl, ConnectOpts, false), ConnectOpts, Deadline).
 
-    case maybe_ssl(Sock, maps:get(ssl, Opts, false), Opts, PgSock, Deadline) of
+maybe_ssl(Sock, false, _ConnectOpts, _Deadline) ->
+    {ok, gen_tcp, Sock};
+maybe_ssl(Sock, Flag, ConnectOpts, Deadline) ->
+    ok = gen_tcp:send(Sock, <<8:?int32, 80877103:?int32>>),
+    Timeout0 = timeout(Deadline),
+    case gen_tcp:recv(Sock, 1, Timeout0) of
+        {ok, <<$S>>}  ->
+            SslOpts = maps:get(ssl_opts, ConnectOpts, []),
+            Timeout = timeout(Deadline),
+            case ssl:connect(Sock, SslOpts, Timeout) of
+                {ok, Sock2} ->
+                    {ok, ssl, Sock2};
+                {error, Reason} ->
+                    Err = {ssl_negotiation_failed, Reason},
+                    {error, Err}
+            end;
+        {ok, <<$N>>} ->
+            case Flag of
+                true ->
+                   {ok, gen_tcp, Sock};
+                required ->
+                    {error, ssl_not_available}
+            end;
         {error, Reason} ->
-            {stop, Reason, {error, Reason}, PgSock};
-        PgSock1 ->
-            Opts2 = ["user", 0, Username, 0],
-            Opts3 = case maps:find(database, Opts) of
-                        error -> Opts2;
-                        {ok, Database}  -> [Opts2 | ["database", 0, Database, 0]]
-                    end,
-
-            {Opts4, PgSock2} =
-                case Opts of
-                    #{replication := Replication}  ->
-                        {[Opts3 | ["replication", 0, Replication, 0]],
-                         epgsql_sock:init_replication_state(PgSock1)};
-                    _ -> {Opts3, PgSock1}
-                end,
-            Opts5 = case Opts of
-                        #{application_name := ApplicationName}  ->
-                            [Opts3 | ["application_name", 0, ApplicationName, 0]];
-                        _ ->
-                            Opts4
-                    end,
-            ok = epgsql_sock:send(PgSock2, [<<196608:?int32>>, Opts5, 0]),
-            PgSock3 = case Opts of
-                          #{async := Async} ->
-                              epgsql_sock:set_attr(async, Async, PgSock2);
-                          _ -> PgSock2
-                      end,
-            {ok, PgSock3, State#connect{stage = maybe_auth}}
+            {error, Reason}
     end.
 
-
 %% @doc Replace `password' in Opts map with obfuscated one
 opts_hide_password(#{password := Password} = Opts) ->
     HiddenPassword = hide_password(Password),
     Opts#{password => HiddenPassword};
 opts_hide_password(Opts) -> Opts.
 
+%% @doc password and username are sensitive data that should not be stored in a
+%% permanent state that might crash during code upgrade
+filter_sensitive_info(Opts0) ->
+  maps:without([password, username], Opts0).
 
 %% @doc this function wraps plaintext password to a lambda function, so, if
 %% epgsql_sock process crashes when executing `connect' command, password will
@@ -124,34 +160,6 @@ hide_password(Password) when is_list(Password);
 hide_password(PasswordFun) when is_function(PasswordFun, 0) ->
     PasswordFun.
 
-
-maybe_ssl(S, false, _, PgSock, _Deadline) ->
-    epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
-maybe_ssl(S, Flag, Opts, PgSock, Deadline) ->
-    ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
-    Timeout0 = timeout(Deadline),
-    case gen_tcp:recv(S, 1, Timeout0) of
-        {ok, <<$S>>}  ->
-            SslOpts = maps:get(ssl_opts, Opts, []),
-            Timeout = timeout(Deadline),
-            case ssl:connect(S, SslOpts, Timeout) of
-                {ok, S2}        ->
-                    epgsql_sock:set_net_socket(ssl, S2, PgSock);
-                {error, Reason} ->
-                    Err = {ssl_negotiation_failed, Reason},
-                    {error, Err}
-            end;
-        {ok, <<$N>>} ->
-            case Flag of
-                true ->
-                    epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
-                required ->
-                    {error, ssl_not_available}
-            end;
-        {error, Reason} ->
-            {error, Reason}
-    end.
-
 %% Auth sub-protocol
 
 auth_init(<<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->

+ 16 - 13
src/epgsql_sock.erl

@@ -92,7 +92,8 @@
                 sync_required :: boolean() | undefined,
                 txstatus :: byte() | undefined,  % $I | $T | $E,
                 complete_status :: atom() | {atom(), integer()} | undefined,
-                repl :: repl_state() | undefined}).
+                repl :: repl_state() | undefined,
+                connect_opts :: epgsql:connect_opts() | undefined}).
 
 -opaque pg_sock() :: #state{}.
 
@@ -158,7 +159,9 @@ set_attr(codec, Codec, State) ->
 set_attr(sync_required, Value, State) ->
     State#state{sync_required = Value};
 set_attr(replication_state, Value, State) ->
-    State#state{repl = Value}.
+    State#state{repl = Value};
+set_attr(connect_opts, ConnectOpts, State) ->
+    State#state{connect_opts = ConnectOpts}.
 
 %% XXX: be careful!
 -spec set_packet_handler(atom(), pg_sock()) -> pg_sock().
@@ -225,17 +228,17 @@ handle_cast(stop, State) ->
     {stop, normal, flush_queue(State, {error, closed})};
 
 handle_cast(cancel, State = #state{backend = {Pid, Key},
-                                   sock = TimedOutSock}) ->
-    {ok, {Addr, Port}} = case State#state.mod of
-                             gen_tcp -> inet:peername(TimedOutSock);
-                             ssl -> ssl:peername(TimedOutSock)
-                         end,
+                                   connect_opts = ConnectOpts,
+                                   mod = Mode}) ->
     SockOpts = [{active, false}, {packet, raw}, binary],
-    %% TODO timeout
-    {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
     Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
-    ok = gen_tcp:send(Sock, Msg),
-    gen_tcp:close(Sock),
+    case epgsql_cmd_connect:open_socket(SockOpts, ConnectOpts) of
+      {ok, Mode, Sock} ->
+          ok = apply(Mode, send, [Sock, Msg]),
+          apply(Mode, close, [Sock]);
+      {error, _Reason} ->
+          noop
+    end,
     {noreply, State}.
 
 handle_info({Closed, Sock}, #state{sock = Sock} = State)
@@ -372,8 +375,8 @@ send(#state{mod = Mod, sock = Sock}, Type, Data) ->
 -spec send_multi(pg_sock(), [{byte(), iodata()}]) -> ok | {error, any()}.
 send_multi(#state{mod = Mod, sock = Sock}, List) ->
     do_send(Mod, Sock, lists:map(fun({Type, Data}) ->
-        epgsql_wire:encode_command(Type, Data)
-    end, List)).
+                                    epgsql_wire:encode_command(Type, Data)
+                                 end, List)).
 
 do_send(gen_tcp, Sock, Bin) ->
     %% Why not gen_tcp:send/2?

+ 64 - 0
test/epgsql_SUITE.erl

@@ -45,6 +45,8 @@ groups() ->
             connect_to_invalid_database,
             connect_with_other_error,
             connect_with_ssl,
+            cancel_query_for_connection_with_ssl,
+            cancel_query_for_connection_with_gen_tcp,
             connect_with_client_cert,
             connect_with_invalid_client_cert,
             connect_to_closed_port,
@@ -171,6 +173,16 @@ end_per_group(_GroupName, _Config) ->
                  {routine, _} | _]
         }}).
 
+-define(QUERY_CANCELED, {error, #error{
+        severity = error,
+        code = <<"57014">>,
+        codename = query_canceled,
+        message = <<"canceling statement due to user request">>,
+        extra = [{file, <<"postgres.c">>},
+                 {line, _},
+                 {routine, _} | _]
+        }}).
+
 %% From uuid.erl in http://gitorious.org/avtobiff/erlang-uuid
 uuid_to_bin_string(<<U0:32, U1:16, U2:16, U3:16, U4:48>>) ->
     iolist_to_binary(io_lib:format(
@@ -284,6 +296,58 @@ connect_with_ssl(Config) ->
         "epgsql_test",
         [{ssl, true}]).
 
+cancel_query_for_connection_with_ssl(Config) ->
+    Module = ?config(module, Config),
+    {Host, Port} = epgsql_ct:connection_data(Config),
+    Module = ?config(module, Config),
+    Args2 = [ {port, Port}, {database, "epgsql_test_db1"}
+            | [ {ssl, true}
+              , {timeout, 1000} ]
+            ],
+    {ok, C} = Module:connect(Host, "epgsql_test", Args2),
+    ?assertMatch({ok, _Cols, [{true}]},
+                Module:equery(C, "select ssl_is_used()")),
+    Self = self(),
+    spawn_link(fun() ->
+                   ?assertMatch(?QUERY_CANCELED, Module:equery(C, "SELECT pg_sleep(5)")),
+                   Self ! done
+               end),
+    %% this timer is needed for the test not to be flaky
+    timer:sleep(1000),
+    epgsql:cancel(C),
+    receive done ->
+        ?assert(true)
+    after 5000 ->
+        epgsql:close(C),
+        ?assert(false)
+    end,
+    epgsql_ct:flush().
+
+cancel_query_for_connection_with_gen_tcp(Config) ->
+    Module = ?config(module, Config),
+    {Host, Port} = epgsql_ct:connection_data(Config),
+    Module = ?config(module, Config),
+    Args2 = [ {port, Port}, {database, "epgsql_test_db1"}
+            | [ {timeout, 1000} ]
+            ],
+    {ok, C} = Module:connect(Host, "epgsql_test", Args2),
+    process_flag(trap_exit, true),
+    Self = self(),
+    spawn_link(fun() ->
+                   ?assertMatch(?QUERY_CANCELED, Module:equery(C, "SELECT pg_sleep(5)")),
+                   Self ! done
+               end),
+    %% this timer is needed for the test not to be flaky
+    timer:sleep(1000),
+    epgsql:cancel(C),
+    receive done ->
+        ?assert(true)
+    after 5000 ->
+        epgsql:close(C),
+        ?assert(false)
+    end,
+    epgsql_ct:flush().
+
 connect_with_client_cert(Config) ->
     Module = ?config(module, Config),
     Dir = filename:join(code:lib_dir(epgsql), ?TEST_DATA_DIR),