123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601 |
- %%% Copyright (C) 2009 - Will Glozer. All rights reserved.
- %%% Copyright (C) 2011 - Anton Lebedevich. All rights reserved.
- -module(pgsql_sock).
- -behavior(gen_server).
- -export([start_link/0,
- close/1,
- get_parameter/2,
- cancel/1]).
- -export([handle_call/3, handle_cast/2, handle_info/2]).
- -export([init/1, code_change/3, terminate/2]).
- %% state callbacks
- -export([auth/2, initializing/2, on_message/2]).
- -include("pgsql.hrl").
- -include("pgsql_binary.hrl").
- -record(state, {mod,
- sock,
- data = <<>>,
- backend,
- handler,
- queue = queue:new(),
- async,
- parameters = [],
- types = [],
- columns = [],
- rows = [],
- results = [],
- sync_required,
- txstatus}).
- %% -- client interface --
- start_link() ->
- gen_server:start_link(?MODULE, [], []).
- close(C) when is_pid(C) ->
- catch gen_server:cast(C, stop),
- ok.
- get_parameter(C, Name) ->
- gen_server:call(C, {get_parameter, to_binary(Name)}, infinity).
- cancel(S) ->
- gen_server:cast(S, cancel).
- %% -- gen_server implementation --
- init([]) ->
- {ok, #state{}}.
- handle_call({get_parameter, Name}, _From, State) ->
- case lists:keysearch(Name, 1, State#state.parameters) of
- {value, {Name, Value}} -> Value;
- false -> Value = undefined
- end,
- {reply, {ok, Value}, State};
- handle_call(Command, From, State) ->
- #state{queue = Q} = State,
- Req = {{call, From}, Command},
- command(Command, State#state{queue = queue:in(Req, Q)}).
- handle_cast({{Method, From, Ref}, Command} = Req, State)
- when ((Method == cast) or (Method == incremental)),
- is_pid(From),
- is_reference(Ref) ->
- #state{queue = Q} = State,
- command(Command, State#state{queue = queue:in(Req, Q)});
- handle_cast(stop, State) ->
- {stop, normal, flush_queue(State, {error, closed})};
- handle_cast(cancel, State = #state{backend = {Pid, Key}}) ->
- {ok, {Addr, Port}} = inet:peername(State#state.sock),
- SockOpts = [{active, false}, {packet, raw}, binary],
- %% TODO timeout
- {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
- Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
- ok = gen_tcp:send(Sock, Msg),
- gen_tcp:close(Sock),
- {noreply, State}.
- handle_info({Closed, Sock}, #state{sock = Sock} = State)
- when Closed == tcp_closed; Closed == ssl_closed ->
- {stop, sock_closed, flush_queue(State, {error, sock_closed})};
- handle_info({Error, Sock, Reason}, #state{sock = Sock} = State)
- when Error == tcp_error; Error == ssl_error ->
- Why = {sock_error, Reason},
- {stop, Why, flush_queue(State, {error, Why})};
- handle_info({inet_reply, _, ok}, State) ->
- {noreply, State};
- handle_info({inet_reply, _, Status}, State) ->
- {stop, Status, flush_queue(State, {error, Status})};
- handle_info({_, Sock, Data2}, #state{data = Data, sock = Sock} = State) ->
- loop(State#state{data = <<Data/binary, Data2/binary>>}).
- terminate(_Reason, _State) ->
- %% TODO send termination msg, close socket ??
- ok.
- code_change(_OldVsn, State, _Extra) ->
- {ok, State}.
- %% -- internal functions --
- command(Command, State = #state{sync_required = true})
- when Command /= sync ->
- {noreply, finish(State, {error, sync_required})};
- command({connect, Host, Username, Password, Opts}, State) ->
- Timeout = proplists:get_value(timeout, Opts, 5000),
- Port = proplists:get_value(port, Opts, 5432),
- SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}],
- {ok, Sock} = gen_tcp:connect(Host, Port, SockOpts, Timeout),
- State2 = case proplists:get_value(ssl, Opts) of
- T when T == true; T == required ->
- start_ssl(Sock, T, Opts, State);
- _ ->
- State#state{mod = gen_tcp, sock = Sock}
- end,
- Opts2 = ["user", 0, Username, 0],
- case proplists:get_value(database, Opts, undefined) of
- undefined -> Opts3 = Opts2;
- Database -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
- end,
- send(State2, [<<196608:?int32>>, Opts3, 0]),
- Async = proplists:get_value(async, Opts, undefined),
- setopts(State2, [{active, true}]),
- put(username, Username),
- put(password, Password),
- {noreply,
- State2#state{handler = auth,
- async = Async}};
- command({squery, Sql}, State) ->
- send(State, $Q, [Sql, 0]),
- {noreply, State};
- %% TODO add fast_equery command that doesn't need parsed statement,
- %% uses default (text) column format,
- %% sends Describe after Bind to get RowDescription
- command({equery, Statement, Parameters}, State) ->
- #statement{name = StatementName, columns = Columns} = Statement,
- Bin1 = pgsql_wire:encode_parameters(Parameters),
- Bin2 = pgsql_wire:encode_formats(Columns),
- send(State, $B, ["", 0, StatementName, 0, Bin1, Bin2]),
- send(State, $E, ["", 0, <<0:?int32>>]),
- send(State, $C, [$S, "", 0]),
- send(State, $S, []),
- {noreply, State};
- command({parse, Name, Sql, Types}, State) ->
- Bin = pgsql_wire:encode_types(Types),
- send(State, $P, [Name, 0, Sql, 0, Bin]),
- send(State, $D, [$S, Name, 0]),
- send(State, $H, []),
- {noreply, State};
- command({bind, Statement, PortalName, Parameters}, State) ->
- #statement{name = StatementName, columns = Columns, types = Types} = Statement,
- Typed_Parameters = lists:zip(Types, Parameters),
- Bin1 = pgsql_wire:encode_parameters(Typed_Parameters),
- Bin2 = pgsql_wire:encode_formats(Columns),
- send(State, $B, [PortalName, 0, StatementName, 0, Bin1, Bin2]),
- send(State, $H, []),
- {noreply, State};
- command({execute, _Statement, PortalName, MaxRows}, State) ->
- send(State, $E, [PortalName, 0, <<MaxRows:?int32>>]),
- send(State, $H, []),
- {noreply, State};
- command({describe_statement, Name}, State) ->
- send(State, $D, [$S, Name, 0]),
- send(State, $H, []),
- {noreply, State};
- command({describe_portal, Name}, State) ->
- send(State, $D, [$P, Name, 0]),
- send(State, $H, []),
- {noreply, State};
- command({close, Type, Name}, State) ->
- case Type of
- statement -> Type2 = $S;
- portal -> Type2 = $P
- end,
- send(State, $C, [Type2, Name, 0]),
- send(State, $H, []),
- {noreply, State};
- command(sync, State) ->
- send(State, $S, []),
- {noreply, State#state{sync_required = false}}.
- start_ssl(S, Flag, Opts, State) ->
- ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
- Timeout = proplists:get_value(timeout, Opts, 5000),
- {ok, <<Code>>} = gen_tcp:recv(S, 1, Timeout),
- case Code of
- $S ->
- case ssl:connect(S, Opts, Timeout) of
- {ok, S2} -> State#state{mod = ssl, sock = S2};
- {error, Reason} -> exit({ssl_negotiation_failed, Reason})
- end;
- $N ->
- case Flag of
- true -> State;
- required -> exit(ssl_not_available)
- end
- end.
- setopts(#state{mod = Mod, sock = Sock}, Opts) ->
- case Mod of
- gen_tcp -> inet:setopts(Sock, Opts);
- ssl -> ssl:setopts(Sock, Opts)
- end.
- send(#state{mod = Mod, sock = Sock}, Data) ->
- do_send(Mod, Sock, pgsql_wire:encode(Data)).
- send(#state{mod = Mod, sock = Sock}, Type, Data) ->
- do_send(Mod, Sock, pgsql_wire:encode(Type, Data)).
- do_send(gen_tcp, Sock, Bin) ->
- try erlang:port_command(Sock, Bin) of
- true ->
- ok
- catch
- error:_Error ->
- {error,einval}
- end;
- do_send(Mod, Sock, Bin) ->
- Mod:send(Sock, Bin).
- loop(#state{data = Data, handler = Handler} = State) ->
- case pgsql_wire:decode_message(Data) of
- {Message, Tail} ->
- case ?MODULE:Handler(Message, State#state{data = Tail}) of
- {noreply, State2} ->
- loop(State2);
- R = {stop, _Reason2, _State2} ->
- R
- end;
- _ ->
- {noreply, State}
- end.
- finish(State, Result) ->
- finish(State, Result, Result).
- finish(State = #state{queue = Q}, Notice, Result) ->
- case queue:get(Q) of
- {{cast, From, Ref}, _} ->
- From ! {self(), Ref, Result};
- {{incremental, From, Ref}, _} ->
- From ! {self(), Ref, Notice};
- {{call, From}, _} ->
- gen_server:reply(From, Result)
- end,
- State#state{queue = queue:drop(Q),
- types = [],
- columns = [],
- rows = [],
- results = []}.
- add_result(State = #state{queue = Q, results = Results}, Notice, Result) ->
- Results2 = case queue:get(Q) of
- {{incremental, From, Ref}, _} ->
- From ! {self(), Ref, Notice},
- Results;
- _ ->
- [Result | Results]
- end,
- State#state{types = [],
- columns = [],
- rows = [],
- results = Results2}.
- add_row(State = #state{queue = Q, rows = Rows}, Data) ->
- Rows2 = case queue:get(Q) of
- {{incremental, From, Ref}, _} ->
- From ! {self(), Ref, {data, Data}},
- Rows;
- _ ->
- [Data | Rows]
- end,
- State#state{rows = Rows2}.
- notify(State = #state{queue = Q}, Notice) ->
- case queue:get(Q) of
- {{incremental, From, Ref}, _} ->
- From ! {self(), Ref, Notice};
- _ ->
- ignore
- end,
- State.
- notify_async(State = #state{async = Pid}, Msg) ->
- case is_pid(Pid) of
- true -> Pid ! {pgsql, self(), Msg};
- false -> false
- end,
- State.
- command_tag(#state{queue = Q}) ->
- {_, Req} = queue:get(Q),
- if is_tuple(Req) ->
- element(1, Req);
- is_atom(Req) ->
- Req
- end.
- get_columns(State) ->
- #state{queue = Q, columns = Columns} = State,
- case queue:get(Q) of
- {_, {equery, #statement{columns = C}, _}} ->
- C;
- {_, {execute, #statement{columns = C}, _, _}} ->
- C;
- {_, {squery, _}} ->
- Columns
- end.
- make_statement(State) ->
- #state{queue = Q, types = Types, columns = Columns} = State,
- Name = case queue:get(Q) of
- {_, {parse, N, _, _}} -> N;
- {_, {describe_statement, N}} -> N
- end,
- #statement{name = Name, types = Types, columns = Columns}.
- sync_required(#state{queue = Q} = State) ->
- case queue:is_empty(Q) of
- false ->
- case command_tag(State) of
- sync ->
- State;
- _ ->
- sync_required(finish(State, {error, sync_required}))
- end;
- true ->
- State#state{sync_required = true}
- end.
- flush_queue(#state{queue = Q} = State, Error) ->
- case queue:is_empty(Q) of
- false ->
- flush_queue(finish(State, Error), Error);
- true -> State
- end.
- to_binary(B) when is_binary(B) -> B;
- to_binary(L) when is_list(L) -> list_to_binary(L).
- hex(Bin) ->
- HChar = fun(N) when N < 10 -> $0 + N;
- (N) when N < 16 -> $W + N
- end,
- <<<<(HChar(H)), (HChar(L))>> || <<H:4, L:4>> <= Bin>>.
- %% -- backend message handling --
- %% AuthenticationOk
- auth({$R, <<0:?int32>>}, State) ->
- {noreply, State#state{handler = initializing}};
- %% AuthenticationCleartextPassword
- auth({$R, <<3:?int32>>}, State) ->
- send(State, $p, [get(password), 0]),
- {noreply, State};
- %% AuthenticationMD5Password
- auth({$R, <<5:?int32, Salt:4/binary>>}, State) ->
- Digest1 = hex(erlang:md5([get(password), get(username)])),
- Str = ["md5", hex(erlang:md5([Digest1, Salt])), 0],
- send(State, $p, Str),
- {noreply, State};
- auth({$R, <<M:?int32, _/binary>>}, State) ->
- case M of
- 2 -> Method = kerberosV5;
- 4 -> Method = crypt;
- 6 -> Method = scm;
- 7 -> Method = gss;
- 8 -> Method = sspi;
- _ -> Method = unknown
- end,
- State2 = finish(State, {error, {unsupported_auth_method, Method}}),
- {stop, normal, State2};
- %% ErrorResponse
- auth({error, E}, State) ->
- case E#error.code of
- <<"28000">> -> Why = invalid_authorization_specification;
- <<"28P01">> -> Why = invalid_password;
- Any -> Why = Any
- end,
- {stop, normal, finish(State, {error, Why})};
- auth(Other, State) ->
- on_message(Other, State).
- %% BackendKeyData
- initializing({$K, <<Pid:?int32, Key:?int32>>}, State) ->
- {noreply, State#state{backend = {Pid, Key}}};
- %% ReadyForQuery
- initializing({$Z, <<Status:8>>}, State) ->
- #state{parameters = Parameters} = State,
- erase(username),
- erase(password),
- %% TODO decode dates to now() format
- case lists:keysearch(<<"integer_datetimes">>, 1, Parameters) of
- {value, {_, <<"on">>}} -> put(datetime_mod, pgsql_idatetime);
- {value, {_, <<"off">>}} -> put(datetime_mod, pgsql_fdatetime)
- end,
- State2 = finish(State#state{handler = on_message,
- txstatus = Status},
- connected),
- {noreply, State2};
- initializing({error, _} = Error, State) ->
- {stop, normal, finish(State, Error)};
- initializing(Other, State) ->
- on_message(Other, State).
- %% ParseComplete
- on_message({$1, <<>>}, State) ->
- {noreply, State};
- %% ParameterDescription
- on_message({$t, <<_Count:?int16, Bin/binary>>}, State) ->
- Types = [pgsql_types:oid2type(Oid) || <<Oid:?int32>> <= Bin],
- State2 = notify(State#state{types = Types}, {types, Types}),
- {noreply, State2};
- %% RowDescription
- on_message({$T, <<Count:?int16, Bin/binary>>}, State) ->
- Columns = pgsql_wire:decode_columns(Count, Bin),
- Columns2 =
- case command_tag(State) of
- C when C == describe_portal; C == squery ->
- Columns;
- C when C == parse; C == describe_statement ->
- [Col#column{format = pgsql_wire:format(Col#column.type)}
- || Col <- Columns]
- end,
- State2 = State#state{columns = Columns2},
- Message = {columns, Columns2},
- State3 = case command_tag(State2) of
- squery ->
- notify(State2, Message);
- T when T == parse; T == describe_statement ->
- finish(State2, Message, {ok, make_statement(State2)});
- describe_portal ->
- finish(State2, Message, {ok, Columns})
- end,
- {noreply, State3};
- %% NoData
- on_message({$n, <<>>}, State) ->
- State2 = case command_tag(State) of
- C when C == parse; C == describe_statement ->
- finish(State, no_data, {ok, make_statement(State)});
- describe_portal ->
- finish(State, no_data, {ok, []})
- end,
- {noreply, State2};
- %% BindComplete
- on_message({$2, <<>>}, State) ->
- State2 = case command_tag(State) of
- equery ->
- %% TODO send Describe as a part of equery, needs text format support
- notify(State, {columns, get_columns(State)});
- bind ->
- finish(State, ok)
- end,
- {noreply, State2};
- %% CloseComplete
- on_message({$3, <<>>}, State) ->
- State2 = case command_tag(State) of
- equery ->
- State;
- close ->
- finish(State, ok)
- end,
- {noreply, State2};
- %% DataRow
- on_message({$D, <<_Count:?int16, Bin/binary>>}, State) ->
- Data = pgsql_wire:decode_data(get_columns(State), Bin),
- {noreply, add_row(State, Data)};
- %% PortalSuspended
- on_message({$s, <<>>}, State) ->
- State2 = finish(State,
- suspended,
- {partial, lists:reverse(State#state.rows)}),
- {noreply, State2};
- %% CommandComplete
- on_message({$C, Bin}, State) ->
- Complete = pgsql_wire:decode_complete(Bin),
- Command = command_tag(State),
- Notice = {complete, Complete},
- Rows = lists:reverse(State#state.rows),
- State2 = case {Command, Complete, Rows} of
- {execute, {_, Count}, []} ->
- finish(State, Notice, {ok, Count});
- {execute, {_, Count}, _} ->
- finish(State, Notice, {ok, Count, Rows});
- {execute, _, _} ->
- finish(State, Notice, {ok, Rows});
- {C, {_, Count}, []} when C == squery; C == equery ->
- add_result(State, Notice, {ok, Count});
- {C, {_, Count}, _} when C == squery; C == equery ->
- add_result(State, Notice, {ok, Count, get_columns(State), Rows});
- {C, _, _} when C == squery; C == equery ->
- add_result(State, Notice, {ok, get_columns(State), Rows})
- end,
- {noreply, State2};
- %% EmptyQueryResponse
- on_message({$I, _Bin}, State) ->
- Notice = {complete, empty},
- State2 = case command_tag(State) of
- execute ->
- finish(State, Notice, {ok, [], []});
- C when C == squery; C == equery ->
- add_result(State, Notice, {ok, [], []})
- end,
- {noreply, State2};
- %% ReadyForQuery
- on_message({$Z, <<Status:8>>}, State) ->
- State2 = case command_tag(State) of
- squery ->
- case State#state.results of
- [Result] ->
- finish(State, done, Result);
- Results ->
- finish(State, done, lists:reverse(Results))
- end;
- equery ->
- case State#state.results of
- [Result] ->
- finish(State, done, Result);
- [] ->
- finish(State, done)
- end;
- sync ->
- finish(State, ok)
- end,
- {noreply, State2#state{txstatus = Status}};
- on_message(Error = {error, _}, State) ->
- State2 = case command_tag(State) of
- C when C == squery; C == equery ->
- add_result(State, Error, Error);
- _ ->
- sync_required(finish(State, Error))
- end,
- {noreply, State2};
- %% NoticeResponse
- on_message({$N, Data}, State) ->
- State2 = notify_async(State, {notice, pgsql_wire:decode_error(Data)}),
- {noreply, State2};
- %% ParameterStatus
- on_message({$S, Data}, State) ->
- [Name, Value] = pgsql_wire:decode_strings(Data),
- Parameters2 = lists:keystore(Name, 1, State#state.parameters,
- {Name, Value}),
- {noreply, State#state{parameters = Parameters2}};
- %% NotificationResponse
- on_message({$A, <<Pid:?int32, Strings/binary>>}, State) ->
- case pgsql_wire:decode_strings(Strings) of
- [Channel, Payload] -> ok;
- [Channel] -> Payload = <<>>
- end,
- State2 = notify_async(State, {notification, Channel, Pid, Payload}),
- {noreply, State2}.
|