Viktor Söderqvist 10 лет назад
Родитель
Сommit
d1260adf16
5 измененных файлов с 239 добавлено и 186 удалено
  1. 2 2
      include/records.hrl
  2. 14 10
      src/mysql_connection.erl
  3. 149 100
      src/mysql_protocol.erl
  4. 1 3
      test/mysql_protocol_tests.erl
  5. 73 71
      test/mysql_tests.erl

+ 2 - 2
include/records.hrl

@@ -40,12 +40,12 @@
 -record(eof, {status, warning_count}).
 
 %% Column definition, used while parsing a result set.
--record(column_definition, {name, type, charset}).
+-record(col, {name, type, charset, length, decimals, flags}).
 
 %% A resultset. The rows can be either lists of terms or unparsed binaries as
 %% received from the server using either the text protocol or the binary
 %% protocol.
--record(resultset, {column_definitions :: [#column_definition{}],
+-record(resultset, {cols :: [#col{}],
                     rows :: [[term()] | binary()]}).
 
 %% Response of a successfull prepare call.

+ 14 - 10
src/mysql_connection.erl

@@ -99,11 +99,12 @@ handle_call({query, Query}, _From, State) when is_binary(Query);
             {reply, ok, State1};
         #error{} = E ->
             {reply, {error, error_to_reason(E)}, State1};
-        #resultset{column_definitions = ColDefs, rows = Rows} ->
-            Names = [Def#column_definition.name || Def <- ColDefs],
+        #resultset{cols = ColDefs, rows = Rows} ->
+            Names = [Def#col.name || Def <- ColDefs],
             {reply, {ok, Names, Rows}, State1}
     end;
-handle_call({execute, Stmt, Args}, _From, State) ->
+handle_call({execute, Stmt, Args}, _From, State) when is_atom(Stmt);
+                                                      is_integer(Stmt) ->
     case dict:find(Stmt, State#state.stmts) of
         {ok, StmtRec} ->
             #state{socket = Socket, timeout = Timeout} = State,
@@ -116,8 +117,8 @@ handle_call({execute, Stmt, Args}, _From, State) ->
                     {reply, ok, State1};
                 #error{} = E ->
                     {reply, {error, error_to_reason(E)}, State1};
-                #resultset{column_definitions = ColDefs, rows = Rows} ->
-                    Names = [Def#column_definition.name || Def <- ColDefs],
+                #resultset{cols = ColDefs, rows = Rows} ->
+                    Names = [Def#col.name || Def <- ColDefs],
                     {reply, {ok, Names, Rows}, State1}
             end;
         error ->
@@ -159,14 +160,15 @@ handle_call({prepare, Name, Query}, _From, State) when is_atom(Name) ->
             State3 = State2#state{stmts = Stmts1},
             {reply, {ok, Name}, State3}
     end;
-handle_call({unprepare, Name}, _From, State) ->
-    case dict:find(Name, State#state.stmts) of
+handle_call({unprepare, Stmt}, _From, State) when is_atom(Stmt);
+                                                  is_integer(Stmt) ->
+    case dict:find(Stmt, State#state.stmts) of
         {ok, StmtRec} ->
             #state{socket = Socket, timeout = Timeout} = State,
             SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
             RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
             mysql_protocol:unprepare(StmtRec, SendFun, RecvFun),
-            Stmts1 = dict:erase(Name, State#state.stmts),
+            Stmts1 = dict:erase(Stmt, State#state.stmts),
             {reply, ok, State#state{stmts = Stmts1}};
         error ->
             {reply, {error, not_prepared}, State}
@@ -193,12 +195,14 @@ handle_cast(_Msg, State) ->
 handle_info(_Info, State) ->
     {noreply, State}.
 
-terminate(_Reason, State) ->
+terminate(Reason, State) when Reason == normal; Reason == shutdown ->
     %% Send the goodbye message for politeness.
     #state{socket = Socket, timeout = Timeout} = State,
     SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
     RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
-    mysql_protocol:quit(SendFun, RecvFun).
+    mysql_protocol:quit(SendFun, RecvFun);
+terminate(_Reason, _State) ->
+    ok.
 
 code_change(_OldVsn, State, _Extra) ->
     {ok, State}.

+ 149 - 100
src/mysql_protocol.erl

@@ -23,6 +23,7 @@
 %%
 %% TCP communication is not handled in this module. Most of the public functions
 %% take funs for data communitaction as parameters.
+%% @private
 -module(mysql_protocol).
 
 -export([handshake/5, quit/2,
@@ -88,12 +89,10 @@ query(Query, SendFun, RecvFun) ->
             case fetch_resultset(RecvFun, ColumnCount, SeqNum2) of
                 #error{} = E ->
                     E;
-                #resultset{column_definitions = ColDefs, rows = Rows} = R ->
+                #resultset{cols = ColDefs, rows = Rows} = R ->
                     %% Parse the rows according to the 'text protocol'
                     %% representation.
-                    ColumnTypes = [ColDef#column_definition.type
-                                   || ColDef <- ColDefs],
-                    Rows1 = [decode_text_row(ColumnCount, ColumnTypes, Row)
+                    Rows1 = [decode_text_row(ColumnCount, ColDefs, Row)
                              || Row <- Rows],
                     R#resultset{rows = Rows1}
             end
@@ -182,12 +181,10 @@ execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
                     %% This can happen for the text protocol but maybe not for
                     %% the binary protocol.
                     E;
-                #resultset{column_definitions = ColDefs, rows = Rows} = R ->
+                #resultset{cols = ColDefs, rows = Rows} = R ->
                     %% Parse the rows according to the 'binary protocol'
                     %% representation.
-                    ColumnTypes = [ColDef#column_definition.type
-                                   || ColDef <- ColDefs],
-                    Rows1 = [decode_binary_row(ColumnCount, ColumnTypes, Row)
+                    Rows1 = [decode_binary_row(ColumnCount, ColDefs, Row)
                              || Row <- Rows],
                     R#resultset{rows = Rows1}
             end
@@ -293,7 +290,12 @@ parse_handshake_confirm(Packet) ->
             error(auth_method_switch)
     end.
 
-%% Fetches packets until a
+%% -- both text and binary protocol --
+
+%% @doc Fetches packets for a result set. The column definitions are parsed but
+%% the rows are unparsed binary packages. This function is used for both the
+%% text protocol and the binary protocol. This affects the way the rows need to
+%% be parsed.
 -spec fetch_resultset(recvfun(), integer(), integer()) ->
     #resultset{} | #error{}.
 fetch_resultset(RecvFun, FieldCount, SeqNum) ->
@@ -303,21 +305,20 @@ fetch_resultset(RecvFun, FieldCount, SeqNum) ->
     #eof{} = parse_eof_packet(DelimiterPacket),
     case fetch_resultset_rows(RecvFun, SeqNum2, []) of
         {ok, Rows, _SeqNum3} ->
-            #resultset{column_definitions = ColDefs, rows = Rows};
+            ColDefs1 = lists:map(fun parse_column_definition/1, ColDefs),
+            #resultset{cols = ColDefs1, rows = Rows};
         #error{} = E ->
             E
     end.
 
-%% @doc Receives NumLeft packets and parses them as column definitions.
-%% TODO: Don't parse them here. That's a sepatate thing we not always need to
-%% do.
+%% @doc Receives NumLeft column definition packets. They are not parsed.
+%% @see parse_column_definition/1
 -spec fetch_column_definitions(recvfun(), SeqNum :: integer(),
-                               NumLeft :: integer(), Acc :: [tuple()]) ->
-    {ok, [tuple()], NextSeqNum :: integer()}.
+                               NumLeft :: integer(), Acc :: [binary()]) ->
+    {ok, ColDefPackets :: [binary()], NextSeqNum :: integer()}.
 fetch_column_definitions(RecvFun, SeqNum, NumLeft, Acc) when NumLeft > 0 ->
     {ok, Packet, SeqNum1} = recv_packet(RecvFun, SeqNum),
-    ColDef = parse_column_definition(Packet),
-    fetch_column_definitions(RecvFun, SeqNum1, NumLeft - 1, [ColDef | Acc]);
+    fetch_column_definitions(RecvFun, SeqNum1, NumLeft - 1, [Packet | Acc]);
 fetch_column_definitions(_RecvFun, SeqNum, 0, Acc) ->
     {ok, lists:reverse(Acc), SeqNum}.
 
@@ -339,8 +340,6 @@ fetch_resultset_rows(RecvFun, SeqNum, Acc) ->
             fetch_resultset_rows(RecvFun, SeqNum1, [Row | Acc])
     end.
 
-%% -- both text and binary protocol --
-
 %% Parses a packet containing a column definition (part of a result set)
 parse_column_definition(Data) ->
     {<<"def">>, Rest1} = lenenc_str(Data),   %% catalog (always "def")
@@ -352,10 +351,10 @@ parse_column_definition(Data) ->
     {16#0c, Rest7} = lenenc_int(Rest6),      %% length of the following fields
                                              %% (always 0x0c)
     <<Charset:16/little,        %% column character set
-      _ColumnLength:32/little,  %% maximum length of the field
-      ColumnType:8,             %% type of the column as defined in Column Type
-      _Flags:16/little,         %% flags
-      _Decimals:8,              %% max shown decimal digits:
+      Length:32/little,         %% maximum length of the field
+      Type:8,                   %% type of the column as defined in Column Type
+      Flags:16/little,          %% flags
+      Decimals:8,               %% max shown decimal digits:
       0,  %% "filler"           %%   - 0x00 for integers and static strings
       0,                        %%   - 0x1f for dynamic strings, double, float
       Rest8/binary>> = Rest7,   %%   - 0x00 to 0x51 for decimals
@@ -363,26 +362,28 @@ parse_column_definition(Data) ->
     %%   default values: lenenc_str
     %% }
     <<>> = Rest8,
-    #column_definition{name = Name, type = ColumnType, charset = Charset}.
+    #col{name = Name, type = Type, charset = Charset, length = Length,
+         decimals = Decimals, flags = Flags}.
 
 %% -- text protocol --
 
--spec decode_text_row(NumColumns :: integer(), ColumnTypes :: integer(),
+-spec decode_text_row(NumColumns :: integer(),
+                      ColumnDefinitions :: [#col{}],
                       Data :: binary()) -> [term()].
-decode_text_row(_NumColumns, ColumnTypes, Data) ->
-    decode_text_row_acc(ColumnTypes, Data, []).
+decode_text_row(_NumColumns, ColumnDefs, Data) ->
+    decode_text_row_acc(ColumnDefs, Data, []).
 
 %% parses Data using ColDefs and builds the values Acc.
-decode_text_row_acc([Type | Types], Data, Acc) ->
+decode_text_row_acc([ColDef | ColDefs], Data, Acc) ->
     case Data of
         <<16#fb, Rest/binary>> ->
             %% NULL
-            decode_text_row_acc(Types, Rest, [null | Acc]);
+            decode_text_row_acc(ColDefs, Rest, [null | Acc]);
         _ ->
             %% Every thing except NULL
             {Text, Rest} = lenenc_str(Data),
-            Term = decode_text(Type, Text),
-            decode_text_row_acc(Types, Rest, [Term | Acc])
+            Term = decode_text(ColDef, Text),
+            decode_text_row_acc(ColDefs, Rest, [Term | Acc])
     end;
 decode_text_row_acc([], <<>>, Acc) ->
     lists:reverse(Acc).
@@ -392,20 +393,25 @@ decode_text_row_acc([], <<>>, Acc) ->
 decode_text(_, null) ->
     %% NULL is the only value not represented as a binary.
     null;
-decode_text(T, Text)
+decode_text(#col{type = T}, Text)
   when T == ?TYPE_TINY; T == ?TYPE_SHORT; T == ?TYPE_LONG; T == ?TYPE_LONGLONG;
-       T == ?TYPE_INT24; T == ?TYPE_YEAR; T == ?TYPE_BIT ->
-    %% For BIT, do we want bitstring, int or binary?
+       T == ?TYPE_INT24; T == ?TYPE_YEAR ->
     binary_to_integer(Text);
-decode_text(T, Text)
-  when T == ?TYPE_DECIMAL; T == ?TYPE_NEWDECIMAL; T == ?TYPE_VARCHAR;
-       T == ?TYPE_ENUM; T == ?TYPE_TINY_BLOB; T == ?TYPE_MEDIUM_BLOB;
-       T == ?TYPE_LONG_BLOB; T == ?TYPE_BLOB; T == ?TYPE_VAR_STRING;
-       T == ?TYPE_STRING; T == ?TYPE_GEOMETRY ->
+decode_text(#col{type = T}, Text)
+  when T == ?TYPE_STRING; T == ?TYPE_VARCHAR; T == ?TYPE_VAR_STRING;
+       T == ?TYPE_ENUM; T == ?TYPE_SET; T == ?TYPE_LONG_BLOB;
+       T == ?TYPE_MEDIUM_BLOB; T == ?TYPE_BLOB; T == ?TYPE_TINY_BLOB;
+       T == ?TYPE_GEOMETRY; T == ?TYPE_DECIMAL; T == ?TYPE_NEWDECIMAL ->
+    %% As of MySQL 5.6.21 we receive SET and ENUM values as STRING, i.e. we
+    %% cannot convert them to atom() or sets:set(), etc.
+    Text;
+decode_text(#col{type = ?TYPE_BIT, length = _Length}, Text) ->
+    %% TODO: Convert to <<_:Length/bitstring>>
     Text;
-decode_text(?TYPE_DATE, <<Y:4/binary, "-", M:2/binary, "-", D:2/binary>>) ->
+decode_text(#col{type = ?TYPE_DATE},
+            <<Y:4/binary, "-", M:2/binary, "-", D:2/binary>>) ->
     {binary_to_integer(Y), binary_to_integer(M), binary_to_integer(D)};
-decode_text(?TYPE_TIME, Text) ->
+decode_text(#col{type = ?TYPE_TIME}, Text) ->
     {match, [Sign, Hbin, Mbin, Sbin, Frac]} =
         re:run(Text,
                <<"^(-?)(\\d+):(\\d+):(\\d+)(\\.?\\d*)$">>,
@@ -426,19 +432,22 @@ decode_text(?TYPE_TIME, Text) ->
            end,
     {Days, {Hours, Minutes, Seconds}} = calendar:seconds_to_daystime(Sec3),
     {Days, {Hours, Minutes, Seconds + Fraction}};
-decode_text(T, <<Y:4/binary, "-", M:2/binary, "-", D:2/binary, " ",
-                 H:2/binary, ":", Mi:2/binary, ":", S:2/binary>>)
+decode_text(#col{type = T},
+            <<Y:4/binary, "-", M:2/binary, "-", D:2/binary, " ",
+              H:2/binary, ":", Mi:2/binary, ":", S:2/binary>>)
   when T == ?TYPE_TIMESTAMP; T == ?TYPE_DATETIME ->
     %% Without fractions.
     {{binary_to_integer(Y), binary_to_integer(M), binary_to_integer(D)},
      {binary_to_integer(H), binary_to_integer(Mi), binary_to_integer(S)}};
-decode_text(T, <<Y:4/binary, "-", M:2/binary, "-", D:2/binary, " ",
-                 H:2/binary, ":", Mi:2/binary, ":", FloatS/binary>>)
+decode_text(#col{type = T},
+            <<Y:4/binary, "-", M:2/binary, "-", D:2/binary, " ",
+              H:2/binary, ":", Mi:2/binary, ":", FloatS/binary>>)
   when T == ?TYPE_TIMESTAMP; T == ?TYPE_DATETIME ->
     %% With fractions.
     {{binary_to_integer(Y), binary_to_integer(M), binary_to_integer(D)},
      {binary_to_integer(H), binary_to_integer(Mi), binary_to_float(FloatS)}};
-decode_text(T, Text) when T == ?TYPE_FLOAT; T == ?TYPE_DOUBLE ->
+decode_text(#col{type = T}, Text) when T == ?TYPE_FLOAT;
+                                                     T == ?TYPE_DOUBLE ->
     try binary_to_float(Text)
     catch error:badarg ->
         try binary_to_integer(Text) of
@@ -447,11 +456,7 @@ decode_text(T, Text) when T == ?TYPE_FLOAT; T == ?TYPE_DOUBLE ->
             %% It is something like "4e75" that must be turned into "4.0e75"
             binary_to_float(binary:replace(Text, <<"e">>, <<".0e">>))
         end
-    end;
-decode_text(?TYPE_SET, <<>>) ->
-    sets:new();
-decode_text(?TYPE_SET, Text) ->
-    sets:from_list(binary:split(Text, <<",">>, [global])).
+    end.
 
 %% -- binary protocol --
 
@@ -468,21 +473,22 @@ fetch_column_definitions_if_any(N, RecvFun, SeqNum) ->
 %% It consists of a 0 byte, then a null bitmap, then the values.
 %% Returns a list of length NumColumns with terms of appropriate types for each
 %% MySQL type in ColumnTypes.
--spec decode_binary_row(NumColumns :: integer(), ColumnTypes :: [integer()],
-                 Data :: binary()) -> [term()].
-decode_binary_row(NumColumns, ColumnTypes, <<0, Data/binary>>) ->
+-spec decode_binary_row(NumColumns :: integer(),
+                        ColumnDefs :: [#col{}],
+                        Data :: binary()) -> [term()].
+decode_binary_row(NumColumns, ColumnDefs, <<0, Data/binary>>) ->
     {NullBitMap, Rest} = null_bitmap_decode(NumColumns, Data, 2),
-    decode_binary_row_acc(ColumnTypes, NullBitMap, Rest, []).
+    decode_binary_row_acc(ColumnDefs, NullBitMap, Rest, []).
 
 %% @doc Accumulating helper for decode_binary_row/3.
-decode_binary_row_acc([_ | Types], <<1:1, NullBitMap/bitstring>>, Data, Acc) ->
+decode_binary_row_acc([_ | ColDefs], <<1:1, NullBitMap/bitstring>>, Data, Acc) ->
     %% NULL
-    decode_binary_row_acc(Types, NullBitMap, Data, [null | Acc]);
-decode_binary_row_acc([Type | Types], <<0:1, NullBitMap/bitstring>>, Data,
+    decode_binary_row_acc(ColDefs, NullBitMap, Data, [null | Acc]);
+decode_binary_row_acc([ColDef | ColDefs], <<0:1, NullBitMap/bitstring>>, Data,
                       Acc) ->
-   %% Not NULL
-   {Term, Rest} = decode_binary(Type, Data),
-   decode_binary_row_acc(Types, NullBitMap, Rest, [Term | Acc]);
+    %% Not NULL
+    {Term, Rest} = decode_binary(ColDef, Data),
+    decode_binary_row_acc(ColDefs, NullBitMap, Rest, [Term | Acc]);
 decode_binary_row_acc([], _, <<>>, Acc) ->
     lists:reverse(Acc).
 
@@ -531,28 +537,33 @@ build_null_bitmap(Values) ->
 %% The types are type constants for the binary protocol, such as
 %% ProtocolBinary::MYSQL_TYPE_STRING. In the guide "MySQL Internals" these are
 %% not listed, but we assume that are the same as for the text protocol.
--spec decode_binary(Type :: integer(), Data :: binary()) ->
+-spec decode_binary(ColDef :: #col{}, Data :: binary()) ->
     {Term :: term(), Rest :: binary()}.
-decode_binary(T, Data)
+decode_binary(#col{type = T}, Data)
   when T == ?TYPE_STRING; T == ?TYPE_VARCHAR; T == ?TYPE_VAR_STRING;
        T == ?TYPE_ENUM; T == ?TYPE_SET; T == ?TYPE_LONG_BLOB;
        T == ?TYPE_MEDIUM_BLOB; T == ?TYPE_BLOB; T == ?TYPE_TINY_BLOB;
        T == ?TYPE_GEOMETRY; T == ?TYPE_BIT; T == ?TYPE_DECIMAL;
        T == ?TYPE_NEWDECIMAL ->
+    %% As of MySQL 5.6.21 we receive SET and ENUM values as STRING, i.e. we
+    %% cannot convert them to atom() or sets:set(), etc.
     lenenc_str(Data);
-decode_binary(?TYPE_LONGLONG, <<Value:64/signed-little, Rest/binary>>) ->
+decode_binary(#col{type = ?TYPE_LONGLONG},
+              <<Value:64/signed-little, Rest/binary>>) ->
     {Value, Rest};
-decode_binary(T, <<Value:32/signed-little, Rest/binary>>)
+decode_binary(#col{type = T}, <<Value:32/signed-little, Rest/binary>>)
   when T == ?TYPE_LONG; T == ?TYPE_INT24 ->
     {Value, Rest};
-decode_binary(T, <<Value:16/signed-little, Rest/binary>>)
+decode_binary(#col{type = T}, <<Value:16/signed-little, Rest/binary>>)
   when T == ?TYPE_SHORT; T == ?TYPE_YEAR ->
     {Value, Rest};
-decode_binary(?TYPE_TINY, <<Value:8, Rest/binary>>) ->
+decode_binary(#col{type = ?TYPE_TINY}, <<Value:8, Rest/binary>>) ->
     {Value, Rest};
-decode_binary(?TYPE_DOUBLE, <<Value:64/float-little, Rest/binary>>) ->
+decode_binary(#col{type = ?TYPE_DOUBLE},
+              <<Value:64/float-little, Rest/binary>>) ->
     {Value, Rest};
-decode_binary(?TYPE_FLOAT, <<Value:32/float-little, Rest/binary>>) ->
+decode_binary(#col{type = ?TYPE_FLOAT},
+              <<Value:32/float-little, Rest/binary>>) ->
     %% There is a precision loss when storing and fetching a 32-bit float.
     %% In the text protocol, it is obviously rounded. Storing 3.14 in a FLOAT
     %% column and fetching it using the text protocol, we get "3.14" which we
@@ -587,14 +598,14 @@ decode_binary(?TYPE_FLOAT, <<Value:32/float-little, Rest/binary>>) ->
     Factor = math:pow(10, floor(6 - math:log10(abs(Value)))),
     RoundedValue = round(Value * Factor) / Factor,
     {RoundedValue, Rest};
-decode_binary(?TYPE_DATE, <<Length, Data/binary>>) ->
+decode_binary(#col{type = ?TYPE_DATE}, <<Length, Data/binary>>) ->
     %% Coded in the same way as DATETIME and TIMESTAMP below, but returned in
     %% a simple triple.
     case {Length, Data} of
         {0, _} -> {{0, 0, 0}, Data};
         {4, <<Y:16/little, M, D, Rest/binary>>} -> {{Y, M, D}, Rest}
     end;
-decode_binary(T, <<Length, Data/binary>>)
+decode_binary(#col{type = T}, <<Length, Data/binary>>)
   when T == ?TYPE_DATETIME; T == ?TYPE_TIMESTAMP ->
     %% length (1) -- number of bytes following (valid values: 0, 4, 7, 11)
     case {Length, Data} of
@@ -607,7 +618,7 @@ decode_binary(T, <<Length, Data/binary>>)
         {11, <<Y:16/little, M, D, H, Mi, S, Micro:32/little, Rest/binary>>} ->
             {{{Y, M, D}, {H, Mi, S + 0.000001 * Micro}}, Rest}
     end;
-decode_binary(?TYPE_TIME, <<Length, Data/binary>>) ->
+decode_binary(#col{type = ?TYPE_TIME}, <<Length, Data/binary>>) ->
     %% length (1) -- number of bytes following (valid values: 0, 8, 12)
     %% is_negative (1) -- (1 if minus, 0 for plus)
     %% days (4) -- days
@@ -690,6 +701,11 @@ encode_param(Value) when is_integer(Value), Value < 0 ->
     end;
 encode_param(Value) when is_float(Value) ->
     {<<?TYPE_DOUBLE, 0>>, <<Value:64/float-little>>};
+encode_param(Set) when is_tuple(Set), element(1, Set) == set ->
+    %% For convenience; encode only. When decoding, a set is returned as binary.
+    Binary = set_to_binary(Set),
+    EncLength = lenenc_int_encode(byte_size(Binary)),
+    {<<?TYPE_SET, 0>>, <<EncLength/binary, Binary/binary>>};
 encode_param({Y, M, D}) ->
     %% calendar:date()
     {<<?TYPE_DATE, 0>>, <<4, Y:16/little, M, D>>};
@@ -729,6 +745,20 @@ encode_param({D, {H, M, S}}) when is_float(S), S > 0.0, D < 0 ->
 encode_param({D, {H, M, 0.0}}) ->
     encode_param({D, {H, M, 0}}).
 
+%% -- Value representation in both the text and binary protocols --
+
+%% @doc Converts a set of atoms (or binaries) to a comma-separated binary.
+set_to_binary(Set) ->
+    List = [if is_atom(X) -> atom_to_binary(X, utf8); is_binary(X) -> X end
+            || X <- sets:to_list(Set)],
+    case List of
+        [] ->
+            <<>>;
+        [First | Rest] ->
+            lists:foldl(fun (X, Acc) -> <<Acc/binary, ",", X/binary>> end,
+                        First, Rest)
+    end.
+
 %% -- Protocol basics: packets --
 
 %% @doc Wraps Data in packet headers, sends it by calling SendFun and returns
@@ -903,61 +933,80 @@ nulterm_str(Bin) ->
 
 decode_text_test() ->
     %% Int types
-    lists:foreach(fun (T) -> ?assertEqual(1, decode_text(T, <<"1">>)) end,
+    lists:foreach(fun (T) ->
+                      ?assertEqual(1, decode_text(#col{type = T}, <<"1">>))
+                  end,
                   [?TYPE_TINY, ?TYPE_SHORT, ?TYPE_LONG, ?TYPE_LONGLONG,
-                   ?TYPE_INT24, ?TYPE_YEAR, ?TYPE_BIT]),
+                   ?TYPE_INT24, ?TYPE_YEAR]),
+
+    %% BIT
+    <<217>> = decode_text(#col{type = ?TYPE_BIT}, <<217>>),
 
     %% Floating point and decimal numbers
-    lists:foreach(fun (T) -> ?assertEqual(3.0, decode_text(T, <<"3.0">>)) end,
+    lists:foreach(fun (T) ->
+                      ?assertEqual(3.0, decode_text(#col{type = T}, <<"3.0">>))
+                  end,
                   [?TYPE_FLOAT, ?TYPE_DOUBLE]),
     %% Decimal types
     lists:foreach(fun (T) ->
-                      ?assertEqual(<<"3.0">>, decode_text(T, <<"3.0">>))
+                      ColDef = #col{type = T},
+                      ?assertEqual(<<"3.0">>, decode_text(ColDef, <<"3.0">>))
                   end,
                   [?TYPE_DECIMAL, ?TYPE_NEWDECIMAL]),
-    ?assertEqual(3.0,  decode_text(?TYPE_FLOAT, <<"3">>)),
-    ?assertEqual(30.0, decode_text(?TYPE_FLOAT, <<"3e1">>)),
-    ?assertEqual(3,    decode_text(?TYPE_LONG, <<"3">>)),
+    ?assertEqual(3.0,  decode_text(#col{type = ?TYPE_FLOAT}, <<"3">>)),
+    ?assertEqual(30.0, decode_text(#col{type = ?TYPE_FLOAT}, <<"3e1">>)),
+    ?assertEqual(3,    decode_text(#col{type = ?TYPE_LONG}, <<"3">>)),
 
     %% Date and time
-    ?assertEqual({2014, 11, 01}, decode_text(?TYPE_DATE, <<"2014-11-01">>)),
-    ?assertEqual({0, {23, 59, 01}}, decode_text(?TYPE_TIME, <<"23:59:01">>)),
+    ?assertEqual({2014, 11, 01},
+                 decode_text(#col{type = ?TYPE_DATE}, <<"2014-11-01">>)),
+    ?assertEqual({0, {23, 59, 01}},
+                 decode_text(#col{type = ?TYPE_TIME}, <<"23:59:01">>)),
     ?assertEqual({{2014, 11, 01}, {23, 59, 01}},
-                 decode_text(?TYPE_DATETIME, <<"2014-11-01 23:59:01">>)),
+                 decode_text(#col{type = ?TYPE_DATETIME},
+                             <<"2014-11-01 23:59:01">>)),
     ?assertEqual({{2014, 11, 01}, {23, 59, 01}},
-                 decode_text(?TYPE_TIMESTAMP, <<"2014-11-01 23:59:01">>)),
+                 decode_text(#col{type = ?TYPE_TIMESTAMP},
+                             <<"2014-11-01 23:59:01">>)),
 
     %% Strings and blobs
     lists:foreach(fun (T) ->
-                      ?assertEqual(<<"x">>, decode_text(T, <<"x">>))
+                      ColDef = #col{type = T},
+                      ?assertEqual(<<"x">>, decode_text(ColDef, <<"x">>))
                   end,
                   [?TYPE_VARCHAR, ?TYPE_ENUM, ?TYPE_TINY_BLOB,
                    ?TYPE_MEDIUM_BLOB, ?TYPE_LONG_BLOB, ?TYPE_BLOB,
                    ?TYPE_VAR_STRING, ?TYPE_STRING, ?TYPE_GEOMETRY]),
 
-    %% Set
-    ?assertEqual(sets:from_list([<<"b">>, <<"a">>]),
-                 decode_text(?TYPE_SET, <<"a,b">>)),
-    ?assertEqual(sets:from_list([]), decode_text(?TYPE_SET, <<>>)),
-
     %% NULL
-    ?assertEqual(null, decode_text(?TYPE_FLOAT, null)),
+    ?assertEqual(null, decode_text(#col{type = ?TYPE_FLOAT}, null)),
     ok.
 
 decode_binary_test() ->
     %% Test the special rounding we apply to (single precision) floats.
-    %?assertEqual({1.0, <<>>},
-    %             decode_binary(?TYPE_FLOAT, <<1.0:32/float-little>>)),
-    %?assertEqual({0.2, <<>>},
-    %             decode_binary(?TYPE_FLOAT, <<0.2:32/float-little>>)),
-    %?assertEqual({-33.3333, <<>>},
-    %             decode_binary(?TYPE_FLOAT, <<-33.333333:32/float-little>>)),
-    %?assertEqual({0.000123457, <<>>},
-    %             decode_binary(?TYPE_FLOAT, <<0.00012345678:32/float-little>>)),
-    %?assertEqual({1234.57, <<>>},
-    %             decode_binary(?TYPE_FLOAT, <<1234.56789:32/float-little>>)),
+    ?assertEqual({1.0, <<>>},
+                 decode_binary(#col{type = ?TYPE_FLOAT},
+                               <<1.0:32/float-little>>)),
+    ?assertEqual({0.2, <<>>},
+                 decode_binary(#col{type = ?TYPE_FLOAT},
+                               <<0.2:32/float-little>>)),
+    ?assertEqual({-33.3333, <<>>},
+                 decode_binary(#col{type = ?TYPE_FLOAT},
+                               <<-33.333333:32/float-little>>)),
+    ?assertEqual({0.000123457, <<>>},
+                 decode_binary(#col{type = ?TYPE_FLOAT},
+                               <<0.00012345678:32/float-little>>)),
+    ?assertEqual({1234.57, <<>>},
+                 decode_binary(#col{type = ?TYPE_FLOAT},
+                               <<1234.56789:32/float-little>>)),
     ok.
 
+encode_param_test() ->
+    %% Additional representations for common types for convenience
+    {<<?TYPE_SET, 0>>, EncodedSet} = encode_param(sets:from_list([foo, bar])),
+    ?assert(EncodedSet == <<7, "foo,bar">> orelse
+            EncodedSet == <<7, "bar,foo">>).
+
 null_bitmap_test() ->
     ?assertEqual({<<0, 1:1>>, <<>>}, null_bitmap_decode(9, <<0, 4>>, 2)),
     ?assertEqual(<<0, 4>>, null_bitmap_encode(<<0, 1:1>>, 2)),

+ 1 - 3
test/mysql_protocol_tests.erl

@@ -43,9 +43,7 @@ resultset_test() ->
     RecvFun = fun (Size) -> fakesocket_recv(FakeSock, Size) end,
     ResultSet = mysql_protocol:query(Query, SendFun, RecvFun),
     fakesocket_close(FakeSock),
-    ?assertMatch(#resultset{column_definitions =
-                                [#column_definition{
-                                     name = <<"@@version_comment">>}],
+    ?assertMatch(#resultset{cols = [#col{name = <<"@@version_comment">>}],
                             rows = [[<<"MySQL Community Server (GPL)">>]]},
                  ResultSet),
     ok.

+ 73 - 71
test/mysql_tests.erl

@@ -61,6 +61,8 @@ query_test_() ->
              fun text_protocol/1,
              fun binary_protocol/1,
              fun float_rounding/1,
+             fun int/1,
+             %fun bit/1,
              fun time/1,
              fun microseconds/1]}}.
 
@@ -107,10 +109,6 @@ text_protocol(Pid) ->
 
     %% TODO:
     %% * More types: BIT, SET, ENUM, GEOMETRY
-    %% * TIME with negative hours
-    %% * TIME with more than 2 digits in hour.
-    %% * TIME with microseconds
-    %% * Negative TIME
 
     ok = mysql:query(Pid, <<"DROP TABLE t">>).
 
@@ -141,9 +139,6 @@ binary_protocol(Pid) ->
     %% * Values for all types
     %% * Negative numbers for all integer types
     %% * Integer overflow
-    %% * TIME with more than 2 digits in hour.
-    %% * TIME with microseconds
-    %% * Negative TIME
 
     ok = mysql:query(Pid, <<"DROP TABLE t">>).
 
@@ -195,42 +190,39 @@ float_rounding(Pid) ->
                 TestData),
     ok = mysql:query(Pid, "DROP TABLE f").
 
+int(Pid) ->
+    ok = mysql:query(Pid, "CREATE TABLE ints (i INT)"),
+    write_read_text_binary(Pid, 42, <<"42">>, <<"ints">>, <<"i">>),
+    write_read_text_binary(Pid, -42, <<"-42">>, <<"ints">>, <<"i">>),
+    write_read_text_binary(Pid, 987654321, <<"987654321">>, <<"ints">>,
+                           <<"i">>),
+    write_read_text_binary(Pid, -987654321, <<"-987654321">>,
+                           <<"ints">>, <<"i">>),
+    ok = mysql:query(Pid, "DROP TABLE ints").
+
+%% The BIT(N) datatype in MySQL 5.0.3 and later: the equivallent to bitstring()
+bit(Pid) ->
+    ok = mysql:query(Pid, "CREATE TABLE bits (b BIT(11))"),
+    write_read_text_binary(Pid, <<16#ff, 0:3>>, <<"b'11111111000'">>,
+                           <<"bits">>, <<"b">>),
+    write_read_text_binary(Pid, <<16#7f, 2:3>>, <<"b'01111111110'">>,
+                           <<"bits">>, <<"b">>),
+    ok = mysql:query(Pid, "DROP TABLE bits").
+
 %% Test TIME value representation. There are a few things to check.
 time(Pid) ->
     ok = mysql:query(Pid, "CREATE TABLE tm (tm TIME)"),
-    {ok, Insert} = mysql:prepare(Pid, "INSERT INTO tm VALUES (?)"),
-    {ok, Select} = mysql:prepare(Pid, "SELECT tm FROM tm"),
     lists:foreach(
-        fun ({Value, Text}) ->
-            %% --- Insert using text query ---
-            ok = mysql:query(Pid, ["INSERT INTO tm VALUES ('", Text, "')"]),
-            %% Select using prepared statement
-            ?assertEqual({ok, [<<"tm">>], [[Value]]},
-                         mysql:execute(Pid, Select, [])),
-            %% Select using plain query
-            ?assertEqual({ok, [<<"tm">>], [[Value]]},
-                         mysql:query(Pid, "SELECT tm FROM tm")),
-            %% Empty table
-            ok = mysql:query(Pid, "DELETE FROM tm"),
-            %% --- Insert using prepared statement ---
-            ok = mysql:execute(Pid, Insert, [Value]),
-            %% Select using prepared statement
-            ?assertEqual({ok, [<<"tm">>], [[Value]]},
-                         mysql:execute(Pid, Select, [])),
-            %% Select using plain query
-            ?assertEqual({ok, [<<"tm">>], [[Value]]},
-                         mysql:query(Pid, "SELECT tm FROM tm")),
-            %% Empty table
-            ok = mysql:query(Pid, "DELETE FROM tm"),
-            ok
+        fun ({Value, SqlLiteral}) ->
+            write_read_text_binary(Pid, Value, SqlLiteral, <<"tm">>, <<"tm">>)
         end,
-        [{{0, {10, 11, 12}},   "10:11:12"},
-         {{5, {0, 0, 1}},     "120:00:01"},
-         {{-1, {23, 59, 59}}, "-00:00:01"},
-         {{-1, {23, 59, 0}},  "-00:01:00"},
-         {{-1, {23, 0, 0}},   "-01:00:00"},
-         {{-1, {0, 0, 0}},    "-24:00:00"},
-         {{-5, {10, 0, 0}},  "-110:00:00"}]
+        [{{0, {10, 11, 12}},   <<"'10:11:12'">>},
+         {{5, {0, 0, 1}},     <<"'120:00:01'">>},
+         {{-1, {23, 59, 59}}, <<"'-00:00:01'">>},
+         {{-1, {23, 59, 0}},  <<"'-00:01:00'">>},
+         {{-1, {23, 0, 0}},   <<"'-01:00:00'">>},
+         {{-1, {0, 0, 0}},    <<"'-24:00:00'">>},
+         {{-5, {10, 0, 0}},  <<"'-110:00:00'">>}]
     ),
     ok = mysql:query(Pid, "DROP TABLE tm").
 
@@ -244,47 +236,57 @@ microseconds(Pid) ->
                              binary:split(Version1, <<".">>, [global])),
         Version2 >= [5, 6, 4] orelse throw(nope)
     of _ ->
-        run_test_microseconds(Pid)
+        test_time_microseconds(Pid),
+        test_datetime_microseconds(Pid)
     catch _:_ ->
         error_logger:info_msg("Skipping microseconds test. Current MySQL"
                               " version is ~s. Required version is >= 5.6.4.~n",
                               [Version])
     end.
 
-run_test_microseconds(Pid) ->
+test_time_microseconds(Pid) ->
     ok = mysql:query(Pid, "CREATE TABLE m (t TIME(6))"),
-    SelectTime = "SELECT t FROM m",
-    {ok, SelectStmt} = mysql:prepare(Pid, SelectTime),
-    {ok, InsertStmt} = mysql:prepare(Pid, "INSERT INTO m VALUES (?)"),
-    %% Positive time, insert using plain query
-    E1 = {0, {23, 59, 57.654321}},
-    ok = mysql:query(Pid, <<"INSERT INTO m VALUES ('23:59:57.654321')">>),
-    ?assertEqual({ok, [<<"t">>], [[E1]]}, mysql:query(Pid, SelectTime)),
-    ?assertEqual({ok, [<<"t">>], [[E1]]}, mysql:execute(Pid, SelectStmt, [])),
-    ok = mysql:query(Pid, "DELETE FROM m"),
-    %% The same, but insert using prepared stmt
-    ok = mysql:execute(Pid, InsertStmt, [E1]),
-    ?assertEqual({ok, [<<"t">>], [[E1]]}, mysql:query(Pid, SelectTime)),
-    ?assertEqual({ok, [<<"t">>], [[E1]]}, mysql:execute(Pid, SelectStmt, [])),
-    ok = mysql:query(Pid, "DELETE FROM m"),
+    %% Positive time
+    write_read_text_binary(Pid, {0, {23, 59, 57.654321}},
+                           <<"'23:59:57.654321'">>, <<"m">>, <<"t">>),
     %% Negative time
-    E2 = {-1, {23, 59, 57.654321}},
-    ok = mysql:query(Pid, <<"INSERT INTO m VALUES ('-00:00:02.345679')">>),
-    ?assertEqual({ok, [<<"t">>], [[E2]]}, mysql:query(Pid, SelectTime)),
-    ?assertEqual({ok, [<<"t">>], [[E2]]}, mysql:execute(Pid, SelectStmt, [])),
-    ok = mysql:query(Pid, "DELETE FROM m"),
-    %% The same, but insert using prepared stmt
-    ok = mysql:execute(Pid, InsertStmt, [E2]),
-    ?assertEqual({ok, [<<"t">>], [[E2]]}, mysql:query(Pid, SelectTime)),
-    ?assertEqual({ok, [<<"t">>], [[E2]]}, mysql:execute(Pid, SelectStmt, [])),
-    ok = mysql:query(Pid, "DROP TABLE m"),
-    %% Datetime
-    Q3 = <<"SELECT TIMESTAMP '2014-11-23 23:59:57.654321' AS t">>,
-    E3 = [[{{2014, 11, 23}, {23, 59, 57.654321}}]],
-    ?assertEqual({ok, [<<"t">>], E3}, mysql:query(Pid, Q3)),
-    {ok, S3} = mysql:prepare(Pid, Q3),
-    ?assertEqual({ok, [<<"t">>], E3}, mysql:execute(Pid, S3, [])),
-    ok.
+    write_read_text_binary(Pid, {-1, {23, 59, 57.654321}},
+                           <<"'-00:00:02.345679'">>, <<"m">>, <<"t">>),
+    ok = mysql:query(Pid, "DROP TABLE m").
+
+test_datetime_microseconds(Pid) ->
+    ok = mysql:query(Pid, "CREATE TABLE dt (dt DATETIME(6))"),
+    write_read_text_binary(Pid, {{2014, 11, 23}, {23, 59, 57.654321}},
+                           <<"'2014-11-23 23:59:57.654321'">>, <<"dt">>,
+                           <<"dt">>),
+    ok = mysql:query(Pid, "DROP TABLE dt").
+
+%% @doc Tests write and read in text and the binary protocol, all combinations.
+%% This helper function assumes an empty table with a single column.
+write_read_text_binary(Conn, Term, SqlLiteral, Table, Column) ->
+    SelectQuery = <<"SELECT ", Column/binary, " FROM ", Table/binary>>,
+    {ok, SelectStmt} = mysql:prepare(Conn, SelectQuery),
+
+    %% Insert as text, read text and binary, delete
+    InsertQuery = <<"INSERT INTO ", Table/binary, " (", Column/binary, ")"
+                    " VALUES (", SqlLiteral/binary, ")">>,
+    ok = mysql:query(Conn, InsertQuery),
+    ?assertEqual({ok, [Column], [[Term]]}, mysql:query(Conn, SelectQuery)),
+    ?assertEqual({ok, [Column], [[Term]]}, mysql:execute(Conn, SelectStmt, [])),
+    mysql:query(Conn, <<"DELETE FROM ", Table/binary>>),
+
+    %% Insert as binary, read text and binary, delete
+    InsertQ = <<"INSERT INTO ", Table/binary, " (", Column/binary, ")",
+                " VALUES (?)">>,
+    {ok, InsertStmt} = mysql:prepare(Conn, InsertQ),
+    ok = mysql:execute(Conn, InsertStmt, [Term]),
+    ok = mysql:unprepare(Conn, InsertStmt),
+    ?assertEqual({ok, [Column], [[Term]]}, mysql:query(Conn, SelectQuery)),
+    ?assertEqual({ok, [Column], [[Term]]}, mysql:execute(Conn, SelectStmt, [])),
+    mysql:query(Conn, <<"DELETE FROM ", Table/binary>>),
+
+    %% Cleanup
+    ok = mysql:unprepare(Conn, SelectStmt).
 
 %% --------------------------------------------------------------------------