Browse Source

Graceful timeout handling using KILL QUERY

Viktor Söderqvist 10 years ago
parent
commit
20b1a8c660
6 changed files with 339 additions and 282 deletions
  1. 5 0
      README.md
  2. 1 1
      include/records.hrl
  3. 162 74
      src/mysql.erl
  4. 131 103
      src/mysql_protocol.erl
  5. 9 104
      test/mysql_protocol_tests.erl
  6. 31 0
      test/mysql_tests.erl

+ 5 - 0
README.md

@@ -16,6 +16,7 @@ Features:
 * Each connection is a gen_server, which makes it compatible with Poolboy (for
 * Each connection is a gen_server, which makes it compatible with Poolboy (for
   connection pooling) and ordinary OTP supervisors.
   connection pooling) and ordinary OTP supervisors.
 * No records in the public API.
 * No records in the public API.
+* Query timeouts don't kill the connection (MySQL version ≥ 5.0.0).
 
 
 See also:
 See also:
 
 
@@ -60,6 +61,10 @@ case Result of
     {aborted, Reason} ->
     {aborted, Reason} ->
         io:format("Inserted 0 rows.~n")
         io:format("Inserted 0 rows.~n")
 end
 end
+
+%% Graceful timeout handling: SLEEP() returns 1 when interrupted
+{ok, [<<"SLEEP(5)">>], [[1]]} =
+    mysql:query(Pid, <<"SELECT SLEEP(5)">>, 1000),
 ```
 ```
 
 
 Tests
 Tests

+ 1 - 1
include/records.hrl

@@ -19,7 +19,7 @@
 %% --- Records ---
 %% --- Records ---
 
 
 %% Returned by parse_handshake/1.
 %% Returned by parse_handshake/1.
--record(handshake, {server_version :: binary(),
+-record(handshake, {server_version :: [integer()],
                     connection_id :: integer(),
                     connection_id :: integer(),
                     capabilities :: integer(),
                     capabilities :: integer(),
                     character_set :: integer(),
                     character_set :: integer(),

+ 162 - 74
src/mysql.erl

@@ -23,7 +23,7 @@
 %% gen_server is locally registered.
 %% gen_server is locally registered.
 -module(mysql).
 -module(mysql).
 
 
--export([start_link/1, query/2, query/3, execute/3,
+-export([start_link/1, query/2, query/3, query/4, execute/3, execute/4,
          prepare/2, prepare/3, unprepare/2,
          prepare/2, prepare/3, unprepare/2,
          warning_count/1, affected_rows/1, autocommit/1, insert_id/1,
          warning_count/1, affected_rows/1, autocommit/1, insert_id/1,
          in_transaction/1,
          in_transaction/1,
@@ -39,7 +39,8 @@
 -define(default_port, 3306).
 -define(default_port, 3306).
 -define(default_user, <<>>).
 -define(default_user, <<>>).
 -define(default_password, <<>>).
 -define(default_password, <<>>).
--define(default_timeout, infinity).
+-define(default_connect_timeout, 5000).
+-define(default_query_timeout, infinity).
 -define(default_query_cache_time, 60000). %% for query/3.
 -define(default_query_cache_time, 60000). %% for query/3.
 
 
 %% A connection is a ServerRef as in gen_server:call/2,3.
 %% A connection is a ServerRef as in gen_server:call/2,3.
@@ -74,6 +75,12 @@
 %%   <dt>`{database, Database}'</dt>
 %%   <dt>`{database, Database}'</dt>
 %%   <dd>The name of the database AKA schema to use. This can be changed later
 %%   <dd>The name of the database AKA schema to use. This can be changed later
 %%       using the query `USE <database>'.</dd>
 %%       using the query `USE <database>'.</dd>
+%%   <dt>`{connect_timeout, Timeout}'</dt>
+%%   <dd>The maximum time to spend for start_link/1.</dd>
+%%   <dt>`{query_timeout, Timeout}'</dt>
+%%   <dd>The default time to wait for a response when executing a query or a
+%%       prepared statement. This can be given per query using `query/3,4' and
+%%       `execute/4'. The default is `infinity'.</dd>
 %%   <dt>`{query_cache_time, Timeout}'</dt>
 %%   <dt>`{query_cache_time, Timeout}'</dt>
 %%   <dd>The minimum number of milliseconds to cache prepared statements used
 %%   <dd>The minimum number of milliseconds to cache prepared statements used
 %%       for parametrized queries with query/3.</dd>
 %%       for parametrized queries with query/3.</dd>
@@ -83,19 +90,23 @@
          Option :: {name, ServerName} | {host, iodata()} | {port, integer()} | 
          Option :: {name, ServerName} | {host, iodata()} | {port, integer()} | 
                    {user, iodata()} | {password, iodata()} |
                    {user, iodata()} | {password, iodata()} |
                    {database, iodata()} |
                    {database, iodata()} |
+                   {connect_timeout, timeout()} |
+                   {query_timeout, timeout()} |
                    {query_cache_time, non_neg_integer()},
                    {query_cache_time, non_neg_integer()},
          ServerName :: {local, Name :: atom()} |
          ServerName :: {local, Name :: atom()} |
                        {global, GlobalName :: term()} |
                        {global, GlobalName :: term()} |
                        {via, Module :: atom(), ViaName :: term()}.
                        {via, Module :: atom(), ViaName :: term()}.
 start_link(Options) ->
 start_link(Options) ->
+    GenSrvOpts = [{timeout, proplists:get_value(connect_timeout, Options,
+                                                ?default_connect_timeout)}],
     case proplists:get_value(name, Options) of
     case proplists:get_value(name, Options) of
         undefined ->
         undefined ->
-            gen_server:start_link(?MODULE, Options, []);
+            gen_server:start_link(?MODULE, Options, GenSrvOpts);
         ServerName ->
         ServerName ->
-            gen_server:start_link(ServerName, ?MODULE, Options, [])
+            gen_server:start_link(ServerName, ?MODULE, Options, GenSrvOpts)
     end.
     end.
 
 
-%% @doc Executes a query.
+%% @doc Executes a query with the query timeout as given to start_link/1.
 -spec query(Conn, Query) -> ok | {ok, ColumnNames, Rows} | {error, Reason}
 -spec query(Conn, Query) -> ok | {ok, ColumnNames, Rows} | {error, Reason}
     when Conn :: connection(),
     when Conn :: connection(),
          Query :: iodata(),
          Query :: iodata(),
@@ -103,26 +114,53 @@ start_link(Options) ->
          Rows :: [[term()]],
          Rows :: [[term()]],
          Reason :: server_reason().
          Reason :: server_reason().
 query(Conn, Query) ->
 query(Conn, Query) ->
-    gen_server:call(Conn, {query, Query}).
+    gen_server:call(Conn, {query, Query}, infinity).
 
 
-%% @doc Executes a parameterized query. A prepared statement is created,
-%% executed and then cached for a certain time. If the same query is executed
-%% again when it is already cached, it does not need to be prepared again.
+%% @doc Depending on the 3rd argument this function does different things.
+%%
+%% If the 3rd argument is a list, it executes a parameterized query. This is
+%% equivallent to query/4 with the query timeout as given to start_link/1.
+%%
+%% If the 3rd argument is a timeout, it executes a plain query with this
+%% timeout.
+%% @see query/2.
+%% @see query/4.
+-spec query(Conn, Query, Params | Timeout) -> ok | {ok, ColumnNames, Rows} |
+                                              {error, Reason}
+    when Conn :: connection(),
+         Query :: iodata(),
+         Timeout :: timeout(),
+         Params :: [term()],
+         ColumnNames :: [binary()],
+         Rows :: [[term()]],
+         Reason :: server_reason().
+query(Conn, Query, Params) when is_list(Params) ->
+    gen_server:call(Conn, {param_query, Query, Params}, infinity);
+query(Conn, Query, Timeout) when is_integer(Timeout); Timeout == infinity ->
+    gen_server:call(Conn, {query, Query, Timeout}, infinity).
+
+%% @doc Executes a parameterized query with a timeout.
+%%
+%% A prepared statement is created, executed and then cached for a certain
+%% time. If the same query is executed again when it is already cached, it does
+%% not need to be prepared again.
 %%
 %%
 %% The minimum time the prepared statement is cached can be specified using the
 %% The minimum time the prepared statement is cached can be specified using the
 %% option `{query_cache_time, Milliseconds}' to start_link/1.
 %% option `{query_cache_time, Milliseconds}' to start_link/1.
--spec query(Conn, Query, Params) -> ok | {ok, ColumnNames, Rows} |
-                                    {error, Reason}
+-spec query(Conn, Query, Params, Timeout) -> ok | {ok, ColumnNames, Rows} |
+                                             {error, Reason}
     when Conn :: connection(),
     when Conn :: connection(),
          Query :: iodata(),
          Query :: iodata(),
+         Timeout :: timeout(),
          Params :: [term()],
          Params :: [term()],
          ColumnNames :: [binary()],
          ColumnNames :: [binary()],
          Rows :: [[term()]],
          Rows :: [[term()]],
          Reason :: server_reason().
          Reason :: server_reason().
-query(Conn, Query, Params) when is_list(Params) ->
-    gen_server:call(Conn, {query, Query, Params}).
+query(Conn, Query, Params, Timeout) ->
+    gen_server:call(Conn, {param_query, Query, Params, Timeout}, infinity).
 
 
-%% @doc Executes a prepared statement.
+%% @doc Executes a prepared statement with the default query timeout as given
+%% to start_link/1.
 %% @see prepare/2
 %% @see prepare/2
 %% @see prepare/3
 %% @see prepare/3
 -spec execute(Conn, StatementRef, Params) ->
 -spec execute(Conn, StatementRef, Params) ->
@@ -134,7 +172,22 @@ query(Conn, Query, Params) when is_list(Params) ->
        Rows :: [[term()]],
        Rows :: [[term()]],
        Reason :: server_reason() | not_prepared.
        Reason :: server_reason() | not_prepared.
 execute(Conn, StatementRef, Params) ->
 execute(Conn, StatementRef, Params) ->
-    gen_server:call(Conn, {execute, StatementRef, Params}).
+    gen_server:call(Conn, {execute, StatementRef, Params}, infinity).
+
+%% @doc Executes a prepared statement.
+%% @see prepare/2
+%% @see prepare/3
+-spec execute(Conn, StatementRef, Params, Timeout) ->
+    ok | {ok, ColumnNames, Rows} | {error, Reason}
+  when Conn :: connection(),
+       StatementRef :: atom() | integer(),
+       Params :: [term()],
+       Timeout :: timeout(),
+       ColumnNames :: [binary()],
+       Rows :: [[term()]],
+       Reason :: server_reason() | not_prepared.
+execute(Conn, StatementRef, Params, Timeout) ->
+    gen_server:call(Conn, {execute, StatementRef, Params, Timeout}, infinity).
 
 
 %% @doc Creates a prepared statement from the passed query.
 %% @doc Creates a prepared statement from the passed query.
 %% @see prepare/3
 %% @see prepare/3
@@ -271,19 +324,21 @@ transaction(Conn, Fun, Args) when is_list(Args),
 -include("server_status.hrl").
 -include("server_status.hrl").
 
 
 %% Gen_server state
 %% Gen_server state
--record(state, {socket, timeout = infinity, affected_rows = 0, status = 0,
-                warning_count = 0, insert_id = 0, stmts = dict:new(),
-                query_cache_time, query_cache = empty}).
+-record(state, {server_version, connection_id, socket,
+                host, port, user, password,
+                query_timeout, query_cache_time,
+                affected_rows = 0, status = 0, warning_count = 0, insert_id = 0,
+                stmts = dict:new(), query_cache = empty}).
 
 
 %% @private
 %% @private
 init(Opts) ->
 init(Opts) ->
     %% Connect
     %% Connect
-    Host     = proplists:get_value(host,     Opts, ?default_host),
-    Port     = proplists:get_value(port,     Opts, ?default_port),
-    User     = proplists:get_value(user,     Opts, ?default_user),
-    Password = proplists:get_value(password, Opts, ?default_password),
-    Database = proplists:get_value(database, Opts, undefined),
-    Timeout  = proplists:get_value(timeout,  Opts, ?default_timeout),
+    Host     = proplists:get_value(host,          Opts, ?default_host),
+    Port     = proplists:get_value(port,          Opts, ?default_port),
+    User     = proplists:get_value(user,          Opts, ?default_user),
+    Password = proplists:get_value(password,      Opts, ?default_password),
+    Database = proplists:get_value(database,      Opts, undefined),
+    Timeout  = proplists:get_value(query_timeout, Opts, ?default_query_timeout),
     QueryCacheTime = proplists:get_value(query_cache_time, Opts,
     QueryCacheTime = proplists:get_value(query_cache_time, Opts,
                                          ?default_query_cache_time),
                                          ?default_query_cache_time),
 
 
@@ -292,29 +347,40 @@ init(Opts) ->
     {ok, Socket} = gen_tcp:connect(Host, Port, SockOpts),
     {ok, Socket} = gen_tcp:connect(Host, Port, SockOpts),
 
 
     %% Exchange handshake communication.
     %% Exchange handshake communication.
-    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
-    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
-    Result = mysql_protocol:handshake(User, Password, Database, SendFun,
-                                      RecvFun),
+    Result = mysql_protocol:handshake(User, Password, Database, gen_tcp,
+                                      Socket),
     case Result of
     case Result of
-        #ok{} = OK ->
-            State = #state{socket = Socket, timeout = Timeout,
+        #handshake{server_version = Version, connection_id = ConnId,
+                   status = Status} ->
+            State = #state{server_version = Version, connection_id = ConnId,
+                           socket = Socket,
+                           host = Host, port = Port, user = User,
+                           password = Password, status = Status,
+                           query_timeout = Timeout,
                            query_cache_time = QueryCacheTime},
                            query_cache_time = QueryCacheTime},
-            State1 = update_state(State, OK),
             %% Trap exit so that we can properly disconnect when we die.
             %% Trap exit so that we can properly disconnect when we die.
             process_flag(trap_exit, true),
             process_flag(trap_exit, true),
-            {ok, State1};
+            {ok, State};
         #error{} = E ->
         #error{} = E ->
             {stop, error_to_reason(E)}
             {stop, error_to_reason(E)}
     end.
     end.
 
 
 %% @private
 %% @private
-handle_call({query, Query}, _From, State) when is_binary(Query);
-                                               is_list(Query) ->
-    #state{socket = Socket, timeout = Timeout} = State,
-    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
-    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
-    Rec = mysql_protocol:query(Query, SendFun, RecvFun),
+handle_call({query, Query}, From, State) ->
+    handle_call({query, Query, State#state.query_timeout}, From, State);
+handle_call({query, Query, Timeout}, _From, State) ->
+    Socket = State#state.socket,
+    Rec = case mysql_protocol:query(Query, gen_tcp, Socket, Timeout) of
+        {error, timeout} when State#state.server_version >= [5, 0, 0] ->
+            kill_query(State),
+            mysql_protocol:fetch_query_response(gen_tcp, Socket, infinity);
+        {error, timeout} ->
+            %% For MySQL 4.x.x there is no way to recover from timeout except
+            %% killing the connection itself.
+            exit(timeout);
+        QueryResult ->
+            QueryResult
+    end,
     State1 = update_state(State, Rec),
     State1 = update_state(State, Rec),
     case Rec of
     case Rec of
         #ok{} ->
         #ok{} ->
@@ -325,12 +391,13 @@ handle_call({query, Query}, _From, State) when is_binary(Query);
             Names = [Def#col.name || Def <- ColDefs],
             Names = [Def#col.name || Def <- ColDefs],
             {reply, {ok, Names, Rows}, State1}
             {reply, {ok, Names, Rows}, State1}
     end;
     end;
-handle_call({query, Query, Params}, _From, State) when is_list(Params) ->
-    %% Parametrized query = anonymous prepared statement
+handle_call({param_query, Query, Params}, From, State) ->
+    handle_call({param_query, Query, Params, State#state.query_timeout}, From,
+                State);
+handle_call({param_query, Query, Params, Timeout}, _From, State) ->
+    %% Parametrized query: Prepared statement cached with the query as the key
     QueryBin = iolist_to_binary(Query),
     QueryBin = iolist_to_binary(Query),
-    #state{socket = Socket, timeout = Timeout} = State,
-    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
-    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+    #state{socket = Socket} = State,
     Cache = State#state.query_cache,
     Cache = State#state.query_cache,
     {StmtResult, Cache1} = case mysql_cache:lookup(QueryBin, Cache) of
     {StmtResult, Cache1} = case mysql_cache:lookup(QueryBin, Cache) of
         {found, FoundStmt, NewCache} ->
         {found, FoundStmt, NewCache} ->
@@ -338,7 +405,7 @@ handle_call({query, Query, Params}, _From, State) when is_list(Params) ->
             {{ok, FoundStmt}, NewCache};
             {{ok, FoundStmt}, NewCache};
         not_found ->
         not_found ->
             %% Prepare
             %% Prepare
-            Rec = mysql_protocol:prepare(Query, SendFun, RecvFun),
+            Rec = mysql_protocol:prepare(Query, gen_tcp, Socket),
             %State1 = update_state(State, Rec),
             %State1 = update_state(State, Rec),
             case Rec of
             case Rec of
                 #error{} = E ->
                 #error{} = E ->
@@ -355,23 +422,22 @@ handle_call({query, Query, Params}, _From, State) when is_list(Params) ->
     case StmtResult of
     case StmtResult of
         {ok, StmtRec} ->
         {ok, StmtRec} ->
             State1 = State#state{query_cache = Cache1},
             State1 = State#state{query_cache = Cache1},
-            execute_stmt(StmtRec, Params, State1);
+            execute_stmt(StmtRec, Params, Timeout, State1);
         PrepareError ->
         PrepareError ->
             {reply, PrepareError, State}
             {reply, PrepareError, State}
     end;
     end;
-handle_call({execute, Stmt, Args}, _From, State) when is_atom(Stmt);
-                                                      is_integer(Stmt) ->
+handle_call({execute, Stmt, Args}, From, State) ->
+    handle_call({execute, Stmt, Args, State#state.query_timeout}, From, State);
+handle_call({execute, Stmt, Args, Timeout}, _From, State) ->
     case dict:find(Stmt, State#state.stmts) of
     case dict:find(Stmt, State#state.stmts) of
         {ok, StmtRec} ->
         {ok, StmtRec} ->
-            execute_stmt(StmtRec, Args, State);
+            execute_stmt(StmtRec, Args, Timeout, State);
         error ->
         error ->
             {reply, {error, not_prepared}, State}
             {reply, {error, not_prepared}, State}
     end;
     end;
 handle_call({prepare, Query}, _From, State) ->
 handle_call({prepare, Query}, _From, State) ->
-    #state{socket = Socket, timeout = Timeout} = State,
-    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
-    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
-    Rec = mysql_protocol:prepare(Query, SendFun, RecvFun),
+    #state{socket = Socket} = State,
+    Rec = mysql_protocol:prepare(Query, gen_tcp, Socket),
     State1 = update_state(State, Rec),
     State1 = update_state(State, Rec),
     case Rec of
     case Rec of
         #error{} = E ->
         #error{} = E ->
@@ -382,18 +448,16 @@ handle_call({prepare, Query}, _From, State) ->
             {reply, {ok, Id}, State2}
             {reply, {ok, Id}, State2}
     end;
     end;
 handle_call({prepare, Name, Query}, _From, State) when is_atom(Name) ->
 handle_call({prepare, Name, Query}, _From, State) when is_atom(Name) ->
-    #state{socket = Socket, timeout = Timeout} = State,
-    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
-    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+    #state{socket = Socket} = State,
     %% First unprepare if there is an old statement with this name.
     %% First unprepare if there is an old statement with this name.
     State1 = case dict:find(Name, State#state.stmts) of
     State1 = case dict:find(Name, State#state.stmts) of
         {ok, OldStmt} ->
         {ok, OldStmt} ->
-            mysql_protocol:unprepare(OldStmt, SendFun, RecvFun),
+            mysql_protocol:unprepare(OldStmt, gen_tcp, Socket),
             State#state{stmts = dict:erase(Name, State#state.stmts)};
             State#state{stmts = dict:erase(Name, State#state.stmts)};
         error ->
         error ->
             State
             State
     end,
     end,
-    Rec = mysql_protocol:prepare(Query, SendFun, RecvFun),
+    Rec = mysql_protocol:prepare(Query, gen_tcp, Socket),
     State2 = update_state(State1, Rec),
     State2 = update_state(State1, Rec),
     case Rec of
     case Rec of
         #error{} = E ->
         #error{} = E ->
@@ -407,10 +471,8 @@ handle_call({unprepare, Stmt}, _From, State) when is_atom(Stmt);
                                                   is_integer(Stmt) ->
                                                   is_integer(Stmt) ->
     case dict:find(Stmt, State#state.stmts) of
     case dict:find(Stmt, State#state.stmts) of
         {ok, StmtRec} ->
         {ok, StmtRec} ->
-            #state{socket = Socket, timeout = Timeout} = State,
-            SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
-            RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
-            mysql_protocol:unprepare(StmtRec, SendFun, RecvFun),
+            #state{socket = Socket} = State,
+            mysql_protocol:unprepare(StmtRec, gen_tcp, Socket),
             Stmts1 = dict:erase(Stmt, State#state.stmts),
             Stmts1 = dict:erase(Stmt, State#state.stmts),
             {reply, ok, State#state{stmts = Stmts1}};
             {reply, ok, State#state{stmts = Stmts1}};
         error ->
         error ->
@@ -437,11 +499,9 @@ handle_info(query_cache, State = #state{query_cache = Cache,
     %% Evict expired queries/statements in the cache used by query/3.
     %% Evict expired queries/statements in the cache used by query/3.
     {Evicted, Cache1} = mysql_cache:evict_older_than(Cache, CacheTime),
     {Evicted, Cache1} = mysql_cache:evict_older_than(Cache, CacheTime),
     %% Unprepare the evicted statements
     %% Unprepare the evicted statements
-    #state{socket = Socket, timeout = Timeout} = State,
-    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
-    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+    #state{socket = Socket} = State,
     lists:foreach(fun ({_Query, Stmt}) ->
     lists:foreach(fun ({_Query, Stmt}) ->
-                      mysql_protocol:unprepare(Stmt, SendFun, RecvFun)
+                      mysql_protocol:unprepare(Stmt, gen_tcp, Socket)
                   end,
                   end,
                   Evicted),
                   Evicted),
     %% If nonempty, schedule eviction again.
     %% If nonempty, schedule eviction again.
@@ -454,10 +514,8 @@ handle_info(_Info, State) ->
 %% @private
 %% @private
 terminate(Reason, State) when Reason == normal; Reason == shutdown ->
 terminate(Reason, State) when Reason == normal; Reason == shutdown ->
     %% Send the goodbye message for politeness.
     %% Send the goodbye message for politeness.
-    #state{socket = Socket, timeout = Timeout} = State,
-    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
-    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
-    mysql_protocol:quit(SendFun, RecvFun);
+    #state{socket = Socket} = State,
+    mysql_protocol:quit(gen_tcp, Socket);
 terminate(_Reason, _State) ->
 terminate(_Reason, _State) ->
     ok.
     ok.
 
 
@@ -470,11 +528,18 @@ code_change(_OldVsn, _State, _Extra) ->
 %% --- Helpers ---
 %% --- Helpers ---
 
 
 %% @doc Returns a tuple on the the same form as handle_call/3.
 %% @doc Returns a tuple on the the same form as handle_call/3.
-execute_stmt(StmtRec, Args, State) ->
-    #state{socket = Socket, timeout = Timeout} = State,
-    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
-    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
-    Rec = mysql_protocol:execute(StmtRec, Args, SendFun, RecvFun),
+execute_stmt(Stmt, Args, Timeout, State = #state{socket = Socket}) ->
+    Rec = case mysql_protocol:execute(Stmt, Args, gen_tcp, Socket, Timeout) of
+        {error, timeout} when State#state.server_version >= [5, 0, 0] ->
+            kill_query(State),
+            mysql_protocol:fetch_execute_response(gen_tcp, Socket, infinity);
+        {error, timeout} ->
+            %% For MySQL 4.x.x there is no way to recover from timeout except
+            %% killing the connection itself.
+            exit(timeout);
+        QueryResult ->
+            QueryResult
+    end,
     State1 = update_state(State, Rec),
     State1 = update_state(State, Rec),
     case Rec of
     case Rec of
         #ok{} ->
         #ok{} ->
@@ -505,3 +570,26 @@ update_state(State, _Other) ->
     %% This includes errors, resultsets, etc.
     %% This includes errors, resultsets, etc.
     %% Reset warnings, etc. (Note: We don't reset status and insert_id.)
     %% Reset warnings, etc. (Note: We don't reset status and insert_id.)
     State#state{warning_count = 0, affected_rows = 0}.
     State#state{warning_count = 0, affected_rows = 0}.
+
+%% @doc Makes a separate connection and execute KILL QUERY. We do this to get
+%% our main connection back to normal. KILL QUERY appeared in MySQL 5.0.0.
+kill_query(#state{connection_id = ConnId, host = Host, port = Port,
+                  user = User, password = Password}) ->
+    %% Connect socket
+    SockOpts = [{active, false}, binary, {packet, raw}],
+    {ok, Socket} = gen_tcp:connect(Host, Port, SockOpts),
+
+    %% Exchange handshake communication.
+    Result = mysql_protocol:handshake(User, Password, undefined, gen_tcp,
+                                      Socket),
+    case Result of
+        #handshake{} ->
+            %% Kill and disconnect
+            IdBin = integer_to_binary(ConnId),
+            #ok{} = mysql_protocol:query(<<"KILL QUERY ", IdBin/binary>>,
+                                         gen_tcp, Socket, 3000),
+            mysql_protocol:quit(gen_tcp, Socket);
+        #error{} = E ->
+            error_logger:error_msg("Failed to connect to kill query: ~p",
+                                   [error_to_reason(E)])
+    end.

+ 131 - 103
src/mysql_protocol.erl

@@ -27,13 +27,8 @@
 -module(mysql_protocol).
 -module(mysql_protocol).
 
 
 -export([handshake/5, quit/2,
 -export([handshake/5, quit/2,
-         query/3,
-         prepare/3, unprepare/3, execute/4]).
-
--export_type([sendfun/0, recvfun/0]).
-
--type sendfun() :: fun((binary()) -> ok).
--type recvfun() :: fun((integer()) -> {ok, binary()}).
+         query/4, fetch_query_response/3,
+         prepare/3, unprepare/3, execute/5, fetch_execute_response/3]).
 
 
 %% How much data do we want to send at most?
 %% How much data do we want to send at most?
 -define(MAX_BYTES_PER_PACKET, 50000000).
 -define(MAX_BYTES_PER_PACKET, 50000000).
@@ -49,41 +44,51 @@
 %% @doc Performs a handshake using the supplied functions for communication.
 %% @doc Performs a handshake using the supplied functions for communication.
 %% Returns an ok or an error record. Raises errors when various unimplemented
 %% Returns an ok or an error record. Raises errors when various unimplemented
 %% features are requested.
 %% features are requested.
--spec handshake(iodata(), iodata(), iodata() | undefined, sendfun(),
-                recvfun()) -> #ok{} | #error{}.
-handshake(Username, Password, Database, SendFun, RecvFun) ->
+-spec handshake(iodata(), iodata(), iodata() | undefined, atom(), term()) ->
+    #handshake{} | #error{}.
+handshake(Username, Password, Database, TcpModule, Socket) ->
     SeqNum0 = 0,
     SeqNum0 = 0,
-    {ok, HandshakePacket, SeqNum1} = recv_packet(RecvFun, SeqNum0),
+    {ok, HandshakePacket, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum0),
     Handshake = parse_handshake(HandshakePacket),
     Handshake = parse_handshake(HandshakePacket),
     Response = build_handshake_response(Handshake, Username, Password,
     Response = build_handshake_response(Handshake, Username, Password,
                                         Database),
                                         Database),
-    {ok, SeqNum2} = send_packet(SendFun, Response, SeqNum1),
-    {ok, ConfirmPacket, _SeqNum3} = recv_packet(RecvFun, SeqNum2),
-    parse_handshake_confirm(ConfirmPacket).
+    {ok, SeqNum2} = send_packet(TcpModule, Socket, Response, SeqNum1),
+    {ok, ConfirmPacket, _SeqNum3} = recv_packet(TcpModule, Socket, SeqNum2),
+    case parse_handshake_confirm(ConfirmPacket) of
+        #ok{status = OkStatus} ->
+            OkStatus = Handshake#handshake.status,
+            Handshake;
+        Error ->
+            Error
+    end.
 
 
-quit(SendFun, RecvFun) ->
-    {ok, SeqNum1} = send_packet(SendFun, <<?COM_QUIT>>, 0),
-    case recv_packet(RecvFun, SeqNum1) of
+quit(TcpModule, Socket) ->
+    {ok, SeqNum1} = send_packet(TcpModule, Socket, <<?COM_QUIT>>, 0),
+    case recv_packet(TcpModule, Socket, SeqNum1) of
         {error, closed} -> ok;
         {error, closed} -> ok;
         {ok, ?ok_pattern, _SeqNum2} -> ok
         {ok, ?ok_pattern, _SeqNum2} -> ok
     end.
     end.
 
 
--spec query(Query :: iodata(), sendfun(), recvfun()) ->
-    #ok{} | #error{} | #resultset{}.
-query(Query, SendFun, RecvFun) ->
+-spec query(Query :: iodata(), atom(), term(), timeout()) ->
+    #ok{} | #resultset{} | #error{} | {error, timeout}.
+query(Query, TcpModule, Socket, Timeout) ->
     Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
     Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
     SeqNum0 = 0,
     SeqNum0 = 0,
-    {ok, SeqNum1} = send_packet(SendFun, Req, SeqNum0),
-    {ok, Resp, SeqNum2} = recv_packet(RecvFun, SeqNum1),
-    case Resp of
-        ?ok_pattern ->
-            parse_ok_packet(Resp);
-        ?error_pattern ->
-            parse_error_packet(Resp);
-        _ResultSet ->
+    {ok, _SeqNum1} = send_packet(TcpModule, Socket, Req, SeqNum0),
+    fetch_query_response(TcpModule, Socket, Timeout).
+
+%% @doc This is used by query/4. If query/4 returns {error, timeout}, this
+%% function can be called to retry to fetch the results of the query. 
+fetch_query_response(TcpModule, Socket, Timeout) ->
+    case recv_packet(TcpModule, Socket, Timeout, any) of
+        {ok, ?ok_pattern = Ok, _} ->
+            parse_ok_packet(Ok);
+        {ok, ?error_pattern = Error, _} ->
+            parse_error_packet(Error);
+        {ok, ResultPacket, SeqNum2} ->
             %% The first packet in a resultset is only the column count.
             %% The first packet in a resultset is only the column count.
-            {ColumnCount, <<>>} = lenenc_int(Resp),
-            case fetch_resultset(RecvFun, ColumnCount, SeqNum2) of
+            {ColumnCount, <<>>} = lenenc_int(ResultPacket),
+            case fetch_resultset(TcpModule, Socket, ColumnCount, SeqNum2) of
                 #error{} = E ->
                 #error{} = E ->
                     E;
                     E;
                 #resultset{cols = ColDefs, rows = Rows} = R ->
                 #resultset{cols = ColDefs, rows = Rows} = R ->
@@ -92,15 +97,17 @@ query(Query, SendFun, RecvFun) ->
                     Rows1 = [decode_text_row(ColumnCount, ColDefs, Row)
                     Rows1 = [decode_text_row(ColumnCount, ColDefs, Row)
                              || Row <- Rows],
                              || Row <- Rows],
                     R#resultset{rows = Rows1}
                     R#resultset{rows = Rows1}
-            end
+            end;
+        {error, timeout} ->
+            {error, timeout}
     end.
     end.
 
 
 %% @doc Prepares a statement.
 %% @doc Prepares a statement.
--spec prepare(iodata(), sendfun(), recvfun()) -> #error{} | #prepared{}.
-prepare(Query, SendFun, RecvFun) ->
+-spec prepare(iodata(), atom(), term()) -> #error{} | #prepared{}.
+prepare(Query, TcpModule, Socket) ->
     Req = <<?COM_STMT_PREPARE, (iolist_to_binary(Query))/binary>>,
     Req = <<?COM_STMT_PREPARE, (iolist_to_binary(Query))/binary>>,
-    {ok, SeqNum1} = send_packet(SendFun, Req, 0),
-    {ok, Resp, SeqNum2} = recv_packet(RecvFun, SeqNum1),
+    {ok, SeqNum1} = send_packet(TcpModule, Socket, Req, 0),
+    {ok, Resp, SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
     case Resp of
     case Resp of
         ?error_pattern ->
         ?error_pattern ->
             parse_error_packet(Resp);
             parse_error_packet(Resp);
@@ -116,27 +123,30 @@ prepare(Query, SendFun, RecvFun) ->
             %% with charset 'binary' so we have to select a type ourselves for
             %% with charset 'binary' so we have to select a type ourselves for
             %% the parameters we have in execute/4.
             %% the parameters we have in execute/4.
             {_ParamDefs, SeqNum3} =
             {_ParamDefs, SeqNum3} =
-                fetch_column_definitions_if_any(NumParams, RecvFun, SeqNum2),
+                fetch_column_definitions_if_any(NumParams, TcpModule, Socket,
+                                                SeqNum2),
             %% Column Definition Block. We get column definitions in execute
             %% Column Definition Block. We get column definitions in execute
             %% too, so we don't need them here. We *could* store them to be able
             %% too, so we don't need them here. We *could* store them to be able
             %% to provide the user with some info about a prepared statement.
             %% to provide the user with some info about a prepared statement.
             {_ColDefs, _SeqNum4} =
             {_ColDefs, _SeqNum4} =
-                fetch_column_definitions_if_any(NumColumns, RecvFun, SeqNum3),
+                fetch_column_definitions_if_any(NumColumns, TcpModule, Socket, SeqNum3),
             #prepared{statement_id = StmtId,
             #prepared{statement_id = StmtId,
                       param_count = NumParams,
                       param_count = NumParams,
                       warning_count = WarningCount}
                       warning_count = WarningCount}
     end.
     end.
 
 
 %% @doc Deallocates a prepared statement.
 %% @doc Deallocates a prepared statement.
--spec unprepare(#prepared{}, sendfun(), recvfun()) -> ok.
-unprepare(#prepared{statement_id = Id}, SendFun, _RecvFun) ->
-    {ok, _SeqNum} = send_packet(SendFun, <<?COM_STMT_CLOSE, Id:32/little>>, 0),
+-spec unprepare(#prepared{}, atom(), term()) -> ok.
+unprepare(#prepared{statement_id = Id}, TcpModule, Socket) ->
+    {ok, _SeqNum} = send_packet(TcpModule, Socket,
+                                <<?COM_STMT_CLOSE, Id:32/little>>, 0),
     ok.
     ok.
 
 
 %% @doc Executes a prepared statement.
 %% @doc Executes a prepared statement.
--spec execute(#prepared{}, [term()], sendfun(), recvfun()) -> #resultset{}.
+-spec execute(#prepared{}, [term()], atom(), term(), timeout()) ->
+    #ok{} | #resultset{} | #error{} | {error, timeout}.
 execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
 execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
-        SendFun, RecvFun) when ParamCount == length(ParamValues) ->
+        TcpModule, Socket, Timeout) when ParamCount == length(ParamValues) ->
     %% Flags Constant Name
     %% Flags Constant Name
     %% 0x00 CURSOR_TYPE_NO_CURSOR
     %% 0x00 CURSOR_TYPE_NO_CURSOR
     %% 0x01 CURSOR_TYPE_READ_ONLY
     %% 0x01 CURSOR_TYPE_READ_ONLY
@@ -162,21 +172,23 @@ execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
             {TypesAndSigns, EncValues} = lists:unzip(EncodedParams),
             {TypesAndSigns, EncValues} = lists:unzip(EncodedParams),
             iolist_to_binary([Req1, TypesAndSigns, EncValues])
             iolist_to_binary([Req1, TypesAndSigns, EncValues])
     end,
     end,
-    {ok, SeqNum1} = send_packet(SendFun, Req, 0),
-    {ok, Resp, SeqNum2} = recv_packet(RecvFun, SeqNum1),
-    case Resp of
-        ?ok_pattern ->
-            parse_ok_packet(Resp);
-        ?error_pattern ->
-            parse_error_packet(Resp);
-        _ResultPacket ->
+    {ok, _SeqNum1} = send_packet(TcpModule, Socket, Req, 0),
+    fetch_execute_response(TcpModule, Socket, Timeout).
+
+%% @doc This is used by execute/5. If execute/5 returns {error, timeout}, this
+%% function can be called to retry to fetch the results of the query.
+fetch_execute_response(TcpModule, Socket, Timeout) ->
+    case recv_packet(TcpModule, Socket, Timeout, any) of
+        {ok, ?ok_pattern = Ok, _} ->
+            parse_ok_packet(Ok);
+        {ok, ?error_pattern = Error, _} ->
+            parse_error_packet(Error);
+        {ok, ResultPacket, SeqNum2} ->
             %% The first packet in a resultset is only the column count.
             %% The first packet in a resultset is only the column count.
-            {ColumnCount, <<>>} = lenenc_int(Resp),
-            case fetch_resultset(RecvFun, ColumnCount, SeqNum2) of
+            {ColumnCount, <<>>} = lenenc_int(ResultPacket),
+            case fetch_resultset(TcpModule, Socket, ColumnCount, SeqNum2) of
                 #error{} = E ->
                 #error{} = E ->
                     %% TODO: Find a way to get here and write a testcase.
                     %% TODO: Find a way to get here and write a testcase.
-                    %% This can happen for the text protocol but maybe not for
-                    %% the binary protocol.
                     E;
                     E;
                 #resultset{cols = ColDefs, rows = Rows} = R ->
                 #resultset{cols = ColDefs, rows = Rows} = R ->
                     %% Parse the rows according to the 'binary protocol'
                     %% Parse the rows according to the 'binary protocol'
@@ -184,7 +196,9 @@ execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
                     Rows1 = [decode_binary_row(ColumnCount, ColDefs, Row)
                     Rows1 = [decode_binary_row(ColumnCount, ColDefs, Row)
                              || Row <- Rows],
                              || Row <- Rows],
                     R#resultset{rows = Rows1}
                     R#resultset{rows = Rows1}
-            end
+            end;
+        {error, timeout} ->
+            {error, timeout}
     end.
     end.
 
 
 %% --- internal ---
 %% --- internal ---
@@ -221,16 +235,24 @@ parse_handshake(<<10, Rest/binary>>) ->
         <<NameNoNul:NameLen/binary-unit:8, 0>> -> NameNoNul;
         <<NameNoNul:NameLen/binary-unit:8, 0>> -> NameNoNul;
         _ -> AuthPluginName
         _ -> AuthPluginName
     end,
     end,
-    #handshake{server_version = ServerVersion,
-              connection_id = ConnectionId,
-              capabilities = Capabilities,
-              character_set = CharacterSet,
-              status = StatusFlags,
-              auth_plugin_data = AuthPluginData,
-              auth_plugin_name = AuthPluginName1};
+    #handshake{server_version = server_version_to_list(ServerVersion),
+               connection_id = ConnectionId,
+               capabilities = Capabilities,
+               character_set = CharacterSet,
+               status = StatusFlags,
+               auth_plugin_data = AuthPluginData,
+               auth_plugin_name = AuthPluginName1};
 parse_handshake(<<Protocol:8, _/binary>>) when Protocol /= 10 ->
 parse_handshake(<<Protocol:8, _/binary>>) when Protocol /= 10 ->
     error(unknown_protocol).
     error(unknown_protocol).
 
 
+%% @doc Converts a version on the form `<<"5.6.21">' to a list `[5, 6, 21]'.
+-spec server_version_to_list(binary()) -> [integer()].
+server_version_to_list(ServerVersion) ->
+    %% Remove stuff after dash for e.g. "5.5.40-0ubuntu0.12.04.1-log"
+    [ServerVersion1 | _] = binary:split(ServerVersion, <<"-">>),
+    lists:map(fun binary_to_integer/1,
+              binary:split(ServerVersion1, <<".">>, [global])).
+
 %% @doc The response sent by the client to the server after receiving the
 %% @doc The response sent by the client to the server after receiving the
 %% initial handshake from the server
 %% initial handshake from the server
 -spec build_handshake_response(#handshake{}, iodata(), iodata(),
 -spec build_handshake_response(#handshake{}, iodata(), iodata(),
@@ -303,14 +325,14 @@ parse_handshake_confirm(Packet) ->
 %% the rows are unparsed binary packages. This function is used for both the
 %% the rows are unparsed binary packages. This function is used for both the
 %% text protocol and the binary protocol. This affects the way the rows need to
 %% text protocol and the binary protocol. This affects the way the rows need to
 %% be parsed.
 %% be parsed.
--spec fetch_resultset(recvfun(), integer(), integer()) ->
+-spec fetch_resultset(atom(), term(), integer(), integer()) ->
     #resultset{} | #error{}.
     #resultset{} | #error{}.
-fetch_resultset(RecvFun, FieldCount, SeqNum) ->
-    {ok, ColDefs, SeqNum1} = fetch_column_definitions(RecvFun, SeqNum,
+fetch_resultset(TcpModule, Socket, FieldCount, SeqNum) ->
+    {ok, ColDefs, SeqNum1} = fetch_column_definitions(TcpModule, Socket, SeqNum,
                                                       FieldCount, []),
                                                       FieldCount, []),
-    {ok, DelimiterPacket, SeqNum2} = recv_packet(RecvFun, SeqNum1),
+    {ok, DelimiterPacket, SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
     #eof{} = parse_eof_packet(DelimiterPacket),
     #eof{} = parse_eof_packet(DelimiterPacket),
-    case fetch_resultset_rows(RecvFun, SeqNum2, []) of
+    case fetch_resultset_rows(TcpModule, Socket, SeqNum2, []) of
         {ok, Rows, _SeqNum3} ->
         {ok, Rows, _SeqNum3} ->
             ColDefs1 = lists:map(fun parse_column_definition/1, ColDefs),
             ColDefs1 = lists:map(fun parse_column_definition/1, ColDefs),
             #resultset{cols = ColDefs1, rows = Rows};
             #resultset{cols = ColDefs1, rows = Rows};
@@ -320,31 +342,33 @@ fetch_resultset(RecvFun, FieldCount, SeqNum) ->
 
 
 %% @doc Receives NumLeft column definition packets. They are not parsed.
 %% @doc Receives NumLeft column definition packets. They are not parsed.
 %% @see parse_column_definition/1
 %% @see parse_column_definition/1
--spec fetch_column_definitions(recvfun(), SeqNum :: integer(),
+-spec fetch_column_definitions(atom(), term(), SeqNum :: integer(),
                                NumLeft :: integer(), Acc :: [binary()]) ->
                                NumLeft :: integer(), Acc :: [binary()]) ->
     {ok, ColDefPackets :: [binary()], NextSeqNum :: integer()}.
     {ok, ColDefPackets :: [binary()], NextSeqNum :: integer()}.
-fetch_column_definitions(RecvFun, SeqNum, NumLeft, Acc) when NumLeft > 0 ->
-    {ok, Packet, SeqNum1} = recv_packet(RecvFun, SeqNum),
-    fetch_column_definitions(RecvFun, SeqNum1, NumLeft - 1, [Packet | Acc]);
-fetch_column_definitions(_RecvFun, SeqNum, 0, Acc) ->
+fetch_column_definitions(TcpModule, Socket, SeqNum, NumLeft, Acc)
+  when NumLeft > 0 ->
+    {ok, Packet, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum),
+    fetch_column_definitions(TcpModule, Socket, SeqNum1, NumLeft - 1,
+                             [Packet | Acc]);
+fetch_column_definitions(_TcpModule, _Socket, SeqNum, 0, Acc) ->
     {ok, lists:reverse(Acc), SeqNum}.
     {ok, lists:reverse(Acc), SeqNum}.
 
 
 %% @doc Fetches rows in a result set. There is a packet per row. The row packets
 %% @doc Fetches rows in a result set. There is a packet per row. The row packets
 %% are not decoded. This function can be used for both the binary and the text
 %% are not decoded. This function can be used for both the binary and the text
 %% protocol result sets.
 %% protocol result sets.
--spec fetch_resultset_rows(recvfun(), SeqNum :: integer(), Acc) ->
+-spec fetch_resultset_rows(atom(), term(), SeqNum :: integer(), Acc) ->
     {ok, Rows, integer()} | #error{}
     {ok, Rows, integer()} | #error{}
     when Acc :: [binary()],
     when Acc :: [binary()],
          Rows :: [binary()].
          Rows :: [binary()].
-fetch_resultset_rows(RecvFun, SeqNum, Acc) ->
-    {ok, Packet, SeqNum1} = recv_packet(RecvFun, SeqNum),
+fetch_resultset_rows(TcpModule, Socket, SeqNum, Acc) ->
+    {ok, Packet, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum),
     case Packet of
     case Packet of
         ?error_pattern ->
         ?error_pattern ->
             parse_error_packet(Packet);
             parse_error_packet(Packet);
         ?eof_pattern ->
         ?eof_pattern ->
             {ok, lists:reverse(Acc), SeqNum1};
             {ok, lists:reverse(Acc), SeqNum1};
         Row ->
         Row ->
-            fetch_resultset_rows(RecvFun, SeqNum1, [Row | Acc])
+            fetch_resultset_rows(TcpModule, Socket, SeqNum1, [Row | Acc])
     end.
     end.
 
 
 %% Parses a packet containing a column definition (part of a result set)
 %% Parses a packet containing a column definition (part of a result set)
@@ -474,11 +498,12 @@ decode_text(#col{type = T}, Text) when T == ?TYPE_FLOAT;
 
 
 %% @doc If NumColumns is non-zero, fetches this number of column definitions
 %% @doc If NumColumns is non-zero, fetches this number of column definitions
 %% and an EOF packet. Used by prepare/3.
 %% and an EOF packet. Used by prepare/3.
-fetch_column_definitions_if_any(0, _RecvFun, SeqNum) ->
+fetch_column_definitions_if_any(0, _TcpModule, _Socket, SeqNum) ->
     {[], SeqNum};
     {[], SeqNum};
-fetch_column_definitions_if_any(N, RecvFun, SeqNum) ->
-    {ok, Defs, SeqNum1} = fetch_column_definitions(RecvFun, SeqNum, N, []),
-    {ok, ?eof_pattern, SeqNum2} = recv_packet(RecvFun, SeqNum1),
+fetch_column_definitions_if_any(N, TcpModule, Socket, SeqNum) ->
+    {ok, Defs, SeqNum1} = fetch_column_definitions(TcpModule, Socket, SeqNum,
+                                                   N, []),
+    {ok, ?eof_pattern, SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
     {Defs, SeqNum2}.
     {Defs, SeqNum2}.
 
 
 %% @doc Decodes a packet representing a row in a binary result set.
 %% @doc Decodes a packet representing a row in a binary result set.
@@ -493,7 +518,7 @@ decode_binary_row(NumColumns, ColumnDefs, <<0, Data/binary>>) ->
     decode_binary_row_acc(ColumnDefs, NullBitMap, Rest, []).
     decode_binary_row_acc(ColumnDefs, NullBitMap, Rest, []).
 
 
 %% @doc Accumulating helper for decode_binary_row/3.
 %% @doc Accumulating helper for decode_binary_row/3.
-decode_binary_row_acc([_ | ColDefs], <<1:1, NullBitMap/bitstring>>, Data, Acc) ->
+decode_binary_row_acc([_|ColDefs], <<1:1, NullBitMap/bitstring>>, Data, Acc) ->
     %% NULL
     %% NULL
     decode_binary_row_acc(ColDefs, NullBitMap, Data, [null | Acc]);
     decode_binary_row_acc(ColDefs, NullBitMap, Data, [null | Acc]);
 decode_binary_row_acc([ColDef | ColDefs], <<0:1, NullBitMap/bitstring>>, Data,
 decode_binary_row_acc([ColDef | ColDefs], <<0:1, NullBitMap/bitstring>>, Data,
@@ -802,40 +827,43 @@ set_to_binary(Set) ->
 
 
 %% -- Protocol basics: packets --
 %% -- Protocol basics: packets --
 
 
-%% @doc Wraps Data in packet headers, sends it by calling SendFun and returns
-%% {ok, SeqNum1} where SeqNum1 is the next sequence number.
--spec send_packet(sendfun(), Data :: binary(), SeqNum :: integer()) ->
+%% @doc Wraps Data in packet headers, sends it by calling TcpModule:send/2 with
+%% Socket and returns {ok, SeqNum1} where SeqNum1 is the next sequence number.
+-spec send_packet(atom(), term(), Data :: binary(), SeqNum :: integer()) ->
     {ok, NextSeqNum :: integer()}.
     {ok, NextSeqNum :: integer()}.
-send_packet(SendFun, Data, SeqNum) ->
+send_packet(TcpModule, Socket, Data, SeqNum) ->
     {WithHeaders, SeqNum1} = add_packet_headers(Data, SeqNum),
     {WithHeaders, SeqNum1} = add_packet_headers(Data, SeqNum),
-    ok = SendFun(WithHeaders),
+    ok = TcpModule:send(Socket, WithHeaders),
     {ok, SeqNum1}.
     {ok, SeqNum1}.
 
 
-%% @doc Receives data by calling RecvFun and removes the packet headers. Returns
-%% the packet contents and the next packet sequence number.
--spec recv_packet(RecvFun :: recvfun(), SeqNum :: integer()) ->
+%% @see recv_packet/4
+recv_packet(TcpModule, Socket, SeqNum) ->
+    recv_packet(TcpModule, Socket, infinity, SeqNum).
+
+%% @doc Receives data by calling TcpModule:recv/2 and removes the packet
+%% headers. Returns the packet contents and the next packet sequence number.
+-spec recv_packet(atom(), term(), timeout(), integer() | any) ->
     {ok, Data :: binary(), NextSeqNum :: integer()}.
     {ok, Data :: binary(), NextSeqNum :: integer()}.
-recv_packet(RecvFun, SeqNum) ->
-    recv_packet(RecvFun, SeqNum, <<>>).
+recv_packet(TcpModule, Socket, Timeout, SeqNum) ->
+    recv_packet(TcpModule, Socket, Timeout, SeqNum, <<>>).
 
 
-%% @doc Receives data by calling RecvFun and removes packet headers. Returns the
-%% data and the next packet sequence number.
--spec recv_packet(RecvFun :: recvfun(), ExpectSeqNum :: integer(),
-                  Acc :: binary()) ->
+%% @doc Accumulating helper for recv_packet/4
+-spec recv_packet(atom(), term(), timeout(), integer() | any, binary()) ->
     {ok, Data :: binary(), NextSeqNum :: integer()}.
     {ok, Data :: binary(), NextSeqNum :: integer()}.
-recv_packet(RecvFun, ExpectSeqNum, Acc) ->
-    case RecvFun(4) of
+recv_packet(TcpModule, Socket, Timeout, ExpectSeqNum, Acc) ->
+    case TcpModule:recv(Socket, 4, Timeout) of
         {ok, Header} ->
         {ok, Header} ->
-            {Size, ExpectSeqNum, More} = parse_packet_header(Header),
-            {ok, Body} = RecvFun(Size),
+            {Size, SeqNum, More} = parse_packet_header(Header),
+            true = SeqNum == ExpectSeqNum orelse ExpectSeqNum == any,
+            {ok, Body} = TcpModule:recv(Socket, Size),
             Acc1 = <<Acc/binary, Body/binary>>,
             Acc1 = <<Acc/binary, Body/binary>>,
-            NextSeqNum = (ExpectSeqNum + 1) band 16#ff,
+            NextSeqNum = (SeqNum + 1) band 16#ff,
             case More of
             case More of
                 false -> {ok, Acc1, NextSeqNum};
                 false -> {ok, Acc1, NextSeqNum};
-                true  -> recv_packet(RecvFun, NextSeqNum, Acc1)
+                true  -> recv_packet(TcpModule, Socket, NextSeqNum, Acc1)
             end;
             end;
-        {error, closed} ->
-            {error, closed}
+        {error, Reason} ->
+            {error, Reason}
     end.
     end.
 
 
 %% @doc Parses a packet header (32 bits) and returns a tuple.
 %% @doc Parses a packet header (32 bits) and returns a tuple.

+ 9 - 104
test/mysql_protocol_tests.erl

@@ -38,11 +38,9 @@ resultset_test() ->
         "00 02 00                                              ..."),
         "00 02 00                                              ..."),
     ExpectedCommunication = [{send, ExpectedReq},
     ExpectedCommunication = [{send, ExpectedReq},
                              {recv, ExpectedResponse}],
                              {recv, ExpectedResponse}],
-    FakeSock = fakesocket_create(ExpectedCommunication),
-    SendFun = fun (Data) -> fakesocket_send(FakeSock, Data) end,
-    RecvFun = fun (Size) -> fakesocket_recv(FakeSock, Size) end,
-    ResultSet = mysql_protocol:query(Query, SendFun, RecvFun),
-    fakesocket_close(FakeSock),
+    Sock = mock_tcp:create(ExpectedCommunication),
+    ResultSet = mysql_protocol:query(Query, mock_tcp, Sock, infinity),
+    mock_tcp:close(Sock),
     ?assertMatch(#resultset{cols = [#col{name = <<"@@version_comment">>}],
     ?assertMatch(#resultset{cols = [#col{name = <<"@@version_comment">>}],
                             rows = [[<<"MySQL Community Server (GPL)">>]]},
                             rows = [[<<"MySQL Community Server (GPL)">>]]},
                  ResultSet),
                  ResultSet),
@@ -77,12 +75,10 @@ resultset_error_test() ->
         "00 00 05 00 00 0c fe 00    00 02 00 17 00 00 0d ff    ................"
         "00 00 05 00 00 0c fe 00    00 02 00 17 00 00 0d ff    ................"
         "48 04 23 48 59 30 30 30    4e 6f 20 74 61 62 6c 65    H.#HY000No table"
         "48 04 23 48 59 30 30 30    4e 6f 20 74 61 62 6c 65    H.#HY000No table"
         "73 20 75 73 65 64                                     s used"),
         "73 20 75 73 65 64                                     s used"),
-    Sock = fakesocket_create([{send, ExpectedReq}, {recv, ExpectedResponse}]),
-    SendFun = fun (Data) -> fakesocket_send(Sock, Data) end,
-    RecvFun = fun (Size) -> fakesocket_recv(Sock, Size) end,
-    Result = mysql_protocol:query(Query, SendFun, RecvFun),
+    Sock = mock_tcp:create([{send, ExpectedReq}, {recv, ExpectedResponse}]),
+    Result = mysql_protocol:query(Query, mock_tcp, Sock, infinity),
     ?assertMatch(#error{}, Result),
     ?assertMatch(#error{}, Result),
-    fakesocket_close(Sock),
+    mock_tcp:close(Sock),
     ok.
     ok.
 
 
 prepare_test() ->
 prepare_test() ->
@@ -102,11 +98,9 @@ prepare_test() ->
         "00 00 05 03 64 65 66 00    00 00 04 63 6f 6c 31 00    ....def....col1."
         "00 00 05 03 64 65 66 00    00 00 04 63 6f 6c 31 00    ....def....col1."
         "0c 3f 00 00 00 00 00 fd    80 00 1f 00 00|05 00 00    .?.............."
         "0c 3f 00 00 00 00 00 fd    80 00 1f 00 00|05 00 00    .?.............."
         "06 fe 00 00 02 00                                     ......"),
         "06 fe 00 00 02 00                                     ......"),
-    Sock = fakesocket_create([{send, ExpectedReq}, {recv, ExpectedResp}]),
-    SendFun = fun (Data) -> fakesocket_send(Sock, Data) end,
-    RecvFun = fun (Size) -> fakesocket_recv(Sock, Size) end,
-    Result = mysql_protocol:prepare(Query, SendFun, RecvFun),
-    fakesocket_close(Sock),
+    Sock = mock_tcp:create([{send, ExpectedReq}, {recv, ExpectedResp}]),
+    Result = mysql_protocol:prepare(Query, mock_tcp, Sock),
+    mock_tcp:close(Sock),
     ?assertMatch(#prepared{statement_id = StmtId,
     ?assertMatch(#prepared{statement_id = StmtId,
                            param_count = 2,
                            param_count = 2,
                            warning_count = 0} when is_integer(StmtId),
                            warning_count = 0} when is_integer(StmtId),
@@ -150,94 +144,5 @@ hexdump_to_bin_test() ->
 
 
 %% --- Fake socket ---
 %% --- Fake socket ---
 %%
 %%
-%% A "fake socket" is used in test where we need to mock socket communication.
-%% It is a pid maintaining a list of expected send and recv events.
 
 
-%% @doc Creates a fakesocket process with a buffer of expected recv and send
-%% calls. The pid of the fakesocket process is returned.
--spec fakesocket_create([{recv, binary()} | {send, binary()}]) -> pid().
-fakesocket_create(ExpectedEvents) ->
-    spawn_link(fun () -> fakesocket_loop(ExpectedEvents) end).
 
 
-%% @doc Receives NumBytes bytes from fakesocket Pid. This function can be used
-%% as a replacement for gen_tcp:recv/2 in unit tests. If there not enough data
-%% in the fakesocket's buffer, an error is raised.
-fakesocket_recv(Pid, NumBytes) ->
-    Pid ! {recv, NumBytes, self()},
-    receive
-        {ok, Data} -> {ok, Data};
-        error -> error({unexpected_recv, NumBytes})
-    after 100 ->
-        error(noreply)
-    end.
-
-%% @doc Sends data to fa fakesocket. This can be used as replacement for
-%% gen_tcp:send/2 in unit tests. If the data sent is not what the fakesocket
-%% expected, an error is raised.
-fakesocket_send(Pid, Data) ->
-    Pid ! {send, iolist_to_binary(Data), self()},
-    receive
-        ok -> ok;
-        error -> error({unexpected_send, Data})
-    after 100 ->
-        error(noreply)
-    end.
-
-%% Stops the fakesocket process. If the fakesocket's buffer is not empty,
-%% an error is raised.
-fakesocket_close(Pid) ->
-    Pid ! {done, self()},
-    receive
-        ok -> ok;
-        {remains, Remains} -> error({unexpected_close, Remains})
-    after 100 ->
-        error(noreply)
-    end.
-
-%% Used by fakesocket_create/1.
-fakesocket_loop(AllEvents = [{Func, Data} | Events]) ->
-    receive
-        {recv, NumBytes, FromPid} when Func == recv, NumBytes == size(Data) ->
-            FromPid ! {ok, Data},
-            fakesocket_loop(Events);
-        {recv, NumBytes, FromPid} when Func == recv, NumBytes < size(Data) ->
-            <<Data1:NumBytes/binary, Rest/binary>> = Data,
-            FromPid ! {ok, Data1},
-            fakesocket_loop([{recv, Rest} | Events]);
-        {send, Bytes, FromPid} when Func == send, Bytes == Data ->
-            FromPid ! ok,
-            fakesocket_loop(Events);
-        {send, Bytes, FromPid} when Func == send, size(Bytes) < size(Data) ->
-            Size = size(Bytes),
-            case Data of
-                <<Bytes:Size/binary, Rest/binary>> ->
-                    FromPid ! ok,
-                    fakesocket_loop([{send, Rest} | Events]);
-                _ ->
-                    FromPid ! error
-            end;
-        {_, _, FromPid} ->
-            FromPid ! error;
-        {done, FromPid} ->
-            FromPid ! {remains, AllEvents}
-    end;
-fakesocket_loop([]) ->
-    receive
-        {done, FromPid} -> FromPid ! ok;
-        {_, _, FromPid} -> FromPid ! error
-    end.
-
-%% Tests for the fakesocket functions.
-fakesocket_bad_recv_test() ->
-    Pid = fakesocket_create([{recv, <<"foobar">>}]),
-    ?assertError(_, fakesocket_recv(Pid, 10)).
-
-fakesocket_success_test() ->
-    Pid = fakesocket_create([{recv, <<"foobar">>}, {send, <<"baz">>}]),
-    %?assertError({unexpected_close, _}, fakesocket_close(Pid)),
-    ?assertEqual({ok, <<"foo">>}, fakesocket_recv(Pid, 3)),
-    ?assertEqual({ok, <<"bar">>}, fakesocket_recv(Pid, 3)),
-    ?assertEqual(ok, fakesocket_send(Pid, <<"baz">>)),
-    ?assertEqual(ok, fakesocket_close(Pid)),
-    %% The process will exit after close. Another recv will raise noreply.
-    ?assertError(noreply, fakesocket_recv(Pid, 3)).

+ 31 - 0
test/mysql_tests.erl

@@ -346,6 +346,37 @@ write_read_text_binary(Conn, Term, SqlLiteral, Table, Column) ->
 
 
 %% --------------------------------------------------------------------------
 %% --------------------------------------------------------------------------
 
 
+timeout_test_() ->
+    {setup,
+     fun () ->
+         {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password}]),
+         Pid
+     end,
+     fun (Pid) ->
+         exit(Pid, normal)
+     end,
+     {with, [fun (Pid) ->
+                 %% SLEEP was added in MySQL 5.0.12
+                 ?assertEqual({ok, [<<"SLEEP(5)">>], [[1]]},
+                              mysql:query(Pid, <<"SELECT SLEEP(5)">>, 40)),
+
+                 %% A query after an interrupted query shouldn't get a timeout.
+                 ?assertMatch({ok,[<<"42">>], [[42]]},
+                              mysql:query(Pid, <<"SELECT 42">>)),
+
+                 %% Parametrized query
+                 ?assertEqual({ok, [<<"SLEEP(?)">>], [[1]]},
+                              mysql:query(Pid, <<"SELECT SLEEP(?)">>, [5], 40)),
+
+                 %% Prepared statement
+                 {ok, Stmt} = mysql:prepare(Pid, <<"SELECT SLEEP(?)">>),
+                 ?assertEqual({ok, [<<"SLEEP(?)">>], [[1]]},
+                              mysql:execute(Pid, Stmt, [5], 40)),
+                 ok = mysql:unprepare(Pid, Stmt)
+             end]}}.
+
+%% --------------------------------------------------------------------------
+
 %% Prepared statements and transactions
 %% Prepared statements and transactions
 
 
 with_table_foo_test_() ->
 with_table_foo_test_() ->