pgsql_sock.erl 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. %%% Copyright (C) 2009 - Will Glozer. All rights reserved.
  2. %%% Copyright (C) 2011 - Anton Lebedevich. All rights reserved.
  3. -module(pgsql_sock).
  4. -behavior(gen_server).
  5. -export([start_link/0,
  6. close/1,
  7. get_parameter/2,
  8. cancel/1]).
  9. -export([handle_call/3, handle_cast/2, handle_info/2]).
  10. -export([init/1, code_change/3, terminate/2]).
  11. %% state callbacks
  12. -export([auth/2, initializing/2, on_message/2]).
  13. -include("pgsql.hrl").
  14. -include("pgsql_binary.hrl").
  15. -record(state, {mod,
  16. sock,
  17. data = <<>>,
  18. backend,
  19. handler,
  20. queue = queue:new(),
  21. async,
  22. parameters = [],
  23. types = [],
  24. columns = [],
  25. rows = [],
  26. results = [],
  27. sync_required,
  28. txstatus}).
  29. %% -- client interface --
  30. start_link() ->
  31. gen_server:start_link(?MODULE, [], []).
  32. close(C) when is_pid(C) ->
  33. catch gen_server:cast(C, stop),
  34. ok.
  35. get_parameter(C, Name) ->
  36. gen_server:call(C, {get_parameter, to_binary(Name)}, infinity).
  37. cancel(S) ->
  38. gen_server:cast(S, cancel).
  39. %% -- gen_server implementation --
  40. init([]) ->
  41. {ok, #state{}}.
  42. handle_call({get_parameter, Name}, _From, State) ->
  43. case lists:keysearch(Name, 1, State#state.parameters) of
  44. {value, {Name, Value}} -> Value;
  45. false -> Value = undefined
  46. end,
  47. {reply, {ok, Value}, State};
  48. handle_call(Command, From, State) ->
  49. #state{queue = Q} = State,
  50. Req = {{call, From}, Command},
  51. command(Command, State#state{queue = queue:in(Req, Q)}).
  52. handle_cast({{Method, From, Ref}, Command} = Req, State)
  53. when ((Method == cast) or (Method == incremental)),
  54. is_pid(From),
  55. is_reference(Ref) ->
  56. #state{queue = Q} = State,
  57. command(Command, State#state{queue = queue:in(Req, Q)});
  58. handle_cast(stop, State) ->
  59. {stop, normal, flush_queue(State, {error, closed})};
  60. handle_cast(cancel, State = #state{backend = {Pid, Key}}) ->
  61. {ok, {Addr, Port}} = inet:peername(State#state.sock),
  62. SockOpts = [{active, false}, {packet, raw}, binary],
  63. %% TODO timeout
  64. {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
  65. Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
  66. ok = gen_tcp:send(Sock, Msg),
  67. gen_tcp:close(Sock),
  68. {noreply, State}.
  69. handle_info({Closed, Sock}, #state{sock = Sock} = State)
  70. when Closed == tcp_closed; Closed == ssl_closed ->
  71. {stop, sock_closed, flush_queue(State, {error, sock_closed})};
  72. handle_info({Error, Sock, Reason}, #state{sock = Sock} = State)
  73. when Error == tcp_error; Error == ssl_error ->
  74. Why = {sock_error, Reason},
  75. {stop, Why, flush_queue(State, {error, Why})};
  76. handle_info({inet_reply, _, ok}, State) ->
  77. {noreply, State};
  78. handle_info({inet_reply, _, Status}, State) ->
  79. {stop, Status, flush_queue(State, {error, Status})};
  80. handle_info({_, Sock, Data2}, #state{data = Data, sock = Sock} = State) ->
  81. loop(State#state{data = <<Data/binary, Data2/binary>>}).
  82. terminate(_Reason, _State) ->
  83. %% TODO send termination msg, close socket ??
  84. ok.
  85. code_change(_OldVsn, State, _Extra) ->
  86. {ok, State}.
  87. %% -- internal functions --
  88. command(Command, State = #state{sync_required = true})
  89. when Command /= sync ->
  90. {noreply, finish(State, {error, sync_required})};
  91. command({connect, Host, Username, Password, Opts}, State) ->
  92. Timeout = proplists:get_value(timeout, Opts, 5000),
  93. Port = proplists:get_value(port, Opts, 5432),
  94. SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}],
  95. {ok, Sock} = gen_tcp:connect(Host, Port, SockOpts, Timeout),
  96. State2 = case proplists:get_value(ssl, Opts) of
  97. T when T == true; T == required ->
  98. start_ssl(Sock, T, Opts, State);
  99. _ ->
  100. State#state{mod = gen_tcp, sock = Sock}
  101. end,
  102. Opts2 = ["user", 0, Username, 0],
  103. case proplists:get_value(database, Opts, undefined) of
  104. undefined -> Opts3 = Opts2;
  105. Database -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
  106. end,
  107. send(State2, [<<196608:?int32>>, Opts3, 0]),
  108. Async = proplists:get_value(async, Opts, undefined),
  109. setopts(State2, [{active, true}]),
  110. put(username, Username),
  111. put(password, Password),
  112. {noreply,
  113. State2#state{handler = auth,
  114. async = Async}};
  115. command({squery, Sql}, State) ->
  116. send(State, $Q, [Sql, 0]),
  117. {noreply, State};
  118. %% TODO add fast_equery command that doesn't need parsed statement,
  119. %% uses default (text) column format,
  120. %% sends Describe after Bind to get RowDescription
  121. command({equery, Statement, Parameters}, State) ->
  122. #statement{name = StatementName, columns = Columns} = Statement,
  123. Bin1 = pgsql_wire:encode_parameters(Parameters),
  124. Bin2 = pgsql_wire:encode_formats(Columns),
  125. send(State, $B, ["", 0, StatementName, 0, Bin1, Bin2]),
  126. send(State, $E, ["", 0, <<0:?int32>>]),
  127. send(State, $C, [$S, "", 0]),
  128. send(State, $S, []),
  129. {noreply, State};
  130. command({parse, Name, Sql, Types}, State) ->
  131. Bin = pgsql_wire:encode_types(Types),
  132. send(State, $P, [Name, 0, Sql, 0, Bin]),
  133. send(State, $D, [$S, Name, 0]),
  134. send(State, $H, []),
  135. {noreply, State};
  136. command({bind, Statement, PortalName, Parameters}, State) ->
  137. #statement{name = StatementName, columns = Columns, types = Types} = Statement,
  138. Typed_Parameters = lists:zip(Types, Parameters),
  139. Bin1 = pgsql_wire:encode_parameters(Typed_Parameters),
  140. Bin2 = pgsql_wire:encode_formats(Columns),
  141. send(State, $B, [PortalName, 0, StatementName, 0, Bin1, Bin2]),
  142. send(State, $H, []),
  143. {noreply, State};
  144. command({execute, _Statement, PortalName, MaxRows}, State) ->
  145. send(State, $E, [PortalName, 0, <<MaxRows:?int32>>]),
  146. send(State, $H, []),
  147. {noreply, State};
  148. command({describe_statement, Name}, State) ->
  149. send(State, $D, [$S, Name, 0]),
  150. send(State, $H, []),
  151. {noreply, State};
  152. command({describe_portal, Name}, State) ->
  153. send(State, $D, [$P, Name, 0]),
  154. send(State, $H, []),
  155. {noreply, State};
  156. command({close, Type, Name}, State) ->
  157. case Type of
  158. statement -> Type2 = $S;
  159. portal -> Type2 = $P
  160. end,
  161. send(State, $C, [Type2, Name, 0]),
  162. send(State, $H, []),
  163. {noreply, State};
  164. command(sync, State) ->
  165. send(State, $S, []),
  166. {noreply, State#state{sync_required = false}}.
  167. start_ssl(S, Flag, Opts, State) ->
  168. ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
  169. Timeout = proplists:get_value(timeout, Opts, 5000),
  170. {ok, <<Code>>} = gen_tcp:recv(S, 1, Timeout),
  171. case Code of
  172. $S ->
  173. case ssl:connect(S, Opts, Timeout) of
  174. {ok, S2} -> State#state{mod = ssl, sock = S2};
  175. {error, Reason} -> exit({ssl_negotiation_failed, Reason})
  176. end;
  177. $N ->
  178. case Flag of
  179. true -> State;
  180. required -> exit(ssl_not_available)
  181. end
  182. end.
  183. setopts(#state{mod = Mod, sock = Sock}, Opts) ->
  184. case Mod of
  185. gen_tcp -> inet:setopts(Sock, Opts);
  186. ssl -> ssl:setopts(Sock, Opts)
  187. end.
  188. send(#state{mod = Mod, sock = Sock}, Data) ->
  189. do_send(Mod, Sock, pgsql_wire:encode(Data)).
  190. send(#state{mod = Mod, sock = Sock}, Type, Data) ->
  191. do_send(Mod, Sock, pgsql_wire:encode(Type, Data)).
  192. do_send(gen_tcp, Sock, Bin) ->
  193. try erlang:port_command(Sock, Bin) of
  194. true ->
  195. ok
  196. catch
  197. error:_Error ->
  198. {error,einval}
  199. end;
  200. do_send(Mod, Sock, Bin) ->
  201. Mod:send(Sock, Bin).
  202. loop(#state{data = Data, handler = Handler} = State) ->
  203. case pgsql_wire:decode_message(Data) of
  204. {Message, Tail} ->
  205. case ?MODULE:Handler(Message, State#state{data = Tail}) of
  206. {noreply, State2} ->
  207. loop(State2);
  208. R = {stop, _Reason2, _State2} ->
  209. R
  210. end;
  211. _ ->
  212. {noreply, State}
  213. end.
  214. finish(State, Result) ->
  215. finish(State, Result, Result).
  216. finish(State = #state{queue = Q}, Notice, Result) ->
  217. case queue:get(Q) of
  218. {{cast, From, Ref}, _} ->
  219. From ! {self(), Ref, Result};
  220. {{incremental, From, Ref}, _} ->
  221. From ! {self(), Ref, Notice};
  222. {{call, From}, _} ->
  223. gen_server:reply(From, Result)
  224. end,
  225. State#state{queue = queue:drop(Q),
  226. types = [],
  227. columns = [],
  228. rows = [],
  229. results = []}.
  230. add_result(State = #state{queue = Q, results = Results}, Notice, Result) ->
  231. Results2 = case queue:get(Q) of
  232. {{incremental, From, Ref}, _} ->
  233. From ! {self(), Ref, Notice},
  234. Results;
  235. _ ->
  236. [Result | Results]
  237. end,
  238. State#state{types = [],
  239. columns = [],
  240. rows = [],
  241. results = Results2}.
  242. add_row(State = #state{queue = Q, rows = Rows}, Data) ->
  243. Rows2 = case queue:get(Q) of
  244. {{incremental, From, Ref}, _} ->
  245. From ! {self(), Ref, {data, Data}},
  246. Rows;
  247. _ ->
  248. [Data | Rows]
  249. end,
  250. State#state{rows = Rows2}.
  251. notify(State = #state{queue = Q}, Notice) ->
  252. case queue:get(Q) of
  253. {{incremental, From, Ref}, _} ->
  254. From ! {self(), Ref, Notice};
  255. _ ->
  256. ignore
  257. end,
  258. State.
  259. notify_async(State = #state{async = Pid}, Msg) ->
  260. case is_pid(Pid) of
  261. true -> Pid ! {pgsql, self(), Msg};
  262. false -> false
  263. end,
  264. State.
  265. command_tag(#state{queue = Q}) ->
  266. {_, Req} = queue:get(Q),
  267. if is_tuple(Req) ->
  268. element(1, Req);
  269. is_atom(Req) ->
  270. Req
  271. end.
  272. get_columns(State) ->
  273. #state{queue = Q, columns = Columns} = State,
  274. case queue:get(Q) of
  275. {_, {equery, #statement{columns = C}, _}} ->
  276. C;
  277. {_, {execute, #statement{columns = C}, _, _}} ->
  278. C;
  279. {_, {squery, _}} ->
  280. Columns
  281. end.
  282. make_statement(State) ->
  283. #state{queue = Q, types = Types, columns = Columns} = State,
  284. Name = case queue:get(Q) of
  285. {_, {parse, N, _, _}} -> N;
  286. {_, {describe_statement, N}} -> N
  287. end,
  288. #statement{name = Name, types = Types, columns = Columns}.
  289. sync_required(#state{queue = Q} = State) ->
  290. case queue:is_empty(Q) of
  291. false ->
  292. case command_tag(State) of
  293. sync ->
  294. State;
  295. _ ->
  296. sync_required(finish(State, {error, sync_required}))
  297. end;
  298. true ->
  299. State#state{sync_required = true}
  300. end.
  301. flush_queue(#state{queue = Q} = State, Error) ->
  302. case queue:is_empty(Q) of
  303. false ->
  304. flush_queue(finish(State, Error), Error);
  305. true -> State
  306. end.
  307. to_binary(B) when is_binary(B) -> B;
  308. to_binary(L) when is_list(L) -> list_to_binary(L).
  309. hex(Bin) ->
  310. HChar = fun(N) when N < 10 -> $0 + N;
  311. (N) when N < 16 -> $W + N
  312. end,
  313. <<<<(HChar(H)), (HChar(L))>> || <<H:4, L:4>> <= Bin>>.
  314. %% -- backend message handling --
  315. %% AuthenticationOk
  316. auth({$R, <<0:?int32>>}, State) ->
  317. {noreply, State#state{handler = initializing}};
  318. %% AuthenticationCleartextPassword
  319. auth({$R, <<3:?int32>>}, State) ->
  320. send(State, $p, [get(password), 0]),
  321. {noreply, State};
  322. %% AuthenticationMD5Password
  323. auth({$R, <<5:?int32, Salt:4/binary>>}, State) ->
  324. Digest1 = hex(erlang:md5([get(password), get(username)])),
  325. Str = ["md5", hex(erlang:md5([Digest1, Salt])), 0],
  326. send(State, $p, Str),
  327. {noreply, State};
  328. auth({$R, <<M:?int32, _/binary>>}, State) ->
  329. case M of
  330. 2 -> Method = kerberosV5;
  331. 4 -> Method = crypt;
  332. 6 -> Method = scm;
  333. 7 -> Method = gss;
  334. 8 -> Method = sspi;
  335. _ -> Method = unknown
  336. end,
  337. State2 = finish(State, {error, {unsupported_auth_method, Method}}),
  338. {stop, normal, State2};
  339. %% ErrorResponse
  340. auth({error, E}, State) ->
  341. case E#error.code of
  342. <<"28000">> -> Why = invalid_authorization_specification;
  343. <<"28P01">> -> Why = invalid_password;
  344. Any -> Why = Any
  345. end,
  346. {stop, normal, finish(State, {error, Why})};
  347. auth(Other, State) ->
  348. on_message(Other, State).
  349. %% BackendKeyData
  350. initializing({$K, <<Pid:?int32, Key:?int32>>}, State) ->
  351. {noreply, State#state{backend = {Pid, Key}}};
  352. %% ReadyForQuery
  353. initializing({$Z, <<Status:8>>}, State) ->
  354. #state{parameters = Parameters} = State,
  355. erase(username),
  356. erase(password),
  357. %% TODO decode dates to now() format
  358. case lists:keysearch(<<"integer_datetimes">>, 1, Parameters) of
  359. {value, {_, <<"on">>}} -> put(datetime_mod, pgsql_idatetime);
  360. {value, {_, <<"off">>}} -> put(datetime_mod, pgsql_fdatetime)
  361. end,
  362. State2 = finish(State#state{handler = on_message,
  363. txstatus = Status},
  364. connected),
  365. {noreply, State2};
  366. initializing({error, _} = Error, State) ->
  367. {stop, normal, finish(State, Error)};
  368. initializing(Other, State) ->
  369. on_message(Other, State).
  370. %% ParseComplete
  371. on_message({$1, <<>>}, State) ->
  372. {noreply, State};
  373. %% ParameterDescription
  374. on_message({$t, <<_Count:?int16, Bin/binary>>}, State) ->
  375. Types = [pgsql_types:oid2type(Oid) || <<Oid:?int32>> <= Bin],
  376. State2 = notify(State#state{types = Types}, {types, Types}),
  377. {noreply, State2};
  378. %% RowDescription
  379. on_message({$T, <<Count:?int16, Bin/binary>>}, State) ->
  380. Columns = pgsql_wire:decode_columns(Count, Bin),
  381. Columns2 =
  382. case command_tag(State) of
  383. C when C == describe_portal; C == squery ->
  384. Columns;
  385. C when C == parse; C == describe_statement ->
  386. [Col#column{format = pgsql_wire:format(Col#column.type)}
  387. || Col <- Columns]
  388. end,
  389. State2 = State#state{columns = Columns2},
  390. Message = {columns, Columns2},
  391. State3 = case command_tag(State2) of
  392. squery ->
  393. notify(State2, Message);
  394. T when T == parse; T == describe_statement ->
  395. finish(State2, Message, {ok, make_statement(State2)});
  396. describe_portal ->
  397. finish(State2, Message, {ok, Columns})
  398. end,
  399. {noreply, State3};
  400. %% NoData
  401. on_message({$n, <<>>}, State) ->
  402. State2 = case command_tag(State) of
  403. C when C == parse; C == describe_statement ->
  404. finish(State, no_data, {ok, make_statement(State)});
  405. describe_portal ->
  406. finish(State, no_data, {ok, []})
  407. end,
  408. {noreply, State2};
  409. %% BindComplete
  410. on_message({$2, <<>>}, State) ->
  411. State2 = case command_tag(State) of
  412. equery ->
  413. %% TODO send Describe as a part of equery, needs text format support
  414. notify(State, {columns, get_columns(State)});
  415. bind ->
  416. finish(State, ok)
  417. end,
  418. {noreply, State2};
  419. %% CloseComplete
  420. on_message({$3, <<>>}, State) ->
  421. State2 = case command_tag(State) of
  422. equery ->
  423. State;
  424. close ->
  425. finish(State, ok)
  426. end,
  427. {noreply, State2};
  428. %% DataRow
  429. on_message({$D, <<_Count:?int16, Bin/binary>>}, State) ->
  430. Data = pgsql_wire:decode_data(get_columns(State), Bin),
  431. {noreply, add_row(State, Data)};
  432. %% PortalSuspended
  433. on_message({$s, <<>>}, State) ->
  434. State2 = finish(State,
  435. suspended,
  436. {partial, lists:reverse(State#state.rows)}),
  437. {noreply, State2};
  438. %% CommandComplete
  439. on_message({$C, Bin}, State) ->
  440. Complete = pgsql_wire:decode_complete(Bin),
  441. Command = command_tag(State),
  442. Notice = {complete, Complete},
  443. Rows = lists:reverse(State#state.rows),
  444. State2 = case {Command, Complete, Rows} of
  445. {execute, {_, Count}, []} ->
  446. finish(State, Notice, {ok, Count});
  447. {execute, {_, Count}, _} ->
  448. finish(State, Notice, {ok, Count, Rows});
  449. {execute, _, _} ->
  450. finish(State, Notice, {ok, Rows});
  451. {C, {_, Count}, []} when C == squery; C == equery ->
  452. add_result(State, Notice, {ok, Count});
  453. {C, {_, Count}, _} when C == squery; C == equery ->
  454. add_result(State, Notice, {ok, Count, get_columns(State), Rows});
  455. {C, _, _} when C == squery; C == equery ->
  456. add_result(State, Notice, {ok, get_columns(State), Rows})
  457. end,
  458. {noreply, State2};
  459. %% EmptyQueryResponse
  460. on_message({$I, _Bin}, State) ->
  461. Notice = {complete, empty},
  462. State2 = case command_tag(State) of
  463. execute ->
  464. finish(State, Notice, {ok, [], []});
  465. C when C == squery; C == equery ->
  466. add_result(State, Notice, {ok, [], []})
  467. end,
  468. {noreply, State2};
  469. %% ReadyForQuery
  470. on_message({$Z, <<Status:8>>}, State) ->
  471. State2 = case command_tag(State) of
  472. squery ->
  473. case State#state.results of
  474. [Result] ->
  475. finish(State, done, Result);
  476. Results ->
  477. finish(State, done, lists:reverse(Results))
  478. end;
  479. equery ->
  480. case State#state.results of
  481. [Result] ->
  482. finish(State, done, Result);
  483. [] ->
  484. finish(State, done)
  485. end;
  486. sync ->
  487. finish(State, ok)
  488. end,
  489. {noreply, State2#state{txstatus = Status}};
  490. on_message(Error = {error, _}, State) ->
  491. State2 = case command_tag(State) of
  492. C when C == squery; C == equery ->
  493. add_result(State, Error, Error);
  494. _ ->
  495. sync_required(finish(State, Error))
  496. end,
  497. {noreply, State2};
  498. %% NoticeResponse
  499. on_message({$N, Data}, State) ->
  500. State2 = notify_async(State, {notice, pgsql_wire:decode_error(Data)}),
  501. {noreply, State2};
  502. %% ParameterStatus
  503. on_message({$S, Data}, State) ->
  504. [Name, Value] = pgsql_wire:decode_strings(Data),
  505. Parameters2 = lists:keystore(Name, 1, State#state.parameters,
  506. {Name, Value}),
  507. {noreply, State#state{parameters = Parameters2}};
  508. %% NotificationResponse
  509. on_message({$A, <<Pid:?int32, Strings/binary>>}, State) ->
  510. case pgsql_wire:decode_strings(Strings) of
  511. [Channel, Payload] -> ok;
  512. [Channel] -> Payload = <<>>
  513. end,
  514. State2 = notify_async(State, {notification, Channel, Pid, Payload}),
  515. {noreply, State2}.