Browse Source

Add 'decode_decimal' option to control how DECIMALs are returned (#194)

The new option `{decode_decimal, auto | number | float | binary}` controls how a
DECIMAL value is translated to an Erlang term. The default 'auto' is the legacy behaviour which returns an integer or a float when no precision loss can happen for the column and a binary otherwise. The options 'number' and 'float' may return a float even if it can result in precision loss, while 'binary' always returns the textual representation of the number.
Jesse Gumm 1 year ago
parent
commit
b97ef3dc13
7 changed files with 143 additions and 64 deletions
  1. 2 0
      .gitignore
  2. 2 1
      include/records.hrl
  3. 13 2
      src/mysql.erl
  4. 27 16
      src/mysql_conn.erl
  5. 86 34
      src/mysql_protocol.erl
  6. 2 2
      test/mysql_protocol_tests.erl
  7. 11 9
      test/mysql_tests.erl

+ 2 - 0
.gitignore

@@ -23,3 +23,5 @@ test/ssl/server*
 test/ssl/my-ssl.cnf
 test/ssl/my-ssl.cnf-e
 mysql.d
+*~
+*.sw?

+ 2 - 1
include/records.hrl

@@ -47,7 +47,8 @@
 -record(eof, {status, warning_count}).
 
 %% Column definition, used while parsing a result set.
--record(col, {name, type, charset, length, decimals, flags}).
+-record(col, {name, type, charset, length, decimals, flags,
+              decode_decimal=auto}).
 
 %% A resultset. The rows can be either lists of terms or unparsed binaries as
 %% received from the server using either the text protocol or the binary

+ 13 - 2
src/mysql.erl

@@ -37,7 +37,8 @@
 
 -export_type([option/0, connection/0, query/0, statement_name/0,
               statement_ref/0, query_param/0, query_filtermap_fun/0,
-              query_result/0, transaction_result/1, server_reason/0]).
+              query_result/0, transaction_result/1, server_reason/0,
+              decode_decimal/0]).
 
 %% A connection is a ServerRef as in gen_server:call/2,3.
 -type connection() :: Name :: atom() |
@@ -65,6 +66,8 @@
 -type statement_name() :: atom().
 -type statement_ref() :: statement_id() | statement_name().
 
+-type decode_decimal() :: auto | binary | float | number.
+
 -type query_result() :: ok
                       | {ok, [column_name()], [row()]}
                       | {ok, [{[column_name()], [row()]}, ...]}
@@ -93,7 +96,8 @@
                 | {query_cache_time, non_neg_integer()}
                 | {tcp_options, [gen_tcp:connect_option()]}
                 | {ssl, term()}
-                | {float_as_decimal, boolean() | non_neg_integer()}.
+                | {float_as_decimal, boolean() | non_neg_integer()}
+                | {decode_decimal, decode_decimal()}.
 
 -include("exception.hrl").
 
@@ -199,6 +203,13 @@
 %%       rounding and truncation errors from happening on the server side. If a
 %%       number is specified, the float is rounded to this number of
 %%       decimals. This is off (false) by default.</dd>
+%%   <dt>`{decode_decimal, auto | float | number | binary}'</dt>
+%%   <dd>When decoding `decimal' columns from the server, force the return the
+%%       value as either a `binary()', `float()`, or `number()' (specified by
+%%       the atoms `binary', `float', `number' respectively). Defaults to
+%%       `auto', which will return a number (`integer()' or `float()') unless
+%%       the conversion to `float()' would result in a loss of precision, in
+%%       which case, `binary()' is returned.</dd>
 %% </dl>
 -spec start_link(Options :: [option()]) -> {ok, pid()} | ignore | {error, term()}.
 start_link(Options) ->

+ 27 - 16
src/mysql_conn.erl

@@ -58,7 +58,7 @@
                 affected_rows = 0, status = 0, warning_count = 0, insert_id = 0,
                 transaction_levels = [], ping_ref = undefined,
                 stmts = dict:new(), query_cache = empty, cap_found_rows = false,
-                float_as_decimal = false}).
+                float_as_decimal = false, decode_decimal = auto}).
 
 %% @private
 init(Opts) ->
@@ -91,6 +91,7 @@ init(Opts) ->
     Queries           = proplists:get_value(queries, Opts, []),
     Prepares          = proplists:get_value(prepare, Opts, []),
     FloatAsDecimal    = proplists:get_value(float_as_decimal, Opts, false),
+    DecodeDecimal     = proplists:get_value(decode_decimal, Opts, auto),
 
     true = lists:all(fun mysql_protocol:valid_path/1, AllowedLocalPaths),
 
@@ -114,7 +115,8 @@ init(Opts) ->
         query_timeout = QueryTimeout,
         query_cache_time = QueryCacheTime,
         cap_found_rows = (SetFoundRows =:= true),
-        float_as_decimal = FloatAsDecimal
+        float_as_decimal = FloatAsDecimal,
+        decode_decimal = DecodeDecimal
     },
 
     case proplists:get_value(connect_mode, Opts, synchronous) of
@@ -468,6 +470,7 @@ handle_call(start_transaction, {FromPid, _},
     {reply, {error, busy}, State};
 handle_call(start_transaction, {FromPid, _},
             State = #state{socket = Socket, sockmod = SockMod,
+                           decode_decimal = DecodeDecimal,
                            transaction_levels = L, status = Status})
   when Status band ?SERVER_STATUS_IN_TRANS == 0, L == [];
        Status band ?SERVER_STATUS_IN_TRANS /= 0, L /= [] ->
@@ -478,13 +481,14 @@ handle_call(start_transaction, {FromPid, _},
     end,
     setopts(SockMod, Socket, [{active, false}]),
     {ok, [Res = #ok{}]} = mysql_protocol:query(Query, SockMod, Socket,
-                                               [], no_filtermap_fun,
+                                               [], DecodeDecimal, no_filtermap_fun,
                                                ?cmd_timeout),
     setopts(SockMod, Socket, [{active, once}]),
     State1 = update_state(Res, State),
     {reply, ok, State1#state{transaction_levels = [{FromPid, MRef} | L]}};
 handle_call(rollback, {FromPid, _},
             State = #state{socket = Socket, sockmod = SockMod, status = Status,
+                           decode_decimal = DecodeDecimal,
                            transaction_levels = [{FromPid, MRef} | L]})
   when Status band ?SERVER_STATUS_IN_TRANS /= 0 ->
     erlang:demonitor(MRef),
@@ -494,13 +498,14 @@ handle_call(rollback, {FromPid, _},
     end,
     setopts(SockMod, Socket, [{active, false}]),
     {ok, [Res = #ok{}]} = mysql_protocol:query(Query, SockMod, Socket,
-                                               [], no_filtermap_fun,
+                                               [], DecodeDecimal, no_filtermap_fun,
                                                ?cmd_timeout),
     setopts(SockMod, Socket, [{active, once}]),
     State1 = update_state(Res, State),
     {reply, ok, State1#state{transaction_levels = L}};
 handle_call(commit, {FromPid, _},
             State = #state{socket = Socket, sockmod = SockMod, status = Status,
+                           decode_decimal = DecodeDecimal,
                            transaction_levels = [{FromPid, MRef} | L]})
   when Status band ?SERVER_STATUS_IN_TRANS /= 0 ->
     erlang:demonitor(MRef),
@@ -510,7 +515,7 @@ handle_call(commit, {FromPid, _},
     end,
     setopts(SockMod, Socket, [{active, false}]),
     {ok, [Res = #ok{}]} = mysql_protocol:query(Query, SockMod, Socket,
-                                               [], no_filtermap_fun,
+                                               [], DecodeDecimal, no_filtermap_fun,
                                                ?cmd_timeout),
     setopts(SockMod, Socket, [{active, once}]),
     State1 = update_state(Res, State),
@@ -592,7 +597,8 @@ code_change(_OldVsn, _State, _Extra) ->
 execute_stmt(Stmt, Args, FilterMap, Timeout, State) ->
     #state{socket = Socket, sockmod = SockMod,
            allowed_local_paths = AllowedPaths,
-           float_as_decimal = FloatAsDecimal} = State,
+           float_as_decimal = FloatAsDecimal,
+           decode_decimal = DecodeDecimal} = State,
     Args1 = case FloatAsDecimal of
                 false ->
                     Args;
@@ -601,12 +607,13 @@ execute_stmt(Stmt, Args, FilterMap, Timeout, State) ->
             end,
     setopts(SockMod, Socket, [{active, false}]),
     {ok, Recs} = case mysql_protocol:execute(Stmt, Args1, SockMod, Socket,
-                                             AllowedPaths, FilterMap,
-                                             Timeout) of
+                                             AllowedPaths, DecodeDecimal,
+                                             FilterMap, Timeout) of
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
             kill_query(State),
-            mysql_protocol:fetch_execute_response(SockMod, Socket,
-                                                  [], FilterMap, ?cmd_timeout);
+            mysql_protocol:fetch_execute_response(SockMod, Socket, [],
+                                                  DecodeDecimal, FilterMap,
+                                                  ?cmd_timeout);
         {error, Reason} ->
             exit(Reason);
         QueryResult ->
@@ -656,15 +663,16 @@ query(Query, FilterMap, default_timeout,
     query(Query, FilterMap, DefaultTimeout, State);
 query(Query, FilterMap, Timeout, State) ->
     #state{sockmod = SockMod, socket = Socket,
-           allowed_local_paths = AllowedPaths} = State,
+           allowed_local_paths = AllowedPaths,
+           decode_decimal = DecodeDecimal} = State,
     setopts(SockMod, Socket, [{active, false}]),
     Result = mysql_protocol:query(Query, SockMod, Socket, AllowedPaths,
-                                  FilterMap, Timeout),
+                                  DecodeDecimal, FilterMap, Timeout),
     {ok, Recs} = case Result of
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
             kill_query(State),
             mysql_protocol:fetch_query_response(SockMod, Socket,
-                                                [], FilterMap,
+                                                [], DecodeDecimal, FilterMap,
                                                 ?cmd_timeout);
         {error, Reason} ->
             exit(Reason);
@@ -759,11 +767,13 @@ schedule_ping(State = #state{ping_timeout = Timeout, ping_ref = Ref}) ->
     State#state{ping_ref = erlang:send_after(Timeout, self(), ping)}.
 
 %% @doc Fetches and logs warnings. Query is the query that gave the warnings.
-log_warnings(#state{socket = Socket, sockmod = SockMod}, Query) ->
+log_warnings(#state{socket = Socket, sockmod = SockMod,
+                    decode_decimal = DecodeDecimal}, Query) ->
     setopts(SockMod, Socket, [{active, false}]),
     {ok, [#resultset{rows = Rows}]} = mysql_protocol:query(<<"SHOW WARNINGS">>,
                                                            SockMod, Socket,
-                                                           [], no_filtermap_fun,
+                                                           [], DecodeDecimal,
+                                                           no_filtermap_fun,
                                                            ?cmd_timeout),
     setopts(SockMod, Socket, [{active, once}]),
     Lines = [[Level, " ", integer_to_binary(Code), ": ", Message, "\n"]
@@ -810,7 +820,8 @@ kill_query(#state{connection_id = ConnId, host = Host, port = Port,
             IdBin = integer_to_binary(ConnId),
             {ok, [#ok{}]} = mysql_protocol:query(<<"KILL QUERY ", IdBin/binary>>,
                                                  SockMod, Socket,
-                                                 [], no_filtermap_fun,
+                                                 [], auto,
+                                                 no_filtermap_fun,
                                                  ?cmd_timeout),
             mysql_protocol:quit(SockMod, Socket);
         #error{} = E ->

+ 86 - 34
src/mysql_protocol.erl

@@ -28,8 +28,8 @@
 -module(mysql_protocol).
 
 -export([handshake/8, change_user/8, quit/2, ping/2,
-         query/6, fetch_query_response/5, prepare/3, unprepare/3,
-         execute/7, fetch_execute_response/5, reset_connnection/2,
+         query/7, fetch_query_response/6, prepare/3, unprepare/3,
+         execute/8, fetch_execute_response/6, reset_connnection/2,
          valid_params/1, valid_path/1]).
 
 -type query_filtermap() :: no_filtermap_fun | mysql:query_filtermap_fun().
@@ -38,6 +38,8 @@
                         | full_auth_requested
                         | {public_key, term()}.
 
+-type decode_decimal() :: mysql:decode_decimal().
+
 %% How much data do we want per packet?
 -define(MAX_BYTES_PER_PACKET, 16#1000000).
 
@@ -193,19 +195,19 @@ ping(SockModule, Socket) ->
     {ok, OkPacket, _SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
     parse_ok_packet(OkPacket).
 
--spec query(Query :: iodata(), module(), term(), [binary()], query_filtermap(),
-            timeout()) ->
+-spec query(Query :: iodata(), module(), term(), [binary()],
+            decode_decimal(), query_filtermap(), timeout()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, term()}.
-query(Query, SockModule, Socket, AllowedPaths, FilterMap, Timeout) ->
+query(Query, SockModule, Socket, AllowedPaths, DecodeDecimal, FilterMap, Timeout) ->
     Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
     SeqNum0 = 0,
     {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, SeqNum0),
-    fetch_query_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout).
+    fetch_query_response(SockModule, Socket, AllowedPaths, DecodeDecimal, FilterMap, Timeout).
 
 %% @doc This is used by query/4. If query/4 returns {error, timeout}, this
 %% function can be called to retry to fetch the results of the query.
-fetch_query_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout) ->
-    fetch_response(SockModule, Socket, Timeout, text, AllowedPaths, FilterMap, []).
+fetch_query_response(SockModule, Socket, AllowedPaths, DecodeDecimal, FilterMap, Timeout) ->
+    fetch_response(SockModule, Socket, Timeout, text, AllowedPaths, DecodeDecimal, FilterMap, []).
 
 %% @doc Prepares a statement.
 -spec prepare(iodata(), module(), term()) -> #error{} | #prepared{}.
@@ -251,10 +253,10 @@ unprepare(#prepared{statement_id = Id}, SockModule, Socket) ->
 
 %% @doc Executes a prepared statement.
 -spec execute(#prepared{}, [term()], module(), term(), [binary()],
-              query_filtermap(), timeout()) ->
+              decode_decimal(), query_filtermap(), timeout()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, term()}.
 execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
-        SockModule, Socket, AllowedPaths, FilterMap, Timeout)
+        SockModule, Socket, AllowedPaths, DecodeDecimal, FilterMap, Timeout)
   when ParamCount == length(ParamValues) ->
     %% Flags Constant Name
     %% 0x00 CURSOR_TYPE_NO_CURSOR
@@ -282,12 +284,12 @@ execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
             iolist_to_binary([Req1, TypesAndSigns, EncValues])
     end,
     {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, 0),
-    fetch_execute_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout).
+    fetch_execute_response(SockModule, Socket, AllowedPaths, DecodeDecimal, FilterMap, Timeout).
 
 %% @doc This is used by execute/5. If execute/5 returns {error, timeout}, this
 %% function can be called to retry to fetch the results of the query.
-fetch_execute_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout) ->
-    fetch_response(SockModule, Socket, Timeout, binary, AllowedPaths, FilterMap, []).
+fetch_execute_response(SockModule, Socket, AllowedPaths, DecodeDecimal, FilterMap, Timeout) ->
+    fetch_response(SockModule, Socket, Timeout, binary, AllowedPaths, DecodeDecimal, FilterMap, []).
 
 %% @doc Changes the user of the connection.
 -spec change_user(module(), term(), iodata(), iodata(), binary(), binary(),
@@ -589,9 +591,9 @@ parse_handshake_confirm(<<?MORE_DATA, MoreData/binary>>) ->
 %% either the text format (for plain queries) or the binary format (for
 %% prepared statements).
 -spec fetch_response(module(), term(), timeout(), text | binary, [binary()],
-                     query_filtermap(), list()) ->
+                     decode_decimal(), query_filtermap(), list()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, term()}.
-fetch_response(SockModule, Socket, Timeout, Proto, AllowedPaths, FilterMap, Acc) ->
+fetch_response(SockModule, Socket, Timeout, Proto, AllowedPaths, DecodeDecimal, FilterMap, Acc) ->
     case recv_packet(SockModule, Socket, Timeout, any) of
         {ok, ?local_infile_pattern = Packet, SeqNum2} ->
             Filename = parse_local_infile_packet(Packet),
@@ -610,7 +612,7 @@ fetch_response(SockModule, Socket, Timeout, Proto, AllowedPaths, FilterMap, Acc)
                     [#error{code = -2, msg = ErrorMsg}|Acc]
             end,
             fetch_response(SockModule, Socket, Timeout, Proto, AllowedPaths,
-                           FilterMap, Acc1);
+                           DecodeDecimal, FilterMap, Acc1);
         {ok, Packet, SeqNum2} ->
             Result = case Packet of
                 ?ok_pattern ->
@@ -621,13 +623,13 @@ fetch_response(SockModule, Socket, Timeout, Proto, AllowedPaths, FilterMap, Acc)
                     %% The first packet in a resultset is only the column count.
                     {ColCount, <<>>} = lenenc_int(ResultPacket),
                     fetch_resultset(SockModule, Socket, ColCount, Proto,
-                                    FilterMap, SeqNum2)
+                                    DecodeDecimal, FilterMap, SeqNum2)
             end,
             Acc1 = [Result | Acc],
             case more_results_exists(Result) of
                 true ->
                     fetch_response(SockModule, Socket, Timeout, Proto,
-                                   AllowedPaths, FilterMap, Acc1);
+                                   AllowedPaths, DecodeDecimal, FilterMap, Acc1);
                 false ->
                     {ok, lists:reverse(Acc1)}
             end;
@@ -637,14 +639,16 @@ fetch_response(SockModule, Socket, Timeout, Proto, AllowedPaths, FilterMap, Acc)
 
 %% @doc Fetches a result set.
 -spec fetch_resultset(module(), term(), integer(), text | binary,
-                      query_filtermap(), integer()) ->
+                      decode_decimal(), query_filtermap(), integer()) ->
     #resultset{} | #error{}.
-fetch_resultset(SockModule, Socket, FieldCount, Proto, FilterMap, SeqNum0) ->
+fetch_resultset(SockModule, Socket, FieldCount, Proto, DecodeDecimal, FilterMap, SeqNum0) ->
     {ok, ColDefs0, SeqNum1} = fetch_column_definitions(SockModule, Socket,
                                                        SeqNum0, FieldCount, []),
     {ok, DelimPacket, SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
     #eof{} = parse_eof_packet(DelimPacket),
-    ColDefs1 = lists:map(fun parse_column_definition/1, ColDefs0),
+    ColDefs1 = lists:map(fun(ColDef) ->
+                                 parse_column_definition(ColDef, DecodeDecimal)
+                         end, ColDefs0),
     case fetch_resultset_rows(SockModule, Socket, FieldCount, ColDefs1, Proto,
                               FilterMap, SeqNum2, []) of
         {ok, Rows, _SeqNum3, #eof{status = S, warning_count = W}} ->
@@ -669,7 +673,7 @@ fetch_resultset_rows(SockModule, Socket, FieldCount, ColDefs, Proto,
             Eof = parse_eof_packet(Packet),
             {ok, lists:reverse(Acc), SeqNum1, Eof};
         RowPacket ->
-            Row0=decode_row(FieldCount, ColDefs, RowPacket, Proto),
+            Row0 = decode_row(FieldCount, ColDefs, RowPacket, Proto),
             Acc1 = case filtermap_resultset_row(FilterMap, ColDefs, Row0) of
                 false ->
                     Acc;
@@ -712,7 +716,7 @@ fetch_column_definitions(_SockModule, _Socket, SeqNum, 0, Acc) ->
     {ok, lists:reverse(Acc), SeqNum}.
 
 %% Parses a packet containing a column definition (part of a result set)
-parse_column_definition(Data) ->
+parse_column_definition(Data, DecodeDecimal) ->
     {<<"def">>, Rest1} = lenenc_str(Data),   %% catalog (always "def")
     {_Schema, Rest2} = lenenc_str(Rest1),    %% schema-name
     {_Table, Rest3} = lenenc_str(Rest2),     %% virtual table-name
@@ -734,7 +738,7 @@ parse_column_definition(Data) ->
     %% }
     <<>> = Rest8,
     #col{name = Name, type = Type, charset = Charset, length = Length,
-         decimals = Decimals, flags = Flags}.
+         decimals = Decimals, flags = Flags, decode_decimal = DecodeDecimal}.
 
 %% @doc Decodes a row using either the text or binary format.
 -spec decode_row(integer(), [#col{}], binary(), text | binary) -> [term()].
@@ -783,11 +787,12 @@ decode_text(#col{type = T}, Text)
 decode_text(#col{type = ?TYPE_BIT, length = Length}, Text) ->
     %% Convert to <<_:Length/bitstring>>
     decode_bitstring(Text, Length);
-decode_text(#col{type = T, decimals = S, length = L}, Text)
+decode_text(#col{type = T, decimals = S, length = L,
+                 decode_decimal = DecodeDecimal}, Text)
   when T == ?TYPE_DECIMAL; T == ?TYPE_NEWDECIMAL ->
     %% Length is the max number of symbols incl. dot and minus sign, e.g. the
     %% number of digits plus 2.
-    decode_decimal(Text, L - 2, S);
+    decode_decimal(DecodeDecimal, Text, L - 2, S);
 decode_text(#col{type = ?TYPE_DATE},
             <<Y:4/binary, "-", M:2/binary, "-", D:2/binary>>) ->
     {binary_to_integer(Y), binary_to_integer(M), binary_to_integer(D)};
@@ -870,7 +875,7 @@ decode_binary_row_acc([ColDef | ColDefs], <<0:1, NullBitMap/bitstring>>, Data,
     %% Not NULL
     {Term, Rest} = decode_binary(ColDef, Data),
     decode_binary_row_acc(ColDefs, NullBitMap, Rest, [Term | Acc]);
-decode_binary_row_acc([], _, <<>>, Acc) ->
+decode_binary_row_acc([], _NullBitMap, <<>>, Acc) ->
     lists:reverse(Acc).
 
 %% @doc Decodes a null bitmap as stored by MySQL and returns it in a strait
@@ -963,12 +968,13 @@ decode_binary(#col{type = ?TYPE_TINY, flags = F},
               <<Value:8/signed, Rest/binary>>)
   when F band ?UNSIGNED_FLAG == 0 ->
     {Value, Rest};
-decode_binary(#col{type = T, decimals = S, length = L}, Data)
+decode_binary(#col{type = T, decimals = S, length = L,
+                   decode_decimal = DecodeDecimal}, Data)
   when T == ?TYPE_DECIMAL; T == ?TYPE_NEWDECIMAL ->
     %% Length is the max number of symbols incl. dot and minus sign, e.g. the
     %% number of digits plus 2.
     {Binary, Rest} = lenenc_str(Data),
-    {decode_decimal(Binary, L - 2, S), Rest};
+    {decode_decimal(DecodeDecimal, Binary, L - 2, S), Rest};
 decode_binary(#col{type = ?TYPE_DOUBLE},
               <<Value:64/float-little, Rest/binary>>) ->
     {Value, Rest};
@@ -1224,11 +1230,16 @@ encode_bitstring(Bitstring) ->
     PaddingSize = byte_size(Bitstring) * 8 - Size,
     <<0:PaddingSize, Bitstring:Size/bitstring>>.
 
-decode_decimal(Bin, _P, 0) ->
+decode_decimal(Decode, Bin, _P, 0) when Decode =:= number;
+                                        Decode =:= auto ->
     binary_to_integer(Bin);
-decode_decimal(Bin, P, S) when P =< 15, S > 0 ->
+decode_decimal(Decode, Bin, _P, 0) when Decode =:= float ->
+    float(binary_to_integer(Bin));
+decode_decimal(Decode, Bin, P, S) when Decode =:= auto, P =< 15, S > 0;
+                                       Decode =:= number;
+                                       Decode =:= float ->
     binary_to_float(Bin);
-decode_decimal(Bin, P, S) when P >= 16, S > 0 ->
+decode_decimal(_Decode, Bin, _P, _S) ->
     Bin.
 
 %% -- Protocol basics: packets --
@@ -1645,10 +1656,51 @@ decode_text_test() ->
                   [?TYPE_FLOAT, ?TYPE_DOUBLE]),
     %% Decimal types
     lists:foreach(fun (T) ->
-                      ColDef = #col{type = T, decimals = 1, length = 4},
-                      ?assertMatch(3.0, decode_text(ColDef, <<"3.0">>))
+                      ColDef = #col{type = T},
+                      ?assertMatch(3.0,
+                                   decode_text(ColDef#col{decimals = 1, length = 4},
+                                               <<"3.0">>)),
+                      %% Decimal Decode Options
+                      %% A small value like this would be returned as a float
+                      %% if decode_decimal=auto. We want to test forcing the
+                      %% binary return
+                      ?assertMatch(<<"3.0">>,
+                                   decode_text(ColDef#col{decimals = 1, length = 4,
+                                                          decode_decimal = binary},
+                                               <<"3.0">>)),
+                      %% When decode_decimal=number, we expect a float and accept the
+                      %% precision loss.
+                      ?assertMatch(123456789.45678912,
+                                   decode_text(ColDef#col{decimals = 12, length = 23,
+                                                          decode_decimal = number},
+                                               <<"123456789.456789123456789">>)),
+                      %% When with decode_decimal=auto, we expect a binary
+                      %% when the float value is large
+                      ?assertMatch(<<"12345678901234567890.12345">>,
+                                   decode_text(ColDef#col{decimals = 5, length = 27,
+                                                          decode_decimal = auto},
+                                               <<"12345678901234567890.12345">>)),
+                      %% When with decode_decimal=number, we expect a large float
+                      ?assertMatch(1.2345678901234567e19,
+                                   decode_text(ColDef#col{decimals = 5, length = 27,
+                                                          decode_decimal = number},
+                                               <<"12345678901234567890.12345">>)),
+                      %% When decode_decimal=float, even if the expected return value
+                      %% would be an integer, force the float
+                      ?assertMatch(3.0,
+                                   decode_text(ColDef#col{decimals = 0, length = 2,
+                                                          decode_decimal = float},
+                                               <<"3">>)),
+                      %% decimal_decode=auto will encode to binary to prevent
+                      %% the loss of precision when converting to float, so
+                      %% this is just testing to ensure that that happens.
+                      ?assertMatch(<<"123456789.456789123456789">>,
+                                   decode_text(ColDef#col{decimals = 12, length = 23,
+                                                          decode_decimal = auto},
+                                               <<"123456789.456789123456789">>))
                   end,
                   [?TYPE_DECIMAL, ?TYPE_NEWDECIMAL]),
+
     ?assertEqual(3.0,  decode_text(#col{type = ?TYPE_FLOAT}, <<"3">>)),
     ?assertEqual(30.0, decode_text(#col{type = ?TYPE_FLOAT}, <<"3e1">>)),
     ?assertEqual(3,    decode_text(#col{type = ?TYPE_LONG}, <<"3">>)),

+ 2 - 2
test/mysql_protocol_tests.erl

@@ -44,7 +44,7 @@ resultset_test() ->
     ExpectedCommunication = [{send, ExpectedReq},
                              {recv, ExpectedResponse}],
     Sock = mock_tcp:create(ExpectedCommunication),
-    {ok, [ResultSet]} = mysql_protocol:query(Query, mock_tcp, Sock, [],
+    {ok, [ResultSet]} = mysql_protocol:query(Query, mock_tcp, Sock, [], auto,
                                              no_filtermap_fun, infinity),
     mock_tcp:close(Sock),
     ?assertMatch(#resultset{cols = [#col{name = <<"@@version_comment">>}],
@@ -82,7 +82,7 @@ resultset_error_test() ->
         "48 04 23 48 59 30 30 30    4e 6f 20 74 61 62 6c 65    H.#HY000No table"
         "73 20 75 73 65 64                                     s used"),
     Sock = mock_tcp:create([{send, ExpectedReq}, {recv, ExpectedResponse}]),
-    {ok, [Result]} = mysql_protocol:query(Query, mock_tcp, Sock, [],
+    {ok, [Result]} = mysql_protocol:query(Query, mock_tcp, Sock, [], auto,
                                           no_filtermap_fun, infinity),
     ?assertMatch(#error{}, Result),
     mock_tcp:close(Sock),

+ 11 - 9
test/mysql_tests.erl

@@ -40,6 +40,7 @@
                           "  f FLOAT,"
                           "  d DOUBLE,"
                           "  dc DECIMAL(5,3),"
+                          "  ldc DECIMAL(25,3),"
                           "  y YEAR,"
                           "  ti TIME,"
                           "  ts TIMESTAMP,"
@@ -641,8 +642,8 @@ multi_statements(Pid) ->
 
 text_protocol(Pid) ->
     ok = mysql:query(Pid, ?create_table_t),
-    ok = mysql:query(Pid, <<"INSERT INTO t (bl, f, d, dc, y, ti, ts, da, c)"
-                            " VALUES ('blob', 3.14, 3.14, 3.14, 2014,"
+    ok = mysql:query(Pid, <<"INSERT INTO t (bl, f, d, dc, ldc, y, ti, ts, da, c)"
+                            " VALUES ('blob', 3.14, 3.14, 3.14, 3.14, 2014,"
                             "'00:22:11', '2014-11-03 00:22:24', '2014-11-03',"
                             " NULL)">>),
     ?assertEqual(1, mysql:warning_count(Pid)), %% tx has no default value
@@ -652,9 +653,10 @@ text_protocol(Pid) ->
     %% select
     {ok, Columns, Rows} = mysql:query(Pid, <<"SELECT * FROM t">>),
     ?assertEqual([<<"id">>, <<"bl">>, <<"tx">>, <<"f">>, <<"d">>, <<"dc">>,
-                  <<"y">>, <<"ti">>, <<"ts">>, <<"da">>, <<"c">>], Columns),
+                  <<"ldc">>, <<"y">>, <<"ti">>, <<"ts">>, <<"da">>, <<"c">>],
+                 Columns),
     ?assertEqual([[1, <<"blob">>, <<>>, 3.14, 3.14, 3.14,
-                   2014, {0, {0, 22, 11}},
+                   <<"3.140">>, 2014, {0, {0, 22, 11}},
                    {{2014, 11, 03}, {00, 22, 24}}, {2014, 11, 03}, null]],
                  Rows),
 
@@ -663,22 +665,22 @@ text_protocol(Pid) ->
 binary_protocol(Pid) ->
     ok = mysql:query(Pid, ?create_table_t),
     %% The same queries as in the text protocol. Expect the same results.
-    {ok, Ins} = mysql:prepare(Pid, <<"INSERT INTO t (bl, tx, f, d, dc, y, ti,"
+    {ok, Ins} = mysql:prepare(Pid, <<"INSERT INTO t (bl, tx, f, d, dc, ldc, y, ti,"
                                      " ts, da, c)"
-                                     " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)">>),
+                                     " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)">>),
     %% 16#161 is the codepoint for "s with caron"; <<197, 161>> in UTF-8.
     ok = mysql:execute(Pid, Ins, [<<"blob">>, [16#161], 3.14, 3.14, 3.14,
-                                  2014, {0, {0, 22, 11}},
+                                  3.14, 2014, {0, {0, 22, 11}},
                                   {{2014, 11, 03}, {0, 22, 24}},
                                   {2014, 11, 03}, null]),
 
     {ok, Stmt} = mysql:prepare(Pid, <<"SELECT * FROM t WHERE id=?">>),
     {ok, Columns, Rows} = mysql:execute(Pid, Stmt, [1]),
     ?assertEqual([<<"id">>, <<"bl">>, <<"tx">>, <<"f">>, <<"d">>, <<"dc">>,
-                  <<"y">>, <<"ti">>,
+                  <<"ldc">>, <<"y">>, <<"ti">>,
                   <<"ts">>, <<"da">>, <<"c">>], Columns),
     ?assertEqual([[1, <<"blob">>, <<197, 161>>, 3.14, 3.14, 3.14,
-                   2014, {0, {0, 22, 11}},
+                   <<"3.140">>, 2014, {0, {0, 22, 11}},
                    {{2014, 11, 03}, {00, 22, 24}}, {2014, 11, 03}, null]],
                  Rows),