pgsql_sock.erl 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. %%% Copyright (C) 2009 - Will Glozer. All rights reserved.
  2. -module(pgsql_sock).
  3. -behavior(gen_server).
  4. -export([start_link/4, cancel/3]).
  5. -export([decode_string/1, lower_atom/1]).
  6. -export([handle_call/3, handle_cast/2, handle_info/2]).
  7. -export([init/1, code_change/3, terminate/2]).
  8. -include("pgsql.hrl").
  9. -include("pgsql_binary.hrl").
  10. -record(state, {mod, sock, tail, backend}).
  11. %% -- client interface --
  12. start_link(Host, Username, Password, Opts) ->
  13. gen_server:start_link(?MODULE, [Host, Username, Password, Opts], []).
  14. cancel(S) ->
  15. gen_server:cast(S, cancel}).
  16. %% -- gen_server implementation --
  17. init([Host, Username, Password, Opts]) ->
  18. %% TODO split connect/query timeout?
  19. Timeout = proplists:get_value(timeout, Opts, 5000),
  20. Port = proplists:get_value(port, Opts, 5432),
  21. SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}],
  22. {ok, S} = gen_tcp:connect(Host, Port, SockOpts, Timeout),
  23. State = #state{
  24. mod = gen_tcp,
  25. sock = S,
  26. decoder = pgsql_wire:init([]),
  27. timeout = Timeout},
  28. case proplists:get_value(ssl, Opts) of
  29. T when T == true; T == required ->
  30. ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
  31. {ok, <<Code>>} = gen_tcp:recv(S, 1, Timeout),
  32. State2 = start_ssl(Code, T, Opts, State);
  33. _ ->
  34. State2 = State
  35. end,
  36. Opts2 = ["user", 0, Username, 0],
  37. case proplists:get_value(database, Opts, undefined) of
  38. undefined -> Opts3 = Opts2;
  39. Database -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
  40. end,
  41. send([<<196608:?int32>>, Opts3, 0], State2),
  42. %% TODO Async = proplists:get_value(async, Opts, undefined),
  43. %% TODO setopts(State2, [{active, true}]),
  44. {ok, initialize(auth(User, Password, State2))}.
  45. handle_call(Call, _From, State) ->
  46. {stop, {unsupported_call, Call}, State}.
  47. handle_cast(cancel, State = #state{backend = {Pid, Key}}) ->
  48. {ok, {Addr, Port}} = inet:peername(State#state.sock),
  49. SockOpts = [{active, false}, {packet, raw}, binary],
  50. {ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
  51. Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
  52. ok = gen_tcp:send(Sock, Msg),
  53. gen_tcp:close(Sock),
  54. {noreply, State}.
  55. handle_info({Closed, _Sock}, State)
  56. when Closed == tcp_closed; Closed == ssl_closed ->
  57. {stop, sock_closed, State};
  58. handle_info({Error, _Sock, Reason}, State)
  59. when Error == tcp_error; Error == ssl_error ->
  60. {stop, {sock_error, Reason}, State};
  61. handle_info({_, _Sock, Data}, #state{tail = Tail} = State) ->
  62. on_tail(State#state{tail = <<Tail/binary, Data/binary>>}.
  63. on_tail(#state{tail = Tail} = State) ->
  64. case pgsql_wire:decode_message(Tail) of
  65. {Message, Tail2} ->
  66. on_tail(on_message(Message, State#{tail = Tail2}));
  67. _ ->
  68. {noreply, State}
  69. end.
  70. terminate(_Reason, _State) ->
  71. ok.
  72. code_change(_OldVsn, State, _Extra) ->
  73. {ok, State}.
  74. %% -- internal functions --
  75. start_ssl($S, _Flag, Opts, State) ->
  76. #state{sock = S1, timeout = Timeout} = State,
  77. case ssl:connect(S1, Opts, Timeout) of
  78. {ok, S2} -> State#state{mod = ssl, sock = S2};
  79. {error, Reason} -> exit({ssl_negotiation_failed, Reason})
  80. end;
  81. start_ssl($N, Flag, _Opts, State) ->
  82. case Flag of
  83. true -> State;
  84. required -> exit(ssl_not_available)
  85. end.
  86. setopts(#state{mod = Mod, sock = Sock}, Opts) ->
  87. case Mod of
  88. gen_tcp -> inet:setopts(Sock, Opts);
  89. ssl -> ssl:setopts(Sock, Opts)
  90. end.
  91. send(Data, #state{mod = Mod, sock = Sock} = State) ->
  92. Mod:send(Sock, pgsql_wire:encode(Data)).
  93. send(Type, Data, #state{mod = Mod, sock = Sock} = State) ->
  94. Mod:send(Sock, pgsql_wire:encode(Type, Data)).
  95. recv(#state{mod = Mod, sock = Sock, tail = Tail, timeout = Timeout} = State) ->
  96. {ok, Data} = Mod:recv(Sock, 0, Timeout),
  97. State#state{tail = <<Tail/binary, Data/binary>>}.
  98. auth(User, Password, State) ->
  99. State2 = #state{tail = Tail} = recv(State),
  100. case pgsql_wire:decode_message(Tail) of
  101. {Message, Tail2} ->
  102. State3 = State2#state{tail = Tail2},
  103. case Message of ->
  104. %% AuthenticationOk
  105. {$R, <<0:?int32>>} ->
  106. State3
  107. end
  108. _ -> auth(User, Password, State2)
  109. end.
  110. on_message({$N, Data}, State) ->
  111. %% TODO use it
  112. {notice, pgsql_wire:decode_error(Data)},
  113. State;
  114. on_message({$S, Data}, State) ->
  115. [Name, Value] = pgsql_wire:decode_strings(Data),
  116. %% TODO use it
  117. {parameter_status, Name, Value},
  118. State;
  119. on_message({$E, Data}, State) ->
  120. %% TODO use it
  121. {error, decode_error(Data)},
  122. State;
  123. on_message({$A, <<Pid:?int32, Strings/binary>>}, State) ->
  124. case pgsql_wire:decode_strings(Strings) of
  125. [Channel, Payload] -> ok;
  126. [Channel] -> Payload = <<>>
  127. end,
  128. %% TODO use it
  129. {notification, Channel, Pid, Payload},
  130. State;
  131. on_message(_Msg, State) ->
  132. State.