Browse Source

Add optional float-as-decimal encoding

* Possibility to send numbers as decimals using {decimal, Value}.
* New option {float_as_decimal, boolean() | non_neg_integer()} to
  do this automatically for floats, optionally rounded to a given
  number of decimals.
Viktor Söderqvist 3 years ago
parent
commit
a74aff1e45
7 changed files with 158 additions and 16 deletions
  1. 13 8
      doc/overview.edoc
  2. 9 2
      src/mysql.erl
  3. 23 5
      src/mysql_conn.erl
  4. 6 0
      src/mysql_encode.erl
  5. 12 1
      src/mysql_protocol.erl
  6. 3 0
      test/mysql_encode_tests.erl
  7. 92 0
      test/mysql_tests.erl

+ 13 - 8
doc/overview.edoc

@@ -69,12 +69,14 @@ For the reference manual see the <a href="mysql.html">mysql</a> module.
       <td>DECIMAL(P, S)</td>
       <td>`integer()' when S == 0<br />
           `float()' when P =&lt; 15 and S &gt; 0<br />
-          `binary()' when P &gt;= 16 and S &gt; 0 [<a href="#vn2">2</a>]</td>
-      <td>`42'<br />`3.14'<br />`<<"3.14159265358979323846">>'</td>
+          `binary()' when P &gt;= 16 and S &gt; 0 [<a href="#vn2">2</a>]<br />
+          `{decimal, Value}' [<a href="#vn3">3</a>] (parameter only)</td>
+      <td>`42'<br />`3.14'<br />`<<"3.14159265358979323846">>'<br />
+          `{decimal, 10.2}'</td>
     </tr>
     <tr>
       <td>DATETIME, TIMESTAMP</td>
-      <td>`calendar:datetime()' [<a href="#vn3">3</a>]</td>
+      <td>`calendar:datetime()' [<a href="#vn4">4</a>]</td>
       <td>`{{2014, 11, 18}, {10, 22, 36}}'</td>
     </tr>
     <tr>
@@ -84,8 +86,8 @@ For the reference manual see the <a href="mysql.html">mysql</a> module.
     </tr>
     <tr>
       <td>TIME</td>
-      <td>`{Days, calendar:time()}' [<a href="#vn3">3</a>,
-          <a href="#vn4">4</a>]</td>
+      <td>`{Days, calendar:time()}' [<a href="#vn4">4</a>,
+          <a href="#vn5">5</a>]</td>
       <td>`{0, {10, 22, 36}}'</td>
     </tr>
     <tr>
@@ -111,17 +113,20 @@ Notes:
     can be represented without precision loss and as `binary()' for high
     precision DECIMAL values. This is similar to how the `odbc' OTP application
     treats DECIMALs.</li>
-  <li id="vn3">For `DATETIME', `TIMESTAMP' and `TIME' values with fractions of
+  <li id="vn3">DECIMALs can be sent as `{decimal, Value}' (where Value is a
+    number, string or binary) but values received from the database are
+    never returned in this format.</li>
+  <li id="vn4">For `DATETIME', `TIMESTAMP' and `TIME' values with fractions of
     seconds, we use a float for the seconds part. (These are unusual and were
     added to MySQL in version 5.6.4.)</li>
-  <li id="vn4">Since `TIME' can be outside the `calendar:time()' interval, we use
+  <li id="vn5">Since `TIME' can be outside the `calendar:time()' interval, we use
     the format as returned by `calendar:seconds_to_daystime/1' for `TIME'
     values.</li>
 </ol>
 
 <h2>Copying</h2>
 
-Copyright 2014-2019 The authors of MySQL/OTP. See the project page at
+Copyright 2014-2021 The authors of MySQL/OTP. See the project page at
 <a href="https://github.com/mysql-otp/mysql-otp"
    target="_top">https://github.com/mysql-otp/mysql-otp</a>.
 

+ 9 - 2
src/mysql.erl

@@ -1,5 +1,5 @@
 %% MySQL/OTP – MySQL client library for Erlang/OTP
-%% Copyright (C) 2014-2015, 2018 Viktor Söderqvist,
+%% Copyright (C) 2014-2015, 2018, 2021 Viktor Söderqvist,
 %%               2016 Johan Lövdahl
 %%               2017 Piotr Nosek, Michal Slaski
 %%
@@ -92,7 +92,8 @@
                 | {found_rows, boolean()}
                 | {query_cache_time, non_neg_integer()}
                 | {tcp_options, [gen_tcp:connect_option()]}
-                | {ssl, term()}.
+                | {ssl, term()}
+                | {float_as_decimal, boolean() | non_neg_integer()}.
 
 -include("exception.hrl").
 
@@ -192,6 +193,12 @@
 %%       The `server_name_indication' option, if omitted, defaults to the value
 %%       of the `host' option if it is a hostname string, otherwise no default
 %%       value is set.</dd>
+%%   <dt>`{float_as_decimal, boolean() | non_neg_integer()}'</dt>
+%%   <dd>Encode floats as decimals when sending parameters for parametrized
+%%       queries and prepared statements to the server. This prevents float
+%%       rounding and truncation errors from happening on the server side. If a
+%%       number is specified, the float is rounded to this number of
+%%       decimals. This is off (false) by default.</dd>
 %% </dl>
 -spec start_link(Options :: [option()]) -> {ok, pid()} | ignore | {error, term()}.
 start_link(Options) ->

+ 23 - 5
src/mysql_conn.erl

@@ -1,5 +1,5 @@
 %% MySQL/OTP – MySQL client library for Erlang/OTP
-%% Copyright (C) 2014-2018 Viktor Söderqvist
+%% Copyright (C) 2014-2021 Viktor Söderqvist
 %%
 %% This file is part of MySQL/OTP.
 %%
@@ -57,7 +57,8 @@
                 connect_timeout, ping_timeout, query_timeout, query_cache_time,
                 affected_rows = 0, status = 0, warning_count = 0, insert_id = 0,
                 transaction_levels = [], ping_ref = undefined,
-                stmts = dict:new(), query_cache = empty, cap_found_rows = false}).
+                stmts = dict:new(), query_cache = empty, cap_found_rows = false,
+                float_as_decimal = false}).
 
 %% @private
 init(Opts) ->
@@ -89,6 +90,7 @@ init(Opts) ->
 
     Queries           = proplists:get_value(queries, Opts, []),
     Prepares          = proplists:get_value(prepare, Opts, []),
+    FloatAsDecimal    = proplists:get_value(float_as_decimal, Opts, false),
 
     true = lists:all(fun mysql_protocol:valid_path/1, AllowedLocalPaths),
 
@@ -111,7 +113,8 @@ init(Opts) ->
         ping_timeout = PingTimeout,
         query_timeout = QueryTimeout,
         query_cache_time = QueryCacheTime,
-        cap_found_rows = (SetFoundRows =:= true)
+        cap_found_rows = (SetFoundRows =:= true),
+        float_as_decimal = FloatAsDecimal
     },
 
     case proplists:get_value(connect_mode, Opts, synchronous) of
@@ -554,9 +557,16 @@ code_change(_OldVsn, _State, _Extra) ->
 %% @doc Executes a prepared statement and returns {Reply, NewState}.
 execute_stmt(Stmt, Args, FilterMap, Timeout, State) ->
     #state{socket = Socket, sockmod = SockMod,
-           allowed_local_paths = AllowedPaths} = State,
+           allowed_local_paths = AllowedPaths,
+           float_as_decimal = FloatAsDecimal} = State,
+    Args1 = case FloatAsDecimal of
+                false ->
+                    Args;
+                _ ->
+                    [float_to_decimal(Arg, FloatAsDecimal) || Arg <- Args]
+            end,
     setopts(SockMod, Socket, [{active, false}]),
-    {ok, Recs} = case mysql_protocol:execute(Stmt, Args, SockMod, Socket,
+    {ok, Recs} = case mysql_protocol:execute(Stmt, Args1, SockMod, Socket,
                                              AllowedPaths, FilterMap,
                                              Timeout) of
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
@@ -576,6 +586,14 @@ execute_stmt(Stmt, Args, FilterMap, Timeout, State) ->
         andalso log_warnings(State1, Stmt#prepared.orig_query),
     handle_query_call_result(Recs, Stmt#prepared.orig_query, State1).
 
+%% @doc Formats floats as decimals, optionally with a given number of decimals.
+float_to_decimal(Arg, true) when is_float(Arg) ->
+    {decimal, list_to_binary(io_lib:format("~w", [Arg]))};
+float_to_decimal(Arg, N) when is_float(Arg), is_integer(N) ->
+    {decimal, float_to_binary(Arg, [{decimals, N}, compact])};
+float_to_decimal(Arg, _) ->
+    Arg.
+
 %% @doc Produces a tuple to return as an error reason.
 -spec error_to_reason(#error{}) -> mysql:server_reason().
 error_to_reason(#error{code = Code, state = State, msg = Msg}) ->

+ 6 - 0
src/mysql_encode.erl

@@ -23,6 +23,12 @@ encode(String) when is_list(String) ->
     encode(unicode:characters_to_binary(String));
 encode(Bitstring) when is_bitstring(Bitstring) ->
     ["b'", [ case B of 0 -> $0; 1 -> $1 end || <<B:1>> <= Bitstring ], $'];
+encode({decimal, Num}) when is_float(Num); is_integer(Num) ->
+    encode(Num);
+encode({decimal, Str}) when is_binary(Str); is_list(Str) ->
+    %% Simple injection block
+    nomatch = re:run(Str, <<"[^0-9.+\\-eE]">>),
+    Str;
 encode({Y, M, D}) ->
     io_lib:format("'~4..0b-~2..0b-~2..0b'", [Y, M, D]);
 encode({{Y, M, D}, {H, Mi, S}}) when is_integer(S) ->

+ 12 - 1
src/mysql_protocol.erl

@@ -1,5 +1,5 @@
 %% MySQL/OTP – MySQL client library for Erlang/OTP
-%% Copyright (C) 2014 Viktor Söderqvist
+%% Copyright (C) 2014-2021 Viktor Söderqvist
 %%               2017 Piotr Nosek, Michal Slaski
 %%
 %% This file is part of MySQL/OTP.
@@ -1078,6 +1078,14 @@ encode_param(Value) when is_integer(Value), Value < 0 ->
     end;
 encode_param(Value) when is_float(Value) ->
     {<<?TYPE_DOUBLE, 0>>, <<Value:64/float-little>>};
+encode_param({decimal, Value}) ->
+    Bin = if is_binary(Value) -> Value;
+             is_list(Value) -> list_to_binary(Value);
+             is_integer(Value) -> integer_to_binary(Value);
+             is_float(Value) -> list_to_binary(io_lib:format("~w", [Value]))
+          end,
+    EncLength = lenenc_int_encode(byte_size(Bin)),
+    {<<?TYPE_DECIMAL, 0>>, <<EncLength/binary, Bin/binary>>};
 encode_param(Value) when is_bitstring(Value) ->
     Binary = encode_bitstring(Value),
     EncLength = lenenc_int_encode(byte_size(Binary)),
@@ -1147,6 +1155,9 @@ is_valid_param(Value) when is_list(Value) ->
     end;
 is_valid_param(Value) when is_number(Value) ->
     true;
+is_valid_param({decimal, Value}) when is_binary(Value); is_list(Value);
+                                      is_float(Value); is_integer(Value) ->
+    true;
 is_valid_param(Value) when is_bitstring(Value) ->
     true;
 is_valid_param({Y, M, D}) ->

+ 3 - 0
test/mysql_encode_tests.erl

@@ -21,6 +21,9 @@ encode_test() ->
          {<<255, 0, 255, 0>>, <<"'", 255, 0, 255, 0, "'">>},
          %% BIT(N)
          {<<255, 2:3>>,   "b'11111111010'"},
+         %% Explicit decimal
+         {{decimal, 10.2}, "10.2"},
+         {{decimal, "10.2"}, "10.2"},
          %% DATE
          {{2014, 11, 03}, "'2014-11-03'"},
          {{0, 0, 0},      "'0000-00-00'"},

+ 92 - 0
test/mysql_tests.erl

@@ -328,6 +328,9 @@ query_test_() ->
           {"Binary protocol",       fun () -> binary_protocol(Pid) end},
           {"FLOAT rounding",        fun () -> float_rounding(Pid) end},
           {"DECIMAL",               fun () -> decimal(Pid) end},
+          {"DECIMAL truncated",     fun () -> decimal_trunc(Pid) end},
+          {"Float as decimal",      fun () -> float_as_decimal(Pid) end},
+          {"Float as decimal(2)",   fun () -> float_as_decimal_2(Pid) end},
           {"INT",                   fun () -> int(Pid) end},
           {"BIT(N)",                fun () -> bit(Pid) end},
           {"DATE",                  fun () -> date(Pid) end},
@@ -709,6 +712,95 @@ decimal(Pid) ->
                            <<"dec16">>, <<"d">>),
     ok = mysql:query(Pid, "DROP TABLE dec16").
 
+decimal_trunc(_Pid) ->
+    %% Create another connection with log_warnings enabled.
+    {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password},
+                                  {log_warnings, true}]),
+    ok = mysql:query(Pid, <<"USE otptest">>),
+    ok = mysql:query(Pid, <<"SET autocommit = 1">>),
+    ok = mysql:query(Pid, <<"SET SESSION sql_mode = ?">>, [?SQL_MODE]),
+    ok = mysql:query(Pid, <<"CREATE TABLE `test_decimals` ("
+                            "  `id` bigint(20) unsigned NOT NULL,"
+                            "  `balance` decimal(13,4) NOT NULL,"
+                            "  PRIMARY KEY (`id`)"
+                            ") ENGINE=InnoDB;">>),
+    ok = mysql:query(Pid, <<"INSERT INTO test_decimals (id, balance)"
+                            " VALUES (1, 5000), (2, 5000), (3, 5000);">>),
+    {ok, decr} = mysql:prepare(Pid, decr, <<"UPDATE test_decimals"
+                                            " SET balance = balance - ?"
+                                            " WHERE id = ?">>),
+    %% Decimal sent as float gives truncation warning.
+    {ok, ok, [{_, LoggedWarning1}|_]} = error_logger_acc:capture(fun () ->
+        ok = mysql:execute(Pid, decr, [10.2, 1]),
+        ok = mysql:execute(Pid, decr, [10.2, 1]),
+        ok = mysql:execute(Pid, decr, [10.2, 1]),
+        ok = mysql:execute(Pid, decr, [10.2, 1])
+    end),
+    ?assertMatch("Note 1265: Data truncated for column 'balance'" ++ _,
+                 LoggedWarning1),
+    %% Decimal sent as binary gives truncation warning.
+    {ok, ok, [{_, LoggedWarning2}|_]} = error_logger_acc:capture(fun () ->
+        ok = mysql:execute(Pid, decr, [<<"10.2">>, 2]),
+        ok = mysql:execute(Pid, decr, [<<"10.2">>, 2]),
+        ok = mysql:execute(Pid, decr, [<<"10.2">>, 2]),
+        ok = mysql:execute(Pid, decr, [<<"10.2">>, 2])
+    end),
+    ?assertMatch("Note 1265: Data truncated for column 'balance'" ++ _,
+                 LoggedWarning2),
+    %% Decimal sent as DECIMAL => no warning
+    {ok, ok, []} = error_logger_acc:capture(fun () ->
+        ok = mysql:execute(Pid, decr, [{decimal, <<"10.2">>}, 3]),
+        ok = mysql:execute(Pid, decr, [{decimal, "10.2"}, 3]),
+        ok = mysql:execute(Pid, decr, [{decimal, 10.2}, 3]),
+        ok = mysql:execute(Pid, decr, [{decimal, 10.2}, 3]),
+        ok = mysql:execute(Pid, decr, [{decimal, 0}, 3]) % <- integer coverage
+    end),
+    ?assertMatch({ok, _, [[1, 4959.2], [2, 4959.2], [3, 4959.2]]},
+                 mysql:query(Pid, <<"SELECT id, balance FROM test_decimals">>)),
+    ok = mysql:query(Pid, "DROP TABLE test_decimals"),
+    ok = mysql:stop(Pid).
+
+float_as_decimal(_Pid) ->
+    %% Create another connection with {float_as_decimal, true}
+    {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password},
+                                  {log_warnings, true},
+                                  {float_as_decimal, true}]),
+    ok = mysql:query(Pid, <<"USE otptest">>),
+    ok = mysql:query(Pid, <<"SET autocommit = 1">>),
+    ok = mysql:query(Pid, <<"SET SESSION sql_mode = ?">>, [?SQL_MODE]),
+    ok = mysql:query(Pid, <<"CREATE TABLE float_as_decimal ("
+                            "  balance decimal(13,4) NOT NULL"
+                            ") ENGINE=InnoDB;">>),
+    ok = mysql:query(Pid, <<"INSERT INTO float_as_decimal (balance)"
+                            " VALUES (5000);">>),
+    {ok, decr} = mysql:prepare(Pid, decr, <<"UPDATE float_as_decimal"
+                                            " SET balance = balance - ?">>),
+    %% Floats sent as decimal => no truncation warning.
+    {ok, ok, []} = error_logger_acc:capture(fun () ->
+        ok = mysql:execute(Pid, decr, [10.2]),
+        ok = mysql:execute(Pid, decr, [10.2]),
+        ok = mysql:execute(Pid, decr, [10.2]),
+        ok = mysql:execute(Pid, decr, [10.2])
+    end),
+    ok = mysql:query(Pid, "DROP TABLE float_as_decimal;"),
+    ok = mysql:stop(Pid).
+
+float_as_decimal_2(_Pid) ->
+    %% Create another connection with {float_as_decimal, 2}.
+    %% Check that floats are sent as DECIMAL with 2 decimals.
+    {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password},
+                                  {log_warnings, true},
+                                  {float_as_decimal, 2}]),
+    ok = mysql:query(Pid, <<"USE otptest">>),
+    ok = mysql:query(Pid, <<"SET autocommit = 1">>),
+    ok = mysql:query(Pid, <<"SET SESSION sql_mode = ?">>, [?SQL_MODE]),
+    ok = mysql:query(Pid, <<"CREATE TABLE dec13_4 (d DECIMAL(13,4))">>),
+    ok = mysql:query(Pid, <<"INSERT INTO dec13_4 (d) VALUES (?)">>, [3.14159]),
+    {ok, _, [[Value]]} = mysql:query(Pid, <<"SELECT d FROM dec13_4">>),
+    ?assertEqual(3.14, Value),
+    ok = mysql:query(Pid, <<"DROP TABLE dec13_4">>),
+    ok = mysql:stop(Pid).
+
 int(Pid) ->
     ok = mysql:query(Pid, "CREATE TABLE ints (i INT)"),
     write_read_text_binary(Pid, 42, <<"42">>, <<"ints">>, <<"i">>),