Browse Source

Enable connect modes: synchronous (default), asynchronous and lazy

An option to select connect mode is added (#142)
Jan Uhlig 5 years ago
parent
commit
1c8475001a
3 changed files with 198 additions and 53 deletions
  1. 31 6
      src/mysql.erl
  2. 125 45
      src/mysql_conn.erl
  3. 42 2
      test/mysql_tests.erl

+ 31 - 6
src/mysql.erl

@@ -26,6 +26,7 @@
 -module(mysql).
 
 -export([start_link/1, stop/1, stop/2,
+         is_connected/1,
          query/2, query/3, query/4, query/5,
          execute/3, execute/4, execute/5,
          prepare/2, prepare/3, unprepare/2,
@@ -61,8 +62,6 @@
                       | {ok, [{column_names(), rows()}, ...]}
                       | {error, server_reason()}.
 
--define(default_connect_timeout, 5000).
-
 -include("exception.hrl").
 
 %% @doc Starts a connection gen_server process and connects to a database. To
@@ -88,6 +87,28 @@
 %%   <dt>`{database, Database}'</dt>
 %%   <dd>The name of the database AKA schema to use. This can be changed later
 %%       using the query `USE <database>'.</dd>
+%%   <dt>`{connect_mode, synchronous | asynchronous | lazy}'</dt>
+%%   <dd>Specifies how and when the connection process should establish a connection
+%%       to the MySQL server.
+%%       <dl>
+%%         <dt>`synchronus' (default)</dt>
+%%         <dd>The connection will be established as part of the connection process'
+%%             start routine, ie the returned connection process will already be
+%%             connected and ready to use, and any on-connect prepares and queries
+%%             will have been executed.</dd>
+%%         <dt>`asynchronous'</dt>
+%%         <dd>The connection process will be started and returned to the caller
+%%             before really establishing a connection to the server and executing
+%%             the on-connect prepares and executes. This will instead be done
+%%             immediately afterwards as the first action of the connection
+%%             process.</dd>
+%%         <dt>`lazy'</dt>
+%%         <dd>Similar to `asynchronous' mode, but an actual connection will be
+%%             esatblished and the on-connect prepares and queries executed only
+%%             when a connection is needed for the first time, eg. to execute a
+%%             query.</dd>
+%%      </dl>
+%%   </dd>
 %%   <dt>`{connect_timeout, Timeout}'</dt>
 %%   <dd>The maximum time to spend for start_link/1.</dd>
 %%   <dt>`{log_warnings, boolean()}'</dt>
@@ -130,6 +151,7 @@
                    {host, inet:socket_address() | inet:hostname()} | {port, integer()} |
                    {user, iodata()} | {password, iodata()} |
                    {database, iodata()} |
+                   {connect_mode, synchronous | asynchronous | lazy} |
                    {connect_timeout, timeout()} |
                    {log_warnings, boolean()} |
                    {keep_alive, boolean() | timeout()} |
@@ -145,13 +167,11 @@
                        {via, Module :: atom(), ViaName :: term()},
          NamedStatements :: [{StatementName :: atom(), Statement :: iodata()}].
 start_link(Options) ->
-    GenSrvOpts = [{timeout, proplists:get_value(connect_timeout, Options,
-                                                ?default_connect_timeout)}],
     case proplists:get_value(name, Options) of
         undefined ->
-            gen_server:start_link(mysql_conn, Options, GenSrvOpts);
+            gen_server:start_link(mysql_conn, Options, []);
         ServerName ->
-            gen_server:start_link(ServerName, mysql_conn, Options, GenSrvOpts)
+            gen_server:start_link(ServerName, mysql_conn, Options, [])
     end.
 
 %% @see stop/2.
@@ -198,6 +218,11 @@ backported_gen_server_stop(Conn, Reason, Timeout) ->
         end
     end.
 
+%% @private
+-spec is_connected(Conn) -> boolean()
+    when Conn :: connection().
+is_connected(Conn) ->
+    gen_server:call(Conn, is_connected).
 
 %% @see query/5.
 -spec query(Conn, Query) -> Result

+ 125 - 45
src/mysql_conn.erl

@@ -34,6 +34,7 @@
 -define(default_port, 3306).
 -define(default_user, <<>>).
 -define(default_password, <<>>).
+-define(default_connect_timeout, 5000).
 -define(default_query_timeout, infinity).
 -define(default_query_cache_time, 60000). %% for query/3.
 -define(default_ping_timeout, 60000).
@@ -49,10 +50,9 @@
 -include("server_status.hrl").
 
 %% Gen_server state
--record(state, {server_version, connection_id, socket, sockmod, ssl_opts,
-                host, port, user, password, auth_plugin_data, log_warnings,
-                ping_timeout,
-                query_timeout, query_cache_time,
+-record(state, {server_version, connection_id, socket, sockmod, tcp_opts, ssl_opts,
+                host, port, user, password, database, queries, prepares, auth_plugin_data,
+                log_warnings, connect_timeout, ping_timeout, query_timeout, query_cache_time,
                 affected_rows = 0, status = 0, warning_count = 0, insert_id = 0,
                 transaction_levels = [], ping_ref = undefined,
                 stmts = dict:new(), query_cache = empty, cap_found_rows = false}).
@@ -73,14 +73,15 @@ init(Opts) ->
     Database       = proplists:get_value(database, Opts, undefined),
     LogWarn        = proplists:get_value(log_warnings, Opts, true),
     KeepAlive      = proplists:get_value(keep_alive, Opts, false),
-    Timeout        = proplists:get_value(query_timeout, Opts,
+    ConnectTimeout = proplists:get_value(connect_timeout, Opts,
+                                         ?default_connect_timeout),
+    QueryTimeout   = proplists:get_value(query_timeout, Opts,
                                          ?default_query_timeout),
     QueryCacheTime = proplists:get_value(query_cache_time, Opts,
                                          ?default_query_cache_time),
     TcpOpts        = proplists:get_value(tcp_options, Opts, []),
     SetFoundRows   = proplists:get_value(found_rows, Opts, false),
     SSLOpts        = proplists:get_value(ssl, Opts, undefined),
-    SockMod0       = gen_tcp,
 
     Queries        = proplists:get_value(queries, Opts, []),
     Prepares       = proplists:get_value(prepare, Opts, []),
@@ -91,10 +92,66 @@ init(Opts) ->
         N when N > 0 -> N
     end,
 
+    State0 = #state{
+        tcp_opts = TcpOpts,
+        ssl_opts = SSLOpts,
+        host = Host, port = Port,
+        user = User, password = Password,
+        database = Database,
+        queries = Queries, prepares = Prepares,
+        log_warnings = LogWarn,
+        connect_timeout = ConnectTimeout,
+        ping_timeout = PingTimeout,
+        query_timeout = QueryTimeout,
+        query_cache_time = QueryCacheTime,
+        cap_found_rows = (SetFoundRows =:= true)
+    },
+
+    case proplists:get_value(connect_mode, Opts, synchronous) of
+        synchronous ->
+            case connect(State0) of
+                {ok, State1} ->
+                    {ok, State1};
+                {error, Reason} ->
+                    {stop, Reason}
+            end;
+        asynchronous ->
+            gen_server:cast(self(), connect),
+            {ok, State0};
+        lazy ->
+            {ok, State0}
+    end.
+
+connect(#state{connect_timeout = ConnectTimeout} = State) ->
+    MainPid = self(),
+    Pid = spawn_link(
+        fun () ->
+            {ok, State1}=connect_socket(State),
+            case handshake(State1) of
+                {ok, #state{sockmod = SockMod, socket = Socket} = State2} ->
+                    SockMod:controlling_process(Socket, MainPid),
+                    MainPid ! {self(), {ok, State2}};
+                {error, _} = E ->
+                    MainPid ! {self(), E}
+            end
+        end
+    ),
+    receive
+        {Pid, {ok, State3}} ->
+            post_connect(State3);
+        {Pid, {error, _} = E} ->
+            E
+    after ConnectTimeout ->
+        unlink(Pid),
+        exit(Pid, kill),
+        {error, timeout}
+    end.
+
+connect_socket(#state{tcp_opts = TcpOpts, host = Host, port = Port} = State) ->
     %% Connect socket
     SockOpts = [binary, {packet, raw}, {active, false}, {nodelay, true}
                 | TcpOpts],
-    {ok, Socket0} = SockMod0:connect(Host, Port, SockOpts),
+    {ok, Socket} = gen_tcp:connect(Host, Port, SockOpts),
 
     %% If buffer wasn't specifically defined make it at least as
     %% large as recbuf, as suggested by the inet:setopts() docs.
@@ -102,13 +159,18 @@ init(Opts) ->
         true ->
             ok;
         false ->
-            {ok, [{buffer, Buffer}]} = inet:getopts(Socket0, [buffer]),
-            {ok, [{recbuf, Recbuf}]} = inet:getopts(Socket0, [recbuf]),
-            ok = inet:setopts(Socket0,[{buffer, max(Buffer, Recbuf)}])
+            {ok, [{buffer, Buffer}]} = inet:getopts(Socket, [buffer]),
+            {ok, [{recbuf, Recbuf}]} = inet:getopts(Socket, [recbuf]),
+            ok = inet:setopts(Socket, [{buffer, max(Buffer, Recbuf)}])
     end,
 
+    {ok, State#state{socket = Socket}}.
+
+handshake(#state{socket = Socket0, ssl_opts = SSLOpts,
+        user = User, password = Password, database = Database,
+        cap_found_rows = SetFoundRows} = State0) ->
     %% Exchange handshake communication.
-    Result = mysql_protocol:handshake(User, Password, Database, SockMod0, SSLOpts,
+    Result = mysql_protocol:handshake(User, Password, Database, gen_tcp, SSLOpts,
                                       Socket0, SetFoundRows),
     case Result of
         {ok, Handshake, SockMod, Socket} ->
@@ -116,42 +178,37 @@ init(Opts) ->
             #handshake{server_version = Version, connection_id = ConnId,
                        status = Status,
                        auth_plugin_data = AuthPluginData} = Handshake,
-            State = #state{server_version = Version, connection_id = ConnId,
+            State1 = State0#state{server_version = Version, connection_id = ConnId,
                            sockmod = SockMod,
                            socket = Socket,
-                           ssl_opts = SSLOpts,
-                           host = Host, port = Port,
-                           user = User, password = Password,
                            auth_plugin_data = AuthPluginData,
-                           status = Status,
-                           log_warnings = LogWarn,
-                           ping_timeout = PingTimeout,
-                           query_timeout = Timeout,
-                           query_cache_time = QueryCacheTime,
-                           cap_found_rows = (SetFoundRows =:= true)},
-            case execute_on_connect(Queries, Prepares, State) of
-                {ok, State1} ->
-                    process_flag(trap_exit, true),
-                    State2 = schedule_ping(State1),
-                    {ok, State2};
-                {error, Reason} ->
-                    {stop, Reason}
-            end;
+                           status = Status},
+            {ok, State1};
         #error{} = E ->
-            {stop, error_to_reason(E)}
+            {error, error_to_reason(E)}
+    end.
+
+post_connect(#state{queries = Queries, prepares = Prepares} = State) ->
+    case execute_on_connect(Queries, Prepares, State) of
+        {ok, State1} ->
+            process_flag(trap_exit, true),
+            State2 = schedule_ping(State1),
+            {ok, State2};
+        {error, _} = E ->
+            E
     end.
 
 execute_on_connect([], [], State) ->
     {ok, State};
 execute_on_connect([], [{Name, Stmt}|Prepares], State) ->
-    case do_named_prepare(Name, Stmt, State) of
+    case named_prepare(Name, Stmt, State) of
         {{ok, Name}, State1} ->
             execute_on_connect([], Prepares, State1);
         {{error, _} = E, _} ->
             E
     end;
 execute_on_connect([Query|Queries], Prepares, State) ->
-    case do_query(Query, no_filtermap_fun, default_timeout, State) of
+    case query(Query, no_filtermap_fun, default_timeout, State) of
         {ok, State1} ->
             execute_on_connect(Queries, Prepares, State1);
         {{ok, _}, State1} ->
@@ -200,8 +257,17 @@ execute_on_connect([Query|Queries], Prepares, State) ->
 %%       able to handle this in the caller's process, we also return the
 %%       nesting level.</dd>
 %% </dl>
+handle_call(is_connected, _, #state{socket = Socket} = State) ->
+    {reply, Socket =/= undefined, State};
+handle_call(Msg, From, #state{socket = undefined} = State) ->
+    case connect(State) of
+        {ok, State1} ->
+            handle_call(Msg, From, State1);
+        {error, _} = E ->
+            {stop, E, State}
+    end;
 handle_call({query, Query, FilterMap, Timeout}, _From, State) ->
-    {Reply, State1} = do_query(Query, FilterMap, Timeout, State),
+    {Reply, State1} = query(Query, FilterMap, Timeout, State),
     {reply, Reply, State1};
 handle_call({param_query, Query, Params, FilterMap, default_timeout}, From,
             State) ->
@@ -236,7 +302,8 @@ handle_call({param_query, Query, Params, FilterMap, Timeout}, _From,
     case StmtResult of
         {ok, StmtRec} ->
             State1 = State#state{query_cache = Cache1},
-            execute_stmt(StmtRec, Params, FilterMap, Timeout, State1);
+            {Reply, State2} = execute_stmt(StmtRec, Params, FilterMap, Timeout, State1),
+            {reply, Reply, State2};
         PrepareError ->
             {reply, PrepareError, State}
     end;
@@ -246,7 +313,8 @@ handle_call({execute, Stmt, Args, FilterMap, default_timeout}, From, State) ->
 handle_call({execute, Stmt, Args, FilterMap, Timeout}, _From, State) ->
     case dict:find(Stmt, State#state.stmts) of
         {ok, StmtRec} ->
-            execute_stmt(StmtRec, Args, FilterMap, Timeout, State);
+            {Reply, State1} = execute_stmt(StmtRec, Args, FilterMap, Timeout, State),
+            {reply, Reply, State1};
         error ->
             {reply, {error, not_prepared}, State}
     end;
@@ -265,7 +333,7 @@ handle_call({prepare, Query}, _From, State) ->
             {reply, {ok, Id}, State2}
     end;
 handle_call({prepare, Name, Query}, _From, State) when is_atom(Name) ->
-    {Reply, State1} = do_named_prepare(Name, Query, State),
+    {Reply, State1} = named_prepare(Name, Query, State),
     {reply, Reply, State1};
 handle_call({unprepare, Stmt}, _From, State) when is_atom(Stmt);
                                                   is_integer(Stmt) ->
@@ -300,8 +368,10 @@ handle_call({change_user, Username, Password, Options}, From,
     State2 = State1#state{query_cache = empty, stmts = dict:new()},
     case Result of
         #ok{} ->
-            State3 = State2#state{user = Username, password = Password},
-            case execute_on_connect(Queries, Prepares, State3) of
+            State3 = State2#state{user = Username, password = Password,
+                                  database=Database, queries=Queries,
+                                  prepares=Prepares},
+            case post_connect(State3) of
                 {ok, State4} ->
                     {reply, ok, State4};
                 {error, Reason} = E ->
@@ -386,6 +456,15 @@ handle_call(commit, {FromPid, _},
     {reply, ok, State1#state{transaction_levels = L}}.
 
 %% @private
+handle_cast(connect, #state{socket = undefined} = State) ->
+    case connect(State) of
+        {ok, State1} ->
+            {noreply, State1};
+        {error, _} = E ->
+            {stop, E, State}
+    end;
+handle_cast(connect, State) ->
+    {noreply, State};
 handle_cast(_Msg, State) ->
     {noreply, State}.
 
@@ -437,7 +516,7 @@ code_change(_OldVsn, _State, _Extra) ->
 
 %% --- Helpers ---
 
-%% @doc Executes a prepared statement and returns {reply, Reply, NextState}.
+%% @doc Executes a prepared statement and returns {Reply, NewState}.
 execute_stmt(Stmt, Args, FilterMap, Timeout,
              State = #state{socket = Socket, sockmod = SockMod}) ->
     setopts(SockMod, Socket, [{active, false}]),
@@ -458,8 +537,7 @@ execute_stmt(Stmt, Args, FilterMap, Timeout,
     State1 = lists:foldl(fun update_state/2, State, Recs),
     State1#state.warning_count > 0 andalso State1#state.log_warnings
         andalso log_warnings(State1, Stmt#prepared.orig_query),
-    {Reply, State2} = handle_query_call_result(Recs, Stmt#prepared.orig_query, State1, []),
-    {reply, Reply, State2}.
+    handle_query_call_result(Recs, Stmt#prepared.orig_query, State1, []).
 
 %% @doc Produces a tuple to return as an error reason.
 -spec error_to_reason(#error{}) -> mysql:server_reason().
@@ -486,10 +564,10 @@ update_state(Rec, State) ->
     schedule_ping(State1).
 
 %% @doc executes an unparameterized query and returns {Reply, NewState}.
-do_query(Query, FilterMap, default_timeout,
+query(Query, FilterMap, default_timeout,
             #state{query_timeout = DefaultTimeout} = State) ->
-    do_query(Query, FilterMap, DefaultTimeout, State);
-do_query(Query, FilterMap, Timeout,
+    query(Query, FilterMap, DefaultTimeout, State);
+query(Query, FilterMap, Timeout,
             #state{sockmod = SockMod, socket = Socket} = State) ->
     setopts(SockMod, Socket, [{active, false}]),
     Result = mysql_protocol:query(Query, SockMod, Socket, FilterMap, Timeout),
@@ -513,7 +591,7 @@ do_query(Query, FilterMap, Timeout,
 
 %% @doc Prepares a named query and returns {{ok, Name}, NewState} or
 %% {{error, Reason}, NewState}.
-do_named_prepare(Name, Query, State) ->
+named_prepare(Name, Query, State) ->
     #state{socket = Socket, sockmod = SockMod} = State,
     %% First unprepare if there is an old statement with this name.
     setopts(SockMod, Socket, [{active, false}]),
@@ -613,6 +691,8 @@ kill_query(#state{connection_id = ConnId, host = Host, port = Port,
                                    [error_to_reason(E)])
     end.
 
+stop_server(Reason, #state{socket = undefined} = State) ->
+  {stop, Reason, State};
 stop_server(Reason,
             #state{socket = Socket, connection_id = ConnId} = State) ->
   error_logger:error_msg("Connection Id ~p closing with reason: ~p~n",

+ 42 - 2
test/mysql_tests.erl

@@ -45,6 +45,46 @@
                           "  c CHAR(2)"
                           ") ENGINE=InnoDB">>).
 
+connect_synchronous_test() ->
+    {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password},
+                                  {connect_mode, synchronous}]),
+    ?assert(mysql:is_connected(Pid)),
+    mysql:stop(Pid),
+    ok.
+
+connect_asynchronous_successful_test() ->
+    {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password},
+                                  {connect_mode, asynchronous}]),
+    ?assert(mysql:is_connected(Pid)),
+    mysql:stop(Pid),
+    ok.
+
+connect_asynchronous_failing_test() ->
+    process_flag(trap_exit, true),
+    {ok, Ret, _Logged} = error_logger_acc:capture(
+        fun () ->
+            {ok, Pid} = mysql:start_link([{user, "dummy"}, {password, "junk"},
+                                          {connect_mode, asynchronous}]),
+            receive
+                {'EXIT', Pid, {error, {1045, <<"28000">>, _}}} -> ok
+            after 1000 ->
+                error(no_exit_message)
+            end
+        end
+    ),
+    ?assertEqual(ok, Ret),
+    process_flag(trap_exit, false),
+    ok.
+
+connect_lazy_test() ->
+    {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password},
+                                  {connect_mode, lazy}]),
+    ?assertNot(mysql:is_connected(Pid)),
+    {ok, [<<"1">>], [[1]]} = mysql:query(Pid, <<"SELECT 1">>),
+    ?assert(mysql:is_connected(Pid)),
+    mysql:stop(Pid),
+    ok.
+
 failing_connect_test() ->
     process_flag(trap_exit, true),
     {ok, Ret, Logged} = error_logger_acc:capture(
@@ -120,7 +160,7 @@ server_disconnect_test() ->
     process_flag(trap_exit, true),
     Options = [{user, ?user}, {password, ?password}],
     {ok, Pid} = mysql:start_link(Options),
-    {ok, ok, LoggedErrors} = error_logger_acc:capture(fun () ->
+    {ok, ok, _LoggedErrors} = error_logger_acc:capture(fun () ->
         %% Make the server close the connection after 1 second of inactivity.
         ok = mysql:query(Pid, <<"SET SESSION wait_timeout = 1">>),
         receive
@@ -162,7 +202,7 @@ keep_alive_test() ->
      receive after 70 -> ok end,
      State = get_state(Pid),
      [state, _Version, _ConnectionId, Socket | _] = tuple_to_list(State),
-     {ok, ExitMessage, LoggedErrors} = error_logger_acc:capture(fun () ->
+     {ok, ExitMessage, _LoggedErrors} = error_logger_acc:capture(fun () ->
          gen_tcp:close(Socket),
          receive
             Message -> Message