Просмотр исходного кода

Fixes handshake for servers without CLIENT_PLUGIN_AUTH

Viktor Söderqvist 10 лет назад
Родитель
Сommit
6987efa974
1 измененных файлов с 18 добавлено и 23 удалено
  1. 18 23
      src/mysql_protocol.erl

+ 18 - 23
src/mysql_protocol.erl

@@ -199,8 +199,8 @@ parse_handshake(<<10, Rest/binary>>) ->
       Rest3/binary>> = Rest1,
     Capabilities = CapabilitiesLower + 16#10000 * CapabilitiesUpper,
     Len = case AuthPluginDataLength of
-        0 -> 13;    %% if not CLIENT_PLUGIN_AUTH
-        K -> K - 8
+        0 -> 13;   %% Server has not CLIENT_PLUGIN_AUTH
+        K -> K - 8 %% Part 2 length = Total length minus the 8 bytes in part 1.
     end,
     <<AuthPluginDataPart2:Len/binary-unit:8, AuthPluginName/binary>> = Rest3,
     AuthPluginData = <<AuthPluginDataPart1/binary, AuthPluginDataPart2/binary>>,
@@ -232,9 +232,15 @@ build_handshake_response(Handshake, Username, Password) ->
                       ?CLIENT_SECURE_CONNECTION,
     Handshake#handshake.capabilities band CapabilityFlags == CapabilityFlags
         orelse error(old_server_version),
-    Hash = hash_password(Password,
-                         Handshake#handshake.auth_plugin_name,
-                         Handshake#handshake.auth_plugin_data),
+    Hash = case Handshake#handshake.auth_plugin_name of
+        <<>> ->
+            %% Server doesn't know auth plugins
+            hash_password(Password, Handshake#handshake.auth_plugin_data);
+        <<"mysql_native_password">> ->
+            hash_password(Password, Handshake#handshake.auth_plugin_data);
+        UnknownAuthMethod ->
+            error({auth_method, UnknownAuthMethod})
+    end,
     HashLength = size(Hash),
     CharacterSet = ?UTF8,
     UsernameUtf8 = unicode:characters_to_binary(Username),
@@ -813,11 +819,8 @@ parse_eof_packet(<<?EOF:8, NumWarnings:16/little, StatusFlags:16/little>>) ->
     %% (Older protocol: <<?EOF:8>>)
     #eof{status = StatusFlags, warning_count = NumWarnings}.
 
--spec hash_password(Password :: iodata(), AuthPluginName :: binary(),
-                    AuthPluginData :: binary()) -> binary().
-hash_password(_Password, <<"mysql_old_password">>, _Salt) ->
-    error(old_auth);
-hash_password(Password, <<"mysql_native_password">>, AuthData) ->
+-spec hash_password(Password :: iodata(), Salt :: binary()) -> Hash :: binary().
+hash_password(Password, Salt) ->
     %% From the "MySQL Internals" manual:
     %% SHA1( password ) XOR SHA1( "20-bytes random data from server" <concat>
     %%                            SHA1( SHA1( password ) ) )
@@ -826,17 +829,15 @@ hash_password(Password, <<"mysql_native_password">>, AuthData) ->
     %%
     %% The auth data is obviously nul-terminated. For the "native" auth
     %% method, it should be a 20 byte salt, so let's trim it in this case.
-    Salt = case AuthData of
+    Salt1 = case Salt of
         <<SaltNoNul:20/binary-unit:8, 0>> -> SaltNoNul;
-        _ when size(AuthData) == 20       -> AuthData
+        _ when size(Salt) == 20           -> Salt
     end,
     %% Hash as described above.
     <<Hash1Num:160>> = Hash1 = crypto:hash(sha, Password),
     Hash2 = crypto:hash(sha, Hash1),
-    <<Hash3Num:160>> = crypto:hash(sha, <<Salt/binary, Hash2/binary>>),
-    <<(Hash1Num bxor Hash3Num):160>>;
-hash_password(_, AuthPlugin, _) ->
-    error({auth_method, AuthPlugin}).
+    <<Hash3Num:160>> = crypto:hash(sha, <<Salt1/binary, Hash2/binary>>),
+    <<(Hash1Num bxor Hash3Num):160>>.
 
 %% --- Lowlevel: variable length integers and strings ---
 
@@ -1005,12 +1006,6 @@ parse_eof_test() ->
 hash_password_test() ->
     ?assertEqual(<<222,207,222,139,41,181,202,13,191,241,
                    234,234,73,127,244,101,205,3,28,251>>,
-                 hash_password(<<"foo">>, <<"mysql_native_password">>,
-                               <<"abcdefghijklmnopqrst">>)),
-    ?assertError(old_auth,
-                 hash_password(<<"foo">>, <<"mysql_old_password">>, <<"abc">>)),
-    ?assertError({auth_method, <<"dummy">>},
-                 hash_password(<<"foo">>, <<"dummy">>, <<"dummy_salt">>)),
-    ok.
+                 hash_password(<<"foo">>, <<"abcdefghijklmnopqrst">>)).
 
 -endif.