Browse Source

Merge pull request #41 from johlo/detect-server-close, fixes #33

Stop the gen_server when mysql server closes connection
Viktor Söderqvist 9 years ago
parent
commit
14e2c709fa
1 changed files with 49 additions and 9 deletions
  1. 49 9
      src/mysql.erl

+ 49 - 9
src/mysql.erl

@@ -479,12 +479,14 @@ init(Opts) ->
     end,
 
     %% Connect socket
-    SockOpts = [{active, false}, binary, {packet, raw} | TcpOpts],
+    SockOpts = [binary, {packet, raw} | TcpOpts],
     {ok, Socket} = gen_tcp:connect(Host, Port, SockOpts),
 
     %% Exchange handshake communication.
+    inet:setopts(Socket, [{active, false}]),
     Result = mysql_protocol:handshake(User, Password, Database, gen_tcp,
                                       Socket),
+    inet:setopts(Socket, [{active, once}]),
     case Result of
         #handshake{server_version = Version, connection_id = ConnId,
                    status = Status} ->
@@ -553,6 +555,7 @@ handle_call({query, Query}, From, State) ->
     handle_call({query, Query, State#state.query_timeout}, From, State);
 handle_call({query, Query, Timeout}, _From, State) ->
     Socket = State#state.socket,
+    inet:setopts(Socket, [{active, false}]),
     {ok, Recs} = case mysql_protocol:query(Query, gen_tcp, Socket, Timeout) of
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
             kill_query(State),
@@ -564,6 +567,7 @@ handle_call({query, Query, Timeout}, _From, State) ->
         QueryResult ->
             QueryResult
     end,
+    inet:setopts(Socket, [{active, once}]),
     State1 = lists:foldl(fun update_state/2, State, Recs),
     State1#state.warning_count > 0 andalso State1#state.log_warnings
         andalso log_warnings(State1, Query),
@@ -582,7 +586,9 @@ handle_call({param_query, Query, Params, Timeout}, _From, State) ->
             {{ok, FoundStmt}, NewCache};
         not_found ->
             %% Prepare
+            inet:setopts(Socket, [{active, false}]),
             Rec = mysql_protocol:prepare(Query, gen_tcp, Socket),
+            inet:setopts(Socket, [{active, once}]),
             %State1 = update_state(Rec, State),
             case Rec of
                 #error{} = E ->
@@ -614,7 +620,9 @@ handle_call({execute, Stmt, Args, Timeout}, _From, State) ->
     end;
 handle_call({prepare, Query}, _From, State) ->
     #state{socket = Socket} = State,
+    inet:setopts(Socket, [{active, false}]),
     Rec = mysql_protocol:prepare(Query, gen_tcp, Socket),
+    inet:setopts(Socket, [{active, once}]),
     State1 = update_state(Rec, State),
     case Rec of
         #error{} = E ->
@@ -627,6 +635,7 @@ handle_call({prepare, Query}, _From, State) ->
 handle_call({prepare, Name, Query}, _From, State) when is_atom(Name) ->
     #state{socket = Socket} = State,
     %% First unprepare if there is an old statement with this name.
+    inet:setopts(Socket, [{active, false}]),
     State1 = case dict:find(Name, State#state.stmts) of
         {ok, OldStmt} ->
             mysql_protocol:unprepare(OldStmt, gen_tcp, Socket),
@@ -635,6 +644,7 @@ handle_call({prepare, Name, Query}, _From, State) when is_atom(Name) ->
             State
     end,
     Rec = mysql_protocol:prepare(Query, gen_tcp, Socket),
+    inet:setopts(Socket, [{active, once}]),
     State2 = update_state(Rec, State1),
     case Rec of
         #error{} = E ->
@@ -649,7 +659,9 @@ handle_call({unprepare, Stmt}, _From, State) when is_atom(Stmt);
     case dict:find(Stmt, State#state.stmts) of
         {ok, StmtRec} ->
             #state{socket = Socket} = State,
+            inet:setopts(Socket, [{active, false}]),
             mysql_protocol:unprepare(StmtRec, gen_tcp, Socket),
+            inet:setopts(Socket, [{active, once}]),
             State1 = State#state{stmts = dict:erase(Stmt, State#state.stmts)},
             State2 = schedule_ping(State1),
             {reply, ok, State2};
@@ -677,8 +689,10 @@ handle_call(start_transaction, _From,
         0 -> <<"BEGIN">>;
         _ -> <<"SAVEPOINT s", (integer_to_binary(L))/binary>>
     end,
+    inet:setopts(Socket, [{active, false}]),
     {ok, [Res = #ok{}]} = mysql_protocol:query(Query, gen_tcp, Socket,
                                                ?cmd_timeout),
+    inet:setopts(Socket, [{active, once}]),
     State1 = update_state(Res, State),
     {reply, ok, State1#state{transaction_level = L + 1}};
 handle_call(rollback, _From, State = #state{socket = Socket, status = Status,
@@ -688,8 +702,10 @@ handle_call(rollback, _From, State = #state{socket = Socket, status = Status,
         1 -> <<"ROLLBACK">>;
         _ -> <<"ROLLBACK TO s", (integer_to_binary(L - 1))/binary>>
     end,
+    inet:setopts(Socket, [{active, false}]),
     {ok, [Res = #ok{}]} = mysql_protocol:query(Query, gen_tcp, Socket,
                                                ?cmd_timeout),
+    inet:setopts(Socket, [{active, once}]),
     State1 = update_state(Res, State),
     {reply, ok, State1#state{transaction_level = L - 1}};
 handle_call(commit, _From, State = #state{socket = Socket, status = Status,
@@ -699,8 +715,10 @@ handle_call(commit, _From, State = #state{socket = Socket, status = Status,
         1 -> <<"COMMIT">>;
         _ -> <<"RELEASE SAVEPOINT s", (integer_to_binary(L - 1))/binary>>
     end,
+    inet:setopts(Socket, [{active, false}]),
     {ok, [Res = #ok{}]} = mysql_protocol:query(Query, gen_tcp, Socket,
                                                ?cmd_timeout),
+    inet:setopts(Socket, [{active, once}]),
     State1 = update_state(Res, State),
     {reply, ok, State1#state{transaction_level = L - 1}}.
 
@@ -709,31 +727,42 @@ handle_cast(_Msg, State) ->
     {noreply, State}.
 
 %% @private
-handle_info(query_cache, State = #state{query_cache = Cache,
-                                        query_cache_time = CacheTime}) ->
+handle_info(query_cache, #state{query_cache = Cache,
+                                query_cache_time = CacheTime} = State) ->
     %% Evict expired queries/statements in the cache used by query/3.
     {Evicted, Cache1} = mysql_cache:evict_older_than(Cache, CacheTime),
     %% Unprepare the evicted statements
     #state{socket = Socket} = State,
+    inet:setopts(Socket, [{active, false}]),
     lists:foreach(fun ({_Query, Stmt}) ->
                       mysql_protocol:unprepare(Stmt, gen_tcp, Socket)
                   end,
                   Evicted),
+    inet:setopts(Socket, [{active, once}]),
     %% If nonempty, schedule eviction again.
     mysql_cache:size(Cache1) > 0 andalso
         erlang:send_after(CacheTime, self(), query_cache),
     {noreply, State#state{query_cache = Cache1}};
-handle_info(ping, State) ->
-    Ok = mysql_protocol:ping(gen_tcp, State#state.socket),
+handle_info(ping, #state{socket = Socket} = State) ->
+    inet:setopts(Socket, [{active, false}]),
+    Ok = mysql_protocol:ping(gen_tcp, Socket),
+    inet:setopts(Socket, [{active, once}]),
     {noreply, update_state(Ok, State)};
+handle_info({tcp_closed, _Socket}, State) ->
+    stop_server(tcp_closed, State);
+handle_info({tcp_error, _Socket, Reason}, State) ->
+    stop_server({tcp_error, Reason}, State);
 handle_info(_Info, State) ->
     {noreply, State}.
 
 %% @private
-terminate(Reason, State) when Reason == normal; Reason == shutdown ->
-    %% Send the goodbye message for politeness.
-    #state{socket = Socket} = State,
-    mysql_protocol:quit(gen_tcp, Socket);
+terminate(Reason, #state{socket = Socket})
+  when Reason == normal; Reason == shutdown ->
+      %% Send the goodbye message for politeness.
+      inet:setopts(Socket, [{active, false}]),
+      R = mysql_protocol:quit(gen_tcp, Socket),
+      inet:setopts(Socket, [{active, once}]),
+      R;
 terminate(_Reason, _State) ->
     ok.
 
@@ -760,6 +789,7 @@ query_call(Conn, CallReq) ->
 
 %% @doc Executes a prepared statement and returns {Reply, NextState}.
 execute_stmt(Stmt, Args, Timeout, State = #state{socket = Socket}) ->
+    inet:setopts(Socket, [{active, false}]),
     {ok, Recs} = case mysql_protocol:execute(Stmt, Args, gen_tcp, Socket,
                                              Timeout) of
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
@@ -773,6 +803,7 @@ execute_stmt(Stmt, Args, Timeout, State = #state{socket = Socket}) ->
         QueryResult ->
             QueryResult
     end,
+    inet:setopts(Socket, [{active, false}]),
     State1 = lists:foldl(fun update_state/2, State, Recs),
     State1#state.warning_count > 0 andalso State1#state.log_warnings
         andalso log_warnings(State1, Stmt#prepared.orig_query),
@@ -852,9 +883,11 @@ clear_transaction_status(State = #state{status = Status}) ->
 
 %% @doc Fetches and logs warnings. Query is the query that gave the warnings.
 log_warnings(#state{socket = Socket}, Query) ->
+    inet:setopts(Socket, [{active, false}]),
     {ok, [#resultset{rows = Rows}]} = mysql_protocol:query(<<"SHOW WARNINGS">>,
                                                            gen_tcp, Socket,
                                                            ?cmd_timeout),
+    inet:setopts(Socket, [{active, once}]),
     Lines = [[Level, " ", integer_to_binary(Code), ": ", Message, "\n"]
              || [Level, Code, Message] <- Rows],
     error_logger:warning_msg("~s in ~s~n", [Lines, Query]).
@@ -882,3 +915,10 @@ kill_query(#state{connection_id = ConnId, host = Host, port = Port,
             error_logger:error_msg("Failed to connect to kill query: ~p",
                                    [error_to_reason(E)])
     end.
+
+stop_server(Reason,
+            #state{socket = Socket, connection_id = ConnId} = State) ->
+  error_logger:error_msg("Connection Id ~p closing with reason: ~p~n",
+                         [ConnId, Reason]),
+  ok = gen_tcp:close(Socket),
+  {stop, Reason, State#state{socket = undefined, connection_id = undefined}}.