Browse Source

Implement binary format of `COPY .. FROM STDIN`. GH-137

* New API function added: `epgsql:copy_send_rows/3`
Sergey Prokhorov 4 years ago
parent
commit
70149c63ab

+ 10 - 3
src/commands/epgsql_cmd_copy_done.erl

@@ -21,11 +21,18 @@ init(_) ->
     [].
     [].
 
 
 execute(Sock0, St) ->
 execute(Sock0, St) ->
-    #copy{} = epgsql_sock:get_subproto_state(Sock0), % assert we are in copy-mode
-    {PktType, PktData} = epgsql_wire:encode_copy_done(),
+    #copy{format = Format} = epgsql_sock:get_subproto_state(Sock0), % assert we are in copy-mode
     Sock1 = epgsql_sock:set_packet_handler(on_message, Sock0),
     Sock1 = epgsql_sock:set_packet_handler(on_message, Sock0),
     Sock = epgsql_sock:set_attr(subproto_state, undefined, Sock1),
     Sock = epgsql_sock:set_attr(subproto_state, undefined, Sock1),
-    {send, PktType, PktData, Sock, St}.
+    {PktType, PktData} = epgsql_wire:encode_copy_done(),
+    case Format of
+        text ->
+            {send, PktType, PktData, Sock, St};
+        binary ->
+            Pkts = [{?COPY_DATA, epgsql_wire:encode_copy_trailer()},
+                    {PktType, PktData}],
+            {send_multi, Pkts, Sock, St}
+    end.
 
 
 handle_message(?COMMAND_COMPLETE, Bin, Sock, St) ->
 handle_message(?COMMAND_COMPLETE, Bin, Sock, St) ->
     Complete = epgsql_wire:decode_complete(Bin),
     Complete = epgsql_wire:decode_complete(Bin),

+ 52 - 20
src/commands/epgsql_cmd_copy_from_stdin.erl

@@ -3,11 +3,19 @@
 %%% See [https://www.postgresql.org/docs/current/sql-copy.html].
 %%% See [https://www.postgresql.org/docs/current/sql-copy.html].
 %%% See [https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-COPY].
 %%% See [https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-COPY].
 %%%
 %%%
-%%% The copy data can then be delivered using Erlang
+%%% When `Format' is `text', copy data should then be delivered using Erlang
 %%% <a href="https://erlang.org/doc/apps/stdlib/io_protocol.html">io protocol</a>.
 %%% <a href="https://erlang.org/doc/apps/stdlib/io_protocol.html">io protocol</a>.
 %%% See {@link file:write/2}, {@link io:put_chars/2}.
 %%% See {@link file:write/2}, {@link io:put_chars/2}.
+%%% "End-of-data" marker `\.' at the end of TEXT or CSV data stream is not needed.
+%%%
+%%% When `Format' is `{binary, [epgsql_type()]}', recommended way to deliver data is
+%%% {@link epgsql:copy_send_rows/3}. IO-protocol can be used as well, as long as you can
+%%% do proper binary encoding of data tuples (header and trailer are sent automatically),
+%%% see [https://www.postgresql.org/docs/current/sql-copy.html#id-1.9.3.55.9.4.6].
+%%% When you don't know what are the correct type names for your columns, you could try to
+%%% construct equivalent `INSERT' or `SELECT' statement and call {@link epgsql:parse/2} command.
+%%% It will return `#statement{columns = [#column{type = TypeName}]}' with correct type names.
 %%%
 %%%
-%%% "End-of-data" marker `\.' at the end of TEXT or CSV data stream is not needed,
 %%% {@link epgsql_cmd_copy_done} should be called in the end.
 %%% {@link epgsql_cmd_copy_done} should be called in the end.
 %%%
 %%%
 %%% This command should not be used with command pipelining!
 %%% This command should not be used with command pipelining!
@@ -31,36 +39,35 @@
 -include("../epgsql_copy.hrl").
 -include("../epgsql_copy.hrl").
 
 
 -record(copy_stdin,
 -record(copy_stdin,
-        {query :: iodata(), initiator :: pid()}).
+        {query :: iodata(),
+         initiator :: pid(),
+         format :: {binary, [epgsql:epgsql_type()]} | text}).
 
 
-init({SQL, Initiator}) ->
-    #copy_stdin{query = SQL, initiator = Initiator}.
+init({SQL, Initiator, Format}) ->
+    #copy_stdin{query = SQL, initiator = Initiator, format = Format}.
 
 
-execute(Sock, #copy_stdin{query = SQL} = St) ->
+execute(Sock, #copy_stdin{query = SQL, format = Format} = St) ->
     undefined = epgsql_sock:get_subproto_state(Sock), % assert we are not in copy-mode already
     undefined = epgsql_sock:get_subproto_state(Sock), % assert we are not in copy-mode already
     {PktType, PktData} = epgsql_wire:encode_query(SQL),
     {PktType, PktData} = epgsql_wire:encode_query(SQL),
-    {send, PktType, PktData, Sock, St}.
+    case Format of
+        text ->
+            {send, PktType, PktData, Sock, St};
+        {binary, _} ->
+            Header = epgsql_wire:encode_copy_header(),
+            {send_multi, [{PktType, PktData},
+                          {?COPY_DATA, Header}], Sock, St}
+    end.
 
 
-%% CopyBothResponseщ
+%% CopyBothResponses
 handle_message(?COPY_IN_RESPONSE, <<BinOrText, NumColumns:?int16, Formats/binary>>, Sock,
 handle_message(?COPY_IN_RESPONSE, <<BinOrText, NumColumns:?int16, Formats/binary>>, Sock,
-               #copy_stdin{initiator = Initiator}) ->
+               #copy_stdin{initiator = Initiator, format = RequestedFormat}) ->
     ColumnFormats =
     ColumnFormats =
         [case Format of
         [case Format of
              0 -> text;
              0 -> text;
              1 -> binary
              1 -> binary
          end || <<Format:?int16>> <= Formats],
          end || <<Format:?int16>> <= Formats],
     length(ColumnFormats) =:= NumColumns orelse error(invalid_copy_in_response),
     length(ColumnFormats) =:= NumColumns orelse error(invalid_copy_in_response),
-    case BinOrText of
-        0 ->
-            %% When BinOrText is 0, all "columns" should be 0 format as well.
-            %% See https://www.postgresql.org/docs/current/protocol-message-formats.html
-            %% CopyInResponse
-            (lists:member(binary, ColumnFormats) == false)
-                orelse error(invalid_copy_in_response);
-        _ ->
-            ok
-    end,
-    CopyState = #copy{initiator = Initiator},
+    CopyState = init_copy_state(BinOrText, RequestedFormat, ColumnFormats, Initiator),
     Sock1 = epgsql_sock:set_attr(subproto_state, CopyState, Sock),
     Sock1 = epgsql_sock:set_attr(subproto_state, CopyState, Sock),
     Res = {ok, ColumnFormats},
     Res = {ok, ColumnFormats},
     {finish, Res, Res, epgsql_sock:set_packet_handler(on_copy_from_stdin, Sock1)};
     {finish, Res, Res, epgsql_sock:set_packet_handler(on_copy_from_stdin, Sock1)};
@@ -69,3 +76,28 @@ handle_message(?ERROR, Error, _Sock, _State) ->
     {sync_required, Result};
     {sync_required, Result};
 handle_message(_, _, _, _) ->
 handle_message(_, _, _, _) ->
     unknown.
     unknown.
+
+init_copy_state(0, text, ColumnFormats, Initiator) ->
+    %% When BinOrText is 0, all "columns" should be 0 format as well.
+    %% See https://www.postgresql.org/docs/current/protocol-message-formats.html
+    %% CopyInResponse
+    (lists:member(binary, ColumnFormats) == false)
+        orelse error(invalid_copy_in_response),
+    #copy{initiator = Initiator, format = text};
+init_copy_state(1, {binary, ColumnTypes}, ColumnFormats, Initiator) ->
+    %% https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-COPY
+    %% "As of the present implementation, all columns in a given COPY operation will use the same
+    %% format, but the message design does not assume this."
+    (lists:member(text, ColumnFormats) == false)
+        orelse error(invalid_copy_in_response),
+    NumColumns = length(ColumnFormats),
+    %% Eg, `epgsql:copy_from_stdin(C, "COPY tab (a, b, c) WITH (FORMAT binary)", {binary, [int2, int4]})'
+    %% so number of columns in SQL is not same as number of types in `binary'
+    (NumColumns == length(ColumnTypes))
+        orelse error({column_count_mismatch, ColumnTypes, NumColumns}),
+    #copy{initiator = Initiator, format = binary, binary_types = ColumnTypes};
+init_copy_state(ServerExpectedFormat, RequestedFormat, _, _Initiator) ->
+    %% Eg, `epgsql:copy_from_stdin(C, "COPY ... WITH (FORMAT text)", {binary, ...})' or
+    %% `epgsql:copy_from_stdin(C, "COPY ... WITH (FORMAT binary)", text)' or maybe PostgreSQL
+    %% got some new format epgsql is not aware of
+    error({format_mismatch, RequestedFormat, ServerExpectedFormat}).

+ 28 - 7
src/epgsql.erl

@@ -29,6 +29,8 @@
          with_transaction/3,
          with_transaction/3,
          sync_on_error/2,
          sync_on_error/2,
          copy_from_stdin/2,
          copy_from_stdin/2,
+         copy_from_stdin/3,
+         copy_send_rows/3,
          copy_done/1,
          copy_done/1,
          standby_status_update/3,
          standby_status_update/3,
          start_replication/5,
          start_replication/5,
@@ -450,10 +452,16 @@ sync_on_error(C, Error = {error, _}) ->
 sync_on_error(_C, R) ->
 sync_on_error(_C, R) ->
     R.
     R.
 
 
+%% @equiv copy_from_stdin(C, SQL, text)
+copy_from_stdin(C, SQL) ->
+    copy_from_stdin(C, SQL, text).
+
 %% @doc Switches epgsql into COPY-mode
 %% @doc Switches epgsql into COPY-mode
 %%
 %%
-%% Erlang IO-protocol can be used to transfer "raw" COPY data to the server (see, eg,
-%% `io:put_chars/2' and `file:write/2' etc).
+%% When `Format' is `text', Erlang IO-protocol should be used to transfer "raw" COPY data to the
+%% server (see, eg, `io:put_chars/2' and `file:write/2' etc).
+%%
+%% When `Format' is `{binary, Types}', {@link copy_send_rows/3} should be used instead.
 %%
 %%
 %% In case COPY-payload is invalid, asynchronous message of the form
 %% In case COPY-payload is invalid, asynchronous message of the form
 %% `{epgsql, connection(), {error, epgsql:query_error()}}' (similar to asynchronous notification,
 %% `{epgsql, connection(), {error, epgsql:query_error()}}' (similar to asynchronous notification,
@@ -462,14 +470,27 @@ sync_on_error(_C, R) ->
 %% It's important to not call `copy_done' if such error is detected!
 %% It's important to not call `copy_done' if such error is detected!
 %%
 %%
 %% @param SQL have to be `COPY ... FROM STDIN ...' statement
 %% @param SQL have to be `COPY ... FROM STDIN ...' statement
--spec copy_from_stdin(connection(), sql_query()) ->
+%% @param Format data transfer format specification: `text' or `{binary, epgsql_type()}'. Have to
+%%        match `WHERE (FORMAT ???)' from SQL (`text' for `text'/`csv' OR `{binary, ..}' for `binary').
+-spec copy_from_stdin(connection(), sql_query(), text | {binary, [epgsql_type()]}) ->
           epgsql_cmd_copy_from_stdin:response().
           epgsql_cmd_copy_from_stdin:response().
-copy_from_stdin(C, SQL) ->
-    epgsql_sock:sync_command(C, epgsql_cmd_copy_from_stdin, {SQL, self()}).
+copy_from_stdin(C, SQL, Format) ->
+    epgsql_sock:sync_command(C, epgsql_cmd_copy_from_stdin, {SQL, self(), Format}).
+
+%% @doc Send a batch of rows to `COPY .. FROM STDIN WITH (FORMAT binary)' in Erlang format
+%%
+%% Erlang values will be converted to postgres types same way as parameters of, eg, {@link equery/3}
+%% using data type specification from 3rd argument of {@link copy_from_stdin/3} (number of columns in
+%% each element of `Rows' should match the number of elements in `{binary, Types}').
+%% @param Rows might be a list of tuples or list of lists. List of lists is slightly more efficient.
+-spec copy_send_rows(connection(), [tuple() | [bind_param()]], timeout()) -> ok | {error, ErrReason} when
+      ErrReason :: not_in_copy_mode | not_binary_format | query_error().
+copy_send_rows(C, Rows, Timeout) ->
+    epgsql_sock:copy_send_rows(C, Rows, Timeout).
 
 
 %% @doc Tells server that the transfer of COPY data is done
 %% @doc Tells server that the transfer of COPY data is done
 %%
 %%
-%% Stops copy-mode and returns number of inserted rows
+%% Stops copy-mode and returns the number of inserted rows.
 -spec copy_done(connection()) -> epgsql_cmd_copy_done:response().
 -spec copy_done(connection()) -> epgsql_cmd_copy_done:response().
 copy_done(C) ->
 copy_done(C) ->
     epgsql_sock:sync_command(C, epgsql_cmd_copy_done, []).
     epgsql_sock:sync_command(C, epgsql_cmd_copy_done, []).
@@ -478,7 +499,7 @@ copy_done(C) ->
 %% @doc sends last flushed and applied WAL positions to the server in a standby status update message via
 %% @doc sends last flushed and applied WAL positions to the server in a standby status update message via
 %% given `Connection'
 %% given `Connection'
 standby_status_update(Connection, FlushedLSN, AppliedLSN) ->
 standby_status_update(Connection, FlushedLSN, AppliedLSN) ->
-    gen_server:call(Connection, {standby_status_update, FlushedLSN, AppliedLSN}).
+    epgsql_sock:standby_status_update(Connection, FlushedLSN, AppliedLSN).
 
 
 handle_x_log_data(Mod, StartLSN, EndLSN, WALRecord, Repl) ->
 handle_x_log_data(Mod, StartLSN, EndLSN, WALRecord, Repl) ->
     Mod:handle_x_log_data(StartLSN, EndLSN, WALRecord, Repl).
     Mod:handle_x_log_data(StartLSN, EndLSN, WALRecord, Repl).

+ 3 - 1
src/epgsql_copy.hrl

@@ -3,5 +3,7 @@
          %% pid of the process that started the COPY. It is used to receive asynchronous error
          %% pid of the process that started the COPY. It is used to receive asynchronous error
          %% messages when some error in data stream was detected
          %% messages when some error in data stream was detected
          initiator :: pid(),
          initiator :: pid(),
-         last_error :: undefined | epgsql:query_error()
+         last_error :: undefined | epgsql:query_error(),
+         format :: binary | text,
+         binary_types :: [epgsql:epgsql_type()] | undefined
         }).
         }).

+ 37 - 2
src/epgsql_sock.erl

@@ -30,6 +30,10 @@
 %%% some conflicting low-level commands (such as `parse', `bind', `execute') are
 %%% some conflicting low-level commands (such as `parse', `bind', `execute') are
 %%% executed in a wrong order. In this case server and epgsql states become out of
 %%% executed in a wrong order. In this case server and epgsql states become out of
 %%% sync and {@link epgsql_cmd_sync} have to be executed in order to recover.
 %%% sync and {@link epgsql_cmd_sync} have to be executed in order to recover.
+%%%
+%%% {@link epgsql_cmd_copy_from_stdin} and {@link epgsql_cmd_start_replication} switches the
+%%% "state machine" of connection process to a special "COPY mode" subprotocol.
+%%% See [https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-COPY].
 %%% @see epgsql_cmd_connect. epgsql_cmd_connect for network connection and authentication setup
 %%% @see epgsql_cmd_connect. epgsql_cmd_connect for network connection and authentication setup
 %%% @end
 %%% @end
 %%% Copyright (C) 2009 - Will Glozer.  All rights reserved.
 %%% Copyright (C) 2009 - Will Glozer.  All rights reserved.
@@ -46,7 +50,9 @@
          get_parameter/2,
          get_parameter/2,
          set_notice_receiver/2,
          set_notice_receiver/2,
          get_cmd_status/1,
          get_cmd_status/1,
-         cancel/1]).
+         cancel/1,
+         copy_send_rows/3,
+         standby_status_update/3]).
 
 
 -export([handle_call/3, handle_cast/2, handle_info/2]).
 -export([handle_call/3, handle_cast/2, handle_info/2]).
 -export([init/1, code_change/3, terminate/2]).
 -export([init/1, code_change/3, terminate/2]).
@@ -133,6 +139,12 @@ get_cmd_status(C) ->
 cancel(S) ->
 cancel(S) ->
     gen_server:cast(S, cancel).
     gen_server:cast(S, cancel).
 
 
+copy_send_rows(C, Rows, Timeout) ->
+    gen_server:call(C, {copy_send_rows, Rows}, Timeout).
+
+standby_status_update(C, FlushedLSN, AppliedLSN) ->
+    gen_server:call(C, {standby_status_update, FlushedLSN, AppliedLSN}).
+
 
 
 %% -- command APIs --
 %% -- command APIs --
 
 
@@ -218,7 +230,12 @@ handle_call({standby_status_update, FlushedLSN, AppliedLSN}, _From,
     send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update(ReceivedLSN, FlushedLSN, AppliedLSN)),
     send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update(ReceivedLSN, FlushedLSN, AppliedLSN)),
     Repl1 = Repl#repl{last_flushed_lsn = FlushedLSN,
     Repl1 = Repl#repl{last_flushed_lsn = FlushedLSN,
                       last_applied_lsn = AppliedLSN},
                       last_applied_lsn = AppliedLSN},
-    {reply, ok, State#state{subproto_state = Repl1}}.
+    {reply, ok, State#state{subproto_state = Repl1}};
+
+handle_call({copy_send_rows, Rows}, _From,
+           #state{handler = Handler, subproto_state = CopyState} = State) ->
+    Response = handle_copy_send_rows(Rows, Handler, CopyState, State),
+    {reply, Response, State}.
 
 
 handle_cast({{Method, From, Ref} = Transport, Command, Args}, State)
 handle_cast({{Method, From, Ref} = Transport, Command, Args}, State)
   when ((Method == cast) or (Method == incremental)),
   when ((Method == cast) or (Method == incremental)),
@@ -539,6 +556,24 @@ try_requests([], _, LastRes) ->
 io_reply(Result, From, ReplyAs) ->
 io_reply(Result, From, ReplyAs) ->
     From ! {io_reply, ReplyAs, Result}.
     From ! {io_reply, ReplyAs, Result}.
 
 
+%% @doc Handler for `copy_send_rows' API
+%%
+%% Only supports binary protocol right now.
+%% But, in theory, can be used for text / csv formats as well, but we would need to add
+%% some more callbacks to `epgsql_type' behaviour (eg, `encode_text')
+handle_copy_send_rows(_Rows, Handler, _CopyState, _State) when Handler =/= on_copy_from_stdin ->
+    {error, not_in_copy_mode};
+handle_copy_send_rows(_, _, #copy{format = Format}, _) when Format =/= binary ->
+    %% copy_send_rows only supports "binary" format
+    {error, not_binary_format};
+handle_copy_send_rows(_, _, #copy{last_error = LastError}, _) when LastError =/= undefined ->
+    %% server already reported error in data stream asynchronously
+    {error, LastError};
+handle_copy_send_rows(Rows, _, #copy{binary_types = Types}, State) ->
+    Data = [epgsql_wire:encode_copy_row(Values, Types, get_codec(State))
+            || Values <- Rows],
+    ok = send(State, ?COPY_DATA, Data).
+
 encode_chars(_, Bin) when is_binary(Bin) ->
 encode_chars(_, Bin) when is_binary(Bin) ->
     Bin;
     Bin;
 encode_chars(unicode, Chars) when is_list(Chars) ->
 encode_chars(unicode, Chars) when is_list(Chars) ->

+ 39 - 2
src/epgsql_wire.erl

@@ -23,7 +23,10 @@
          encode_formats/1,
          encode_formats/1,
          format/2,
          format/2,
          encode_parameters/2,
          encode_parameters/2,
-         encode_standby_status_update/3]).
+         encode_standby_status_update/3,
+         encode_copy_header/0,
+         encode_copy_row/3,
+         encode_copy_trailer/0]).
 %% Encoders for Client -> Server packets
 %% Encoders for Client -> Server packets
 -export([encode_query/1,
 -export([encode_query/1,
          encode_parse/3,
          encode_parse/3,
@@ -253,7 +256,8 @@ format(#column{oid = Oid}, Codec) ->
     end.
     end.
 
 
 %% @doc encode parameters for 'Bind'
 %% @doc encode parameters for 'Bind'
--spec encode_parameters([], epgsql_binary:codec()) -> iolist().
+-spec encode_parameters([{epgsql:epgsql_type(), epgsql:bind_param()}],
+                        epgsql_binary:codec()) -> iolist().
 encode_parameters(Parameters, Codec) ->
 encode_parameters(Parameters, Codec) ->
     encode_parameters(Parameters, 0, <<>>, [], Codec).
     encode_parameters(Parameters, 0, <<>>, [], Codec).
 
 
@@ -312,6 +316,39 @@ encode_standby_status_update(ReceivedLSN, FlushedLSN, AppliedLSN) ->
     Timestamp = ((MegaSecs * 1000000 + Secs) * 1000000 + MicroSecs) - 946684800*1000000,
     Timestamp = ((MegaSecs * 1000000 + Secs) * 1000000 + MicroSecs) - 946684800*1000000,
     <<$r:8, ReceivedLSN:?int64, FlushedLSN:?int64, AppliedLSN:?int64, Timestamp:?int64, 0:8>>.
     <<$r:8, ReceivedLSN:?int64, FlushedLSN:?int64, AppliedLSN:?int64, Timestamp:?int64, 0:8>>.
 
 
+%% @doc encode binary copy data file header
+%%
+%% See [https://www.postgresql.org/docs/current/sql-copy.html#id-1.9.3.55.9.4.5]
+encode_copy_header() ->
+    <<
+      "PGCOPY\n", 8#377, "\r\n", 0,             % "signature"
+      0:?int32,                                 % flags
+      0:?int32                                  % length of the extensions area
+    >>.
+
+%% @doc encode binary copy data file row / tuple
+%%
+%% See [https://www.postgresql.org/docs/current/sql-copy.html#id-1.9.3.55.9.4.6]
+encode_copy_row(ValuesTuple, Types, Codec) when is_tuple(ValuesTuple) ->
+    encode_copy_row(tuple_to_list(ValuesTuple), Types, Codec);
+encode_copy_row(Values, Types, Codec) ->
+    NumCols = length(Types),
+    [<<NumCols:?int16>>
+    | [
+       case epgsql_binary:is_null(Value, Codec) of
+           true ->
+               <<-1:?int32>>;
+           false ->
+               epgsql_binary:encode(Type, Value, Codec)
+       end || {Type, Value} <- lists:zip(Types, Values) % TODO: parallel iteration ninstead
+      ]].
+
+%% @doc encode binary copy data file header
+%%
+%% See [https://www.postgresql.org/docs/current/sql-copy.html#id-1.9.3.55.9.4.7]
+encode_copy_trailer() ->
+    <<-1:?int16>>.
+
 %%
 %%
 %% Encoders for various PostgreSQL protocol client-side packets
 %% Encoders for various PostgreSQL protocol client-side packets
 %% See https://www.postgresql.org/docs/current/protocol-message-formats.html
 %% See https://www.postgresql.org/docs/current/protocol-message-formats.html

+ 74 - 0
test/epgsql_copy_SUITE.erl

@@ -10,6 +10,7 @@
 
 
     from_stdin_text/1,
     from_stdin_text/1,
     from_stdin_csv/1,
     from_stdin_csv/1,
+    from_stdin_binary/1,
     from_stdin_io_apis/1,
     from_stdin_io_apis/1,
     from_stdin_with_terminator/1,
     from_stdin_with_terminator/1,
     from_stdin_corrupt_data/1
     from_stdin_corrupt_data/1
@@ -25,6 +26,7 @@ all() ->
     [
     [
      from_stdin_text,
      from_stdin_text,
      from_stdin_csv,
      from_stdin_csv,
+     from_stdin_binary,
      from_stdin_io_apis,
      from_stdin_io_apis,
      from_stdin_with_terminator,
      from_stdin_with_terminator,
      from_stdin_corrupt_data
      from_stdin_corrupt_data
@@ -108,6 +110,56 @@ from_stdin_csv(Config) ->
                                  " WHERE id IN (20, 21, 22, 23, 24) ORDER BY id"))
                                  " WHERE id IN (20, 21, 22, 23, 24) ORDER BY id"))
         end).
         end).
 
 
+%% @doc Test that COPY in binary format works
+from_stdin_binary(Config) ->
+    Module = ?config(module, Config),
+    epgsql_ct:with_connection(
+        Config,
+        fun(C) ->
+                ?assertEqual(
+                   {ok, [binary, binary]},
+                   Module:copy_from_stdin(
+                     C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT binary)",
+                     {binary, [int4, text]})),
+                %% Batch of rows
+                ?assertEqual(
+                   ok,
+                   Module:copy_send_rows(
+                     C,
+                     [{60, <<"hello world">>},
+                      {61, null},
+                      {62, "line 62"}],
+                     5000)),
+                %% Single row
+                ?assertEqual(
+                   ok,
+                   Module:copy_send_rows(
+                     C,
+                     [{63, <<"line 63">>}],
+                     1000)),
+                %% Rows as lists
+                ?assertEqual(
+                   ok,
+                   Module:copy_send_rows(
+                     C,
+                     [
+                      [64, <<"line 64">>],
+                      [65, <<"line 65">>]
+                     ],
+                     infinity)),
+                ?assertEqual({ok, 6}, Module:copy_done(C)),
+                ?assertMatch(
+                   {ok, _, [{60, <<"hello world">>},
+                            {61, null},
+                            {62, <<"line 62">>},
+                            {63, <<"line 63">>},
+                            {64, <<"line 64">>},
+                            {65, <<"line 65">>}]},
+                   Module:equery(C,
+                                 "SELECT id, value FROM test_table1"
+                                 " WHERE id IN (60, 61, 62, 63, 64, 65) ORDER BY id"))
+        end).
+
 %% @doc Tests that different IO-protocol APIs work
 %% @doc Tests that different IO-protocol APIs work
 from_stdin_io_apis(Config) ->
 from_stdin_io_apis(Config) ->
     Module = ?config(module, Config),
     Module = ?config(module, Config),
@@ -228,6 +280,7 @@ from_stdin_corrupt_data(Config) ->
                 ?assertEqual({error, {fun_return_not_characters, node()}},
                 ?assertEqual({error, {fun_return_not_characters, node()}},
                              io:request(C, {put_chars, unicode, erlang, node, []})),
                              io:request(C, {put_chars, unicode, erlang, node, []})),
                 ?assertEqual({ok, 0}, Module:copy_done(C)),
                 ?assertEqual({ok, 0}, Module:copy_done(C)),
+                %%
                 %% Corrupt text format
                 %% Corrupt text format
                 ?assertEqual(
                 ?assertEqual(
                    {ok, [text, text]},
                    {ok, [text, text]},
@@ -248,6 +301,7 @@ from_stdin_corrupt_data(Config) ->
                 ?assertEqual({error, not_in_copy_mode},
                 ?assertEqual({error, not_in_copy_mode},
                              io:request(C, {put_chars, unicode, "queque\n"})),
                              io:request(C, {put_chars, unicode, "queque\n"})),
                 ?assertError(badarg, io:format(C, "~w\n~s\n", [60, "wasd"])),
                 ?assertError(badarg, io:format(C, "~w\n~s\n", [60, "wasd"])),
+                %%
                 %% Corrupt CSV format
                 %% Corrupt CSV format
                 ?assertEqual(
                 ?assertEqual(
                    {ok, [text, text]},
                    {ok, [text, text]},
@@ -265,6 +319,26 @@ from_stdin_corrupt_data(Config) ->
                    after 5000 ->
                    after 5000 ->
                            timeout
                            timeout
                    end),
                    end),
+                %%
+                %% Corrupt binary format
+                ?assertEqual(
+                   {ok, [binary, binary]},
+                   Module:copy_from_stdin(
+                     C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT binary)",
+                     {binary, [int4, text]})),
+                ?assertEqual(
+                   ok,
+                   Module:copy_send_rows(C, [{44, <<"line 44">>}], 1000)),
+                ?assertEqual(ok, io:put_chars(C, "45\tThis is not ok!\n")),
+                ?assertMatch(
+                   #error{codename = bad_copy_file_format,
+                          severity = error},
+                   receive
+                       {epgsql, C, {error, Err}} ->
+                           Err
+                   after 5000 ->
+                           timeout
+                   end),
                 %% Connection is still usable
                 %% Connection is still usable
                 ?assertMatch(
                 ?assertMatch(
                    {ok, _, [{1}]},
                    {ok, _, [{1}]},