Browse Source

Move on-connect queries and prepares into mysql_conn

juhlig 6 years ago
parent
commit
aaca8d013a
4 changed files with 140 additions and 135 deletions
  1. 11 54
      src/mysql.erl
  2. 116 68
      src/mysql_conn.erl
  3. 5 5
      test/mysql_change_user_tests.erl
  4. 8 8
      test/mysql_tests.erl

+ 11 - 54
src/mysql.erl

@@ -66,7 +66,7 @@
 -include("exception.hrl").
 
 %% @doc Starts a connection gen_server process and connects to a database. To
-%% disconnect just do `exit(Pid, normal)'.
+%% disconnect use `mysql:stop/1,2'.
 %%
 %% Options:
 %%
@@ -147,20 +147,12 @@
 start_link(Options) ->
     GenSrvOpts = [{timeout, proplists:get_value(connect_timeout, Options,
                                                 ?default_connect_timeout)}],
-    Ret = case proplists:get_value(name, Options) of
+    case proplists:get_value(name, Options) of
         undefined ->
             gen_server:start_link(mysql_conn, Options, GenSrvOpts);
         ServerName ->
             gen_server:start_link(ServerName, mysql_conn, Options, GenSrvOpts)
-    end,
-    case Ret of
-        {ok, Pid} ->
-            execute_after_connect(Pid,
-                                  proplists:get_value(queries, Options, []),
-                                  proplists:get_value(prepare, Options, []));
-        _ -> ok
-    end,
-    Ret.
+    end.
  
 %% @see stop/2.
 -spec stop(Conn) -> ok
@@ -660,10 +652,13 @@ change_user(Conn, Username, Password) ->
 %% an error exception and `change_user_in_transaction' as the error
 %% message.
 %%
-%% If the change user operation fails for other reasons (eg authentication
-%% failure), an error exception occurs, and the connection process
-%% exits with reason `change_user_failed'. The connection can not be used
-%% any longer if this happens.
+%% If the change user operation fails, `{error, Reason}'  will be
+%% returned. Specifically, if the operation itself fails (eg
+%% authentication failure), `change_user_failed' will be returned as
+%% the reason, while if the operation itself succeeds but one of
+%% the given initial queries or prepares fails, the reason will
+%% reflect the cause for the failure. In any case, the connection
+%% process will exit with the same reason and cannot be used any longer.
 %%
 %% For a description of the `database', `queries' and `prepare'
 %% options, see `start_link/1'.
@@ -684,17 +679,7 @@ change_user(Conn, Username, Password, Options) ->
         true -> error(change_user_in_transaction);
         false -> ok
     end,
-    Database = proplists:get_value(database, Options, undefined),
-    Ret = gen_server:call(Conn, {change_user, Username, Password, Database}),
-    case Ret of
-        ok ->
-            execute_after_connect(Conn,
-                                  proplists:get_value(queries, Options, []),
-                                  proplists:get_value(prepare, Options, [])),
-            ok;
-        {error, Reason} ->
-            error(Reason)
-    end.
+    gen_server:call(Conn, {change_user, Username, Password, Options}).
 
 %% @doc Encodes a term as a MySQL literal so that it can be used to inside a
 %% query. If backslash escapes are enabled, backslashes and single quotes in
@@ -716,34 +701,6 @@ encode(Conn, Term) ->
 
 %% --- Helpers ---
 
-%% @doc Executes the given queries and prepares the given statements after a
-%% connection has been made.
-%%
-%% If any of the queries or prepares fails, the connection is closed and an
-%% exception is raised.
--spec execute_after_connect(connection(), [iodata()], [{atom(), iodata()}])
-    -> ok.
-execute_after_connect(Conn, Queries, Prepares) ->
-    try
-        lists:foreach(fun (Query) ->
-                          case query(Conn, Query) of
-                              ok -> ok;
-                              {ok, _} -> ok;
-                              {ok, _, _} -> ok
-                          end
-                      end,
-                      Queries),
-        lists:foreach(fun ({Name, Stmt}) ->
-                          {ok, Name} = prepare(Conn, Name, Stmt)
-                      end,
-                      Prepares),
-        ok
-    catch
-        ?EXCEPTION(Class, Reason, Stacktrace) ->
-            catch stop(Conn, ?default_connect_timeout),
-            erlang:raise(Class, Reason, ?GET_STACK(Stacktrace))
-    end.
-
 %% @doc Makes a gen_server call for a query (plain, parametrized or prepared),
 %% checks the reply and sometimes throws an exception when we need to jump out
 %% of a transaction.

+ 116 - 68
src/mysql_conn.erl

@@ -82,6 +82,9 @@ init(Opts) ->
     SSLOpts        = proplists:get_value(ssl, Opts, undefined),
     SockMod0       = gen_tcp,
 
+    Queries        = proplists:get_value(queries, Opts, []),
+    Prepares       = proplists:get_value(prepare, Opts, []),
+
     PingTimeout = case KeepAlive of
         true         -> ?default_ping_timeout;
         false        -> infinity;
@@ -126,14 +129,39 @@ init(Opts) ->
                            query_timeout = Timeout,
                            query_cache_time = QueryCacheTime,
                            cap_found_rows = (SetFoundRows =:= true)},
-            %% Trap exit so that we can properly disconnect when we die.
-            process_flag(trap_exit, true),
-            State1 = schedule_ping(State),
-            {ok, State1};
+            case execute_on_connect(Queries, Prepares, State) of
+                {ok, State1} ->
+                    process_flag(trap_exit, true),
+                    State2 = schedule_ping(State1),
+                    {ok, State2};
+                {error, Reason} ->
+                    {stop, Reason}
+            end;
         #error{} = E ->
             {stop, error_to_reason(E)}
     end.
 
+execute_on_connect([], [], State) ->
+    {ok, State};
+execute_on_connect([], [{Name, Stmt}|Prepares], State) ->
+    case do_named_prepare(Name, Stmt, State) of
+        {{ok, Name}, State1} ->
+            execute_on_connect([], Prepares, State1);
+        {{error, _} = E, _} ->
+            E
+    end;
+execute_on_connect([Query|Queries], Prepares, State) ->
+    case do_query(Query, no_filtermap_fun, default_timeout, State) of
+        {ok, State1} ->
+            execute_on_connect(Queries, Prepares, State1);
+        {{ok, _}, State1} ->
+            execute_on_connect(Queries, Prepares, State1);
+        {{ok, _, _}, State1} ->
+            execute_on_connect(Queries, Prepares, State1);
+        {{error, _} = E, _} ->
+            E
+    end.
+
 %% @private
 %% @doc
 %%
@@ -172,30 +200,9 @@ init(Opts) ->
 %%       able to handle this in the caller's process, we also return the
 %%       nesting level.</dd>
 %% </dl>
-handle_call({query, Query, FilterMap, default_timeout}, From, State) ->
-    handle_call({query, Query, FilterMap, State#state.query_timeout}, From,
-                State);
-handle_call({query, Query, FilterMap, Timeout}, _From,
-            #state{sockmod = SockMod, socket = Socket} = State) ->
-    setopts(SockMod, Socket, [{active, false}]),
-    Result = mysql_protocol:query(Query, SockMod, Socket, FilterMap, Timeout),
-    {ok, Recs} = case Result of
-        {error, timeout} when State#state.server_version >= [5, 0, 0] ->
-            kill_query(State),
-            mysql_protocol:fetch_query_response(SockMod, Socket, FilterMap,
-                                                ?cmd_timeout);
-        {error, timeout} ->
-            %% For MySQL 4.x.x there is no way to recover from timeout except
-            %% killing the connection itself.
-            exit(timeout);
-        QueryResult ->
-            QueryResult
-    end,
-    setopts(SockMod, Socket, [{active, once}]),
-    State1 = lists:foldl(fun update_state/2, State, Recs),
-    State1#state.warning_count > 0 andalso State1#state.log_warnings
-        andalso log_warnings(State1, Query),
-    handle_query_call_reply(Recs, Query, State1, []);
+handle_call({query, Query, FilterMap, Timeout}, _From, State) ->
+    {Reply, State1} = do_query(Query, FilterMap, Timeout, State),
+    {reply, Reply, State1};
 handle_call({param_query, Query, Params, FilterMap, default_timeout}, From,
             State) ->
     handle_call({param_query, Query, Params, FilterMap,
@@ -258,27 +265,8 @@ handle_call({prepare, Query}, _From, State) ->
             {reply, {ok, Id}, State2}
     end;
 handle_call({prepare, Name, Query}, _From, State) when is_atom(Name) ->
-    #state{socket = Socket, sockmod = SockMod} = State,
-    %% First unprepare if there is an old statement with this name.
-    setopts(SockMod, Socket, [{active, false}]),
-    State1 = case dict:find(Name, State#state.stmts) of
-        {ok, OldStmt} ->
-            mysql_protocol:unprepare(OldStmt, SockMod, Socket),
-            State#state{stmts = dict:erase(Name, State#state.stmts)};
-        error ->
-            State
-    end,
-    Rec = mysql_protocol:prepare(Query, SockMod, Socket),
-    setopts(SockMod, Socket, [{active, once}]),
-    State2 = update_state(Rec, State1),
-    case Rec of
-        #error{} = E ->
-            {reply, {error, error_to_reason(E)}, State2};
-        #prepared{} = Stmt ->
-            Stmts1 = dict:store(Name, Stmt, State2#state.stmts),
-            State3 = State2#state{stmts = Stmts1},
-            {reply, {ok, Name}, State3}
-    end;
+    {Reply, State1} = do_named_prepare(Name, Query, State),
+    {reply, Reply, State1};
 handle_call({unprepare, Stmt}, _From, State) when is_atom(Stmt);
                                                   is_integer(Stmt) ->
     case dict:find(Stmt, State#state.stmts) of
@@ -293,11 +281,14 @@ handle_call({unprepare, Stmt}, _From, State) when is_atom(Stmt);
         error ->
             {reply, {error, not_prepared}, State}
     end;
-handle_call({change_user, Username, Password, Database}, From,
+handle_call({change_user, Username, Password, Options}, From,
             State = #state{transaction_levels = []}) ->
     #state{socket = Socket, sockmod = SockMod,
            auth_plugin_data = AuthPluginData,
            server_version = ServerVersion} = State,
+    Database = proplists:get_value(database, Options, undefined),
+    Queries = proplists:get_value(queries, Options, []),
+    Prepares = proplists:get_value(prepare, Options, []),
     setopts(SockMod, Socket, [{active, false}]),
     Result = mysql_protocol:change_user(SockMod, Socket, Username, Password,
                                         AuthPluginData, Database, 
@@ -309,7 +300,14 @@ handle_call({change_user, Username, Password, Database}, From,
     State2 = State1#state{query_cache = empty, stmts = dict:new()},
     case Result of
         #ok{} ->
-            {reply, ok, State2#state{user = Username, password = Password}};
+            State3 = State2#state{user = Username, password = Password},
+            case execute_on_connect(Queries, Prepares, State3) of
+                {ok, State4} ->
+                    {reply, ok, State4};
+                {error, Reason} = E ->
+                    gen_server:reply(From, E),
+                    stop_server(Reason, State3)
+            end;
         #error{} = E ->
             gen_server:reply(From, {error, error_to_reason(E)}),
             stop_server(change_user_failed, State2)
@@ -425,7 +423,7 @@ code_change(_OldVsn, _State, _Extra) ->
 
 %% --- Helpers ---
 
-%% @doc Executes a prepared statement and returns {Reply, NextState}.
+%% @doc Executes a prepared statement and returns {reply, Reply, NextState}.
 execute_stmt(Stmt, Args, FilterMap, Timeout,
              State = #state{socket = Socket, sockmod = SockMod}) ->
     setopts(SockMod, Socket, [{active, false}]),
@@ -446,7 +444,8 @@ execute_stmt(Stmt, Args, FilterMap, Timeout,
     State1 = lists:foldl(fun update_state/2, State, Recs),
     State1#state.warning_count > 0 andalso State1#state.log_warnings
         andalso log_warnings(State1, Stmt#prepared.orig_query),
-    handle_query_call_reply(Recs, Stmt#prepared.orig_query, State1, []).
+    {Reply, State2} = handle_query_call_result(Recs, Stmt#prepared.orig_query, State1, []),
+    {reply, Reply, State2}.
 
 %% @doc Produces a tuple to return as an error reason.
 -spec error_to_reason(#error{}) -> mysql:server_reason().
@@ -472,17 +471,66 @@ update_state(Rec, State) ->
     end,
     schedule_ping(State1).
 
-%% @doc Produces a reply for handle_call/3 for queries and prepared statements.
-handle_query_call_reply([], _Query, State, ResultSetsAcc) ->
-    Reply = case ResultSetsAcc of
-        []                    -> ok;
-        [{ColumnNames, Rows}] -> {ok, ColumnNames, Rows};
-        [_|_]                 -> {ok, lists:reverse(ResultSetsAcc)}
+%% @doc executes an unparameterized query and returns {Reply, NewState}.
+do_query(Query, FilterMap, default_timeout,
+            #state{query_timeout = DefaultTimeout} = State) ->
+    do_query(Query, FilterMap, DefaultTimeout, State);
+do_query(Query, FilterMap, Timeout,
+            #state{sockmod = SockMod, socket = Socket} = State) ->
+    setopts(SockMod, Socket, [{active, false}]),
+    Result = mysql_protocol:query(Query, SockMod, Socket, FilterMap, Timeout),
+    {ok, Recs} = case Result of
+        {error, timeout} when State#state.server_version >= [5, 0, 0] ->
+            kill_query(State),
+            mysql_protocol:fetch_query_response(SockMod, Socket, FilterMap,
+                                                ?cmd_timeout);
+        {error, timeout} ->
+            %% For MySQL 4.x.x there is no way to recover from timeout except
+            %% killing the connection itself.
+            exit(timeout);
+        QueryResult ->
+            QueryResult
     end,
-    {reply, Reply, State};
-handle_query_call_reply([Rec|Recs], Query,
-                        #state{transaction_levels = L} = State,
-                        ResultSetsAcc) ->
+    setopts(SockMod, Socket, [{active, once}]),
+    State1 = lists:foldl(fun update_state/2, State, Recs),
+    State1#state.warning_count > 0 andalso State1#state.log_warnings
+        andalso log_warnings(State1, Query),
+    handle_query_call_result(Recs, Query, State1, []).
+
+%% @doc Prepares a named query and returns {{ok, Name}, NewState} or
+%% {{error, Reason}, NewState}.
+do_named_prepare(Name, Query, State) ->
+    #state{socket = Socket, sockmod = SockMod} = State,
+    %% First unprepare if there is an old statement with this name.
+    setopts(SockMod, Socket, [{active, false}]),
+    State1 = case dict:find(Name, State#state.stmts) of
+        {ok, OldStmt} ->
+            mysql_protocol:unprepare(OldStmt, SockMod, Socket),
+            State#state{stmts = dict:erase(Name, State#state.stmts)};
+        error ->
+            State
+    end,
+    Rec = mysql_protocol:prepare(Query, SockMod, Socket),
+    setopts(SockMod, Socket, [{active, once}]),
+    State2 = update_state(Rec, State1),
+    case Rec of
+        #error{} = E ->
+            {{error, error_to_reason(E)}, State2};
+        #prepared{} = Stmt ->
+            Stmts1 = dict:store(Name, Stmt, State2#state.stmts),
+            State3 = State2#state{stmts = Stmts1},
+            {{ok, Name}, State3}
+    end.
+
+%% @doc Transforms result sets into a structure appropriate to be returned
+%% to the client.
+handle_query_call_result([], _Query, State, []) ->
+    {ok, State};
+handle_query_call_result([], _Query, State, [{ColumnNames, Rows}]) ->
+    {{ok, ColumnNames, Rows}, State};
+handle_query_call_result([], _Query, State, ResultSetsAcc) ->
+    {{ok, lists:reverse(ResultSetsAcc)}, State};
+handle_query_call_result([Rec|Recs], Query, State = #state{transaction_levels = L}, ResultSetsAcc) ->
     case Rec of
         #ok{status = Status} when Status band ?SERVER_STATUS_IN_TRANS == 0,
                                   L /= [] ->
@@ -491,22 +539,22 @@ handle_query_call_reply([Rec|Recs], Query,
             Length = length(L),
             Reply = {implicit_commit, Length, Query},
             [] = demonitor_processes(L, Length),
-            {reply, Reply, State#state{transaction_levels = []}};
+            {Reply, State#state{transaction_levels = []}};
         #ok{} ->
-            handle_query_call_reply(Recs, Query, State, ResultSetsAcc);
+            handle_query_call_result(Recs, Query, State, ResultSetsAcc);
         #resultset{cols = ColDefs, rows = Rows} ->
             Names = [Def#col.name || Def <- ColDefs],
             ResultSetsAcc1 = [{Names, Rows} | ResultSetsAcc],
-            handle_query_call_reply(Recs, Query, State, ResultSetsAcc1);
+            handle_query_call_result(Recs, Query, State, ResultSetsAcc1);
         #error{code = ?ERROR_DEADLOCK} when L /= [] ->
             %% These errors result in an implicit rollback.
             Reply = {implicit_rollback, length(L), error_to_reason(Rec)},
             %% Everything in the transaction is rolled back, except the BEGIN
             %% statement itself. Thus, we are in transaction level 1.
             NewMonitors = demonitor_processes(L, length(L) - 1),
-            {reply, Reply, State#state{transaction_levels = NewMonitors}};
+            {Reply, State#state{transaction_levels = NewMonitors}};
         #error{} ->
-            {reply, {error, error_to_reason(Rec)}, State}
+            {{error, error_to_reason(Rec)}, State}
     end.
 
 %% @doc Schedules (or re-schedules) ping.

+ 5 - 5
test/mysql_change_user_tests.erl

@@ -40,7 +40,7 @@ correct_credentials_test() ->
 incorrect_credentials_fail_test() ->
     Pid = connect_db(?user1, ?password1),
     TrapExit = erlang:process_flag(trap_exit, true),
-    ?assertError({1045, <<"28000">>, <<"Access denied", _/binary>>},
+    ?assertMatch({error, {1045, <<"28000">>, <<"Access denied", _/binary>>}},
                  mysql:change_user(Pid, ?user2, ?password1)),
     ExitReason = receive {'EXIT', Pid, Reason} -> Reason after 1000 -> error(timeout) end,
     erlang:process_flag(trap_exit, TrapExit),
@@ -129,9 +129,9 @@ execute_queries_test() ->
 execute_queries_failure_test() ->
     Pid = connect_db(?user1, ?password1),
     erlang:process_flag(trap_exit, true),
-    ?assertError(_Reason, mysql:change_user(Pid, ?user2, ?password2, [{queries, [<<"foo">>]}])),
+    {error, Reason} = mysql:change_user(Pid, ?user2, ?password2, [{queries, [<<"foo">>]}]),
     receive
-        {'EXIT', Pid, normal} -> ok
+        {'EXIT', Pid, Reason} -> ok
     after 1000 ->
         error(no_exit_message)
     end,
@@ -151,9 +151,9 @@ prepare_statements_test() ->
 prepare_statements_failure_test() ->
     Pid = connect_db(?user1, ?password1),
     erlang:process_flag(trap_exit, true),
-    ?assertError(_Reason, mysql:change_user(Pid, ?user2, ?password2, [{prepare, [{foo, <<"foo">>}]}])),
+    {error, Reason} = mysql:change_user(Pid, ?user2, ?password2, [{prepare, [{foo, <<"foo">>}]}]),
     receive
-        {'EXIT', Pid, normal} -> ok
+        {'EXIT', Pid, Reason} -> ok
     after 1000 ->
         error(no_exit_message)
     end,

+ 8 - 8
test/mysql_tests.erl

@@ -208,23 +208,23 @@ unix_socket_test() ->
     
 connect_queries_failure_test() ->
     process_flag(trap_exit, true),
-    ?assertError(_Reason, mysql:start_link([{user, ?user}, {password, ?password},
-                                            {queries, ["foo"]}])),
+    {error, Reason} = mysql:start_link([{user, ?user}, {password, ?password},
+                                        {queries, ["foo"]}]),
     receive
-        {'EXIT', _Pid, normal} -> ok
+        {'EXIT', _Pid, Reason} -> ok
     after 1000 ->
-        error(no_exit_message)
+        exit(no_exit_message)
     end,
     process_flag(trap_exit, false).
 
 connect_prepare_failure_test() ->
     process_flag(trap_exit, true),
-    ?assertError(_Reason, mysql:start_link([{user, ?user}, {password, ?password},
-                                            {prepare, [{foo, "foo"}]}])),
+    {error, Reason} = mysql:start_link([{user, ?user}, {password, ?password},
+                                        {prepare, [{foo, "foo"}]}]),
     receive
-        {'EXIT', _Pid, normal} -> ok
+        {'EXIT', _Pid, Reason} -> ok
     after 1000 ->
-        error(no_exit_message)
+        exit(no_exit_message)
     end,
     process_flag(trap_exit, false).