pgsql_sock.erl 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. %%% Copyright (C) 2009 - Will Glozer. All rights reserved.
  2. -module(pgsql_sock).
  3. -behavior(gen_server).
  4. -export([start_link/0, cancel/1]).
  5. -export([handle_call/3, handle_cast/2, handle_info/2]).
  6. -export([init/1, code_change/3, terminate/2]).
  7. -include("pgsql.hrl").
  8. -include("pgsql_binary.hrl").
  9. -record(state, {mod,
  10. sock,
  11. data,
  12. backend,
  13. handler,
  14. queue = queue:new(),
  15. async,
  16. ready,
  17. timeout,
  18. parameters,
  19. txstatus}).
  20. %% -- client interface --
  21. start_link() ->
  22. gen_server:start_link(?MODULE, [], []).
  23. cancel(S) ->
  24. gen_server:cast(S, cancel).
  25. %% -- gen_server implementation --
  26. init([]) ->
  27. {ok, #state{}}.
  28. handle_call({connect, Host, Username, Password, Opts},
  29. From,
  30. #state{queue = Queue} = State) ->
  31. %% TODO split connect/query timeout?
  32. Timeout = proplists:get_value(timeout, Opts, 5000),
  33. Port = proplists:get_value(port, Opts, 5432),
  34. SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}],
  35. {ok, Sock} = gen_tcp:connect(Host, Port, SockOpts, Timeout),
  36. State2 = case proplists:get_value(ssl, Opts) of
  37. T when T == true; T == required ->
  38. start_ssl(Sock, T, Opts, State);
  39. _ ->
  40. State#state{mod = gen_tcp, sock = Sock}
  41. end,
  42. Opts2 = ["user", 0, Username, 0],
  43. case proplists:get_value(database, Opts, undefined) of
  44. undefined -> Opts3 = Opts2;
  45. Database -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
  46. end,
  47. send(State2, [<<196608:?int32>>, Opts3, 0]),
  48. Async = proplists:get_value(async, Opts, undefined),
  49. setopts(State2, [{active, true}]),
  50. put(username, Username),
  51. put(password, Password),
  52. {noreply,
  53. State2#state{handler = auth,
  54. queue = queue:in(From, Queue),
  55. async = Async},
  56. Timeout}.
  57. handle_cast(cancel, State = #state{backend = {Pid, Key}}) ->
  58. {ok, {Addr, Port}} = inet:peername(State#state.sock),
  59. SockOpts = [{active, false}, {packet, raw}, binary],
  60. {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
  61. Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
  62. ok = gen_tcp:send(Sock, Msg),
  63. gen_tcp:close(Sock),
  64. {noreply, State}.
  65. handle_info({Closed, Sock}, #state{sock = Sock} = State)
  66. when Closed == tcp_closed; Closed == ssl_closed ->
  67. {stop, sock_closed, State};
  68. handle_info({Error, Sock, Reason}, #state{sock = Sock} = State)
  69. when Error == tcp_error; Error == ssl_error ->
  70. {stop, {sock_error, Reason}, State};
  71. handle_info(timeout, #state{handler = Handler} = State) ->
  72. Handler(timeout, State);
  73. handle_info({_, Sock, Data2}, #state{data = Data, sock = Sock} = State) ->
  74. loop(State#state{data = <<Data/binary, Data2/binary>>}, infinity).
  75. loop(#state{data = Data, handler = Handler} = State, Timeout) ->
  76. case pgsql_wire:decode_message(Data) of
  77. {Message, Tail} ->
  78. case ?MODULE:Handler(Message, State#state{data = Tail}) of
  79. {noreply, State2} ->
  80. loop(State2, infinity);
  81. {noreply, State2, Timeout2} ->
  82. loop(State2, Timeout2);
  83. R = {stop, _Reason2, _State2} ->
  84. R
  85. end;
  86. _ ->
  87. {noreply, State, Timeout}
  88. end.
  89. terminate(_Reason, _State) ->
  90. ok.
  91. code_change(_OldVsn, State, _Extra) ->
  92. {ok, State}.
  93. %% -- internal functions --
  94. start_ssl(S, Flag, Opts, #state{timeout = Timeout} = State) ->
  95. ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
  96. {ok, <<Code>>} = gen_tcp:recv(S, 1, Timeout),
  97. case Code of
  98. $S ->
  99. case ssl:connect(S, Opts, Timeout) of
  100. {ok, S2} -> State#state{mod = ssl, sock = S2};
  101. {error, Reason} -> exit({ssl_negotiation_failed, Reason})
  102. end;
  103. $N ->
  104. case Flag of
  105. true -> State;
  106. required -> exit(ssl_not_available)
  107. end
  108. end.
  109. setopts(#state{mod = Mod, sock = Sock}, Opts) ->
  110. case Mod of
  111. gen_tcp -> inet:setopts(Sock, Opts);
  112. ssl -> ssl:setopts(Sock, Opts)
  113. end.
  114. send(#state{mod = Mod, sock = Sock}, Data) ->
  115. Mod:send(Sock, pgsql_wire:encode(Data)).
  116. send(#state{mod = Mod, sock = Sock}, Type, Data) ->
  117. Mod:send(Sock, pgsql_wire:encode(Type, Data)).
  118. reply(#state{queue = Q} = State, Message) ->
  119. {{value, {Pid, _}}, Q2} = queue:out(Q),
  120. Pid ! Message,
  121. State#state{queue = Q2}.
  122. %% -- backend message handling --
  123. %% AuthenticationOk
  124. auth({$R, <<0:?int32>>}, State) ->
  125. #state{timeout = Timeout} = State,
  126. {noreply, State#state{handler = initializing}, Timeout};
  127. %% AuthenticationCleartextPassword
  128. auth({$R, <<3:?int32>>}, State) ->
  129. #state{timeout = Timeout} = State,
  130. send(State, $p, [get(password), 0]),
  131. {noreply, State, Timeout};
  132. %% AuthenticationMD5Password
  133. auth({$R, <<5:?int32, Salt:4/binary>>}, State) ->
  134. #state{timeout = Timeout} = State,
  135. Digest1 = hex(erlang:md5([get(password), get(username)])),
  136. Str = ["md5", hex(erlang:md5([Digest1, Salt])), 0],
  137. send(State, $p, Str),
  138. {noreply, State, Timeout};
  139. auth({$R, <<M:?int32, _/binary>>}, State) ->
  140. case M of
  141. 2 -> Method = kerberosV5;
  142. 4 -> Method = crypt;
  143. 6 -> Method = scm;
  144. 7 -> Method = gss;
  145. 8 -> Method = sspi;
  146. _ -> Method = unknown
  147. end,
  148. Error = {error, {unsupported_auth_method, Method}},
  149. {stop, Error, reply(State, Error)};
  150. %% ErrorResponse
  151. auth({error, E}, State) ->
  152. case E#error.code of
  153. <<"28000">> -> Why = invalid_authorization_specification;
  154. <<"28P01">> -> Why = invalid_password;
  155. Any -> Why = Any
  156. end,
  157. Error = {error, Why},
  158. {stop, Error, reply(State, Error)};
  159. auth(timeout, State) ->
  160. Error = {error, timeout},
  161. {stop, Error, reply(State, Error)};
  162. auth(Other, State) ->
  163. on_message(Other, State).
  164. %% BackendKeyData
  165. initializing({$K, <<Pid:?int32, Key:?int32>>}, State) ->
  166. #state{timeout = Timeout} = State,
  167. State2 = State#state{backend = {Pid, Key}},
  168. {noreply, State2, Timeout};
  169. initializing(timeout, State) ->
  170. Error = {error, timeout},
  171. {stop, Error, reply(State, Error)};
  172. %% ReadyForQuery
  173. initializing({$Z, <<Status:8>>}, State) ->
  174. #state{parameters = Parameters} = State,
  175. erase(username),
  176. erase(password),
  177. %% TODO decode dates to now() format
  178. case lists:keysearch(<<"integer_datetimes">>, 1, Parameters) of
  179. {value, {_, <<"on">>}} -> put(datetime_mod, pgsql_idatetime);
  180. {value, {_, <<"off">>}} -> put(datetime_mod, pgsql_fdatetime)
  181. end,
  182. State2 = State#state{handler = on_message,
  183. txstatus = Status,
  184. ready = true},
  185. {noreply, reply(State2, {ok, self()})};
  186. initializing({error, _} = Error, State) ->
  187. {stop, Error, reply(State, Error)};
  188. initializing(Other, State) ->
  189. on_message(Other, State).
  190. on_message({$N, Data}, State) ->
  191. %% TODO use it
  192. {notice, pgsql_wire:decode_error(Data)},
  193. {infinity, State};
  194. %% ParameterStatus
  195. on_message({$S, Data}, State) ->
  196. [Name, Value] = pgsql_wire:decode_strings(Data),
  197. Parameters2 = lists:keystore(Name, 1, State#state.parameters,
  198. {Name, Value}),
  199. {noreply, State#state{parameters = Parameters2}};
  200. on_message({$E, Data}, State) ->
  201. %% TODO use it
  202. {error, pgsql_wire:decode_error(Data)},
  203. {infinity, State};
  204. on_message({$A, <<Pid:?int32, Strings/binary>>}, State) ->
  205. case pgsql_wire:decode_strings(Strings) of
  206. [Channel, Payload] -> ok;
  207. [Channel] -> Payload = <<>>
  208. end,
  209. %% TODO use it
  210. {notification, Channel, Pid, Payload},
  211. {infinity, State};
  212. on_message(_Msg, State) ->
  213. {infinity, State}.
  214. hex(Bin) ->
  215. HChar = fun(N) when N < 10 -> $0 + N;
  216. (N) when N < 16 -> $W + N
  217. end,
  218. <<<<(HChar(H)), (HChar(L))>> || <<H:4, L:4>> <= Bin>>.