Browse Source

Keep alive (ping)

Viktor Söderqvist 10 years ago
parent
commit
9b9915e93f
3 changed files with 93 additions and 28 deletions
  1. 56 24
      src/mysql.erl
  2. 10 3
      src/mysql_protocol.erl
  3. 27 1
      test/mysql_tests.erl

+ 56 - 24
src/mysql.erl

@@ -42,6 +42,7 @@
 -define(default_connect_timeout, 5000).
 -define(default_connect_timeout, 5000).
 -define(default_query_timeout, infinity).
 -define(default_query_timeout, infinity).
 -define(default_query_cache_time, 60000). %% for query/3.
 -define(default_query_cache_time, 60000). %% for query/3.
+-define(default_ping_timeout, 60000).
 
 
 -define(cmd_timeout, 3000). %% Timeout used for various commands to the server
 -define(cmd_timeout, 3000). %% Timeout used for various commands to the server
 
 
@@ -86,6 +87,11 @@
 %%   <dt>`{log_warnings, boolean()}'</dt>
 %%   <dt>`{log_warnings, boolean()}'</dt>
 %%   <dd>Whether to fetch warnings and log them using error_logger; default
 %%   <dd>Whether to fetch warnings and log them using error_logger; default
 %%       true.</dd>
 %%       true.</dd>
+%%   <dt>`{keep_alive, boolean() | timeout()}'</dt>
+%%   <dd>Send ping when unused for a certain time. Possible values are `true',
+%%       `false' and `integer() > 0' for an explicit interval in milliseconds.
+%%       The default is `false'. For `true' a default ping timeout is used.
+%%       </dt>
 %%   <dt>`{query_timeout, Timeout}'</dt>
 %%   <dt>`{query_timeout, Timeout}'</dt>
 %%   <dd>The default time to wait for a response when executing a query or a
 %%   <dd>The default time to wait for a response when executing a query or a
 %%       prepared statement. This can be given per query using `query/3,4' and
 %%       prepared statement. This can be given per query using `query/3,4' and
@@ -381,24 +387,33 @@ transaction(Conn, Fun, Args, Retries) when is_list(Args),
 %% Gen_server state
 %% Gen_server state
 -record(state, {server_version, connection_id, socket,
 -record(state, {server_version, connection_id, socket,
                 host, port, user, password, log_warnings,
                 host, port, user, password, log_warnings,
+                ping_timeout,
                 query_timeout, query_cache_time,
                 query_timeout, query_cache_time,
                 affected_rows = 0, status = 0, warning_count = 0, insert_id = 0,
                 affected_rows = 0, status = 0, warning_count = 0, insert_id = 0,
-                transaction_level = 0,
+                transaction_level = 0, ping_ref = undefined,
                 stmts = dict:new(), query_cache = empty}).
                 stmts = dict:new(), query_cache = empty}).
 
 
 %% @private
 %% @private
 init(Opts) ->
 init(Opts) ->
     %% Connect
     %% Connect
-    Host     = proplists:get_value(host,          Opts, ?default_host),
-    Port     = proplists:get_value(port,          Opts, ?default_port),
-    User     = proplists:get_value(user,          Opts, ?default_user),
-    Password = proplists:get_value(password,      Opts, ?default_password),
-    Database = proplists:get_value(database,      Opts, undefined),
-    LogWarn  = proplists:get_value(log_warnings,  Opts, true),
-    Timeout  = proplists:get_value(query_timeout, Opts, ?default_query_timeout),
+    Host           = proplists:get_value(host, Opts, ?default_host),
+    Port           = proplists:get_value(port, Opts, ?default_port),
+    User           = proplists:get_value(user, Opts, ?default_user),
+    Password       = proplists:get_value(password, Opts, ?default_password),
+    Database       = proplists:get_value(database, Opts, undefined),
+    LogWarn        = proplists:get_value(log_warnings, Opts, true),
+    KeepAlive      = proplists:get_value(keep_alive, Opts, false),
+    Timeout        = proplists:get_value(query_timeout, Opts,
+                                         ?default_query_timeout),
     QueryCacheTime = proplists:get_value(query_cache_time, Opts,
     QueryCacheTime = proplists:get_value(query_cache_time, Opts,
                                          ?default_query_cache_time),
                                          ?default_query_cache_time),
 
 
+    PingTimeout = case KeepAlive of
+        true         -> ?default_ping_timeout;
+        false        -> infinity;
+        N when N > 0 -> N
+    end,
+
     %% Connect socket
     %% Connect socket
     SockOpts = [{active, false}, binary, {packet, raw}],
     SockOpts = [{active, false}, binary, {packet, raw}],
     {ok, Socket} = gen_tcp:connect(Host, Port, SockOpts),
     {ok, Socket} = gen_tcp:connect(Host, Port, SockOpts),
@@ -414,11 +429,13 @@ init(Opts) ->
                            host = Host, port = Port, user = User,
                            host = Host, port = Port, user = User,
                            password = Password, status = Status,
                            password = Password, status = Status,
                            log_warnings = LogWarn,
                            log_warnings = LogWarn,
+                           ping_timeout = PingTimeout,
                            query_timeout = Timeout,
                            query_timeout = Timeout,
                            query_cache_time = QueryCacheTime},
                            query_cache_time = QueryCacheTime},
             %% Trap exit so that we can properly disconnect when we die.
             %% Trap exit so that we can properly disconnect when we die.
             process_flag(trap_exit, true),
             process_flag(trap_exit, true),
-            {ok, State};
+            State1 = schedule_ping(State),
+            {ok, State1};
         #error{} = E ->
         #error{} = E ->
             {stop, error_to_reason(E)}
             {stop, error_to_reason(E)}
     end.
     end.
@@ -597,8 +614,9 @@ handle_call({unprepare, Stmt}, _From, State) when is_atom(Stmt);
         {ok, StmtRec} ->
         {ok, StmtRec} ->
             #state{socket = Socket} = State,
             #state{socket = Socket} = State,
             mysql_protocol:unprepare(StmtRec, gen_tcp, Socket),
             mysql_protocol:unprepare(StmtRec, gen_tcp, Socket),
-            Stmts1 = dict:erase(Stmt, State#state.stmts),
-            {reply, ok, State#state{stmts = Stmts1}};
+            State1 = State#state{stmts = dict:erase(Stmt, State#state.stmts)},
+            State2 = schedule_ping(State1),
+            {reply, ok, State2};
         error ->
         error ->
             {reply, {error, not_prepared}, State}
             {reply, {error, not_prepared}, State}
     end;
     end;
@@ -669,6 +687,9 @@ handle_info(query_cache, State = #state{query_cache = Cache,
     mysql_cache:size(Cache1) > 0 andalso
     mysql_cache:size(Cache1) > 0 andalso
         erlang:send_after(CacheTime, self(), query_cache),
         erlang:send_after(CacheTime, self(), query_cache),
     {noreply, State#state{query_cache = Cache1}};
     {noreply, State#state{query_cache = Cache1}};
+handle_info(ping, State) ->
+    Ok = mysql_protocol:ping(gen_tcp, State#state.socket),
+    {noreply, update_state(State, Ok)};
 handle_info(_Info, State) ->
 handle_info(_Info, State) ->
     {noreply, State}.
     {noreply, State}.
 
 
@@ -738,20 +759,31 @@ execute_stmt(Stmt, Args, Timeout, State = #state{socket = Socket}) ->
 error_to_reason(#error{code = Code, state = State, msg = Msg}) ->
 error_to_reason(#error{code = Code, state = State, msg = Msg}) ->
     {Code, State, Msg}.
     {Code, State, Msg}.
 
 
-%% @doc Updates a state with information from a response.
+%% @doc Updates a state with information from a response. Also re-schedules
+%% ping.
 -spec update_state(#state{}, #ok{} | #eof{} | any()) -> #state{}.
 -spec update_state(#state{}, #ok{} | #eof{} | any()) -> #state{}.
-update_state(State, #ok{status = S, affected_rows = R,
-                        insert_id = Id, warning_count = W}) ->
-    State#state{status = S, affected_rows = R, insert_id = Id,
-                warning_count = W};
-%update_state(State, #eof{status = S, warning_count = W}) ->
-%    State#state{status = S, warning_count = W, affected_rows = 0};
-update_state(State, #prepared{warning_count = W}) ->
-    State#state{warning_count = W};
-update_state(State, _Other) ->
-    %% This includes errors, resultsets, etc.
-    %% Reset warnings, etc. (Note: We don't reset status and insert_id.)
-    State#state{warning_count = 0, affected_rows = 0}.
+update_state(State, Rec) ->
+    State1 = case Rec of
+        #ok{status = S, affected_rows = R, insert_id = Id, warning_count = W} ->
+            State#state{status = S, affected_rows = R, insert_id = Id,
+                        warning_count = W};
+        %#eof{status = S, warning_count = W} ->
+        %    State#state{status = S, warning_count = W, affected_rows = 0};
+        #prepared{warning_count = W} ->
+            State#state{warning_count = W};
+        _Other ->
+            %% This includes errors, resultsets, etc.
+            %% Reset some things. (Note: We don't reset status and insert_id.)
+            State#state{warning_count = 0, affected_rows = 0}
+    end,
+    schedule_ping(State1).
+
+%% @doc Schedules (or re-schedules) ping.
+schedule_ping(State = #state{ping_timeout = infinity}) ->
+    State;
+schedule_ping(State = #state{ping_timeout = Timeout, ping_ref = Ref}) ->
+    is_reference(Ref) andalso erlang:cancel_timer(Ref),
+    State#state{ping_ref = erlang:send_after(Timeout, self(), ping)}.
 
 
 %% @doc Since errors don't return a status but some errors cause an implicit
 %% @doc Since errors don't return a status but some errors cause an implicit
 %% rollback, we use this function to clear fix the transaction bit in the
 %% rollback, we use this function to clear fix the transaction bit in the

+ 10 - 3
src/mysql_protocol.erl

@@ -26,7 +26,7 @@
 %% @private
 %% @private
 -module(mysql_protocol).
 -module(mysql_protocol).
 
 
--export([handshake/5, quit/2,
+-export([handshake/5, quit/2, ping/2,
          query/4, fetch_query_response/3,
          query/4, fetch_query_response/3,
          prepare/3, unprepare/3, execute/5, fetch_execute_response/3]).
          prepare/3, unprepare/3, execute/5, fetch_execute_response/3]).
 
 
@@ -62,13 +62,20 @@ handshake(Username, Password, Database, TcpModule, Socket) ->
             Error
             Error
     end.
     end.
 
 
+-spec quit(atom(), term()) -> ok.
 quit(TcpModule, Socket) ->
 quit(TcpModule, Socket) ->
     {ok, SeqNum1} = send_packet(TcpModule, Socket, <<?COM_QUIT>>, 0),
     {ok, SeqNum1} = send_packet(TcpModule, Socket, <<?COM_QUIT>>, 0),
     case recv_packet(TcpModule, Socket, SeqNum1) of
     case recv_packet(TcpModule, Socket, SeqNum1) of
-        {error, closed} -> ok;
-        {ok, ?ok_pattern, _SeqNum2} -> ok
+        {error, closed} -> ok;            %% MySQL 5.5.40 and more
+        {ok, ?ok_pattern, _SeqNum2} -> ok %% Some older MySQL versions?
     end.
     end.
 
 
+-spec ping(atom(), term()) -> #ok{}.
+ping(TcpModule, Socket) ->
+    {ok, SeqNum1} = send_packet(TcpModule, Socket, <<?COM_PING>>, 0),
+    {ok, OkPacket, _SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
+    parse_ok_packet(OkPacket).
+
 -spec query(Query :: iodata(), atom(), term(), timeout()) ->
 -spec query(Query :: iodata(), atom(), term(), timeout()) ->
     #ok{} | #resultset{} | #error{} | {error, timeout}.
     #ok{} | #resultset{} | #error{} | {error, timeout}.
 query(Query, TcpModule, Socket, Timeout) ->
 query(Query, TcpModule, Socket, Timeout) ->

+ 27 - 1
test/mysql_tests.erl

@@ -48,6 +48,31 @@ connect_test() ->
     ?assertMatch({error, _}, mysql:code_change("2.0.0", unknown_state, [])),
     ?assertMatch({error, _}, mysql:code_change("2.0.0", unknown_state, [])),
     exit(whereis(tardis), normal).
     exit(whereis(tardis), normal).
 
 
+keep_alive_test() ->
+     %% Let the connection send a few pings.
+     process_flag(trap_exit, true),
+     Options = [{user, ?user}, {password, ?password}, {keep_alive, 20}],
+     {ok, Pid} = mysql:start_link(Options),
+     receive after 70 -> ok end,
+     State = get_state(Pid),
+     [state, _Version, _ConnectionId, Socket | _] = tuple_to_list(State),
+     {ok, ExitMessage, LoggedErrors} = error_logger_acc:capture(fun () ->
+         gen_tcp:close(Socket),
+         receive
+            Message -> Message
+         after 1000 ->
+             ping_didnt_crash_connection
+         end
+     end),
+     process_flag(trap_exit, false),
+     %% Check that we got the expected crash report in the error log.
+     ?assertMatch({'EXIT', Pid, _Reason}, ExitMessage),
+     [{error, LoggedMsg}, {error_report, LoggedReport}] = LoggedErrors,
+     ExpectedPrefix = io_lib:format("** Generic server ~p terminating", [Pid]),
+     ?assert(lists:prefix(lists:flatten(ExpectedPrefix), LoggedMsg)),
+     ?assertMatch({crash_report, _}, LoggedReport),
+     exit(Pid, normal).
+
 %% For R16B where sys:get_state/1 is not available.
 %% For R16B where sys:get_state/1 is not available.
 get_state(Process) ->
 get_state(Process) ->
     {status,_,_,[_,_,_,_,Misc]} = sys:get_status(Process),
     {status,_,_,[_,_,_,_,Misc]} = sys:get_status(Process),
@@ -57,7 +82,8 @@ query_test_() ->
     {setup,
     {setup,
      fun () ->
      fun () ->
          {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password},
          {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password},
-                                       {log_warnings, false}]),
+                                       {log_warnings, false},
+                                       {keep_alive, true}]),
          ok = mysql:query(Pid, <<"DROP DATABASE IF EXISTS otptest">>),
          ok = mysql:query(Pid, <<"DROP DATABASE IF EXISTS otptest">>),
          ok = mysql:query(Pid, <<"CREATE DATABASE otptest">>),
          ok = mysql:query(Pid, <<"CREATE DATABASE otptest">>),
          ok = mysql:query(Pid, <<"USE otptest">>),
          ok = mysql:query(Pid, <<"USE otptest">>),