pgsql_sock.erl 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. %%% Copyright (C) 2009 - Will Glozer. All rights reserved.
  2. -module(pgsql_sock).
  3. -behavior(gen_server).
  4. -export([start_link/0,
  5. connect/5,
  6. close/1,
  7. get_parameter/2,
  8. squery/2,
  9. parse/4,
  10. bind/4,
  11. execute/4,
  12. describe/3,
  13. sync/1,
  14. close/3,
  15. cancel/1]).
  16. -export([handle_call/3, handle_cast/2, handle_info/2]).
  17. -export([init/1, code_change/3, terminate/2]).
  18. %% state callbacks
  19. -export([auth/2, initializing/2, on_message/2]).
  20. -include("pgsql.hrl").
  21. -include("pgsql_binary.hrl").
  22. -record(state, {mod,
  23. sock,
  24. data = <<>>,
  25. backend,
  26. handler,
  27. queue = queue:new(),
  28. async,
  29. ready,
  30. timeout = 5000,
  31. parameters = [],
  32. txstatus}).
  33. %% -- client interface --
  34. start_link() ->
  35. gen_server:start_link(?MODULE, [], []).
  36. connect(C, Host, Username, Password, Opts) ->
  37. gen_server:call(C, {connect, Host, Username, Password, Opts}, infinity).
  38. close(C) when is_pid(C) ->
  39. catch gen_server:call(C, stop, infinity),
  40. ok.
  41. get_parameter(C, Name) ->
  42. gen_server:call(C, {get_parameter, to_binary(Name)}).
  43. squery(C, Sql) ->
  44. gen_server:call(C, {squery, Sql}, infinity).
  45. parse(C, Name, Sql, Types) ->
  46. gen_server:call(C, {parse, Name, Sql, Types}, infinity).
  47. bind(C, Statement, PortalName, Parameters) ->
  48. gen_server:call(C, {bind, Statement, PortalName, Parameters}, infinity).
  49. execute(C, Statement, PortalName, MaxRows) ->
  50. gen_server:call(C, {execute, Statement, PortalName, MaxRows}, infinity).
  51. describe(C, Type, Name) ->
  52. gen_server:call(C, {describe, Type, Name}, infinity).
  53. close(C, Type, Name) ->
  54. gen_server:call(C, {close, Type, Name}, infinity).
  55. sync(C) ->
  56. gen_server:call(C, sync, infinity).
  57. cancel(S) ->
  58. gen_server:cast(S, cancel).
  59. %% -- gen_server implementation --
  60. init([]) ->
  61. {ok, #state{}}.
  62. handle_call({connect, Host, Username, Password, Opts},
  63. From,
  64. #state{queue = Queue} = State) ->
  65. %% TODO split connect/query timeout?
  66. Timeout = proplists:get_value(timeout, Opts, 5000),
  67. Port = proplists:get_value(port, Opts, 5432),
  68. SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}],
  69. {ok, Sock} = gen_tcp:connect(Host, Port, SockOpts, Timeout),
  70. State2 = case proplists:get_value(ssl, Opts) of
  71. T when T == true; T == required ->
  72. start_ssl(Sock, T, Opts, State);
  73. _ ->
  74. State#state{mod = gen_tcp, sock = Sock}
  75. end,
  76. Opts2 = ["user", 0, Username, 0],
  77. case proplists:get_value(database, Opts, undefined) of
  78. undefined -> Opts3 = Opts2;
  79. Database -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
  80. end,
  81. send(State2, [<<196608:?int32>>, Opts3, 0]),
  82. Async = proplists:get_value(async, Opts, undefined),
  83. setopts(State2, [{active, true}]),
  84. put(username, Username),
  85. put(password, Password),
  86. {noreply,
  87. State2#state{handler = auth,
  88. queue = queue:in(From, Queue),
  89. async = Async},
  90. Timeout};
  91. handle_call(stop, From, #state{queue = Queue} = State) ->
  92. %% TODO flush queue
  93. {stop, normal, ok, State}.
  94. handle_cast(cancel, State = #state{backend = {Pid, Key}}) ->
  95. {ok, {Addr, Port}} = inet:peername(State#state.sock),
  96. SockOpts = [{active, false}, {packet, raw}, binary],
  97. {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
  98. Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
  99. ok = gen_tcp:send(Sock, Msg),
  100. gen_tcp:close(Sock),
  101. {noreply, State}.
  102. handle_info({Closed, Sock}, #state{sock = Sock} = State)
  103. when Closed == tcp_closed; Closed == ssl_closed ->
  104. {stop, sock_closed, State};
  105. handle_info({Error, Sock, Reason}, #state{sock = Sock} = State)
  106. when Error == tcp_error; Error == ssl_error ->
  107. {stop, {sock_error, Reason}, State};
  108. handle_info(timeout, #state{handler = Handler} = State) ->
  109. ?MODULE:Handler(timeout, State);
  110. handle_info({_, Sock, Data2}, #state{data = Data, sock = Sock} = State) ->
  111. loop(State#state{data = <<Data/binary, Data2/binary>>}, infinity).
  112. loop(#state{data = Data, handler = Handler} = State, Timeout) ->
  113. case pgsql_wire:decode_message(Data) of
  114. {Message, Tail} ->
  115. case ?MODULE:Handler(Message, State#state{data = Tail}) of
  116. {noreply, State2} ->
  117. loop(State2, infinity);
  118. {noreply, State2, Timeout2} ->
  119. loop(State2, Timeout2);
  120. R = {stop, _Reason2, _State2} ->
  121. R
  122. end;
  123. _ ->
  124. {noreply, State, Timeout}
  125. end.
  126. terminate(_Reason, _State) ->
  127. %% TODO send termination msg, close socket
  128. ok.
  129. code_change(_OldVsn, State, _Extra) ->
  130. {ok, State}.
  131. %% -- internal functions --
  132. start_ssl(S, Flag, Opts, #state{timeout = Timeout} = State) ->
  133. ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
  134. {ok, <<Code>>} = gen_tcp:recv(S, 1, Timeout),
  135. case Code of
  136. $S ->
  137. case ssl:connect(S, Opts, Timeout) of
  138. {ok, S2} -> State#state{mod = ssl, sock = S2};
  139. {error, Reason} -> exit({ssl_negotiation_failed, Reason})
  140. end;
  141. $N ->
  142. case Flag of
  143. true -> State;
  144. required -> exit(ssl_not_available)
  145. end
  146. end.
  147. setopts(#state{mod = Mod, sock = Sock}, Opts) ->
  148. case Mod of
  149. gen_tcp -> inet:setopts(Sock, Opts);
  150. ssl -> ssl:setopts(Sock, Opts)
  151. end.
  152. send(#state{mod = Mod, sock = Sock}, Data) ->
  153. Mod:send(Sock, pgsql_wire:encode(Data)).
  154. send(#state{mod = Mod, sock = Sock}, Type, Data) ->
  155. Mod:send(Sock, pgsql_wire:encode(Type, Data)).
  156. reply(#state{queue = Q} = State, Message) ->
  157. {{value, {Pid, _}}, Q2} = queue:out(Q),
  158. Pid ! Message,
  159. State#state{queue = Q2}.
  160. gen_reply(#state{queue = Q} = State, Message) ->
  161. {{value, From}, Q2} = queue:out(Q),
  162. gen_server:reply(From, Message),
  163. State#state{queue = Q2}.
  164. notify_async(#state{async = Pid}, Msg) ->
  165. case is_pid(Pid) of
  166. true -> Pid ! {pgsql, self(), Msg};
  167. false -> false
  168. end.
  169. %% -- backend message handling --
  170. %% AuthenticationOk
  171. auth({$R, <<0:?int32>>}, State) ->
  172. #state{timeout = Timeout} = State,
  173. {noreply, State#state{handler = initializing}, Timeout};
  174. %% AuthenticationCleartextPassword
  175. auth({$R, <<3:?int32>>}, State) ->
  176. #state{timeout = Timeout} = State,
  177. send(State, $p, [get(password), 0]),
  178. {noreply, State, Timeout};
  179. %% AuthenticationMD5Password
  180. auth({$R, <<5:?int32, Salt:4/binary>>}, State) ->
  181. #state{timeout = Timeout} = State,
  182. Digest1 = hex(erlang:md5([get(password), get(username)])),
  183. Str = ["md5", hex(erlang:md5([Digest1, Salt])), 0],
  184. send(State, $p, Str),
  185. {noreply, State, Timeout};
  186. auth({$R, <<M:?int32, _/binary>>}, State) ->
  187. case M of
  188. 2 -> Method = kerberosV5;
  189. 4 -> Method = crypt;
  190. 6 -> Method = scm;
  191. 7 -> Method = gss;
  192. 8 -> Method = sspi;
  193. _ -> Method = unknown
  194. end,
  195. {stop,
  196. normal,
  197. gen_reply(State, {error, {unsupported_auth_method, Method}})};
  198. %% ErrorResponse
  199. auth({error, E}, State) ->
  200. case E#error.code of
  201. <<"28000">> -> Why = invalid_authorization_specification;
  202. <<"28P01">> -> Why = invalid_password;
  203. Any -> Why = Any
  204. end,
  205. {stop, normal, gen_reply(State, {error, Why})};
  206. auth(timeout, State) ->
  207. {stop, normal, gen_reply(State, {error, timeout})};
  208. auth(Other, State) ->
  209. #state{timeout = Timeout} = State,
  210. {noreply, State2} = on_message(Other, State),
  211. {noreply, State2, Timeout}.
  212. %% BackendKeyData
  213. initializing({$K, <<Pid:?int32, Key:?int32>>}, State) ->
  214. #state{timeout = Timeout} = State,
  215. State2 = State#state{backend = {Pid, Key}},
  216. {noreply, State2, Timeout};
  217. initializing(timeout, State) ->
  218. {stop, normal, gen_reply(State, {error, timeout})};
  219. %% ReadyForQuery
  220. initializing({$Z, <<Status:8>>}, State) ->
  221. #state{parameters = Parameters} = State,
  222. erase(username),
  223. erase(password),
  224. %% TODO decode dates to now() format
  225. case lists:keysearch(<<"integer_datetimes">>, 1, Parameters) of
  226. {value, {_, <<"on">>}} -> put(datetime_mod, pgsql_idatetime);
  227. {value, {_, <<"off">>}} -> put(datetime_mod, pgsql_fdatetime)
  228. end,
  229. State2 = State#state{handler = on_message,
  230. txstatus = Status,
  231. ready = true},
  232. {noreply, gen_reply(State2, {ok, self()})};
  233. initializing({error, _} = Error, State) ->
  234. {stop, normal, gen_reply(State, Error)};
  235. initializing(Other, State) ->
  236. #state{timeout = Timeout} = State,
  237. {noreply, State2} = on_message(Other, State),
  238. {noreply, State2, Timeout}.
  239. %% NoticeResponse
  240. on_message({$N, Data}, State) ->
  241. notify_async(State, {notice, pgsql_wire:decode_error(Data)}),
  242. {noreply, State};
  243. %% ParameterStatus
  244. on_message({$S, Data}, State) ->
  245. [Name, Value] = pgsql_wire:decode_strings(Data),
  246. Parameters2 = lists:keystore(Name, 1, State#state.parameters,
  247. {Name, Value}),
  248. {noreply, State#state{parameters = Parameters2}};
  249. %% NotificationResponse
  250. on_message({$A, <<Pid:?int32, Strings/binary>>}, State) ->
  251. case pgsql_wire:decode_strings(Strings) of
  252. [Channel, Payload] -> ok;
  253. [Channel] -> Payload = <<>>
  254. end,
  255. %% TODO use it
  256. notify_async(State, {notification, Channel, Pid, Payload}),
  257. {noreply, State}.
  258. to_binary(B) when is_binary(B) -> B;
  259. to_binary(L) when is_list(L) -> list_to_binary(L).
  260. hex(Bin) ->
  261. HChar = fun(N) when N < 10 -> $0 + N;
  262. (N) when N < 16 -> $W + N
  263. end,
  264. <<<<(HChar(H)), (HChar(L))>> || <<H:4, L:4>> <= Bin>>.