Просмотр исходного кода

Monitor calling process during transaction

Monitoring the caller allows the connection to kill itself if the
calling process is killed (#96)

When using Poolboy (or similar) for connection pooling and a when
process who has checked out a connection is killed, the connection
goes back into the pool, even if it is in an ongoing transaction.
Monitoring the calling process when a transaction is started, the
connection can kill itself if the caller is killed and thus prevent
connections in a bad state from being put back into the pool.
Silviu Caragea 6 лет назад
Родитель
Сommit
395efb0925
2 измененных файлов с 85 добавлено и 12 удалено
  1. 33 12
      src/mysql.erl
  2. 52 0
      test/transaction_tests.erl

+ 33 - 12
src/mysql.erl

@@ -483,6 +483,7 @@ encode(Conn, Term) ->
                 query_timeout, query_cache_time,
                 affected_rows = 0, status = 0, warning_count = 0, insert_id = 0,
                 transaction_level = 0, ping_ref = undefined,
+                monitors = [],
                 stmts = dict:new(), query_cache = empty, cap_found_rows = false}).
 
 %% @private
@@ -715,11 +716,14 @@ handle_call(backslash_escapes_enabled, _From, State = #state{status = S}) ->
     {reply, S band ?SERVER_STATUS_NO_BACKSLASH_ESCAPES == 0, State};
 handle_call(in_transaction, _From, State) ->
     {reply, State#state.status band ?SERVER_STATUS_IN_TRANS /= 0, State};
-handle_call(start_transaction, _From,
+handle_call(start_transaction, {FromPid, _},
             State = #state{socket = Socket, sockmod = SockMod,
-                           transaction_level = L, status = Status})
+                           transaction_level = L, status = Status, monitors = Monitors})
   when Status band ?SERVER_STATUS_IN_TRANS == 0, L == 0;
        Status band ?SERVER_STATUS_IN_TRANS /= 0, L > 0 ->
+
+    MRef = erlang:monitor(process, FromPid),
+
     Query = case L of
         0 -> <<"BEGIN">>;
         _ -> <<"SAVEPOINT s", (integer_to_binary(L))/binary>>
@@ -730,10 +734,13 @@ handle_call(start_transaction, _From,
                                                ?cmd_timeout),
     SockMod:setopts(Socket, [{active, once}]),
     State1 = update_state(Res, State),
-    {reply, ok, State1#state{transaction_level = L + 1}};
-handle_call(rollback, _From, State = #state{socket = Socket, sockmod = SockMod,
-                                            status = Status, transaction_level = L})
+    {reply, ok, State1#state{transaction_level = L + 1, monitors = [{FromPid, MRef} | Monitors]}};
+handle_call(rollback, {FromPid, _}, State = #state{socket = Socket, sockmod = SockMod,
+                                                   status = Status, transaction_level = L,
+                                                   monitors = [{FromPid, MRef}|NewMonitors]})
   when Status band ?SERVER_STATUS_IN_TRANS /= 0, L >= 1 ->
+    erlang:demonitor(MRef),
+
     Query = case L of
         1 -> <<"ROLLBACK">>;
         _ -> <<"ROLLBACK TO s", (integer_to_binary(L - 1))/binary>>
@@ -744,10 +751,13 @@ handle_call(rollback, _From, State = #state{socket = Socket, sockmod = SockMod,
                                                ?cmd_timeout),
     SockMod:setopts(Socket, [{active, once}]),
     State1 = update_state(Res, State),
-    {reply, ok, State1#state{transaction_level = L - 1}};
-handle_call(commit, _From, State = #state{socket = Socket, sockmod = SockMod,
-                                          status = Status, transaction_level = L})
+    {reply, ok, State1#state{transaction_level = L - 1, monitors = NewMonitors}};
+handle_call(commit, {FromPid, _}, State = #state{socket = Socket, sockmod = SockMod,
+                                                 status = Status, transaction_level = L,
+                                                 monitors = [{FromPid, MRef}|NewMonitors]})
   when Status band ?SERVER_STATUS_IN_TRANS /= 0, L >= 1 ->
+    erlang:demonitor(MRef),
+
     Query = case L of
         1 -> <<"COMMIT">>;
         _ -> <<"RELEASE SAVEPOINT s", (integer_to_binary(L - 1))/binary>>
@@ -758,7 +768,7 @@ handle_call(commit, _From, State = #state{socket = Socket, sockmod = SockMod,
                                                ?cmd_timeout),
     SockMod:setopts(Socket, [{active, once}]),
     State1 = update_state(Res, State),
-    {reply, ok, State1#state{transaction_level = L - 1}}.
+    {reply, ok, State1#state{transaction_level = L - 1, monitors = NewMonitors}}.
 
 %% @private
 handle_cast(_Msg, State) ->
@@ -782,6 +792,8 @@ handle_info(query_cache, #state{query_cache = Cache,
     mysql_cache:size(Cache1) > 0 andalso
         erlang:send_after(CacheTime, self(), query_cache),
     {noreply, State#state{query_cache = Cache1}};
+handle_info({'DOWN', _MRef, _, Pid, _Info}, State) ->
+    stop_server({application_process_died, Pid}, State);
 handle_info(ping, #state{socket = Socket, sockmod = SockMod} = State) ->
     SockMod:setopts(Socket, [{active, false}]),
     SockMod = State#state.sockmod,
@@ -880,14 +892,15 @@ handle_query_call_reply([], _Query, State, ResultSetsAcc) ->
         [_|_]                 -> {ok, lists:reverse(ResultSetsAcc)}
     end,
     {reply, Reply, State};
-handle_query_call_reply([Rec|Recs], Query, State, ResultSetsAcc) ->
+handle_query_call_reply([Rec|Recs], Query, #state{monitors = Monitors} = State, ResultSetsAcc) ->
     case Rec of
         #ok{status = Status} when Status band ?SERVER_STATUS_IN_TRANS == 0,
                                   State#state.transaction_level > 0 ->
             %% DDL statements (e.g. CREATE TABLE, ALTER TABLE, etc.) result in
             %% an implicit commit.
             Reply = {implicit_commit, State#state.transaction_level, Query},
-            {reply, Reply, State#state{transaction_level = 0}};
+            NewMonitors = demonitor_processes(Monitors, length(Monitors)),
+            {reply, Reply, State#state{transaction_level = 0, monitors = NewMonitors}};
         #ok{} ->
             handle_query_call_reply(Recs, Query, State, ResultSetsAcc);
         #resultset{cols = ColDefs, rows = Rows} ->
@@ -900,7 +913,8 @@ handle_query_call_reply([Rec|Recs], Query, State, ResultSetsAcc) ->
                      error_to_reason(Rec)},
             %% Everything in the transaction is rolled back, except the BEGIN
             %% statement itself. Thus, we are in transaction level 1.
-            {reply, Reply, State#state{transaction_level = 1}};
+            NewMonitors = demonitor_processes(Monitors, length(Monitors) -1),
+            {reply, Reply, State#state{transaction_level = 1, monitors = NewMonitors}};
         #error{} ->
             {reply, {error, error_to_reason(Rec)}, State}
     end.
@@ -954,3 +968,10 @@ stop_server(Reason,
                          [ConnId, Reason]),
   ok = gen_tcp:close(Socket),
   {stop, Reason, State#state{socket = undefined, connection_id = undefined}}.
+
+demonitor_processes(List, 0) ->
+    List;
+demonitor_processes([{_FromPid, MRef}|T], Count) ->
+    erlang:demonitor(MRef),
+    demonitor_processes(T, Count -1).
+

+ 52 - 0
test/transaction_tests.erl

@@ -48,6 +48,58 @@ single_connection_test_() ->
           {"Implicit commit",      fun () -> implicit_commit(Pid) end}]
      end}.
 
+application_process_kill_test() ->
+    {ok, Pid} = mysql:start_link([
+        {user, ?user},
+        {password, ?password},
+        {query_cache_time, 50},
+        {log_warnings, false}
+    ]),
+
+    unlink(Pid),
+    Mref = erlang:monitor(process, Pid),
+
+    ok = mysql:query(Pid, <<"DROP DATABASE IF EXISTS otptest">>),
+    ok = mysql:query(Pid, <<"CREATE DATABASE otptest">>),
+    ok = mysql:query(Pid, <<"USE otptest">>),
+    ok = mysql:query(Pid, <<"CREATE TABLE foo (bar INT) engine=InnoDB">>),
+
+    ?assertNot(mysql:in_transaction(Pid)),
+    ?assert(is_process_alive(Pid)),
+
+    Self = self(),
+
+    AppPid = spawn(fun() ->
+        mysql:transaction(Pid, fun () ->
+            ok = mysql:query(Pid, "INSERT INTO foo (bar) VALUES (42)"),
+            Self! killme,
+            receive after 10000 -> throw(too_long) end,
+            ok
+        end)
+    end),
+
+    receive killme -> exit(AppPid, kill) end,
+
+    receive
+        {'DOWN', Mref, process, Pid, {application_process_died, AppPid}}->
+            ok
+        after 10000 ->
+            throw(too_long)
+    end,
+
+    ?assertNot(is_process_alive(Pid)),
+
+    {ok, Pid2} = mysql:start_link([
+        {user, ?user},
+        {password, ?password},
+        {query_cache_time, 50},
+        {log_warnings, false}
+    ]),
+    ok = mysql:query(Pid2, <<"USE otptest">>),
+    ?assertMatch({ok, _, []}, mysql:query(Pid2, <<"SELECT * from foo where bar = 42">>)),
+    ok = mysql:query(Pid2, <<"DROP DATABASE otptest">>),
+    exit(Pid2, normal).
+
 simple_atomic(Pid) ->
     ?assertNot(mysql:in_transaction(Pid)),
     Result = mysql:transaction(Pid, fun () ->