Browse Source

Refactoring

Viktor Söderqvist 10 years ago
parent
commit
6ffe865853
6 changed files with 281 additions and 308 deletions
  1. 4 3
      include/records.hrl
  2. 0 64
      src/mysql_binary.erl
  3. 1 2
      src/mysql_connection.erl
  4. 270 110
      src/mysql_protocol.erl
  5. 0 124
      src/mysql_text_protocol.erl
  6. 6 5
      test/mysql_tests.erl

+ 4 - 3
include/records.hrl

@@ -40,10 +40,11 @@
 %% Column definition, used while parsing a result set.
 %% Column definition, used while parsing a result set.
 -record(column_definition, {name, type, charset}).
 -record(column_definition, {name, type, charset}).
 
 
-%% A resultset as received from the server using the text protocol.
-%% For text protocol resultsets, rows :: [[binary() | null]].
+%% A resultset. The rows can be either lists of terms or unparsed binaries as
+%% received from the server using either the text protocol or the binary
+%% protocol.
 -record(resultset, {column_definitions :: [#column_definition{}],
 -record(resultset, {column_definitions :: [#column_definition{}],
-                    rows :: [[term()]]}).
+                    rows :: [[term()] | binary()]}).
 
 
 %% Response of a successfull prepare call.
 %% Response of a successfull prepare call.
 -record(prepared, {statement_id :: integer(),
 -record(prepared, {statement_id :: integer(),

+ 0 - 64
src/mysql_binary.erl

@@ -1,64 +0,0 @@
-%% MySQL/OTP – a MySQL driver for Erlang/OTP
-%% Copyright (C) 2014 Viktor Söderqvist
-%%
-%% This program is free software: you can redistribute it and/or modify
-%% it under the terms of the GNU General Public License as published by
-%% the Free Software Foundation, either version 3 of the License, or
-%% (at your option) any later version.
-%%
-%% This program is distributed in the hope that it will be useful,
-%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-%% GNU General Public License for more details.
-%%
-%% You should have received a copy of the GNU General Public License
-%% along with this program. If not, see <https://www.gnu.org/licenses/>.
-
-%% @doc The MySQL binary protocol is used for prepared statements. This module
-%% is used mainly from the mysql_protocol module.
--module(mysql_binary).
-
--export([null_bitmap_decode/3, null_bitmap_encode/2]).
-
-%% @doc Decodes a null bitmap as stored by MySQL and returns it in a strait
-%% bitstring from left to right. Returns it together with the rest of the data.
-%%
-%% In the MySQL null bitmap the bits are stored counting bytes from the left and
-%% bits within each byte from the right. (Sort of little endian.)
--spec null_bitmap_decode(NumColumns :: integer(), BitOffset :: integer(),
-                         Data :: binary()) ->
-    {NullBitstring :: bitstring(), Rest :: binary()}.
-null_bitmap_decode(NumColumns, Data, BitOffset) ->
-    %% Binary shift right by 3 is equivallent to integer division by 8.
-    BitMapLength = (NumColumns + BitOffset + 7) bsr 3,
-    <<NullBitstring0:BitMapLength/binary, Rest/binary>> = Data,
-    <<_:BitOffset, NullBitstring:NumColumns/bitstring, _/bitstring>> =
-        << <<(reverse_byte(B))/binary>> || <<B:1/binary>> <= NullBitstring0 >>,
-    {NullBitstring, Rest}.
-
-%% @doc The reverse of null_bitmap_decode/3. The number of columns is taken to
-%% be the number of bits in NullBitstring. Returns the MySQL null bitmap as a
-%% binary (i.e. full bytes). BitOffset is the number of unused bits that should
-%% be inserted before the other bits.
--spec null_bitmap_encode(bitstring(), integer()) -> binary().
-null_bitmap_encode(NullBitstring, BitOffset) ->
-    PayloadLength = bit_size(NullBitstring) + BitOffset,
-    %% Round up to a multiple of 8.
-    BitMapLength = (PayloadLength + 7) band bnot 7,
-    PadBitsLength = BitMapLength - PayloadLength,
-    PaddedBitstring = <<0:BitOffset, NullBitstring/bitstring, 0:PadBitsLength>>,
-    << <<(reverse_byte(B))/binary>> || <<B:1/binary>> <= PaddedBitstring >>.
-
-%% Reverses the bits in a byte.
-reverse_byte(<<A:1, B:1, C:1, D:1, E:1, F:1, G:1, H:1>>) ->
-    <<H:1, G:1, F:1, E:1, D:1, C:1, B:1, A:1>>.
-
--ifdef(TEST).
--include_lib("eunit/include/eunit.hrl").
-
-null_bitmap_test() ->
-    ?assertEqual({<<0, 1:1>>, <<>>}, null_bitmap_decode(9, <<0, 4>>, 2)),
-    ?assertEqual(<<0, 4>>, null_bitmap_encode(<<0, 1:1>>, 2)),
-    ok.
-
--endif.

+ 1 - 2
src/mysql_connection.erl

@@ -80,8 +80,7 @@ handle_call({query, Query}, _From, State) when is_binary(Query);
             {reply, {error, error_to_reason(E)}, State1};
             {reply, {error, error_to_reason(E)}, State1};
         #resultset{column_definitions = ColDefs, rows = Rows} ->
         #resultset{column_definitions = ColDefs, rows = Rows} ->
             Names = [Def#column_definition.name || Def <- ColDefs],
             Names = [Def#column_definition.name || Def <- ColDefs],
-            Rows1 = decode_text_rows(ColDefs, Rows),
-            {reply, {ok, Names, Rows1}, State1}
+            {reply, {ok, Names, Rows}, State1}
     end;
     end;
 handle_call({query, Stmt, Args}, _From, State) when is_integer(Stmt);
 handle_call({query, Stmt, Args}, _From, State) when is_integer(Stmt);
                                                     is_atom(Stmt) ->
                                                     is_atom(Stmt) ->

+ 270 - 110
src/mysql_protocol.erl

@@ -76,10 +76,18 @@ query(Query, SendFun, RecvFun) ->
         _ResultSet ->
         _ResultSet ->
             %% The first packet in a resultset is only the column count.
             %% The first packet in a resultset is only the column count.
             {ColumnCount, <<>>} = lenenc_int(Resp),
             {ColumnCount, <<>>} = lenenc_int(Resp),
-            ResultSet = fetch_resultset(RecvFun, ColumnCount, SeqNum2),
-            %% TODO: Factor out parsing the rows from fetch_resultset/3 and do
-            %% that here instead.
-            ResultSet
+            case fetch_resultset(RecvFun, ColumnCount, SeqNum2) of
+                #error{} = E ->
+                    E;
+                #resultset{column_definitions = ColDefs, rows = Rows} = R ->
+                    %% Parse the rows according to the 'text protocol'
+                    %% representation.
+                    ColumnTypes = [ColDef#column_definition.type
+                                   || ColDef <- ColDefs],
+                    Rows1 = [decode_text_row(ColumnCount, ColumnTypes, Row)
+                             || Row <- Rows],
+                    R#resultset{rows = Rows1}
+            end
     end.
     end.
 
 
 %% @doc Prepares a statement.
 %% @doc Prepares a statement.
@@ -106,7 +114,8 @@ prepare(Query, SendFun, RecvFun) ->
                 0 ->
                 0 ->
                     SeqNum3;
                     SeqNum3;
                 _ ->
                 _ ->
-                    {ok, ?eof_pattern, SeqNum3x} = recv_packet(RecvFun, SeqNum3),
+                    {ok, ?eof_pattern, SeqNum3x} = recv_packet(RecvFun,
+                                                               SeqNum3),
                     SeqNum3x
                     SeqNum3x
             end,
             end,
             {ok, ColDefs, SeqNum5} =
             {ok, ColDefs, SeqNum5} =
@@ -121,7 +130,7 @@ prepare(Query, SendFun, RecvFun) ->
 %% @doc Executes a prepared statement.
 %% @doc Executes a prepared statement.
 -spec execute(#prepared{}, [term()], sendfun(), recvfun()) -> #resultset{}.
 -spec execute(#prepared{}, [term()], sendfun(), recvfun()) -> #resultset{}.
 execute(#prepared{statement_id = Id, params = ParamDefs}, ParamValues,
 execute(#prepared{statement_id = Id, params = ParamDefs}, ParamValues,
-        SendFun, RecvFun) ->
+        SendFun, RecvFun) when length(ParamDefs) == length(ParamValues) ->
     %% Flags Constant Name
     %% Flags Constant Name
     %% 0x00 CURSOR_TYPE_NO_CURSOR
     %% 0x00 CURSOR_TYPE_NO_CURSOR
     %% 0x01 CURSOR_TYPE_READ_ONLY
     %% 0x01 CURSOR_TYPE_READ_ONLY
@@ -133,20 +142,21 @@ execute(#prepared{statement_id = Id, params = ParamDefs}, ParamValues,
         [] ->
         [] ->
             Req0;
             Req0;
         _ ->
         _ ->
-            Types = [Def#column_definition.type || Def <- ParamDefs],
-            NullBitMap = build_null_bitmap(Types, ParamValues),
+            ParamTypes = [Def#column_definition.type || Def <- ParamDefs],
+            NullBitMap = build_null_bitmap(ParamValues),
+            %% TODO: Find out when would you use NewParamsBoundFlag = 0?
             NewParamsBoundFlag = 1,
             NewParamsBoundFlag = 1,
             Req1 = <<Req0/binary, NullBitMap/binary, NewParamsBoundFlag>>,
             Req1 = <<Req0/binary, NullBitMap/binary, NewParamsBoundFlag>>,
             %% Append type and signedness (16#80 signed or 00 unsigned)
             %% Append type and signedness (16#80 signed or 00 unsigned)
             %% for each value
             %% for each value
             lists:foldl(
             lists:foldl(
                 fun ({Type, Value}, Acc) ->
                 fun ({Type, Value}, Acc) ->
-                    BinValue = binary_encode(Type, Value),
+                    BinValue = encode_binary(Type, Value),
                     Signedness = 0, %% Hmm.....
                     Signedness = 0, %% Hmm.....
                     <<Acc/binary, Type, Signedness, BinValue/binary>>
                     <<Acc/binary, Type, Signedness, BinValue/binary>>
                 end,
                 end,
                 Req1,
                 Req1,
-                lists:zip(Types, ParamValues)
+                lists:zip(ParamTypes, ParamValues)
             )
             )
     end,
     end,
     {ok, SeqNum1} = send_packet(SendFun, Req, 0),
     {ok, SeqNum1} = send_packet(SendFun, Req, 0),
@@ -156,10 +166,24 @@ execute(#prepared{statement_id = Id, params = ParamDefs}, ParamValues,
             parse_ok_packet(Resp);
             parse_ok_packet(Resp);
         ?error_pattern ->
         ?error_pattern ->
             parse_error_packet(Resp);
             parse_error_packet(Resp);
-        _ResultSet ->
+        _ResultPacket ->
             %% The first packet in a resultset is only the column count.
             %% The first packet in a resultset is only the column count.
             {ColumnCount, <<>>} = lenenc_int(Resp),
             {ColumnCount, <<>>} = lenenc_int(Resp),
-            fetch_resultset_bin(RecvFun, ColumnCount, SeqNum2)
+            case fetch_resultset(RecvFun, ColumnCount, SeqNum2) of
+                #error{} = E ->
+                    %% TODO: Find a way to get here and write a testcase.
+                    %% This can happen for the text protocol but maybe not for
+                    %% the binary protocol.
+                    E;
+                #resultset{column_definitions = ColDefs, rows = Rows} = R ->
+                    %% Parse the rows according to the 'binary protocol'
+                    %% representation.
+                    ColumnTypes = [ColDef#column_definition.type
+                                   || ColDef <- ColDefs],
+                    Rows1 = [decode_binary_row(ColumnCount, ColumnTypes, Row)
+                             || Row <- Rows],
+                    R#resultset{rows = Rows1}
+            end
     end.
     end.
 
 
 %% --- internal ---
 %% --- internal ---
@@ -256,6 +280,7 @@ parse_handshake_confirm(Packet) ->
             error(auth_method_switch)
             error(auth_method_switch)
     end.
     end.
 
 
+%% Fetches packets until a
 -spec fetch_resultset(recvfun(), integer(), integer()) ->
 -spec fetch_resultset(recvfun(), integer(), integer()) ->
     #resultset{} | #error{}.
     #resultset{} | #error{}.
 fetch_resultset(RecvFun, FieldCount, SeqNum) ->
 fetch_resultset(RecvFun, FieldCount, SeqNum) ->
@@ -263,7 +288,7 @@ fetch_resultset(RecvFun, FieldCount, SeqNum) ->
                                                       FieldCount, []),
                                                       FieldCount, []),
     {ok, DelimiterPacket, SeqNum2} = recv_packet(RecvFun, SeqNum1),
     {ok, DelimiterPacket, SeqNum2} = recv_packet(RecvFun, SeqNum1),
     #eof{} = parse_eof_packet(DelimiterPacket),
     #eof{} = parse_eof_packet(DelimiterPacket),
-    case fetch_resultset_rows(RecvFun, ColDefs, SeqNum2, []) of
+    case fetch_resultset_rows(RecvFun, SeqNum2, []) of
         {ok, Rows, _SeqNum3} ->
         {ok, Rows, _SeqNum3} ->
             #resultset{column_definitions = ColDefs, rows = Rows};
             #resultset{column_definitions = ColDefs, rows = Rows};
         #error{} = E ->
         #error{} = E ->
@@ -281,124 +306,214 @@ fetch_column_definitions(RecvFun, SeqNum, NumLeft, Acc) when NumLeft > 0 ->
 fetch_column_definitions(_RecvFun, SeqNum, 0, Acc) ->
 fetch_column_definitions(_RecvFun, SeqNum, 0, Acc) ->
     {ok, lists:reverse(Acc), SeqNum}.
     {ok, lists:reverse(Acc), SeqNum}.
 
 
--spec fetch_resultset_rows(recvfun(), ColumnDefinitions, integer(),
-                           Acc) -> {ok, Rows, integer()} | #error{}
-    when ColumnDefinitions :: [#column_definition{}],
-         Acc :: [[binary() | null]],
-         Rows :: [[binary() | null]].
-fetch_resultset_rows(RecvFun, ColDefs, SeqNum, Acc) ->
+%% @doc Fetches rows in a result set. There is a packet per row. The row packets
+%% are not decoded. This function can be used for both the binary and the text
+%% protocol result sets.
+-spec fetch_resultset_rows(recvfun(), SeqNum :: integer(), Acc) ->
+    {ok, Rows, integer()} | #error{}
+    when Acc :: [binary()],
+         Rows :: [binary()].
+fetch_resultset_rows(RecvFun, SeqNum, Acc) ->
     {ok, Packet, SeqNum1} = recv_packet(RecvFun, SeqNum),
     {ok, Packet, SeqNum1} = recv_packet(RecvFun, SeqNum),
     case Packet of
     case Packet of
         ?error_pattern ->
         ?error_pattern ->
             parse_error_packet(Packet);
             parse_error_packet(Packet);
         ?eof_pattern ->
         ?eof_pattern ->
             {ok, lists:reverse(Acc), SeqNum1};
             {ok, lists:reverse(Acc), SeqNum1};
-        _AnotherRow ->
-            Row = parse_resultset_row(ColDefs, Packet, []),
-            fetch_resultset_rows(RecvFun, ColDefs, SeqNum1, [Row | Acc])
+        Row ->
+            fetch_resultset_rows(RecvFun, SeqNum1, [Row | Acc])
     end.
     end.
 
 
+%% -- both text and binary protocol --
+
+%% Parses a packet containing a column definition (part of a result set)
+parse_column_definition(Data) ->
+    {<<"def">>, Rest1} = lenenc_str(Data),   %% catalog (always "def")
+    {_Schema, Rest2} = lenenc_str(Rest1),    %% schema-name 
+    {_Table, Rest3} = lenenc_str(Rest2),     %% virtual table-name 
+    {_OrgTable, Rest4} = lenenc_str(Rest3),  %% physical table-name 
+    {Name, Rest5} = lenenc_str(Rest4),       %% virtual column name
+    {_OrgName, Rest6} = lenenc_str(Rest5),   %% physical column name
+    {16#0c, Rest7} = lenenc_int(Rest6),      %% length of the following fields
+                                             %% (always 0x0c)
+    <<Charset:16/little,        %% column character set
+      _ColumnLength:32/little,  %% maximum length of the field
+      ColumnType:8,             %% type of the column as defined in Column Type
+      _Flags:16/little,         %% flags
+      _Decimals:8,              %% max shown decimal digits:
+      0,  %% "filler"           %%   - 0x00 for integers and static strings
+      0,                        %%   - 0x1f for dynamic strings, double, float
+      Rest8/binary>> = Rest7,   %%   - 0x00 to 0x51 for decimals
+    %% Here, if command was COM_FIELD_LIST {
+    %%   default values: lenenc_str
+    %% }
+    <<>> = Rest8,
+    #column_definition{name = Name, type = ColumnType, charset = Charset}.
+
+%% -- text protocol --
+
+-spec decode_text_row(NumColumns :: integer(), ColumnTypes :: integer(),
+                      Data :: binary()) -> [term()].
+decode_text_row(_NumColumns, ColumnTypes, Data) ->
+    decode_text_row_acc(ColumnTypes, Data, []).
+
 %% parses Data using ColDefs and builds the values Acc.
 %% parses Data using ColDefs and builds the values Acc.
-parse_resultset_row([_ColDef | ColDefs], Data, Acc) ->
+decode_text_row_acc([Type | Types], Data, Acc) ->
     case Data of
     case Data of
         <<16#fb, Rest/binary>> ->
         <<16#fb, Rest/binary>> ->
             %% NULL
             %% NULL
-            parse_resultset_row(ColDefs, Rest, [null | Acc]);
+            decode_text_row_acc(Types, Rest, [null | Acc]);
         _ ->
         _ ->
             %% Every thing except NULL
             %% Every thing except NULL
-            {Str, Rest} = lenenc_str(Data),
-            parse_resultset_row(ColDefs, Rest, [Str | Acc])
+            {Text, Rest} = lenenc_str(Data),
+            Term = decode_text(Type, Text),
+            decode_text_row_acc(Types, Rest, [Term | Acc])
     end;
     end;
-parse_resultset_row([], <<>>, Acc) ->
+decode_text_row_acc([], <<>>, Acc) ->
     lists:reverse(Acc).
     lists:reverse(Acc).
 
 
-%% -- binary protocol --
-%% TODO: move this to mysql_binary.
-
-build_null_bitmap(_Types, _Values) -> todo.
-
-binary_encode(_Type, _Value) -> todo.
-
-%% @doc Fetches a result set and parses it in the binary protocol
-%%
-%% TODO: Merge this with fetch_resultset/3 and don't parse the rows.
--spec fetch_resultset_bin(recvfun(), integer(), integer()) ->
-    #resultset{} | #error{}.
-fetch_resultset_bin(RecvFun, ColumnCount, SeqNum) ->
-    {ok, ColDefs, SeqNum1} = fetch_column_definitions(RecvFun, SeqNum,
-                                                      ColumnCount, []),
-    {ok, ?eof_pattern, SeqNum2} = recv_packet(RecvFun, SeqNum1),
-    case fetch_resultset_bin_rows(RecvFun, ColumnCount, ColDefs, SeqNum2, []) of
-        {ok, Rows, _SeqNum3} ->
-            #resultset{column_definitions = ColDefs, rows = Rows};
-        #error{} = E ->
-            E
-    end.
-
-%% @doc For the binary protocol. Almost identical to fetch_resultset_rows/4.
-fetch_resultset_bin_rows(RecvFun, NumColumns, ColDefs, SeqNum, Acc) ->
-    {ok, Packet, SeqNum1} = recv_packet(RecvFun, SeqNum),
-    case Packet of
-        ?error_pattern ->
-            parse_error_packet(Packet);
-        ?eof_pattern ->
-            {ok, lists:reverse(Acc), SeqNum1};
-        _AnotherRow ->
-            Row = parse_resultset_bin_row(NumColumns, ColDefs, Packet),
-            fetch_resultset_bin_rows(RecvFun, NumColumns, ColDefs, SeqNum1,
-                                     [Row | Acc])
-    end.
+%% @doc When receiving data in the text protocol, we get everything as binaries
+%% (except NULL). This function is used to parse these strings values.
+decode_text(_, null) ->
+    %% NULL is the only value not represented as a binary.
+    null;
+decode_text(T, Text)
+  when T == ?TYPE_TINY; T == ?TYPE_SHORT; T == ?TYPE_LONG; T == ?TYPE_LONGLONG;
+       T == ?TYPE_INT24; T == ?TYPE_YEAR; T == ?TYPE_BIT ->
+    %% For BIT, do we want bitstring, int or binary?
+    binary_to_integer(Text);
+decode_text(T, Text)
+  when T == ?TYPE_DECIMAL; T == ?TYPE_NEWDECIMAL; T == ?TYPE_VARCHAR;
+       T == ?TYPE_ENUM; T == ?TYPE_TINY_BLOB; T == ?TYPE_MEDIUM_BLOB;
+       T == ?TYPE_LONG_BLOB; T == ?TYPE_BLOB; T == ?TYPE_VAR_STRING;
+       T == ?TYPE_STRING; T == ?TYPE_GEOMETRY ->
+    Text;
+decode_text(?TYPE_DATE, <<Y:4/binary, "-", M:2/binary, "-", D:2/binary>>) ->
+    {binary_to_integer(Y), binary_to_integer(M), binary_to_integer(D)};
+decode_text(?TYPE_TIME, <<H:2/binary, ":", Mi:2/binary, ":", S:2/binary>>) ->
+    %% FIXME: Hours can be negative + more digits. Seconds can have fractions.
+    %% Add tests for these cases.
+    {binary_to_integer(H), binary_to_integer(Mi), binary_to_integer(S)};
+decode_text(T, <<Y:4/binary, "-", M:2/binary, "-", D:2/binary, " ",
+                 H:2/binary, ":", Mi:2/binary, ":", S:2/binary>>)
+  when T == ?TYPE_TIMESTAMP; T == ?TYPE_DATETIME ->
+    {{binary_to_integer(Y), binary_to_integer(M), binary_to_integer(D)},
+     {binary_to_integer(H), binary_to_integer(Mi), binary_to_integer(S)}};
+decode_text(T, Text) when T == ?TYPE_FLOAT; T == ?TYPE_DOUBLE ->
+    try binary_to_float(Text)
+    catch error:badarg ->
+        try binary_to_integer(Text) of
+            Int -> float(Int)
+        catch error:badarg ->
+            %% It is something like "4e75" that must be turned into "4.0e75"
+            binary_to_float(binary:replace(Text, <<"e">>, <<".0e">>))
+        end
+    end;
+decode_text(?TYPE_SET, <<>>) ->
+    sets:new();
+decode_text(?TYPE_SET, Text) ->
+    sets:from_list(binary:split(Text, <<",">>, [global])).
 
 
-parse_resultset_bin_row(NumColumns, ColDefs, <<0, Data/binary>>) ->
-    {NullBitMap, Rest} = mysql_binary:null_bitmap_decode(NumColumns, Data, 2),
-    parse_resultset_bin_values(ColDefs, NullBitMap, Rest, []).
+%% -- binary protocol --
 
 
-parse_resultset_bin_values([_ | ColDefs], <<1:1, NullBitMap/bitstring>>, Data,
-                           Acc) ->
+%% @doc Decodes a packet representing a row in a binary result set.
+%% It consists of a 0 byte, then a null bitmap, then the values.
+%% Returns a list of length NumColumns with terms of appropriate types for each
+%% MySQL type in ColumnTypes.
+-spec decode_binary_row(NumColumns :: integer(), ColumnTypes :: [integer()],
+                 Data :: binary()) -> [term()].
+decode_binary_row(NumColumns, ColumnTypes, <<0, Data/binary>>) ->
+    {NullBitMap, Rest} = null_bitmap_decode(NumColumns, Data, 2),
+    decode_binary_row_acc(ColumnTypes, NullBitMap, Rest, []).
+
+%% @doc Accumulating helper for decode_binary_row/3.
+decode_binary_row_acc([_ | Types], <<1:1, NullBitMap/bitstring>>, Data, Acc) ->
     %% NULL
     %% NULL
-    parse_resultset_bin_values(ColDefs, NullBitMap, Data, [null | Acc]);
-parse_resultset_bin_values([ColDef | ColDefs], <<0:1, NullBitMap/bitstring>>,
-                           Data, Acc) ->
+    decode_binary_row_acc(Types, NullBitMap, Data, [null | Acc]);
+decode_binary_row_acc([Type | Types], <<0:1, NullBitMap/bitstring>>, Data,
+                      Acc) ->
    %% Not NULL
    %% Not NULL
-   {Term, Rest} = bin_protocol_decode(ColDef#column_definition.type, Data),
-   parse_resultset_bin_values(ColDefs, NullBitMap, Rest, [Term | Acc]);
-parse_resultset_bin_values([], _, <<>>, Acc) ->
+   {Term, Rest} = decode_binary(Type, Data),
+   decode_binary_row_acc(Types, NullBitMap, Rest, [Term | Acc]);
+decode_binary_row_acc([], _, <<>>, Acc) ->
     lists:reverse(Acc).
     lists:reverse(Acc).
 
 
+%% @doc Decodes a null bitmap as stored by MySQL and returns it in a strait
+%% bitstring counting bits from left to right in a tuple with remaining data.
+%%
+%% In the MySQL null bitmap the bits are stored counting bytes from the left and
+%% bits within each byte from the right. (Sort of little endian.)
+-spec null_bitmap_decode(NumColumns :: integer(), BitOffset :: integer(),
+                         Data :: binary()) ->
+    {NullBitstring :: bitstring(), Rest :: binary()}.
+null_bitmap_decode(NumColumns, Data, BitOffset) ->
+    %% Binary shift right by 3 is equivallent to integer division by 8.
+    BitMapLength = (NumColumns + BitOffset + 7) bsr 3,
+    <<NullBitstring0:BitMapLength/binary, Rest/binary>> = Data,
+    <<_:BitOffset, NullBitstring:NumColumns/bitstring, _/bitstring>> =
+        << <<(reverse_byte(B))/binary>> || <<B:1/binary>> <= NullBitstring0 >>,
+    {NullBitstring, Rest}.
+
+%% @doc The reverse of null_bitmap_decode/3. The number of columns is taken to
+%% be the number of bits in NullBitstring. Returns the MySQL null bitmap as a
+%% binary (i.e. full bytes). BitOffset is the number of unused bits that should
+%% be inserted before the other bits.
+-spec null_bitmap_encode(bitstring(), integer()) -> binary().
+null_bitmap_encode(NullBitstring, BitOffset) ->
+    PayloadLength = bit_size(NullBitstring) + BitOffset,
+    %% Round up to a multiple of 8.
+    BitMapLength = (PayloadLength + 7) band bnot 7,
+    PadBitsLength = BitMapLength - PayloadLength,
+    PaddedBitstring = <<0:BitOffset, NullBitstring/bitstring, 0:PadBitsLength>>,
+    << <<(reverse_byte(B))/binary>> || <<B:1/binary>> <= PaddedBitstring >>.
+
+%% Reverses the bits in a byte.
+reverse_byte(<<A:1, B:1, C:1, D:1, E:1, F:1, G:1, H:1>>) ->
+    <<H:1, G:1, F:1, E:1, D:1, C:1, B:1, A:1>>.
+
+%% @doc Used for executing prepared statements. The bit offset whould be 0 in
+%% this case.
+-spec build_null_bitmap([any()]) -> binary().
+build_null_bitmap(Values) ->
+    Bits = << <<(case V of null -> 1; _ -> 0 end):1/bits>> || V <- Values >>,
+    null_bitmap_encode(Bits, 0).
+
+%% Decodes a value as received in the 'binary protocol' result set.
+%%
 %% The types are type constants for the binary protocol, such as
 %% The types are type constants for the binary protocol, such as
-%% ProtocolBinary::MYSQL_TYPE_STRING. We assume that these are the same as for
-%% the text protocol.
--spec bin_protocol_decode(Type :: integer(), Data :: binary()) ->
+%% ProtocolBinary::MYSQL_TYPE_STRING. In the guide "MySQL Internals" these are
+%% not listed, but we assume that are the same as for the text protocol.
+-spec decode_binary(Type :: integer(), Data :: binary()) ->
     {Term :: term(), Rest :: binary()}.
     {Term :: term(), Rest :: binary()}.
-bin_protocol_decode(T, Data)
+decode_binary(T, Data)
   when T == ?TYPE_STRING; T == ?TYPE_VARCHAR; T == ?TYPE_VAR_STRING;
   when T == ?TYPE_STRING; T == ?TYPE_VARCHAR; T == ?TYPE_VAR_STRING;
        T == ?TYPE_ENUM; T == ?TYPE_SET; T == ?TYPE_LONG_BLOB;
        T == ?TYPE_ENUM; T == ?TYPE_SET; T == ?TYPE_LONG_BLOB;
        T == ?TYPE_MEDIUM_BLOB; T == ?TYPE_BLOB; T == ?TYPE_TINY_BLOB;
        T == ?TYPE_MEDIUM_BLOB; T == ?TYPE_BLOB; T == ?TYPE_TINY_BLOB;
        T == ?TYPE_GEOMETRY; T == ?TYPE_BIT; T == ?TYPE_DECIMAL;
        T == ?TYPE_GEOMETRY; T == ?TYPE_BIT; T == ?TYPE_DECIMAL;
        T == ?TYPE_NEWDECIMAL ->
        T == ?TYPE_NEWDECIMAL ->
     lenenc_str(Data);
     lenenc_str(Data);
-bin_protocol_decode(?TYPE_LONGLONG, <<Value:64/little, Rest/binary>>) ->
+decode_binary(?TYPE_LONGLONG, <<Value:64/little, Rest/binary>>) ->
     {Value, Rest};
     {Value, Rest};
-bin_protocol_decode(T, <<Value:32/little, Rest/binary>>)
+decode_binary(T, <<Value:32/little, Rest/binary>>)
   when T == ?TYPE_LONG; T == ?TYPE_INT24 ->
   when T == ?TYPE_LONG; T == ?TYPE_INT24 ->
     {Value, Rest};
     {Value, Rest};
-bin_protocol_decode(T, <<Value:16/little, Rest/binary>>)
+decode_binary(T, <<Value:16/little, Rest/binary>>)
   when T == ?TYPE_SHORT; T == ?TYPE_YEAR ->
   when T == ?TYPE_SHORT; T == ?TYPE_YEAR ->
     {Value, Rest};
     {Value, Rest};
-bin_protocol_decode(?TYPE_TINY, <<Value:8, Rest/binary>>) ->
+decode_binary(?TYPE_TINY, <<Value:8, Rest/binary>>) ->
     {Value, Rest};
     {Value, Rest};
-bin_protocol_decode(?TYPE_DOUBLE, <<Value:64/float-little, Rest/binary>>) ->
+decode_binary(?TYPE_DOUBLE, <<Value:64/float-little, Rest/binary>>) ->
     {Value, Rest};
     {Value, Rest};
-bin_protocol_decode(?TYPE_FLOAT, <<Value:32/float-little, Rest/binary>>) ->
+decode_binary(?TYPE_FLOAT, <<Value:32/float-little, Rest/binary>>) ->
     {Value, Rest};
     {Value, Rest};
-bin_protocol_decode(?TYPE_DATE, <<Length, Data/binary>>) ->
+decode_binary(?TYPE_DATE, <<Length, Data/binary>>) ->
     %% Coded in the same way as DATETIME and TIMESTAMP below, but returned in
     %% Coded in the same way as DATETIME and TIMESTAMP below, but returned in
     %% a simple triple.
     %% a simple triple.
     case {Length, Data} of
     case {Length, Data} of
         {0, _} -> {{0, 0, 0}, Data};
         {0, _} -> {{0, 0, 0}, Data};
         {4, <<Y:16/little, M, D, Rest/binary>>} -> {{Y, M, D}, Rest}
         {4, <<Y:16/little, M, D, Rest/binary>>} -> {{Y, M, D}, Rest}
     end;
     end;
-bin_protocol_decode(T, <<Length, Data/binary>>)
+decode_binary(T, <<Length, Data/binary>>)
   when T == ?TYPE_DATETIME; T == ?TYPE_TIMESTAMP ->
   when T == ?TYPE_DATETIME; T == ?TYPE_TIMESTAMP ->
     %% length (1) -- number of bytes following (valid values: 0, 4, 7, 11)
     %% length (1) -- number of bytes following (valid values: 0, 4, 7, 11)
     case {Length, Data} of
     case {Length, Data} of
@@ -411,7 +526,7 @@ bin_protocol_decode(T, <<Length, Data/binary>>)
         {11, <<Y:16/little, M, D, H, Mi, S, Micro:32/little, Rest/binary>>} ->
         {11, <<Y:16/little, M, D, H, Mi, S, Micro:32/little, Rest/binary>>} ->
             {{{Y, M, D}, {H, Mi, S + 0.000001 * Micro}}, Rest}
             {{{Y, M, D}, {H, Mi, S + 0.000001 * Micro}}, Rest}
     end;
     end;
-bin_protocol_decode(?TYPE_TIME, <<Length, Data/binary>>) ->
+decode_binary(?TYPE_TIME, <<Length, Data/binary>>) ->
     %% length (1) -- number of bytes following (valid values: 0, 8, 12)
     %% length (1) -- number of bytes following (valid values: 0, 8, 12)
     %% is_negative (1) -- (1 if minus, 0 for plus)
     %% is_negative (1) -- (1 if minus, 0 for plus)
     %% days (4) -- days
     %% days (4) -- days
@@ -429,29 +544,26 @@ bin_protocol_decode(?TYPE_TIME, <<Length, Data/binary>>) ->
              Rest}
              Rest}
     end.
     end.
 
 
-%% Parses a packet containing a column definition (part of a result set)
-parse_column_definition(Data) ->
-    {<<"def">>, Rest1} = lenenc_str(Data),   %% catalog (always "def")
-    {_Schema, Rest2} = lenenc_str(Rest1),    %% schema-name 
-    {_Table, Rest3} = lenenc_str(Rest2),     %% virtual table-name 
-    {_OrgTable, Rest4} = lenenc_str(Rest3),  %% physical table-name 
-    {Name, Rest5} = lenenc_str(Rest4),       %% virtual column name
-    {_OrgName, Rest6} = lenenc_str(Rest5),   %% physical column name
-    {16#0c, Rest7} = lenenc_int(Rest6),      %% length of the following fields
-                                             %% (always 0x0c)
-    <<Charset:16/little,        %% column character set
-      _ColumnLength:32/little,  %% maximum length of the field
-      ColumnType:8,             %% type of the column as defined in Column Type
-      _Flags:16/little,         %% flags
-      _Decimals:8,              %% max shown decimal digits:
-      0,  %% "filler"           %%   - 0x00 for integers and static strings
-      0,                        %%   - 0x1f for dynamic strings, double, float
-      Rest8/binary>> = Rest7,   %%   - 0x00 to 0x51 for decimals
-    %% Here, if command was COM_FIELD_LIST {
-    %%   default values: lenenc_str
-    %% }
-    <<>> = Rest8,
-    #column_definition{name = Name, type = ColumnType, charset = Charset}.
+%% @doc Encodes a term reprenting av value of type Type as a binary for use in
+%% the binary protocol.
+-spec encode_binary(Type :: integer(), Value :: term()) -> binary().
+encode_binary(_Type, null) ->
+    <<>>;
+encode_binary(T, Value)
+  when T == ?TYPE_STRING; T == ?TYPE_VARCHAR; T == ?TYPE_VAR_STRING;
+       T == ?TYPE_ENUM; T == ?TYPE_SET; T == ?TYPE_LONG_BLOB;
+       T == ?TYPE_MEDIUM_BLOB; T == ?TYPE_BLOB; T == ?TYPE_TINY_BLOB;
+       T == ?TYPE_GEOMETRY; T == ?TYPE_BIT; T == ?TYPE_DECIMAL;
+       T == ?TYPE_NEWDECIMAL ->
+    build_lenenc_str(Value);
+encode_binary(_T, _Value) ->
+    fixme = todo.
+
+%% Rename this and lenenc_str (the decode function)
+build_lenenc_str(_Value) ->
+    ok = fixme.
+
+%% -- Protocol basics: packets --
 
 
 %% @doc Wraps Data in packet headers, sends it by calling SendFun and returns
 %% @doc Wraps Data in packet headers, sends it by calling SendFun and returns
 %% {ok, SeqNum1} where SeqNum1 is the next sequence number.
 %% {ok, SeqNum1} where SeqNum1 is the next sequence number.
@@ -609,6 +721,54 @@ nulterm_str(Bin) ->
 %% Testing some of the internal functions, mostly the cases we don't cover in
 %% Testing some of the internal functions, mostly the cases we don't cover in
 %% other tests.
 %% other tests.
 
 
+decode_text_test() ->
+    %% Int types
+    lists:foreach(fun (T) -> ?assertEqual(1, decode_text(T, <<"1">>)) end,
+                  [?TYPE_TINY, ?TYPE_SHORT, ?TYPE_LONG, ?TYPE_LONGLONG,
+                   ?TYPE_INT24, ?TYPE_YEAR, ?TYPE_BIT]),
+
+    %% Floating point and decimal numbers
+    lists:foreach(fun (T) -> ?assertEqual(3.0, decode_text(T, <<"3.0">>)) end,
+                  [?TYPE_FLOAT, ?TYPE_DOUBLE]),
+    %% Decimal types
+    lists:foreach(fun (T) ->
+                      ?assertEqual(<<"3.0">>, decode_text(T, <<"3.0">>))
+                  end,
+                  [?TYPE_DECIMAL, ?TYPE_NEWDECIMAL]),
+    ?assertEqual(3.0,  decode_text(?TYPE_FLOAT, <<"3">>)),
+    ?assertEqual(30.0, decode_text(?TYPE_FLOAT, <<"3e1">>)),
+    ?assertEqual(3,    decode_text(?TYPE_LONG, <<"3">>)),
+
+    %% Date and time
+    ?assertEqual({2014, 11, 01}, decode_text(?TYPE_DATE, <<"2014-11-01">>)),
+    ?assertEqual({23, 59, 01}, decode_text(?TYPE_TIME, <<"23:59:01">>)),
+    ?assertEqual({{2014, 11, 01}, {23, 59, 01}},
+                 decode_text(?TYPE_DATETIME, <<"2014-11-01 23:59:01">>)),
+    ?assertEqual({{2014, 11, 01}, {23, 59, 01}},
+                 decode_text(?TYPE_TIMESTAMP, <<"2014-11-01 23:59:01">>)),
+
+    %% Strings and blobs
+    lists:foreach(fun (T) ->
+                      ?assertEqual(<<"x">>, decode_text(T, <<"x">>))
+                  end,
+                  [?TYPE_VARCHAR, ?TYPE_ENUM, ?TYPE_TINY_BLOB,
+                   ?TYPE_MEDIUM_BLOB, ?TYPE_LONG_BLOB, ?TYPE_BLOB,
+                   ?TYPE_VAR_STRING, ?TYPE_STRING, ?TYPE_GEOMETRY]),
+
+    %% Set
+    ?assertEqual(sets:from_list([<<"b">>, <<"a">>]),
+                 decode_text(?TYPE_SET, <<"a,b">>)),
+    ?assertEqual(sets:from_list([]), decode_text(?TYPE_SET, <<>>)),
+
+    %% NULL
+    ?assertEqual(null, decode_text(?TYPE_FLOAT, null)),
+    ok.
+
+null_bitmap_test() ->
+    ?assertEqual({<<0, 1:1>>, <<>>}, null_bitmap_decode(9, <<0, 4>>, 2)),
+    ?assertEqual(<<0, 4>>, null_bitmap_encode(<<0, 1:1>>, 2)),
+    ok.
+
 lenenc_int_test() ->
 lenenc_int_test() ->
     ?assertEqual({40, <<>>}, lenenc_int(<<40>>)),
     ?assertEqual({40, <<>>}, lenenc_int(<<40>>)),
     ?assertEqual({16#ff, <<>>}, lenenc_int(<<16#fc, 255, 0>>)),
     ?assertEqual({16#ff, <<>>}, lenenc_int(<<16#fc, 255, 0>>)),

+ 0 - 124
src/mysql_text_protocol.erl

@@ -1,124 +0,0 @@
-%% MySQL/OTP – a MySQL driver for Erlang/OTP
-%% Copyright (C) 2014 Viktor Söderqvist
-%%
-%% This program is free software: you can redistribute it and/or modify
-%% it under the terms of the GNU General Public License as published by
-%% the Free Software Foundation, either version 3 of the License, or
-%% (at your option) any later version.
-%%
-%% This program is distributed in the hope that it will be useful,
-%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-%% GNU General Public License for more details.
-%%
-%% You should have received a copy of the GNU General Public License
-%% along with this program. If not, see <https://www.gnu.org/licenses/>.
-
-%% @doc This module handles conversion of values in the form they are
-%% represented in the text protocol to our prefered Erlang term representations.
--module(mysql_text_protocol).
-
--export([text_to_term/2]).
-
--include("records.hrl").
--include("protocol.hrl"). %% The TYPE_* macros.
-
-%% @doc When receiving data in the text protocol, we get everything as binaries
-%% (except NULL). This function is used to parse these strings values.
-text_to_term(Type, Text) when is_binary(Text) ->
-    case Type of
-        ?TYPE_DECIMAL -> parse_float(Text); %% <-- this will probably change
-        ?TYPE_TINY -> binary_to_integer(Text);
-        ?TYPE_SHORT -> binary_to_integer(Text);
-        ?TYPE_LONG -> binary_to_integer(Text);
-        ?TYPE_FLOAT -> parse_float(Text);
-        ?TYPE_DOUBLE -> parse_float(Text);
-        ?TYPE_TIMESTAMP -> parse_datetime(Text);
-        ?TYPE_LONGLONG -> binary_to_integer(Text);
-        ?TYPE_INT24 -> binary_to_integer(Text);
-        ?TYPE_DATE -> parse_date(Text);
-        ?TYPE_TIME -> parse_time(Text);
-        ?TYPE_DATETIME -> parse_datetime(Text);
-        ?TYPE_YEAR -> binary_to_integer(Text);
-        ?TYPE_VARCHAR -> Text;
-        ?TYPE_BIT -> binary_to_integer(Text);
-        ?TYPE_NEWDECIMAL -> parse_float(Text); %% <-- this will probably change
-        ?TYPE_ENUM -> Text;
-        ?TYPE_SET when Text == <<>> -> sets:new();
-        ?TYPE_SET -> sets:from_list(binary:split(Text, <<",">>, [global]));
-        ?TYPE_TINY_BLOB -> Text; %% charset?
-        ?TYPE_MEDIUM_BLOB -> Text;
-        ?TYPE_LONG_BLOB -> Text;
-        ?TYPE_BLOB -> Text;
-        ?TYPE_VAR_STRING -> Text;
-        ?TYPE_STRING -> Text;
-        ?TYPE_GEOMETRY -> Text %% <-- what do we want here?
-    end;
-text_to_term(_, null) ->
-    %% NULL is the only value not represented as a binary.
-    null.
-
-parse_datetime(<<Y:4/binary, "-", M:2/binary, "-", D:2/binary, " ",
-                 H:2/binary, ":", Mi:2/binary, ":", S:2/binary>>) ->
-    {{binary_to_integer(Y), binary_to_integer(M), binary_to_integer(D)},
-     {binary_to_integer(H), binary_to_integer(Mi), binary_to_integer(S)}}.
-
-parse_date(<<Y:4/binary, "-", M:2/binary, "-", D:2/binary>>) ->
-    {binary_to_integer(Y), binary_to_integer(M), binary_to_integer(D)}.
-
-parse_time(<<H:2/binary, ":", Mi:2/binary, ":", S:2/binary>>) ->
-    {binary_to_integer(H), binary_to_integer(Mi), binary_to_integer(S)}.
-
-parse_float(Text) ->
-    try binary_to_float(Text)
-    catch error:badarg ->
-        try binary_to_integer(Text) of
-            Int -> float(Int)
-        catch error:badarg ->
-            %% It is something like "4e75" that must be turned into "4.0e75"
-            binary_to_float(binary:replace(Text, <<"e">>, <<".0e">>))
-        end
-    end.
-
--ifdef(TEST).
--include_lib("eunit/include/eunit.hrl").
-
-text_to_term_test() ->
-    %% Int types
-    lists:foreach(fun (T) -> ?assertEqual(1, text_to_term(T, <<"1">>)) end,
-                  [?TYPE_TINY, ?TYPE_SHORT, ?TYPE_LONG, ?TYPE_LONGLONG,
-                   ?TYPE_INT24, ?TYPE_YEAR, ?TYPE_BIT]),
-
-    %% Floating point and decimal numbers
-    lists:foreach(fun (T) -> ?assertEqual(3.0, text_to_term(T, <<"3.0">>)) end,
-                  [?TYPE_FLOAT, ?TYPE_DOUBLE, ?TYPE_DECIMAL, ?TYPE_NEWDECIMAL]),
-    ?assertEqual(3.0,  text_to_term(?TYPE_FLOAT, <<"3">>)),
-    ?assertEqual(30.0, text_to_term(?TYPE_FLOAT, <<"3e1">>)),
-    ?assertEqual(3,    text_to_term(?TYPE_LONG, <<"3">>)),
-
-    %% Date and time
-    ?assertEqual({2014, 11, 01}, text_to_term(?TYPE_DATE, <<"2014-11-01">>)),
-    ?assertEqual({23, 59, 01}, text_to_term(?TYPE_TIME, <<"23:59:01">>)),
-    ?assertEqual({{2014, 11, 01}, {23, 59, 01}},
-                 text_to_term(?TYPE_DATETIME, <<"2014-11-01 23:59:01">>)),
-    ?assertEqual({{2014, 11, 01}, {23, 59, 01}},
-                 text_to_term(?TYPE_TIMESTAMP, <<"2014-11-01 23:59:01">>)),
-
-    %% Strings and blobs
-    lists:foreach(fun (T) ->
-                      ?assertEqual(<<"x">>, text_to_term(T, <<"x">>))
-                  end,
-                  [?TYPE_VARCHAR, ?TYPE_ENUM, ?TYPE_TINY_BLOB,
-                   ?TYPE_MEDIUM_BLOB, ?TYPE_LONG_BLOB, ?TYPE_BLOB,
-                   ?TYPE_VAR_STRING, ?TYPE_STRING, ?TYPE_GEOMETRY]),
-
-    %% Set
-    ?assertEqual(sets:from_list([<<"b">>, <<"a">>]),
-                 text_to_term(?TYPE_SET, <<"a,b">>)),
-    ?assertEqual(sets:from_list([]), text_to_term(?TYPE_SET, <<>>)),
-
-    %% NULL
-    ?assertEqual(null, text_to_term(?TYPE_FLOAT, null)),
-    ok.
-
--endif.

+ 6 - 5
test/mysql_tests.erl

@@ -80,11 +80,12 @@ text_protocol(Pid) ->
     ?assertEqual(1, mysql:affected_rows(Pid)),
     ?assertEqual(1, mysql:affected_rows(Pid)),
 
 
     %% select
     %% select
-    ?assertEqual({ok, [<<"id">>, <<"bl">>, <<"tx">>, <<"f">>, <<"dc">>,
-                       <<"ti">>, <<"ts">>, <<"da">>, <<"c">>],
-                      [[1, <<"blob">>, <<>>, 3.14, 3.14, {0, 22, 11},
-                        {{2014, 11, 03}, {00, 22, 24}}, {2014, 11, 03}, null]]},
-                 mysql:query(Pid, <<"SELECT * FROM t">>)),
+    {ok, Columns, Rows} = mysql:query(Pid, <<"SELECT * FROM t">>),
+    ?assertEqual([<<"id">>, <<"bl">>, <<"tx">>, <<"f">>, <<"dc">>, <<"ti">>,
+                  <<"ts">>, <<"da">>, <<"c">>], Columns),
+    ?assertEqual([[1, <<"blob">>, <<>>, 3.14, <<"3.140">>, {0, 22, 11},
+                   {{2014, 11, 03}, {00, 22, 24}}, {2014, 11, 03}, null]],
+                 Rows),
     ok.
     ok.
 
 
 binary_protocol(Pid) ->
 binary_protocol(Pid) ->