Browse Source

Prepared statements without params (bin protocol)

Viktor Söderqvist 10 years ago
parent
commit
09a2adfcb2
7 changed files with 420 additions and 106 deletions
  1. 3 3
      include/records.hrl
  2. 16 4
      src/mysql.erl
  3. 64 0
      src/mysql_binary.erl
  4. 35 5
      src/mysql_connection.erl
  5. 258 82
      src/mysql_protocol.erl
  6. 4 4
      test/mysql_protocol_tests.erl
  7. 40 8
      test/mysql_tests.erl

+ 3 - 3
include/records.hrl

@@ -25,9 +25,9 @@
 -record(column_definition, {name, type, charset}).
 -record(column_definition, {name, type, charset}).
 
 
 %% A resultset as received from the server using the text protocol.
 %% A resultset as received from the server using the text protocol.
-%% All values are binary (SQL code) except NULL.
--record(text_resultset, {column_definitions :: [#column_definition{}],
-                         rows :: [[binary() | null]]}).
+%% For text protocol resultsets, rows :: [[binary() | null]].
+-record(resultset, {column_definitions :: [#column_definition{}],
+                    rows :: [[term()]]}).
 
 
 %% Response of a successfull prepare call.
 %% Response of a successfull prepare call.
 -record(prepared, {statement_id :: integer(),
 -record(prepared, {statement_id :: integer(),

+ 16 - 4
src/mysql.erl

@@ -1,8 +1,12 @@
 %% @doc MySQL/OTP
 %% @doc MySQL/OTP
 -module(mysql).
 -module(mysql).
 
 
--export([connect/1, disconnect/1, query/2, warning_count/1, affected_rows/1,
-         insert_id/1]).
+-export([connect/1, disconnect/1, query/2, query/3, prepare/2, warning_count/1,
+         affected_rows/1, insert_id/1]).
+
+%% @doc A MySQL error with the codes and message returned from the server.
+-type reason() :: {Code :: integer(), SQLState :: binary(),
+                   Message :: binary()}.
 
 
 -spec connect(list()) -> {ok, pid()} | ignore | {error, term()}.
 -spec connect(list()) -> {ok, pid()} | ignore | {error, term()}.
 connect(Opts) ->
 connect(Opts) ->
@@ -18,11 +22,19 @@ disconnect(Conn) ->
          Query :: iodata(),
          Query :: iodata(),
          Fields :: [binary()],
          Fields :: [binary()],
          Rows :: [[term()]],
          Rows :: [[term()]],
-         Reason :: {Code :: integer(), SQLState :: binary(),
-                    Message :: binary()}.
+         Reason :: reason().
 query(Conn, Query) ->
 query(Conn, Query) ->
     gen_server:call(Conn, {query, Query}).
     gen_server:call(Conn, {query, Query}).
 
 
+%% @doc Executes a prepared statement.
+query(Conn, StatementId, Args) ->
+    gen_server:call(Conn, {query, StatementId, Args}).
+
+-spec prepare(Conn :: pid(), Query :: iodata()) ->
+    {ok, StatementId :: integer()} | {error, Reason :: reason()}.
+prepare(Conn, Query) ->
+    gen_server:call(Conn, {prepare, Query}).
+
 -spec warning_count(pid()) -> integer().
 -spec warning_count(pid()) -> integer().
 warning_count(Conn) ->
 warning_count(Conn) ->
     gen_server:call(Conn, warning_count).
     gen_server:call(Conn, warning_count).

+ 64 - 0
src/mysql_binary.erl

@@ -0,0 +1,64 @@
+%% MySQL/OTP – a MySQL driver for Erlang/OTP
+%% Copyright (C) 2014 Viktor Söderqvist
+%%
+%% This program is free software: you can redistribute it and/or modify
+%% it under the terms of the GNU General Public License as published by
+%% the Free Software Foundation, either version 3 of the License, or
+%% (at your option) any later version.
+%%
+%% This program is distributed in the hope that it will be useful,
+%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+%% GNU General Public License for more details.
+%%
+%% You should have received a copy of the GNU General Public License
+%% along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+%% @doc The MySQL binary protocol is used for prepared statements. This module
+%% is used mainly from the mysql_protocol module.
+-module(mysql_binary).
+
+-export([null_bitmap_decode/3, null_bitmap_encode/2]).
+
+%% @doc Decodes a null bitmap as stored by MySQL and returns it in a strait
+%% bitstring from left to right. Returns it together with the rest of the data.
+%%
+%% In the MySQL null bitmap the bits are stored counting bytes from the left and
+%% bits within each byte from the right. (Sort of little endian.)
+-spec null_bitmap_decode(NumColumns :: integer(), BitOffset :: integer(),
+                         Data :: binary()) ->
+    {NullBitstring :: bitstring(), Rest :: binary()}.
+null_bitmap_decode(NumColumns, Data, BitOffset) ->
+    %% Binary shift right by 3 is equivallent to integer division by 8.
+    BitMapLength = (NumColumns + BitOffset + 7) bsr 3,
+    <<NullBitstring0:BitMapLength/binary, Rest/binary>> = Data,
+    <<_:BitOffset, NullBitstring:NumColumns/bitstring, _/bitstring>> =
+        << <<(reverse_byte(B))/binary>> || <<B:1/binary>> <= NullBitstring0 >>,
+    {NullBitstring, Rest}.
+
+%% @doc The reverse of null_bitmap_decode/3. The number of columns is taken to
+%% be the number of bits in NullBitstring. Returns the MySQL null bitmap as a
+%% binary (i.e. full bytes). BitOffset is the number of unused bits that should
+%% be inserted before the other bits.
+-spec null_bitmap_encode(bitstring(), integer()) -> binary().
+null_bitmap_encode(NullBitstring, BitOffset) ->
+    PayloadLength = bit_size(NullBitstring) + BitOffset,
+    %% Round up to a multiple of 8.
+    BitMapLength = (PayloadLength + 7) band bnot 7,
+    PadBitsLength = BitMapLength - PayloadLength,
+    PaddedBitstring = <<0:BitOffset, NullBitstring/bitstring, 0:PadBitsLength>>,
+    << <<(reverse_byte(B))/binary>> || <<B:1/binary>> <= PaddedBitstring >>.
+
+%% Reverses the bits in a byte.
+reverse_byte(<<A:1, B:1, C:1, D:1, E:1, F:1, G:1, H:1>>) ->
+    <<H:1, G:1, F:1, E:1, D:1, C:1, B:1, A:1>>.
+
+-ifdef(TEST).
+-include_lib("eunit/include/eunit.hrl").
+
+null_bitmap_test() ->
+    ?assertEqual({<<0, 1:1>>, <<>>}, null_bitmap_decode(9, <<0, 4>>, 2)),
+    ?assertEqual(<<0, 4>>, null_bitmap_encode(<<0, 1:1>>, 2)),
+    ok.
+
+-endif.

+ 35 - 5
src/mysql_connection.erl

@@ -17,7 +17,7 @@
 
 
 %% Gen_server state
 %% Gen_server state
 -record(state, {socket, timeout = infinity, affected_rows = 0, status = 0,
 -record(state, {socket, timeout = infinity, affected_rows = 0, status = 0,
-                warning_count = 0, insert_id = 0}).
+                warning_count = 0, insert_id = 0, stmts = dict:new()}).
 
 
 %% A tuple representing a MySQL server error, typically returned in the form
 %% A tuple representing a MySQL server error, typically returned in the form
 %% {error, reason()}.
 %% {error, reason()}.
@@ -62,18 +62,48 @@ handle_call({query, Query}, _From, State) when is_binary(Query);
             {reply, ok, State1};
             {reply, ok, State1};
         #error{} = E ->
         #error{} = E ->
             {reply, {error, error_to_reason(E)}, State1};
             {reply, {error, error_to_reason(E)}, State1};
-        #text_resultset{column_definitions = ColDefs, rows = Rows} ->
+        #resultset{column_definitions = ColDefs, rows = Rows} ->
             Names = [Def#column_definition.name || Def <- ColDefs],
             Names = [Def#column_definition.name || Def <- ColDefs],
             Rows1 = decode_text_rows(ColDefs, Rows),
             Rows1 = decode_text_rows(ColDefs, Rows),
             {reply, {ok, Names, Rows1}, State1}
             {reply, {ok, Names, Rows1}, State1}
     end;
     end;
+handle_call({query, Stmt, Args}, _From, State) when is_integer(Stmt);
+                                                    is_atom(Stmt) ->
+    StmtRec = dict:fetch(Stmt, State#state.stmts),
+    #state{socket = Socket, timeout = Timeout} = State,
+    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
+    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+    Rec = mysql_protocol:execute(StmtRec, Args, SendFun, RecvFun),
+    State1 = update_state(State, Rec),
+    case Rec of
+        #ok{} ->
+            {reply, ok, State1};
+        #error{} = E ->
+            {reply, {error, error_to_reason(E)}, State1};
+        #resultset{column_definitions = ColDefs, rows = Rows} ->
+            Names = [Def#column_definition.name || Def <- ColDefs],
+            {reply, {ok, Names, Rows}, State1}
+    end;
+handle_call({prepare, Query}, _From, State) ->
+    #state{socket = Socket, timeout = Timeout} = State,
+    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
+    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+    Rec = mysql_protocol:prepare(Query, SendFun, RecvFun),
+    State1 = update_state(State, Rec),
+    case Rec of
+        #error{} = E ->
+            {reply, {error, error_to_reason(E)}, State1};
+        #prepared{statement_id = Id} = Stmt ->
+            Stmts1 = dict:store(Id, Stmt, State1#state.stmts),
+            State2 = State#state{stmts = Stmts1},
+            {reply, {ok, Id}, State2}
+    end;
 handle_call(warning_count, _From, State) ->
 handle_call(warning_count, _From, State) ->
     {reply, State#state.warning_count, State};
     {reply, State#state.warning_count, State};
 handle_call(insert_id, _From, State) ->
 handle_call(insert_id, _From, State) ->
     {reply, State#state.insert_id, State};
     {reply, State#state.insert_id, State};
-handle_call(status_flags, _From, State) ->
-    %% Bitmask of status flags from the last ok packet, etc.
-    {reply, State#state.status, State}.
+handle_call(affected_rows, _From, State) ->
+    {reply, State#state.affected_rows, State}.
 
 
 handle_cast(_, _) -> todo.
 handle_cast(_, _) -> todo.
 
 

+ 258 - 82
src/mysql_protocol.erl

@@ -9,7 +9,7 @@
 
 
 -export([handshake/5,
 -export([handshake/5,
          query/3,
          query/3,
-         prepare/3]).
+         prepare/3, execute/4]).
 
 
 -export_type([sendfun/0, recvfun/0]).
 -export_type([sendfun/0, recvfun/0]).
 
 
@@ -27,36 +27,6 @@
 -define(error_pattern, <<?ERROR, _/binary>>).
 -define(error_pattern, <<?ERROR, _/binary>>).
 -define(eof_pattern, <<?EOF, _:4/binary>>).
 -define(eof_pattern, <<?EOF, _:4/binary>>).
 
 
-%% @doc Parses a packet header (32 bits) and returns a tuple.
-%%
-%% The client should first read a header and parse it. Then read PacketLength
-%% bytes. If there are more packets, read another header and read a new packet
-%% length of payload until there are no more packets. The seq num should
-%% increment from 0 and may wrap around at 255 back to 0.
-%%
-%% When all packets are read and the payload of all packets are concatenated, it
-%% can be parsed using parse_response/1, etc. depending on what type of response
-%% is expected.
--spec parse_packet_header(PackerHeader :: binary()) ->
-    {PacketLength :: integer(),
-     SeqNum :: integer(),
-     MorePacketsExist :: boolean()}.
-parse_packet_header(<<PacketLength:24/little-integer, SeqNum:8/integer>>) ->
-    {PacketLength, SeqNum, PacketLength == 16#ffffff}.
-
-%% @doc Splits a packet body into chunks and wraps them in headers. The
-%% resulting list is ready to sent to the socket.
--spec add_packet_headers(PacketBody :: iodata(), SeqNum :: integer()) ->
-    {PacketWithHeaders :: iodata(), NextSeqNum :: integer()}.
-add_packet_headers(PacketBody, SeqNum) ->
-    Bin = iolist_to_binary(PacketBody),
-    Size = size(Bin),
-    SeqNum1 = (SeqNum + 1) rem 16#100,
-    %% Todo: implement the case when Size >= 16#ffffff.
-    if Size < 16#ffffff ->
-        {[<<Size:24/little, SeqNum:8>>, Bin], SeqNum1}
-    end.
-
 %% @doc Performs a handshake using the supplied functions for communication.
 %% @doc Performs a handshake using the supplied functions for communication.
 %% Returns an ok or an error record. Raises errors when various unimplemented
 %% Returns an ok or an error record. Raises errors when various unimplemented
 %% features are requested.
 %% features are requested.
@@ -75,6 +45,109 @@ handshake(Username, Password, Database, SendFun, RecvFun) ->
     {ok, ConfirmPacket, _SeqNum3} = recv_packet(RecvFun, SeqNum2),
     {ok, ConfirmPacket, _SeqNum3} = recv_packet(RecvFun, SeqNum2),
     parse_handshake_confirm(ConfirmPacket).
     parse_handshake_confirm(ConfirmPacket).
 
 
+-spec query(Query :: iodata(), sendfun(), recvfun()) ->
+    #ok{} | #error{} | #resultset{}.
+query(Query, SendFun, RecvFun) ->
+    Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
+    SeqNum0 = 0,
+    {ok, SeqNum1} = send_packet(SendFun, Req, SeqNum0),
+    {ok, Resp, SeqNum2} = recv_packet(RecvFun, SeqNum1),
+    case Resp of
+        ?ok_pattern ->
+            parse_ok_packet(Resp);
+        ?error_pattern ->
+            parse_error_packet(Resp);
+        _ResultSet ->
+            %% The first packet in a resultset is only the column count.
+            {ColumnCount, <<>>} = lenenc_int(Resp),
+            ResultSet = fetch_resultset(RecvFun, ColumnCount, SeqNum2),
+            %% TODO: Factor out parsing the rows from fetch_resultset/3 and do
+            %% that here instead.
+            ResultSet
+    end.
+
+%% @doc Prepares a statement.
+-spec prepare(iodata(), sendfun(), recvfun()) -> #error{} | #prepared{}.
+prepare(Query, SendFun, RecvFun) ->
+    Req = <<?COM_STMT_PREPARE, (iolist_to_binary(Query))/binary>>,
+    {ok, SeqNum1} = send_packet(SendFun, Req, 0),
+    {ok, Resp, SeqNum2} = recv_packet(RecvFun, SeqNum1),
+    case Resp of
+        ?error_pattern ->
+            parse_error_packet(Resp);
+        <<?OK,
+          StmtId:32/little,
+          NumColumns:16/little,
+          NumParams:16/little,
+          0, %% reserved_1 -- [00] filler
+          WarningCount:16/little>> ->
+            %% This was the first packet.
+            %% If NumParams > 0 more packets will follow:
+            {ok, ParamDefs, SeqNum3} =
+                fetch_column_definitions(RecvFun, SeqNum2, NumParams, []),
+            %% The eof packet is not here in mysql 5.6 but it's in the examples.
+            SeqNum4 = case NumParams of
+                0 ->
+                    SeqNum3;
+                _ ->
+                    {ok, ?eof_pattern, SeqNum3x} = recv_packet(RecvFun, SeqNum3),
+                    SeqNum3x
+            end,
+            {ok, ColDefs, SeqNum5} =
+                fetch_column_definitions(RecvFun, SeqNum4, NumColumns, []),
+            {ok, ?eof_pattern, _SeqNum6} = recv_packet(RecvFun, SeqNum5),
+            #prepared{statement_id = StmtId,
+                      params = ParamDefs,
+                      columns = ColDefs,
+                      warning_count = WarningCount}
+    end.
+
+%% @doc Executes a prepared statement.
+-spec execute(#prepared{}, [term()], sendfun(), recvfun()) -> #resultset{}.
+execute(#prepared{statement_id = Id, params = ParamDefs}, ParamValues,
+        SendFun, RecvFun) ->
+    %% Flags Constant Name
+    %% 0x00 CURSOR_TYPE_NO_CURSOR
+    %% 0x01 CURSOR_TYPE_READ_ONLY
+    %% 0x02 CURSOR_TYPE_FOR_UPDATE
+    %% 0x04 CURSOR_TYPE_SCROLLABLE
+    Flags = 0,
+    Req0 = <<?COM_STMT_EXECUTE, Id:32/little, Flags, 1:32/little>>,
+    Req = case ParamDefs of
+        [] ->
+            Req0;
+        _ ->
+            Types = [Def#column_definition.type || Def <- ParamDefs],
+            NullBitMap = build_null_bitmap(Types, ParamValues),
+            NewParamsBoundFlag = 1,
+            Req1 = <<Req0/binary, NullBitMap/binary, NewParamsBoundFlag>>,
+            %% Append type and signedness (16#80 signed or 00 unsigned)
+            %% for each value
+            lists:foldl(
+                fun ({Type, Value}, Acc) ->
+                    BinValue = binary_encode(Type, Value),
+                    Signedness = 0, %% Hmm.....
+                    <<Acc/binary, Type, Signedness, BinValue/binary>>
+                end,
+                Req1,
+                lists:zip(Types, ParamValues)
+            )
+    end,
+    {ok, SeqNum1} = send_packet(SendFun, Req, 0),
+    {ok, Resp, SeqNum2} = recv_packet(RecvFun, SeqNum1),
+    case Resp of
+        ?ok_pattern ->
+            parse_ok_packet(Resp);
+        ?error_pattern ->
+            parse_error_packet(Resp);
+        _ResultSet ->
+            %% The first packet in a resultset is only the column count.
+            {ColumnCount, <<>>} = lenenc_int(Resp),
+            fetch_resultset_bin(RecvFun, ColumnCount, SeqNum2)
+    end.
+
+%% --- internal ---
+
 %% @doc Parses a handshake. This is the first thing that comes from the server
 %% @doc Parses a handshake. This is the first thing that comes from the server
 %% when connecting. If an unsupported version or variant of the protocol is used
 %% when connecting. If an unsupported version or variant of the protocol is used
 %% an error is raised.
 %% an error is raised.
@@ -167,55 +240,8 @@ parse_handshake_confirm(Packet) ->
             error(auth_method_switch)
             error(auth_method_switch)
     end.
     end.
 
 
--spec query(Query :: iodata(), sendfun(), recvfun()) ->
-    #ok{} | #error{} | #text_resultset{}.
-query(Query, SendFun, RecvFun) ->
-    Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
-    SeqNum0 = 0,
-    {ok, SeqNum1} = send_packet(SendFun, Req, SeqNum0),
-    {ok, Resp, SeqNum2} = recv_packet(RecvFun, SeqNum1),
-    case Resp of
-        ?ok_pattern ->
-            parse_ok_packet(Resp);
-        ?error_pattern ->
-            parse_error_packet(Resp);
-        _ResultSet ->
-            %% The first packet in a resultset is just the field count.
-            {FieldCount, <<>>} = lenenc_int(Resp),
-            fetch_resultset(RecvFun, FieldCount, SeqNum2)
-    end.
-
-%% @doc Prepares a statement.
--spec prepare(iodata(), sendfun(), recvfun()) -> #error{} | #prepared{}.
-prepare(Query, SendFun, RecvFun) ->
-    Req = <<?COM_STMT_PREPARE, (iolist_to_binary(Query))/binary>>,
-    {ok, SeqNum1} = send_packet(SendFun, Req, 0),
-    {ok, Resp, SeqNum2} = recv_packet(RecvFun, SeqNum1),
-    case Resp of
-        ?error_pattern ->
-            parse_error_packet(Resp);
-        <<?OK,
-          StmtId:32/little,
-          NumColumns:16/little,
-          NumParams:16/little,
-          0, %% reserved_1 -- [00] filler
-          WarningCount:16/little>> ->
-            %% This was the first packet.
-            %% If NumParams > 0 more packets will follow:
-            {ok, ParamDefs, SeqNum3} =
-                fetch_column_definitions(RecvFun, SeqNum2, NumParams, []),
-            {ok, ?eof_pattern, SeqNum4} = recv_packet(RecvFun, SeqNum3),
-            {ok, ColDefs, SeqNum5} =
-                fetch_column_definitions(RecvFun, SeqNum4, NumColumns, []),
-            {ok, ?eof_pattern, _SeqNum6} = recv_packet(RecvFun, SeqNum5),
-            #prepared{statement_id = StmtId,
-                      params = ParamDefs,
-                      columns = ColDefs,
-                      warning_count = WarningCount}
-    end.
-
 -spec fetch_resultset(recvfun(), integer(), integer()) ->
 -spec fetch_resultset(recvfun(), integer(), integer()) ->
-    #text_resultset{} | #error{}.
+    #resultset{} | #error{}.
 fetch_resultset(RecvFun, FieldCount, SeqNum) ->
 fetch_resultset(RecvFun, FieldCount, SeqNum) ->
     {ok, ColDefs, SeqNum1} = fetch_column_definitions(RecvFun, SeqNum,
     {ok, ColDefs, SeqNum1} = fetch_column_definitions(RecvFun, SeqNum,
                                                       FieldCount, []),
                                                       FieldCount, []),
@@ -223,7 +249,7 @@ fetch_resultset(RecvFun, FieldCount, SeqNum) ->
     #eof{} = parse_eof_packet(DelimiterPacket),
     #eof{} = parse_eof_packet(DelimiterPacket),
     case fetch_resultset_rows(RecvFun, ColDefs, SeqNum2, []) of
     case fetch_resultset_rows(RecvFun, ColDefs, SeqNum2, []) of
         {ok, Rows, _SeqNum3} ->
         {ok, Rows, _SeqNum3} ->
-            #text_resultset{column_definitions = ColDefs, rows = Rows};
+            #resultset{column_definitions = ColDefs, rows = Rows};
         #error{} = E ->
         #error{} = E ->
             E
             E
     end.
     end.
@@ -270,6 +296,123 @@ parse_resultset_row([_ColDef | ColDefs], Data, Acc) ->
 parse_resultset_row([], <<>>, Acc) ->
 parse_resultset_row([], <<>>, Acc) ->
     lists:reverse(Acc).
     lists:reverse(Acc).
 
 
+%% -- binary protocol --
+%% TODO: move this to mysql_binary.
+
+build_null_bitmap(_Types, _Values) -> todo.
+
+binary_encode(_Type, _Value) -> todo.
+
+%% @doc Fetches a result set and parses it in the binary protocol
+%%
+%% TODO: Merge this with fetch_resultset/3 and don't parse the rows.
+-spec fetch_resultset_bin(recvfun(), integer(), integer()) ->
+    #resultset{} | #error{}.
+fetch_resultset_bin(RecvFun, ColumnCount, SeqNum) ->
+    {ok, ColDefs, SeqNum1} = fetch_column_definitions(RecvFun, SeqNum,
+                                                      ColumnCount, []),
+    {ok, ?eof_pattern, SeqNum2} = recv_packet(RecvFun, SeqNum1),
+    case fetch_resultset_bin_rows(RecvFun, ColumnCount, ColDefs, SeqNum2, []) of
+        {ok, Rows, _SeqNum3} ->
+            #resultset{column_definitions = ColDefs, rows = Rows};
+        #error{} = E ->
+            E
+    end.
+
+%% @doc For the binary protocol. Almost identical to fetch_resultset_rows/4.
+fetch_resultset_bin_rows(RecvFun, NumColumns, ColDefs, SeqNum, Acc) ->
+    {ok, Packet, SeqNum1} = recv_packet(RecvFun, SeqNum),
+    case Packet of
+        ?error_pattern ->
+            parse_error_packet(Packet);
+        ?eof_pattern ->
+            {ok, lists:reverse(Acc), SeqNum1};
+        _AnotherRow ->
+            Row = parse_resultset_bin_row(NumColumns, ColDefs, Packet),
+            fetch_resultset_bin_rows(RecvFun, NumColumns, ColDefs, SeqNum1,
+                                     [Row | Acc])
+    end.
+
+parse_resultset_bin_row(NumColumns, ColDefs, <<0, Data/binary>>) ->
+    {NullBitMap, Rest} = mysql_binary:null_bitmap_decode(NumColumns, Data, 2),
+    parse_resultset_bin_values(ColDefs, NullBitMap, Rest, []).
+
+parse_resultset_bin_values([_ | ColDefs], <<1:1, NullBitMap/bitstring>>, Data,
+                           Acc) ->
+    %% NULL
+    parse_resultset_bin_values(ColDefs, NullBitMap, Data, [null | Acc]);
+parse_resultset_bin_values([ColDef | ColDefs], <<0:1, NullBitMap/bitstring>>,
+                           Data, Acc) ->
+   %% Not NULL
+   {Term, Rest} = bin_protocol_decode(ColDef#column_definition.type, Data),
+   parse_resultset_bin_values(ColDefs, NullBitMap, Rest, [Term | Acc]);
+parse_resultset_bin_values([], _, <<>>, Acc) ->
+    lists:reverse(Acc).
+
+%% The types are type constants for the binary protocol, such as
+%% ProtocolBinary::MYSQL_TYPE_STRING. We assume that these are the same as for
+%% the text protocol.
+-spec bin_protocol_decode(Type :: integer(), Data :: binary()) ->
+    {Term :: term(), Rest :: binary()}.
+bin_protocol_decode(T, Data)
+  when T == ?TYPE_STRING; T == ?TYPE_VARCHAR; T == ?TYPE_VAR_STRING;
+       T == ?TYPE_ENUM; T == ?TYPE_SET; T == ?TYPE_LONG_BLOB;
+       T == ?TYPE_MEDIUM_BLOB; T == ?TYPE_BLOB; T == ?TYPE_TINY_BLOB;
+       T == ?TYPE_GEOMETRY; T == ?TYPE_BIT; T == ?TYPE_DECIMAL;
+       T == ?TYPE_NEWDECIMAL ->
+    lenenc_str(Data);
+bin_protocol_decode(?TYPE_LONGLONG, <<Value:64/little, Rest/binary>>) ->
+    {Value, Rest};
+bin_protocol_decode(T, <<Value:32/little, Rest/binary>>)
+  when T == ?TYPE_LONG; T == ?TYPE_INT24 ->
+    {Value, Rest};
+bin_protocol_decode(T, <<Value:16/little, Rest/binary>>)
+  when T == ?TYPE_SHORT; T == ?TYPE_YEAR ->
+    {Value, Rest};
+bin_protocol_decode(?TYPE_TINY, <<Value:8, Rest/binary>>) ->
+    {Value, Rest};
+bin_protocol_decode(?TYPE_DOUBLE, <<Value:64/float-little, Rest/binary>>) ->
+    {Value, Rest};
+bin_protocol_decode(?TYPE_FLOAT, <<Value:32/float-little, Rest/binary>>) ->
+    {Value, Rest};
+bin_protocol_decode(?TYPE_DATE, <<Length, Data/binary>>) ->
+    %% Coded in the same way as DATETIME and TIMESTAMP below, but returned in
+    %% a simple triple.
+    case {Length, Data} of
+        {0, _} -> {{0, 0, 0}, Data};
+        {4, <<Y:16/little, M, D, Rest/binary>>} -> {{Y, M, D}, Rest}
+    end;
+bin_protocol_decode(T, <<Length, Data/binary>>)
+  when T == ?TYPE_DATETIME; T == ?TYPE_TIMESTAMP ->
+    %% length (1) -- number of bytes following (valid values: 0, 4, 7, 11)
+    case {Length, Data} of
+        {0, _} ->
+            {{{0,0,0},{0,0,0}}, Data};
+        {4, <<Y:16/little, M, D, Rest/binary>>} ->
+            {{{Y, M, D}, {0, 0, 0}}, Rest};
+        {7, <<Y:16/little, M, D, H, Mi, S, Rest/binary>>} ->
+            {{{Y, M, D}, {H, Mi, S}}, Rest};
+        {11, <<Y:16/little, M, D, H, Mi, S, Micro:32/little, Rest/binary>>} ->
+            {{{Y, M, D}, {H, Mi, S + 0.000001 * Micro}}, Rest}
+    end;
+bin_protocol_decode(?TYPE_TIME, <<Length, Data/binary>>) ->
+    %% length (1) -- number of bytes following (valid values: 0, 8, 12)
+    %% is_negative (1) -- (1 if minus, 0 for plus)
+    %% days (4) -- days
+    %% hours (1) -- hours
+    %% minutes (1) -- minutes
+    %% seconds (1) -- seconds
+    %% micro_seconds (4) -- micro-seconds
+    case {Length, Data} of
+        {0, _} ->
+            {{0, 0, 0}, Data};
+        {8, <<IsNeg, D:32/little, H, M, S, Rest/binary>>} ->
+            {{(-IsNeg bsl 1 + 1) * (D * 24 + H), M, S}, Rest};
+        {8, <<IsNeg, D:32/little, H, M, S, Micro:32/little, Rest/binary>>} ->
+            {{(-IsNeg bsl 1 + 1) * (D * 24 + H), M, S + 0.000001 * Micro},
+             Rest}
+    end.
+
 %% Parses a packet containing a column definition (part of a result set)
 %% Parses a packet containing a column definition (part of a result set)
 parse_column_definition(Data) ->
 parse_column_definition(Data) ->
     {<<"def">>, Rest1} = lenenc_str(Data),   %% catalog (always "def")
     {<<"def">>, Rest1} = lenenc_str(Data),   %% catalog (always "def")
@@ -294,8 +437,6 @@ parse_column_definition(Data) ->
     <<>> = Rest8,
     <<>> = Rest8,
     #column_definition{name = Name, type = ColumnType, charset = Charset}.
     #column_definition{name = Name, type = ColumnType, charset = Charset}.
 
 
-%% --- internal ---
-
 %% @doc Wraps Data in packet headers, sends it by calling SendFun and returns
 %% @doc Wraps Data in packet headers, sends it by calling SendFun and returns
 %% {ok, SeqNum1} where SeqNum1 is the next sequence number.
 %% {ok, SeqNum1} where SeqNum1 is the next sequence number.
 -spec send_packet(sendfun(), Data :: binary(), SeqNum :: integer()) ->
 -spec send_packet(sendfun(), Data :: binary(), SeqNum :: integer()) ->
@@ -328,6 +469,36 @@ recv_packet(RecvFun, ExpectSeqNum, Acc) ->
         true  -> recv_packet(RecvFun, NextSeqNum, Acc1)
         true  -> recv_packet(RecvFun, NextSeqNum, Acc1)
     end.
     end.
 
 
+%% @doc Parses a packet header (32 bits) and returns a tuple.
+%%
+%% The client should first read a header and parse it. Then read PacketLength
+%% bytes. If there are more packets, read another header and read a new packet
+%% length of payload until there are no more packets. The seq num should
+%% increment from 0 and may wrap around at 255 back to 0.
+%%
+%% When all packets are read and the payload of all packets are concatenated, it
+%% can be parsed using parse_response/1, etc. depending on what type of response
+%% is expected.
+-spec parse_packet_header(PackerHeader :: binary()) ->
+    {PacketLength :: integer(),
+     SeqNum :: integer(),
+     MorePacketsExist :: boolean()}.
+parse_packet_header(<<PacketLength:24/little-integer, SeqNum:8/integer>>) ->
+    {PacketLength, SeqNum, PacketLength == 16#ffffff}.
+
+%% @doc Splits a packet body into chunks and wraps them in headers. The
+%% resulting list is ready to sent to the socket.
+-spec add_packet_headers(PacketBody :: iodata(), SeqNum :: integer()) ->
+    {PacketWithHeaders :: iodata(), NextSeqNum :: integer()}.
+add_packet_headers(PacketBody, SeqNum) ->
+    Bin = iolist_to_binary(PacketBody),
+    Size = size(Bin),
+    SeqNum1 = (SeqNum + 1) rem 16#100,
+    %% Todo: implement the case when Size >= 16#ffffff.
+    if Size < 16#ffffff ->
+        {[<<Size:24/little, SeqNum:8>>, Bin], SeqNum1}
+    end.
+
 -spec parse_ok_packet(binary()) -> #ok{}.
 -spec parse_ok_packet(binary()) -> #ok{}.
 parse_ok_packet(<<?OK:8, Rest/binary>>) ->
 parse_ok_packet(<<?OK:8, Rest/binary>>) ->
     {AffectedRows, Rest1} = lenenc_int(Rest),
     {AffectedRows, Rest1} = lenenc_int(Rest),
@@ -394,6 +565,8 @@ hash_password(Password, <<"mysql_native_password">>, AuthData) ->
 hash_password(_, AuthPlugin, _) ->
 hash_password(_, AuthPlugin, _) ->
     error({auth_method, AuthPlugin}).
     error({auth_method, AuthPlugin}).
 
 
+%% --- Lowlevel: decoding variable length integers and strings ---
+
 %% lenenc_int/1 decodes length-encoded-integer values
 %% lenenc_int/1 decodes length-encoded-integer values
 -spec lenenc_int(Input :: binary()) -> {Value :: integer(), Rest :: binary()}.
 -spec lenenc_int(Input :: binary()) -> {Value :: integer(), Rest :: binary()}.
 lenenc_int(<<Value:8, Rest/bits>>) when Value < 251 -> {Value, Rest};
 lenenc_int(<<Value:8, Rest/bits>>) when Value < 251 -> {Value, Rest};
@@ -417,6 +590,9 @@ nulterm_str(Bin) ->
 -ifdef(TEST).
 -ifdef(TEST).
 -include_lib("eunit/include/eunit.hrl").
 -include_lib("eunit/include/eunit.hrl").
 
 
+%% Testing some of the internal functions, mostly the cases we don't cover in
+%% other tests.
+
 lenenc_int_test() ->
 lenenc_int_test() ->
     ?assertEqual({40, <<>>}, lenenc_int(<<40>>)),
     ?assertEqual({40, <<>>}, lenenc_int(<<40>>)),
     ?assertEqual({16#ff, <<>>}, lenenc_int(<<16#fc, 255, 0>>)),
     ?assertEqual({16#ff, <<>>}, lenenc_int(<<16#fc, 255, 0>>)),

+ 4 - 4
test/mysql_protocol_tests.erl

@@ -25,10 +25,10 @@ resultset_test() ->
     RecvFun = fun (Size) -> fakesocket_recv(FakeSock, Size) end,
     RecvFun = fun (Size) -> fakesocket_recv(FakeSock, Size) end,
     ResultSet = mysql_protocol:query(Query, SendFun, RecvFun),
     ResultSet = mysql_protocol:query(Query, SendFun, RecvFun),
     fakesocket_close(FakeSock),
     fakesocket_close(FakeSock),
-    ?assertMatch(#text_resultset{column_definitions =
-                                     [#column_definition{
-                                          name = <<"@@version_comment">>}],
-                                 rows = [[<<"MySQL Community Server (GPL)">>]]},
+    ?assertMatch(#resultset{column_definitions =
+                                [#column_definition{
+                                     name = <<"@@version_comment">>}],
+                            rows = [[<<"MySQL Community Server (GPL)">>]]},
                  ResultSet),
                  ResultSet),
     ok.
     ok.
 
 

+ 40 - 8
test/mysql_tests.erl

@@ -6,27 +6,37 @@
 -define(user,     "otptest").
 -define(user,     "otptest").
 -define(password, "otptest").
 -define(password, "otptest").
 
 
+-define(create_table_t, <<"CREATE TABLE t ("
+                          "  id INT NOT NULL PRIMARY KEY AUTO_INCREMENT,"
+                          "  bl BLOB,"
+                          "  tx TEXT NOT NULL," %% No default value
+                          "  f FLOAT,"
+                          "  dc DECIMAL(5,3),"
+                          "  ti TIME,"
+                          "  ts TIMESTAMP,"
+                          "  da DATE,"
+                          "  c CHAR(2)"
+                          ") ENGINE=InnoDB">>).
+
 connect_test() ->
 connect_test() ->
     {ok, Pid} = mysql:connect([{user, ?user}, {password, ?password}]),
     {ok, Pid} = mysql:connect([{user, ?user}, {password, ?password}]),
-
-    %% A query without a result set
-    ?assertEqual(ok, mysql:query(Pid, <<"USE otptest">>)),
-
     ?assertEqual(ok, mysql:disconnect(Pid)).
     ?assertEqual(ok, mysql:disconnect(Pid)).
 
 
 query_test_() ->
 query_test_() ->
     {setup,
     {setup,
      fun () ->
      fun () ->
          {ok, Pid} = mysql:connect([{user, ?user}, {password, ?password}]),
          {ok, Pid} = mysql:connect([{user, ?user}, {password, ?password}]),
-         %ok = mysql:query(Pid, <<"DROP DATABASE IF EXISTS otptest">>),
-         %ok = mysql:query(Pid, <<"CREATE DATABASE otptest">>),
+         ok = mysql:query(Pid, <<"DROP DATABASE IF EXISTS otptest">>),
+         ok = mysql:query(Pid, <<"CREATE DATABASE otptest">>),
          ok = mysql:query(Pid, <<"USE otptest">>),
          ok = mysql:query(Pid, <<"USE otptest">>),
+         ok = mysql:query(Pid, ?create_table_t),
          Pid
          Pid
      end,
      end,
      fun (Pid) ->
      fun (Pid) ->
+         ok = mysql:query(Pid, "DROP TABLE t;"),
          mysql:disconnect(Pid)
          mysql:disconnect(Pid)
      end,
      end,
-     {with, [fun basic_queries/1]}}.
+     {with, [fun basic_queries/1, fun text_protocol/1, fun binary_protocol/1]}}.
 
 
 basic_queries(Pid) ->
 basic_queries(Pid) ->
 
 
@@ -42,5 +52,27 @@ basic_queries(Pid) ->
     ?assertEqual({ok, [<<"i">>, <<"s">>], [[42, <<"foo">>]]},
     ?assertEqual({ok, [<<"i">>, <<"s">>], [[42, <<"foo">>]]},
                  mysql:query(Pid, <<"SELECT 42 AS i, 'foo' AS s;">>)),
                  mysql:query(Pid, <<"SELECT 42 AS i, 'foo' AS s;">>)),
 
 
-    %{ok, Fields, Rows} = mysql:query(Pid, <<"SELECT * FROM settest">>),
     ok.
     ok.
+
+text_protocol(Pid) ->
+    ok = mysql:query(Pid, <<"INSERT INTO t (bl, f, dc, ti, ts, da, c)"
+                            " VALUES ('blob', 3.14, 3.14, '00:22:11',"
+                            " '2014-11-03 00:22:24', '2014-11-03',"
+                            " NULL)">>),
+    ?assertEqual(1, mysql:warning_count(Pid)), %% tx has no default value
+    ?assertEqual(1, mysql:insert_id(Pid)),     %% auto_increment starts from 1
+    ?assertEqual(1, mysql:affected_rows(Pid)),
+
+    %% select
+    ?assertEqual({ok, [<<"id">>, <<"bl">>, <<"tx">>, <<"f">>, <<"dc">>,
+                       <<"ti">>, <<"ts">>, <<"da">>, <<"c">>],
+                      [[1, <<"blob">>, <<>>, 3.14, 3.14, {0, 22, 11},
+                        {{2014, 11, 03}, {00, 22, 24}}, {2014, 11, 03}, null]]},
+                 mysql:query(Pid, <<"SELECT * FROM t">>)),
+    ok.
+
+binary_protocol(Pid) ->
+    {ok, Stmt} = mysql:prepare(Pid, <<"SELECT * FROM t">>),
+    {ok, Cols, Rows} = mysql:query(Pid, Stmt, []),
+    io:format("Cols: ~p~nRows: ~p~n", [Cols, Rows]),
+    todo.