|
@@ -46,11 +46,15 @@
|
|
|
%% @doc Performs a handshake using the supplied functions for communication.
|
|
|
%% Returns an ok or an error record. Raises errors when various unimplemented
|
|
|
%% features are requested.
|
|
|
--spec handshake(Username :: iodata(), Password :: iodata(), Database :: iodata() | undefined,
|
|
|
- SockModule :: module(), SSLOpts :: list() | undefined, Socket :: term(),
|
|
|
+-spec handshake(Username :: iodata(), Password :: iodata(),
|
|
|
+ Database :: iodata() | undefined,
|
|
|
+ SockModule :: module(), SSLOpts :: list() | undefined,
|
|
|
+ Socket :: term(),
|
|
|
SetFoundRows :: boolean()) ->
|
|
|
- {ok, #handshake{}, SockModule :: module(), Socket :: term()} | #error{}.
|
|
|
-handshake(Username, Password, Database, SockModule0, SSLOpts, Socket0, SetFoundRows) ->
|
|
|
+ {ok, #handshake{}, SockModule :: module(), Socket :: term()} |
|
|
|
+ #error{}.
|
|
|
+handshake(Username, Password, Database, SockModule0, SSLOpts, Socket0,
|
|
|
+ SetFoundRows) ->
|
|
|
SeqNum0 = 0,
|
|
|
{ok, HandshakePacket, SeqNum1} = recv_packet(SockModule0, Socket0, SeqNum0),
|
|
|
Handshake = parse_handshake(HandshakePacket),
|
|
@@ -60,15 +64,18 @@ handshake(Username, Password, Database, SockModule0, SSLOpts, Socket0, SetFoundR
|
|
|
Response = build_handshake_response(Handshake, Username, Password,
|
|
|
Database, SetFoundRows),
|
|
|
{ok, SeqNum3} = send_packet(SockModule, Socket, Response, SeqNum2),
|
|
|
- handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket, SeqNum3).
|
|
|
+ handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket,
|
|
|
+ SeqNum3).
|
|
|
|
|
|
-handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket, SeqNum0) ->
|
|
|
+handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket,
|
|
|
+ SeqNum0) ->
|
|
|
{ok, ConfirmPacket, SeqNum1} = recv_packet(SockModule, Socket, SeqNum0),
|
|
|
case parse_handshake_confirm(ConfirmPacket) of
|
|
|
#ok{status = OkStatus} ->
|
|
|
OkStatus = Handshake#handshake.status,
|
|
|
{ok, Handshake, SockModule, Socket};
|
|
|
- #auth_method_switch{auth_plugin_name = AuthPluginName, auth_plugin_data = AuthPluginData} ->
|
|
|
+ #auth_method_switch{auth_plugin_name = AuthPluginName,
|
|
|
+ auth_plugin_data = AuthPluginData} ->
|
|
|
Hash = case AuthPluginName of
|
|
|
<<>> ->
|
|
|
hash_password(Password, AuthPluginData);
|
|
@@ -78,7 +85,8 @@ handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket, SeqNum0
|
|
|
error({auth_method, UnknownAuthMethod})
|
|
|
end,
|
|
|
{ok, SeqNum2} = send_packet(SockModule, Socket, Hash, SeqNum1),
|
|
|
- handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket, SeqNum2);
|
|
|
+ handshake_finish_or_switch_auth(Handshake, Password, SockModule,
|
|
|
+ Socket, SeqNum2);
|
|
|
Error ->
|
|
|
Error
|
|
|
end.
|
|
@@ -250,11 +258,13 @@ server_version_to_list(ServerVersion) ->
|
|
|
SSLOpts :: undefined | list(),
|
|
|
Database :: iodata() | undefined,
|
|
|
SetFoundRows :: boolean()) ->
|
|
|
- {ok, SockModule :: module(), Socket :: term(), SeqNum2 :: non_neg_integer()}.
|
|
|
+ {ok, SockModule :: module(), Socket :: term(),
|
|
|
+ SeqNum2 :: non_neg_integer()}.
|
|
|
maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, _Handshake, undefined,
|
|
|
_Database, _SetFoundRows) ->
|
|
|
{ok, SockModule0, Socket0, SeqNum1};
|
|
|
-maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, Handshake, SSLOpts, Database, SetFoundRows) ->
|
|
|
+maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, Handshake, SSLOpts,
|
|
|
+ Database, SetFoundRows) ->
|
|
|
Response = build_handshake_response(Handshake, Database, SetFoundRows),
|
|
|
{ok, SeqNum2} = send_packet(SockModule0, Socket0, Response, SeqNum1),
|
|
|
case mysql_sock_ssl:connect(Socket0, SSLOpts, 5000) of
|
|
@@ -264,7 +274,8 @@ maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, Handshake, SSLOpts, Database
|
|
|
exit({failed_to_upgrade_socket, Reason})
|
|
|
end.
|
|
|
|
|
|
--spec build_handshake_response(#handshake{}, iodata() | undefined, boolean()) -> binary().
|
|
|
+-spec build_handshake_response(#handshake{}, iodata() | undefined, boolean()) ->
|
|
|
+ binary().
|
|
|
build_handshake_response(Handshake, Database, SetFoundRows) ->
|
|
|
CapabilityFlags = basic_capabilities(Database /= undefined, SetFoundRows),
|
|
|
verify_server_capabilities(Handshake, CapabilityFlags),
|
|
@@ -280,7 +291,8 @@ build_handshake_response(Handshake, Database, SetFoundRows) ->
|
|
|
%% initial handshake from the server
|
|
|
-spec build_handshake_response(#handshake{}, iodata(), iodata(),
|
|
|
iodata() | undefined, boolean()) -> binary().
|
|
|
-build_handshake_response(Handshake, Username, Password, Database, SetFoundRows) ->
|
|
|
+build_handshake_response(Handshake, Username, Password, Database,
|
|
|
+ SetFoundRows) ->
|
|
|
CapabilityFlags = basic_capabilities(Database /= undefined, SetFoundRows),
|
|
|
verify_server_capabilities(Handshake, CapabilityFlags),
|
|
|
%% Add some extra capability flags only for signalling to the server what
|
|
@@ -313,13 +325,16 @@ build_handshake_response(Handshake, Username, Password, Database, SetFoundRows)
|
|
|
Hash/binary,
|
|
|
DbBin/binary>>.
|
|
|
|
|
|
--spec verify_server_capabilities(Handshake :: #handshake{}, CapabilityFlags :: integer()) -> true | no_return().
|
|
|
+-spec verify_server_capabilities(Handshake :: #handshake{},
|
|
|
+ CapabilityFlags :: integer()) ->
|
|
|
+ true | no_return().
|
|
|
verify_server_capabilities(Handshake, CapabilityFlags) ->
|
|
|
%% We require these capabilities. Make sure the server handles them.
|
|
|
Handshake#handshake.capabilities band CapabilityFlags == CapabilityFlags
|
|
|
orelse error(old_server_version).
|
|
|
|
|
|
--spec basic_capabilities(ConnectWithDB :: boolean(), SetFoundRows :: boolean()) -> integer().
|
|
|
+-spec basic_capabilities(ConnectWithDB :: boolean(),
|
|
|
+ SetFoundRows :: boolean()) -> integer().
|
|
|
basic_capabilities(ConnectWithDB, SetFoundRows) ->
|
|
|
CapabilityFlags0 = ?CLIENT_PROTOCOL_41 bor
|
|
|
?CLIENT_TRANSACTIONS bor
|
|
@@ -343,7 +358,8 @@ add_client_capabilities(Caps) ->
|
|
|
%% @doc Handles the second packet from the server, when we have replied to the
|
|
|
%% initial handshake. Returns an error if the server returns an error. Raises
|
|
|
%% an error if unimplemented features are required.
|
|
|
--spec parse_handshake_confirm(binary()) -> #ok{} | #error{}.
|
|
|
+-spec parse_handshake_confirm(binary()) -> #ok{} | #auth_method_switch{} |
|
|
|
+ #error{}.
|
|
|
parse_handshake_confirm(Packet) ->
|
|
|
case Packet of
|
|
|
?ok_pattern ->
|
|
@@ -410,8 +426,8 @@ fetch_response(SockModule, Socket, Timeout, Proto, Acc) ->
|
|
|
-spec fetch_resultset(atom(), term(), integer(), integer()) ->
|
|
|
#resultset{} | #error{}.
|
|
|
fetch_resultset(SockModule, Socket, FieldCount, SeqNum) ->
|
|
|
- {ok, ColDefs, SeqNum1} = fetch_column_definitions(SockModule, Socket, SeqNum,
|
|
|
- FieldCount, []),
|
|
|
+ {ok, ColDefs, SeqNum1} = fetch_column_definitions(SockModule, Socket,
|
|
|
+ SeqNum, FieldCount, []),
|
|
|
{ok, DelimiterPacket, SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
|
|
|
#eof{status = S, warning_count = W} = parse_eof_packet(DelimiterPacket),
|
|
|
case fetch_resultset_rows(SockModule, Socket, SeqNum2, []) of
|
|
@@ -423,11 +439,13 @@ fetch_resultset(SockModule, Socket, FieldCount, SeqNum) ->
|
|
|
E
|
|
|
end.
|
|
|
|
|
|
-parse_resultset(#resultset{cols = ColDefs, rows = Rows} = R, ColumnCount, text) ->
|
|
|
+parse_resultset(#resultset{cols = ColDefs, rows = Rows} = R, ColumnCount,
|
|
|
+ text) ->
|
|
|
%% Parse the rows according to the 'text protocol' representation.
|
|
|
Rows1 = [decode_text_row(ColumnCount, ColDefs, Row) || Row <- Rows],
|
|
|
R#resultset{rows = Rows1};
|
|
|
-parse_resultset(#resultset{cols = ColDefs, rows = Rows} = R, ColumnCount, binary) ->
|
|
|
+parse_resultset(#resultset{cols = ColDefs, rows = Rows} = R, ColumnCount,
|
|
|
+ binary) ->
|
|
|
%% Parse the rows according to the 'binary protocol' representation.
|
|
|
Rows1 = [decode_binary_row(ColumnCount, ColDefs, Row) || Row <- Rows],
|
|
|
R#resultset{rows = Rows1}.
|
|
@@ -1054,7 +1072,8 @@ parse_auth_method_switch(AMSData) ->
|
|
|
auth_plugin_data = AuthPluginData
|
|
|
}.
|
|
|
|
|
|
--spec get_null_terminated_binary(binary()) -> {Binary :: binary(), Rest :: binary()}.
|
|
|
+-spec get_null_terminated_binary(binary()) -> {Binary :: binary(),
|
|
|
+ Rest :: binary()}.
|
|
|
get_null_terminated_binary(In) ->
|
|
|
get_null_terminated_binary(In, <<>>).
|
|
|
|
|
@@ -1082,7 +1101,8 @@ hash_password(Password, Salt) ->
|
|
|
false -> hash_non_empty_password(Password, Salt)
|
|
|
end.
|
|
|
|
|
|
--spec hash_non_empty_password(Password :: iodata(), Salt :: binary()) -> Hash :: binary().
|
|
|
+-spec hash_non_empty_password(Password :: iodata(), Salt :: binary()) ->
|
|
|
+ Hash :: binary().
|
|
|
hash_non_empty_password(Password, Salt) ->
|
|
|
Salt1 = case Salt of
|
|
|
<<SaltNoNul:20/binary-unit:8, 0>> -> SaltNoNul;
|