Browse Source

Merge pull request #243 from seriyps/command-api-improvements

Command api improvements
Sergey Prokhorov 4 years ago
parent
commit
11d4321b8c

+ 4 - 2
doc/pluggable_commands.md

@@ -49,12 +49,14 @@ passed to all subsequent callbacks. No PostgreSQL interactions should be done he
 ```erlang
 execute(pg_sock(), state()) ->
     {ok, pg_sock(), state()}
+  | {send, epgsql_wire:packet_type(), iodata(), pg_sock(), state()}
+  | {send_multi, [{epgsql_wire:packet_type(), iodata()}], pg_sock(), state()}
   | {stop, Reason :: any(), Response :: any(), pg_sock()}.
-
 ```
 
 Client -> Server packets should be sent from this callback by `epgsql_sock:send_multi/2` or
-`epgsql_sock:send/3`. `epgsql_wire` module is usually used to create wire protocol packets.
+`epgsql_sock:send/3` or by returning equivalent `send` or `send_multi` values.
+`epgsql_wire` module is usually used to create wire protocol packets.
 Please note that many packets might be sent at once. See `epgsql_cmd_equery` as an example.
 
 This callback might be executed more than once for a single command execution if your command

+ 6 - 8
src/commands/epgsql_cmd_batch.erl

@@ -58,10 +58,9 @@ execute(Sock, #batch{batch = Batch, statement = undefined} = State) ->
                   BinFormats = epgsql_wire:encode_formats(Columns),
                   add_command(StatementName, Types, Parameters, BinFormats, Codec, Acc)
           end,
-          [{?SYNC, []}],
+          [epgsql_wire:encode_sync()],
           Batch),
-    epgsql_sock:send_multi(Sock, Commands),
-    {ok, Sock, State};
+    {send_multi, Commands, Sock, State};
 execute(Sock, #batch{batch = Batch,
                      statement = #statement{name = StatementName,
                                             columns = Columns,
@@ -74,16 +73,15 @@ execute(Sock, #batch{batch = Batch,
           fun(Parameters, Acc) ->
                   add_command(StatementName, Types, Parameters, BinFormats, Codec, Acc)
           end,
-          [{?SYNC, []}],
+          [epgsql_wire:encode_sync()],
           Batch),
-    epgsql_sock:send_multi(Sock, Commands),
-    {ok, Sock, State}.
+    {send_multi, Commands, Sock, State}.
 
 add_command(StmtName, Types, Params, BinFormats, Codec, Acc) ->
     TypedParameters = lists:zip(Types, Params),
     BinParams = epgsql_wire:encode_parameters(TypedParameters, Codec),
-    [{?BIND, [0, StmtName, 0, BinParams, BinFormats]},
-     {?EXECUTE, [0, <<0:?int32>>]} | Acc].
+    [epgsql_wire:encode_bind("", StmtName, BinParams, BinFormats),
+     epgsql_wire:encode_execute("", 0) | Acc].
 
 handle_message(?BIND_COMPLETE, <<>>, Sock, State) ->
     Columns = current_cols(State),

+ 6 - 8
src/commands/epgsql_cmd_bind.erl

@@ -1,4 +1,4 @@
-%% @doc Binds placeholder parameters to prepared statement
+%% @doc Binds placeholder parameters to prepared statement, creating a "portal"
 %%
 %% ```
 %% > Bind
@@ -30,13 +30,11 @@ execute(Sock, #bind{stmt = Stmt, portal = PortalName, params = Params} = St) ->
     TypedParams = lists:zip(Types, Params),
     Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec),
     Bin2 = epgsql_wire:encode_formats(Columns),
-    epgsql_sock:send_multi(
-      Sock,
-      [
-       {?BIND, [PortalName, 0, StatementName, 0, Bin1, Bin2]},
-       {?FLUSH, []}
-      ]),
-    {ok, Sock, St}.
+    Commands = [
+       epgsql_wire:encode_bind(PortalName, StatementName, Bin1, Bin2),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 handle_message(?BIND_COMPLETE, <<>>, Sock, _State) ->
     {finish, ok, ok, Sock};

+ 5 - 11
src/commands/epgsql_cmd_close.erl

@@ -22,17 +22,11 @@ init({Type, Name}) ->
     #close{type = Type, name = Name}.
 
 execute(Sock, #close{type = Type, name = Name} = St) ->
-    Type2 = case Type of
-        statement -> ?PREPARED_STATEMENT;
-        portal    -> ?PORTAL
-    end,
-    epgsql_sock:send_multi(
-      Sock,
-      [
-       {?CLOSE, [Type2, Name, 0]},
-       {?FLUSH, []}
-      ]),
-    {ok, Sock, St}.
+    Packets = [
+       epgsql_wire:encode_close(Type, Name),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Packets, Sock, St}.
 
 handle_message(?CLOSE_COMPLETE, <<>>, Sock, _St) ->
     {finish, ok, ok, Sock};

+ 2 - 3
src/commands/epgsql_cmd_connect.erl

@@ -82,9 +82,8 @@ execute(PgSock, #connect{opts = #{username := Username} = Opts, stage = connect}
         {error, Reason} = Error ->
             {stop, Reason, Error, PgSock}
     end;
-execute(PgSock, #connect{stage = auth, auth_send = {PacketId, Data}} = St) ->
-    ok = epgsql_sock:send(PgSock, PacketId, Data),
-    {ok, PgSock, St#connect{auth_send = undefined}}.
+execute(PgSock, #connect{stage = auth, auth_send = {PacketType, Data}} = St) ->
+    {send, PacketType, Data, PgSock, St#connect{auth_send = undefined}}.
 
 -spec open_socket([{atom(), any()}], epgsql:connect_opts()) ->
     {ok , gen_tcp | ssl, port() | ssl:sslsocket()} | {error, any()}.

+ 5 - 6
src/commands/epgsql_cmd_describe_portal.erl

@@ -22,13 +22,12 @@ init(Name) ->
     #desc_portal{name = Name}.
 
 execute(Sock, #desc_portal{name = Name} = St) ->
-    epgsql_sock:send_multi(
-      Sock,
+    Commands =
       [
-       {?DESCRIBE, [?PORTAL, Name, 0]},
-       {?FLUSH, []}
-      ]),
-    {ok, Sock, St}.
+       epgsql_wire:encode_describe(portal, Name),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 handle_message(?ROW_DESCRIPTION, <<Count:?int16, Bin/binary>>, Sock, _St) ->
     Codec = epgsql_sock:get_codec(Sock),

+ 5 - 6
src/commands/epgsql_cmd_describe_statement.erl

@@ -26,13 +26,12 @@ init(Name) ->
     #desc_stmt{name = Name}.
 
 execute(Sock, #desc_stmt{name = Name} = St) ->
-    epgsql_sock:send_multi(
-      Sock,
+    Commands =
       [
-       {?DESCRIBE, [?PREPARED_STATEMENT, Name, 0]},
-       {?FLUSH, []}
-      ]),
-    {ok, Sock, St}.
+       epgsql_wire:encode_describe(statement, Name),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 handle_message(?PARAMETER_DESCRIPTION, Bin, Sock, State) ->
     Codec = epgsql_sock:get_codec(Sock),

+ 7 - 8
src/commands/epgsql_cmd_equery.erl

@@ -43,15 +43,14 @@ execute(Sock, #equery{stmt = Stmt, params = TypedParams} = St) ->
     Codec = epgsql_sock:get_codec(Sock),
     Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec),
     Bin2 = epgsql_wire:encode_formats(Columns),
-    epgsql_sock:send_multi(
-      Sock,
+    Commands =
       [
-       {?BIND, ["", 0, StatementName, 0, Bin1, Bin2]},
-       {?EXECUTE, ["", 0, <<0:?int32>>]},
-       {?CLOSE, [?PREPARED_STATEMENT, StatementName, 0]},
-       {?SYNC, []}
-      ]),
-    {ok, Sock, St}.
+       epgsql_wire:encode_bind("", StatementName, Bin1, Bin2),
+       epgsql_wire:encode_execute("", 0),
+       epgsql_wire:encode_close(statement, StatementName),
+       epgsql_wire:encode_sync()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 handle_message(?BIND_COMPLETE, <<>>, Sock, #equery{stmt = Stmt} = State) ->
     #statement{columns = Columns} = Stmt,

+ 6 - 7
src/commands/epgsql_cmd_execute.erl

@@ -32,16 +32,15 @@ init({Stmt, PortalName, MaxRows}) ->
     #execute{stmt = Stmt, portal_name = PortalName, max_rows = MaxRows}.
 
 execute(Sock, #execute{stmt = Stmt, portal_name = PortalName, max_rows = MaxRows} = State) ->
-    epgsql_sock:send_multi(
-      Sock,
-      [
-       {?EXECUTE, [PortalName, 0, <<MaxRows:?int32>>]},
-       {?FLUSH, []}
-      ]),
     #statement{columns = Columns} = Stmt,
     Codec = epgsql_sock:get_codec(Sock),
     Decoder = epgsql_wire:build_decoder(Columns, Codec),
-    {ok, Sock, State#execute{decoder = Decoder}}.
+    Commands =
+      [
+       epgsql_wire:encode_execute(PortalName, MaxRows),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Commands, Sock, State#execute{decoder = Decoder}}.
 
 handle_message(?DATA_ROW, <<_Count:?int16, Bin/binary>>, Sock,
                #execute{decoder = Decoder} = St) ->

+ 11 - 7
src/commands/epgsql_cmd_parse.erl

@@ -1,5 +1,10 @@
 %% @doc Asks server to parse SQL query and send information aboud bind-parameters and result columns.
 %%
+%% Empty `Name' creates a "disposable" anonymous prepared statement.
+%% Non-empty `Name' creates a named prepared statement (name is not shared between connections),
+%% which should be explicitly closed when no logner needed (but will be terminated automatically
+%% when connection is closed).
+%% Non-empty name can't be rebound to another query; it should be closed for being available again.
 %% ```
 %% > Parse
 %% < ParseComplete
@@ -31,14 +36,13 @@ init({Name, Sql, Types}) ->
 execute(Sock, #parse{name = Name, sql = Sql, types = Types} = St) ->
     Codec = epgsql_sock:get_codec(Sock),
     Bin = epgsql_wire:encode_types(Types, Codec),
-    epgsql_sock:send_multi(
-      Sock,
+    Commands =
       [
-       {?PARSE, [Name, 0, Sql, 0, Bin]},
-       {?DESCRIBE, [?PREPARED_STATEMENT, Name, 0]},
-       {?FLUSH, []}
-      ]),
-    {ok, Sock, St}.
+       epgsql_wire:encode_parse(Name, Sql, Bin),
+       epgsql_wire:encode_describe(statement, Name),
+       epgsql_wire:encode_flush()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 handle_message(?PARSE_COMPLETE, <<>>, Sock, _State) ->
     {noaction, Sock};

+ 6 - 7
src/commands/epgsql_cmd_prepared_query.erl

@@ -37,14 +37,13 @@ execute(Sock, #pquery{stmt = Stmt, params = TypedParams} = St) ->
     Codec = epgsql_sock:get_codec(Sock),
     Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec),
     Bin2 = epgsql_wire:encode_formats(Columns),
-    epgsql_sock:send_multi(
-      Sock,
+    Commands =
       [
-       {?BIND, ["", 0, StatementName, 0, Bin1, Bin2]},
-       {?EXECUTE, ["", 0, <<0:?int32>>]},
-       {?SYNC, []}
-      ]),
-    {ok, Sock, St}.
+       epgsql_wire:encode_bind("", StatementName, Bin1, Bin2),
+       epgsql_wire:encode_execute("", 0),
+       epgsql_wire:encode_sync()
+      ],
+    {send_multi, Commands, Sock, St}.
 
 handle_message(?BIND_COMPLETE, <<>>, Sock, #pquery{stmt = Stmt} = State) ->
     #statement{columns = Columns} = Stmt,

+ 2 - 2
src/commands/epgsql_cmd_squery.erl

@@ -38,8 +38,8 @@ init(Sql) ->
     #squery{query = Sql}.
 
 execute(Sock, #squery{query = Q} = State) ->
-    epgsql_sock:send(Sock, ?SIMPLEQUERY, [Q, 0]),
-    {ok, Sock, State}.
+    {Type, Data} = epgsql_wire:encode_query(Q),
+    {send, Type, Data, Sock, State}.
 
 handle_message(?ROW_DESCRIPTION, <<Count:?int16, Bin/binary>>, Sock, State) ->
     Codec = epgsql_sock:get_codec(Sock),

+ 2 - 3
src/commands/epgsql_cmd_start_replication.erl

@@ -59,9 +59,8 @@ execute(Sock, #start_repl{slot = ReplicationSlot, callback = Callback,
                        align_lsn = AlignLsn},
     Sock2 = epgsql_sock:set_attr(replication_state, Repl3, Sock),
                          %% handler = on_replication},
-
-    epgsql_sock:send(Sock2, ?SIMPLEQUERY, [Sql2, 0]),
-    {ok, Sock2, St}.
+    {PktType, PktData} = epgsql_wire:encode_query(Sql2),
+    {send, PktType, PktData, Sock2, St}.
 
 %% CopyBothResponse
 handle_message(?COPY_BOTH_RESPONSE, _Data, Sock, _State) ->

+ 3 - 2
src/commands/epgsql_cmd_sync.erl

@@ -1,6 +1,7 @@
 %% @doc Synchronize client and server states for multi-command combinations
 %%
 %% Should be executed if APIs start to return `{error, sync_required}'.
+%% See [https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY]
 %% ```
 %% > Sync
 %% < ReadyForQuery
@@ -20,9 +21,9 @@ init(_) ->
     undefined.
 
 execute(Sock, St) ->
-    epgsql_sock:send(Sock, ?SYNC, []),
     Sock1 = epgsql_sock:set_attr(sync_required, false, Sock),
-    {ok, Sock1, St}.
+    {Type, Data} = epgsql_wire:encode_sync(),
+    {send, Type, Data, Sock1, St}.
 
 handle_message(?READY_FOR_QUERY, _, Sock, _State) ->
     {finish, ok, ok, Sock};

+ 2 - 2
src/commands/epgsql_cmd_update_type_cache.erl

@@ -22,8 +22,8 @@ execute(Sock, #upd{codecs = Codecs} = State) ->
     CodecEntries = epgsql_codec:init_mods(Codecs, Sock),
     TypeNames = [element(1, Entry) || Entry <- CodecEntries],
     Query = epgsql_oid_db:build_query(TypeNames),
-    epgsql_sock:send(Sock, ?SIMPLEQUERY, [Query, 0]),
-    {ok, Sock, State#upd{codec_entries = CodecEntries}}.
+    {PktType, PktData} = epgsql_wire:encode_query(Query),
+    {send, PktType, PktData, Sock, State#upd{codec_entries = CodecEntries}}.
 
 handle_message(?ROW_DESCRIPTION, <<Count:?int16, Bin/binary>>, Sock, State) ->
     Codec = epgsql_sock:get_codec(Sock),

+ 4 - 0
src/epgsql_command.erl

@@ -15,6 +15,10 @@
 
 -type execute_return() ::
         {ok, epgsql_sock:pg_sock(), state()}
+      | {send, epgsql_wire:packet_type(), PktData :: iodata(),
+         epgsql_sock:pg_sock(), state()}
+      | {send_multi, [{epgsql_wire:packet_type(), PktData :: iodata()}],
+         epgsql_sock:pg_sock(), state()}
       | {stop, Reason :: any(), Response :: any(), epgsql_sock:pg_sock()}.
 
 %% Execute command. It should send commands to socket.

+ 8 - 2
src/epgsql_sock.erl

@@ -290,6 +290,12 @@ command_exec(Transport, Command, CmdState, State) ->
     case epgsql_command:execute(Command, State, CmdState) of
         {ok, State1, CmdState1} ->
             {noreply, command_enqueue(Transport, Command, CmdState1, State1)};
+        {send, PktType, PktData, State1, CmdState1} ->
+            ok = send(State1, PktType, PktData),
+            {noreply, command_enqueue(Transport, Command, CmdState1, State1)};
+        {send_multi, Packets, State1, CmdState1} when is_list(Packets) ->
+            ok = send_multi(State1, Packets),
+            {noreply, command_enqueue(Transport, Command, CmdState1, State1)};
         {stop, StopReason, Response, State1} ->
             reply(Transport, Response, Response),
             {stop, StopReason, State1}
@@ -368,11 +374,11 @@ setopts(#state{mod = Mod, sock = Sock}, Opts) ->
 send(#state{mod = Mod, sock = Sock}, Data) ->
     do_send(Mod, Sock, epgsql_wire:encode_command(Data)).
 
--spec send(pg_sock(), byte(), iodata()) -> ok | {error, any()}.
+-spec send(pg_sock(), epgsql_wire:packet_type(), iodata()) -> ok | {error, any()}.
 send(#state{mod = Mod, sock = Sock}, Type, Data) ->
     do_send(Mod, Sock, epgsql_wire:encode_command(Type, Data)).
 
--spec send_multi(pg_sock(), [{byte(), iodata()}]) -> ok | {error, any()}.
+-spec send_multi(pg_sock(), [{epgsql_wire:packet_type(), iodata()}]) -> ok | {error, any()}.
 send_multi(#state{mod = Mod, sock = Sock}, List) ->
     do_send(Mod, Sock, lists:map(fun({Type, Data}) ->
                                     epgsql_wire:encode_command(Type, Data)

+ 102 - 2
src/epgsql_wire.erl

@@ -24,15 +24,27 @@
          format/2,
          encode_parameters/2,
          encode_standby_status_update/3]).
--export_type([row_decoder/0]).
+%% Encoders for Client -> Server packets
+-export([encode_query/1,
+         encode_parse/3,
+         encode_describe/2,
+         encode_bind/4,
+         encode_execute/2,
+         encode_close/2,
+         encode_flush/0,
+         encode_sync/0]).
+
+-export_type([row_decoder/0, packet_type/0]).
 
 -include("epgsql.hrl").
 -include("protocol.hrl").
 
 -opaque row_decoder() :: {[epgsql_binary:decoder()], [epgsql:column()], epgsql_binary:codec()}.
+-type packet_type() :: byte().                 % see protocol.hrl
+%% -type packet_type(Exact) :: Exact.   % TODO: uncomment when OTP-18 is dropped
 
 %% @doc tries to extract single postgresql packet from TCP stream
--spec decode_message(binary()) -> {byte(), binary(), binary()} | binary().
+-spec decode_message(binary()) -> {packet_type(), binary(), binary()} | binary().
 decode_message(<<Type:8, Len:?int32, Rest/binary>> = Bin) ->
     Len2 = Len - 4,
     case Rest of
@@ -60,6 +72,10 @@ decode_strings(Bin) ->
     <<Subj:Sz/binary, 0>> = Bin,
     binary:split(Subj, <<0>>, [global]).
 
+-spec encode_string(iodata()) -> iodata().
+encode_string(Val) ->
+    [Val, 0].
+
 %% @doc decode error's field
 -spec decode_fields(binary()) -> [{byte(), binary()}].
 decode_fields(Bin) ->
@@ -282,6 +298,7 @@ encode_command(Data) ->
     [<<(Size + 4):?int32>> | Data].
 
 %% @doc Encode PG command with type and size prefix
+-spec encode_command(packet_type(), iodata()) -> iodata().
 encode_command(Type, Data) ->
     Size = iolist_size(Data),
     [<<Type:8, (Size + 4):?int32>> | Data].
@@ -292,3 +309,86 @@ encode_standby_status_update(ReceivedLSN, FlushedLSN, AppliedLSN) ->
     %% microseconds since midnight on 2000-01-01
     Timestamp = ((MegaSecs * 1000000 + Secs) * 1000000 + MicroSecs) - 946684800*1000000,
     <<$r:8, ReceivedLSN:?int64, FlushedLSN:?int64, AppliedLSN:?int64, Timestamp:?int64, 0:8>>.
+
+%%
+%% Encoders for various PostgreSQL protocol client-side packets
+%% See https://www.postgresql.org/docs/current/protocol-message-formats.html
+%%
+
+%% @doc encodes simple 'Query' packet.
+encode_query(SQL) ->
+    {?SIMPLEQUERY, encode_string(SQL)}.
+
+%% @doc encodes 'Parse' packet.
+%%
+%% Results in `ParseComplete' response.
+%%
+%% @param ColumnEncoding see {@link encode_types/2}
+-spec encode_parse(iodata(), iodata(), iodata()) -> {packet_type(), iodata()}.
+encode_parse(Name, SQL, ColumnEncoding) ->
+    {?PARSE, [encode_string(Name), encode_string(SQL), ColumnEncoding]}.
+
+%% @doc encodes `Describe' packet.
+%%
+%% @param What might be `?PORTAL' (results in `RowDescription' response) or `?PREPARED_STATEMENT'
+%%   (results in `ParameterDescription' followed by `RowDescription' or `NoData' response)
+-spec encode_describe(byte() | statement | portal, iodata()) ->
+          {packet_type(), iodata()}.
+encode_describe(What, Name) when What =:= ?PREPARED_STATEMENT;
+                                 What =:= ?PORTAL ->
+    {?DESCRIBE, [What, encode_string(Name)]};
+encode_describe(What, Name) when is_atom(What) ->
+    encode_describe(obj_atom_to_byte(What), Name).
+
+%% @doc encodes `Bind' packet.
+%%
+%% @param BinParams see {@link encode_parameters/2}.
+%% @param BinFormats  see {@link encode_formats/1}
+-spec encode_bind(iodata(), iodata(), iodata(), iodata()) -> {packet_type(), iodata()}.
+encode_bind(PortalName, StmtName, BinParams, BinFormats) ->
+    {?BIND, [encode_string(PortalName), encode_string(StmtName), BinParams, BinFormats]}.
+
+%% @doc encodes `Execute' packet.
+%%
+%% Results in 0 or up to `MaxRows' packets of `DataRow' type followed by `CommandComplete' (when no
+%% more rows are available) or `PortalSuspend' (repeated `Execute' will return more rows)
+%%
+%% @param PortalName  might be an empty string (anonymous portal) or name of the named portal
+%% @param MaxRows  how many rows server should send (0 means all of them)
+-spec encode_execute(iodata(), non_neg_integer()) -> {packet_type(), iodata()}.
+encode_execute("", 0) ->
+    %% optimization: literal for most common case
+    {?EXECUTE, [0, <<0:?int32>>]};
+encode_execute(PortalName, MaxRows) ->
+    {?EXECUTE, [encode_string(PortalName), <<MaxRows:?int32>>]}.
+
+%% @doc encodes `Close' packet.
+%%
+%% Results in `CloseComplete' response
+%%
+%% @param What see {@link encode_describe/2}
+-spec encode_close(byte() | statement | portal, iodata()) ->
+          {packet_type(), iodata()}.
+encode_close(What, Name) when What =:= ?PREPARED_STATEMENT;
+                              What =:= ?PORTAL ->
+    {?CLOSE, [What, encode_string(Name)]};
+encode_close(What, Name) when is_atom(What) ->
+    encode_close(obj_atom_to_byte(What), Name).
+
+%% @doc encodes `Flush' packet.
+%%
+%% It doesn't cause any specific response packet, but tells PostgreSQL server to flush it's send
+%% network buffers
+-spec encode_flush() -> {packet_type(), iodata()}.
+encode_flush() ->
+    {?FLUSH, []}.
+
+%% @doc encodes `Sync' packet.
+%%
+%% Results in `ReadyForQuery' response
+-spec encode_sync() -> {packet_type(), iodata()}.
+encode_sync() ->
+    {?SYNC, []}.
+
+obj_atom_to_byte(statement) -> ?PREPARED_STATEMENT;
+obj_atom_to_byte(portal) -> ?PORTAL.