Browse Source

Merge branch 'devel'

Sergey Prokhorov 4 years ago
parent
commit
8681806b5c

+ 1 - 0
.travis.yml

@@ -11,6 +11,7 @@ install: "true"
 language: erlang
 language: erlang
 matrix:
 matrix:
   include:
   include:
+    - otp_release: 23.0
     - otp_release: 22.2
     - otp_release: 22.2
     - otp_release: 21.3
     - otp_release: 21.3
     - otp_release: 20.3
     - otp_release: 20.3

+ 11 - 0
CHANGES

@@ -1,3 +1,14 @@
+In 4.5.0
+
+* Add support for `application_name` connection parameter #226
+* Execute request cancelation over TLS, when main connection is TLS as well #227
+* Handle skipped commands in execute_batch #228
+* Add sasl_prep implementation for validating passwords according to sasl specification #229
+* OTP-23 in CI #237
+* switch to `crypto:mac/4` since `crypto:hmac/3` is deprecated #239
+* Add `tcp_opts` connect option #242
+* Command API improvements #243
+
 In 4.4.0
 In 4.4.0
 
 
 * Guards are now added to avoid silent integer truncation for numeric and
 * Guards are now added to avoid silent integer truncation for numeric and

+ 1 - 1
Makefile

@@ -4,7 +4,7 @@ MINIMAL_COVERAGE = 55
 all: compile
 all: compile
 
 
 $(REBAR):
 $(REBAR):
-	wget https://s3.amazonaws.com/rebar3/rebar3
+	wget https://github.com/erlang/rebar3/releases/download/3.13.2/rebar3
 	chmod +x rebar3
 	chmod +x rebar3
 
 
 compile: src/epgsql_errcodes.erl $(REBAR)
 compile: src/epgsql_errcodes.erl $(REBAR)

+ 18 - 4
README.md

@@ -71,6 +71,7 @@ connect(Opts) -> {ok, Connection :: epgsql:connection()} | {error, Reason :: epg
       port =>     inet:port_number(),
       port =>     inet:port_number(),
       ssl =>      boolean() | required,
       ssl =>      boolean() | required,
       ssl_opts => [ssl:ssl_option()],    % @see OTP ssl app, ssl_api.hrl
       ssl_opts => [ssl:ssl_option()],    % @see OTP ssl app, ssl_api.hrl
+      tcp_opts => [gen_tcp:option()],    % @see OTP gen_tcp module documentation
       timeout =>  timeout(),             % socket connect timeout, default: 5000 ms
       timeout =>  timeout(),             % socket connect timeout, default: 5000 ms
       async =>    pid() | atom(),        % process to receive LISTEN/NOTIFY msgs
       async =>    pid() | atom(),        % process to receive LISTEN/NOTIFY msgs
       codecs =>   [{epgsql_codec:codec_mod(), any()}]}
       codecs =>   [{epgsql_codec:codec_mod(), any()}]}
@@ -84,7 +85,10 @@ connect(Host, Username, Password, Opts) -> {ok, C} | {error, Reason}.
 example:
 example:
 
 
 ```erlang
 ```erlang
-{ok, C} = epgsql:connect("localhost", "username", "psss", #{
+{ok, C} = epgsql:connect(#{
+    host => "localhost",
+    username => "username",
+    password => "psss",
     database => "test_db",
     database => "test_db",
     timeout => 4000
     timeout => 4000
 }),
 }),
@@ -103,6 +107,10 @@ Only `host` and `username` are mandatory, but most likely you would need `databa
   if encryption isn't supported by server. if set to `required` connection will fail if encryption
   if encryption isn't supported by server. if set to `required` connection will fail if encryption
   is not available.
   is not available.
 - `ssl_opts` will be passed as is to `ssl:connect/3`
 - `ssl_opts` will be passed as is to `ssl:connect/3`
+- `tcp_opts` will be passed as is to `gen_tcp:connect/3`. Some options are forbidden, such as
+  `mode`, `packet`, `header`, `active`. When `tcp_opts` is not provided, epgsql does some tuning
+  (eg, sets TCP `keepalive` and auto-tunes `buffer`), but when `tcp_opts` is provided, no
+  additional tweaks are added by epgsql itself, other than necessary ones (`active`, `packet` and `mode`).
 - `async` see [Server notifications](#server-notifications)
 - `async` see [Server notifications](#server-notifications)
 - `codecs` see [Pluggable datatype codecs](#pluggable-datatype-codecs)
 - `codecs` see [Pluggable datatype codecs](#pluggable-datatype-codecs)
 - `nulls` terms which will be used to represent SQL `NULL`. If any of those has been encountered in
 - `nulls` terms which will be used to represent SQL `NULL`. If any of those has been encountered in
@@ -112,6 +120,9 @@ Only `host` and `username` are mandatory, but most likely you would need `databa
    Default is `[null, undefined]`, i.e. encode `null` or `undefined` in parameters as `NULL`
    Default is `[null, undefined]`, i.e. encode `null` or `undefined` in parameters as `NULL`
    and decode `NULL`s as atom `null`.
    and decode `NULL`s as atom `null`.
 - `replication` see [Streaming replication protocol](#streaming-replication-protocol)
 - `replication` see [Streaming replication protocol](#streaming-replication-protocol)
+- `application_name` is an optional string parameter. It is usually set by an application upon
+   connection to the server. The name will be displayed in the `pg_stat_activity`
+   view and included in CSV log entries.
 
 
 Options may be passed as proplist or as map with the same key names.
 Options may be passed as proplist or as map with the same key names.
 
 
@@ -427,8 +438,11 @@ epgsql:execute_batch(C, "INSERT INTO account (name, age) VALUES ($1, $2) RETURNI
                      [ ["Joe", 35], ["Paul", 26], ["Mary", 24] ]).
                      [ ["Joe", 35], ["Paul", 26], ["Mary", 24] ]).
 ```
 ```
 
 
-In case one of the batch items causes an error, the result returned for this particular
-item will be `{error, #error{}}` and no more results will be produced.
+In case one of the batch items causes an error, all the remaining queries of
+that batch will be ignored. So, last element of the result list will be 
+`{error, #error{}}` and the length of the result list might be shorter that 
+the length of the batch. For a better illustration of such scenario please 
+refer to `epgsql_SUITE:batch_error/1`
 
 
 `epgsqla:execute_batch/{2,3}` sends `{C, Ref, Results}`
 `epgsqla:execute_batch/{2,3}` sends `{C, Ref, Results}`
 
 
@@ -448,7 +462,7 @@ epgsql:cancel(connection()) -> ok.
 
 
 PostgreSQL protocol supports [cancellation](https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.9)
 PostgreSQL protocol supports [cancellation](https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.9)
 of currently executing command. `cancel/1` sends a cancellation request via the
 of currently executing command. `cancel/1` sends a cancellation request via the
-new temporary TCP connection asynchronously, it doesn't await for the command to
+new temporary TCP/TLS_over_TCP connection asynchronously, it doesn't await for the command to
 be cancelled. Instead, client should expect to get
 be cancelled. Instead, client should expect to get
 `{error, #error{code = <<"57014">>, codename = query_canceled}}` back from
 `{error, #error{code = <<"57014">>, codename = query_canceled}}` back from
 the command that was cancelled. However, normal response can still be received as well.
 the command that was cancelled. However, normal response can still be received as well.

+ 4 - 2
doc/pluggable_commands.md

@@ -49,12 +49,14 @@ passed to all subsequent callbacks. No PostgreSQL interactions should be done he
 ```erlang
 ```erlang
 execute(pg_sock(), state()) ->
 execute(pg_sock(), state()) ->
     {ok, pg_sock(), state()}
     {ok, pg_sock(), state()}
+  | {send, epgsql_wire:packet_type(), iodata(), pg_sock(), state()}
+  | {send_multi, [{epgsql_wire:packet_type(), iodata()}], pg_sock(), state()}
   | {stop, Reason :: any(), Response :: any(), pg_sock()}.
   | {stop, Reason :: any(), Response :: any(), pg_sock()}.
-
 ```
 ```
 
 
 Client -> Server packets should be sent from this callback by `epgsql_sock:send_multi/2` or
 Client -> Server packets should be sent from this callback by `epgsql_sock:send_multi/2` or
-`epgsql_sock:send/3`. `epgsql_wire` module is usually used to create wire protocol packets.
+`epgsql_sock:send/3` or by returning equivalent `send` or `send_multi` values.
+`epgsql_wire` module is usually used to create wire protocol packets.
 Please note that many packets might be sent at once. See `epgsql_cmd_equery` as an example.
 Please note that many packets might be sent at once. See `epgsql_cmd_equery` as an example.
 
 
 This callback might be executed more than once for a single command execution if your command
 This callback might be executed more than once for a single command execution if your command

+ 9 - 11
src/commands/epgsql_cmd_batch.erl

@@ -36,7 +36,8 @@
 -type response() :: [{ok, Count :: non_neg_integer(), Rows :: [tuple()]}
 -type response() :: [{ok, Count :: non_neg_integer(), Rows :: [tuple()]}
                      | {ok, Count :: non_neg_integer()}
                      | {ok, Count :: non_neg_integer()}
                      | {ok, Rows :: [tuple()]}
                      | {ok, Rows :: [tuple()]}
-                     | {error, epgsql:query_error()}].
+                     | {error, epgsql:query_error()}
+                     ].
 -type state() :: #batch{}.
 -type state() :: #batch{}.
 
 
 -spec init(arguments()) -> state().
 -spec init(arguments()) -> state().
@@ -57,10 +58,9 @@ execute(Sock, #batch{batch = Batch, statement = undefined} = State) ->
                   BinFormats = epgsql_wire:encode_formats(Columns),
                   BinFormats = epgsql_wire:encode_formats(Columns),
                   add_command(StatementName, Types, Parameters, BinFormats, Codec, Acc)
                   add_command(StatementName, Types, Parameters, BinFormats, Codec, Acc)
           end,
           end,
-          [{?SYNC, []}],
+          [epgsql_wire:encode_sync()],
           Batch),
           Batch),
-    epgsql_sock:send_multi(Sock, Commands),
-    {ok, Sock, State};
+    {send_multi, Commands, Sock, State};
 execute(Sock, #batch{batch = Batch,
 execute(Sock, #batch{batch = Batch,
                      statement = #statement{name = StatementName,
                      statement = #statement{name = StatementName,
                                             columns = Columns,
                                             columns = Columns,
@@ -73,16 +73,15 @@ execute(Sock, #batch{batch = Batch,
           fun(Parameters, Acc) ->
           fun(Parameters, Acc) ->
                   add_command(StatementName, Types, Parameters, BinFormats, Codec, Acc)
                   add_command(StatementName, Types, Parameters, BinFormats, Codec, Acc)
           end,
           end,
-          [{?SYNC, []}],
+          [epgsql_wire:encode_sync()],
           Batch),
           Batch),
-    epgsql_sock:send_multi(Sock, Commands),
-    {ok, Sock, State}.
+    {send_multi, Commands, Sock, State}.
 
 
 add_command(StmtName, Types, Params, BinFormats, Codec, Acc) ->
 add_command(StmtName, Types, Params, BinFormats, Codec, Acc) ->
     TypedParameters = lists:zip(Types, Params),
     TypedParameters = lists:zip(Types, Params),
     BinParams = epgsql_wire:encode_parameters(TypedParameters, Codec),
     BinParams = epgsql_wire:encode_parameters(TypedParameters, Codec),
-    [{?BIND, [0, StmtName, 0, BinParams, BinFormats]},
-     {?EXECUTE, [0, <<0:?int32>>]} | Acc].
+    [epgsql_wire:encode_bind("", StmtName, BinParams, BinFormats),
+     epgsql_wire:encode_execute("", 0) | Acc].
 
 
 handle_message(?BIND_COMPLETE, <<>>, Sock, State) ->
 handle_message(?BIND_COMPLETE, <<>>, Sock, State) ->
     Columns = current_cols(State),
     Columns = current_cols(State),
@@ -110,8 +109,7 @@ handle_message(?COMMAND_COMPLETE, Bin, Sock,
                      {ok, Rows}
                      {ok, Rows}
              end,
              end,
     {add_result, Result, {complete, Complete}, Sock, State#batch{batch = Batch}};
     {add_result, Result, {complete, Complete}, Sock, State#batch{batch = Batch}};
-handle_message(?READY_FOR_QUERY, _Status, Sock, #batch{batch = B} = _State) when
-      length(B) =< 1 ->
+handle_message(?READY_FOR_QUERY, _Status, Sock, _State) ->
     Results = epgsql_sock:get_results(Sock),
     Results = epgsql_sock:get_results(Sock),
     {finish, Results, done, Sock};
     {finish, Results, done, Sock};
 handle_message(?ERROR, Error, Sock, #batch{batch = [_ | Batch]} = State) ->
 handle_message(?ERROR, Error, Sock, #batch{batch = [_ | Batch]} = State) ->

+ 6 - 8
src/commands/epgsql_cmd_bind.erl

@@ -1,4 +1,4 @@
-%% @doc Binds placeholder parameters to prepared statement
+%% @doc Binds placeholder parameters to prepared statement, creating a "portal"
 %%
 %%
 %% ```
 %% ```
 %% > Bind
 %% > Bind
@@ -30,13 +30,11 @@ execute(Sock, #bind{stmt = Stmt, portal = PortalName, params = Params} = St) ->
     TypedParams = lists:zip(Types, Params),
     TypedParams = lists:zip(Types, Params),
     Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec),
     Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec),
     Bin2 = epgsql_wire:encode_formats(Columns),
     Bin2 = epgsql_wire:encode_formats(Columns),
-    epgsql_sock:send_multi(
-      Sock,
-      [
-       {?BIND, [PortalName, 0, StatementName, 0, Bin1, Bin2]},
-       {?FLUSH, []}
-      ]),
-    {ok, Sock, St}.
+    Commands = [
+       epgsql_wire:encode_bind(PortalName, StatementName, Bin1, Bin2),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 
 handle_message(?BIND_COMPLETE, <<>>, Sock, _State) ->
 handle_message(?BIND_COMPLETE, <<>>, Sock, _State) ->
     {finish, ok, ok, Sock};
     {finish, ok, ok, Sock};

+ 5 - 11
src/commands/epgsql_cmd_close.erl

@@ -22,17 +22,11 @@ init({Type, Name}) ->
     #close{type = Type, name = Name}.
     #close{type = Type, name = Name}.
 
 
 execute(Sock, #close{type = Type, name = Name} = St) ->
 execute(Sock, #close{type = Type, name = Name} = St) ->
-    Type2 = case Type of
-        statement -> ?PREPARED_STATEMENT;
-        portal    -> ?PORTAL
-    end,
-    epgsql_sock:send_multi(
-      Sock,
-      [
-       {?CLOSE, [Type2, Name, 0]},
-       {?FLUSH, []}
-      ]),
-    {ok, Sock, St}.
+    Packets = [
+       epgsql_wire:encode_close(Type, Name),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Packets, Sock, St}.
 
 
 handle_message(?CLOSE_COMPLETE, <<>>, Sock, _St) ->
 handle_message(?CLOSE_COMPLETE, <<>>, Sock, _St) ->
     {finish, ok, ok, Sock};
     {finish, ok, ok, Sock};

+ 109 - 73
src/commands/epgsql_cmd_connect.erl

@@ -6,7 +6,7 @@
 %%%
 %%%
 -module(epgsql_cmd_connect).
 -module(epgsql_cmd_connect).
 -behaviour(epgsql_command).
 -behaviour(epgsql_command).
--export([hide_password/1, opts_hide_password/1]).
+-export([hide_password/1, opts_hide_password/1, open_socket/2]).
 -export([init/1, execute/2, handle_message/4]).
 -export([init/1, execute/2, handle_message/4]).
 -export_type([response/0, connect_error/0]).
 -export_type([response/0, connect_error/0]).
 
 
@@ -47,57 +47,99 @@
 init(#{host := _, username := _} = Opts) ->
 init(#{host := _, username := _} = Opts) ->
     #connect{opts = Opts}.
     #connect{opts = Opts}.
 
 
-execute(PgSock, #connect{opts = #{host := Host} = Opts, stage = connect} = State) ->
-    Timeout = maps:get(timeout, Opts, 5000),
-    Deadline = deadline(Timeout),
-    Port = maps:get(port, Opts, 5432),
-    SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}, {keepalive, true}],
-    case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
-        {ok, Sock} ->
-            client_handshake(Sock, PgSock, State, Deadline);
-        {error, Reason} = Error ->
-            {stop, Reason, Error, PgSock}
-    end;
-execute(PgSock, #connect{stage = auth, auth_send = {PacketId, Data}} = St) ->
-    ok = epgsql_sock:send(PgSock, PacketId, Data),
-    {ok, PgSock, St#connect{auth_send = undefined}}.
-
-client_handshake(Sock, PgSock, #connect{opts = #{username := Username} = Opts} = State, Deadline) ->
-    %% Increase the buffer size.  Following the recommendation in the inet man page:
-    %%
-    %%    It is recommended to have val(buffer) >=
-    %%    max(val(sndbuf),val(recbuf)).
-
-    {ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
-        inet:getopts(Sock, [recbuf, sndbuf]),
-    inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]),
-
-    case maybe_ssl(Sock, maps:get(ssl, Opts, false), Opts, PgSock, Deadline) of
-        {error, Reason} ->
-            {stop, Reason, {error, Reason}, PgSock};
-        PgSock1 ->
+execute(PgSock, #connect{opts = #{username := Username} = Opts, stage = connect} = State) ->
+    SockOpts = prepare_tcp_opts(maps:get(tcp_opts, Opts, [])),
+    FilteredOpts = filter_sensitive_info(Opts),
+    PgSock1 = epgsql_sock:set_attr(connect_opts, FilteredOpts, PgSock),
+    case open_socket(SockOpts, Opts) of
+        {ok, Mode, Sock} ->
+            PgSock2 = epgsql_sock:set_net_socket(Mode, Sock, PgSock1),
             Opts2 = ["user", 0, Username, 0],
             Opts2 = ["user", 0, Username, 0],
             Opts3 = case maps:find(database, Opts) of
             Opts3 = case maps:find(database, Opts) of
                         error -> Opts2;
                         error -> Opts2;
                         {ok, Database}  -> [Opts2 | ["database", 0, Database, 0]]
                         {ok, Database}  -> [Opts2 | ["database", 0, Database, 0]]
                     end,
                     end,
+           {Opts4, PgSock3} =
+               case Opts of
+                   #{replication := Replication}  ->
+                       {[Opts3 | ["replication", 0, Replication, 0]],
+                        epgsql_sock:init_replication_state(PgSock2)};
+                   _ -> {Opts3, PgSock2}
+               end,
+            Opts5 = case Opts of
+                        #{application_name := ApplicationName}  ->
+                            [Opts4 | ["application_name", 0, ApplicationName, 0]];
+                        _ ->
+                            Opts4
+                    end,
+           ok = epgsql_sock:send(PgSock3, [<<196608:?int32>>, Opts5, 0]),
+           PgSock4 = case Opts of
+                         #{async := Async} ->
+                             epgsql_sock:set_attr(async, Async, PgSock3);
+                         _ -> PgSock3
+                     end,
+           {ok, PgSock4, State#connect{stage = maybe_auth}};
+        {error, Reason} = Error ->
+            {stop, Reason, Error, PgSock}
+    end;
+execute(PgSock, #connect{stage = auth, auth_send = {PacketType, Data}} = St) ->
+    {send, PacketType, Data, PgSock, St#connect{auth_send = undefined}}.
 
 
-            {Opts4, PgSock2} =
-                case Opts of
-                    #{replication := Replication}  ->
-                        {[Opts3 | ["replication", 0, Replication, 0]],
-                         epgsql_sock:init_replication_state(PgSock1)};
-                    _ -> {Opts3, PgSock1}
-                end,
-            ok = epgsql_sock:send(PgSock2, [<<196608:?int32>>, Opts4, 0]),
-            PgSock3 = case Opts of
-                          #{async := Async} ->
-                              epgsql_sock:set_attr(async, Async, PgSock2);
-                          _ -> PgSock2
-                      end,
-            {ok, PgSock3, State#connect{stage = maybe_auth}}
+-spec open_socket([{atom(), any()}], epgsql:connect_opts()) ->
+    {ok , gen_tcp | ssl, port() | ssl:sslsocket()} | {error, any()}.
+open_socket(SockOpts, #{host := Host} = ConnectOpts) ->
+    Timeout = maps:get(timeout, ConnectOpts, 5000),
+    Deadline = deadline(Timeout),
+    Port = maps:get(port, ConnectOpts, 5432),
+    case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
+       {ok, Sock} ->
+           client_handshake(Sock, ConnectOpts, Deadline);
+       {error, _Reason} = Error ->
+           Error
     end.
     end.
 
 
+client_handshake(Sock, ConnectOpts, Deadline) ->
+    case maps:is_key(tcp_opts, ConnectOpts) of
+        false ->
+            %% Increase the buffer size.  Following the recommendation in the inet man page:
+            %%
+            %%    It is recommended to have val(buffer) >=
+            %%    max(val(sndbuf),val(recbuf)).
+            {ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
+                inet:getopts(Sock, [recbuf, sndbuf]),
+            inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]);
+        true ->
+            %% All TCP options are provided by the user
+            noop
+    end,
+    maybe_ssl(Sock, maps:get(ssl, ConnectOpts, false), ConnectOpts, Deadline).
+
+maybe_ssl(Sock, false, _ConnectOpts, _Deadline) ->
+    {ok, gen_tcp, Sock};
+maybe_ssl(Sock, Flag, ConnectOpts, Deadline) ->
+    ok = gen_tcp:send(Sock, <<8:?int32, 80877103:?int32>>),
+    Timeout0 = timeout(Deadline),
+    case gen_tcp:recv(Sock, 1, Timeout0) of
+        {ok, <<$S>>}  ->
+            SslOpts = maps:get(ssl_opts, ConnectOpts, []),
+            Timeout = timeout(Deadline),
+            case ssl:connect(Sock, SslOpts, Timeout) of
+                {ok, Sock2} ->
+                    {ok, ssl, Sock2};
+                {error, Reason} ->
+                    Err = {ssl_negotiation_failed, Reason},
+                    {error, Err}
+            end;
+        {ok, <<$N>>} ->
+            case Flag of
+                true ->
+                   {ok, gen_tcp, Sock};
+                required ->
+                    {error, ssl_not_available}
+            end;
+        {error, Reason} ->
+            {error, Reason}
+    end.
 
 
 %% @doc Replace `password' in Opts map with obfuscated one
 %% @doc Replace `password' in Opts map with obfuscated one
 opts_hide_password(#{password := Password} = Opts) ->
 opts_hide_password(#{password := Password} = Opts) ->
@@ -105,6 +147,10 @@ opts_hide_password(#{password := Password} = Opts) ->
     Opts#{password => HiddenPassword};
     Opts#{password => HiddenPassword};
 opts_hide_password(Opts) -> Opts.
 opts_hide_password(Opts) -> Opts.
 
 
+%% @doc password and username are sensitive data that should not be stored in a
+%% permanent state that might crash during code upgrade
+filter_sensitive_info(Opts0) ->
+  maps:without([password, username], Opts0).
 
 
 %% @doc this function wraps plaintext password to a lambda function, so, if
 %% @doc this function wraps plaintext password to a lambda function, so, if
 %% epgsql_sock process crashes when executing `connect' command, password will
 %% epgsql_sock process crashes when executing `connect' command, password will
@@ -118,34 +164,6 @@ hide_password(Password) when is_list(Password);
 hide_password(PasswordFun) when is_function(PasswordFun, 0) ->
 hide_password(PasswordFun) when is_function(PasswordFun, 0) ->
     PasswordFun.
     PasswordFun.
 
 
-
-maybe_ssl(S, false, _, PgSock, _Deadline) ->
-    epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
-maybe_ssl(S, Flag, Opts, PgSock, Deadline) ->
-    ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
-    Timeout0 = timeout(Deadline),
-    case gen_tcp:recv(S, 1, Timeout0) of
-        {ok, <<$S>>}  ->
-            SslOpts = maps:get(ssl_opts, Opts, []),
-            Timeout = timeout(Deadline),
-            case ssl:connect(S, SslOpts, Timeout) of
-                {ok, S2}        ->
-                    epgsql_sock:set_net_socket(ssl, S2, PgSock);
-                {error, Reason} ->
-                    Err = {ssl_negotiation_failed, Reason},
-                    {error, Err}
-            end;
-        {ok, <<$N>>} ->
-            case Flag of
-                true ->
-                    epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
-                required ->
-                    {error, ssl_not_available}
-            end;
-        {error, Reason} ->
-            {error, Reason}
-    end.
-
 %% Auth sub-protocol
 %% Auth sub-protocol
 
 
 auth_init(<<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->
 auth_init(<<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->
@@ -268,6 +286,24 @@ handle_message(?ERROR, #error{code = Code} = Err, Sock, #connect{stage = Stage}
 handle_message(_, _, _, _) ->
 handle_message(_, _, _, _) ->
     unknown.
     unknown.
 
 
+prepare_tcp_opts([]) ->
+    [{active, false}, {packet, raw}, {mode, binary}, {nodelay, true}, {keepalive, true}];
+prepare_tcp_opts(Opts0) ->
+    case lists:filter(fun(binary) -> true;
+                         (list) -> true;
+                         ({mode, _}) -> true;
+                         ({packet, _}) -> true;
+                         ({packet_size, _}) -> true;
+                         ({header, _}) -> true;
+                         ({active, _}) -> true;
+                         (_) -> false
+                      end, Opts0) of
+        [] ->
+            [{active, false}, {packet, raw}, {mode, binary} | Opts0];
+        Forbidden ->
+            error({forbidden_tcp_opts, Forbidden})
+    end.
+
 
 
 get_password(Opts) ->
 get_password(Opts) ->
     PasswordFun = maps:get(password, Opts),
     PasswordFun = maps:get(password, Opts),
@@ -284,4 +320,4 @@ deadline(Timeout) ->
     erlang:monotonic_time(milli_seconds) + Timeout.
     erlang:monotonic_time(milli_seconds) + Timeout.
 
 
 timeout(Deadline) ->
 timeout(Deadline) ->
-    erlang:max(0, Deadline - erlang:monotonic_time(milli_seconds)).
+    erlang:max(0, Deadline - erlang:monotonic_time(milli_seconds)).

+ 5 - 6
src/commands/epgsql_cmd_describe_portal.erl

@@ -22,13 +22,12 @@ init(Name) ->
     #desc_portal{name = Name}.
     #desc_portal{name = Name}.
 
 
 execute(Sock, #desc_portal{name = Name} = St) ->
 execute(Sock, #desc_portal{name = Name} = St) ->
-    epgsql_sock:send_multi(
-      Sock,
+    Commands =
       [
       [
-       {?DESCRIBE, [?PORTAL, Name, 0]},
-       {?FLUSH, []}
-      ]),
-    {ok, Sock, St}.
+       epgsql_wire:encode_describe(portal, Name),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 
 handle_message(?ROW_DESCRIPTION, <<Count:?int16, Bin/binary>>, Sock, _St) ->
 handle_message(?ROW_DESCRIPTION, <<Count:?int16, Bin/binary>>, Sock, _St) ->
     Codec = epgsql_sock:get_codec(Sock),
     Codec = epgsql_sock:get_codec(Sock),

+ 5 - 6
src/commands/epgsql_cmd_describe_statement.erl

@@ -26,13 +26,12 @@ init(Name) ->
     #desc_stmt{name = Name}.
     #desc_stmt{name = Name}.
 
 
 execute(Sock, #desc_stmt{name = Name} = St) ->
 execute(Sock, #desc_stmt{name = Name} = St) ->
-    epgsql_sock:send_multi(
-      Sock,
+    Commands =
       [
       [
-       {?DESCRIBE, [?PREPARED_STATEMENT, Name, 0]},
-       {?FLUSH, []}
-      ]),
-    {ok, Sock, St}.
+       epgsql_wire:encode_describe(statement, Name),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 
 handle_message(?PARAMETER_DESCRIPTION, Bin, Sock, State) ->
 handle_message(?PARAMETER_DESCRIPTION, Bin, Sock, State) ->
     Codec = epgsql_sock:get_codec(Sock),
     Codec = epgsql_sock:get_codec(Sock),

+ 7 - 8
src/commands/epgsql_cmd_equery.erl

@@ -43,15 +43,14 @@ execute(Sock, #equery{stmt = Stmt, params = TypedParams} = St) ->
     Codec = epgsql_sock:get_codec(Sock),
     Codec = epgsql_sock:get_codec(Sock),
     Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec),
     Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec),
     Bin2 = epgsql_wire:encode_formats(Columns),
     Bin2 = epgsql_wire:encode_formats(Columns),
-    epgsql_sock:send_multi(
-      Sock,
+    Commands =
       [
       [
-       {?BIND, ["", 0, StatementName, 0, Bin1, Bin2]},
-       {?EXECUTE, ["", 0, <<0:?int32>>]},
-       {?CLOSE, [?PREPARED_STATEMENT, StatementName, 0]},
-       {?SYNC, []}
-      ]),
-    {ok, Sock, St}.
+       epgsql_wire:encode_bind("", StatementName, Bin1, Bin2),
+       epgsql_wire:encode_execute("", 0),
+       epgsql_wire:encode_close(statement, StatementName),
+       epgsql_wire:encode_sync()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 
 handle_message(?BIND_COMPLETE, <<>>, Sock, #equery{stmt = Stmt} = State) ->
 handle_message(?BIND_COMPLETE, <<>>, Sock, #equery{stmt = Stmt} = State) ->
     #statement{columns = Columns} = Stmt,
     #statement{columns = Columns} = Stmt,

+ 6 - 7
src/commands/epgsql_cmd_execute.erl

@@ -32,16 +32,15 @@ init({Stmt, PortalName, MaxRows}) ->
     #execute{stmt = Stmt, portal_name = PortalName, max_rows = MaxRows}.
     #execute{stmt = Stmt, portal_name = PortalName, max_rows = MaxRows}.
 
 
 execute(Sock, #execute{stmt = Stmt, portal_name = PortalName, max_rows = MaxRows} = State) ->
 execute(Sock, #execute{stmt = Stmt, portal_name = PortalName, max_rows = MaxRows} = State) ->
-    epgsql_sock:send_multi(
-      Sock,
-      [
-       {?EXECUTE, [PortalName, 0, <<MaxRows:?int32>>]},
-       {?FLUSH, []}
-      ]),
     #statement{columns = Columns} = Stmt,
     #statement{columns = Columns} = Stmt,
     Codec = epgsql_sock:get_codec(Sock),
     Codec = epgsql_sock:get_codec(Sock),
     Decoder = epgsql_wire:build_decoder(Columns, Codec),
     Decoder = epgsql_wire:build_decoder(Columns, Codec),
-    {ok, Sock, State#execute{decoder = Decoder}}.
+    Commands =
+      [
+       epgsql_wire:encode_execute(PortalName, MaxRows),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Commands, Sock, State#execute{decoder = Decoder}}.
 
 
 handle_message(?DATA_ROW, <<_Count:?int16, Bin/binary>>, Sock,
 handle_message(?DATA_ROW, <<_Count:?int16, Bin/binary>>, Sock,
                #execute{decoder = Decoder} = St) ->
                #execute{decoder = Decoder} = St) ->

+ 11 - 7
src/commands/epgsql_cmd_parse.erl

@@ -1,5 +1,10 @@
 %% @doc Asks server to parse SQL query and send information aboud bind-parameters and result columns.
 %% @doc Asks server to parse SQL query and send information aboud bind-parameters and result columns.
 %%
 %%
+%% Empty `Name' creates a "disposable" anonymous prepared statement.
+%% Non-empty `Name' creates a named prepared statement (name is not shared between connections),
+%% which should be explicitly closed when no logner needed (but will be terminated automatically
+%% when connection is closed).
+%% Non-empty name can't be rebound to another query; it should be closed for being available again.
 %% ```
 %% ```
 %% > Parse
 %% > Parse
 %% < ParseComplete
 %% < ParseComplete
@@ -31,14 +36,13 @@ init({Name, Sql, Types}) ->
 execute(Sock, #parse{name = Name, sql = Sql, types = Types} = St) ->
 execute(Sock, #parse{name = Name, sql = Sql, types = Types} = St) ->
     Codec = epgsql_sock:get_codec(Sock),
     Codec = epgsql_sock:get_codec(Sock),
     Bin = epgsql_wire:encode_types(Types, Codec),
     Bin = epgsql_wire:encode_types(Types, Codec),
-    epgsql_sock:send_multi(
-      Sock,
+    Commands =
       [
       [
-       {?PARSE, [Name, 0, Sql, 0, Bin]},
-       {?DESCRIBE, [?PREPARED_STATEMENT, Name, 0]},
-       {?FLUSH, []}
-      ]),
-    {ok, Sock, St}.
+       epgsql_wire:encode_parse(Name, Sql, Bin),
+       epgsql_wire:encode_describe(statement, Name),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 
 handle_message(?PARSE_COMPLETE, <<>>, Sock, _State) ->
 handle_message(?PARSE_COMPLETE, <<>>, Sock, _State) ->
     {noaction, Sock};
     {noaction, Sock};

+ 6 - 7
src/commands/epgsql_cmd_prepared_query.erl

@@ -37,14 +37,13 @@ execute(Sock, #pquery{stmt = Stmt, params = TypedParams} = St) ->
     Codec = epgsql_sock:get_codec(Sock),
     Codec = epgsql_sock:get_codec(Sock),
     Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec),
     Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec),
     Bin2 = epgsql_wire:encode_formats(Columns),
     Bin2 = epgsql_wire:encode_formats(Columns),
-    epgsql_sock:send_multi(
-      Sock,
+    Commands =
       [
       [
-       {?BIND, ["", 0, StatementName, 0, Bin1, Bin2]},
-       {?EXECUTE, ["", 0, <<0:?int32>>]},
-       {?SYNC, []}
-      ]),
-    {ok, Sock, St}.
+       epgsql_wire:encode_bind("", StatementName, Bin1, Bin2),
+       epgsql_wire:encode_execute("", 0),
+       epgsql_wire:encode_sync()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 
 handle_message(?BIND_COMPLETE, <<>>, Sock, #pquery{stmt = Stmt} = State) ->
 handle_message(?BIND_COMPLETE, <<>>, Sock, #pquery{stmt = Stmt} = State) ->
     #statement{columns = Columns} = Stmt,
     #statement{columns = Columns} = Stmt,

+ 2 - 2
src/commands/epgsql_cmd_squery.erl

@@ -38,8 +38,8 @@ init(Sql) ->
     #squery{query = Sql}.
     #squery{query = Sql}.
 
 
 execute(Sock, #squery{query = Q} = State) ->
 execute(Sock, #squery{query = Q} = State) ->
-    epgsql_sock:send(Sock, ?SIMPLEQUERY, [Q, 0]),
-    {ok, Sock, State}.
+    {Type, Data} = epgsql_wire:encode_query(Q),
+    {send, Type, Data, Sock, State}.
 
 
 handle_message(?ROW_DESCRIPTION, <<Count:?int16, Bin/binary>>, Sock, State) ->
 handle_message(?ROW_DESCRIPTION, <<Count:?int16, Bin/binary>>, Sock, State) ->
     Codec = epgsql_sock:get_codec(Sock),
     Codec = epgsql_sock:get_codec(Sock),

+ 2 - 3
src/commands/epgsql_cmd_start_replication.erl

@@ -59,9 +59,8 @@ execute(Sock, #start_repl{slot = ReplicationSlot, callback = Callback,
                        align_lsn = AlignLsn},
                        align_lsn = AlignLsn},
     Sock2 = epgsql_sock:set_attr(replication_state, Repl3, Sock),
     Sock2 = epgsql_sock:set_attr(replication_state, Repl3, Sock),
                          %% handler = on_replication},
                          %% handler = on_replication},
-
-    epgsql_sock:send(Sock2, ?SIMPLEQUERY, [Sql2, 0]),
-    {ok, Sock2, St}.
+    {PktType, PktData} = epgsql_wire:encode_query(Sql2),
+    {send, PktType, PktData, Sock2, St}.
 
 
 %% CopyBothResponse
 %% CopyBothResponse
 handle_message(?COPY_BOTH_RESPONSE, _Data, Sock, _State) ->
 handle_message(?COPY_BOTH_RESPONSE, _Data, Sock, _State) ->

+ 3 - 2
src/commands/epgsql_cmd_sync.erl

@@ -1,6 +1,7 @@
 %% @doc Synchronize client and server states for multi-command combinations
 %% @doc Synchronize client and server states for multi-command combinations
 %%
 %%
 %% Should be executed if APIs start to return `{error, sync_required}'.
 %% Should be executed if APIs start to return `{error, sync_required}'.
+%% See [https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY]
 %% ```
 %% ```
 %% > Sync
 %% > Sync
 %% < ReadyForQuery
 %% < ReadyForQuery
@@ -20,9 +21,9 @@ init(_) ->
     undefined.
     undefined.
 
 
 execute(Sock, St) ->
 execute(Sock, St) ->
-    epgsql_sock:send(Sock, ?SYNC, []),
     Sock1 = epgsql_sock:set_attr(sync_required, false, Sock),
     Sock1 = epgsql_sock:set_attr(sync_required, false, Sock),
-    {ok, Sock1, St}.
+    {Type, Data} = epgsql_wire:encode_sync(),
+    {send, Type, Data, Sock1, St}.
 
 
 handle_message(?READY_FOR_QUERY, _, Sock, _State) ->
 handle_message(?READY_FOR_QUERY, _, Sock, _State) ->
     {finish, ok, ok, Sock};
     {finish, ok, ok, Sock};

+ 2 - 2
src/commands/epgsql_cmd_update_type_cache.erl

@@ -22,8 +22,8 @@ execute(Sock, #upd{codecs = Codecs} = State) ->
     CodecEntries = epgsql_codec:init_mods(Codecs, Sock),
     CodecEntries = epgsql_codec:init_mods(Codecs, Sock),
     TypeNames = [element(1, Entry) || Entry <- CodecEntries],
     TypeNames = [element(1, Entry) || Entry <- CodecEntries],
     Query = epgsql_oid_db:build_query(TypeNames),
     Query = epgsql_oid_db:build_query(TypeNames),
-    epgsql_sock:send(Sock, ?SIMPLEQUERY, [Query, 0]),
-    {ok, Sock, State#upd{codec_entries = CodecEntries}}.
+    {PktType, PktData} = epgsql_wire:encode_query(Query),
+    {send, PktType, PktData, Sock, State#upd{codec_entries = CodecEntries}}.
 
 
 handle_message(?ROW_DESCRIPTION, <<Count:?int16, Bin/binary>>, Sock, State) ->
 handle_message(?ROW_DESCRIPTION, <<Count:?int16, Bin/binary>>, Sock, State) ->
     Codec = epgsql_sock:get_codec(Sock),
     Codec = epgsql_sock:get_codec(Sock),

+ 1 - 1
src/epgsql.app.src

@@ -1,6 +1,6 @@
 {application, epgsql,
 {application, epgsql,
  [{description, "PostgreSQL Client"},
  [{description, "PostgreSQL Client"},
-  {vsn, "4.4.0"},
+  {vsn, "4.5.0"},
   {modules, []},
   {modules, []},
   {registered, []},
   {registered, []},
   {applications, [kernel,
   {applications, [kernel,

+ 7 - 2
src/epgsql.erl

@@ -58,11 +58,13 @@
     {port,     PortNum    :: inet:port_number()}   |
     {port,     PortNum    :: inet:port_number()}   |
     {ssl,      IsEnabled  :: boolean() | required} |
     {ssl,      IsEnabled  :: boolean() | required} |
     {ssl_opts, SslOptions :: [ssl:ssl_option()]}   | % see OTP ssl app, ssl_api.hrl
     {ssl_opts, SslOptions :: [ssl:ssl_option()]}   | % see OTP ssl app, ssl_api.hrl
+    {tcp_opts, TcpOptions :: [gen_tcp:option()]}   | % see OTP ssl app, ssl_api.hrl
     {timeout,  TimeoutMs  :: timeout()}            | % default: 5000 ms
     {timeout,  TimeoutMs  :: timeout()}            | % default: 5000 ms
     {async,    Receiver   :: pid() | atom()}       | % process to receive LISTEN/NOTIFY msgs
     {async,    Receiver   :: pid() | atom()}       | % process to receive LISTEN/NOTIFY msgs
     {codecs,   Codecs     :: [{epgsql_codec:codec_mod(), any()}]} |
     {codecs,   Codecs     :: [{epgsql_codec:codec_mod(), any()}]} |
     {nulls,    Nulls      :: [any(), ...]} |    % terms to be used as NULL
     {nulls,    Nulls      :: [any(), ...]} |    % terms to be used as NULL
-    {replication, Replication :: string()}. % Pass "database" to connect in replication mode
+    {replication, Replication :: string()} | % Pass "database" to connect in replication mode
+    {application_name, ApplicationName :: string()}.
 
 
 -type connect_opts() ::
 -type connect_opts() ::
         [connect_option()]
         [connect_option()]
@@ -73,11 +75,14 @@
           port => inet:port_number(),
           port => inet:port_number(),
           ssl => boolean() | required,
           ssl => boolean() | required,
           ssl_opts => [ssl:ssl_option()],
           ssl_opts => [ssl:ssl_option()],
+          tcp_opts => [gen_tcp:option()],
           timeout => timeout(),
           timeout => timeout(),
           async => pid() | atom(),
           async => pid() | atom(),
           codecs => [{epgsql_codec:codec_mod(), any()}],
           codecs => [{epgsql_codec:codec_mod(), any()}],
           nulls => [any(), ...],
           nulls => [any(), ...],
-          replication => string()}.
+          replication => string(),
+          application_name => string()
+          }.
 
 
 -type connect_error() :: epgsql_cmd_connect:connect_error().
 -type connect_error() :: epgsql_cmd_connect:connect_error().
 -type query_error() :: #error{}.              % Error report generated by server
 -type query_error() :: #error{}.              % Error report generated by server

+ 4 - 0
src/epgsql_command.erl

@@ -15,6 +15,10 @@
 
 
 -type execute_return() ::
 -type execute_return() ::
         {ok, epgsql_sock:pg_sock(), state()}
         {ok, epgsql_sock:pg_sock(), state()}
+      | {send, epgsql_wire:packet_type(), PktData :: iodata(),
+         epgsql_sock:pg_sock(), state()}
+      | {send_multi, [{epgsql_wire:packet_type(), PktData :: iodata()}],
+         epgsql_sock:pg_sock(), state()}
       | {stop, Reason :: any(), Response :: any(), epgsql_sock:pg_sock()}.
       | {stop, Reason :: any(), Response :: any(), epgsql_sock:pg_sock()}.
 
 
 %% Execute command. It should send commands to socket.
 %% Execute command. It should send commands to socket.

+ 160 - 0
src/epgsql_sasl_prep_profile.erl

@@ -0,0 +1,160 @@
+%%% coding: utf-8
+%%% @doc
+%%% This is a helper module that will validate a utf-8
+%%% string based on sasl_prep profile as defined in
+%%% https://tools.ietf.org/html/rfc4013
+%%% @end
+
+-module(epgsql_sasl_prep_profile).
+
+-export([ validate/1
+        ]).
+
+-spec validate(iolist()) -> iolist().
+validate(Str) ->
+    CharL = unicode:characters_to_list(Str, utf8),
+    lists:foreach(fun(F) ->
+                      lists:any(F, CharL)
+                          andalso error({non_valid_scram_password, Str})
+                  end, [ fun is_non_asci_space_character/1
+                       , fun is_ascii_control_character/1
+                       , fun is_non_ascii_control_character/1
+                       , fun is_private_use_characters/1
+                       , fun is_non_character_code_points/1
+                       , fun is_surrogate_code_points/1
+                       , fun is_inappropriate_for_plain_text_char/1
+                       , fun is_inappropriate_for_canonical_representation_char/1
+                       , fun is_change_display_properties_or_deprecated_char/1
+                       , fun is_tagging_char/1 ]),
+    Str.
+
+%% @doc Return true if the given character is a non-ASCII space character
+%% as defined by https://tools.ietf.org/html/rfc3454#appendix-C.1.2
+-spec is_non_asci_space_character(char()) -> boolean().
+is_non_asci_space_character(C) ->
+    C == 16#00A0
+        orelse C == 16#1680
+        orelse (16#2000 =< C andalso C =< 16#200B)
+        orelse C == 16#202F
+        orelse C == 16#205F
+        orelse C == 16#3000.
+
+%% @doc Return true if the given character is an ASCII control character
+%% as defined by https://tools.ietf.org/html/rfc3454#appendix-C.2.1
+-spec is_ascii_control_character(char()) -> boolean().
+is_ascii_control_character(C) ->
+    C =< 16#001F orelse C == 16#007F.
+
+%% @doc Return true if the given character is a non-ASCII control character
+%% as defined by https://tools.ietf.org/html/rfc3454#appendix-C.2.2
+-spec is_non_ascii_control_character(char()) -> boolean().
+is_non_ascii_control_character(C) ->
+    (16#0080 =< C andalso C =< 16#009F)
+        orelse C == 16#06DD
+        orelse C == 16#070F
+        orelse C == 16#180E
+        orelse C == 16#200C
+        orelse C == 16#200D
+        orelse C == 16#2028
+        orelse C == 16#2029
+        orelse C == 16#2060
+        orelse C == 16#2061
+        orelse C == 16#2062
+        orelse C == 16#2063
+        orelse (16#206A =< C andalso C =< 16#206F)
+        orelse C == 16#FEFF
+        orelse (16#FFF9 =< C andalso C =< 16#FFFC)
+        orelse (16#1D173 =< C andalso C =< 16#1D17A).
+
+%% @doc Return true if the given character is a private use character
+%% as defined by https://tools.ietf.org/html/rfc3454#appendix-C.3
+-spec is_private_use_characters(char()) -> boolean().
+is_private_use_characters(C) ->
+    (16#E000 =< C andalso C =< 16#F8FF)
+         orelse (16#F000 =< C andalso C =< 16#FFFFD)
+        orelse (16#100000 =< C andalso C =< 16#10FFFD).
+
+%% @doc Return true if the given character is a non-character code point
+%% as defined by https://tools.ietf.org/html/rfc3454#appendix-C.4
+-spec is_non_character_code_points(char()) -> boolean().
+is_non_character_code_points(C) ->
+    (16#FDD0 =< C andalso C =< 16#FDEF)
+        orelse (16#FFFE =< C andalso C =< 16#FFFF)
+        orelse (16#1FFFE =< C andalso C =< 16#1FFFF)
+        orelse (16#2FFFE =< C andalso C =< 16#2FFFF)
+        orelse (16#3FFFE =< C andalso C =< 16#3FFFF)
+        orelse (16#4FFFE =< C andalso C =< 16#4FFFF)
+        orelse (16#5FFFE =< C andalso C =< 16#5FFFF)
+        orelse (16#6FFFE =< C andalso C =< 16#6FFFF)
+        orelse (16#7FFFE =< C andalso C =< 16#7FFFF)
+        orelse (16#8FFFE =< C andalso C =< 16#8FFFF)
+        orelse (16#9FFFE =< C andalso C =< 16#9FFFF)
+        orelse (16#AFFFE =< C andalso C =< 16#AFFFF)
+        orelse (16#BFFFE =< C andalso C =< 16#BFFFF)
+        orelse (16#CFFFE =< C andalso C =< 16#CFFFF)
+        orelse (16#DFFFE =< C andalso C =< 16#DFFFF)
+        orelse (16#EFFFE =< C andalso C =< 16#EFFFF)
+        orelse (16#FFFFE =< C andalso C =< 16#FFFFF)
+        orelse (16#10FFFE =< C andalso C =< 16#10FFFF).
+
+%% @doc Return true if the given character is a surrogate code point as defined by
+%% https://tools.ietf.org/html/rfc3454#appendix-C.5
+-spec is_surrogate_code_points(char()) -> boolean().
+is_surrogate_code_points(C) ->
+    16#D800 =< C andalso C =< 16#DFFF.
+
+%% @doc Return true if the given character is inappropriate for plain text characters
+%% as defined by https://tools.ietf.org/html/rfc3454#appendix-C.6
+-spec is_inappropriate_for_plain_text_char(char()) -> boolean().
+is_inappropriate_for_plain_text_char(C) ->
+    C == 16#FFF9
+        orelse C == 16#FFFA
+        orelse C == 16#FFFB
+        orelse C == 16#FFFC
+        orelse C == 16#FFFD.
+
+%% @doc Return true if the given character is inappropriate for canonical representation
+%% as defined by https://tools.ietf.org/html/rfc3454#appendix-C.7
+-spec is_inappropriate_for_canonical_representation_char(char()) -> boolean().
+is_inappropriate_for_canonical_representation_char(C) ->
+    16#2FF0 =< C andalso C =< 16#2FFB.
+
+%% @doc Return true if the given character is change display properties or deprecated
+%% characters as defined by https://tools.ietf.org/html/rfc3454#appendix-C.8
+-spec is_change_display_properties_or_deprecated_char(char()) -> boolean().
+is_change_display_properties_or_deprecated_char(C) ->
+    C == 16#0340
+        orelse C == 16#0341
+        orelse C == 16#200E
+        orelse C == 16#200F
+        orelse C == 16#202A
+        orelse C == 16#202B
+        orelse C == 16#202C
+        orelse C == 16#202D
+        orelse C == 16#202E
+        orelse C == 16#206A
+        orelse C == 16#206B
+        orelse C == 16#206C
+        orelse C == 16#206D
+        orelse C == 16#206E
+        orelse C == 16#206F.
+
+%% @doc Return true if the given character is a tagging character as defined by
+%% https://tools.ietf.org/html/rfc3454#appendix-C.9
+-spec is_tagging_char(char()) -> boolean().
+is_tagging_char(C) ->
+    C == 16#E0001 orelse
+        (16#E0020 =< C andalso C =< 16#E007F).
+
+-ifdef(TEST).
+-include_lib("eunit/include/eunit.hrl").
+
+normalize_test() ->
+    ?assertEqual(<<"123 !~">>, validate(<<"123 !~">>)),
+    ?assertEqual(<<"привет"/utf8>>, validate(<<"привет"/utf8>>)),
+    ?assertEqual(<<"Χαίρετε"/utf8>>, validate(<<"Χαίρετε"/utf8>>)),
+    ?assertEqual(<<"你好"/utf8>>, validate(<<"你好"/utf8>>)),
+    ?assertError({non_valid_scram_password, _},
+                 validate(<<"boom in the last char  ́"/utf8>>)).
+
+-endif.

+ 11 - 14
src/epgsql_scram.erl

@@ -76,7 +76,7 @@ get_client_final(SrvFirst, ClientNonce, UserName, Password) ->
     Salt = proplists:get_value(salt, SrvFirst),
     Salt = proplists:get_value(salt, SrvFirst),
     I = proplists:get_value(i, SrvFirst),
     I = proplists:get_value(i, SrvFirst),
 
 
-    SaltedPassword = hi(normalize(Password), Salt, I),
+    SaltedPassword = hi(epgsql_sasl_prep_profile:validate(Password), Salt, I),
     ClientKey = hmac(SaltedPassword, "Client Key"),
     ClientKey = hmac(SaltedPassword, "Client Key"),
     StoredKey = h(ClientKey),
     StoredKey = h(ClientKey),
     ClientFirstBare = client_first_bare(UserName, ClientNonce),
     ClientFirstBare = client_first_bare(UserName, ClientNonce),
@@ -100,15 +100,6 @@ parse_server_final(<<"e=", ServerError/binary>>) ->
 
 
 %% Helpers
 %% Helpers
 
 
-%% TODO: implement according to rfc3454
-normalize(Str) ->
-    lists:all(fun is_ascii_non_control/1, unicode:characters_to_list(Str, utf8))
-        orelse error({scram_non_ascii_password, Str}),
-    Str.
-
-is_ascii_non_control(C) when C > 16#1F, C < 16#7F -> true;
-is_ascii_non_control(_) -> false.
-
 check_nonce(ClientNonce, ServerNonce) ->
 check_nonce(ClientNonce, ServerNonce) ->
     Size = size(ClientNonce),
     Size = size(ClientNonce),
     <<ClientNonce:Size/binary, _/binary>> = ServerNonce,
     <<ClientNonce:Size/binary, _/binary>> = ServerNonce,
@@ -125,8 +116,18 @@ hi1(Str, U, Hi, I) ->
     Hi1 = bin_xor(Hi, U2),
     Hi1 = bin_xor(Hi, U2),
     hi1(Str, U2, Hi1, I - 1).
     hi1(Str, U2, Hi1, I - 1).
 
 
+-ifdef(OTP_RELEASE).
+-if(OTP_RELEASE >= 23).
+hmac(Key, Str) ->
+    crypto:mac(hmac, sha256, Key, Str).
+-else.
+hmac(Key, Str) ->
+    crypto:hmac(sha256, Key, Str).
+-endif.
+-else.
 hmac(Key, Str) ->
 hmac(Key, Str) ->
     crypto:hmac(sha256, Key, Str).
     crypto:hmac(sha256, Key, Str).
+-endif.
 
 
 h(Str) ->
 h(Str) ->
     crypto:hash(sha256, Str).
     crypto:hash(sha256, Str).
@@ -159,8 +160,4 @@ exchange_test() ->
     ?assertEqual(ClientFinal, iolist_to_binary(CF)),
     ?assertEqual(ClientFinal, iolist_to_binary(CF)),
     ?assertEqual({ok, ServerProof}, parse_server_final(ServerFinal)).
     ?assertEqual({ok, ServerProof}, parse_server_final(ServerFinal)).
 
 
-normalize_test() ->
-    ?assertEqual(<<"123 !~">>, normalize(<<"123 !~">>)),
-    ?assertError({scram_non_ascii_password, _}, normalize(<<"привет"/utf8>>)).
-
 -endif.
 -endif.

+ 24 - 15
src/epgsql_sock.erl

@@ -92,7 +92,8 @@
                 sync_required :: boolean() | undefined,
                 sync_required :: boolean() | undefined,
                 txstatus :: byte() | undefined,  % $I | $T | $E,
                 txstatus :: byte() | undefined,  % $I | $T | $E,
                 complete_status :: atom() | {atom(), integer()} | undefined,
                 complete_status :: atom() | {atom(), integer()} | undefined,
-                repl :: repl_state() | undefined}).
+                repl :: repl_state() | undefined,
+                connect_opts :: epgsql:connect_opts() | undefined}).
 
 
 -opaque pg_sock() :: #state{}.
 -opaque pg_sock() :: #state{}.
 
 
@@ -158,7 +159,9 @@ set_attr(codec, Codec, State) ->
 set_attr(sync_required, Value, State) ->
 set_attr(sync_required, Value, State) ->
     State#state{sync_required = Value};
     State#state{sync_required = Value};
 set_attr(replication_state, Value, State) ->
 set_attr(replication_state, Value, State) ->
-    State#state{repl = Value}.
+    State#state{repl = Value};
+set_attr(connect_opts, ConnectOpts, State) ->
+    State#state{connect_opts = ConnectOpts}.
 
 
 %% XXX: be careful!
 %% XXX: be careful!
 -spec set_packet_handler(atom(), pg_sock()) -> pg_sock().
 -spec set_packet_handler(atom(), pg_sock()) -> pg_sock().
@@ -225,17 +228,17 @@ handle_cast(stop, State) ->
     {stop, normal, flush_queue(State, {error, closed})};
     {stop, normal, flush_queue(State, {error, closed})};
 
 
 handle_cast(cancel, State = #state{backend = {Pid, Key},
 handle_cast(cancel, State = #state{backend = {Pid, Key},
-                                   sock = TimedOutSock}) ->
-    {ok, {Addr, Port}} = case State#state.mod of
-                             gen_tcp -> inet:peername(TimedOutSock);
-                             ssl -> ssl:peername(TimedOutSock)
-                         end,
+                                   connect_opts = ConnectOpts,
+                                   mod = Mode}) ->
     SockOpts = [{active, false}, {packet, raw}, binary],
     SockOpts = [{active, false}, {packet, raw}, binary],
-    %% TODO timeout
-    {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
     Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
     Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
-    ok = gen_tcp:send(Sock, Msg),
-    gen_tcp:close(Sock),
+    case epgsql_cmd_connect:open_socket(SockOpts, ConnectOpts) of
+      {ok, Mode, Sock} ->
+          ok = apply(Mode, send, [Sock, Msg]),
+          apply(Mode, close, [Sock]);
+      {error, _Reason} ->
+          noop
+    end,
     {noreply, State}.
     {noreply, State}.
 
 
 handle_info({Closed, Sock}, #state{sock = Sock} = State)
 handle_info({Closed, Sock}, #state{sock = Sock} = State)
@@ -287,6 +290,12 @@ command_exec(Transport, Command, CmdState, State) ->
     case epgsql_command:execute(Command, State, CmdState) of
     case epgsql_command:execute(Command, State, CmdState) of
         {ok, State1, CmdState1} ->
         {ok, State1, CmdState1} ->
             {noreply, command_enqueue(Transport, Command, CmdState1, State1)};
             {noreply, command_enqueue(Transport, Command, CmdState1, State1)};
+        {send, PktType, PktData, State1, CmdState1} ->
+            ok = send(State1, PktType, PktData),
+            {noreply, command_enqueue(Transport, Command, CmdState1, State1)};
+        {send_multi, Packets, State1, CmdState1} when is_list(Packets) ->
+            ok = send_multi(State1, Packets),
+            {noreply, command_enqueue(Transport, Command, CmdState1, State1)};
         {stop, StopReason, Response, State1} ->
         {stop, StopReason, Response, State1} ->
             reply(Transport, Response, Response),
             reply(Transport, Response, Response),
             {stop, StopReason, State1}
             {stop, StopReason, State1}
@@ -365,15 +374,15 @@ setopts(#state{mod = Mod, sock = Sock}, Opts) ->
 send(#state{mod = Mod, sock = Sock}, Data) ->
 send(#state{mod = Mod, sock = Sock}, Data) ->
     do_send(Mod, Sock, epgsql_wire:encode_command(Data)).
     do_send(Mod, Sock, epgsql_wire:encode_command(Data)).
 
 
--spec send(pg_sock(), byte(), iodata()) -> ok | {error, any()}.
+-spec send(pg_sock(), epgsql_wire:packet_type(), iodata()) -> ok | {error, any()}.
 send(#state{mod = Mod, sock = Sock}, Type, Data) ->
 send(#state{mod = Mod, sock = Sock}, Type, Data) ->
     do_send(Mod, Sock, epgsql_wire:encode_command(Type, Data)).
     do_send(Mod, Sock, epgsql_wire:encode_command(Type, Data)).
 
 
--spec send_multi(pg_sock(), [{byte(), iodata()}]) -> ok | {error, any()}.
+-spec send_multi(pg_sock(), [{epgsql_wire:packet_type(), iodata()}]) -> ok | {error, any()}.
 send_multi(#state{mod = Mod, sock = Sock}, List) ->
 send_multi(#state{mod = Mod, sock = Sock}, List) ->
     do_send(Mod, Sock, lists:map(fun({Type, Data}) ->
     do_send(Mod, Sock, lists:map(fun({Type, Data}) ->
-        epgsql_wire:encode_command(Type, Data)
-    end, List)).
+                                    epgsql_wire:encode_command(Type, Data)
+                                 end, List)).
 
 
 do_send(gen_tcp, Sock, Bin) ->
 do_send(gen_tcp, Sock, Bin) ->
     %% Why not gen_tcp:send/2?
     %% Why not gen_tcp:send/2?

+ 102 - 2
src/epgsql_wire.erl

@@ -24,15 +24,27 @@
          format/2,
          format/2,
          encode_parameters/2,
          encode_parameters/2,
          encode_standby_status_update/3]).
          encode_standby_status_update/3]).
--export_type([row_decoder/0]).
+%% Encoders for Client -> Server packets
+-export([encode_query/1,
+         encode_parse/3,
+         encode_describe/2,
+         encode_bind/4,
+         encode_execute/2,
+         encode_close/2,
+         encode_flush/0,
+         encode_sync/0]).
+
+-export_type([row_decoder/0, packet_type/0]).
 
 
 -include("epgsql.hrl").
 -include("epgsql.hrl").
 -include("protocol.hrl").
 -include("protocol.hrl").
 
 
 -opaque row_decoder() :: {[epgsql_binary:decoder()], [epgsql:column()], epgsql_binary:codec()}.
 -opaque row_decoder() :: {[epgsql_binary:decoder()], [epgsql:column()], epgsql_binary:codec()}.
+-type packet_type() :: byte().                 % see protocol.hrl
+%% -type packet_type(Exact) :: Exact.   % TODO: uncomment when OTP-18 is dropped
 
 
 %% @doc tries to extract single postgresql packet from TCP stream
 %% @doc tries to extract single postgresql packet from TCP stream
--spec decode_message(binary()) -> {byte(), binary(), binary()} | binary().
+-spec decode_message(binary()) -> {packet_type(), binary(), binary()} | binary().
 decode_message(<<Type:8, Len:?int32, Rest/binary>> = Bin) ->
 decode_message(<<Type:8, Len:?int32, Rest/binary>> = Bin) ->
     Len2 = Len - 4,
     Len2 = Len - 4,
     case Rest of
     case Rest of
@@ -60,6 +72,10 @@ decode_strings(Bin) ->
     <<Subj:Sz/binary, 0>> = Bin,
     <<Subj:Sz/binary, 0>> = Bin,
     binary:split(Subj, <<0>>, [global]).
     binary:split(Subj, <<0>>, [global]).
 
 
+-spec encode_string(iodata()) -> iodata().
+encode_string(Val) ->
+    [Val, 0].
+
 %% @doc decode error's field
 %% @doc decode error's field
 -spec decode_fields(binary()) -> [{byte(), binary()}].
 -spec decode_fields(binary()) -> [{byte(), binary()}].
 decode_fields(Bin) ->
 decode_fields(Bin) ->
@@ -282,6 +298,7 @@ encode_command(Data) ->
     [<<(Size + 4):?int32>> | Data].
     [<<(Size + 4):?int32>> | Data].
 
 
 %% @doc Encode PG command with type and size prefix
 %% @doc Encode PG command with type and size prefix
+-spec encode_command(packet_type(), iodata()) -> iodata().
 encode_command(Type, Data) ->
 encode_command(Type, Data) ->
     Size = iolist_size(Data),
     Size = iolist_size(Data),
     [<<Type:8, (Size + 4):?int32>> | Data].
     [<<Type:8, (Size + 4):?int32>> | Data].
@@ -292,3 +309,86 @@ encode_standby_status_update(ReceivedLSN, FlushedLSN, AppliedLSN) ->
     %% microseconds since midnight on 2000-01-01
     %% microseconds since midnight on 2000-01-01
     Timestamp = ((MegaSecs * 1000000 + Secs) * 1000000 + MicroSecs) - 946684800*1000000,
     Timestamp = ((MegaSecs * 1000000 + Secs) * 1000000 + MicroSecs) - 946684800*1000000,
     <<$r:8, ReceivedLSN:?int64, FlushedLSN:?int64, AppliedLSN:?int64, Timestamp:?int64, 0:8>>.
     <<$r:8, ReceivedLSN:?int64, FlushedLSN:?int64, AppliedLSN:?int64, Timestamp:?int64, 0:8>>.
+
+%%
+%% Encoders for various PostgreSQL protocol client-side packets
+%% See https://www.postgresql.org/docs/current/protocol-message-formats.html
+%%
+
+%% @doc encodes simple 'Query' packet.
+encode_query(SQL) ->
+    {?SIMPLEQUERY, encode_string(SQL)}.
+
+%% @doc encodes 'Parse' packet.
+%%
+%% Results in `ParseComplete' response.
+%%
+%% @param ColumnEncoding see {@link encode_types/2}
+-spec encode_parse(iodata(), iodata(), iodata()) -> {packet_type(), iodata()}.
+encode_parse(Name, SQL, ColumnEncoding) ->
+    {?PARSE, [encode_string(Name), encode_string(SQL), ColumnEncoding]}.
+
+%% @doc encodes `Describe' packet.
+%%
+%% @param What might be `?PORTAL' (results in `RowDescription' response) or `?PREPARED_STATEMENT'
+%%   (results in `ParameterDescription' followed by `RowDescription' or `NoData' response)
+-spec encode_describe(byte() | statement | portal, iodata()) ->
+          {packet_type(), iodata()}.
+encode_describe(What, Name) when What =:= ?PREPARED_STATEMENT;
+                                 What =:= ?PORTAL ->
+    {?DESCRIBE, [What, encode_string(Name)]};
+encode_describe(What, Name) when is_atom(What) ->
+    encode_describe(obj_atom_to_byte(What), Name).
+
+%% @doc encodes `Bind' packet.
+%%
+%% @param BinParams see {@link encode_parameters/2}.
+%% @param BinFormats  see {@link encode_formats/1}
+-spec encode_bind(iodata(), iodata(), iodata(), iodata()) -> {packet_type(), iodata()}.
+encode_bind(PortalName, StmtName, BinParams, BinFormats) ->
+    {?BIND, [encode_string(PortalName), encode_string(StmtName), BinParams, BinFormats]}.
+
+%% @doc encodes `Execute' packet.
+%%
+%% Results in 0 or up to `MaxRows' packets of `DataRow' type followed by `CommandComplete' (when no
+%% more rows are available) or `PortalSuspend' (repeated `Execute' will return more rows)
+%%
+%% @param PortalName  might be an empty string (anonymous portal) or name of the named portal
+%% @param MaxRows  how many rows server should send (0 means all of them)
+-spec encode_execute(iodata(), non_neg_integer()) -> {packet_type(), iodata()}.
+encode_execute("", 0) ->
+    %% optimization: literal for most common case
+    {?EXECUTE, [0, <<0:?int32>>]};
+encode_execute(PortalName, MaxRows) ->
+    {?EXECUTE, [encode_string(PortalName), <<MaxRows:?int32>>]}.
+
+%% @doc encodes `Close' packet.
+%%
+%% Results in `CloseComplete' response
+%%
+%% @param What see {@link encode_describe/2}
+-spec encode_close(byte() | statement | portal, iodata()) ->
+          {packet_type(), iodata()}.
+encode_close(What, Name) when What =:= ?PREPARED_STATEMENT;
+                              What =:= ?PORTAL ->
+    {?CLOSE, [What, encode_string(Name)]};
+encode_close(What, Name) when is_atom(What) ->
+    encode_close(obj_atom_to_byte(What), Name).
+
+%% @doc encodes `Flush' packet.
+%%
+%% It doesn't cause any specific response packet, but tells PostgreSQL server to flush it's send
+%% network buffers
+-spec encode_flush() -> {packet_type(), iodata()}.
+encode_flush() ->
+    {?FLUSH, []}.
+
+%% @doc encodes `Sync' packet.
+%%
+%% Results in `ReadyForQuery' response
+-spec encode_sync() -> {packet_type(), iodata()}.
+encode_sync() ->
+    {?SYNC, []}.
+
+obj_atom_to_byte(statement) -> ?PREPARED_STATEMENT;
+obj_atom_to_byte(portal) -> ?PORTAL.

+ 82 - 3
test/epgsql_SUITE.erl

@@ -34,6 +34,7 @@ groups() ->
     Groups = [
     Groups = [
         {connect, [parrallel], [
         {connect, [parrallel], [
             connect,
             connect,
+            connect_with_application_name,
             connect_to_db,
             connect_to_db,
             connect_as,
             connect_as,
             connect_with_cleartext,
             connect_with_cleartext,
@@ -44,6 +45,8 @@ groups() ->
             connect_to_invalid_database,
             connect_to_invalid_database,
             connect_with_other_error,
             connect_with_other_error,
             connect_with_ssl,
             connect_with_ssl,
+            cancel_query_for_connection_with_ssl,
+            cancel_query_for_connection_with_gen_tcp,
             connect_with_client_cert,
             connect_with_client_cert,
             connect_with_invalid_client_cert,
             connect_with_invalid_client_cert,
             connect_to_closed_port,
             connect_to_closed_port,
@@ -170,6 +173,16 @@ end_per_group(_GroupName, _Config) ->
                  {routine, _} | _]
                  {routine, _} | _]
         }}).
         }}).
 
 
+-define(QUERY_CANCELED, {error, #error{
+        severity = error,
+        code = <<"57014">>,
+        codename = query_canceled,
+        message = <<"canceling statement due to user request">>,
+        extra = [{file, <<"postgres.c">>},
+                 {line, _},
+                 {routine, _} | _]
+        }}).
+
 %% From uuid.erl in http://gitorious.org/avtobiff/erlang-uuid
 %% From uuid.erl in http://gitorious.org/avtobiff/erlang-uuid
 uuid_to_bin_string(<<U0:32, U1:16, U2:16, U3:16, U4:48>>) ->
 uuid_to_bin_string(<<U0:32, U1:16, U2:16, U3:16, U4:48>>) ->
     iolist_to_binary(io_lib:format(
     iolist_to_binary(io_lib:format(
@@ -179,6 +192,18 @@ uuid_to_bin_string(<<U0:32, U1:16, U2:16, U3:16, U4:48>>) ->
 connect(Config) ->
 connect(Config) ->
     epgsql_ct:connect_only(Config, []).
     epgsql_ct:connect_only(Config, []).
 
 
+connect_with_application_name(Config) ->
+    Module = ?config(module, Config),
+    Fun = fun(C) ->
+              Query = "select application_name from pg_stat_activity",
+              {ok, _Columns, Rows} = Module:equery(C, Query),
+              ?assert(lists:member({<<"app_test">>}, Rows))
+          end,
+    epgsql_ct:with_connection(Config,
+                              Fun,
+                              "epgsql_test",
+                              [{application_name, "app_test"}]).
+
 connect_to_db(Connect) ->
 connect_to_db(Connect) ->
     epgsql_ct:connect_only(Connect, [{database, "epgsql_test_db1"}]).
     epgsql_ct:connect_only(Connect, [{database, "epgsql_test_db1"}]).
 
 
@@ -271,6 +296,58 @@ connect_with_ssl(Config) ->
         "epgsql_test",
         "epgsql_test",
         [{ssl, true}]).
         [{ssl, true}]).
 
 
+cancel_query_for_connection_with_ssl(Config) ->
+    Module = ?config(module, Config),
+    {Host, Port} = epgsql_ct:connection_data(Config),
+    Module = ?config(module, Config),
+    Args2 = [ {port, Port}, {database, "epgsql_test_db1"}
+            | [ {ssl, true}
+              , {timeout, 1000} ]
+            ],
+    {ok, C} = Module:connect(Host, "epgsql_test", Args2),
+    ?assertMatch({ok, _Cols, [{true}]},
+                Module:equery(C, "select ssl_is_used()")),
+    Self = self(),
+    spawn_link(fun() ->
+                   ?assertMatch(?QUERY_CANCELED, Module:equery(C, "SELECT pg_sleep(5)")),
+                   Self ! done
+               end),
+    %% this timer is needed for the test not to be flaky
+    timer:sleep(1000),
+    epgsql:cancel(C),
+    receive done ->
+        ?assert(true)
+    after 5000 ->
+        epgsql:close(C),
+        ?assert(false)
+    end,
+    epgsql_ct:flush().
+
+cancel_query_for_connection_with_gen_tcp(Config) ->
+    Module = ?config(module, Config),
+    {Host, Port} = epgsql_ct:connection_data(Config),
+    Module = ?config(module, Config),
+    Args2 = [ {port, Port}, {database, "epgsql_test_db1"}
+            | [ {timeout, 1000} ]
+            ],
+    {ok, C} = Module:connect(Host, "epgsql_test", Args2),
+    process_flag(trap_exit, true),
+    Self = self(),
+    spawn_link(fun() ->
+                   ?assertMatch(?QUERY_CANCELED, Module:equery(C, "SELECT pg_sleep(5)")),
+                   Self ! done
+               end),
+    %% this timer is needed for the test not to be flaky
+    timer:sleep(1000),
+    epgsql:cancel(C),
+    receive done ->
+        ?assert(true)
+    after 5000 ->
+        epgsql:close(C),
+        ?assert(false)
+    end,
+    epgsql_ct:flush().
+
 connect_with_client_cert(Config) ->
 connect_with_client_cert(Config) ->
     Module = ?config(module, Config),
     Module = ?config(module, Config),
     Dir = filename:join(code:lib_dir(epgsql), ?TEST_DATA_DIR),
     Dir = filename:join(code:lib_dir(epgsql), ?TEST_DATA_DIR),
@@ -466,13 +543,15 @@ batch_error(Config) ->
     Module = ?config(module, Config),
     Module = ?config(module, Config),
     epgsql_ct:with_rollback(Config, fun(C) ->
     epgsql_ct:with_rollback(Config, fun(C) ->
         {ok, S} = Module:parse(C, "insert into test_table1(id, value) values($1, $2)"),
         {ok, S} = Module:parse(C, "insert into test_table1(id, value) values($1, $2)"),
-        [{ok, 1}, {error, _}] =
+        [{ok, 1}, {error, Error}] =
             Module:execute_batch(
             Module:execute_batch(
               C,
               C,
               [{S, [3, "batch_error 3"]},
               [{S, [3, "batch_error 3"]},
                {S, [2, "batch_error 2"]}, % duplicate key error
                {S, [2, "batch_error 2"]}, % duplicate key error
-               {S, [5, "batch_error 5"]}  % won't be executed
-              ])
+               {S, [5, "batch_error 5"]},  % won't be executed
+               {S, [6, "batch_error 6"]}  % won't be executed
+              ]),
+        ?assertMatch(#error{}, Error)
     end).
     end).
 
 
 single_batch(Config) ->
 single_batch(Config) ->