|
@@ -1,5 +1,6 @@
|
|
%% MySQL/OTP – MySQL client library for Erlang/OTP
|
|
%% MySQL/OTP – MySQL client library for Erlang/OTP
|
|
%% Copyright (C) 2014 Viktor Söderqvist
|
|
%% Copyright (C) 2014 Viktor Söderqvist
|
|
|
|
+%% 2017 Piotr Nosek, Michal Slaski
|
|
%%
|
|
%%
|
|
%% This file is part of MySQL/OTP.
|
|
%% This file is part of MySQL/OTP.
|
|
%%
|
|
%%
|
|
@@ -26,7 +27,7 @@
|
|
%% @private
|
|
%% @private
|
|
-module(mysql_protocol).
|
|
-module(mysql_protocol).
|
|
|
|
|
|
--export([handshake/6, quit/2, ping/2,
|
|
|
|
|
|
+-export([handshake/7, quit/2, ping/2,
|
|
query/4, fetch_query_response/3,
|
|
query/4, fetch_query_response/3,
|
|
prepare/3, unprepare/3, execute/5, fetch_execute_response/3]).
|
|
prepare/3, unprepare/3, execute/5, fetch_execute_response/3]).
|
|
|
|
|
|
@@ -45,23 +46,28 @@
|
|
%% @doc Performs a handshake using the supplied functions for communication.
|
|
%% @doc Performs a handshake using the supplied functions for communication.
|
|
%% Returns an ok or an error record. Raises errors when various unimplemented
|
|
%% Returns an ok or an error record. Raises errors when various unimplemented
|
|
%% features are requested.
|
|
%% features are requested.
|
|
--spec handshake(iodata(), iodata(), iodata() | undefined, atom(),
|
|
|
|
- term(), boolean()) -> #handshake{} | #error{}.
|
|
|
|
-handshake(Username, Password, Database, TcpModule, Socket, SetFoundRows) ->
|
|
|
|
|
|
+-spec handshake(Username :: iodata(), Password :: iodata(), Database :: iodata() | undefined,
|
|
|
|
+ SockModule :: module(), SSLOpts :: list() | undefined, Socket :: term(),
|
|
|
|
+ SetFoundRows :: boolean()) ->
|
|
|
|
+ {ok, #handshake{}, SockModule :: module(), Socket :: term()} | #error{}.
|
|
|
|
+handshake(Username, Password, Database, SockModule0, SSLOpts, Socket0, SetFoundRows) ->
|
|
SeqNum0 = 0,
|
|
SeqNum0 = 0,
|
|
- {ok, HandshakePacket, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum0),
|
|
|
|
|
|
+ {ok, HandshakePacket, SeqNum1} = recv_packet(SockModule0, Socket0, SeqNum0),
|
|
Handshake = parse_handshake(HandshakePacket),
|
|
Handshake = parse_handshake(HandshakePacket),
|
|
|
|
+ {ok, SockModule, Socket, SeqNum2}
|
|
|
|
+ = maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, Handshake,
|
|
|
|
+ SSLOpts, Database, SetFoundRows),
|
|
Response = build_handshake_response(Handshake, Username, Password,
|
|
Response = build_handshake_response(Handshake, Username, Password,
|
|
Database, SetFoundRows),
|
|
Database, SetFoundRows),
|
|
- {ok, SeqNum2} = send_packet(TcpModule, Socket, Response, SeqNum1),
|
|
|
|
- handshake_finish_or_switch_auth(Handshake, Password, TcpModule, Socket, SeqNum2).
|
|
|
|
|
|
+ {ok, SeqNum3} = send_packet(SockModule, Socket, Response, SeqNum2),
|
|
|
|
+ handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket, SeqNum3).
|
|
|
|
|
|
-handshake_finish_or_switch_auth(Handshake, Password, TcpModule, Socket, SeqNum0) ->
|
|
|
|
- {ok, ConfirmPacket, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum0),
|
|
|
|
|
|
+handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket, SeqNum0) ->
|
|
|
|
+ {ok, ConfirmPacket, SeqNum1} = recv_packet(SockModule, Socket, SeqNum0),
|
|
case parse_handshake_confirm(ConfirmPacket) of
|
|
case parse_handshake_confirm(ConfirmPacket) of
|
|
#ok{status = OkStatus} ->
|
|
#ok{status = OkStatus} ->
|
|
OkStatus = Handshake#handshake.status,
|
|
OkStatus = Handshake#handshake.status,
|
|
- Handshake;
|
|
|
|
|
|
+ {ok, Handshake, SockModule, Socket};
|
|
#auth_method_switch{auth_plugin_name = AuthPluginName, auth_plugin_data = AuthPluginData} ->
|
|
#auth_method_switch{auth_plugin_name = AuthPluginName, auth_plugin_data = AuthPluginData} ->
|
|
Hash = case AuthPluginName of
|
|
Hash = case AuthPluginName of
|
|
<<>> ->
|
|
<<>> ->
|
|
@@ -71,45 +77,45 @@ handshake_finish_or_switch_auth(Handshake, Password, TcpModule, Socket, SeqNum0)
|
|
UnknownAuthMethod ->
|
|
UnknownAuthMethod ->
|
|
error({auth_method, UnknownAuthMethod})
|
|
error({auth_method, UnknownAuthMethod})
|
|
end,
|
|
end,
|
|
- {ok, SeqNum2} = send_packet(TcpModule, Socket, Hash, SeqNum1),
|
|
|
|
- handshake_finish_or_switch_auth(Handshake, Password, TcpModule, Socket, SeqNum2);
|
|
|
|
|
|
+ {ok, SeqNum2} = send_packet(SockModule, Socket, Hash, SeqNum1),
|
|
|
|
+ handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket, SeqNum2);
|
|
Error ->
|
|
Error ->
|
|
Error
|
|
Error
|
|
end.
|
|
end.
|
|
|
|
|
|
-spec quit(atom(), term()) -> ok.
|
|
-spec quit(atom(), term()) -> ok.
|
|
-quit(TcpModule, Socket) ->
|
|
|
|
- {ok, SeqNum1} = send_packet(TcpModule, Socket, <<?COM_QUIT>>, 0),
|
|
|
|
- case recv_packet(TcpModule, Socket, SeqNum1) of
|
|
|
|
|
|
+quit(SockModule, Socket) ->
|
|
|
|
+ {ok, SeqNum1} = send_packet(SockModule, Socket, <<?COM_QUIT>>, 0),
|
|
|
|
+ case recv_packet(SockModule, Socket, SeqNum1) of
|
|
{error, closed} -> ok; %% MySQL 5.5.40 and more
|
|
{error, closed} -> ok; %% MySQL 5.5.40 and more
|
|
{ok, ?ok_pattern, _SeqNum2} -> ok %% Some older MySQL versions?
|
|
{ok, ?ok_pattern, _SeqNum2} -> ok %% Some older MySQL versions?
|
|
end.
|
|
end.
|
|
|
|
|
|
-spec ping(atom(), term()) -> #ok{}.
|
|
-spec ping(atom(), term()) -> #ok{}.
|
|
-ping(TcpModule, Socket) ->
|
|
|
|
- {ok, SeqNum1} = send_packet(TcpModule, Socket, <<?COM_PING>>, 0),
|
|
|
|
- {ok, OkPacket, _SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
|
|
|
|
|
|
+ping(SockModule, Socket) ->
|
|
|
|
+ {ok, SeqNum1} = send_packet(SockModule, Socket, <<?COM_PING>>, 0),
|
|
|
|
+ {ok, OkPacket, _SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
|
|
parse_ok_packet(OkPacket).
|
|
parse_ok_packet(OkPacket).
|
|
|
|
|
|
-spec query(Query :: iodata(), atom(), term(), timeout()) ->
|
|
-spec query(Query :: iodata(), atom(), term(), timeout()) ->
|
|
{ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
|
|
{ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
|
|
-query(Query, TcpModule, Socket, Timeout) ->
|
|
|
|
|
|
+query(Query, SockModule, Socket, Timeout) ->
|
|
Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
|
|
Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
|
|
SeqNum0 = 0,
|
|
SeqNum0 = 0,
|
|
- {ok, _SeqNum1} = send_packet(TcpModule, Socket, Req, SeqNum0),
|
|
|
|
- fetch_query_response(TcpModule, Socket, Timeout).
|
|
|
|
|
|
+ {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, SeqNum0),
|
|
|
|
+ fetch_query_response(SockModule, Socket, Timeout).
|
|
|
|
|
|
%% @doc This is used by query/4. If query/4 returns {error, timeout}, this
|
|
%% @doc This is used by query/4. If query/4 returns {error, timeout}, this
|
|
%% function can be called to retry to fetch the results of the query.
|
|
%% function can be called to retry to fetch the results of the query.
|
|
-fetch_query_response(TcpModule, Socket, Timeout) ->
|
|
|
|
- fetch_response(TcpModule, Socket, Timeout, text, []).
|
|
|
|
|
|
+fetch_query_response(SockModule, Socket, Timeout) ->
|
|
|
|
+ fetch_response(SockModule, Socket, Timeout, text, []).
|
|
|
|
|
|
%% @doc Prepares a statement.
|
|
%% @doc Prepares a statement.
|
|
-spec prepare(iodata(), atom(), term()) -> #error{} | #prepared{}.
|
|
-spec prepare(iodata(), atom(), term()) -> #error{} | #prepared{}.
|
|
-prepare(Query, TcpModule, Socket) ->
|
|
|
|
|
|
+prepare(Query, SockModule, Socket) ->
|
|
Req = <<?COM_STMT_PREPARE, (iolist_to_binary(Query))/binary>>,
|
|
Req = <<?COM_STMT_PREPARE, (iolist_to_binary(Query))/binary>>,
|
|
- {ok, SeqNum1} = send_packet(TcpModule, Socket, Req, 0),
|
|
|
|
- {ok, Resp, SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
|
|
|
|
|
|
+ {ok, SeqNum1} = send_packet(SockModule, Socket, Req, 0),
|
|
|
|
+ {ok, Resp, SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
|
|
case Resp of
|
|
case Resp of
|
|
?error_pattern ->
|
|
?error_pattern ->
|
|
parse_error_packet(Resp);
|
|
parse_error_packet(Resp);
|
|
@@ -125,13 +131,13 @@ prepare(Query, TcpModule, Socket) ->
|
|
%% with charset 'binary' so we have to select a type ourselves for
|
|
%% with charset 'binary' so we have to select a type ourselves for
|
|
%% the parameters we have in execute/4.
|
|
%% the parameters we have in execute/4.
|
|
{_ParamDefs, SeqNum3} =
|
|
{_ParamDefs, SeqNum3} =
|
|
- fetch_column_definitions_if_any(NumParams, TcpModule, Socket,
|
|
|
|
|
|
+ fetch_column_definitions_if_any(NumParams, SockModule, Socket,
|
|
SeqNum2),
|
|
SeqNum2),
|
|
%% Column Definition Block. We get column definitions in execute
|
|
%% Column Definition Block. We get column definitions in execute
|
|
%% too, so we don't need them here. We *could* store them to be able
|
|
%% too, so we don't need them here. We *could* store them to be able
|
|
%% to provide the user with some info about a prepared statement.
|
|
%% to provide the user with some info about a prepared statement.
|
|
{_ColDefs, _SeqNum4} =
|
|
{_ColDefs, _SeqNum4} =
|
|
- fetch_column_definitions_if_any(NumColumns, TcpModule, Socket,
|
|
|
|
|
|
+ fetch_column_definitions_if_any(NumColumns, SockModule, Socket,
|
|
SeqNum3),
|
|
SeqNum3),
|
|
#prepared{statement_id = StmtId,
|
|
#prepared{statement_id = StmtId,
|
|
orig_query = Query,
|
|
orig_query = Query,
|
|
@@ -141,8 +147,8 @@ prepare(Query, TcpModule, Socket) ->
|
|
|
|
|
|
%% @doc Deallocates a prepared statement.
|
|
%% @doc Deallocates a prepared statement.
|
|
-spec unprepare(#prepared{}, atom(), term()) -> ok.
|
|
-spec unprepare(#prepared{}, atom(), term()) -> ok.
|
|
-unprepare(#prepared{statement_id = Id}, TcpModule, Socket) ->
|
|
|
|
- {ok, _SeqNum} = send_packet(TcpModule, Socket,
|
|
|
|
|
|
+unprepare(#prepared{statement_id = Id}, SockModule, Socket) ->
|
|
|
|
+ {ok, _SeqNum} = send_packet(SockModule, Socket,
|
|
<<?COM_STMT_CLOSE, Id:32/little>>, 0),
|
|
<<?COM_STMT_CLOSE, Id:32/little>>, 0),
|
|
ok.
|
|
ok.
|
|
|
|
|
|
@@ -150,7 +156,7 @@ unprepare(#prepared{statement_id = Id}, TcpModule, Socket) ->
|
|
-spec execute(#prepared{}, [term()], atom(), term(), timeout()) ->
|
|
-spec execute(#prepared{}, [term()], atom(), term(), timeout()) ->
|
|
{ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
|
|
{ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
|
|
execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
|
|
execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
|
|
- TcpModule, Socket, Timeout) when ParamCount == length(ParamValues) ->
|
|
|
|
|
|
+ SockModule, Socket, Timeout) when ParamCount == length(ParamValues) ->
|
|
%% Flags Constant Name
|
|
%% Flags Constant Name
|
|
%% 0x00 CURSOR_TYPE_NO_CURSOR
|
|
%% 0x00 CURSOR_TYPE_NO_CURSOR
|
|
%% 0x01 CURSOR_TYPE_READ_ONLY
|
|
%% 0x01 CURSOR_TYPE_READ_ONLY
|
|
@@ -176,13 +182,13 @@ execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
|
|
{TypesAndSigns, EncValues} = lists:unzip(EncodedParams),
|
|
{TypesAndSigns, EncValues} = lists:unzip(EncodedParams),
|
|
iolist_to_binary([Req1, TypesAndSigns, EncValues])
|
|
iolist_to_binary([Req1, TypesAndSigns, EncValues])
|
|
end,
|
|
end,
|
|
- {ok, _SeqNum1} = send_packet(TcpModule, Socket, Req, 0),
|
|
|
|
- fetch_execute_response(TcpModule, Socket, Timeout).
|
|
|
|
|
|
+ {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, 0),
|
|
|
|
+ fetch_execute_response(SockModule, Socket, Timeout).
|
|
|
|
|
|
%% @doc This is used by execute/5. If execute/5 returns {error, timeout}, this
|
|
%% @doc This is used by execute/5. If execute/5 returns {error, timeout}, this
|
|
%% function can be called to retry to fetch the results of the query.
|
|
%% function can be called to retry to fetch the results of the query.
|
|
-fetch_execute_response(TcpModule, Socket, Timeout) ->
|
|
|
|
- fetch_response(TcpModule, Socket, Timeout, binary, []).
|
|
|
|
|
|
+fetch_execute_response(SockModule, Socket, Timeout) ->
|
|
|
|
+ fetch_response(SockModule, Socket, Timeout, binary, []).
|
|
|
|
|
|
%% --- internal ---
|
|
%% --- internal ---
|
|
|
|
|
|
@@ -237,32 +243,50 @@ server_version_to_list(ServerVersion) ->
|
|
[{capture, all_but_first, binary}]),
|
|
[{capture, all_but_first, binary}]),
|
|
lists:map(fun binary_to_integer/1, Parts).
|
|
lists:map(fun binary_to_integer/1, Parts).
|
|
|
|
|
|
|
|
+-spec maybe_do_ssl_upgrade(SockModule0 :: module(),
|
|
|
|
+ Socket0 :: term(),
|
|
|
|
+ SeqNum1 :: non_neg_integer(),
|
|
|
|
+ Handshake :: #handshake{},
|
|
|
|
+ SSLOpts :: undefined | list(),
|
|
|
|
+ Database :: iodata() | undefined,
|
|
|
|
+ SetFoundRows :: boolean()) ->
|
|
|
|
+ {ok, SockModule :: module(), Socket :: term(), SeqNum2 :: non_neg_integer()}.
|
|
|
|
+maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, _Handshake, undefined,
|
|
|
|
+ _Database, _SetFoundRows) ->
|
|
|
|
+ {ok, SockModule0, Socket0, SeqNum1};
|
|
|
|
+maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, Handshake, SSLOpts, Database, SetFoundRows) ->
|
|
|
|
+ Response = build_handshake_response(Handshake, Database, SetFoundRows),
|
|
|
|
+ {ok, SeqNum2} = send_packet(SockModule0, Socket0, Response, SeqNum1),
|
|
|
|
+ case mysql_sock_ssl:connect(Socket0, SSLOpts, 5000) of
|
|
|
|
+ {ok, SSLSocket} ->
|
|
|
|
+ {ok, ssl, SSLSocket, SeqNum2};
|
|
|
|
+ {error, Reason} ->
|
|
|
|
+ exit({failed_to_upgrade_socket, Reason})
|
|
|
|
+ end.
|
|
|
|
+
|
|
|
|
+-spec build_handshake_response(#handshake{}, iodata() | undefined, boolean()) -> binary().
|
|
|
|
+build_handshake_response(Handshake, Database, SetFoundRows) ->
|
|
|
|
+ CapabilityFlags = basic_capabilities(Database /= undefined, SetFoundRows),
|
|
|
|
+ verify_server_capabilities(Handshake, CapabilityFlags),
|
|
|
|
+ ClientCapabilities = add_client_capabilities(CapabilityFlags),
|
|
|
|
+ ClientSSLCapabilities = ClientCapabilities bor ?CLIENT_SSL,
|
|
|
|
+ CharacterSet = ?UTF8,
|
|
|
|
+ <<ClientSSLCapabilities:32/little,
|
|
|
|
+ ?MAX_BYTES_PER_PACKET:32/little,
|
|
|
|
+ CharacterSet:8,
|
|
|
|
+ 0:23/unit:8>>.
|
|
|
|
+
|
|
%% @doc The response sent by the client to the server after receiving the
|
|
%% @doc The response sent by the client to the server after receiving the
|
|
%% initial handshake from the server
|
|
%% initial handshake from the server
|
|
-spec build_handshake_response(#handshake{}, iodata(), iodata(),
|
|
-spec build_handshake_response(#handshake{}, iodata(), iodata(),
|
|
iodata() | undefined, boolean()) -> binary().
|
|
iodata() | undefined, boolean()) -> binary().
|
|
build_handshake_response(Handshake, Username, Password, Database, SetFoundRows) ->
|
|
build_handshake_response(Handshake, Username, Password, Database, SetFoundRows) ->
|
|
- %% We require these capabilities. Make sure the server handles them.
|
|
|
|
- CapabilityFlags0 = ?CLIENT_PROTOCOL_41 bor
|
|
|
|
- ?CLIENT_TRANSACTIONS bor
|
|
|
|
- ?CLIENT_SECURE_CONNECTION,
|
|
|
|
- CapabilityFlags1 = case Database of
|
|
|
|
- undefined -> CapabilityFlags0;
|
|
|
|
- _ -> CapabilityFlags0 bor ?CLIENT_CONNECT_WITH_DB
|
|
|
|
- end,
|
|
|
|
- CapabilityFlags = case SetFoundRows of
|
|
|
|
- true -> CapabilityFlags1 bor ?CLIENT_FOUND_ROWS;
|
|
|
|
- _ -> CapabilityFlags1
|
|
|
|
- end,
|
|
|
|
- Handshake#handshake.capabilities band CapabilityFlags == CapabilityFlags
|
|
|
|
- orelse error(old_server_version),
|
|
|
|
|
|
+ CapabilityFlags = basic_capabilities(Database /= undefined, SetFoundRows),
|
|
|
|
+ verify_server_capabilities(Handshake, CapabilityFlags),
|
|
%% Add some extra capability flags only for signalling to the server what
|
|
%% Add some extra capability flags only for signalling to the server what
|
|
%% the client wants to do. The server doesn't say it handles them although
|
|
%% the client wants to do. The server doesn't say it handles them although
|
|
%% it does. (http://bugs.mysql.com/bug.php?id=42268)
|
|
%% it does. (http://bugs.mysql.com/bug.php?id=42268)
|
|
- ClientCapabilityFlags = CapabilityFlags bor
|
|
|
|
- ?CLIENT_MULTI_STATEMENTS bor
|
|
|
|
- ?CLIENT_MULTI_RESULTS bor
|
|
|
|
- ?CLIENT_PS_MULTI_RESULTS,
|
|
|
|
|
|
+ ClientCapabilityFlags = add_client_capabilities(CapabilityFlags),
|
|
Hash = case Handshake#handshake.auth_plugin_name of
|
|
Hash = case Handshake#handshake.auth_plugin_name of
|
|
<<>> ->
|
|
<<>> ->
|
|
%% Server doesn't know auth plugins
|
|
%% Server doesn't know auth plugins
|
|
@@ -289,6 +313,33 @@ build_handshake_response(Handshake, Username, Password, Database, SetFoundRows)
|
|
Hash/binary,
|
|
Hash/binary,
|
|
DbBin/binary>>.
|
|
DbBin/binary>>.
|
|
|
|
|
|
|
|
+-spec verify_server_capabilities(Handshake :: #handshake{}, CapabilityFlags :: integer()) -> true | no_return().
|
|
|
|
+verify_server_capabilities(Handshake, CapabilityFlags) ->
|
|
|
|
+ %% We require these capabilities. Make sure the server handles them.
|
|
|
|
+ Handshake#handshake.capabilities band CapabilityFlags == CapabilityFlags
|
|
|
|
+ orelse error(old_server_version).
|
|
|
|
+
|
|
|
|
+-spec basic_capabilities(ConnectWithDB :: boolean(), SetFoundRows :: boolean()) -> integer().
|
|
|
|
+basic_capabilities(ConnectWithDB, SetFoundRows) ->
|
|
|
|
+ CapabilityFlags0 = ?CLIENT_PROTOCOL_41 bor
|
|
|
|
+ ?CLIENT_TRANSACTIONS bor
|
|
|
|
+ ?CLIENT_SECURE_CONNECTION,
|
|
|
|
+ CapabilityFlags1 = case ConnectWithDB of
|
|
|
|
+ true -> CapabilityFlags0 bor ?CLIENT_CONNECT_WITH_DB;
|
|
|
|
+ _ -> CapabilityFlags0
|
|
|
|
+ end,
|
|
|
|
+ case SetFoundRows of
|
|
|
|
+ true -> CapabilityFlags1 bor ?CLIENT_FOUND_ROWS;
|
|
|
|
+ _ -> CapabilityFlags1
|
|
|
|
+ end.
|
|
|
|
+
|
|
|
|
+-spec add_client_capabilities(Caps :: integer()) -> integer().
|
|
|
|
+add_client_capabilities(Caps) ->
|
|
|
|
+ Caps bor
|
|
|
|
+ ?CLIENT_MULTI_STATEMENTS bor
|
|
|
|
+ ?CLIENT_MULTI_RESULTS bor
|
|
|
|
+ ?CLIENT_PS_MULTI_RESULTS.
|
|
|
|
+
|
|
%% @doc Handles the second packet from the server, when we have replied to the
|
|
%% @doc Handles the second packet from the server, when we have replied to the
|
|
%% initial handshake. Returns an error if the server returns an error. Raises
|
|
%% initial handshake. Returns an error if the server returns an error. Raises
|
|
%% an error if unimplemented features are required.
|
|
%% an error if unimplemented features are required.
|
|
@@ -321,8 +372,8 @@ parse_handshake_confirm(Packet) ->
|
|
%% prepared statements).
|
|
%% prepared statements).
|
|
-spec fetch_response(atom(), term(), timeout(), text | binary, list()) ->
|
|
-spec fetch_response(atom(), term(), timeout(), text | binary, list()) ->
|
|
{ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
|
|
{ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
|
|
-fetch_response(TcpModule, Socket, Timeout, Proto, Acc) ->
|
|
|
|
- case recv_packet(TcpModule, Socket, Timeout, any) of
|
|
|
|
|
|
+fetch_response(SockModule, Socket, Timeout, Proto, Acc) ->
|
|
|
|
+ case recv_packet(SockModule, Socket, Timeout, any) of
|
|
{ok, Packet, SeqNum2} ->
|
|
{ok, Packet, SeqNum2} ->
|
|
Result = case Packet of
|
|
Result = case Packet of
|
|
?ok_pattern ->
|
|
?ok_pattern ->
|
|
@@ -332,7 +383,7 @@ fetch_response(TcpModule, Socket, Timeout, Proto, Acc) ->
|
|
ResultPacket ->
|
|
ResultPacket ->
|
|
%% The first packet in a resultset is only the column count.
|
|
%% The first packet in a resultset is only the column count.
|
|
{ColCount, <<>>} = lenenc_int(ResultPacket),
|
|
{ColCount, <<>>} = lenenc_int(ResultPacket),
|
|
- R0 = fetch_resultset(TcpModule, Socket, ColCount, SeqNum2),
|
|
|
|
|
|
+ R0 = fetch_resultset(SockModule, Socket, ColCount, SeqNum2),
|
|
case R0 of
|
|
case R0 of
|
|
#error{} = E ->
|
|
#error{} = E ->
|
|
%% TODO: Find a way to get here + testcase
|
|
%% TODO: Find a way to get here + testcase
|
|
@@ -344,7 +395,7 @@ fetch_response(TcpModule, Socket, Timeout, Proto, Acc) ->
|
|
Acc1 = [Result | Acc],
|
|
Acc1 = [Result | Acc],
|
|
case more_results_exists(Result) of
|
|
case more_results_exists(Result) of
|
|
true ->
|
|
true ->
|
|
- fetch_response(TcpModule, Socket, Timeout, Proto, Acc1);
|
|
|
|
|
|
+ fetch_response(SockModule, Socket, Timeout, Proto, Acc1);
|
|
false ->
|
|
false ->
|
|
{ok, lists:reverse(Acc1)}
|
|
{ok, lists:reverse(Acc1)}
|
|
end;
|
|
end;
|
|
@@ -358,12 +409,12 @@ fetch_response(TcpModule, Socket, Timeout, Proto, Acc) ->
|
|
%% be parsed.
|
|
%% be parsed.
|
|
-spec fetch_resultset(atom(), term(), integer(), integer()) ->
|
|
-spec fetch_resultset(atom(), term(), integer(), integer()) ->
|
|
#resultset{} | #error{}.
|
|
#resultset{} | #error{}.
|
|
-fetch_resultset(TcpModule, Socket, FieldCount, SeqNum) ->
|
|
|
|
- {ok, ColDefs, SeqNum1} = fetch_column_definitions(TcpModule, Socket, SeqNum,
|
|
|
|
|
|
+fetch_resultset(SockModule, Socket, FieldCount, SeqNum) ->
|
|
|
|
+ {ok, ColDefs, SeqNum1} = fetch_column_definitions(SockModule, Socket, SeqNum,
|
|
FieldCount, []),
|
|
FieldCount, []),
|
|
- {ok, DelimiterPacket, SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
|
|
|
|
|
|
+ {ok, DelimiterPacket, SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
|
|
#eof{status = S, warning_count = W} = parse_eof_packet(DelimiterPacket),
|
|
#eof{status = S, warning_count = W} = parse_eof_packet(DelimiterPacket),
|
|
- case fetch_resultset_rows(TcpModule, Socket, SeqNum2, []) of
|
|
|
|
|
|
+ case fetch_resultset_rows(SockModule, Socket, SeqNum2, []) of
|
|
{ok, Rows, _SeqNum3} ->
|
|
{ok, Rows, _SeqNum3} ->
|
|
ColDefs1 = lists:map(fun parse_column_definition/1, ColDefs),
|
|
ColDefs1 = lists:map(fun parse_column_definition/1, ColDefs),
|
|
#resultset{cols = ColDefs1, rows = Rows,
|
|
#resultset{cols = ColDefs1, rows = Rows,
|
|
@@ -393,12 +444,12 @@ more_results_exists(#resultset{status = S}) ->
|
|
-spec fetch_column_definitions(atom(), term(), SeqNum :: integer(),
|
|
-spec fetch_column_definitions(atom(), term(), SeqNum :: integer(),
|
|
NumLeft :: integer(), Acc :: [binary()]) ->
|
|
NumLeft :: integer(), Acc :: [binary()]) ->
|
|
{ok, ColDefPackets :: [binary()], NextSeqNum :: integer()}.
|
|
{ok, ColDefPackets :: [binary()], NextSeqNum :: integer()}.
|
|
-fetch_column_definitions(TcpModule, Socket, SeqNum, NumLeft, Acc)
|
|
|
|
|
|
+fetch_column_definitions(SockModule, Socket, SeqNum, NumLeft, Acc)
|
|
when NumLeft > 0 ->
|
|
when NumLeft > 0 ->
|
|
- {ok, Packet, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum),
|
|
|
|
- fetch_column_definitions(TcpModule, Socket, SeqNum1, NumLeft - 1,
|
|
|
|
|
|
+ {ok, Packet, SeqNum1} = recv_packet(SockModule, Socket, SeqNum),
|
|
|
|
+ fetch_column_definitions(SockModule, Socket, SeqNum1, NumLeft - 1,
|
|
[Packet | Acc]);
|
|
[Packet | Acc]);
|
|
-fetch_column_definitions(_TcpModule, _Socket, SeqNum, 0, Acc) ->
|
|
|
|
|
|
+fetch_column_definitions(_SockModule, _Socket, SeqNum, 0, Acc) ->
|
|
{ok, lists:reverse(Acc), SeqNum}.
|
|
{ok, lists:reverse(Acc), SeqNum}.
|
|
|
|
|
|
%% @doc Fetches rows in a result set. There is a packet per row. The row packets
|
|
%% @doc Fetches rows in a result set. There is a packet per row. The row packets
|
|
@@ -408,15 +459,15 @@ fetch_column_definitions(_TcpModule, _Socket, SeqNum, 0, Acc) ->
|
|
{ok, Rows, integer()} | #error{}
|
|
{ok, Rows, integer()} | #error{}
|
|
when Acc :: [binary()],
|
|
when Acc :: [binary()],
|
|
Rows :: [binary()].
|
|
Rows :: [binary()].
|
|
-fetch_resultset_rows(TcpModule, Socket, SeqNum, Acc) ->
|
|
|
|
- {ok, Packet, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum),
|
|
|
|
|
|
+fetch_resultset_rows(SockModule, Socket, SeqNum, Acc) ->
|
|
|
|
+ {ok, Packet, SeqNum1} = recv_packet(SockModule, Socket, SeqNum),
|
|
case Packet of
|
|
case Packet of
|
|
?error_pattern ->
|
|
?error_pattern ->
|
|
parse_error_packet(Packet);
|
|
parse_error_packet(Packet);
|
|
?eof_pattern ->
|
|
?eof_pattern ->
|
|
{ok, lists:reverse(Acc), SeqNum1};
|
|
{ok, lists:reverse(Acc), SeqNum1};
|
|
Row ->
|
|
Row ->
|
|
- fetch_resultset_rows(TcpModule, Socket, SeqNum1, [Row | Acc])
|
|
|
|
|
|
+ fetch_resultset_rows(SockModule, Socket, SeqNum1, [Row | Acc])
|
|
end.
|
|
end.
|
|
|
|
|
|
%% Parses a packet containing a column definition (part of a result set)
|
|
%% Parses a packet containing a column definition (part of a result set)
|
|
@@ -543,12 +594,12 @@ decode_text(#col{type = T}, Text) when T == ?TYPE_FLOAT;
|
|
|
|
|
|
%% @doc If NumColumns is non-zero, fetches this number of column definitions
|
|
%% @doc If NumColumns is non-zero, fetches this number of column definitions
|
|
%% and an EOF packet. Used by prepare/3.
|
|
%% and an EOF packet. Used by prepare/3.
|
|
-fetch_column_definitions_if_any(0, _TcpModule, _Socket, SeqNum) ->
|
|
|
|
|
|
+fetch_column_definitions_if_any(0, _SockModule, _Socket, SeqNum) ->
|
|
{[], SeqNum};
|
|
{[], SeqNum};
|
|
-fetch_column_definitions_if_any(N, TcpModule, Socket, SeqNum) ->
|
|
|
|
- {ok, Defs, SeqNum1} = fetch_column_definitions(TcpModule, Socket, SeqNum,
|
|
|
|
|
|
+fetch_column_definitions_if_any(N, SockModule, Socket, SeqNum) ->
|
|
|
|
+ {ok, Defs, SeqNum1} = fetch_column_definitions(SockModule, Socket, SeqNum,
|
|
N, []),
|
|
N, []),
|
|
- {ok, ?eof_pattern, SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
|
|
|
|
|
|
+ {ok, ?eof_pattern, SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
|
|
{Defs, SeqNum2}.
|
|
{Defs, SeqNum2}.
|
|
|
|
|
|
%% @doc Decodes a packet representing a row in a binary result set.
|
|
%% @doc Decodes a packet representing a row in a binary result set.
|
|
@@ -884,40 +935,40 @@ decode_decimal(Bin, P, S) when P >= 16, S > 0 ->
|
|
|
|
|
|
%% -- Protocol basics: packets --
|
|
%% -- Protocol basics: packets --
|
|
|
|
|
|
-%% @doc Wraps Data in packet headers, sends it by calling TcpModule:send/2 with
|
|
|
|
|
|
+%% @doc Wraps Data in packet headers, sends it by calling SockModule:send/2 with
|
|
%% Socket and returns {ok, SeqNum1} where SeqNum1 is the next sequence number.
|
|
%% Socket and returns {ok, SeqNum1} where SeqNum1 is the next sequence number.
|
|
-spec send_packet(atom(), term(), Data :: binary(), SeqNum :: integer()) ->
|
|
-spec send_packet(atom(), term(), Data :: binary(), SeqNum :: integer()) ->
|
|
{ok, NextSeqNum :: integer()}.
|
|
{ok, NextSeqNum :: integer()}.
|
|
-send_packet(TcpModule, Socket, Data, SeqNum) ->
|
|
|
|
|
|
+send_packet(SockModule, Socket, Data, SeqNum) ->
|
|
{WithHeaders, SeqNum1} = add_packet_headers(Data, SeqNum),
|
|
{WithHeaders, SeqNum1} = add_packet_headers(Data, SeqNum),
|
|
- ok = TcpModule:send(Socket, WithHeaders),
|
|
|
|
|
|
+ ok = SockModule:send(Socket, WithHeaders),
|
|
{ok, SeqNum1}.
|
|
{ok, SeqNum1}.
|
|
|
|
|
|
%% @see recv_packet/4
|
|
%% @see recv_packet/4
|
|
-recv_packet(TcpModule, Socket, SeqNum) ->
|
|
|
|
- recv_packet(TcpModule, Socket, infinity, SeqNum).
|
|
|
|
|
|
+recv_packet(SockModule, Socket, SeqNum) ->
|
|
|
|
+ recv_packet(SockModule, Socket, infinity, SeqNum).
|
|
|
|
|
|
-%% @doc Receives data by calling TcpModule:recv/2 and removes the packet
|
|
|
|
|
|
+%% @doc Receives data by calling SockModule:recv/2 and removes the packet
|
|
%% headers. Returns the packet contents and the next packet sequence number.
|
|
%% headers. Returns the packet contents and the next packet sequence number.
|
|
-spec recv_packet(atom(), term(), timeout(), integer() | any) ->
|
|
-spec recv_packet(atom(), term(), timeout(), integer() | any) ->
|
|
{ok, Data :: binary(), NextSeqNum :: integer()} | {error, term()}.
|
|
{ok, Data :: binary(), NextSeqNum :: integer()} | {error, term()}.
|
|
-recv_packet(TcpModule, Socket, Timeout, SeqNum) ->
|
|
|
|
- recv_packet(TcpModule, Socket, Timeout, SeqNum, <<>>).
|
|
|
|
|
|
+recv_packet(SockModule, Socket, Timeout, SeqNum) ->
|
|
|
|
+ recv_packet(SockModule, Socket, Timeout, SeqNum, <<>>).
|
|
|
|
|
|
%% @doc Accumulating helper for recv_packet/4
|
|
%% @doc Accumulating helper for recv_packet/4
|
|
-spec recv_packet(atom(), term(), timeout(), integer() | any, binary()) ->
|
|
-spec recv_packet(atom(), term(), timeout(), integer() | any, binary()) ->
|
|
{ok, Data :: binary(), NextSeqNum :: integer()} | {error, term()}.
|
|
{ok, Data :: binary(), NextSeqNum :: integer()} | {error, term()}.
|
|
-recv_packet(TcpModule, Socket, Timeout, ExpectSeqNum, Acc) ->
|
|
|
|
- case TcpModule:recv(Socket, 4, Timeout) of
|
|
|
|
|
|
+recv_packet(SockModule, Socket, Timeout, ExpectSeqNum, Acc) ->
|
|
|
|
+ case SockModule:recv(Socket, 4, Timeout) of
|
|
{ok, Header} ->
|
|
{ok, Header} ->
|
|
{Size, SeqNum, More} = parse_packet_header(Header),
|
|
{Size, SeqNum, More} = parse_packet_header(Header),
|
|
true = SeqNum == ExpectSeqNum orelse ExpectSeqNum == any,
|
|
true = SeqNum == ExpectSeqNum orelse ExpectSeqNum == any,
|
|
- {ok, Body} = TcpModule:recv(Socket, Size),
|
|
|
|
|
|
+ {ok, Body} = SockModule:recv(Socket, Size),
|
|
Acc1 = <<Acc/binary, Body/binary>>,
|
|
Acc1 = <<Acc/binary, Body/binary>>,
|
|
NextSeqNum = (SeqNum + 1) band 16#ff,
|
|
NextSeqNum = (SeqNum + 1) band 16#ff,
|
|
case More of
|
|
case More of
|
|
false -> {ok, Acc1, NextSeqNum};
|
|
false -> {ok, Acc1, NextSeqNum};
|
|
- true -> recv_packet(TcpModule, Socket, Timeout, NextSeqNum,
|
|
|
|
|
|
+ true -> recv_packet(SockModule, Socket, Timeout, NextSeqNum,
|
|
Acc1)
|
|
Acc1)
|
|
end;
|
|
end;
|
|
{error, Reason} ->
|
|
{error, Reason} ->
|