Browse Source

Fix nonce generation code; minor fixes

* Nonce should be unique
* Make normalize/1 stricter
* add handling of 'unknown' from auth functions
Сергей Прохоров 7 years ago
parent
commit
86fe6854a5
2 changed files with 47 additions and 9 deletions
  1. 15 5
      src/commands/epgsql_cmd_connect.erl
  2. 32 4
      src/epgsql_scram.erl

+ 15 - 5
src/commands/epgsql_cmd_connect.erl

@@ -20,7 +20,11 @@
 -include("epgsql.hrl").
 -include("epgsql.hrl").
 -include("protocol.hrl").
 -include("protocol.hrl").
 
 
--type auth_fun() :: fun((init | binary(), _, _) -> {send, byte(), iodata(), any()} | ok | {error, any()}).
+-type auth_fun() :: fun((init | binary(), _, _) ->
+                                     {send, byte(), iodata(), any()}
+                                   | ok
+                                   | {error, any()}
+                                   | unknown).
 
 
 -record(connect,
 -record(connect,
         {opts :: list(),
         {opts :: list(),
@@ -157,7 +161,8 @@ auth_handle(Data, PgSock, #connect{auth_fun = Fun, auth_state = AuthSt} = St) ->
                                          auth_send = {SendPacketId, SendData}}};
                                          auth_send = {SendPacketId, SendData}}};
         ok -> {noaction, PgSock, St};
         ok -> {noaction, PgSock, St};
         {error, Reason} ->
         {error, Reason} ->
-            {stop, normal, {error, Reason}}
+            {stop, normal, {error, Reason}};
+        unknown -> unknown
     end.
     end.
 
 
 %% AuthenticationCleartextPassword
 %% AuthenticationCleartextPassword
@@ -178,7 +183,7 @@ auth_md5(_, _, _) -> unknown.
 %% AuthenticationSASL
 %% AuthenticationSASL
 auth_scram(init, undefined, #connect{opts = Opts}) ->
 auth_scram(init, undefined, #connect{opts = Opts}) ->
     User = get_val(username, Opts),
     User = get_val(username, Opts),
-    Nonce = epgsql_scram:get_nonce(10),
+    Nonce = epgsql_scram:get_nonce(16),
     ClientFirst = epgsql_scram:get_client_first(User, Nonce),
     ClientFirst = epgsql_scram:get_client_first(User, Nonce),
     SaslInitialResponse = [?SCRAM_AUTH_METHOD, 0, <<(iolist_size(ClientFirst)):?int32>>, ClientFirst],
     SaslInitialResponse = [?SCRAM_AUTH_METHOD, 0, <<(iolist_size(ClientFirst)):?int32>>, ClientFirst],
     {send, ?SASL_ANY_RESPONSE, SaslInitialResponse, {auth_request, Nonce}};
     {send, ?SASL_ANY_RESPONSE, SaslInitialResponse, {auth_request, Nonce}};
@@ -192,14 +197,19 @@ auth_scram(<<?AUTH_SASL_FINAL:?int32, ServerFinalMsg/binary>>, {server_final, Se
     case epgsql_scram:parse_server_final(ServerFinalMsg) of
     case epgsql_scram:parse_server_final(ServerFinalMsg) of
         {ok, ServerProof} -> ok;
         {ok, ServerProof} -> ok;
         Other -> {error, {sasl_server_final, Other}}
         Other -> {error, {sasl_server_final, Other}}
-    end.
+    end;
+auth_scram(_, _, _) ->
+    unknown.
 
 
 
 
 %% --- Auth ---
 %% --- Auth ---
 
 
 %% AuthenticationOk
 %% AuthenticationOk
 handle_message(?AUTHENTICATION_REQUEST, <<?AUTH_OK:?int32>>, Sock, State) ->
 handle_message(?AUTHENTICATION_REQUEST, <<?AUTH_OK:?int32>>, Sock, State) ->
-    {noaction, Sock, State#connect{stage = initialization}};
+    {noaction, Sock, State#connect{stage = initialization,
+                                   auth_fun = undefined,
+                                   auth_state = undefned,
+                                   auth_send = undefined}};
 
 
 handle_message(?AUTHENTICATION_REQUEST, Message, Sock, #connect{stage = Stage} = St) when Stage =/= auth ->
 handle_message(?AUTHENTICATION_REQUEST, Message, Sock, #connect{stage = Stage} = St) when Stage =/= auth ->
     auth_init(Message, Sock, St);
     auth_init(Message, Sock, St);

+ 32 - 4
src/epgsql_scram.erl

@@ -1,3 +1,4 @@
+%%% coding: utf-8
 %%% @doc
 %%% @doc
 %%% SCRAM--SHA-256 helper functions
 %%% SCRAM--SHA-256 helper functions
 %%% See
 %%% See
@@ -32,10 +33,15 @@ get_client_first(UserName, Nonce) ->
 client_first_bare(UserName, Nonce) ->
 client_first_bare(UserName, Nonce) ->
     [<<"n=">>, UserName, <<",r=">>, Nonce].
     [<<"n=">>, UserName, <<",r=">>, Nonce].
 
 
+%% @doc Generate unique ASCII string.
+%% Resulting string length isn't guaranteed, but it's guaranteed to be unique and will
+%% contain `NumRandomBytes' of random data.
 -spec get_nonce(pos_integer()) -> nonce().
 -spec get_nonce(pos_integer()) -> nonce().
-get_nonce(Len) ->
-    Nonce = crypto:strong_rand_bytes(Len),
-    base64:encode(Nonce).
+get_nonce(NumRandomBytes) when NumRandomBytes < 255 ->
+    Random = crypto:strong_rand_bytes(NumRandomBytes),
+    Unique = binary:encode_unsigned(unique()),
+    NonceBin = <<NumRandomBytes, Random:NumRandomBytes/binary, Unique/binary>>,
+    base64:encode(NonceBin).
 
 
 -spec parse_server_first(binary(), nonce()) -> server_first().
 -spec parse_server_first(binary(), nonce()) -> server_first().
 parse_server_first(ServerFirst, ClientNonce) ->
 parse_server_first(ServerFirst, ClientNonce) ->
@@ -92,10 +98,15 @@ parse_server_final(<<"e=", ServerError/binary>>) ->
 
 
 %% Helpers
 %% Helpers
 
 
-%% TODO: implement
+%% TODO: implement according to rfc3454
 normalize(Str) ->
 normalize(Str) ->
+    lists:all(fun is_ascii_non_control/1, unicode:characters_to_list(Str, utf8))
+        orelse error({scram_non_ascii_password, Str}),
     Str.
     Str.
 
 
+is_ascii_non_control(C) when C > 16#1F, C < 16#7F -> true;
+is_ascii_non_control(_) -> false.
+
 check_nonce(ClientNonce, ServerNonce) ->
 check_nonce(ClientNonce, ServerNonce) ->
     Size = size(ClientNonce),
     Size = size(ClientNonce),
     <<ClientNonce:Size/binary, _/binary>> = ServerNonce,
     <<ClientNonce:Size/binary, _/binary>> = ServerNonce,
@@ -122,6 +133,18 @@ h(Str) ->
 bin_xor(B1, B2) ->
 bin_xor(B1, B2) ->
     crypto:exor(B1, B2).
     crypto:exor(B1, B2).
 
 
+
+-ifdef(FAST_MAPS).
+unique() ->
+    erlang:unique_integer([positive]).
+-else.
+unique() ->
+    %% POSIX timestamp microseconds
+    {Mega, Secs, Micro} = erlang:now(),
+    (Mega * 1000000 + Secs) * 1000000 + Micro.
+-endif.
+
+
 -ifdef(TEST).
 -ifdef(TEST).
 -include_lib("eunit/include/eunit.hrl").
 -include_lib("eunit/include/eunit.hrl").
 
 
@@ -141,4 +164,9 @@ exchange_test() ->
     {CF, ServerProof} = get_client_final(SF, Nonce, Username, Password),
     {CF, ServerProof} = get_client_final(SF, Nonce, Username, Password),
     ?assertEqual(ClientFinal, iolist_to_binary(CF)),
     ?assertEqual(ClientFinal, iolist_to_binary(CF)),
     ?assertEqual({ok, ServerProof}, parse_server_final(ServerFinal)).
     ?assertEqual({ok, ServerProof}, parse_server_final(ServerFinal)).
+
+normalize_test() ->
+    ?assertEqual(<<"123 !~">>, normalize(<<"123 !~">>)),
+    ?assertError({scram_non_ascii_password, _}, normalize(<<"привет"/utf8>>)).
+
 -endif.
 -endif.