pgsql_sock.erl 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. %%% Copyright (C) 2009 - Will Glozer. All rights reserved.
  2. %%% Copyright (C) 2011 - Anton Lebedevich. All rights reserved.
  3. -module(pgsql_sock).
  4. -behavior(gen_server).
  5. -export([start_link/0,
  6. connect/5,
  7. close/1,
  8. get_parameter/2,
  9. squery/2,
  10. equery/3,
  11. parse/4,
  12. bind/4,
  13. execute/4,
  14. describe/3,
  15. close/3,
  16. sync/1,
  17. cancel/1]).
  18. -export([handle_call/3, handle_cast/2, handle_info/2]).
  19. -export([init/1, code_change/3, terminate/2]).
  20. %% state callbacks
  21. -export([auth/2, initializing/2, on_message/2]).
  22. -include("pgsql.hrl").
  23. -include("pgsql_binary.hrl").
  24. -record(state, {mod,
  25. sock,
  26. data = <<>>,
  27. backend,
  28. handler,
  29. queue = queue:new(),
  30. async,
  31. parameters = [],
  32. statement,
  33. columns = [],
  34. rows = [],
  35. results = [],
  36. sync_required,
  37. txstatus}).
  38. %% -- client interface --
  39. start_link() ->
  40. gen_server:start_link(?MODULE, [], []).
  41. connect(C, Host, Username, Password, Opts) ->
  42. cast(C, {connect, Host, Username, Password, Opts}).
  43. %% TODO extract API functions
  44. close(C) when is_pid(C) ->
  45. catch gen_server:cast(C, stop),
  46. ok.
  47. get_parameter(C, Name) ->
  48. gen_server:call(C, {get_parameter, to_binary(Name)}, infinity).
  49. squery(C, Sql) ->
  50. cast(C, {squery, Sql}).
  51. equery(C, Statement, Parameters) ->
  52. cast(C, {equery, Statement, Parameters}).
  53. parse(C, Name, Sql, Types) ->
  54. cast(C, {parse, Name, Sql, Types}).
  55. bind(C, Statement, PortalName, Parameters) ->
  56. cast(C, {bind, Statement, PortalName, Parameters}).
  57. execute(C, Statement, PortalName, MaxRows) ->
  58. cast(C, {execute, Statement, PortalName, MaxRows}).
  59. describe(C, statement, Name) ->
  60. cast(C, {describe_statement, Name});
  61. describe(C, portal, Name) ->
  62. cast(C, {describe_portal, Name}).
  63. close(C, Type, Name) ->
  64. cast(C, {close, Type, Name}).
  65. sync(C) ->
  66. cast(C, sync).
  67. cancel(S) ->
  68. gen_server:cast(S, cancel).
  69. %% -- gen_server implementation --
  70. init([]) ->
  71. {ok, #state{}}.
  72. handle_call({get_parameter, Name}, _From, State) ->
  73. case lists:keysearch(Name, 1, State#state.parameters) of
  74. {value, {Name, Value}} -> Value;
  75. false -> Value = undefined
  76. end,
  77. {reply, {ok, Value}, State}.
  78. %% TODO request format broken
  79. handle_cast({{From, Ref}, Command}, State = #state{sync_required = true})
  80. when Command /= sync ->
  81. From ! {Ref, {error, sync_required}},
  82. {noreply, State};
  83. handle_cast(Req = {_, {connect, Host, Username, Password, Opts}}, State) ->
  84. #state{queue = Q} = State,
  85. Timeout = proplists:get_value(timeout, Opts, 5000),
  86. Port = proplists:get_value(port, Opts, 5432),
  87. SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}],
  88. {ok, Sock} = gen_tcp:connect(Host, Port, SockOpts, Timeout),
  89. State2 = case proplists:get_value(ssl, Opts) of
  90. T when T == true; T == required ->
  91. start_ssl(Sock, T, Opts, State);
  92. _ ->
  93. State#state{mod = gen_tcp, sock = Sock}
  94. end,
  95. Opts2 = ["user", 0, Username, 0],
  96. case proplists:get_value(database, Opts, undefined) of
  97. undefined -> Opts3 = Opts2;
  98. Database -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
  99. end,
  100. send(State2, [<<196608:?int32>>, Opts3, 0]),
  101. Async = proplists:get_value(async, Opts, undefined),
  102. setopts(State2, [{active, true}]),
  103. put(username, Username),
  104. put(password, Password),
  105. {noreply,
  106. State2#state{handler = auth,
  107. queue = queue:in(Req, Q),
  108. async = Async}};
  109. handle_cast(stop, State) ->
  110. {stop, normal, flush_queue(State, {error, closed})};
  111. handle_cast(Req = {_, {squery, Sql}}, State) ->
  112. #state{queue = Q} = State,
  113. send(State, $Q, [Sql, 0]),
  114. {noreply, State#state{queue = queue:in(Req, Q)}};
  115. %% TODO add fast_equery command that doesn't need parsed statement,
  116. %% uses default (text) column format,
  117. %% sends Describe after Bind to get RowDescription
  118. handle_cast(Req = {_, {equery, Statement, Parameters}}, State) ->
  119. #state{queue = Q} = State,
  120. #statement{name = StatementName, columns = Columns} = Statement,
  121. Bin1 = pgsql_wire:encode_parameters(Parameters),
  122. Bin2 = pgsql_wire:encode_formats(Columns),
  123. send(State, $B, ["", 0, StatementName, 0, Bin1, Bin2]),
  124. send(State, $E, ["", 0, <<0:?int32>>]),
  125. send(State, $C, [$S, "", 0]),
  126. send(State, $S, []),
  127. {noreply, State#state{queue = queue:in(Req, Q)}};
  128. handle_cast(Req = {_, {parse, Name, Sql, Types}}, State) ->
  129. #state{queue = Q} = State,
  130. Bin = pgsql_wire:encode_types(Types),
  131. send(State, $P, [Name, 0, Sql, 0, Bin]),
  132. send(State, $D, [$S, Name, 0]),
  133. send(State, $H, []),
  134. {noreply, State#state{queue = queue:in(Req, Q)}};
  135. handle_cast(Req = {_, {bind, Statement, PortalName, Parameters}}, State) ->
  136. #state{queue = Q} = State,
  137. #statement{name = StatementName, columns = Columns, types = Types} = Statement,
  138. Typed_Parameters = lists:zip(Types, Parameters),
  139. Bin1 = pgsql_wire:encode_parameters(Typed_Parameters),
  140. Bin2 = pgsql_wire:encode_formats(Columns),
  141. send(State, $B, [PortalName, 0, StatementName, 0, Bin1, Bin2]),
  142. send(State, $H, []),
  143. {noreply, State#state{queue = queue:in(Req, Q)}};
  144. handle_cast(Req = {_, {execute, _, PortalName, MaxRows}}, State) ->
  145. #state{queue = Q} = State,
  146. send(State, $E, [PortalName, 0, <<MaxRows:?int32>>]),
  147. send(State, $H, []),
  148. {noreply, State#state{queue = queue:in(Req, Q)}};
  149. handle_cast(Req = {_, {describe_statement, Name}}, State) ->
  150. #state{queue = Q} = State,
  151. send(State, $D, [$S, Name, 0]),
  152. send(State, $H, []),
  153. {noreply, State#state{queue = queue:in(Req, Q)}};
  154. handle_cast(Req = {_, {describe_portal, Name}}, State) ->
  155. #state{queue = Q} = State,
  156. send(State, $D, [$P, Name, 0]),
  157. send(State, $H, []),
  158. {noreply, State#state{queue = queue:in(Req, Q)}};
  159. handle_cast(Req = {_, {close, Type, Name}}, State) ->
  160. #state{queue = Q} = State,
  161. case Type of
  162. statement -> Type2 = $S;
  163. portal -> Type2 = $P
  164. end,
  165. send(State, $C, [Type2, Name, 0]),
  166. send(State, $H, []),
  167. {noreply, State#state{queue = queue:in(Req, Q)}};
  168. handle_cast(Req = {_, sync}, State) ->
  169. #state{queue = Q} = State,
  170. send(State, $S, []),
  171. {noreply, State#state{queue = queue:in(Req, Q), sync_required = false}};
  172. handle_cast(cancel, State = #state{backend = {Pid, Key}}) ->
  173. {ok, {Addr, Port}} = inet:peername(State#state.sock),
  174. SockOpts = [{active, false}, {packet, raw}, binary],
  175. %% TODO timeout
  176. {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
  177. Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
  178. ok = gen_tcp:send(Sock, Msg),
  179. gen_tcp:close(Sock),
  180. {noreply, State}.
  181. handle_info({Closed, Sock}, #state{sock = Sock} = State)
  182. when Closed == tcp_closed; Closed == ssl_closed ->
  183. {stop, sock_closed, flush_queue(State, {error, sock_closed})};
  184. handle_info({Error, Sock, Reason}, #state{sock = Sock} = State)
  185. when Error == tcp_error; Error == ssl_error ->
  186. Why = {sock_error, Reason},
  187. {stop, Why, flush_queue(State, {error, Why})};
  188. handle_info({_, Sock, Data2}, #state{data = Data, sock = Sock} = State) ->
  189. loop(State#state{data = <<Data/binary, Data2/binary>>}).
  190. terminate(_Reason, _State) ->
  191. %% TODO send termination msg, close socket ??
  192. ok.
  193. code_change(_OldVsn, State, _Extra) ->
  194. {ok, State}.
  195. %% -- internal functions --
  196. cast(C, Command) ->
  197. Ref = make_ref(),
  198. gen_server:cast(C, {{self(), Ref}, Command}),
  199. Ref.
  200. start_ssl(S, Flag, Opts, State) ->
  201. ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
  202. Timeout = proplists:get_value(timeout, Opts, 5000),
  203. {ok, <<Code>>} = gen_tcp:recv(S, 1, Timeout),
  204. case Code of
  205. $S ->
  206. case ssl:connect(S, Opts, Timeout) of
  207. {ok, S2} -> State#state{mod = ssl, sock = S2};
  208. {error, Reason} -> exit({ssl_negotiation_failed, Reason})
  209. end;
  210. $N ->
  211. case Flag of
  212. true -> State;
  213. required -> exit(ssl_not_available)
  214. end
  215. end.
  216. setopts(#state{mod = Mod, sock = Sock}, Opts) ->
  217. case Mod of
  218. gen_tcp -> inet:setopts(Sock, Opts);
  219. ssl -> ssl:setopts(Sock, Opts)
  220. end.
  221. send(#state{mod = Mod, sock = Sock}, Data) ->
  222. Mod:send(Sock, pgsql_wire:encode(Data)).
  223. send(#state{mod = Mod, sock = Sock}, Type, Data) ->
  224. Mod:send(Sock, pgsql_wire:encode(Type, Data)).
  225. loop(#state{data = Data, handler = Handler} = State) ->
  226. case pgsql_wire:decode_message(Data) of
  227. {Message, Tail} ->
  228. case ?MODULE:Handler(Message, State#state{data = Tail}) of
  229. {noreply, State2} ->
  230. loop(State2);
  231. R = {stop, _Reason2, _State2} ->
  232. R
  233. end;
  234. _ ->
  235. {noreply, State}
  236. end.
  237. reply(State = #state{queue = Q}, Message) ->
  238. {{From, Ref}, _} = queue:get(Q),
  239. From ! {Ref, Message},
  240. State#state{queue = queue:drop(Q),
  241. statement = undefined,
  242. columns = [],
  243. rows = [],
  244. results = []}.
  245. add_result(State = #state{results = Results}, Result) ->
  246. State#state{statement = undefined,
  247. columns = [],
  248. rows = [],
  249. results = [Result | Results]}.
  250. notify_async(#state{async = Pid}, Msg) ->
  251. case is_pid(Pid) of
  252. true -> Pid ! {pgsql, self(), Msg};
  253. false -> false
  254. end.
  255. command_tag(#state{queue = Q}) ->
  256. {_, Req} = queue:get(Q),
  257. if is_tuple(Req) ->
  258. element(1, Req);
  259. is_atom(Req) ->
  260. Req
  261. end.
  262. get_columns(State) ->
  263. #state{queue = Q, columns = Columns} = State,
  264. case queue:get(Q) of
  265. {_, {equery, #statement{columns = C}, _}} ->
  266. C;
  267. {_, {execute, #statement{columns = C}, _, _}} ->
  268. C;
  269. {_, {squery, _}} ->
  270. Columns
  271. end.
  272. sync_required(#state{queue = Q} = State) ->
  273. case queue:is_empty(Q) of
  274. false ->
  275. case command_tag(State) of
  276. sync ->
  277. State;
  278. _ ->
  279. sync_required(reply(State, {error, sync_required}))
  280. end;
  281. true ->
  282. State#state{sync_required = true}
  283. end.
  284. flush_queue(#state{queue = Q} = State, Error) ->
  285. case queue:is_empty(Q) of
  286. false ->
  287. flush_queue(reply(State, Error), Error);
  288. true -> State
  289. end.
  290. to_binary(B) when is_binary(B) -> B;
  291. to_binary(L) when is_list(L) -> list_to_binary(L).
  292. hex(Bin) ->
  293. HChar = fun(N) when N < 10 -> $0 + N;
  294. (N) when N < 16 -> $W + N
  295. end,
  296. <<<<(HChar(H)), (HChar(L))>> || <<H:4, L:4>> <= Bin>>.
  297. %% -- backend message handling --
  298. %% AuthenticationOk
  299. auth({$R, <<0:?int32>>}, State) ->
  300. {noreply, State#state{handler = initializing}};
  301. %% AuthenticationCleartextPassword
  302. auth({$R, <<3:?int32>>}, State) ->
  303. send(State, $p, [get(password), 0]),
  304. {noreply, State};
  305. %% AuthenticationMD5Password
  306. auth({$R, <<5:?int32, Salt:4/binary>>}, State) ->
  307. Digest1 = hex(erlang:md5([get(password), get(username)])),
  308. Str = ["md5", hex(erlang:md5([Digest1, Salt])), 0],
  309. send(State, $p, Str),
  310. {noreply, State};
  311. auth({$R, <<M:?int32, _/binary>>}, State) ->
  312. case M of
  313. 2 -> Method = kerberosV5;
  314. 4 -> Method = crypt;
  315. 6 -> Method = scm;
  316. 7 -> Method = gss;
  317. 8 -> Method = sspi;
  318. _ -> Method = unknown
  319. end,
  320. State2 = reply(State, {error, {unsupported_auth_method, Method}}),
  321. {stop, normal, State2};
  322. %% ErrorResponse
  323. auth({error, E}, State) ->
  324. case E#error.code of
  325. <<"28000">> -> Why = invalid_authorization_specification;
  326. <<"28P01">> -> Why = invalid_password;
  327. Any -> Why = Any
  328. end,
  329. {stop, normal, reply(State, {error, Why})};
  330. auth(Other, State) ->
  331. on_message(Other, State).
  332. %% BackendKeyData
  333. initializing({$K, <<Pid:?int32, Key:?int32>>}, State) ->
  334. {noreply, State#state{backend = {Pid, Key}}};
  335. %% ReadyForQuery
  336. initializing({$Z, <<Status:8>>}, State) ->
  337. #state{parameters = Parameters} = State,
  338. erase(username),
  339. erase(password),
  340. %% TODO decode dates to now() format
  341. case lists:keysearch(<<"integer_datetimes">>, 1, Parameters) of
  342. {value, {_, <<"on">>}} -> put(datetime_mod, pgsql_idatetime);
  343. {value, {_, <<"off">>}} -> put(datetime_mod, pgsql_fdatetime)
  344. end,
  345. State2 = reply(State#state{handler = on_message,
  346. txstatus = Status},
  347. connected),
  348. {noreply, State2};
  349. initializing({error, _} = Error, State) ->
  350. {stop, normal, reply(State, Error)};
  351. initializing(Other, State) ->
  352. on_message(Other, State).
  353. %% ParseComplete
  354. on_message({$1, <<>>}, State) ->
  355. {noreply, State};
  356. %% ParameterDescription
  357. on_message({$t, <<_Count:?int16, Bin/binary>>}, State) ->
  358. #state{queue = Q} = State,
  359. Types = [pgsql_types:oid2type(Oid) || <<Oid:?int32>> <= Bin],
  360. Name = case queue:get(Q) of
  361. {_, {parse, N, _, _}} -> N;
  362. {_, {describe_statement, N}} -> N
  363. end,
  364. {noreply, State#state{statement = #statement{name = Name,
  365. types = Types}}};
  366. %% RowDescription
  367. on_message({$T, <<Count:?int16, Bin/binary>>}, State) ->
  368. Columns = pgsql_wire:decode_columns(Count, Bin),
  369. State2 = case command_tag(State) of
  370. squery ->
  371. State#state{columns = Columns};
  372. C when C == parse; C == describe_statement ->
  373. Columns2 =
  374. [Col#column{format = pgsql_wire:format(
  375. Col#column.type)}
  376. || Col <- Columns],
  377. reply(State,
  378. {ok, State#state.statement#statement{
  379. columns = Columns2}});
  380. describe_portal ->
  381. reply(State, {ok, Columns})
  382. end,
  383. {noreply, State2};
  384. %% NoData
  385. on_message({$n, <<>>}, State) ->
  386. State2 = case command_tag(State) of
  387. C when C == parse; C == describe_statement ->
  388. reply(State,
  389. {ok, State#state.statement#statement{columns= []}});
  390. describe_portal ->
  391. reply(State, {ok, []})
  392. end,
  393. {noreply, State2};
  394. %% BindComplete
  395. on_message({$2, <<>>}, State) ->
  396. State2 = case command_tag(State) of
  397. equery ->
  398. State;
  399. bind ->
  400. reply(State, ok)
  401. end,
  402. {noreply, State2};
  403. %% CloseComplete
  404. on_message({$3, <<>>}, State) ->
  405. State2 = case command_tag(State) of
  406. close ->
  407. reply(State, ok);
  408. equery ->
  409. State
  410. end,
  411. {noreply, State2};
  412. %% DataRow
  413. on_message({$D, <<_Count:?int16, Bin/binary>>}, State) ->
  414. Data = pgsql_wire:decode_data(get_columns(State), Bin),
  415. {noreply, State#state{rows = [Data | State#state.rows]}};
  416. %% PortalSuspended
  417. on_message({$s, <<>>}, State) ->
  418. {noreply, reply(State, {partial, lists:reverse(State#state.rows)})};
  419. %% CommandComplete
  420. on_message({$C, Bin}, State) ->
  421. Complete = pgsql_wire:decode_complete(Bin),
  422. Command = command_tag(State),
  423. Rows = lists:reverse(State#state.rows),
  424. State2 = case {Command, Complete, Rows} of
  425. {execute, {_, Count}, []} ->
  426. reply(State, {ok, Count});
  427. {execute, {_, Count}, _} ->
  428. reply(State, {ok, Count, Rows});
  429. {execute, _, _} ->
  430. reply(State, {ok, Rows});
  431. {C, {_, Count}, []} when C == squery; C == equery ->
  432. add_result(State, {ok, Count});
  433. {C, {_, Count}, _} when C == squery; C == equery ->
  434. add_result(State, {ok, Count, get_columns(State), Rows});
  435. {C, _, _} when C == squery; C == equery ->
  436. add_result(State, {ok, get_columns(State), Rows})
  437. end,
  438. {noreply, State2};
  439. %% EmptyQueryResponse
  440. on_message({$I, _Bin}, State) ->
  441. State2 = case command_tag(State) of
  442. execute ->
  443. reply(State, {ok, [], []});
  444. C when C == squery; C == equery ->
  445. add_result(State, {ok, [], []})
  446. end,
  447. {noreply, State2};
  448. %% ReadyForQuery
  449. on_message({$Z, <<Status:8>>}, State) ->
  450. State2 = case command_tag(State) of
  451. squery ->
  452. case State#state.results of
  453. [Result] ->
  454. reply(State, Result);
  455. Results ->
  456. reply(State, lists:reverse(Results))
  457. end;
  458. equery ->
  459. [Result] = State#state.results,
  460. reply(State, Result);
  461. sync ->
  462. reply(State, ok)
  463. end,
  464. {noreply, State2#state{txstatus = Status}};
  465. on_message(Error = {error, _}, State) ->
  466. State2 = case command_tag(State) of
  467. C when C == squery; C == equery ->
  468. add_result(State, Error);
  469. _ ->
  470. sync_required(reply(State, Error))
  471. end,
  472. {noreply, State2};
  473. %% NoticeResponse
  474. on_message({$N, Data}, State) ->
  475. notify_async(State, {notice, pgsql_wire:decode_error(Data)}),
  476. {noreply, State};
  477. %% ParameterStatus
  478. on_message({$S, Data}, State) ->
  479. [Name, Value] = pgsql_wire:decode_strings(Data),
  480. Parameters2 = lists:keystore(Name, 1, State#state.parameters,
  481. {Name, Value}),
  482. {noreply, State#state{parameters = Parameters2}};
  483. %% NotificationResponse
  484. on_message({$A, <<Pid:?int32, Strings/binary>>}, State) ->
  485. case pgsql_wire:decode_strings(Strings) of
  486. [Channel, Payload] -> ok;
  487. [Channel] -> Payload = <<>>
  488. end,
  489. notify_async(State, {notification, Channel, Pid, Payload}),
  490. {noreply, State}.