Browse Source

mysql:encode/2 with tests

Viktor Söderqvist 10 years ago
parent
commit
4643b47956
4 changed files with 146 additions and 1 deletions
  1. 21 1
      src/mysql.erl
  2. 67 0
      src/mysql_encode.erl
  3. 46 0
      test/mysql_encode_tests.erl
  4. 12 0
      test/mysql_tests.erl

+ 21 - 1
src/mysql.erl

@@ -26,7 +26,7 @@
 -export([start_link/1, query/2, query/3, query/4, execute/3, execute/4,
          prepare/2, prepare/3, unprepare/2,
          warning_count/1, affected_rows/1, autocommit/1, insert_id/1,
-         in_transaction/1,
+         encode/2, in_transaction/1,
          transaction/2, transaction/3, transaction/4]).
 
 -export_type([connection/0, server_reason/0]).
@@ -379,6 +379,24 @@ transaction(Conn, Fun, Args, Retries) when is_list(Args),
             {aborted, Aborted}
     end.
 
+%% @doc Encodes a term as a MySQL literal so that it can be used to inside a
+%% query. If backslash escapes are enabled, backslashes and single quotes in
+%% strings and binaries are escaped. Otherwise only single quotes are escaped.
+%%
+%% Note that the preferred way of sending values is by prepared statements or
+%% parametrized queries with placeholders.
+%%
+%% @see query/3
+%% @see execute/30
+-spec encode(connection(), term()) -> iodata().
+encode(Conn, Term) ->
+    Term1 = case (is_list(Term) orelse is_binary(Term)) andalso
+                 gen_server:call(Conn, backslash_escapes_enabled) of
+        true  -> mysql_encode:backslash_escape(Term);
+        false -> Term
+    end,
+    mysql_encode:encode(Term1).
+
 %% --- Gen_server callbacks ---
 
 -include("records.hrl").
@@ -628,6 +646,8 @@ handle_call(affected_rows, _From, State) ->
     {reply, State#state.affected_rows, State};
 handle_call(autocommit, _From, State) ->
     {reply, State#state.status band ?SERVER_STATUS_AUTOCOMMIT /= 0, State};
+handle_call(backslash_escapes_enabled, _From, State = #state{status = S}) ->
+    {reply, S band ?SERVER_STATUS_NO_BACKSLASH_ESCAPES == 0, State};
 handle_call(in_transaction, _From, State) ->
     {reply, State#state.status band ?SERVER_STATUS_IN_TRANS /= 0, State};
 handle_call(start_transaction, _From,

+ 67 - 0
src/mysql_encode.erl

@@ -0,0 +1,67 @@
+%% @private
+%% @doc Functions for encoding a term as an SQL literal. This is not really
+%% part of the protocol; thus the separate module.
+-module(mysql_encode).
+
+-export([encode/1, backslash_escape/1]).
+
+%% @doc Encodes a term as an ANSI SQL literal so that it can be used to inside
+%% a query. In strings only single quotes (') are escaped. If backslash escapes
+%% are enabled for the connection, you should first use backslash_escape/1 to
+%% escape backslashes in strings.
+-spec encode(term()) -> iodata().
+encode(null) -> <<"NULL">>;
+encode(Int) when is_integer(Int) ->
+    integer_to_binary(Int);
+encode(Float) when is_float(Float) ->
+    %% "floats are printed accurately as the shortest, correctly rounded string"
+    io_lib:format("~w", [Float]);
+encode(String) when is_list(String); is_binary(String) ->
+    Bin = iolist_to_binary(String),
+    Escaped = binary:replace(Bin, <<"'">>, <<"''">>),
+    [$', Escaped, $'];
+encode(Bitstring) when is_bitstring(Bitstring) ->
+    ["b'", [ case B of 0 -> $0; 1 -> $1 end || <<B:1>> <= Bitstring ], $'];
+encode({Y, M, D}) ->
+    io_lib:format("'~4..0b-~2..0b-~2..0b'", [Y, M, D]);
+encode({{Y, M, D}, {H, Mi, S}}) when is_integer(S) ->
+    io_lib:format("'~4..0b-~2..0b-~2..0b ~2..0b:~2..0b:~2..0b'",
+                  [Y, M, D, H, Mi, S]);
+encode({{Y, M, D}, {H, Mi, S}}) when is_float(S) ->
+    io_lib:format("'~4..0b-~2..0b-~2..0b ~2..0b:~2..0b:~9.6.0f'",
+                  [Y, M, D, H, Mi, S]);
+encode({D, {H, M, S}}) when D >= 0 ->
+    Args = [H1 = D * 24 + H, M, S],
+    if
+        H1 > 99, is_integer(S) -> io_lib:format("'~b:~2..0b:~2..0b'", Args);
+        H1 > 99, is_float(S)   -> io_lib:format("'~b:~2..0b:~9.6.0f'", Args);
+        is_integer(S)          -> io_lib:format("'~2..0b:~2..0b:~2..0b'", Args);
+        is_float(S)            -> io_lib:format("'~2..0b:~2..0b:~9.6.0f'", Args)
+    end;
+encode({D, {H, M, S}}) when D < 0, is_integer(S) ->
+    Sec = (D * 24 + H) * 3600 + M * 60 + S,
+    {D1, {H1, M1, S1}} = calendar:seconds_to_daystime(-Sec),
+    Args = [H2 = D1 * 24 + H1, M1, S1],
+    if
+        H2 > 99 -> io_lib:format("'-~b:~2..0b:~2..0b'", Args);
+        true    -> io_lib:format("'-~2..0b:~2..0b:~2..0b'", Args)
+    end;
+encode({D, {H, M, S}}) when D < 0, is_float(S) ->
+    SInt = trunc(S), % trunc(57.654321) = 57
+    {SInt1, Frac} = case S - SInt of % 57.6543 - 57 = 0.654321
+        0.0  -> {SInt, 0.0};
+        Rest -> {SInt + 1, 1 - Rest} % {58, 0.345679}
+    end,
+    Sec = (D * 24 + H) * 3600 + M * 60 + SInt1,
+    {D1, {H1, M1, S1}} = calendar:seconds_to_daystime(-Sec),
+    Args = [H2 = D1 * 24 + H1, M1, S1 + Frac],
+    if
+        H2 > 99 -> io_lib:format("'-~b:~2..0b:~9.6.0f'", Args);
+        true    -> io_lib:format("'-~2..0b:~2..0b:~9.6.0f'", Args)
+    end.
+
+%% @doc Escapes backslashes with an extra backslash. This is necessary if
+%% backslash escapes are enabled in the session.
+backslash_escape(String) ->
+    Bin = iolist_to_binary(String),
+    binary:replace(Bin, <<"\\">>, <<"\\\\">>).

+ 46 - 0
test/mysql_encode_tests.erl

@@ -0,0 +1,46 @@
+%% @doc This test suite does not require an actual MySQL connection.
+-module(mysql_encode_tests).
+-include_lib("eunit/include/eunit.hrl").
+
+encode_test() ->
+    lists:foreach(
+        fun ({Term, Sql}) ->
+            ?assertEqual(iolist_to_binary(Sql),
+                         iolist_to_binary(mysql_encode:encode(Term)))
+        end,
+        [{null,    "NULL"},
+         {42,      "42"},
+         {3.14,    "3.14"},
+         {"don't", "'don''t'"}, %% Escape single quote using single quote.
+         {"\\n",   "'\\n'"},    %% Don't escape backslash.
+         %% BIT(N)
+         {<<255, 2:3>>,   "b'11111111010'"},
+         %% DATE
+         {{2014, 11, 03}, "'2014-11-03'"},
+         {{0, 0, 0},      "'0000-00-00'"},
+         %% TIME
+         {{0, {10, 11, 12}},   "'10:11:12'"},
+         {{5, {0, 0, 1}},     "'120:00:01'"},
+         {{-1, {23, 59, 59}}, "'-00:00:01'"},
+         {{-1, {23, 59, 0}},  "'-00:01:00'"},
+         {{-1, {23, 0, 0}},   "'-01:00:00'"},
+         {{-1, {0, 0, 0}},    "'-24:00:00'"},
+         {{-5, {10, 0, 0}},  "'-110:00:00'"},
+         {{0, {0, 0, 0}},      "'00:00:00'"},
+         %% TIME with microseconds
+         {{0, {23, 59, 57.654321}},   "'23:59:57.654321'"},
+         {{5, {0, 0, 1.1}},          "'120:00:01.100000'"},
+         {{-1, {23, 59, 57.654321}}, "'-00:00:02.345679'"},
+         {{-1, {23, 59,  0.0}},      "'-00:01:00.000000'"},
+         {{-6, {23, 59, 57.0}},     "'-120:00:03.000000'"},
+         %% DATETIME
+         {{{2014, 12, 14}, {19, 39, 20}},   "'2014-12-14 19:39:20'"},
+         {{{2014, 12, 14}, {0, 0, 0}},      "'2014-12-14 00:00:00'"},
+         {{{0, 0, 0}, {0, 0, 0}},           "'0000-00-00 00:00:00'"},
+         %% DATETIME with microseconds
+         {{{2014, 11, 23}, {23, 59, 57.654321}}, "'2014-11-23 23:59:57.654321'"}]
+    ).
+
+backslash_escape_test() ->
+    ?assertEqual(<<"a'b\\\\c">>,
+                 iolist_to_binary(mysql_encode:backslash_escape("a'b\\c"))).

+ 12 - 0
test/mysql_tests.erl

@@ -108,6 +108,7 @@ query_test_() ->
      fun (Pid) ->
          [{"Select db on connect", fun () -> connect_with_db(Pid) end},
           {"Autocommit",           fun () -> autocommit(Pid) end},
+          {"Encode",               fun () -> encode(Pid) end},
           {"Basic queries",        fun () -> basic_queries(Pid) end},
           {"Text protocol",        fun () -> text_protocol(Pid) end},
           {"Binary protocol",      fun () -> binary_protocol(Pid) end},
@@ -157,6 +158,17 @@ autocommit(Pid) ->
     ok = mysql:query(Pid, <<"SET autocommit = 1">>),
     ?assert(mysql:autocommit(Pid)).
 
+encode(Pid) ->
+    %% Test with backslash escapes enabled and disabled.
+    {ok, _, [[OldMode]]} = mysql:query(Pid, "SELECT @@sql_mode"),
+    ok = mysql:query(Pid, "SET sql_mode = ''"),
+    ?assertEqual(<<"'foo\\\\bar''baz'">>,
+                 iolist_to_binary(mysql:encode(Pid, "foo\\bar'baz"))),
+    ok = mysql:query(Pid, "SET sql_mode = 'NO_BACKSLASH_ESCAPES'"),
+    ?assertEqual(<<"'foo\\bar''baz'">>,
+                 iolist_to_binary(mysql:encode(Pid, "foo\\bar'baz"))),
+    ok = mysql:query(Pid, "SET sql_mode = ?", [OldMode]).
+
 basic_queries(Pid) ->
 
     %% warning count