Browse Source

Implement plain queries and the text protocol

Viktor Söderqvist 10 years ago
parent
commit
7a767be4b5
11 changed files with 707 additions and 183 deletions
  1. 1 1
      README.md
  2. 33 0
      include/protocol.hrl
  3. 19 8
      include/records.hrl
  4. 1 0
      rebar.config
  5. 1 1
      src/mysql.app.src
  6. 36 0
      src/mysql.erl
  7. 61 98
      src/mysql_connection.erl
  8. 203 75
      src/mysql_protocol.erl
  9. 108 0
      src/mysql_text_protocol.erl
  10. 46 0
      test/mysql_tests.erl
  11. 198 0
      test/protocol_tests.erl

+ 1 - 1
README.md

@@ -3,7 +3,7 @@ MySQL/OTP
 
 This is a MySQL driver for Erlang following the OTP principles.
 
-Status: Just started. Connecting works but nothing else.
+Status: Pre-alpha. Connecting and queries using the text protocol work. The API and the value representation are subjects to change.
 
 Background: We are starting this project with the aim at overcoming the problems with Emysql (the currently most popular driver) and erlang-mysql-driver (the even older driver).
 

+ 33 - 0
include/protocol.hrl

@@ -24,6 +24,9 @@
 -define(EOF, 16#fe).
 -define(ERROR, 16#ff).
 
+%% Character sets
+-define(UTF8, 16#21). %% utf8_general_ci
+
 %% --- Capability flags ---
 
 %% Server: supports the 4.1 protocol 
@@ -86,3 +89,33 @@
 -define(COM_SET_OPTION, 16#1b).
 -define(COM_STMT_FETCH, 16#1c).
 
+%% --- Types ---
+
+-define(TYPE_DECIMAL, 16#00).
+-define(TYPE_TINY, 16#01).
+-define(TYPE_SHORT, 16#02).
+-define(TYPE_LONG, 16#03).
+-define(TYPE_FLOAT, 16#04).
+-define(TYPE_DOUBLE, 16#05).
+-define(TYPE_NULL, 16#06).
+-define(TYPE_TIMESTAMP, 16#07).
+-define(TYPE_LONGLONG, 16#08).
+-define(TYPE_INT24, 16#09).
+-define(TYPE_DATE, 16#0a).
+-define(TYPE_TIME, 16#0b).
+-define(TYPE_DATETIME, 16#0c).
+-define(TYPE_YEAR, 16#0d).
+-define(TYPE_VARCHAR, 16#0f).
+-define(TYPE_BIT, 16#10).
+-define(TYPE_NEWDECIMAL, 16#f6).
+-define(TYPE_ENUM, 16#f7).
+-define(TYPE_SET, 16#f8).
+-define(TYPE_TINY_BLOB, 16#f9).
+-define(TYPE_MEDIUM_BLOB, 16#fa).
+-define(TYPE_LONG_BLOB, 16#fb).
+-define(TYPE_BLOB, 16#fc).
+-define(TYPE_VAR_STRING, 16#fd).
+-define(TYPE_STRING, 16#fe).
+-define(TYPE_GEOMETRY, 16#ff).
+
+

+ 19 - 8
include/records.hrl

@@ -9,12 +9,23 @@
                     auth_plugin_data :: binary(),
                     auth_plugin_name :: binary()}).
 
-%% Records returned by parse_response/1.
--record(ok_packet, {affected_rows :: integer(),
-                    insert_id :: integer(),
-                    status :: integer(),
-                    warning_count :: integer(),
-                    msg :: binary()}).
--record(error_packet, {code, state, msg}).
--record(eof_packet, {status, warning_count}).
+%% OK packet, commonly used in the protocol.
+-record(ok, {affected_rows :: integer(),
+             insert_id :: integer(),
+             status :: integer(),
+             warning_count :: integer(),
+             msg :: binary()}).
+%% Error packet, commonly used in the protocol.
+-record(error, {code, state, msg}).
+
+%% EOF packet, commonly used in the protocol.
+-record(eof, {status, warning_count}).
+
+
+%% Column definition, used while parsing a result set.
+-record(column_definition, {name, type, charset}).
 
+%% A resultset as received from the server using the text protocol.
+%% All values are binary (SQL code) except NULL.
+-record(text_resultset, {column_definitions :: [#column_definition{}],
+                         rows :: [[binary() | null]]}).

+ 1 - 0
rebar.config

@@ -0,0 +1 @@
+{cover_enabled, true}.

+ 1 - 1
src/mysql.app.src

@@ -1,5 +1,5 @@
 {application, mysql, [
     {description, "MySQL/OTP - Erlang MySQL driver"},
-    {vsn, "0.0.1"},
+    {vsn, "0.0.2"},
     {modules, []}
 ]}.

+ 36 - 0
src/mysql.erl

@@ -0,0 +1,36 @@
+%% @doc MySQL/OTP
+-module(mysql).
+
+-export([connect/1, disconnect/1, query/2, warning_count/1, affected_rows/1,
+         insert_id/1]).
+
+-spec connect(list()) -> {ok, pid()} | ignore | {error, term()}.
+connect(Opts) ->
+    gen_server:start_link(mysql_connection, Opts, []).
+
+-spec disconnect(pid()) -> ok.
+disconnect(Conn) ->
+    exit(Conn, normal),
+    ok.
+
+-spec query(Conn, Query) -> ok | {ok, Fields, Rows} | {error, Reason}
+    when Conn :: pid(),
+         Query :: iodata(),
+         Fields :: [binary()],
+         Rows :: [[term()]],
+         Reason :: {Code :: integer(), SQLState :: binary(),
+                    Message :: binary()}.
+query(Conn, Query) ->
+    gen_server:call(Conn, {query, Query}).
+
+-spec warning_count(pid()) -> integer().
+warning_count(Conn) ->
+    gen_server:call(Conn, warning_count).
+
+-spec affected_rows(pid()) -> integer().
+affected_rows(Conn) ->
+    gen_server:call(Conn, affected_rows).
+
+-spec insert_id(pid()) -> integer().
+insert_id(Conn) ->
+    gen_server:call(Conn, insert_id).

+ 61 - 98
src/mysql_connection.erl

@@ -1,69 +1,69 @@
+%% A mysql connection implemented as a gen_server. This is a gen_server callback
+%% module only. The API functions are located in the mysql module.
 -module(mysql_connection).
 -behaviour(gen_server).
 
--export([start_link/1]).
-
-%% Gen_server callbacks
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2,
          code_change/3]).
 
--include("records.hrl").
-
-start_link(Args) ->
-    gen_server:start_link(?MODULE, Args, []).
-
-%% --- Gen_server ballbacks ---
-
+%% Some defaults
 -define(default_host, "localhost").
 -define(default_port, 3306).
 -define(default_user, <<>>).
 -define(default_password, <<>>).
 -define(default_timeout, infinity).
 
+-include("records.hrl").
+
+%% Gen_server state
 -record(state, {socket, affected_rows = 0, status = 0, warning_count = 0,
                 insert_id = 0}).
 
+%% A tuple representing a MySQL server error, typically returned in the form
+%% {error, reason()}.
+-type reason() :: {Code :: integer(), SQLState :: binary(), Msg :: binary()}.
+
 init(Opts) ->
     %% Connect
     Host     = proplists:get_value(host,     Opts, ?default_host),
     Port     = proplists:get_value(port,     Opts, ?default_port),
     User     = proplists:get_value(user,     Opts, ?default_user),
     Password = proplists:get_value(password, Opts, ?default_password),
+    Database = proplists:get_value(database, Opts, undefined),
     Timeout  = proplists:get_value(timeout,  Opts, ?default_timeout),
 
     %% Connect socket
     SockOpts = [{active, false}, binary, {packet, raw}],
     {ok, Socket} = gen_tcp:connect(Host, Port, SockOpts),
 
-    %% Receive handshake
-    {ok, HandshakeBin, 1} = recv(Socket, 0, Timeout),
-    Handshake = mysql_protocol:parse_handshake(HandshakeBin),
-
-    %% Reply to handshake
-    HandshakeResp =
-        mysql_protocol:build_handshake_response(Handshake, User, Password),
-    {ok, 2} = send(Socket, HandshakeResp, 1),
-
-    %% Receive connection ok or error
-    {ok, ContBin, 3} = recv(Socket, 2, Timeout),
-    case mysql_protocol:parse_handshake_confirm(ContBin) of
-        #ok_packet{status = Status} ->
+    %% Exchange handshake communication.
+    Result = mysql_protocol:handshake(User, Password, Database,
+                                      fun (Data) ->
+                                          gen_tcp:send(Socket, Data)
+                                      end,
+                                      fun (Size) ->
+                                          gen_tcp:recv(Socket, Size, Timeout)
+                                      end),
+    case Result of
+        #ok{status = Status} ->
             {ok, #state{status = Status, socket = Socket}};
-        #error_packet{msg = Reason} ->
-            {stop, Reason}
+        #error{} = E ->
+            {stop, error_to_reason(E)}
     end.
 
-handle_call({'query', Query}, _From, State) when is_binary(Query) ->
-    Req = mysql_protocol:build_query(Query),
-    Resp = call_db(State, Req),
-    Rec = mysql_protocol:parse_query_response(Resp),
+handle_call({query, Query}, _From, State) when is_binary(Query) ->
+    Rec = mysql_protocol:query_tcp(Query, State#state.socket,
+                                   infinity),
     State1 = update_state(State, Rec),
     case Rec of
-        #ok_packet{} ->
+        #ok{} ->
             {reply, ok, State1};
-        #error_packet{msg = Msg} ->
-            {reply, {error, Msg}, State1}
-        %% TODO: Add result set here.
+        #error{} = E ->
+            {reply, {error, error_to_reason(E)}, State1};
+        #text_resultset{column_definitions = ColDefs, rows = Rows} ->
+            Names = [Def#column_definition.name || Def <- ColDefs],
+            Rows1 = decode_text_rows(ColDefs, Rows),
+            {reply, {ok, Names, Rows1}, State1}
     end;
 handle_call(warning_count, _From, State) ->
     {reply, State#state.warning_count, State};
@@ -83,71 +83,34 @@ code_change(_, _, _) -> todo.
 
 %% --- Helpers ---
 
+%% @doc Produces a tuple to return when an error needs to be returned to in the
+%% public API.
+-spec error_to_reason(#error{}) -> reason().
+error_to_reason(#error{code = Code, state = State, msg = Msg}) ->
+    {Code, State, Msg}.
+
 %% @doc Updates a state with information from a response.
--spec update_state(#state{}, #ok_packet{} | #error_packet{} | #eof_packet{}) ->
-    #state{}.
-update_state(State, #ok_packet{status = S, affected_rows = R,
-                               insert_id = Id, warning_count = W}) ->
+-spec update_state(#state{}, #ok{} | #eof{} | any()) -> #state{}.
+update_state(State, #ok{status = S, affected_rows = R,
+                        insert_id = Id, warning_count = W}) ->
     State#state{status = S, affected_rows = R, insert_id = Id,
                 warning_count = W};
-update_state(State, #error_packet{}) ->
-    State;
-update_state(State, #eof_packet{status = S, warning_count = W}) ->
-    State#state{status = S, warning_count = W}.
-
-%% @doc Sends data to mysql and receives the response.
-call_db(State, PacketBody) ->
-    call_db(State, PacketBody, infinity).
-
-%% @doc Sends data to mysql and receives the response.
-call_db(#state{socket = Socket}, PacketBody, Timeout) ->
-    {ok, SeqNum} = send(Socket, PacketBody, 0),
-    {ok, Response, _SeqNum} = recv(Socket, SeqNum, Timeout),
-    Response.
-
-%% @doc Sends data and returns {ok, SeqNum1} where SeqNum1 is the next sequence
-%% number.
--spec send(Socket :: gen_tcp:socket(), Data :: binary(), SeqNum :: integer()) ->
-    {ok, NextSeqNum :: integer()}.
-send(Socket, Data, SeqNum) ->
-    {WithHeaders, SeqNum1} = mysql_protocol:add_packet_headers(Data, SeqNum),
-    ok = gen_tcp:send(Socket, WithHeaders),
-    {ok, SeqNum1}.
-
-%% @doc Receives data from the server and removes packet headers. Returns the
-%% next packet sequence number.
--spec recv(Socket :: gen_tcp:socket(), SeqNum :: integer(),
-           Timeout :: timeout()) ->
-    {ok, Data :: binary(), NextSeqNum :: integer()}.
-recv(Socket, SeqNum, Timeout) ->
-    recv(Socket, SeqNum, Timeout, <<>>).
-
-%% @doc Receives data from the server and removes packet headers. Returns the
-%% next packet sequence number.
--spec recv(Socket :: gen_tcp:socket(), ExpectSeqNum :: integer(),
-           Timeout :: timeout(), Acc :: binary()) ->
-    {ok, Data :: binary(), NextSeqNum :: integer()}.
-recv(Socket, ExpectSeqNum, Timeout, Acc) ->
-    {ok, Header} = gen_tcp:recv(Socket, 4, Timeout),
-    {Size, ExpectSeqNum, More} = mysql_protocol:parse_packet_header(Header),
-    {ok, Body} = gen_tcp:recv(Socket, Size, Timeout),
-    Acc1 = <<Acc/binary, Body/binary>>,
-    NextSeqNum = (ExpectSeqNum + 1) band 16#ff,
-    case More of
-        false -> {ok, Acc1, NextSeqNum};
-        true  -> recv(Socket, NextSeqNum, Acc1)
-    end.
-
--ifdef(TEST).
--include_lib("eunit/include/eunit.hrl").
-
-connect_test() ->
-    {ok, Pid} = start_link([{user, "test"}, {password, "test"}]),
-    %ok = gen_server:call(Pid, {'query', <<"CREATE DATABASE foo">>}),
-    ok = gen_server:call(Pid, {'query', <<"USE foo">>}),
-    ok = gen_server:call(Pid, {'query', <<"DROP TABLE IF EXISTS foo">>}),
-    1 = gen_server:call(Pid, warning_count),
-    {error, <<"You h", _/binary>>} = gen_server:call(Pid, {'query', <<"FOO">>}),
-    ok.
-
--endif.
+update_state(State, #eof{status = S, warning_count = W}) ->
+    State#state{status = S, warning_count = W, insert_id = 0,
+                affected_rows = 0};
+update_state(State, _Other) ->
+    %% This includes errors, resultsets, etc.
+    %% Reset warnings, etc. (Note: We don't reset 'status'.)
+    State#state{warning_count = 0, insert_id = 0, affected_rows = 0}.
+
+%% @doc Uses a list of column definitions to decode rows returned in the text
+%% protocol. Returns the rows with values as for their type their appropriate
+%% Erlang terms.
+decode_text_rows(ColDefs, Rows) ->
+    [decode_text_row_acc(ColDefs, Row, []) || Row <- Rows].
+
+decode_text_row_acc([#column_definition{type = T} | Defs], [V | Vs], Acc) ->
+    Term = mysql_text_protocol:text_to_term(T, V),
+    decode_text_row_acc(Defs, Vs, [Term | Acc]);
+decode_text_row_acc([], [], Acc) ->
+    lists:reverse(Acc).

+ 203 - 75
src/mysql_protocol.erl

@@ -6,21 +6,26 @@
 %% TCP communication is not handled in this module.
 -module(mysql_protocol).
 
--export([parse_packet_header/1, add_packet_headers/2,
-         parse_handshake/1, build_handshake_response/3,
-         parse_handshake_confirm/1,
-         build_query/1, parse_query_response/1]).
+-export([
+         %parse_packet_header/1, add_packet_headers/2,
+         %parse_handshake/1, build_handshake_response/3,
+         %parse_handshake_confirm/1,
+         handshake/5,
+         query_tcp/3, query/3]).
 
 -include("records.hrl").
 -include("protocol.hrl").
 
+-type sendfun() :: fun((binary()) -> ok).
+-type recvfun() :: fun((integer()) -> {ok, binary()}).
+
 %% How much data do we want to send at most?
 -define(MAX_BYTES_PER_PACKET, 50000000).
 
 %% Macros for pattern matching on packets.
 -define(ok_pattern, <<?OK, _/binary>>).
 -define(error_pattern, <<?ERROR, _/binary>>).
--define(eof_pattern, <<?EOF, _/binary>>).
+-define(eof_pattern, <<?EOF, _:4/binary>>).
 
 %% @doc Parses a packet header (32 bits) and returns a tuple.
 %%
@@ -52,13 +57,30 @@ add_packet_headers(PacketBody, SeqNum) ->
         {[<<Size:24/little, SeqNum:8>>, Bin], SeqNum1}
     end.
 
+%% @doc Performs a handshake using the supplied functions for communication.
+%% Returns an ok or an error record.
+%%
+%% TODO: Implement setting the database in the handshake. Currently an error
+%% occurs if Database is anything other than undefined.
+-spec handshake(iodata(), iodata(), iodata() | undefined, sendfun(),
+                recvfun()) -> #ok{} | #error{}.
+handshake(Username, Password, Database, SendFun, RecvFun) ->
+    SeqNum0 = 0,
+    Database == undefined orelse error(database_in_handshake_not_implemented),
+    {ok, HandshakePacket, SeqNum1} = recv_packet(RecvFun, SeqNum0),
+    Handshake = parse_handshake(HandshakePacket),
+    Response = build_handshake_response(Handshake, Username, Password),
+    {ok, SeqNum2} = send_packet(SendFun, Response, SeqNum1),
+    {ok, ConfirmPacket, _SeqNum3} = recv_packet(RecvFun, SeqNum2),
+    parse_handshake_confirm(ConfirmPacket).
+
 %% @doc Parses a handshake. This is the first thing that comes from the server
 %% when connecting. If an unsupported version of variant of the protocol is used
 %% an error is raised.
 -spec parse_handshake(binary()) -> #handshake{}.
 parse_handshake(<<10, Rest/binary>>) ->
     %% Protocol version 10.
-    {ServerVersion, Rest1} = nulterm(Rest),
+    {ServerVersion, Rest1} = nulterm_str(Rest),
     <<ConnectionId:32/little,
       AuthPluginDataPart1:8/binary-unit:8,
       0, %% "filler" -- everything below is optional
@@ -103,12 +125,12 @@ build_handshake_response(Handshake, Username, Password) ->
                       ?CLIENT_TRANSACTIONS bor
                       ?CLIENT_SECURE_CONNECTION,
     Handshake#handshake.capabilities band CapabilityFlags == CapabilityFlags
-        orelse error({incompatible, <<"Server version is too old">>}),
+        orelse error({not_implemented, old_server_version}),
     Hash = hash_password(Password,
                          Handshake#handshake.auth_plugin_name,
                          Handshake#handshake.auth_plugin_data),
     HashLength = size(Hash),
-    CharacterSet = 16#21, %% utf8_general_ci
+    CharacterSet = ?UTF8,
     UsernameUtf8 = unicode:characters_to_binary(Username),
     <<CapabilityFlags:32/little,
       ?MAX_BYTES_PER_PACKET:32/little,
@@ -120,8 +142,9 @@ build_handshake_response(Handshake, Username, Password) ->
       Hash/binary>>.
 
 %% @doc Handles the second packet from the server, when we have replied to the
-%% initial handshake. Returns an error if unimplemented features are required.
--spec parse_handshake_confirm(binary()) -> #ok_packet{} | #error_packet{}.
+%% initial handshake. Returns an error if the server returns an error. Raises
+%% an error if unimplemented features are required.
+-spec parse_handshake_confirm(binary()) -> #ok{} | #error{}.
 parse_handshake_confirm(Packet) ->
     case Packet of
         ?ok_pattern ->
@@ -135,45 +158,151 @@ parse_handshake_confirm(Packet) ->
             %% single 0xfe byte. It is sent by server to request client to
             %% switch to Old Password Authentication if CLIENT_PLUGIN_AUTH
             %% capability is not supported (by either the client or the server)"
-            %%
-            %% Simulate an error packet (without code)
-            #error_packet{msg = <<"Old auth method not implemented">>};
+            error({not_implemented, old_auth});
         <<?EOF, _/binary>> ->
             %% "Authentication Method Switch Request Packet. If both server and
             %% client support CLIENT_PLUGIN_AUTH capability, server can send
             %% this packet to ask client to use another authentication method."
-            %%
-            %% Simulate an error packet (without code)
-            #error_packet{msg = <<"Auth method switch not implemented">>}
+            error({not_implemented, auth_method_switch})
     end.
 
-build_query(Query) when is_binary(Query) ->
-    <<?COM_QUERY, Query/binary>>.
+%% @doc Query on a tcp socket.
+query_tcp(Query, Socket, Timeout) ->
+    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
+    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+    query(Query, SendFun, RecvFun).
 
-%% @doc TODO: Handle result set responses.
--spec parse_query_response(binary()) -> #ok_packet{} | #error_packet{}.
-parse_query_response(Resp) ->
+%% @doc Normally fun gen_tcp:send/2 and fun gen_tcp:recv/3 are used, except in
+%% unit testing.
+query(Query, SendFun, RecvFun) ->
+    Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
+    SeqNum0 = 0,
+    {ok, SeqNum1} = send_packet(SendFun, Req, SeqNum0),
+    {ok, Resp, SeqNum2} = recv_packet(RecvFun, SeqNum1),
     case Resp of
-        ?ok_pattern -> parse_ok_packet(Resp);
-        ?error_pattern -> parse_error_packet(Resp);
-        _ -> error(result_set_not_implemented)
+        ?ok_pattern ->
+            parse_ok_packet(Resp);
+        ?error_pattern ->
+            parse_error_packet(Resp);
+        _ResultSet ->
+            %% The first packet in a resultset is just the field count.
+            {FieldCount, <<>>} = lenenc_int(Resp),
+            fetch_resultset(RecvFun, FieldCount, SeqNum2)
     end.
 
+-spec fetch_resultset(recvfun(), integer(), integer()) ->
+    #text_resultset{} | #error{}.
+fetch_resultset(RecvFun, FieldCount, SeqNum) ->
+    {ok, ColDefs, SeqNum1} = fetch_column_definitions(RecvFun, SeqNum,
+                                                      FieldCount, []),
+    {ok, DelimiterPacket, SeqNum2} = recv_packet(RecvFun, SeqNum1),
+    case DelimiterPacket of
+        ?eof_pattern ->
+            #eof{} = parse_eof_packet(DelimiterPacket),
+            {ok, Rows, _SeqNum3} = fetch_resultset_rows(RecvFun, ColDefs,
+                                                        SeqNum2, []),
+            #text_resultset{column_definitions = ColDefs, rows = Rows};
+        ?error_pattern ->
+            parse_error_packet(DelimiterPacket)
+    end.
+
+%% Receives NumLeft packets and parses them as column definitions.
+-spec fetch_column_definitions(recvfun(), SeqNum :: integer(),
+                               NumLeft :: integer(), Acc :: [tuple()]) ->
+    {ok, [tuple()], NextSeqNum :: integer()}.
+fetch_column_definitions(RecvFun, SeqNum, NumLeft, Acc) when NumLeft > 0 ->
+    {ok, Packet, SeqNum1} = recv_packet(RecvFun, SeqNum),
+    ColDef = parse_column_definition(Packet),
+    fetch_column_definitions(RecvFun, SeqNum1, NumLeft - 1, [ColDef | Acc]);
+fetch_column_definitions(_RecvFun, SeqNum, 0, Acc) ->
+    {ok, lists:reverse(Acc), SeqNum}.
+
+fetch_resultset_rows(RecvFun, ColDefs, SeqNum, Acc) ->
+    {ok, Packet, SeqNum1} = recv_packet(RecvFun, SeqNum),
+    case Packet of
+        ?eof_pattern ->
+            {ok, lists:reverse(Acc), SeqNum1};
+        _AnotherRow ->
+            Row = parse_resultset_row(ColDefs, Packet, []),
+            fetch_resultset_rows(RecvFun, ColDefs, SeqNum1, [Row | Acc])
+    end.
+
+%% parses Data using ColDefs and builds the values Acc.
+parse_resultset_row([_ColDef | ColDefs], Data, Acc) ->
+    case Data of
+        <<16#fb, Rest/binary>> ->
+            %% NULL
+            parse_resultset_row(ColDefs, Rest, [null | Acc]);
+        _ ->
+            %% Every thing except NULL
+            {Str, Rest} = lenenc_str(Data),
+            parse_resultset_row(ColDefs, Rest, [Str | Acc])
+    end;
+parse_resultset_row([], <<>>, Acc) ->
+    lists:reverse(Acc).
+
+%% Parses a packet containing a column definition (part of a result set)
+parse_column_definition(Data) ->
+    {<<"def">>, Rest1} = lenenc_str(Data),   %% catalog (always "def")
+    {_Schema, Rest2} = lenenc_str(Rest1),    %% schema-name 
+    {_Table, Rest3} = lenenc_str(Rest2),     %% virtual table-name 
+    {_OrgTable, Rest4} = lenenc_str(Rest3),  %% physical table-name 
+    {Name, Rest5} = lenenc_str(Rest4),       %% virtual column name
+    {_OrgName, Rest6} = lenenc_str(Rest5),   %% physical column name
+    {16#0c, Rest7} = lenenc_int(Rest6),      %% length of the following fields
+                                             %% (always 0x0c)
+    <<Charset:16/little,        %% column character set
+      _ColumnLength:32/little,  %% maximum length of the field
+      ColumnType:8,             %% type of the column as defined in Column Type
+      _Flags:16/little,         %% flags
+      _Decimals:8,              %% max shown decimal digits:
+      0,  %% "filler"           %%   - 0x00 for integers and static strings
+      0,                        %%   - 0x1f for dynamic strings, double, float
+      Rest8/binary>> = Rest7,   %%   - 0x00 to 0x51 for decimals
+    %% Here, if command was COM_FIELD_LIST {
+    %%   default values: lenenc_str
+    %% }
+    <<>> = Rest8,
+    #column_definition{name = Name, type = ColumnType, charset = Charset}.
+
 %% --- internal ---
 
-%is_ok_packet(<<?OK, _/binary>>) -> true;
-%is_ok_packet(_)                 -> false;
+%% @doc Wraps Data in packet headers, sends it by calling SendFun and returns
+%% {ok, SeqNum1} where SeqNum1 is the next sequence number.
+-spec send_packet(sendfun(), Data :: binary(), SeqNum :: integer()) ->
+    {ok, NextSeqNum :: integer()}.
+send_packet(SendFun, Data, SeqNum) ->
+    {WithHeaders, SeqNum1} = add_packet_headers(Data, SeqNum),
+    ok = SendFun(WithHeaders),
+    {ok, SeqNum1}.
 
-%is_error_packet(<<?ERROR, _/binary>>) -> true;
-%is_error_packet(_)                    -> false;
+%% @doc Receives data by calling RecvFun and removes the packet headers. Returns
+%% the packet contents and the next packet sequence number.
+-spec recv_packet(RecvFun :: recvfun(), SeqNum :: integer()) ->
+    {ok, Data :: binary(), NextSeqNum :: integer()}.
+recv_packet(RecvFun, SeqNum) ->
+    recv_packet(RecvFun, SeqNum, <<>>).
 
-%is_eof_packet(<<?EOF, _/binary>>) -> true;
-%is_eof_paclet(_)                  -> false;
+%% @doc Receives data by calling RecvFun and removes packet headers. Returns the
+%% data and the next packet sequence number.
+-spec recv_packet(RecvFun :: recvfun(), ExpectSeqNum :: integer(),
+                  Acc :: binary()) ->
+    {ok, Data :: binary(), NextSeqNum :: integer()}.
+recv_packet(RecvFun, ExpectSeqNum, Acc) ->
+    {ok, Header} = RecvFun(4),
+    {Size, ExpectSeqNum, More} = parse_packet_header(Header),
+    {ok, Body} = RecvFun(Size),
+    Acc1 = <<Acc/binary, Body/binary>>,
+    NextSeqNum = (ExpectSeqNum + 1) band 16#ff,
+    case More of
+        false -> {ok, Acc1, NextSeqNum};
+        true  -> recv_packet(RecvFun, NextSeqNum, Acc1)
+    end.
 
--spec parse_ok_packet(binary()) -> #ok_packet{}.
+-spec parse_ok_packet(binary()) -> #ok{}.
 parse_ok_packet(<<?OK:8, Rest/binary>>) ->
-    {AffectedRows, Rest1} = lci(Rest),
-    {InsertId, Rest2} = lci(Rest1),
+    {AffectedRows, Rest1} = lenenc_int(Rest),
+    {InsertId, Rest2} = lenenc_int(Rest1),
     <<StatusFlags:16/little, WarningCount:16/little, Msg/binary>> = Rest2,
     %% We have enabled CLIENT_PROTOCOL_41 but not CLIENT_SESSION_TRACK in the
     %% conditional protocol:
@@ -192,24 +321,24 @@ parse_ok_packet(<<?OK:8, Rest/binary>>) ->
     %% } else {
     %%   string<EOF> info
     %% }
-    #ok_packet{affected_rows = AffectedRows,
-               insert_id = InsertId,
-               status = StatusFlags,
-               warning_count = WarningCount,
-               msg = Msg}.
+    #ok{affected_rows = AffectedRows,
+        insert_id = InsertId,
+        status = StatusFlags,
+        warning_count = WarningCount,
+        msg = Msg}.
 
--spec parse_error_packet(binary()) -> #error_packet{}.
+-spec parse_error_packet(binary()) -> #error{}.
 parse_error_packet(<<?ERROR:8, ErrNo:16/little, "#", SQLState:5/binary-unit:8,
                      Msg/binary>>) ->
     %% Error, 4.1 protocol.
     %% (Older protocol: <<?ERROR:8, ErrNo:16/little, Msg/binary>>)
-    #error_packet{code = ErrNo, state = SQLState, msg = Msg}.
+    #error{code = ErrNo, state = SQLState, msg = Msg}.
 
--spec parse_eof_packet(binary()) -> #eof_packet{}.
+-spec parse_eof_packet(binary()) -> #eof{}.
 parse_eof_packet(<<?EOF:8, NumWarnings:16/little, StatusFlags:16/little>>) ->
     %% EOF packet, 4.1 protocol.
     %% (Older protocol: <<?EOF:8>>)
-    #eof_packet{status = StatusFlags, warning_count = NumWarnings}.
+    #eof{status = StatusFlags, warning_count = NumWarnings}.
 
 -spec hash_password(Password :: iodata(), AuthPluginName :: binary(),
                     AuthPluginData :: binary()) -> binary().
@@ -236,42 +365,42 @@ hash_password(Password, <<"mysql_native_password">>, AuthData) ->
 hash_password(_, AuthPlugin, _) ->
     error({unsupported_auth_method, AuthPlugin}).
 
-%% lci/1 decodes length-coded-integer values
--spec lci(Input :: binary()) -> {Value :: integer(), Rest :: binary()}.
-lci(<<Value:8, Rest/bits>>) when Value < 251 -> {Value, Rest};
-lci(<<16#fc:8, Value:16/little, Rest/binary>>) -> {Value, Rest};
-lci(<<16#fd:8, Value:24/little, Rest/binary>>) -> {Value, Rest};
-lci(<<16#fe:8, Value:64/little, Rest/binary>>) -> {Value, Rest}.
-
-%% lcs/1 decodes length-encoded-string values
--spec lcs(Input :: binary()) -> {String :: binary(), Rest :: binary()}.
-lcs(Bin) ->
-    {Length, Rest} = lci(Bin),
+%% lenenc_int/1 decodes length-encoded-integer values
+-spec lenenc_int(Input :: binary()) -> {Value :: integer(), Rest :: binary()}.
+lenenc_int(<<Value:8, Rest/bits>>) when Value < 251 -> {Value, Rest};
+lenenc_int(<<16#fc:8, Value:16/little, Rest/binary>>) -> {Value, Rest};
+lenenc_int(<<16#fd:8, Value:24/little, Rest/binary>>) -> {Value, Rest};
+lenenc_int(<<16#fe:8, Value:64/little, Rest/binary>>) -> {Value, Rest}.
+
+%% lenenc_str/1 decodes length-encoded-string values
+-spec lenenc_str(Input :: binary()) -> {String :: binary(), Rest :: binary()}.
+lenenc_str(Bin) ->
+    {Length, Rest} = lenenc_int(Bin),
     <<String:Length/binary, Rest1/binary>> = Rest,
     {String, Rest1}.
 
 %% nts/1 decodes a nul-terminated string
--spec nulterm(Input :: binary()) -> {String :: binary(), Rest :: binary()}.
-nulterm(Bin) ->
+-spec nulterm_str(Input :: binary()) -> {String :: binary(), Rest :: binary()}.
+nulterm_str(Bin) ->
     [String, Rest] = binary:split(Bin, <<0>>),
     {String, Rest}.
 
 -ifdef(TEST).
 -include_lib("eunit/include/eunit.hrl").
 
-lci_test() ->
-    ?assertEqual({40, <<>>}, lci(<<40>>)),
-    ?assertEqual({16#ff, <<>>}, lci(<<16#fc, 255, 0>>)),
-    ?assertEqual({16#33aaff, <<>>}, lci(<<16#fd, 16#ff, 16#aa, 16#33>>)),
-    ?assertEqual({16#12345678, <<>>}, lci(<<16#fe, 16#78, 16#56, 16#34, 16#12,
-                                            0, 0, 0, 0>>)),
+lenenc_int_test() ->
+    ?assertEqual({40, <<>>}, lenenc_int(<<40>>)),
+    ?assertEqual({16#ff, <<>>}, lenenc_int(<<16#fc, 255, 0>>)),
+    ?assertEqual({16#33aaff, <<>>}, lenenc_int(<<16#fd, 16#ff, 16#aa, 16#33>>)),
+    ?assertEqual({16#12345678, <<>>}, lenenc_int(<<16#fe, 16#78, 16#56, 16#34,
+                                                 16#12, 0, 0, 0, 0>>)),
     ok.
 
-lcs_test() ->
-    ?assertEqual({<<"Foo">>, <<"bar">>}, lcs(<<3, "Foobar">>)).
+lenenc_str_test() ->
+    ?assertEqual({<<"Foo">>, <<"bar">>}, lenenc_str(<<3, "Foobar">>)).
 
 nulterm_test() ->
-    ?assertEqual({<<"Foo">>, <<"bar">>}, nulterm(<<"Foo", 0, "bar">>)).
+    ?assertEqual({<<"Foo">>, <<"bar">>}, nulterm_str(<<"Foo", 0, "bar">>)).
 
 parse_header_test() ->
     %% Example from "MySQL Internals", revision 307, section 14.1.3.3 EOF_Packet
@@ -280,26 +409,24 @@ parse_header_test() ->
     %% Check header contents and body length
     ?assertEqual({size(Body), 5, false}, parse_packet_header(Header)),
     ok.
-    
+
 add_packet_headers_test() ->
     {Data, 43} = add_packet_headers(<<"foo">>, 42),
     ?assertEqual(<<3, 0, 0, 42, "foo">>, list_to_binary(Data)).
 
 parse_ok_test() ->
     Body = <<0, 5, 1, 2, 0, 0, 0, "Foo">>,
-    ?assertEqual(#ok_packet{affected_rows = 5,
-                            insert_id = 1,
-                            status = ?SERVER_STATUS_AUTOCOMMIT,
-                            warning_count = 0,
-                            msg = <<"Foo">>},
+    ?assertEqual(#ok{affected_rows = 5,
+                     insert_id = 1,
+                     status = ?SERVER_STATUS_AUTOCOMMIT,
+                     warning_count = 0,
+                     msg = <<"Foo">>},
                  parse_ok_packet(Body)).
 
 parse_error_test() ->
     %% Protocol 4.1
     Body = <<255, 42, 0, "#", "XYZxx", "Foo">>,
-    ?assertEqual(#error_packet{code = 42,
-                               state = <<"XYZxx">>,
-                               msg = <<"Foo">>},
+    ?assertEqual(#error{code = 42, state = <<"XYZxx">>, msg = <<"Foo">>},
                  parse_error_packet(Body)),
     ok.
 
@@ -308,9 +435,10 @@ parse_eof_test() ->
     Packet = <<16#05, 16#00, 16#00, 16#05, 16#fe, 16#00, 16#00, 16#02, 16#00>>,
     <<_Header:4/binary-unit:8, Body/binary>> = Packet,
     %% Ignore header. Parse body as an eof_packet.
-    ?assertEqual(#eof_packet{warning_count = 0,
-                             status = ?SERVER_STATUS_AUTOCOMMIT},
+    ?assertEqual(#eof{warning_count = 0,
+                      status = ?SERVER_STATUS_AUTOCOMMIT},
                  parse_eof_packet(Body)),
     ok.
 
+
 -endif.

+ 108 - 0
src/mysql_text_protocol.erl

@@ -0,0 +1,108 @@
+%% @doc This module handles conversion of values in the form they are
+%% represented in the text protocol to our prefered Erlang term representations.
+-module(mysql_text_protocol).
+
+-export([text_to_term/2]).
+
+-include("records.hrl").
+-include("protocol.hrl"). %% The TYPE_* macros.
+
+%% @doc When receiving data in the text protocol, we get everything as binaries
+%% (except NULL). This function is used to parse these strings values.
+text_to_term(Type, Text) when is_binary(Text) ->
+    case Type of
+        ?TYPE_DECIMAL -> parse_float(Text); %% <-- this will probably change
+        ?TYPE_TINY -> binary_to_integer(Text);
+        ?TYPE_SHORT -> binary_to_integer(Text);
+        ?TYPE_LONG -> binary_to_integer(Text);
+        ?TYPE_FLOAT -> parse_float(Text);
+        ?TYPE_DOUBLE -> parse_float(Text);
+        ?TYPE_TIMESTAMP -> parse_datetime(Text);
+        ?TYPE_LONGLONG -> binary_to_integer(Text);
+        ?TYPE_INT24 -> binary_to_integer(Text);
+        ?TYPE_DATE -> parse_date(Text);
+        ?TYPE_TIME -> parse_time(Text);
+        ?TYPE_DATETIME -> parse_datetime(Text);
+        ?TYPE_YEAR -> binary_to_integer(Text);
+        ?TYPE_VARCHAR -> Text;
+        ?TYPE_BIT -> binary_to_integer(Text);
+        ?TYPE_NEWDECIMAL -> parse_float(Text); %% <-- this will probably change
+        ?TYPE_ENUM -> Text;
+        ?TYPE_SET when Text == <<>> -> sets:new();
+        ?TYPE_SET -> sets:from_list(binary:split(Text, <<",">>, [global]));
+        ?TYPE_TINY_BLOB -> Text; %% charset?
+        ?TYPE_MEDIUM_BLOB -> Text;
+        ?TYPE_LONG_BLOB -> Text;
+        ?TYPE_BLOB -> Text;
+        ?TYPE_VAR_STRING -> Text;
+        ?TYPE_STRING -> Text;
+        ?TYPE_GEOMETRY -> Text %% <-- what do we want here?
+    end;
+text_to_term(_, null) ->
+    %% NULL is the only value not represented as a binary.
+    null.
+
+parse_datetime(<<Y:4/binary, "-", M:2/binary, "-", D:2/binary, " ",
+                 H:2/binary, ":", Mi:2/binary, ":", S:2/binary>>) ->
+    {{binary_to_integer(Y), binary_to_integer(M), binary_to_integer(D)},
+     {binary_to_integer(H), binary_to_integer(Mi), binary_to_integer(S)}}.
+
+parse_date(<<Y:4/binary, "-", M:2/binary, "-", D:2/binary>>) ->
+    {binary_to_integer(Y), binary_to_integer(M), binary_to_integer(D)}.
+
+parse_time(<<H:2/binary, ":", Mi:2/binary, ":", S:2/binary>>) ->
+    {binary_to_integer(H), binary_to_integer(Mi), binary_to_integer(S)}.
+
+parse_float(Text) ->
+    try binary_to_float(Text)
+    catch error:badarg ->
+        try binary_to_integer(Text) of
+            Int -> float(Int)
+        catch error:badarg ->
+            %% It is something like "4e75" that must be turned into "4.0e75"
+            binary_to_float(binary:replace(Text, <<"e">>, <<".0e">>))
+        end
+    end.
+
+-ifdef(TEST).
+-include_lib("eunit/include/eunit.hrl").
+
+text_to_term_test() ->
+    %% Int types
+    lists:foreach(fun (T) -> ?assertEqual(1, text_to_term(T, <<"1">>)) end,
+                  [?TYPE_TINY, ?TYPE_SHORT, ?TYPE_LONG, ?TYPE_LONGLONG,
+                   ?TYPE_INT24, ?TYPE_YEAR, ?TYPE_BIT]),
+
+    %% Floating point and decimal numbers
+    lists:foreach(fun (T) -> ?assertEqual(3.0, text_to_term(T, <<"3.0">>)) end,
+                  [?TYPE_FLOAT, ?TYPE_DOUBLE, ?TYPE_DECIMAL, ?TYPE_NEWDECIMAL]),
+    ?assertEqual(3.0,  text_to_term(?TYPE_FLOAT, <<"3">>)),
+    ?assertEqual(30.0, text_to_term(?TYPE_FLOAT, <<"3e1">>)),
+    ?assertEqual(3,    text_to_term(?TYPE_LONG, <<"3">>)),
+
+    %% Date and time
+    ?assertEqual({2014, 11, 01}, text_to_term(?TYPE_DATE, <<"2014-11-01">>)),
+    ?assertEqual({23, 59, 01}, text_to_term(?TYPE_TIME, <<"23:59:01">>)),
+    ?assertEqual({{2014, 11, 01}, {23, 59, 01}},
+                 text_to_term(?TYPE_DATETIME, <<"2014-11-01 23:59:01">>)),
+    ?assertEqual({{2014, 11, 01}, {23, 59, 01}},
+                 text_to_term(?TYPE_TIMESTAMP, <<"2014-11-01 23:59:01">>)),
+
+    %% Strings and blobs
+    lists:foreach(fun (T) ->
+                      ?assertEqual(<<"x">>, text_to_term(T, <<"x">>))
+                  end,
+                  [?TYPE_VARCHAR, ?TYPE_ENUM, ?TYPE_TINY_BLOB,
+                   ?TYPE_MEDIUM_BLOB, ?TYPE_LONG_BLOB, ?TYPE_BLOB,
+                   ?TYPE_VAR_STRING, ?TYPE_STRING, ?TYPE_GEOMETRY]),
+
+    %% Set
+    ?assertEqual(sets:from_list([<<"b">>, <<"a">>]),
+                 text_to_term(?TYPE_SET, <<"a,b">>)),
+    ?assertEqual(sets:from_list([]), text_to_term(?TYPE_SET, <<>>)),
+
+    %% NULL
+    ?assertEqual(null, text_to_term(?TYPE_FLOAT, null)),
+    ok.
+
+-endif.

+ 46 - 0
test/mysql_tests.erl

@@ -0,0 +1,46 @@
+%% @doc This module performs test to an actual database.
+-module(mysql_tests).
+
+-include_lib("eunit/include/eunit.hrl").
+
+-define(user,     "otptest").
+-define(password, "otptest").
+
+connect_test() ->
+    {ok, Pid} = mysql:connect([{user, ?user}, {password, ?password}]),
+
+    %% A query without a result set
+    ?assertEqual(ok, mysql:query(Pid, <<"USE otptest">>)),
+
+    ?assertEqual(ok, mysql:disconnect(Pid)).
+
+query_test_() ->
+    {setup,
+     fun () ->
+         {ok, Pid} = mysql:connect([{user, ?user}, {password, ?password}]),
+         %ok = mysql:query(Pid, <<"DROP DATABASE IF EXISTS otptest">>),
+         %ok = mysql:query(Pid, <<"CREATE DATABASE otptest">>),
+         ok = mysql:query(Pid, <<"USE otptest">>),
+         Pid
+     end,
+     fun (Pid) ->
+         mysql:disconnect(Pid)
+     end,
+     {with, [fun basic_queries/1]}}.
+
+basic_queries(Pid) ->
+
+    %% warning count
+    ?assertEqual(ok, mysql:query(Pid, <<"DROP TABLE IF EXISTS foo">>)),
+    ?assertEqual(1, mysql:warning_count(Pid)),
+
+    %% SQL parse error
+    ?assertMatch({error, {1064, <<"42000">>, <<"You have an erro", _/binary>>}},
+                 mysql:query(Pid, <<"FOO">>)),
+
+    %% Simple resultset with various types
+    ?assertEqual({ok, [<<"i">>, <<"s">>], [[42, <<"foo">>]]},
+                 mysql:query(Pid, <<"SELECT 42 AS i, 'foo' AS s;">>)),
+
+    %{ok, Fields, Rows} = mysql:query(Pid, <<"SELECT * FROM settest">>),
+    ok.

+ 198 - 0
test/protocol_tests.erl

@@ -0,0 +1,198 @@
+%% @doc Eunit test cases for the mysql_protocol module.
+-module(protocol_tests).
+
+-include_lib("eunit/include/eunit.hrl").
+
+-include("protocol.hrl").
+-include("records.hrl").
+
+resultset_test() ->
+    %% A query that returns a result set in the text protocol.
+    Query = <<"SELECT @@version_comment">>,
+    ExpectedReq = <<(size(Query) + 1):24/little, 0, ?COM_QUERY, Query/binary>>,
+    ExpectedResponse = hexdump_to_bin(
+        "01 00 00 01 01|27 00 00    02 03 64 65 66 00 00 00    .....'....def..."
+        "11 40 40 76 65 72 73 69    6f 6e 5f 63 6f 6d 6d 65    .@@version_comme"
+        "6e 74 00 0c 08 00 1c 00    00 00 fd 00 00 1f 00 00|   nt.............."
+        "05 00 00 03 fe 00 00 02    00|1d 00 00 04 1c 4d 79    ..............My"
+        "53 51 4c 20 43 6f 6d 6d    75 6e 69 74 79 20 53 65    SQL Community Se"
+        "72 76 65 72 20 28 47 50    4c 29|05 00 00 05 fe 00    rver (GPL)......"
+        "00 02 00                                              ..."),
+    ExpectedCommunication = [{send, ExpectedReq},
+                             {recv, ExpectedResponse}],
+    FakeSock = fakesocket_create(ExpectedCommunication),
+    SendFun = fun (Data) -> fakesocket_send(FakeSock, Data) end,
+    RecvFun = fun (Size) -> fakesocket_recv(FakeSock, Size) end,
+    ResultSet = mysql_protocol:query(Query, SendFun, RecvFun),
+    fakesocket_close(FakeSock),
+    ?assertMatch(#text_resultset{column_definitions =
+                                     [#column_definition{
+                                          name = <<"@@version_comment">>}],
+                                 rows = [[<<"MySQL Community Server (GPL)">>]]},
+                 ResultSet),
+    ok.
+
+resultset_error_test() ->
+    %% A query that returns a response starting as a result set but then
+    %% interupts itself and decides that it is an error.
+    Query = <<"EXPLAIN SELECT * FROM dual;">>,
+    ExpectedReq = <<(size(Query) + 1):24/little, 0, ?COM_QUERY, Query/binary>>,
+    ExpectedResponse = hexdump_to_bin(
+        "01 00 00 01 0a 18 00 00    02 03 64 65 66 00 00 00    ..........def..."
+        "02 69 64 00 0c 3f 00 03    00 00 00 08 a1 00 00 00    .id..?.........."
+        "00 21 00 00 03 03 64 65    66 00 00 00 0b 73 65 6c    .!....def....sel"
+        "65 63 74 5f 74 79 70 65    00 0c 08 00 13 00 00 00    ect_type........"
+        "fd 01 00 1f 00 00 1b 00    00 04 03 64 65 66 00 00    ...........def.."
+        "00 05 74 61 62 6c 65 00    0c 08 00 40 00 00 00 fd    ..table....@...."
+        "00 00 1f 00 00 1a 00 00    05 03 64 65 66 00 00 00    ..........def..."
+        "04 74 79 70 65 00 0c 08    00 0a 00 00 00 fd 00 00    .type..........."
+        "1f 00 00 23 00 00 06 03    64 65 66 00 00 00 0d 70    ...#....def....p"
+        "6f 73 73 69 62 6c 65 5f    6b 65 79 73 00 0c 08 00    ossible_keys...."
+        "00 10 00 00 fd 00 00 1f    00 00 19 00 00 07 03 64    ...............d"
+        "65 66 00 00 00 03 6b 65    79 00 0c 08 00 40 00 00    ef....key....@.."
+        "00 fd 00 00 1f 00 00 1d    00 00 08 03 64 65 66 00    ............def."
+        "00 00 07 6b 65 79 5f 6c    65 6e 00 0c 08 00 00 10    ...key_len......"
+        "00 00 fd 00 00 1f 00 00    19 00 00 09 03 64 65 66    .............def"
+        "00 00 00 03 72 65 66 00    0c 08 00 00 04 00 00 fd    ....ref........."
+        "00 00 1f 00 00 1a 00 00    0a 03 64 65 66 00 00 00    ..........def..."
+        "04 72 6f 77 73 00 0c 3f    00 0a 00 00 00 08 a0 00    .rows..?........"
+        "00 00 00 1b 00 00 0b 03    64 65 66 00 00 00 05 45    ........def....E"
+        "78 74 72 61 00 0c 08 00    ff 00 00 00 fd 01 00 1f    xtra............"
+        "00 00 05 00 00 0c fe 00    00 02 00 17 00 00 0d ff    ................"
+        "48 04 23 48 59 30 30 30    4e 6f 20 74 61 62 6c 65    H.#HY000No table"
+        "73 20 75 73 65 64                                     s used"),
+    Sock = fakesocket_create([{send, ExpectedReq}, {recv, ExpectedResponse}]),
+    SendFun = fun (Data) -> fakesocket_send(Sock, Data) end,
+    RecvFun = fun (Size) -> fakesocket_recv(Sock, Size) end,
+    Result = mysql_protocol:query(Query, SendFun, RecvFun),
+    ?assertMatch(#error{}, Result),
+    fakesocket_close(Sock),
+    ok.
+
+%% --- Helper functions for the above tests ---
+
+%% Convert hex dumps to binaries. This is a helper function for the tests.
+%% This function is also tested below.
+hexdump_to_bin(HexDump) ->
+    hexdump_to_bin(iolist_to_binary(HexDump), <<>>).
+
+hexdump_to_bin(<<Line:50/binary, _Junk:20/binary, Rest/binary>>, Acc) ->
+    hexdump_to_bin(Line, Rest, Acc);
+hexdump_to_bin(<<Line:50/binary, _Junk/binary>>, Acc) ->
+    %% last line (shorter than 70)
+    hexdump_to_bin(Line, <<>>, Acc);
+hexdump_to_bin(<<>>, Acc) ->
+    Acc.
+
+hexdump_to_bin(Line, Rest, Acc) ->
+    HexNums = re:split(Line, <<"[ |]+">>, [{return, list}, trim]),
+    Acc1 = lists:foldl(fun (HexNum, Acc0) ->
+                           {ok, [Byte], []} = io_lib:fread("~16u", HexNum),
+                           <<Acc0/binary, Byte:8>>
+                       end,
+                       Acc,
+                       HexNums),
+    hexdump_to_bin(Rest, Acc1).
+
+hexdump_to_bin_test() ->
+    HexDump =
+        "0e 00 00 00 03 73 65 6c    65 63 74 20 55 53 45 52    .....select USER"
+        "28 29                                                 ()",
+    Expect = <<16#0e, 16#00, 16#00, 16#00, 16#03, 16#73, 16#65, 16#6c,
+               16#65, 16#63, 16#74, 16#20, 16#55, 16#53, 16#45, 16#52,
+               16#28, 16#29>>,
+    ?assertEqual(Expect, hexdump_to_bin(HexDump)).
+
+%% --- Fake socket ---
+%%
+%% A "fake socket" is used in test where we need to mock socket communication.
+%% It is a pid maintaining a list of expected send and recv events.
+
+%% @doc Creates a fakesocket process with a buffer of expected recv and send
+%% calls. The pid of the fakesocket process is returned.
+-spec fakesocket_create([{recv, binary()} | {send, binary()}]) -> pid().
+fakesocket_create(ExpectedEvents) ->
+    spawn_link(fun () -> fakesocket_loop(ExpectedEvents) end).
+
+%% @doc Receives NumBytes bytes from fakesocket Pid. This function can be used
+%% as a replacement for gen_tcp:recv/2 in unit tests. If there not enough data
+%% in the fakesocket's buffer, an error is raised.
+fakesocket_recv(Pid, NumBytes) ->
+    Pid ! {recv, NumBytes, self()},
+    receive
+        {ok, Data} -> {ok, Data};
+        error -> error({unexpected_recv, NumBytes})
+    after 100 ->
+        error(noreply)
+    end.
+
+%% @doc Sends data to fa fakesocket. This can be used as replacement for
+%% gen_tcp:send/2 in unit tests. If the data sent is not what the fakesocket
+%% expected, an error is raised.
+fakesocket_send(Pid, Data) ->
+    Pid ! {send, iolist_to_binary(Data), self()},
+    receive
+        ok -> ok;
+        error -> error({unexpected_send, Data})
+    after 100 ->
+        error(noreply)
+    end.
+
+%% Stops the fakesocket process. If the fakesocket's buffer is not empty,
+%% an error is raised.
+fakesocket_close(Pid) ->
+    Pid ! {done, self()},
+    receive
+        ok -> ok;
+        {remains, Remains} -> error({unexpected_close, Remains})
+    after 100 ->
+        error(noreply)
+    end.
+
+%% Used by fakesocket_create/1.
+fakesocket_loop(AllEvents = [{Func, Data} | Events]) ->
+    receive
+        {recv, NumBytes, FromPid} when Func == recv, NumBytes == size(Data) ->
+            FromPid ! {ok, Data},
+            fakesocket_loop(Events);
+        {recv, NumBytes, FromPid} when Func == recv, NumBytes < size(Data) ->
+            <<Data1:NumBytes/binary, Rest/binary>> = Data,
+            FromPid ! {ok, Data1},
+            fakesocket_loop([{recv, Rest} | Events]);
+        {send, Bytes, FromPid} when Func == send, Bytes == Data ->
+            FromPid ! ok,
+            fakesocket_loop(Events);
+        {send, Bytes, FromPid} when Func == send, size(Bytes) < size(Data) ->
+            Size = size(Bytes),
+            case Data of
+                <<Bytes:Size/binary, Rest/binary>> ->
+                    FromPid ! ok,
+                    fakesocket_loop([{send, Rest} | Events]);
+                _ ->
+                    FromPid ! error
+            end;
+        {_, _, FromPid} ->
+            FromPid ! error;
+        {done, FromPid} ->
+            FromPid ! {remains, AllEvents}
+    end;
+fakesocket_loop([]) ->
+    receive
+        {done, FromPid} -> FromPid ! ok;
+        {_, _, FromPid} -> FromPid ! error
+    end.
+
+%% Tests for the fakesocket functions.
+fakesocket_bad_recv_test() ->
+    Pid = fakesocket_create([{recv, <<"foobar">>}]),
+    ?assertError(_, fakesocket_recv(Pid, 10)).
+
+fakesocket_success_test() ->
+    Pid = fakesocket_create([{recv, <<"foobar">>}, {send, <<"baz">>}]),
+    %?assertError({unexpected_close, _}, fakesocket_close(Pid)),
+    ?assertEqual({ok, <<"foo">>}, fakesocket_recv(Pid, 3)),
+    ?assertEqual({ok, <<"bar">>}, fakesocket_recv(Pid, 3)),
+    ?assertEqual(ok, fakesocket_send(Pid, <<"baz">>)),
+    ?assertEqual(ok, fakesocket_close(Pid)),
+    %% The process will exit after close. Another recv will raise noreply.
+    ?assertError(noreply, fakesocket_recv(Pid, 3)).