pgsql_sock.erl 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. %%% Copyright (C) 2009 - Will Glozer. All rights reserved.
  2. -module(pgsql_sock).
  3. -behavior(gen_server).
  4. -export([start_link/4, send/2, send/3, cancel/3]).
  5. -export([decode_string/1, lower_atom/1]).
  6. -export([handle_call/3, handle_cast/2, handle_info/2]).
  7. -export([init/1, code_change/3, terminate/2]).
  8. -include("pgsql.hrl").
  9. -record(state, {c, mod, sock, tail}).
  10. -define(int16, 1/big-signed-unit:16).
  11. -define(int32, 1/big-signed-unit:32).
  12. %% -- client interface --
  13. start_link(C, Host, Username, Opts) ->
  14. gen_server:start_link(?MODULE, [C, Host, Username, Opts], []).
  15. send(S, Type, Data) ->
  16. Bin = iolist_to_binary(Data),
  17. Msg = <<Type:8, (byte_size(Bin) + 4):?int32, Bin/binary>>,
  18. gen_server:cast(S, {send, Msg}).
  19. send(S, Data) ->
  20. Bin = iolist_to_binary(Data),
  21. Msg = <<(byte_size(Bin) + 4):?int32, Bin/binary>>,
  22. gen_server:cast(S, {send, Msg}).
  23. cancel(S, Pid, Key) ->
  24. gen_server:cast(S, {cancel, Pid, Key}).
  25. %% -- gen_server implementation --
  26. init([C, Host, Username, Opts]) ->
  27. process_flag(trap_exit, true),
  28. Opts2 = ["user", 0, Username, 0],
  29. case proplists:get_value(database, Opts, undefined) of
  30. undefined -> Opts3 = Opts2;
  31. Database -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
  32. end,
  33. Port = proplists:get_value(port, Opts, 5432),
  34. SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}],
  35. {ok, S} = gen_tcp:connect(Host, Port, SockOpts),
  36. State = #state{
  37. c = C,
  38. mod = gen_tcp,
  39. sock = S,
  40. tail = <<>>},
  41. case proplists:get_value(ssl, Opts) of
  42. T when T == true; T == required ->
  43. ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
  44. {ok, <<Code>>} = gen_tcp:recv(S, 1),
  45. State2 = start_ssl(Code, T, Opts, State);
  46. _ ->
  47. State2 = State
  48. end,
  49. setopts(State2, [{active, true}]),
  50. send(self(), [<<196608:32>>, Opts3, 0]),
  51. {ok, State2}.
  52. handle_call(Call, _From, State) ->
  53. {stop, {unsupported_call, Call}, State}.
  54. handle_cast({send, Data}, State) ->
  55. #state{mod = Mod, sock = Sock} = State,
  56. ok = Mod:send(Sock, Data),
  57. {noreply, State};
  58. handle_cast({cancel, Pid, Key}, State) ->
  59. {ok, {Addr, Port}} = inet:peername(State#state.sock),
  60. SockOpts = [{active, false}, {packet, raw}, binary],
  61. {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
  62. Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
  63. ok = gen_tcp:send(Sock, Msg),
  64. gen_tcp:close(Sock),
  65. {noreply, State};
  66. handle_cast(Cast, State) ->
  67. {stop, {unsupported_cast, Cast}, State}.
  68. handle_info({_, _Sock, Data}, #state{tail = Tail} = State) ->
  69. State2 = decode(<<Tail/binary, Data/binary>>, State),
  70. {noreply, State2};
  71. handle_info({Closed, _Sock}, State)
  72. when Closed == tcp_closed; Closed == ssl_closed ->
  73. {stop, sock_closed, State};
  74. handle_info({Error, _Sock, Reason}, State)
  75. when Error == tcp_error; Error == ssl_error ->
  76. {stop, {sock_error, Reason}, State};
  77. handle_info({'EXIT', _Pid, Reason}, State) ->
  78. {stop, Reason, State};
  79. handle_info(Info, State) ->
  80. {stop, {unsupported_info, Info}, State}.
  81. terminate(_Reason, _State) ->
  82. ok.
  83. code_change(_OldVsn, State, _Extra) ->
  84. {ok, State}.
  85. %% -- internal functions --
  86. start_ssl($S, _Flag, Opts, State) ->
  87. #state{sock = S1} = State,
  88. case ssl:connect(S1, Opts) of
  89. {ok, S2} -> State#state{mod = ssl, sock = S2};
  90. {error, Reason} -> exit({ssl_negotiation_failed, Reason})
  91. end;
  92. start_ssl($N, Flag, _Opts, State) ->
  93. case Flag of
  94. true -> State;
  95. required -> exit(ssl_not_available)
  96. end.
  97. setopts(#state{mod = Mod, sock = Sock}, Opts) ->
  98. case Mod of
  99. gen_tcp -> inet:setopts(Sock, Opts);
  100. ssl -> ssl:setopts(Sock, Opts)
  101. end.
  102. decode(<<Type:8, Len:?int32, Rest/binary>> = Bin, #state{c = C} = State) ->
  103. Len2 = Len - 4,
  104. case Rest of
  105. <<Data:Len2/binary, Tail/binary>> when Type == $N ->
  106. gen_fsm:send_all_state_event(C, {notice, decode_error(Data)}),
  107. decode(Tail, State);
  108. <<Data:Len2/binary, Tail/binary>> when Type == $S ->
  109. [Name, Value] = decode_strings(Data),
  110. gen_fsm:send_all_state_event(C, {parameter_status, Name, Value}),
  111. decode(Tail, State);
  112. <<Data:Len2/binary, Tail/binary>> when Type == $E ->
  113. gen_fsm:send_event(C, {error, decode_error(Data)}),
  114. decode(Tail, State);
  115. <<Data:Len2/binary, Tail/binary>> when Type == $A ->
  116. <<Pid:?int32, Strings/binary>> = Data,
  117. case decode_strings(Strings) of
  118. [Channel, Payload] -> ok;
  119. [Channel] -> Payload = <<>>
  120. end,
  121. gen_fsm:send_all_state_event(C, {notification, Channel, Pid, Payload}),
  122. decode(Tail, State);
  123. <<Data:Len2/binary, Tail/binary>> ->
  124. gen_fsm:send_event(C, {Type, Data}),
  125. decode(Tail, State);
  126. _Other ->
  127. State#state{tail = Bin}
  128. end;
  129. decode(Bin, State) ->
  130. State#state{tail = Bin}.
  131. %% decode a single null-terminated string
  132. decode_string(Bin) ->
  133. decode_string(Bin, <<>>).
  134. decode_string(<<0, Rest/binary>>, Str) ->
  135. {Str, Rest};
  136. decode_string(<<C, Rest/binary>>, Str) ->
  137. decode_string(Rest, <<Str/binary, C>>).
  138. %% decode multiple null-terminated string
  139. decode_strings(Bin) ->
  140. decode_strings(Bin, []).
  141. decode_strings(<<>>, Acc) ->
  142. lists:reverse(Acc);
  143. decode_strings(Bin, Acc) ->
  144. {Str, Rest} = decode_string(Bin),
  145. decode_strings(Rest, [Str | Acc]).
  146. %% decode field
  147. decode_fields(Bin) ->
  148. decode_fields(Bin, []).
  149. decode_fields(<<0>>, Acc) ->
  150. Acc;
  151. decode_fields(<<Type:8, Rest/binary>>, Acc) ->
  152. {Str, Rest2} = decode_string(Rest),
  153. decode_fields(Rest2, [{Type, Str} | Acc]).
  154. %% decode ErrorResponse
  155. decode_error(Bin) ->
  156. Fields = decode_fields(Bin),
  157. Error = #error{
  158. severity = lower_atom(proplists:get_value($S, Fields)),
  159. code = proplists:get_value($C, Fields),
  160. message = proplists:get_value($M, Fields),
  161. extra = decode_error_extra(Fields)},
  162. Error.
  163. decode_error_extra(Fields) ->
  164. Types = [{$D, detail}, {$H, hint}, {$P, position}],
  165. decode_error_extra(Types, Fields, []).
  166. decode_error_extra([], _Fields, Extra) ->
  167. Extra;
  168. decode_error_extra([{Type, Name} | T], Fields, Extra) ->
  169. case proplists:get_value(Type, Fields) of
  170. undefined -> decode_error_extra(T, Fields, Extra);
  171. Value -> decode_error_extra(T, Fields, [{Name, Value} | Extra])
  172. end.
  173. lower_atom(Str) when is_binary(Str) ->
  174. lower_atom(binary_to_list(Str));
  175. lower_atom(Str) when is_list(Str) ->
  176. list_to_atom(string:to_lower(Str)).