Browse Source

Parameter validation for prepared statements and parameterized queries

The connection process crashed when invalid parameters were given to a
parameterized query or to the execution of a prepared statement.
juhlig 6 years ago
parent
commit
d3a4c78ea9
3 changed files with 122 additions and 7 deletions
  1. 14 4
      src/mysql.erl
  2. 99 1
      src/mysql_protocol.erl
  3. 9 2
      test/mysql_tests.erl

+ 14 - 4
src/mysql.erl

@@ -357,7 +357,13 @@ query(Conn, Query, Params, FilterMap) when (Params == no_params orelse
 query(Conn, Query, no_params, FilterMap, Timeout) ->
     query_call(Conn, {query, Query, FilterMap, Timeout});
 query(Conn, Query, Params, FilterMap, Timeout) ->
-    query_call(Conn, {param_query, Query, Params, FilterMap, Timeout}).
+    case mysql_protocol:valid_params(Params) of
+        true ->
+            query_call(Conn,
+                       {param_query, Query, Params, FilterMap, Timeout});
+        false ->
+            error(badarg)
+    end.
 
 %% @doc Executes a prepared statement with the default query timeout as given
 %% to start_link/1.
@@ -421,7 +427,13 @@ execute(Conn, StatementRef, Params, FilterMap) when FilterMap == no_filtermap_fu
        Timeout :: default_timeout | timeout(),
        Result :: query_result().
 execute(Conn, StatementRef, Params, FilterMap, Timeout) ->
-    query_call(Conn, {execute, StatementRef, Params, FilterMap, Timeout}).
+    case mysql_protocol:valid_params(Params) of
+        true ->
+            query_call(Conn,
+                       {execute, StatementRef, Params, FilterMap, Timeout});
+        false ->
+            error(badarg)
+    end.
 
 %% @doc Creates a prepared statement from the passed query.
 %% @see prepare/3
@@ -729,5 +741,3 @@ query_call(Conn, CallReq) ->
         Result ->
             Result
     end.
-
-

+ 99 - 1
src/mysql_protocol.erl

@@ -31,7 +31,7 @@
          query/4, query/5, fetch_query_response/3,
          fetch_query_response/4, prepare/3, unprepare/3,
          execute/5, execute/6, fetch_execute_response/3,
-         fetch_execute_response/4]).
+         fetch_execute_response/4, valid_params/1]).
 
 -type query_filtermap() :: no_filtermap_fun
                          | fun(([term()]) -> query_filtermap_res())
@@ -1013,6 +1013,45 @@ encode_param({D, {H, M, S}}) when is_float(S), S > 0.0, D < 0 ->
 encode_param({D, {H, M, 0.0}}) ->
     encode_param({D, {H, M, 0}}).
 
+%% @doc Checks if the given Parameters can be encoded for use in the
+%% binary protocol. Returns `true' if all of the parameters can be
+%% encoded, `false' if any of them cannot be encoded.
+-spec valid_params([term()]) -> boolean().
+valid_params(Values) when is_list(Values) ->
+    lists:all(fun is_valid_param/1, Values).
+
+%% @doc Checks if the given parameter can be encoded for use in the
+%% binary protocol.
+-spec is_valid_param(term()) -> boolean().
+is_valid_param(null) ->
+    true;
+is_valid_param(Value) when is_list(Value) ->
+    try
+        unicode:characters_to_binary(Value)
+    of
+        Value1 when is_binary(Value1) ->
+            true;
+        _ErrorOrIncomplete ->
+            false
+    catch
+        error:badarg ->
+            false
+    end;
+is_valid_param(Value) when is_number(Value) ->
+    true;
+is_valid_param(Value) when is_bitstring(Value) ->
+    true;
+is_valid_param({Y, M, D}) ->
+    is_integer(Y) andalso is_integer(M) andalso is_integer(D);
+is_valid_param({{Y, M, D}, {H, Mi, S}}) ->
+    is_integer(Y) andalso is_integer(M) andalso is_integer(D) andalso
+    is_integer(H) andalso is_integer(Mi) andalso is_number(S);
+is_valid_param({D, {H, M, S}}) ->
+    is_integer(D) andalso
+    is_integer(H) andalso is_integer(M) andalso is_number(S);
+is_valid_param(_) ->
+    false.
+
 %% -- Value representation in both the text and binary protocols --
 
 %% @doc Convert to `<<_:Length/bitstring>>'
@@ -1402,5 +1441,64 @@ hash_password_test() ->
                  hash_password(<<"foo">>, <<"abcdefghijklmnopqrst">>)),
     ?assertEqual(<<>>, hash_password(<<>>, <<"abcdefghijklmnopqrst">>)).
 
+valid_params_test() ->
+    ValidParams = [
+        null,
+        1,
+        0.5,
+        <<>>, <<$x>>, <<0:1>>,
+
+        %% valid unicode
+        [], [$x], [16#E4],
+
+        %% valid date
+        {1, 2, 3},
+
+        %% valid time
+        {1, {2, 3, 4}}, {1, {2, 3, 4.5}},
+
+        %% valid datetime
+        {{1, 2, 3}, {4, 5, 6}}, {{1, 2, 3}, {4, 5, 6.5}}
+    ],
+
+    InvalidParams = [
+        x,
+        [x],
+        {},
+        self(),
+        make_ref(),
+        fun () -> ok end,
+
+        %% invalid unicode
+        [16#FFFFFFFF],
+
+        %% invalid date
+        {x, 1, 2}, {1, x, 2}, {1, 2, x},
+
+        %% invalid time
+        {x, {1, 2, 3}}, {1, {x, 2, 3}},
+        {1, {2, x, 3}}, {1, {2, 3, x}},
+
+        %% invalid datetime
+        {{x, 1, 2}, {3, 4, 5}}, {{1, x, 2}, {3, 4, 5}},
+        {{1, 2, x}, {3, 4, 5}}, {{1, 2, 3}, {x, 4, 5}},
+        {{1, 2, 3}, {4, x, 5}}, {{1, 2, 3}, {4, 5, x}}
+    ],
+
+    lists:foreach(
+        fun (ValidParam) ->
+            ?assert(is_valid_param(ValidParam))
+        end,
+        ValidParams),
+    ?assert(valid_params(ValidParams)),
+
+    lists:foreach(
+        fun (InvalidParam) ->
+            ?assertNot(is_valid_param(InvalidParam))
+        end,
+        InvalidParams),
+    ?assertNot(valid_params(InvalidParams)),
+    ?assertNot(valid_params(ValidParams ++ InvalidParams)).
+
 -endif.
 

+ 9 - 2
test/mysql_tests.erl

@@ -213,7 +213,8 @@ query_test_() ->
           {"TIME",                 fun () -> time(Pid) end},
           {"DATETIME",             fun () -> datetime(Pid) end},
           {"JSON",                 fun () -> json(Pid) end},
-          {"Microseconds",         fun () -> microseconds(Pid) end}]
+          {"Microseconds",         fun () -> microseconds(Pid) end},
+          {"Invalid params",       fun () -> invalid_params(Pid) end}]
      end}.
 
 connect_with_db(_Pid) ->
@@ -684,6 +685,12 @@ test_datetime_microseconds(Pid) ->
                            <<"dt">>),
     ok = mysql:query(Pid, "DROP TABLE dt").
 
+invalid_params(Pid) ->
+    {ok, StmtId} = mysql:prepare(Pid, "SELECT ?"),
+    ?assertError(badarg, mysql:execute(Pid, StmtId, [x])),
+    ?assertError(badarg, mysql:query(Pid, "SELECT ?", [x])),
+    ok = mysql:unprepare(Pid, StmtId).
+
 %% @doc Tests write and read in text and the binary protocol, all combinations.
 %% This helper function assumes an empty table with a single column.
 write_read_text_binary(Conn, Term, SqlLiteral, Table, Column) ->
@@ -801,7 +808,7 @@ parameterized_query(Conn) ->
     {ok, _, []} = mysql:query(Conn, "SELECT * FROM foo WHERE bar = ?", [2]),
     receive after 150 -> ok end, %% Now the query cache should emptied
     {ok, _, []} = mysql:query(Conn, "SELECT * FROM foo WHERE bar = ?", [3]),
-    {error, {_, _, _}} = mysql:query(Conn, "Lorem ipsum dolor sit amet", [x]).
+    {error, {_, _, _}} = mysql:query(Conn, "Lorem ipsum dolor sit amet", [4]).
 
 %% --- simple gen_server callbacks ---