Browse Source

Cleanup #142

* Add typespecs
* Fix tests
* Add server proof validation
* Change the way how server version is handled
Сергей Прохоров 7 years ago
parent
commit
b92065eccc
6 changed files with 128 additions and 73 deletions
  1. 48 41
      src/commands/epgsql_cmd_connect.erl
  2. 26 5
      src/epgsql_scram.erl
  3. 3 3
      test/data/test_schema.sql
  4. 18 12
      test/epgsql_SUITE.erl
  5. 10 9
      test/epgsql_ct.erl
  6. 23 3
      test/epgsql_cth.erl

+ 48 - 41
src/commands/epgsql_cmd_connect.erl

@@ -12,18 +12,22 @@
 -type connect_error() ::
         invalid_authorization_specification
       | invalid_password
-      | {unsupported_auth_method, integer()}
+      | {unsupported_auth_method,
+         kerberosV5 | crypt | scm | gss | sspi | {unknown, integer()} | {sasl, [binary()]}}
+      | {sasl_server_final, any()}
       | epgsql:query_error().
 
 -include("epgsql.hrl").
 -include("protocol.hrl").
 
+-type auth_fun() :: fun((init | binary(), _, _) -> {send, byte(), iodata(), any()} | ok | {error, any()}).
+
 -record(connect,
         {opts :: list(),
-         auth_fun :: fun() | undefined,
+         auth_fun :: auth_fun() | undefined,
          auth_state :: any() | undefined,
          auth_send :: {integer(), iodata()} | undefined,
-         stage = connect :: connect | auth | initialization}).
+         stage = connect :: connect | maybe_auth | auth | initialization}).
 
 -define(SCRAM_AUTH_METHOD, <<"SCRAM-SHA-256">>).
 -define(AUTH_OK, 0).
@@ -83,7 +87,7 @@ execute(PgSock, #connect{opts = Opts, stage = connect} = State) ->
                           undefined -> PgSock2;
                           Async -> epgsql_sock:set_attr(async, Async, PgSock2)
                       end,
-            {ok, PgSock3, State#connect{stage = auth}};
+            {ok, PgSock3, State#connect{stage = maybe_auth}};
         {error, Reason} = Error ->
             {stop, Reason, Error, PgSock}
     end;
@@ -116,27 +120,53 @@ maybe_ssl(S, Flag, Opts, PgSock) ->
             end
     end.
 
-
 %% Auth sub-protocol
 
+auth_init(<<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->
+    auth_init(fun auth_cleartext/3, undefined, Sock, St);
+auth_init(<<?AUTH_MD5:?int32, Salt:4/binary>>, Sock, St) ->
+    auth_init(fun auth_md5/3, Salt, Sock, St);
+auth_init(<<?AUTH_SASL:?int32, MethodsB/binary>>, Sock, St) ->
+    Methods = epgsql_wire:decode_strings(MethodsB),
+    case lists:member(?SCRAM_AUTH_METHOD, Methods) of
+        true ->
+            auth_init(fun auth_scram/3, undefined, Sock, St);
+        false ->
+            {stop, normal, {error, {unsupported_auth_method,
+                                    {sasl, lists:delete(<<>>, Methods)}}}}
+    end;
+auth_init(<<M:?int32, _/binary>>, Sock, _St) ->
+    Method = case M of
+                 2 -> kerberosV5;
+                 4 -> crypt;
+                 6 -> scm;
+                 7 -> gss;
+                 8 -> sspi;
+                 _ -> {unknown, M}
+             end,
+    {stop, normal, {error, {unsupported_auth_method, Method}}, Sock}.
+
 auth_init(Fun, InitState, PgSock, St) ->
     auth_handle(init, PgSock, St#connect{auth_fun = Fun, auth_state = InitState,
-                                                    stage = auth}).
+                                         stage = auth}).
 
 auth_handle(Data, PgSock, #connect{auth_fun = Fun, auth_state = AuthSt} = St) ->
     case Fun(Data, AuthSt, St) of
         {send, SendPacketId, SendData, AuthSt1} ->
             {requeue, PgSock, St#connect{auth_state = AuthSt1,
                                          auth_send = {SendPacketId, SendData}}};
-        ok ->
-            {noaction, PgSock, St}
+        ok -> {noaction, PgSock, St};
+        {error, Reason} ->
+            {stop, normal, {error, Reason}}
     end.
 
+%% AuthenticationCleartextPassword
 auth_cleartext(init, _AuthState, #connect{opts = Opts}) ->
     Password = get_val(password, Opts),
     {send, ?PASSWORD, [Password, 0], undefined};
 auth_cleartext(_, _, _) -> unknown.
 
+%% AuthenticationMD5Password
 auth_md5(init, Salt, #connect{opts = Opts}) ->
     User = get_val(username, Opts),
     Password = get_val(password, Opts),
@@ -145,7 +175,7 @@ auth_md5(init, Salt, #connect{opts = Opts}) ->
     {send, ?PASSWORD, Str, undefined};
 auth_md5(_, _, _) -> unknown.
 
-
+%% AuthenticationSASL
 auth_scram(init, undefined, #connect{opts = Opts}) ->
     User = get_val(username, Opts),
     Nonce = epgsql_scram:get_nonce(10),
@@ -158,8 +188,11 @@ auth_scram(<<?AUTH_SASL_CONTINUE:?int32, ServerFirst/binary>>, {auth_request, No
     ServerFirstParts = epgsql_scram:parse_server_first(ServerFirst, Nonce),
     {ClientFinalMessage, ServerProof} = epgsql_scram:get_client_final(ServerFirstParts, Nonce, User, Password),
     {send, ?SASL_ANY_RESPONSE, ClientFinalMessage, {server_final, ServerProof}};
-auth_scram(_Msg, {server_final, _ServerProof}, _Conn) ->
-    ok.
+auth_scram(<<?AUTH_SASL_FINAL:?int32, ServerFinalMsg/binary>>, {server_final, ServerProof}, _Conn) ->
+    case epgsql_scram:parse_server_final(ServerFinalMsg) of
+        {ok, ServerProof} -> ok;
+        Other -> {error, {sasl_server_final, Other}}
+    end.
 
 
 %% --- Auth ---
@@ -168,39 +201,12 @@ auth_scram(_Msg, {server_final, _ServerProof}, _Conn) ->
 handle_message(?AUTHENTICATION_REQUEST, <<?AUTH_OK:?int32>>, Sock, State) ->
     {noaction, Sock, State#connect{stage = initialization}};
 
-%% AuthenticationCleartextPassword
-handle_message(?AUTHENTICATION_REQUEST, <<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->
-    auth_init(fun auth_cleartext/3, undefined, Sock, St);
-
-%% AuthenticationMD5Password
-handle_message(?AUTHENTICATION_REQUEST, <<?AUTH_MD5:?int32, Salt:4/binary>>, Sock, St) ->
-    auth_init(fun auth_md5/3, Salt, Sock, St);
-
-%% AuthenticationSASL
-handle_message(?AUTHENTICATION_REQUEST, <<?AUTH_SASL:?int32, MethodsB/binary>>, Sock, St) ->
-    Methods = epgsql_wire:decode_strings(MethodsB),
-    case lists:member(?SCRAM_AUTH_METHOD, Methods) of
-        true ->
-            auth_init(fun auth_scram/3, undefined, Sock, St);
-        false ->
-            {stop, normal, {error, {unsupported_auth_method,
-                                    lists:delete(<<>>, Methods)}}}
-    end;
+handle_message(?AUTHENTICATION_REQUEST, Message, Sock, #connect{stage = Stage} = St) when Stage =/= auth ->
+    auth_init(Message, Sock, St);
 
 handle_message(?AUTHENTICATION_REQUEST, Packet, Sock, #connect{stage = auth} = St) ->
     auth_handle(Packet, Sock, St);
 
-handle_message(?AUTHENTICATION_REQUEST, <<M:?int32, _/binary>>, Sock, _State) ->
-    Method = case M of
-        2 -> kerberosV5;
-        4 -> crypt;
-        6 -> scm;
-        7 -> gss;
-        8 -> sspi;
-        _ -> {unknown, M}
-    end,
-    {stop, normal, {error, {unsupported_auth_method, Method}}, Sock};
-
 %% --- Initialization ---
 
 %% BackendKeyData
@@ -215,7 +221,8 @@ handle_message(?READY_FOR_QUERY, _, Sock, _State) ->
 
 
 %% ErrorResponse
-handle_message(?ERROR, Err, Sock, #connect{stage = auth} = _State) ->
+handle_message(?ERROR, Err, Sock, #connect{stage = Stage} = _State) when Stage == auth;
+                                                                         Stage == maybe_auth ->
     Why = case Err#error.code of
         <<"28000">> -> invalid_authorization_specification;
         <<"28P01">> -> invalid_password;

+ 26 - 5
src/epgsql_scram.erl

@@ -11,12 +11,20 @@
 -export([get_nonce/1,
          get_client_first/2,
          get_client_final/4,
-         parse_server_first/2]).
+         parse_server_first/2,
+         parse_server_final/1]).
 -export([hi/3,
          hmac/2,
          h/1,
          bin_xor/2]).
 
+-type nonce() :: binary().
+-type server_first() :: [{nonce, nonce()} |
+                         {salt, binary()} |
+                         {i, pos_integer()} |
+                         {raw, binary()}].
+
+-spec get_client_first(iodata(), nonce()) -> iodata().
 get_client_first(UserName, Nonce) ->
     %% Username is ignored by postgresql
     [<<"n,,">> | client_first_bare(UserName, Nonce)].
@@ -24,10 +32,12 @@ get_client_first(UserName, Nonce) ->
 client_first_bare(UserName, Nonce) ->
     [<<"n=">>, UserName, <<",r=">>, Nonce].
 
+-spec get_nonce(pos_integer()) -> nonce().
 get_nonce(Len) ->
     Nonce = crypto:strong_rand_bytes(Len),
     base64:encode(Nonce).
 
+-spec parse_server_first(binary(), nonce()) -> server_first().
 parse_server_first(ServerFirst, ClientNonce) ->
     PartsB = binary:split(ServerFirst, <<",">>, [global]),
     (length(PartsB) == 3) orelse error({invalid_server_first, ServerFirst}),
@@ -49,6 +59,8 @@ parse_server_first(ServerFirst, ClientNonce) ->
 %% AuthMessage     := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
 %% ClientSignature := HMAC(StoredKey, AuthMessage)
 %% ClientProof     := ClientKey XOR ClientSignature
+-spec get_client_final(server_first(), nonce(), iodata(), iodata()) ->
+                              {ClientFinal :: iodata(), ServerSignature :: binary()}.
 get_client_final(SrvFirst, ClientNonce, UserName, Password) ->
     ChannelBinding = <<"c=biws">>,                 %channel-binding isn't implemented
     Nonce = [<<"r=">>, proplists:get_value(nonce, SrvFirst)],
@@ -71,6 +83,13 @@ get_client_final(SrvFirst, ClientNonce, UserName, Password) ->
 
     {[ClientFinalWithoutProof, ",p=", base64:encode(ClientProof)], ServerSignature}.
 
+-spec parse_server_final(binary()) -> {ok, binary()} | {error, binary()}.
+parse_server_final(<<"v=", ServerFinal/binary>>) ->
+    [ServerFinal1 | _] = binary:split(ServerFinal, <<",">>),
+    {ok, base64:decode(ServerFinal1)};
+parse_server_final(<<"e=", ServerError/binary>>) ->
+    {error, ServerError}.
+
 %% Helpers
 
 %% TODO: implement
@@ -113,11 +132,13 @@ exchange_test() ->
 
     ClientFirst = <<"n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB">>,
     ServerFirst = <<"r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i=4096">>,
-    ClientFinal = <<"c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS31NTlQYNs5BTeQjdHdk7lOflDo5re2an8=">>,
-    _ServerFinal = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=",
+    ClientFinal = <<"c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS31NTlQ"
+                    "YNs5BTeQjdHdk7lOflDo5re2an8=">>,
+    ServerFinal = <<"v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=">>,
 
     ?assertEqual(ClientFirst, iolist_to_binary(get_client_first(Username, Nonce))),
     SF = parse_server_first(ServerFirst, Nonce),
-    {CF, _} = get_client_final(SF, Nonce, Username, Password),
-    ?assertEqual(ClientFinal, iolist_to_binary(CF), CF).
+    {CF, ServerProof} = get_client_final(SF, Nonce, Username, Password),
+    ?assertEqual(ClientFinal, iolist_to_binary(CF)),
+    ?assertEqual({ok, ServerProof}, parse_server_final(ServerFinal)).
 -endif.

+ 3 - 3
test/data/test_schema.sql

@@ -9,12 +9,12 @@
 
 CREATE USER epgsql_test;
 CREATE USER epgsql_test_md5 WITH PASSWORD 'epgsql_test_md5';
-SET password_encryption TO 'scram-sha-256';
-CREATE USER epgsql_test_scram WITH PASSWORD 'epgsql_test_scram';
-SET password_encryption TO 'md5';
 CREATE USER epgsql_test_cleartext WITH PASSWORD 'epgsql_test_cleartext';
 CREATE USER epgsql_test_cert;
 CREATE USER epgsql_test_replication WITH REPLICATION PASSWORD 'epgsql_test_replication';
+SET password_encryption TO 'scram-sha-256';
+CREATE USER epgsql_test_scram WITH PASSWORD 'epgsql_test_scram';
+SET password_encryption TO 'md5';
 
 CREATE DATABASE epgsql_test_db1 WITH ENCODING 'UTF8';
 CREATE DATABASE epgsql_test_db2 WITH ENCODING 'UTF8';

+ 18 - 12
test/epgsql_SUITE.erl

@@ -197,11 +197,17 @@ connect_with_md5(Config) ->
     ]).
 
 connect_with_scram(Config) ->
-    epgsql_ct:connect_only(Config, [
-        "epgsql_test_scram",
-        "epgsql_test_scram",
-        [{database, "epgsql_test_db1"}]
-    ]).
+    PgConf = ?config(pg_config, Config),
+    Ver = ?config(version, PgConf),
+    (Ver >= [10, 0])
+        andalso
+        epgsql_ct:connect_only(
+          Config,
+          [
+           "epgsql_test_scram",
+           "epgsql_test_scram",
+           [{database, "epgsql_test_db1"}]
+          ]).
 
 connect_with_invalid_user(Config) ->
     {Host, Port} = epgsql_ct:connection_data(Config),
@@ -868,8 +874,8 @@ query_timeout(Config) ->
     Module = ?config(module, Config),
     epgsql_ct:with_connection(Config, fun(C) ->
         {ok, _, _} = Module:squery(C, "SET statement_timeout = 500"),
-        ?TIMEOUT_ERROR = Module:squery(C, "SELECT pg_sleep(1)"),
-        ?TIMEOUT_ERROR = Module:equery(C, "SELECT pg_sleep(2)"),
+        ?assertMatch(?TIMEOUT_ERROR, Module:squery(C, "SELECT pg_sleep(1)")),
+        ?assertMatch(?TIMEOUT_ERROR, Module:equery(C, "SELECT pg_sleep(2)")),
         {ok, _Cols, [{1}]} = Module:equery(C, "SELECT 1")
     end, []).
 
@@ -879,7 +885,7 @@ execute_timeout(Config) ->
         {ok, _, _} = Module:squery(C, "SET statement_timeout = 500"),
         {ok, S} = Module:parse(C, "select pg_sleep($1)"),
         ok = Module:bind(C, S, [2]),
-        ?TIMEOUT_ERROR = Module:execute(C, S, 0),
+        ?assertMatch(?TIMEOUT_ERROR, Module:execute(C, S, 0)),
         ok = Module:sync(C),
         ok = Module:bind(C, S, [0]),
         {ok, [{<<>>}]} = Module:execute(C, S, 0),
@@ -990,7 +996,7 @@ listen_notify(Config) ->
 
 listen_notify_payload(Config) ->
     Module = ?config(module, Config),
-    epgsql_ct:with_min_version(Config, 9.0, fun(C) ->
+    epgsql_ct:with_min_version(Config, [9, 0], fun(C) ->
         {ok, [], []}     = Module:squery(C, "listen epgsql_test"),
         {ok, _, [{Pid}]} = Module:equery(C, "select pg_backend_pid()"),
         {ok, [], []}     = Module:squery(C, "notify epgsql_test, 'test!'"),
@@ -1003,7 +1009,7 @@ listen_notify_payload(Config) ->
 
 set_notice_receiver(Config) ->
     Module = ?config(module, Config),
-    epgsql_ct:with_min_version(Config, 9.0, fun(C) ->
+    epgsql_ct:with_min_version(Config, [9, 0], fun(C) ->
         {ok, [], []}     = Module:squery(C, "listen epgsql_test"),
         {ok, _, [{Pid}]} = Module:equery(C, "select pg_backend_pid()"),
 
@@ -1081,7 +1087,7 @@ get_cmd_status(Config) ->
     end).
 
 range_type(Config) ->
-    epgsql_ct:with_min_version(Config, 9.2, fun(_C) ->
+    epgsql_ct:with_min_version(Config, [9, 2], fun(_C) ->
         check_type(Config, int4range, "int4range(10, 20)", {10, 20}, [
             {1, 58}, {-1, 12}, {-985521, 5412687}, {minus_infinity, 0},
             {984655, plus_infinity}, {minus_infinity, plus_infinity}
@@ -1089,7 +1095,7 @@ range_type(Config) ->
    end, []).
 
 range8_type(Config) ->
-    epgsql_ct:with_min_version(Config, 9.2, fun(_C) ->
+    epgsql_ct:with_min_version(Config, [9, 2], fun(_C) ->
         check_type(Config, int8range, "int8range(10, 20)", {10, 20}, [
             {1, 58}, {-1, 12}, {-9223372036854775808, 5412687},
             {minus_infinity, 9223372036854775807},

+ 10 - 9
test/epgsql_ct.erl

@@ -64,15 +64,16 @@ with_rollback(Config, F) ->
       end).
 
 with_min_version(Config, Min, F, Args) ->
-    Module = ?config(module, Config),
-    epgsql_ct:with_connection(Config, fun(C) ->
-        {ok, Bin} = Module:get_parameter(C, <<"server_version">>),
-        {ok, [{float, 1, Ver} | _], _} = erl_scan:string(binary_to_list(Bin)),
-        case Ver >= Min of
-            true  -> F(C);
-            false -> ?debugFmt("skipping test requiring PostgreSQL >= ~.2f~n", [Min])
-        end
-    end, Args).
+    PgConf = ?config(pg_config, Config),
+    Ver = ?config(version, PgConf),
+
+    case Ver >= Min of
+        true ->
+            epgsql_ct:with_connection(Config, F, Args);
+        false ->
+            ?debugFmt("skipping test requiring PostgreSQL >= ~p, but we have ~p ~p",
+                      [Min, Ver, Config])
+    end.
 
 %% flush mailbox
 flush() ->

+ 23 - 3
test/epgsql_cth.erl

@@ -51,6 +51,7 @@ start_postgres() ->
     ok = application:start(erlexec),
     pipe([
         fun find_utils/1,
+        fun get_version/1,
         fun init_database/1,
         fun write_postgresql_config/1,
         fun copy_certs/1,
@@ -102,7 +103,8 @@ start_postgresql(Config) ->
             [{stderr,
               fun(_, _, Msg) ->
                   ct:pal(info, "postgres: ~s", [Msg])
-              end}]),
+              end},
+             {env, [{"LANGUAGE", "en"}]}]),
         loop(I)
     end),
     ConfigR = [
@@ -152,6 +154,17 @@ init_database(Config) ->
     {ok, _} = exec:run(Initdb ++ " --locale en_US.UTF8 " ++ PgDataDir, [sync,stdout,stderr]),
     [{datadir, PgDataDir}|Config].
 
+get_version(Config) ->
+    %% XXX: maybe use datadir/PG_VERSION after initdb?
+    Utils = ?config(utils, Config),
+    Postgres = ?config(postgres, Utils),
+
+    VersionStdout = list_to_binary(string:strip(os:cmd(Postgres ++ " -V"), both, $\n)),
+    VersionBin = lists:last(binary:split(VersionStdout, <<" ">>, [global])),
+    Version = lists:map(fun erlang:binary_to_integer/1,
+                        binary:split(VersionBin, <<".">>, [global])),
+    [{version, Version} | Config].
+
 write_postgresql_config(Config) ->
     PgDataDir = ?config(datadir, Config),
 
@@ -159,6 +172,7 @@ write_postgresql_config(Config) ->
         "ssl = on\n",
         "ssl_ca_file = 'root.crt'\n",
         "lc_messages = 'en_US.UTF-8'\n",
+        "fsync = off\n",
         "wal_level = 'logical'\n",
         "max_replication_slots = 15\n",
         "max_wal_senders = 15"
@@ -186,6 +200,7 @@ copy_certs(Config) ->
 
 write_pg_hba_config(Config) ->
     PgDataDir = ?config(datadir, Config),
+    Version = ?config(version, Config),
 
     User = os:getenv("USER"),
     PGConfig = [
@@ -196,10 +211,15 @@ write_pg_hba_config(Config) ->
         "host    epgsql_test_db1 ", User, "              127.0.0.1/32    trust\n",
         "host    epgsql_test_db1 epgsql_test             127.0.0.1/32    trust\n",
         "host    epgsql_test_db1 epgsql_test_md5         127.0.0.1/32    md5\n",
-        "host    epgsql_test_db1 epgsql_test_scram       127.0.0.1/32    scram-sha-256\n",
         "host    epgsql_test_db1 epgsql_test_cleartext   127.0.0.1/32    password\n",
         "hostssl epgsql_test_db1 epgsql_test_cert        127.0.0.1/32    cert clientcert=1\n",
-        "host    replication     epgsql_test_replication 127.0.0.1/32    trust"
+        "host    replication     epgsql_test_replication 127.0.0.1/32    trust\n" |
+        case Version >= [10, 0] of
+            true ->
+                "host    epgsql_test_db1 epgsql_test_scram       127.0.0.1/32    scram-sha-256\n";
+            false ->
+                []
+        end
     ],
     FilePath = filename:join(PgDataDir, "pg_hba.conf"),
     ok = file:write_file(FilePath, PGConfig),