Просмотр исходного кода

implement support for SSL session encryption

Will 16 лет назад
Родитель
Сommit
6c1f19d2c5
7 измененных файлов с 249 добавлено и 139 удалено
  1. 1 1
      Makefile
  2. 1 0
      README
  3. 4 3
      src/epgsql.app
  4. 26 133
      src/pgsql_connection.erl
  5. 201 0
      src/pgsql_sock.erl
  6. 13 2
      test_src/pgsql_tests.erl
  7. 3 0
      test_src/test_schema.sql

+ 1 - 1
Makefile

@@ -1,5 +1,5 @@
 NAME		:= epgsql
-VERSION		:= 1.0
+VERSION		:= 1.1
 
 ERL  		:= erl
 ERLC 		:= erlc

+ 1 - 0
README

@@ -8,6 +8,7 @@ Erlang PostgreSQL Database Client
 
   - database
   - port
+  - ssl (true | false | required)
 
   ok = pgsql:close(C).
 

+ 4 - 3
src/epgsql.app

@@ -1,7 +1,8 @@
 {application, epgsql,
  [{description, "PostgreSQL Client"},
-  {vsn, "1.0"},
-  {modules, [pgsql, pgsql_binary, pgsql_connection, pgsql_datetime, pgsql_types]},
+  {vsn, "1.1"},
+  {modules, [pgsql, pgsql_binary, pgsql_connection, pgsql_fdatetime,
+             pgsql_idatetime, pgsql_sock, pgsql_types]},
   {registered, []},
-  {applications, [kernel, stdlib]},
+  {applications, [kernel, stdlib, crypto, ssl]},
   {included_applications, []}]}.

+ 26 - 133
src/pgsql_connection.erl

@@ -11,7 +11,6 @@
 
 -export([init/1, handle_event/3, handle_sync_event/4]).
 -export([handle_info/3, terminate/3, code_change/4]).
--export([read/3]).
 
 -export([startup/3, auth/2, initializing/2, ready/2, ready/3]).
 -export([querying/2, parsing/2, binding/2, describing/2]).
@@ -93,7 +92,7 @@ handle_event(Event, _State_Name, State) ->
 handle_sync_event(Event, _From, _State_Name, State) ->
     {stop, {unsupported_sync_event, Event}, State}.
 
-handle_info({'EXIT', Pid, Reason}, _State_Name, State = #state{reader = Pid}) ->
+handle_info({'EXIT', Pid, Reason}, _State_Name, State = #state{sock = Pid}) ->
     {stop, Reason, State};
 
 handle_info(Info, _State_Name, State) ->
@@ -101,8 +100,7 @@ handle_info(Info, _State_Name, State) ->
 
 terminate(_Reason, _State_Name, State = #state{sock = Sock})
   when Sock =/= undefined ->
-    send(State, $X, []),
-    gen_tcp:close(Sock);
+    send(State, $X, []);
 
 terminate(_Reason, _State_Name, _State) ->
     ok.
@@ -113,25 +111,11 @@ code_change(_Old_Vsn, State_Name, State, _Extra) ->
 %% -- states --
 
 startup({connect, Host, Username, Password, Opts}, From, State) ->
-    Port      = proplists:get_value(port, Opts, 5432),
-    Sock_Opts = [{active, false}, {packet, raw}, binary],
-    case gen_tcp:connect(Host, Port, Sock_Opts) of
+    case pgsql_sock:start_link(self(), Host, Username, Opts) of
         {ok, Sock} ->
-            Reader = spawn_link(?MODULE, read, [self(), Sock, <<>>]),
-
-            Opts2 = ["user", 0, Username, 0],
-            case proplists:get_value(database, Opts, undefined) of
-                undefined -> Opts3 = Opts2;
-                Database  -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
-            end,
-
             put(username, Username),
             put(password, Password),
-            State2 = State#state{reader   = Reader,
-                                 sock     = Sock,
-                                 reply_to = From},
-            send(State2, [<<196608:32>>, Opts3, 0]),
-
+            State2 = State#state{sock = Sock, reply_to = From},
             {next_state, auth, State2};
         Error ->
             {stop, normal, Error, State}
@@ -167,9 +151,8 @@ auth({$R, <<M:?int32, _/binary>>}, State) ->
     {stop, normal, State};
 
 %% ErrorResponse
-auth({$E, Bin}, State) ->
-    Error = decode_error(Bin),
-    case Error#error.code of
+auth({error, E}, State) ->
+    case E#error.code of
         <<"28000">> -> Why = invalid_authorization_specification;
         Any         -> Why = Any
     end,
@@ -182,9 +165,8 @@ initializing({$K, <<Pid:?int32, Key:?int32>>}, State) ->
     {next_state, initializing, State2};
 
 %% ErrorResponse
-initializing({$E, Bin}, State) ->
-    Error = decode_error(Bin),
-    case Error#error.code of
+initializing({error, E}, State) ->
+    case E#error.code of
         <<"28000">> -> Why = invalid_authorization_specification;
         Any         -> Why = Any
     end,
@@ -311,9 +293,8 @@ querying({$I, _Bin}, State) ->
     {next_state, querying, State};
 
 %% ErrorResponse
-querying({$E, Bin}, State) ->
-    Error = decode_error(Bin),
-    notify(State, {error, Error}),
+querying({error, E}, State) ->
+    notify(State, {error, E}),
     {next_state, querying, State};
 
 %% ReadyForQuery
@@ -326,8 +307,8 @@ parsing({$1, <<>>}, State) ->
     {next_state, describing, State};
 
 %% ErrorResponse
-parsing({$E, Bin}, State) ->
-    Reply = {error, decode_error(Bin)},
+parsing({error, E}, State) ->
+    Reply = {error, E},
     send(State, $S, []),
     {next_state, parsing, State#state{reply = Reply}};
 
@@ -343,8 +324,8 @@ binding({$2, <<>>}, State) ->
     {next_state, ready, State};
 
 %% ErrorResponse
-binding({$E, Bin}, State) ->
-    Reply = {error, decode_error(Bin)},
+binding({error, E}, State) ->
+    Reply = {error, E},
     send(State, $S, []),
     {next_state, binding, State#state{reply = Reply}};
 
@@ -375,8 +356,8 @@ describing({$n, <<>>}, State) ->
     {next_state, ready, State};
 
 %% ErrorResponse
-describing({$E, Bin}, State) ->
-    Reply = {error, decode_error(Bin)},
+describing({error, E}, State) ->
+    Reply = {error, E},
     send(State, $S, []),
     {next_state, describing, State#state{reply = Reply}};
 
@@ -409,8 +390,8 @@ executing({$I, _Bin}, State) ->
     {next_state, ready, State};
 
 %% ErrorResponse
-executing({$E, Bin}, State) ->
-    notify(State, {error, decode_error(Bin)}),
+executing({error, E}, State) ->
+    notify(State, {error, E}),
     {next_state, executing, State}.
 
 %% CloseComplete
@@ -419,14 +400,14 @@ closing({$3, <<>>}, State) ->
     {next_state, ready, State};
 
 %% ErrorResponse
-closing({$E, Bin}, State) ->
-    Error = {error, decode_error(Bin)},
+closing({error, E}, State) ->
+    Error = {error, E},
     gen_fsm:reply(State#state.reply_to, Error),
     {next_state, ready, State}.
 
 %% ErrorResponse
-synchronizing({$E, Bin}, State) ->
-    Reply = {error, decode_error(Bin)},
+synchronizing({error, E}, State) ->
+    Reply = {error, E},
     {next_state, synchronizing, State#state{reply = Reply}};
 
 %% ReadyForQuery
@@ -437,35 +418,6 @@ synchronizing({$Z, <<Status:8>>}, State) ->
 
 %% -- internal functions --
 
-%% decode a single null-terminated string
-decode_string(Bin) ->
-    decode_string(Bin, <<>>).
-
-decode_string(<<0, Rest/binary>>, Str) ->
-    {Str, Rest};
-decode_string(<<C, Rest/binary>>, Str) ->
-    decode_string(Rest, <<Str/binary, C>>).
-
-%% decode multiple null-terminated string
-decode_strings(Bin) ->
-    decode_strings(Bin, []).
-
-decode_strings(<<>>, Acc) ->
-    lists:reverse(Acc);
-decode_strings(Bin, Acc) ->
-    {Str, Rest} = decode_string(Bin),
-    decode_strings(Rest, [Str | Acc]).
-
-%% decode field
-decode_fields(Bin) ->
-    decode_fields(Bin, []).
-
-decode_fields(<<0>>, Acc) ->
-    Acc;
-decode_fields(<<Type:8, Rest/binary>>, Acc) ->
-    {Str, Rest2} = decode_string(Rest),
-    decode_fields(Rest2, [{Type, Str} | Acc]).
-
 %% decode data
 decode_data(Columns, Bin) ->
     decode_data(Columns, Bin, []).
@@ -488,7 +440,7 @@ decode_columns(Count, Bin) ->
 decode_columns(0, _Bin, Acc) ->
     lists:reverse(Acc);
 decode_columns(N, Bin, Acc) ->
-    {Name, Rest} = decode_string(Bin),
+    {Name, Rest} = pgsql_sock:decode_string(Bin),
     <<_Table_Oid:?int32, _Attrib_Num:?int16, Type_Oid:?int32,
      Size:?int16, Modifier:?int32, Format:?int16, Rest2/binary>> = Rest,
     Desc = #column{
@@ -504,36 +456,14 @@ decode_complete(<<"SELECT", 0>>)   -> select;
 decode_complete(<<"BEGIN", 0>>)    -> 'begin';
 decode_complete(<<"ROLLBACK", 0>>) -> rollback;
 decode_complete(Bin) ->
-    {Str, _} = decode_string(Bin),
+    {Str, _} = pgsql_sock:decode_string(Bin),
     case string:tokens(binary_to_list(Str), " ") of
         ["INSERT", _Oid, Rows] -> {insert, list_to_integer(Rows)};
         ["UPDATE", Rows]       -> {update, list_to_integer(Rows)};
         ["DELETE", Rows]       -> {delete, list_to_integer(Rows)};
         ["MOVE", Rows]         -> {move, list_to_integer(Rows)};
         ["FETCH", Rows]        -> {fetch, list_to_integer(Rows)};
-        [Type | _Rest]         -> lower_atom(Type)
-    end.
-
-%% decode ErrorResponse
-decode_error(Bin) ->
-    Fields = decode_fields(Bin),
-    Error = #error{
-      severity = lower_atom(proplists:get_value($S, Fields)),
-      code     = proplists:get_value($C, Fields),
-      message  = proplists:get_value($M, Fields),
-      extra    = decode_error_extra(Fields)},
-    Error.
-
-decode_error_extra(Fields) ->
-    Types = [{$D, detail}, {$H, hint}, {$P, position}],
-    decode_error_extra(Types, Fields, []).
-
-decode_error_extra([], _Fields, Extra) ->
-    Extra;
-decode_error_extra([{Type, Name} | T], Fields, Extra) ->
-    case proplists:get_value(Type, Fields) of
-        undefined -> decode_error_extra(T, Fields, Extra);
-        Value     -> decode_error_extra(T, Fields, [{Name, Value} | Extra])
+        [Type | _Rest]         -> pgsql_sock:lower_atom(Type)
     end.
 
 %% encode types
@@ -599,11 +529,6 @@ encode_list(L) ->
 notify(#state{reply_to = {Pid, _Tag}}, Msg) ->
     Pid ! {pgsql, self(), Msg}.
 
-lower_atom(Str) when is_binary(Str) ->
-    lower_atom(binary_to_list(Str));
-lower_atom(Str) when is_list(Str) ->
-    list_to_atom(string:to_lower(Str)).
-
 to_binary(B) when is_binary(B) -> B;
 to_binary(L) when is_list(L)   -> list_to_binary(L).
 
@@ -616,36 +541,4 @@ hex(Bin) ->
 %% send data to server
 
 send(#state{sock = Sock}, Type, Data) ->
-    Bin = iolist_to_binary(Data),
-    gen_tcp:send(Sock, <<Type:8, (byte_size(Bin) + 4):?int32, Bin/binary>>).
-
-send(#state{sock = Sock}, Data) ->
-    Bin = iolist_to_binary(Data),
-    gen_tcp:send(Sock, <<(byte_size(Bin) + 4):?int32, Bin/binary>>).
-
-%% -- socket read loop --
-
-read(Fsm, Sock, Tail) ->
-    case gen_tcp:recv(Sock, 0) of
-        {ok, Bin} -> decode(Fsm, Sock, <<Tail/binary, Bin/binary>>);
-        Error     -> exit(Error)
-    end.
-
-decode(Fsm, Sock, <<Type:8, Len:?int32, Rest/binary>> = Bin) ->
-    Len2 = Len - 4,
-    case Rest of
-        <<Data:Len2/binary, Tail/binary>> when Type == $N ->
-            gen_fsm:send_all_state_event(Fsm, {notice, decode_error(Data)}),
-            decode(Fsm, Sock, Tail);
-        <<Data:Len2/binary, Tail/binary>> when Type == $S ->
-            [Name, Value] = decode_strings(Data),
-            gen_fsm:send_all_state_event(Fsm, {parameter_status, Name, Value}),
-            decode(Fsm, Sock, Tail);
-        <<Data:Len2/binary, Tail/binary>> ->
-            gen_fsm:send_event(Fsm, {Type, Data}),
-            decode(Fsm, Sock, Tail);
-        _Other ->
-            ?MODULE:read(Fsm, Sock, Bin)
-    end;
-decode(Fsm, Sock, Bin) ->
-    ?MODULE:read(Fsm, Sock, Bin).
+    pgsql_sock:send(Sock, Type, Data).

+ 201 - 0
src/pgsql_sock.erl

@@ -0,0 +1,201 @@
+%%% Copyright (C) 2009 - Will Glozer.  All rights reserved.
+
+-module(pgsql_sock).
+
+-behavior(gen_server).
+
+-export([start_link/4, send/2, send/3]).
+-export([decode_string/1, lower_atom/1]).
+
+-export([handle_call/3, handle_cast/2, handle_info/2]).
+-export([init/1, code_change/3, terminate/2]).
+
+-include("pgsql.hrl").
+
+-record(state, {c, mod, sock, tail}).
+
+-define(int16, 1/big-signed-unit:16).
+-define(int32, 1/big-signed-unit:32).
+
+%% -- client interface --
+
+start_link(C, Host, Username, Opts) ->
+    gen_server:start_link(?MODULE, [C, Host, Username, Opts], []).
+
+send(S, Type, Data) ->
+    Bin = iolist_to_binary(Data),
+    Msg = <<Type:8, (byte_size(Bin) + 4):?int32, Bin/binary>>,
+    gen_server:cast(S, {send, Msg}).
+
+send(S, Data) ->
+    Bin = iolist_to_binary(Data),
+    Msg = <<(byte_size(Bin) + 4):?int32, Bin/binary>>,
+    gen_server:cast(S, {send, Msg}).
+
+%% -- gen_server implementation --
+
+init([C, Host, Username, Opts]) ->
+    process_flag(trap_exit, true),
+
+    Opts2 = ["user", 0, Username, 0],
+    case proplists:get_value(database, Opts, undefined) of
+        undefined -> Opts3 = Opts2;
+        Database  -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
+    end,
+
+    Port = proplists:get_value(port, Opts, 5432),
+    SockOpts = [{active, false}, {packet, raw}, binary],
+    {ok, S} = gen_tcp:connect(Host, Port, SockOpts),
+
+    State = #state{
+      c    = C,
+      mod  = gen_tcp,
+      sock = S,
+      tail = <<>>},
+
+    case proplists:get_value(ssl, Opts) of
+        T when T == true; T == required ->
+            ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
+            {ok, <<Code>>} = gen_tcp:recv(S, 1),
+            State2 = start_ssl(Code, T, Opts, State);
+        _ ->
+            State2 = State
+    end,
+
+    setopts(State2, [{active, true}]),
+    send(self(), [<<196608:32>>, Opts3, 0]),
+    {ok, State2}.
+
+handle_call(Call, _From, State) ->
+    {stop, {unsupported_call, Call}, State}.
+
+handle_cast({send, Data}, State) ->
+    #state{mod = Mod, sock = Sock} = State,
+    ok = Mod:send(Sock, Data),
+    {noreply, State};
+
+handle_cast(Cast, State) ->
+    {stop, {unsupported_cast, Cast}, State}.
+
+handle_info({_, _Sock, Data}, #state{tail = Tail} = State) ->
+    State2 = decode(<<Tail/binary, Data/binary>>, State),
+    {noreply, State2};
+
+handle_info({Closed, _Sock}, State)
+  when Closed == tcp_closed; Closed == ssl_closed ->
+    {stop, sock_closed, State};
+
+handle_info({Error, _Sock, Reason}, State)
+  when Error == tcp_error; Error == ssl_error ->
+    {stop, {sock_error, Reason}, State};
+
+handle_info({'EXIT', _Pid, Reason}, State) ->
+    {stop, Reason, State};
+
+handle_info(Info, State) ->
+    {stop, {unsupported_info, Info}, State}.
+
+terminate(_Reason, _State) ->
+    ok.
+
+code_change(_OldVsn, State, _Extra) ->
+    {ok, State}.
+
+%% -- internal functions --
+
+start_ssl($S, _Flag, Opts, State) ->
+    #state{sock = S1} = State,
+    case ssl:connect(S1, Opts) of
+        {ok, S2}        -> State#state{mod = ssl, sock = S2};
+        {error, Reason} -> exit({ssl_negotiation_failed, Reason})
+    end;
+
+start_ssl($N, Flag, _Opts, State) ->
+    case Flag of
+        true     -> State;
+        required -> exit(ssl_not_available)
+    end.
+
+setopts(#state{mod = Mod, sock = Sock}, Opts) ->
+    case Mod of
+        gen_tcp -> inet:setopts(Sock, Opts);
+        ssl     -> ssl:setopts(Sock, Opts)
+    end.
+
+decode(<<Type:8, Len:?int32, Rest/binary>> = Bin, #state{c = C} = State) ->
+    Len2 = Len - 4,
+    case Rest of
+        <<Data:Len2/binary, Tail/binary>> when Type == $N ->
+            gen_fsm:send_all_state_event(C, {notice, decode_error(Data)}),
+            decode(Tail, State);
+        <<Data:Len2/binary, Tail/binary>> when Type == $S ->
+            [Name, Value] = decode_strings(Data),
+            gen_fsm:send_all_state_event(C, {parameter_status, Name, Value}),
+            decode(Tail, State);
+        <<Data:Len2/binary, Tail/binary>> when Type == $E ->
+            gen_fsm:send_event(C, {error, decode_error(Data)}),
+            decode(Tail, State);
+        <<Data:Len2/binary, Tail/binary>> ->
+            gen_fsm:send_event(C, {Type, Data}),
+            decode(Tail, State);
+        _Other ->
+            State#state{tail = Bin}
+    end;
+decode(Bin, State) ->
+    State#state{tail = Bin}.
+
+%% decode a single null-terminated string
+decode_string(Bin) ->
+    decode_string(Bin, <<>>).
+
+decode_string(<<0, Rest/binary>>, Str) ->
+    {Str, Rest};
+decode_string(<<C, Rest/binary>>, Str) ->
+    decode_string(Rest, <<Str/binary, C>>).
+
+%% decode multiple null-terminated string
+decode_strings(Bin) ->
+    decode_strings(Bin, []).
+
+decode_strings(<<>>, Acc) ->
+    lists:reverse(Acc);
+decode_strings(Bin, Acc) ->
+    {Str, Rest} = decode_string(Bin),
+    decode_strings(Rest, [Str | Acc]).
+
+%% decode field
+decode_fields(Bin) ->
+    decode_fields(Bin, []).
+
+decode_fields(<<0>>, Acc) ->
+    Acc;
+decode_fields(<<Type:8, Rest/binary>>, Acc) ->
+    {Str, Rest2} = decode_string(Rest),
+    decode_fields(Rest2, [{Type, Str} | Acc]).
+
+%% decode ErrorResponse
+decode_error(Bin) ->
+    Fields = decode_fields(Bin),
+    Error = #error{
+      severity = lower_atom(proplists:get_value($S, Fields)),
+      code     = proplists:get_value($C, Fields),
+      message  = proplists:get_value($M, Fields),
+      extra    = decode_error_extra(Fields)},
+    Error.
+
+decode_error_extra(Fields) ->
+    Types = [{$D, detail}, {$H, hint}, {$P, position}],
+    decode_error_extra(Types, Fields, []).
+
+decode_error_extra([], _Fields, Extra) ->
+    Extra;
+decode_error_extra([{Type, Name} | T], Fields, Extra) ->
+    case proplists:get_value(Type, Fields) of
+        undefined -> decode_error_extra(T, Fields, Extra);
+        Value     -> decode_error_extra(T, Fields, [{Name, Value} | Extra])
+    end.
+
+lower_atom(Str) when is_binary(Str) ->
+    lower_atom(binary_to_list(Str));
+lower_atom(Str) when is_list(Str) ->
+    list_to_atom(string:to_lower(Str)).

+ 13 - 2
test_src/pgsql_tests.erl

@@ -34,6 +34,14 @@ connect_with_invalid_password_test() ->
                       "epgsql_test_sha1",
                       [{port, ?port}, {database, "epgsql_test_db1"}]).
 
+connect_with_ssl_test() ->
+    lists:foreach(fun application:start/1, [crypto, ssl]),
+    with_connection(
+      fun(C) ->
+              {ok, _Cols, [{true}]} = pgsql:equery(C, "select ssl_is_used()")
+      end,
+      [{ssl, true}]).
+
 select_test() ->
     with_connection(
       fun(C) ->
@@ -394,8 +402,11 @@ connect_only(Args) ->
     flush().
 
 with_connection(F) ->
-    Args = [{port, ?port}, {database, "epgsql_test_db1"}],
-    {ok, C} = pgsql:connect(?host, "epgsql_test", Args),
+    with_connection(F, []).
+
+with_connection(F, Args) ->
+    Args2 = [{port, ?port}, {database, "epgsql_test_db1"} | Args],
+    {ok, C} = pgsql:connect(?host, "epgsql_test", Args2),
     try
         F(C)
     after

+ 3 - 0
test_src/test_schema.sql

@@ -12,6 +12,9 @@
 --
 -- any 'trust all' must be commented out for the invalid password test
 -- to succeed.
+--
+-- ssl support must be configured, and the sslinfo contrib module
+-- loaded for the ssl tests to succeed.
 
 
 CREATE USER epgsql_test;