pgsql_sock.erl 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. %%% Copyright (C) 2009 - Will Glozer. All rights reserved.
  2. -module(pgsql_sock).
  3. -behavior(gen_server).
  4. -export([start_link/4, cancel/3]).
  5. -export([decode_string/1, lower_atom/1]).
  6. -export([handle_call/3, handle_cast/2, handle_info/2]).
  7. -export([init/1, code_change/3, terminate/2]).
  8. -include("pgsql.hrl").
  9. -include("pgsql_binary.hrl").
  10. -record(state, {mod, sock, decoder, backend}).
  11. %% -- client interface --
  12. start_link(Host, Username, Opts) ->
  13. gen_server:start_link(?MODULE, [Host, Username, Opts], []).
  14. cancel(S) ->
  15. gen_server:cast(S, cancel}).
  16. %% -- gen_server implementation --
  17. init([C, Host, Username, Opts]) ->
  18. Opts2 = ["user", 0, Username, 0],
  19. case proplists:get_value(database, Opts, undefined) of
  20. undefined -> Opts3 = Opts2;
  21. Database -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
  22. end,
  23. Port = proplists:get_value(port, Opts, 5432),
  24. SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}],
  25. {ok, S} = gen_tcp:connect(Host, Port, SockOpts),
  26. State = #state{
  27. mod = gen_tcp,
  28. sock = S,
  29. tail = <<>>},
  30. case proplists:get_value(ssl, Opts) of
  31. T when T == true; T == required ->
  32. ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
  33. {ok, <<Code>>} = gen_tcp:recv(S, 1),
  34. State2 = start_ssl(Code, T, Opts, State);
  35. _ ->
  36. State2 = State
  37. end,
  38. setopts(State2, [{active, true}]),
  39. send(self(), [<<196608:32>>, Opts3, 0]),
  40. {ok, State2}.
  41. handle_call(Call, _From, State) ->
  42. {stop, {unsupported_call, Call}, State}.
  43. handle_cast({send, Data}, State) ->
  44. #state{mod = Mod, sock = Sock} = State,
  45. ok = Mod:send(Sock, Data),
  46. {noreply, State};
  47. handle_cast(cancel, State = #state{backend = {Pid, Key}}) ->
  48. {ok, {Addr, Port}} = inet:peername(State#state.sock),
  49. SockOpts = [{active, false}, {packet, raw}, binary],
  50. {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
  51. Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
  52. ok = gen_tcp:send(Sock, Msg),
  53. gen_tcp:close(Sock),
  54. {noreply, State};
  55. handle_cast(Cast, State) ->
  56. {stop, {unsupported_cast, Cast}, State}.
  57. handle_info({_, _Sock, Data}, #state{tail = Tail} = State) ->
  58. State2 = decode(<<Tail/binary, Data/binary>>, State),
  59. {noreply, State2};
  60. handle_info({Closed, _Sock}, State)
  61. when Closed == tcp_closed; Closed == ssl_closed ->
  62. {stop, sock_closed, State};
  63. handle_info({Error, _Sock, Reason}, State)
  64. when Error == tcp_error; Error == ssl_error ->
  65. {stop, {sock_error, Reason}, State}.
  66. terminate(_Reason, _State) ->
  67. ok.
  68. code_change(_OldVsn, State, _Extra) ->
  69. {ok, State}.
  70. %% -- internal functions --
  71. start_ssl($S, _Flag, Opts, State) ->
  72. #state{sock = S1} = State,
  73. case ssl:connect(S1, Opts) of
  74. {ok, S2} -> State#state{mod = ssl, sock = S2};
  75. {error, Reason} -> exit({ssl_negotiation_failed, Reason})
  76. end;
  77. start_ssl($N, Flag, _Opts, State) ->
  78. case Flag of
  79. true -> State;
  80. required -> exit(ssl_not_available)
  81. end.
  82. setopts(#state{mod = Mod, sock = Sock}, Opts) ->
  83. case Mod of
  84. gen_tcp -> inet:setopts(Sock, Opts);
  85. ssl -> ssl:setopts(Sock, Opts)
  86. end.
  87. decode(<<Type:8, Len:?int32, Rest/binary>> = Bin, #state{c = C} = State) ->
  88. Len2 = Len - 4,
  89. case Rest of
  90. <<Data:Len2/binary, Tail/binary>> when Type == $N ->
  91. gen_fsm:send_all_state_event(C, {notice, decode_error(Data)}),
  92. decode(Tail, State);
  93. <<Data:Len2/binary, Tail/binary>> when Type == $S ->
  94. [Name, Value] = decode_strings(Data),
  95. gen_fsm:send_all_state_event(C, {parameter_status, Name, Value}),
  96. decode(Tail, State);
  97. <<Data:Len2/binary, Tail/binary>> when Type == $E ->
  98. gen_fsm:send_event(C, {error, decode_error(Data)}),
  99. decode(Tail, State);
  100. <<Data:Len2/binary, Tail/binary>> when Type == $A ->
  101. <<Pid:?int32, Strings/binary>> = Data,
  102. case decode_strings(Strings) of
  103. [Channel, Payload] -> ok;
  104. [Channel] -> Payload = <<>>
  105. end,
  106. gen_fsm:send_all_state_event(C, {notification, Channel, Pid, Payload}),
  107. decode(Tail, State);
  108. <<Data:Len2/binary, Tail/binary>> ->
  109. gen_fsm:send_event(C, {Type, Data}),
  110. decode(Tail, State);
  111. _Other ->
  112. State#state{tail = Bin}
  113. end;
  114. decode(Bin, State) ->
  115. State#state{tail = Bin}.
  116. %% decode a single null-terminated string
  117. decode_string(Bin) ->
  118. decode_string(Bin, <<>>).
  119. decode_string(<<0, Rest/binary>>, Str) ->
  120. {Str, Rest};
  121. decode_string(<<C, Rest/binary>>, Str) ->
  122. decode_string(Rest, <<Str/binary, C>>).
  123. %% decode multiple null-terminated string
  124. decode_strings(Bin) ->
  125. decode_strings(Bin, []).
  126. decode_strings(<<>>, Acc) ->
  127. lists:reverse(Acc);
  128. decode_strings(Bin, Acc) ->
  129. {Str, Rest} = decode_string(Bin),
  130. decode_strings(Rest, [Str | Acc]).
  131. %% decode field
  132. decode_fields(Bin) ->
  133. decode_fields(Bin, []).
  134. decode_fields(<<0>>, Acc) ->
  135. Acc;
  136. decode_fields(<<Type:8, Rest/binary>>, Acc) ->
  137. {Str, Rest2} = decode_string(Rest),
  138. decode_fields(Rest2, [{Type, Str} | Acc]).
  139. %% decode ErrorResponse
  140. decode_error(Bin) ->
  141. Fields = decode_fields(Bin),
  142. Error = #error{
  143. severity = lower_atom(proplists:get_value($S, Fields)),
  144. code = proplists:get_value($C, Fields),
  145. message = proplists:get_value($M, Fields),
  146. extra = decode_error_extra(Fields)},
  147. Error.
  148. decode_error_extra(Fields) ->
  149. Types = [{$D, detail}, {$H, hint}, {$P, position}],
  150. decode_error_extra(Types, Fields, []).
  151. decode_error_extra([], _Fields, Extra) ->
  152. Extra;
  153. decode_error_extra([{Type, Name} | T], Fields, Extra) ->
  154. case proplists:get_value(Type, Fields) of
  155. undefined -> decode_error_extra(T, Fields, Extra);
  156. Value -> decode_error_extra(T, Fields, [{Name, Value} | Extra])
  157. end.
  158. lower_atom(Str) when is_binary(Str) ->
  159. lower_atom(binary_to_list(Str));
  160. lower_atom(Str) when is_list(Str) ->
  161. list_to_atom(string:to_lower(Str)).