Browse Source

More tests fixes #2

Viktor Söderqvist 10 years ago
parent
commit
67b61d6496
6 changed files with 61 additions and 27 deletions
  1. 3 1
      include/records.hrl
  2. 4 9
      src/mysql.erl
  3. 8 8
      src/mysql_protocol.erl
  4. 9 5
      test/mysql_protocol_tests.erl
  5. 35 2
      test/mysql_tests.erl
  6. 2 2
      test/transaction_tests.erl

+ 3 - 1
include/records.hrl

@@ -46,7 +46,9 @@
 %% received from the server using either the text protocol or the binary
 %% protocol.
 -record(resultset, {cols :: [#col{}],
-                    rows :: [[term()] | binary()]}).
+                    rows :: [[term()] | binary()],
+                    status :: integer(),
+                    warning_count :: integer()}).
 
 %% Response of a successfull prepare call.
 -record(prepared, {statement_id :: integer(),

+ 4 - 9
src/mysql.erl

@@ -661,12 +661,7 @@ handle_call(commit, _From, State = #state{socket = Socket, status = Status,
     end,
     Res = #ok{} = mysql_protocol:query(Query, gen_tcp, Socket, ?cmd_timeout),
     State1 = update_state(State, Res),
-    {reply, ok, State1#state{transaction_level = L - 1}};
-handle_call(Trans, _From, State) when Trans == start_transaction;
-                                      Trans == rollback;
-                                      Trans == commit ->
-    %% The 'in transaction' flag doesn't match the level we have in the state.
-    {reply, {error, incorrectly_nested}, State}.
+    {reply, ok, State1#state{transaction_level = L - 1}}.
 
 %% @private
 handle_cast(_Msg, State) ->
@@ -767,12 +762,12 @@ update_state(State, Rec) ->
         #ok{status = S, affected_rows = R, insert_id = Id, warning_count = W} ->
             State#state{status = S, affected_rows = R, insert_id = Id,
                         warning_count = W};
-        %#eof{status = S, warning_count = W} ->
-        %    State#state{status = S, warning_count = W, affected_rows = 0};
+        #resultset{status = S, warning_count = W} ->
+            State#state{status = S, warning_count = W};
         #prepared{warning_count = W} ->
             State#state{warning_count = W};
         _Other ->
-            %% This includes errors, resultsets, etc.
+            %% This includes errors.
             %% Reset some things. (Note: We don't reset status and insert_id.)
             State#state{warning_count = 0, affected_rows = 0}
     end,

+ 8 - 8
src/mysql_protocol.erl

@@ -30,8 +30,8 @@
          query/4, fetch_query_response/3,
          prepare/3, unprepare/3, execute/5, fetch_execute_response/3]).
 
-%% How much data do we want to send at most?
--define(MAX_BYTES_PER_PACKET, 50000000).
+%% How much data do we want per packet?
+-define(MAX_BYTES_PER_PACKET, 16#1000000).
 
 -include("records.hrl").
 -include("protocol.hrl").
@@ -237,9 +237,8 @@ parse_handshake(<<10, Rest/binary>>) ->
     %% "Due to Bug#59453 the auth-plugin-name is missing the terminating
     %% NUL-char in versions prior to 5.5.10 and 5.6.2."
     %% Strip the final NUL byte if any.
-    NameLen = size(AuthPluginName) - 1,
-    AuthPluginName1 = case AuthPluginName of
-        <<NameNoNul:NameLen/binary-unit:8, 0>> -> NameNoNul;
+    AuthPluginName1 = case binary:last(AuthPluginName) of
+        0 -> binary:part(AuthPluginName, 0, byte_size(AuthPluginName) - 1);
         _ -> AuthPluginName
     end,
     #handshake{server_version = server_version_to_list(ServerVersion),
@@ -311,7 +310,7 @@ parse_handshake_confirm(Packet) ->
             %% Connection complete.
             parse_ok_packet(Packet);
         ?error_pattern ->
-            %% "Insufficient Client Capabilities"
+            %% Access denied, insufficient client capabilities, etc.
             parse_error_packet(Packet);
         <<?EOF>> ->
             %% "Old Authentication Method Switch Request Packet consisting of a
@@ -338,11 +337,12 @@ fetch_resultset(TcpModule, Socket, FieldCount, SeqNum) ->
     {ok, ColDefs, SeqNum1} = fetch_column_definitions(TcpModule, Socket, SeqNum,
                                                       FieldCount, []),
     {ok, DelimiterPacket, SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
-    #eof{} = parse_eof_packet(DelimiterPacket),
+    #eof{status = S, warning_count = W} = parse_eof_packet(DelimiterPacket),
     case fetch_resultset_rows(TcpModule, Socket, SeqNum2, []) of
         {ok, Rows, _SeqNum3} ->
             ColDefs1 = lists:map(fun parse_column_definition/1, ColDefs),
-            #resultset{cols = ColDefs1, rows = Rows};
+            #resultset{cols = ColDefs1, rows = Rows,
+                       status = S, warning_count = W};
         #error{} = E ->
             E
     end.

+ 9 - 5
test/mysql_protocol_tests.erl

@@ -17,6 +17,10 @@
 %% along with this program. If not, see <https://www.gnu.org/licenses/>.
 
 %% @doc Eunit test cases for the mysql_protocol module.
+%% Most of the hexdump tests are from examples in the protocol documentation.
+%%
+%% TODO: Use ngrep -x -q -d lo '' 'port 3306' to dump traffic using various
+%% server versions.
 -module(mysql_protocol_tests).
 
 -include_lib("eunit/include/eunit.hrl").
@@ -107,6 +111,11 @@ prepare_test() ->
                  Result),
     ok.
     
+bad_protocol_version_test() ->
+    Sock = mock_tcp:create([{recv, <<2, 0, 0, 0, 9, 0>>}]),
+    ?assertError(unknown_protocol,
+                 mysql_protocol:handshake("foo", "bar", "db", mock_tcp, Sock)),
+    mock_tcp:close(Sock).
 
 %% --- Helper functions for the above tests ---
 
@@ -141,8 +150,3 @@ hexdump_to_bin_test() ->
                16#65, 16#63, 16#74, 16#20, 16#55, 16#53, 16#45, 16#52,
                16#28, 16#29>>,
     ?assertEqual(Expect, hexdump_to_bin(HexDump)).
-
-%% --- Fake socket ---
-%%
-
-

+ 35 - 2
test/mysql_tests.erl

@@ -38,7 +38,18 @@
                           "  c CHAR(2)"
                           ") ENGINE=InnoDB">>).
 
-connect_test() ->
+failing_connect_test() ->
+    process_flag(trap_exit, true),
+    ?assertMatch({error, {1045, <<"28000">>, <<"Access denied", _/binary>>}},
+                 mysql:start_link([{user, "dummy"}, {password, "junk"}])),
+    receive
+        {'EXIT', _Pid, {1045, <<"28000">>, <<"Access denie", _/binary>>}} -> ok
+    after 1000 ->
+        ?assertEqual(ok, no_exit_message)
+    end,
+    process_flag(trap_exit, false).
+
+successful_connect_test() ->
     %% A connection with a registered name
     Options = [{name, {local, tardis}}, {user, ?user}, {password, ?password}],
     {ok, Pid} = mysql:start_link(Options),
@@ -125,6 +136,27 @@ connect_with_db(_Pid) ->
                  LoggedErrors),
     exit(Pid, normal).
 
+log_warnings_test() ->
+    {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password}]),
+    ok = mysql:query(Pid, <<"CREATE DATABASE otptest">>),
+    ok = mysql:query(Pid, <<"USE otptest">>),
+    %% Capture error log to check that we get a warning logged
+    ok = mysql:query(Pid, "CREATE TABLE foo (x INT NOT NULL)"),
+    {ok, insrt} = mysql:prepare(Pid, insrt, "INSERT INTO foo () VALUES ()"),
+    {ok, ok, LoggedErrors} = error_logger_acc:capture(fun () ->
+        ok = mysql:query(Pid, "INSERT INTO foo () VALUES ()"),
+        ok = mysql:query(Pid, "INSeRT INtO foo () VaLUeS ()", []),
+        ok = mysql:execute(Pid, insrt, [])
+    end),
+    [{_, Log1}, {_, Log2}, {_, Log3}] = LoggedErrors,
+    ?assertEqual("Warning 1364: Field 'x' doesn't have a default value\n"
+                 " in INSERT INTO foo () VALUES ()\n", Log1),
+    ?assertEqual("Warning 1364: Field 'x' doesn't have a default value\n"
+                 " in INSeRT INtO foo () VaLUeS ()\n", Log2),
+    ?assertEqual("Warning 1364: Field 'x' doesn't have a default value\n"
+                 " in prepared statement insrt\n", Log3),
+    exit(Pid, normal).
+
 autocommit(Pid) ->
     ?assert(mysql:autocommit(Pid)),
     ok = mysql:query(Pid, <<"SET autocommit = 0">>),
@@ -505,7 +537,8 @@ parameterized_query(Conn) ->
     {ok, _, []} = mysql:query(Conn, "SELECT * FROM foo WHERE bar = ?", [1]),
     {ok, _, []} = mysql:query(Conn, "SELECT * FROM foo WHERE bar = ?", [2]),
     receive after 150 -> ok end, %% Now the query cache should emptied
-    {ok, _, []} = mysql:query(Conn, "SELECT * FROM foo WHERE bar = ?", [3]).
+    {ok, _, []} = mysql:query(Conn, "SELECT * FROM foo WHERE bar = ?", [3]),
+    {error, {_, _, _}} = mysql:query(Conn, "Lorem ipsum dolor sit amet", [x]).
 
 %% --- simple gen_server callbacks ---
 

+ 2 - 2
test/transaction_tests.erl

@@ -225,7 +225,7 @@ deadlock_prepared_statements({Conn1, Conn2}) ->
                 ok = mysql:execute(Conn2, upd, [2, 1])
             end),
             ok
-        end),
+        end, 2),
         MainPid ! done
     end),
 
@@ -241,7 +241,7 @@ deadlock_prepared_statements({Conn1, Conn2}) ->
             ok = mysql:execute(Conn1, upd, [1, 2])
         end),
         ok
-    end),
+    end, 2),
 
     %% Wait for a reply from worker 2.
     receive done -> ok end,