%%% Copyright (C) 2009 - Will Glozer. All rights reserved. %%% Copyright (C) 2011 - Anton Lebedevich. All rights reserved. %%% @doc GenServer holding all connection state (including socket). %%% %%% See https://www.postgresql.org/docs/current/static/protocol-flow.html %%% Commands in PostgreSQL are pipelined: you don't need to wait for reply to %%% be able to send next command. %%% Commands are processed (and responses to them are generated) in FIFO order. %%% eg, if you execute 2 SimpleQuery: #1 and #2, first you get all response %%% packets for #1 and then all for #2: %%% > SQuery #1 %%% > SQuery #2 %%% < RowDescription #1 %%% < DataRow #1 %%% < CommandComplete #1 %%% < RowDescription #2 %%% < DataRow #2 %%% < CommandComplete #2 %%% %%% See epgsql_cmd_connect for network connection and authentication setup -module(epgsql_sock). -behavior(gen_server). -export([start_link/0, close/1, sync_command/3, async_command/4, get_parameter/2, set_notice_receiver/2, get_cmd_status/1, cancel/1]). -export([handle_call/3, handle_cast/2, handle_info/2]). -export([init/1, code_change/3, terminate/2]). %% loop callback -export([on_message/3, on_replication/3]). %% Comand's APIs -export([set_net_socket/3, init_replication_state/1, set_attr/3, get_codec/1, get_rows/1, get_results/1, notify/2, send/2, send/3, send_multi/2, get_parameter_internal/2, get_replication_state/1, set_packet_handler/2]). -export_type([transport/0, pg_sock/0]). -include("epgsql.hrl"). -include("protocol.hrl"). -include("epgsql_replication.hrl"). -type transport() :: {call, any()} | {cast, pid(), reference()} | {incremental, pid(), reference()}. -type tcp_socket() :: port(). %gen_tcp:socket() isn't exported prior to erl 18 -type repl_state() :: #repl{}. -record(state, {mod :: gen_tcp | ssl | undefined, sock :: tcp_socket() | ssl:sslsocket() | undefined, data = <<>>, backend :: {Pid :: integer(), Key :: integer()} | undefined, handler = on_message :: on_message | on_replication | undefined, codec :: epgsql_binary:codec() | undefined, queue = queue:new() :: queue:queue({epgsql_command:command(), any(), transport()}), current_cmd :: epgsql_command:command() | undefined, current_cmd_state :: any() | undefined, current_cmd_transport :: transport() | undefined, async :: undefined | atom() | pid(), parameters = [] :: [{Key :: binary(), Value :: binary()}], rows = [] :: [tuple()], results = [], sync_required :: boolean() | undefined, txstatus :: byte() | undefined, % $I | $T | $E, complete_status :: atom() | {atom(), integer()} | undefined, repl :: repl_state() | undefined}). -opaque pg_sock() :: #state{}. %% -- client interface -- start_link() -> gen_server:start_link(?MODULE, [], []). close(C) when is_pid(C) -> catch gen_server:cast(C, stop), ok. -spec sync_command(epgsql:connection(), epgsql_command:command(), any()) -> any(). sync_command(C, Command, Args) -> gen_server:call(C, {command, Command, Args}, infinity). -spec async_command(epgsql:connection(), cast | incremental, epgsql_command:command(), any()) -> reference(). async_command(C, Transport, Command, Args) -> Ref = make_ref(), Pid = self(), ok = gen_server:cast(C, {{Transport, Pid, Ref}, Command, Args}), Ref. get_parameter(C, Name) -> gen_server:call(C, {get_parameter, to_binary(Name)}, infinity). set_notice_receiver(C, PidOrName) when is_pid(PidOrName); is_atom(PidOrName) -> gen_server:call(C, {set_async_receiver, PidOrName}, infinity). get_cmd_status(C) -> gen_server:call(C, get_cmd_status, infinity). cancel(S) -> gen_server:cast(S, cancel). %% -- command APIs -- %% send() %% send_many() -spec set_net_socket(gen_tcp | ssl, tcp_socket() | ssl:sslsocket(), pg_sock()) -> pg_sock(). set_net_socket(Mod, Socket, State) -> State1 = State#state{mod = Mod, sock = Socket}, setopts(State1, [{active, true}]), State1. -spec init_replication_state(pg_sock()) -> pg_sock(). init_replication_state(State) -> State#state{repl = #repl{}}. -spec set_attr(atom(), any(), pg_sock()) -> pg_sock(). set_attr(backend, {_Pid, _Key} = Backend, State) -> State#state{backend = Backend}; set_attr(async, Async, State) -> State#state{async = Async}; set_attr(txstatus, Status, State) -> State#state{txstatus = Status}; set_attr(codec, Codec, State) -> State#state{codec = Codec}; set_attr(sync_required, Value, State) -> State#state{sync_required = Value}; set_attr(replication_state, Value, State) -> State#state{repl = Value}. %% XXX: be careful! -spec set_packet_handler(atom(), pg_sock()) -> pg_sock(). set_packet_handler(Handler, State) -> State#state{handler = Handler}. -spec get_codec(pg_sock()) -> epgsql_binary:codec(). get_codec(#state{codec = Codec}) -> Codec. -spec get_replication_state(pg_sock()) -> repl_state(). get_replication_state(#state{repl = Repl}) -> Repl. -spec get_rows(pg_sock()) -> [tuple()]. get_rows(#state{rows = Rows}) -> lists:reverse(Rows). -spec get_results(pg_sock()) -> [any()]. get_results(#state{results = Results}) -> lists:reverse(Results). -spec get_parameter_internal(binary(), pg_sock()) -> binary() | undefined. get_parameter_internal(Name, #state{parameters = Parameters}) -> case lists:keysearch(Name, 1, Parameters) of {value, {Name, Value}} -> Value; false -> undefined end. %% -- gen_server implementation -- init([]) -> {ok, #state{}}. handle_call({get_parameter, Name}, _From, State) -> {reply, {ok, get_parameter_internal(Name, State)}, State}; handle_call({set_async_receiver, PidOrName}, _From, #state{async = Previous} = State) -> {reply, {ok, Previous}, State#state{async = PidOrName}}; handle_call(get_cmd_status, _From, #state{complete_status = Status} = State) -> {reply, {ok, Status}, State}; handle_call({standby_status_update, FlushedLSN, AppliedLSN}, _From, #state{handler = on_replication, repl = #repl{last_received_lsn = ReceivedLSN} = Repl} = State) -> send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update(ReceivedLSN, FlushedLSN, AppliedLSN)), Repl1 = Repl#repl{last_flushed_lsn = FlushedLSN, last_applied_lsn = AppliedLSN}, {reply, ok, State#state{repl = Repl1}}; handle_call({command, Command, Args}, From, State) -> Transport = {call, From}, command_new(Transport, Command, Args, State). handle_cast({{Method, From, Ref} = Transport, Command, Args}, State) when ((Method == cast) or (Method == incremental)), is_pid(From), is_reference(Ref) -> command_new(Transport, Command, Args, State); handle_cast(stop, State) -> {stop, normal, flush_queue(State, {error, closed})}; handle_cast(cancel, State = #state{backend = {Pid, Key}, sock = TimedOutSock}) -> {ok, {Addr, Port}} = case State#state.mod of gen_tcp -> inet:peername(TimedOutSock); ssl -> ssl:peername(TimedOutSock) end, SockOpts = [{active, false}, {packet, raw}, binary], %% TODO timeout {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts), Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>, ok = gen_tcp:send(Sock, Msg), gen_tcp:close(Sock), {noreply, State}. handle_info({Closed, Sock}, #state{sock = Sock} = State) when Closed == tcp_closed; Closed == ssl_closed -> {stop, sock_closed, flush_queue(State#state{sock = undefined}, {error, sock_closed})}; handle_info({Error, Sock, Reason}, #state{sock = Sock} = State) when Error == tcp_error; Error == ssl_error -> Why = {sock_error, Reason}, {stop, Why, flush_queue(State, {error, Why})}; handle_info({inet_reply, _, ok}, State) -> {noreply, State}; handle_info({inet_reply, _, Status}, State) -> {stop, Status, flush_queue(State, {error, Status})}; handle_info({_, Sock, Data2}, #state{data = Data, sock = Sock} = State) -> loop(State#state{data = <>}). terminate(_Reason, #state{sock = undefined}) -> ok; terminate(_Reason, #state{mod = gen_tcp, sock = Sock}) -> gen_tcp:close(Sock); terminate(_Reason, #state{mod = ssl, sock = Sock}) -> ssl:close(Sock). code_change(_OldVsn, State, _Extra) -> {ok, State}. %% -- internal functions -- -spec command_new(transport(), epgsql_command:command(), any(), pg_sock()) -> Result when Result :: {noreply, pg_sock()} | {stop, Reason :: any(), pg_sock()}. command_new(Transport, Command, Args, State) -> CmdState = epgsql_command:init(Command, Args), command_exec(Transport, Command, CmdState, State). -spec command_exec(transport(), epgsql_command:command(), any(), pg_sock()) -> Result when Result :: {noreply, pg_sock()} | {stop, Reason :: any(), pg_sock()}. command_exec(Transport, Command, _, State = #state{sync_required = true}) when Command /= epgsql_cmd_sync -> {noreply, finish(State#state{current_cmd = Command, current_cmd_transport = Transport}, {error, sync_required})}; command_exec(Transport, Command, CmdState, State) -> case epgsql_command:execute(Command, State, CmdState) of {ok, State1, CmdState1} -> {noreply, command_enqueue(Transport, Command, CmdState1, State1)}; {stop, StopReason, Response, State1} -> reply(Transport, Response, Response), {stop, StopReason, State1} end. -spec command_enqueue(transport(), epgsql_command:command(), epgsql_command:state(), pg_sock()) -> pg_sock(). command_enqueue(Transport, Command, CmdState, #state{current_cmd = undefined} = State) -> State#state{current_cmd = Command, current_cmd_state = CmdState, current_cmd_transport = Transport, complete_status = undefined}; command_enqueue(Transport, Command, CmdState, #state{queue = Q} = State) -> State#state{queue = queue:in({Command, CmdState, Transport}, Q), complete_status = undefined}. -spec command_handle_message(byte(), binary() | epgsql:query_error(), pg_sock()) -> {noreply, pg_sock()} | {stop, any(), pg_sock()}. command_handle_message(Msg, Payload, #state{current_cmd = Command, current_cmd_state = CmdState} = State) -> case epgsql_command:handle_message(Command, Msg, Payload, State, CmdState) of {add_row, Row, State1, CmdState1} -> {noreply, add_row(State1#state{current_cmd_state = CmdState1}, Row)}; {add_result, Result, Notice, State1, CmdState1} -> {noreply, add_result(State1#state{current_cmd_state = CmdState1}, Notice, Result)}; {finish, Result, Notice, State1} -> {noreply, finish(State1, Notice, Result)}; {noaction, State1} -> {noreply, State1}; {noaction, State1, CmdState1} -> {noreply, State1#state{current_cmd_state = CmdState1}}; {requeue, State1, CmdState1} -> Transport = State1#state.current_cmd_transport, command_exec(Transport, Command, CmdState1, State1#state{current_cmd = undefined}); {stop, Reason, Response, State1} -> {stop, Reason, finish(State1, Response)}; {sync_required, Why} -> %% Protocol error. Finish and flush all pending commands. {noreply, sync_required(finish(State#state{sync_required = true}, Why))}; unknown -> {stop, {error, {unexpected_message, Msg, Command, CmdState}}, State} end. command_next(#state{current_cmd = PrevCmd, queue = Q} = State) when PrevCmd =/= undefined -> case queue:out(Q) of {empty, _} -> State#state{current_cmd = undefined, current_cmd_state = undefined, current_cmd_transport = undefined, rows = [], results = []}; {{value, {Command, CmdState, Transport}}, Q1} -> State#state{current_cmd = Command, current_cmd_state = CmdState, current_cmd_transport = Transport, queue = Q1, rows = [], results = []} end. setopts(#state{mod = Mod, sock = Sock}, Opts) -> case Mod of gen_tcp -> inet:setopts(Sock, Opts); ssl -> ssl:setopts(Sock, Opts) end. %% This one only used in connection initiation to send client's %% `StartupMessage' and `SSLRequest' packets -spec send(pg_sock(), iodata()) -> ok | {error, any()}. send(#state{mod = Mod, sock = Sock}, Data) -> do_send(Mod, Sock, epgsql_wire:encode_command(Data)). -spec send(pg_sock(), byte(), iodata()) -> ok | {error, any()}. send(#state{mod = Mod, sock = Sock}, Type, Data) -> do_send(Mod, Sock, epgsql_wire:encode_command(Type, Data)). -spec send_multi(pg_sock(), [{byte(), iodata()}]) -> ok | {error, any()}. send_multi(#state{mod = Mod, sock = Sock}, List) -> do_send(Mod, Sock, lists:map(fun({Type, Data}) -> epgsql_wire:encode_command(Type, Data) end, List)). do_send(gen_tcp, Sock, Bin) -> %% Why not gen_tcp:send/2? %% See https://github.com/rabbitmq/rabbitmq-common/blob/v3.7.4/src/rabbit_writer.erl#L367-L384 %% Because of that we also have `handle_info({inet_reply, ...` try erlang:port_command(Sock, Bin) of true -> ok catch error:_Error -> {error, einval} end; do_send(ssl, Sock, Bin) -> ssl:send(Sock, Bin). loop(#state{data = Data, handler = Handler, repl = Repl} = State) -> case epgsql_wire:decode_message(Data) of {Type, Payload, Tail} -> case ?MODULE:Handler(Type, Payload, State#state{data = Tail}) of {noreply, State2} -> loop(State2); R = {stop, _Reason2, _State2} -> R end; _ -> %% in replication mode send feedback after each batch of messages case (Repl =/= undefined) andalso (Repl#repl.feedback_required) of true -> #repl{last_received_lsn = LastReceivedLSN, last_flushed_lsn = LastFlushedLSN, last_applied_lsn = LastAppliedLSN} = Repl, send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update( LastReceivedLSN, LastFlushedLSN, LastAppliedLSN)), {noreply, State#state{repl = Repl#repl{feedback_required = false}}}; _ -> {noreply, State} end end. finish(State, Result) -> finish(State, Result, Result). finish(State = #state{current_cmd_transport = Transport}, Notice, Result) -> reply(Transport, Notice, Result), command_next(State). reply({cast, From, Ref}, _, Result) -> From ! {self(), Ref, Result}; reply({incremental, From, Ref}, Notice, _) -> From ! {self(), Ref, Notice}; reply({call, From}, _, Result) -> gen_server:reply(From, Result). add_result(#state{results = Results, current_cmd_transport = Transport} = State, Notice, Result) -> Results2 = case Transport of {incremental, From, Ref} -> From ! {self(), Ref, Notice}, Results; _ -> [Result | Results] end, State#state{rows = [], results = Results2}. add_row(#state{rows = Rows, current_cmd_transport = Transport} = State, Data) -> Rows2 = case Transport of {incremental, From, Ref} -> From ! {self(), Ref, {data, Data}}, Rows; _ -> [Data | Rows] end, State#state{rows = Rows2}. notify(#state{current_cmd_transport = {incremental, From, Ref}} = State, Notice) -> From ! {self(), Ref, Notice}, State; notify(State, _) -> State. %% Send asynchronous messages (notice / notification) notify_async(#state{async = undefined}, _) -> false; notify_async(#state{async = PidOrName}, Msg) -> try PidOrName ! {epgsql, self(), Msg} of _ -> true catch error:badarg -> %% no process registered under this name false end. sync_required(#state{current_cmd = epgsql_cmd_sync} = State) -> State; sync_required(#state{current_cmd = undefined} = State) -> State#state{sync_required = true}; sync_required(State) -> sync_required(finish(State, {error, sync_required})). flush_queue(#state{current_cmd = undefined} = State, _) -> State; flush_queue(State, Error) -> flush_queue(finish(State, Error), Error). to_binary(B) when is_binary(B) -> B; to_binary(L) when is_list(L) -> list_to_binary(L). %% -- backend message handling -- %% CommandComplete on_message(?COMMAND_COMPLETE = Msg, Bin, State) -> Complete = epgsql_wire:decode_complete(Bin), command_handle_message(Msg, Bin, State#state{complete_status = Complete}); %% ReadyForQuery on_message(?READY_FOR_QUERY = Msg, <> = Bin, State) -> command_handle_message(Msg, Bin, State#state{txstatus = Status}); %% Error on_message(?ERROR = Msg, Err, #state{current_cmd = CurrentCmd} = State) -> Reason = epgsql_wire:decode_error(Err), case CurrentCmd of undefined -> %% Message generated by server asynchronously {stop, {shutdown, Reason}, State}; _ -> command_handle_message(Msg, Reason, State) end; %% NoticeResponse on_message(?NOTICE, Data, State) -> notify_async(State, {notice, epgsql_wire:decode_error(Data)}), {noreply, State}; %% ParameterStatus on_message(?PARAMETER_STATUS, Data, State) -> [Name, Value] = epgsql_wire:decode_strings(Data), Parameters2 = lists:keystore(Name, 1, State#state.parameters, {Name, Value}), {noreply, State#state{parameters = Parameters2}}; %% NotificationResponse on_message(?NOTIFICATION, <>, State) -> {Channel1, Payload1} = case epgsql_wire:decode_strings(Strings) of [Channel, Payload] -> {Channel, Payload}; [Channel] -> {Channel, <<>>} end, notify_async(State, {notification, Channel1, Pid, Payload1}), {noreply, State}; %% ParseComplete %% ParameterDescription %% RowDescription %% NoData %% BindComplete %% CloseComplete %% DataRow %% PortalSuspended %% EmptyQueryResponse %% CopyData %% CopyBothResponse on_message(Msg, Payload, State) -> command_handle_message(Msg, Payload, State). %% CopyData for Replication mode on_replication(?COPY_DATA, <>, #state{repl = #repl{last_flushed_lsn = LastFlushedLSN, last_applied_lsn = LastAppliedLSN, align_lsn = AlignLsn} = Repl} = State) -> Repl1 = case ReplyRequired of 1 when AlignLsn -> send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update(LSN, LSN, LSN)), Repl#repl{feedback_required = false, last_received_lsn = LSN, last_applied_lsn = LSN, last_flushed_lsn = LSN}; 1 when not AlignLsn -> send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update(LSN, LastFlushedLSN, LastAppliedLSN)), Repl#repl{feedback_required = false, last_received_lsn = LSN}; _ -> Repl#repl{feedback_required = true, last_received_lsn = LSN} end, {noreply, State#state{repl = Repl1}}; %% CopyData for Replication mode on_replication(?COPY_DATA, <>, #state{repl = Repl} = State) -> Repl1 = handle_xlog_data(StartLSN, EndLSN, WALRecord, Repl), {noreply, State#state{repl = Repl1}}; on_replication(?ERROR, Err, State) -> Reason = epgsql_wire:decode_error(Err), {stop, {error, Reason}, State}; on_replication(M, Data, Sock) when M == ?NOTICE; M == ?NOTIFICATION; M == ?PARAMETER_STATUS -> on_message(M, Data, Sock). handle_xlog_data(StartLSN, EndLSN, WALRecord, #repl{cbmodule = undefined, receiver = Receiver} = Repl) -> %% with async messages Receiver ! {epgsql, self(), {x_log_data, StartLSN, EndLSN, WALRecord}}, Repl#repl{feedback_required = true, last_received_lsn = EndLSN}; handle_xlog_data(StartLSN, EndLSN, WALRecord, #repl{cbmodule = CbModule, cbstate = CbState, receiver = undefined} = Repl) -> %% with callback method {ok, LastFlushedLSN, LastAppliedLSN, NewCbState} = epgsql:handle_x_log_data(CbModule, StartLSN, EndLSN, WALRecord, CbState), Repl#repl{feedback_required = true, last_received_lsn = EndLSN, last_flushed_lsn = LastFlushedLSN, last_applied_lsn = LastAppliedLSN, cbstate = NewCbState}.