pgsql_sock.erl 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. %%% Copyright (C) 2009 - Will Glozer. All rights reserved.
  2. -module(pgsql_sock).
  3. -behavior(gen_server).
  4. -export([start_link/4, cancel/1]).
  5. -export([handle_call/3, handle_cast/2, handle_info/2]).
  6. -export([init/1, code_change/3, terminate/2]).
  7. -include("pgsql.hrl").
  8. -include("pgsql_binary.hrl").
  9. -record(state, {mod,
  10. sock,
  11. data,
  12. backend,
  13. on_message,
  14. on_timeout,
  15. ready,
  16. timeout}).
  17. %% -- client interface --
  18. start_link(Host, Username, Password, Opts) ->
  19. gen_server:start_link(?MODULE, [Host, Username, Password, Opts], []).
  20. cancel(S) ->
  21. gen_server:cast(S, cancel).
  22. %% -- gen_server implementation --
  23. init([Host, Username, Password, Opts]) ->
  24. gen_server:cast(self(), {connect, Host, Username, Password, Opts}),
  25. %% TODO split connect/query timeout?
  26. Timeout = proplists:get_value(timeout, Opts, 5000),
  27. {ok, #state{timeout = Timeout}}.
  28. handle_call(Call, _From, State) ->
  29. {stop, {unsupported_call, Call}, State}.
  30. handle_cast({connect, Host, Username, Password, Opts},
  31. #state{timeout = Timeout} = State) ->
  32. Port = proplists:get_value(port, Opts, 5432),
  33. SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}],
  34. {ok, Sock} = gen_tcp:connect(Host, Port, SockOpts, Timeout),
  35. State2 = case proplists:get_value(ssl, Opts) of
  36. T when T == true; T == required ->
  37. start_ssl(Sock, T, Opts, State);
  38. _ ->
  39. State#state{mod = gen_tcp, sock = Sock}
  40. end,
  41. Opts2 = ["user", 0, Username, 0],
  42. case proplists:get_value(database, Opts, undefined) of
  43. undefined -> Opts3 = Opts2;
  44. Database -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
  45. end,
  46. send(State2, [<<196608:?int32>>, Opts3, 0]),
  47. %% TODO Async = proplists:get_value(async, Opts, undefined),
  48. setopts(State2, [{active, true}]),
  49. {noreply,
  50. State2#state{on_message = fun(M, S) ->
  51. auth(Username, Password, M, S)
  52. end,
  53. on_timeout = fun auth_timeout/1},
  54. Timeout};
  55. handle_cast(cancel, State = #state{backend = {Pid, Key}}) ->
  56. {ok, {Addr, Port}} = inet:peername(State#state.sock),
  57. SockOpts = [{active, false}, {packet, raw}, binary],
  58. {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
  59. Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
  60. ok = gen_tcp:send(Sock, Msg),
  61. gen_tcp:close(Sock),
  62. {noreply, State}.
  63. handle_info({Closed, Sock}, #state{sock = Sock} = State)
  64. when Closed == tcp_closed; Closed == ssl_closed ->
  65. {stop, sock_closed, State};
  66. handle_info({Error, Sock, Reason}, #state{sock = Sock} = State)
  67. when Error == tcp_error; Error == ssl_error ->
  68. {stop, {sock_error, Reason}, State};
  69. handle_info(timeout, #state{on_timeout = OnTimeout} = State) ->
  70. OnTimeout(State);
  71. handle_info({_, Sock, Data2}, #state{data = Data, sock = Sock} = State) ->
  72. loop(State#state{data = <<Data/binary, Data2/binary>>}, infinity).
  73. loop(#state{data = Data, on_message = OnMessage} = State, Timeout) ->
  74. case pgsql_wire:decode_message(Data) of
  75. {Message, Tail} ->
  76. case OnMessage(Message, State#state{data = Tail}) of
  77. {noreply, State2} ->
  78. loop(State2, infinity);
  79. {noreply, State2, Timeout2} ->
  80. loop(State2, Timeout2);
  81. R = {stop, _Reason2, _State2} ->
  82. R
  83. end;
  84. _ ->
  85. {noreply, State, Timeout}
  86. end.
  87. terminate(_Reason, _State) ->
  88. ok.
  89. code_change(_OldVsn, State, _Extra) ->
  90. {ok, State}.
  91. %% -- internal functions --
  92. start_ssl(S, Flag, Opts, #state{timeout = Timeout} = State) ->
  93. ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
  94. {ok, <<Code>>} = gen_tcp:recv(S, 1, Timeout),
  95. case Code of
  96. $S ->
  97. case ssl:connect(S, Opts, Timeout) of
  98. {ok, S2} -> State#state{mod = ssl, sock = S2};
  99. {error, Reason} -> exit({ssl_negotiation_failed, Reason})
  100. end;
  101. $N ->
  102. case Flag of
  103. true -> State;
  104. required -> exit(ssl_not_available)
  105. end
  106. end.
  107. setopts(#state{mod = Mod, sock = Sock}, Opts) ->
  108. case Mod of
  109. gen_tcp -> inet:setopts(Sock, Opts);
  110. ssl -> ssl:setopts(Sock, Opts)
  111. end.
  112. send(#state{mod = Mod, sock = Sock}, Data) ->
  113. Mod:send(Sock, pgsql_wire:encode(Data)).
  114. send(#state{mod = Mod, sock = Sock}, Type, Data) ->
  115. Mod:send(Sock, pgsql_wire:encode(Type, Data)).
  116. %% -- backend message handling --
  117. %% AuthenticationOk
  118. auth(_Username, _Password, {$R, <<0:?int32>>}, State) ->
  119. #state{timeout = Timeout} = State,
  120. {State#state{on_message = fun initializing/2}, Timeout};
  121. %% AuthenticationCleartextPassword
  122. auth(_Username, Password, {$R, <<3:?int32>>}, State) ->
  123. #state{timeout = Timeout} = State,
  124. send(State, $p, [Password, 0]),
  125. {noreply, State, Timeout};
  126. %% AuthenticationMD5Password
  127. auth(Username, Password, {$R, <<5:?int32, Salt:4/binary>>}, State) ->
  128. #state{timeout = Timeout} = State,
  129. Digest1 = hex(erlang:md5([Password, Username])),
  130. Str = ["md5", hex(erlang:md5([Digest1, Salt])), 0],
  131. send(State, $p, Str),
  132. {noreply, State, Timeout};
  133. auth(_Username, _Password, {$R, <<M:?int32, _/binary>>}, State) ->
  134. case M of
  135. 2 -> Method = kerberosV5;
  136. 4 -> Method = crypt;
  137. 6 -> Method = scm;
  138. 7 -> Method = gss;
  139. 8 -> Method = sspi;
  140. _ -> Method = unknown
  141. end,
  142. Error = {error, {unsupported_auth_method, Method}},
  143. %% TODO send error response
  144. {stop, Error, State};
  145. %% ErrorResponse
  146. %% TODO who decodes error ?
  147. auth(_Username, _Password, {error, E}, State) ->
  148. case E#error.code of
  149. <<"28000">> -> Why = invalid_authorization_specification;
  150. <<"28P01">> -> Why = invalid_password;
  151. Any -> Why = Any
  152. end,
  153. %% TODO send error response
  154. {stop, {error, Why}, State}.
  155. auth_timeout(State) ->
  156. %% TODO send error response
  157. {stop, {error, timeout}, State}.
  158. initializing(_, State) ->
  159. %% TODO incomplete
  160. {noreply, State#state{on_message = fun on_message/2}}.
  161. on_message({$N, Data}, State) ->
  162. %% TODO use it
  163. {notice, pgsql_wire:decode_error(Data)},
  164. {infinity, State};
  165. on_message({$S, Data}, State) ->
  166. [Name, Value] = pgsql_wire:decode_strings(Data),
  167. %% TODO use it
  168. {parameter_status, Name, Value},
  169. {infinity, State};
  170. on_message({$E, Data}, State) ->
  171. %% TODO use it
  172. {error, pgsql_wire:decode_error(Data)},
  173. {infinity, State};
  174. on_message({$A, <<Pid:?int32, Strings/binary>>}, State) ->
  175. case pgsql_wire:decode_strings(Strings) of
  176. [Channel, Payload] -> ok;
  177. [Channel] -> Payload = <<>>
  178. end,
  179. %% TODO use it
  180. {notification, Channel, Pid, Payload},
  181. {infinity, State};
  182. on_message(_Msg, State) ->
  183. {infinity, State}.
  184. hex(Bin) ->
  185. HChar = fun(N) when N < 10 -> $0 + N;
  186. (N) when N < 16 -> $W + N
  187. end,
  188. <<<<(HChar(H)), (HChar(L))>> || <<H:4, L:4>> <= Bin>>.