Browse Source

swap send/2,3, on_data/1 params, handle auth replies

Anton Lebedevich 13 years ago
parent
commit
8bdd58380a
1 changed files with 65 additions and 8 deletions
  1. 65 8
      src/pgsql_sock.erl

+ 65 - 8
src/pgsql_sock.erl

@@ -58,13 +58,14 @@ handle_cast({connect, Host, Username, Password, Opts},
         undefined -> Opts3 = Opts2;
         Database  -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
     end,
-    send([<<196608:?int32>>, Opts3, 0], State2),
+    send(State2, [<<196608:?int32>>, Opts3, 0]),
     %% TODO    Async   = proplists:get_value(async, Opts, undefined),
     setopts(State2, [{active, true}]),
     {noreply,
      State2#state{on_message = fun(M, S) ->
                                        auth(Username, Password, M, S)
-                               end},
+                               end,
+                  on_timeout = fun auth_timeout/1},
      Timeout};
 
 handle_cast(cancel, State = #state{backend = {Pid, Key}}) ->
@@ -88,9 +89,9 @@ handle_info(timeout, #state{on_timeout = OnTimeout} = State) ->
     OnTimeout(State);
 
 handle_info({_, Sock, Data2}, #state{data = Data, sock = Sock} = State) ->
-    on_data({infinity, State#state{data = <<Data/binary, Data2/binary>>}}).
+    on_data({State#state{data = <<Data/binary, Data2/binary>>}, infinity}).
 
-on_data({Timeout, #state{data = Data, on_message = OnMessage} = State}) ->
+on_data({#state{data = Data, on_message = OnMessage} = State, Timeout}) ->
     case pgsql_wire:decode_message(Data) of
         {Message, Tail} ->
             on_data(OnMessage(Message, State#state{data = Tail}));
@@ -128,15 +129,64 @@ setopts(#state{mod = Mod, sock = Sock}, Opts) ->
         ssl     -> ssl:setopts(Sock, Opts)
     end.
 
-send(Data, #state{mod = Mod, sock = Sock}) ->
+send(#state{mod = Mod, sock = Sock}, Data) ->
     Mod:send(Sock, pgsql_wire:encode(Data)).
 
-send(Type, Data, #state{mod = Mod, sock = Sock}) ->
+send(#state{mod = Mod, sock = Sock}, Type, Data) ->
     Mod:send(Sock, pgsql_wire:encode(Type, Data)).
 
+%% -- backend message handling --
+
 %% AuthenticationOk
-auth(User, Password, {$R, <<0:?int32>>}, State) ->
-    State#state{on_message = fun on_message/2}.
+auth(_Username, _Password, {$R, <<0:?int32>>}, State) ->
+    #state{timeout = Timeout} = State,
+    {State#state{on_message = fun initializing/2}, Timeout};
+
+%% AuthenticationCleartextPassword
+auth(_Username, Password, {$R, <<3:?int32>>}, State) ->
+    #state{timeout = Timeout} = State,
+    send(State, $p, [Password, 0]),
+    {State, Timeout};
+
+%% AuthenticationMD5Password
+auth(Username, Password, {$R, <<5:?int32, Salt:4/binary>>}, State) ->
+    #state{timeout = Timeout} = State,
+    Digest1 = hex(erlang:md5([Password, Username])),
+    Str = ["md5", hex(erlang:md5([Digest1, Salt])), 0],
+    send(State, $p, Str),
+    {State, Timeout};
+
+auth(_Username, _Password, {$R, <<M:?int32, _/binary>>}, State) ->
+    case M of
+        2 -> Method = kerberosV5;
+        4 -> Method = crypt;
+        6 -> Method = scm;
+        7 -> Method = gss;
+        8 -> Method = sspi;
+        _ -> Method = unknown
+    end,
+    Error = {error, {unsupported_auth_method, Method}},
+    %% TODO send error response
+    {stop, Error, State};
+
+%% ErrorResponse
+%% TODO who decodes error ?
+auth(_Username, _Password, {error, E}, State) ->
+    case E#error.code of
+        <<"28000">> -> Why = invalid_authorization_specification;
+        <<"28P01">> -> Why = invalid_password;
+        Any         -> Why = Any
+    end,
+    %% TODO send error response
+    {stop, {error, Why}, State}.
+
+auth_timeout(State) ->
+    %% TODO send error response
+    {stop, {error, timeout}, State}.
+
+
+initializing(_, State) ->
+    {infinity, State}.
 
 on_message({$N, Data}, State) ->
     %% TODO use it
@@ -165,3 +215,10 @@ on_message({$A, <<Pid:?int32, Strings/binary>>}, State) ->
 
 on_message(_Msg, State) ->
     {infinity, State}.
+
+
+hex(Bin) ->
+    HChar = fun(N) when N < 10 -> $0 + N;
+               (N) when N < 16 -> $W + N
+            end,
+    <<<<(HChar(H)), (HChar(L))>> || <<H:4, L:4>> <= Bin>>.