Browse Source

Add more tests; cleanup the way we handle errors in COPY data. GH-137

Sergey Prokhorov 4 years ago
parent
commit
ca40d81537

+ 2 - 0
src/commands/epgsql_cmd_copy_done.erl

@@ -15,11 +15,13 @@
 
 
 %% -include("epgsql.hrl").
 %% -include("epgsql.hrl").
 -include("protocol.hrl").
 -include("protocol.hrl").
+-include("../epgsql_copy.hrl").
 
 
 init(_) ->
 init(_) ->
     [].
     [].
 
 
 execute(Sock0, St) ->
 execute(Sock0, St) ->
+    #copy{} = epgsql_sock:get_subproto_state(Sock0), % assert we are in copy-mode
     {PktType, PktData} = epgsql_wire:encode_copy_done(),
     {PktType, PktData} = epgsql_wire:encode_copy_done(),
     Sock1 = epgsql_sock:set_packet_handler(on_message, Sock0),
     Sock1 = epgsql_sock:set_packet_handler(on_message, Sock0),
     Sock = epgsql_sock:set_attr(subproto_state, undefined, Sock1),
     Sock = epgsql_sock:set_attr(subproto_state, undefined, Sock1),

+ 6 - 5
src/commands/epgsql_cmd_copy_from_stdin.erl

@@ -31,10 +31,10 @@
 -include("../epgsql_copy.hrl").
 -include("../epgsql_copy.hrl").
 
 
 -record(copy_stdin,
 -record(copy_stdin,
-        {query :: iodata()}).
+        {query :: iodata(), initiator :: pid()}).
 
 
-init(SQL) ->
-    #copy_stdin{query = SQL}.
+init({SQL, Initiator}) ->
+    #copy_stdin{query = SQL, initiator = Initiator}.
 
 
 execute(Sock, #copy_stdin{query = SQL} = St) ->
 execute(Sock, #copy_stdin{query = SQL} = St) ->
     undefined = epgsql_sock:get_subproto_state(Sock), % assert we are not in copy-mode already
     undefined = epgsql_sock:get_subproto_state(Sock), % assert we are not in copy-mode already
@@ -42,7 +42,8 @@ execute(Sock, #copy_stdin{query = SQL} = St) ->
     {send, PktType, PktData, Sock, St}.
     {send, PktType, PktData, Sock, St}.
 
 
 %% CopyBothResponseщ
 %% CopyBothResponseщ
-handle_message(?COPY_IN_RESPONSE, <<BinOrText, NumColumns:?int16, Formats/binary>>, Sock, _State) ->
+handle_message(?COPY_IN_RESPONSE, <<BinOrText, NumColumns:?int16, Formats/binary>>, Sock,
+               #copy_stdin{initiator = Initiator}) ->
     ColumnFormats =
     ColumnFormats =
         [case Format of
         [case Format of
              0 -> text;
              0 -> text;
@@ -59,7 +60,7 @@ handle_message(?COPY_IN_RESPONSE, <<BinOrText, NumColumns:?int16, Formats/binary
         _ ->
         _ ->
             ok
             ok
     end,
     end,
-    CopyState = #copy{},
+    CopyState = #copy{initiator = Initiator},
     Sock1 = epgsql_sock:set_attr(subproto_state, CopyState, Sock),
     Sock1 = epgsql_sock:set_attr(subproto_state, CopyState, Sock),
     Res = {ok, ColumnFormats},
     Res = {ok, ColumnFormats},
     {finish, Res, Res, epgsql_sock:set_packet_handler(on_copy_from_stdin, Sock1)};
     {finish, Res, Res, epgsql_sock:set_packet_handler(on_copy_from_stdin, Sock1)};

+ 8 - 1
src/epgsql.erl

@@ -454,11 +454,18 @@ sync_on_error(_C, R) ->
 %%
 %%
 %% Erlang IO-protocol can be used to transfer "raw" COPY data to the server (see, eg,
 %% Erlang IO-protocol can be used to transfer "raw" COPY data to the server (see, eg,
 %% `io:put_chars/2' and `file:write/2' etc).
 %% `io:put_chars/2' and `file:write/2' etc).
+%%
+%% In case COPY-payload is invalid, asynchronous message of the form
+%% `{epgsql, connection(), {error, epgsql:query_error()}}' (similar to asynchronous notification,
+%% see {@link set_notice_receiver/2}) will be sent to the process that called `copy_from_stdin'
+%% and all the subsequent IO-protocol requests will return error.
+%% It's important to not call `copy_done' if such error is detected!
+%%
 %% @param SQL have to be `COPY ... FROM STDIN ...' statement
 %% @param SQL have to be `COPY ... FROM STDIN ...' statement
 -spec copy_from_stdin(connection(), sql_query()) ->
 -spec copy_from_stdin(connection(), sql_query()) ->
           epgsql_cmd_copy_from_stdin:response().
           epgsql_cmd_copy_from_stdin:response().
 copy_from_stdin(C, SQL) ->
 copy_from_stdin(C, SQL) ->
-    epgsql_sock:sync_command(C, epgsql_cmd_copy_from_stdin, SQL).
+    epgsql_sock:sync_command(C, epgsql_cmd_copy_from_stdin, {SQL, self()}).
 
 
 %% @doc Tells server that the transfer of COPY data is done
 %% @doc Tells server that the transfer of COPY data is done
 %%
 %%

+ 7 - 1
src/epgsql_copy.hrl

@@ -1 +1,7 @@
--record(copy, {}).
+-record(copy,
+        {
+         %% pid of the process that started the COPY. It is used to receive asynchronous error
+         %% messages when some error in data stream was detected
+         initiator :: pid(),
+         last_error :: undefined | epgsql:query_error()
+        }).

+ 18 - 7
src/epgsql_sock.erl

@@ -263,7 +263,7 @@ handle_info({inet_reply, _, ok}, State) ->
 handle_info({inet_reply, _, Status}, State) ->
 handle_info({inet_reply, _, Status}, State) ->
     {stop, Status, flush_queue(State, {error, Status})};
     {stop, Status, flush_queue(State, {error, Status})};
 
 
-handle_info({io_request, From, ReplyAs, Request}, #state{handler = on_copy_from_stdin} = State) ->
+handle_info({io_request, From, ReplyAs, Request}, State) ->
     Response = handle_io_request(Request, State),
     Response = handle_io_request(Request, State),
     io_reply(Response, From, ReplyAs),
     io_reply(Response, From, ReplyAs),
     {noreply, State}.
     {noreply, State}.
@@ -501,6 +501,12 @@ flush_queue(State, Error) ->
 %%
 %%
 %% COPY FROM STDIN is implemented as Erlang
 %% COPY FROM STDIN is implemented as Erlang
 %% <a href="https://erlang.org/doc/apps/stdlib/io_protocol.html">io protocol</a>.
 %% <a href="https://erlang.org/doc/apps/stdlib/io_protocol.html">io protocol</a>.
+handle_io_request(_, #state{handler = Handler}) when Handler =/= on_copy_from_stdin ->
+    %% Received IO request when `epgsql_cmd_copy_from_stdin' haven't yet been called or it was
+    %% terminated with error and already sent `ReadyForQuery'
+    {error, not_in_copy_mode};
+handle_io_request(_, #state{subproto_state = #copy{last_error = Err}}) when Err =/= undefined ->
+    {error, Err};
 handle_io_request({put_chars, Encoding, Chars}, State) ->
 handle_io_request({put_chars, Encoding, Chars}, State) ->
     send(State, ?COPY_DATA, encode_chars(Encoding, Chars));
     send(State, ?COPY_DATA, encode_chars(Encoding, Chars));
 handle_io_request({put_chars, Encoding, Mod, Fun, Args}, State) ->
 handle_io_request({put_chars, Encoding, Mod, Fun, Args}, State) ->
@@ -604,13 +610,18 @@ on_message(Msg, Payload, State) ->
 
 
 %% @doc Handle "copy subprotocol" for COPY .. FROM STDIN
 %% @doc Handle "copy subprotocol" for COPY .. FROM STDIN
 %%
 %%
-%% Activated by `epgsql_cmd_copy_from_stdin' and deactivated by `epgsql_cmd_copy_done'
-on_copy_from_stdin(?COMMAND_COMPLETE, Bin, State) ->
-    _Complete = epgsql_wire:decode_complete(Bin),
-    {noreply, State#state{subproto_state = undefined, handler = on_message}};
-on_copy_from_stdin(?ERROR, Err, State) ->
+%% Activated by `epgsql_cmd_copy_from_stdin', deactivated by `epgsql_cmd_copy_done' or error
+on_copy_from_stdin(?READY_FOR_QUERY, <<Status:8>>,
+                   #state{subproto_state = #copy{last_error = Err,
+                                                 initiator = Pid}} = State) when Err =/= undefined ->
+    %% Reporting error from here and not from ?ERROR so it's easier to be in sync state
+    Pid ! {epgsql, self(), {error, Err}},
+    {noreply, State#state{subproto_state = undefined,
+                          handler = on_message,
+                          txstatus = Status}};
+on_copy_from_stdin(?ERROR, Err, #state{subproto_state = SubState} = State) ->
     Reason = epgsql_wire:decode_error(Err),
     Reason = epgsql_wire:decode_error(Err),
-    {stop, {error, Reason}, State};
+    {noreply, State#state{subproto_state = SubState#copy{last_error = Reason}}};
 on_copy_from_stdin(M, Data, Sock) when M == ?NOTICE;
 on_copy_from_stdin(M, Data, Sock) when M == ?NOTICE;
                                        M == ?NOTIFICATION;
                                        M == ?NOTIFICATION;
                                        M == ?PARAMETER_STATUS ->
                                        M == ?PARAMETER_STATUS ->

+ 223 - 10
test/epgsql_copy_SUITE.erl

@@ -8,7 +8,11 @@
     all/0,
     all/0,
     end_per_suite/1,
     end_per_suite/1,
 
 
-    from_stdin_text/1
+    from_stdin_text/1,
+    from_stdin_csv/1,
+    from_stdin_io_apis/1,
+    from_stdin_with_terminator/1,
+    from_stdin_corrupt_data/1
 ]).
 ]).
 
 
 init_per_suite(Config) ->
 init_per_suite(Config) ->
@@ -19,14 +23,14 @@ end_per_suite(_Config) ->
 
 
 all() ->
 all() ->
     [
     [
-     from_stdin_text%% ,
-     %% from_stdin_csv,
-     %% from_stdin_io_apis,
-     %% from_stdin_fragmented,
-     %% from_stdin_with_terminator,
-     %% from_stdin_corrupt_data
+     from_stdin_text,
+     from_stdin_csv,
+     from_stdin_io_apis,
+     from_stdin_with_terminator,
+     from_stdin_corrupt_data
     ].
     ].
 
 
+%% @doc Test that COPY in text format works
 from_stdin_text(Config) ->
 from_stdin_text(Config) ->
     Module = ?config(module, Config),
     Module = ?config(module, Config),
     epgsql_ct:with_connection(
     epgsql_ct:with_connection(
@@ -46,14 +50,223 @@ from_stdin_text(Config) ->
                    ok,
                    ok,
                    io:put_chars(C, "13\tline 13\n")),
                    io:put_chars(C, "13\tline 13\n")),
                 ?assertEqual(
                 ?assertEqual(
-                   {ok, 4},
+                   ok,
+                   io:put_chars(C, "14\tli")),
+                ?assertEqual(
+                   ok,
+                   io:put_chars(C, "ne 14\n")),
+                ?assertEqual(
+                   {ok, 5},
                    Module:copy_done(C)),
                    Module:copy_done(C)),
                 ?assertMatch(
                 ?assertMatch(
                    {ok, _, [{10, <<"hello world">>},
                    {ok, _, [{10, <<"hello world">>},
                             {11, null},
                             {11, null},
                             {12, <<"line 12">>},
                             {12, <<"line 12">>},
-                            {13, <<"line 13">>}]},
+                            {13, <<"line 13">>},
+                            {14, <<"line 14">>}]},
+                   Module:equery(C,
+                                 "SELECT id, value FROM test_table1"
+                                 " WHERE id IN (10, 11, 12, 13, 14) ORDER BY id"))
+        end).
+
+%% @doc Test that COPY in CSV format works
+from_stdin_csv(Config) ->
+    Module = ?config(module, Config),
+    epgsql_ct:with_connection(
+        Config,
+        fun(C) ->
+                ?assertEqual(
+                   {ok, [text, text]},
+                   Module:copy_from_stdin(
+                     C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT csv, QUOTE '''')")),
+                ?assertEqual(
+                   ok,
+                   io:put_chars(C,
+                                "20,'hello world'\n"
+                                "21,\n"
+                                "22,line 22\n")),
+                ?assertEqual(
+                   ok,
+                   io:put_chars(C, "23,'line 23'\n")),
+                ?assertEqual(
+                   ok,
+                   io:put_chars(C, "24,'li")),
+                ?assertEqual(
+                   ok,
+                   io:put_chars(C, "ne 24'\n")),
+                ?assertEqual(
+                   {ok, 5},
+                   Module:copy_done(C)),
+                ?assertMatch(
+                   {ok, _, [{20, <<"hello world">>},
+                            {21, null},
+                            {22, <<"line 22">>},
+                            {23, <<"line 23">>},
+                            {24, <<"line 24">>}]},
+                   Module:equery(C,
+                                 "SELECT id, value FROM test_table1"
+                                 " WHERE id IN (20, 21, 22, 23, 24) ORDER BY id"))
+        end).
+
+%% @doc Tests that different IO-protocol APIs work
+from_stdin_io_apis(Config) ->
+    Module = ?config(module, Config),
+    epgsql_ct:with_connection(
+        Config,
+        fun(C) ->
+                ?assertEqual(
+                   {ok, [text, text]},
+                   Module:copy_from_stdin(
+                     C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT text)")),
+                ?assertEqual(ok, io:format(C, "30\thello world\n", [])),
+                ?assertEqual(ok, io:format(C, "~b\t~s\n", [31, "line 31"])),
+                %% Output "32\thello\n" in multiple calls
+                ?assertEqual(ok, io:write(C, 32)),
+                ?assertEqual(ok, io:put_chars(C, "\t")),
+                ?assertEqual(ok, io:write(C, hello)),
+                ?assertEqual(ok, io:nl(C)),
+                %% Using `file` API
+                ?assertEqual(ok, file:write(C, "33\tline 33\n34\tline 34\n")),
+                %% Binary
+                ?assertEqual(ok, io:put_chars(C, <<"35\tline 35\n">>)),
+                ?assertEqual(ok, file:write(C, <<"36\tline 36\n">>)),
+                %% IoData
+                ?assertEqual(ok, io:put_chars(C, [<<"37">>, $\t, <<"line 37">>, <<$\n>>])),
+                ?assertEqual(ok, file:write(C, [["38", <<$\t>>], [<<"line 38">>, $\n]])),
+                %% Raw IO-protocol message-passing
+                C ! {io_request, self(), ?FUNCTION_NAME, {put_chars, unicode, "39\tline 39\n"}},
+                ?assertEqual(ok, receive {io_reply, ?FUNCTION_NAME, Resp} -> Resp
+                                 after 5000 ->
+                                         timeout
+                                 end),
+                %% Not documented!
+                ?assertEqual(ok, io:requests(
+                                   C,
+                                   [{put_chars, unicode, "40\tline 40\n"},
+                                    {put_chars, latin1, "41\tline 41\n"},
+                                    {format, "~w\t~s", [42, "line 42"]},
+                                    nl])),
+                ?assertEqual(
+                   {ok, 13},
+                   Module:copy_done(C)),
+                ?assertMatch(
+                   {ok, _, [{30, <<"hello world">>},
+                            {31, <<"line 31">>},
+                            {32, <<"hello">>},
+                            {33, <<"line 33">>},
+                            {34, <<"line 34">>},
+                            {35, <<"line 35">>},
+                            {36, <<"line 36">>},
+                            {37, <<"line 37">>},
+                            {38, <<"line 38">>},
+                            {39, <<"line 39">>},
+                            {40, <<"line 40">>},
+                            {41, <<"line 41">>},
+                            {42, <<"line 42">>}
+                            ]},
+                   Module:equery(
+                     C,
+                     "SELECT id, value FROM test_table1"
+                     " WHERE id IN (30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42)"
+                     " ORDER BY id"))
+        end).
+
+%% @doc Tests that "end-of-data" terminator is successfully ignored
+from_stdin_with_terminator(Config) ->
+    Module = ?config(module, Config),
+    epgsql_ct:with_connection(
+        Config,
+        fun(C) ->
+                %% TEXT
+                ?assertEqual(
+                   {ok, [text, text]},
+                   Module:copy_from_stdin(
+                     C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT text)")),
+                ?assertEqual(ok, io:put_chars(
+                                   C,
+                                   "50\tline 50\n"
+                                   "51\tline 51\n"
+                                   "\\.\n")),
+                ?assertEqual({ok, 2}, Module:copy_done(C)),
+                %% CSV
+                ?assertEqual(
+                   {ok, [text, text]},
+                   Module:copy_from_stdin(
+                     C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT csv)")),
+                ?assertEqual(ok, io:put_chars(
+                                   C,
+                                   "52,line 52\n"
+                                   "53,line 53\n"
+                                   "\\.\n")),
+                ?assertEqual({ok, 2}, Module:copy_done(C)),
+                ?assertMatch(
+                   {ok, _, [{50, <<"line 50">>},
+                            {51, <<"line 51">>},
+                            {52, <<"line 52">>},
+                            {53, <<"line 53">>}
+                            ]},
                    Module:equery(C,
                    Module:equery(C,
                                  "SELECT id, value FROM test_table1"
                                  "SELECT id, value FROM test_table1"
-                                 " WHERE id IN (10, 11, 12, 13) ORDER BY id"))
+                                 " WHERE id IN (50, 51, 52, 53) ORDER BY id"))
+        end).
+
+from_stdin_corrupt_data(Config) ->
+    Module = ?config(module, Config),
+    epgsql_ct:with_connection(
+        Config,
+        fun(C) ->
+                ?assertEqual(
+                   {ok, [text, text]},
+                   Module:copy_from_stdin(
+                     C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT text)")),
+                %% Wrong number of arguments to io:format
+                Fmt = "~w\t~s\n",
+                ?assertMatch({error, {fun_exception, {error, badarg, _Stack}}},
+                             io:request(C, {format, Fmt, []})),
+                ?assertError(badarg, io:format(C, Fmt, [])),
+                %% Wrong return value from IO function
+                ?assertEqual({error, {fun_return_not_characters, node()}},
+                             io:request(C, {put_chars, unicode, erlang, node, []})),
+                ?assertEqual({ok, 0}, Module:copy_done(C)),
+                %% Corrupt text format
+                ?assertEqual(
+                   {ok, [text, text]},
+                   Module:copy_from_stdin(
+                     C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT text)")),
+                ?assertEqual(ok, io:put_chars(
+                                   C,
+                                   "42\n43\nwasd\n")),
+                ?assertMatch(
+                   #error{codename = bad_copy_file_format,
+                          severity = error},
+                   receive
+                       {epgsql, C, {error, Err}} ->
+                           Err
+                   after 5000 ->
+                           timeout
+                   end),
+                ?assertEqual({error, not_in_copy_mode},
+                             io:request(C, {put_chars, unicode, "queque\n"})),
+                ?assertError(badarg, io:format(C, "~w\n~s\n", [60, "wasd"])),
+                %% Corrupt CSV format
+                ?assertEqual(
+                   {ok, [text, text]},
+                   Module:copy_from_stdin(
+                     C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT csv)")),
+                ?assertEqual(ok, io:put_chars(
+                                   C,
+                                   "42\n43\nwasd\n")),
+                ?assertMatch(
+                   #error{codename = bad_copy_file_format,
+                          severity = error},
+                   receive
+                       {epgsql, C, {error, Err}} ->
+                           Err
+                   after 5000 ->
+                           timeout
+                   end),
+                %% Connection is still usable
+                ?assertMatch(
+                   {ok, _, [{1}]},
+                   Module:equery(C, "SELECT 1", []))
         end).
         end).