Browse Source

Parametrized queries using cached prep. stmts

Viktor Söderqvist 10 years ago
parent
commit
db2191db31
3 changed files with 228 additions and 19 deletions
  1. 99 18
      src/mysql.erl
  2. 119 0
      src/mysql_cache.erl
  3. 10 1
      test/mysql_tests.erl

+ 99 - 18
src/mysql.erl

@@ -23,7 +23,8 @@
 %% gen_server is locally registered.
 -module(mysql).
 
--export([start_link/1, query/2, execute/3, prepare/2, prepare/3, unprepare/2,
+-export([start_link/1, query/2, query/3, execute/3,
+         prepare/2, prepare/3, unprepare/2,
          warning_count/1, affected_rows/1, autocommit/1, insert_id/1,
          in_transaction/1,
          transaction/2, transaction/3]).
@@ -39,6 +40,7 @@
 -define(default_user, <<>>).
 -define(default_password, <<>>).
 -define(default_timeout, infinity).
+-define(default_query_cache_time, 60000). %% for query/3.
 
 %% A connection is a ServerRef as in gen_server:call/2,3.
 -type connection() :: Name :: atom() |
@@ -72,12 +74,16 @@
 %%   <dt>`{database, Database}'</dt>
 %%   <dd>The name of the database AKA schema to use. This can be changed later
 %%       using the query `USE <database>'.</dd>
+%%   <dt>`{query_cache_time, Timeout}'</dt>
+%%   <dd>The minimum number of milliseconds to cache prepared statements used
+%%       for parametrized queries with query/3.</dd>
 %% </dl>
 -spec start_link(Options) -> {ok, pid()} | ignore | {error, term()}
     when Options :: [Option],
          Option :: {name, ServerName} | {host, iodata()} | {port, integer()} | 
                    {user, iodata()} | {password, iodata()} |
-                   {database, iodata()},
+                   {database, iodata()} |
+                   {query_cache_time, non_neg_integer()},
          ServerName :: {local, Name :: atom()} |
                        {global, GlobalName :: term()} |
                        {via, Module :: atom(), ViaName :: term()}.
@@ -99,6 +105,23 @@ start_link(Options) ->
 query(Conn, Query) ->
     gen_server:call(Conn, {query, Query}).
 
+%% @doc Executes a parameterized query. A prepared statement is created,
+%% executed and then cached for a certain time. If the same query is executed
+%% again when it is already cached, it does not need to be prepared again.
+%%
+%% The minimum time the prepared statement is cached can be specified using the
+%% option `{query_cache_time, Milliseconds}' to start_link/1.
+-spec query(Conn, Query, Params) -> ok | {ok, ColumnNames, Rows} |
+                                    {error, Reason}
+    when Conn :: connection(),
+         Query :: iodata(),
+         Params :: [term()],
+         ColumnNames :: [binary()],
+         Rows :: [[term()]],
+         Reason :: server_reason().
+query(Conn, Query, Params) when is_list(Params) ->
+    gen_server:call(Conn, {query, Query, Params}).
+
 %% @doc Executes a prepared statement.
 %% @see prepare/2
 %% @see prepare/3
@@ -249,7 +272,8 @@ transaction(Conn, Fun, Args) when is_list(Args),
 
 %% Gen_server state
 -record(state, {socket, timeout = infinity, affected_rows = 0, status = 0,
-                warning_count = 0, insert_id = 0, stmts = dict:new()}).
+                warning_count = 0, insert_id = 0, stmts = dict:new(),
+                query_cache_time, query_cache = empty}).
 
 %% @private
 init(Opts) ->
@@ -260,6 +284,8 @@ init(Opts) ->
     Password = proplists:get_value(password, Opts, ?default_password),
     Database = proplists:get_value(database, Opts, undefined),
     Timeout  = proplists:get_value(timeout,  Opts, ?default_timeout),
+    QueryCacheTime = proplists:get_value(query_cache_time, Opts,
+                                         ?default_query_cache_time),
 
     %% Connect socket
     SockOpts = [{active, false}, binary, {packet, raw}],
@@ -272,7 +298,8 @@ init(Opts) ->
                                       RecvFun),
     case Result of
         #ok{} = OK ->
-            State = #state{socket = Socket, timeout = Timeout},
+            State = #state{socket = Socket, timeout = Timeout,
+                           query_cache_time = QueryCacheTime},
             State1 = update_state(State, OK),
             %% Trap exit so that we can properly disconnect when we die.
             process_flag(trap_exit, true),
@@ -298,24 +325,45 @@ handle_call({query, Query}, _From, State) when is_binary(Query);
             Names = [Def#col.name || Def <- ColDefs],
             {reply, {ok, Names, Rows}, State1}
     end;
+handle_call({query, Query, Params}, _From, State) when is_list(Params) ->
+    %% Parametrized query = anonymous prepared statement
+    QueryBin = iolist_to_binary(Query),
+    #state{socket = Socket, timeout = Timeout} = State,
+    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
+    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+    Cache = State#state.query_cache,
+    {StmtResult, Cache1} = case mysql_cache:lookup(QueryBin, Cache) of
+        {found, FoundStmt, NewCache} ->
+            %% Found
+            {{ok, FoundStmt}, NewCache};
+        not_found ->
+            %% Prepare
+            Rec = mysql_protocol:prepare(Query, SendFun, RecvFun),
+            %State1 = update_state(State, Rec),
+            case Rec of
+                #error{} = E ->
+                    {{error, error_to_reason(E)}, Cache};
+                #prepared{} = Stmt ->
+                    %% If the first entry in the cache, start the timer.
+                    Cache == empty andalso begin
+                        When = State#state.query_cache_time * 2,
+                        erlang:send_after(When, self(), query_cache)
+                    end,
+                    {{ok, Stmt}, mysql_cache:store(QueryBin, Stmt, Cache)}
+            end
+    end,
+    case StmtResult of
+        {ok, StmtRec} ->
+            State1 = State#state{query_cache = Cache1},
+            execute_stmt(StmtRec, Params, State1);
+        PrepareError ->
+            {reply, PrepareError, State}
+    end;
 handle_call({execute, Stmt, Args}, _From, State) when is_atom(Stmt);
                                                       is_integer(Stmt) ->
     case dict:find(Stmt, State#state.stmts) of
         {ok, StmtRec} ->
-            #state{socket = Socket, timeout = Timeout} = State,
-            SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
-            RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
-            Rec = mysql_protocol:execute(StmtRec, Args, SendFun, RecvFun),
-            State1 = update_state(State, Rec),
-            case Rec of
-                #ok{} ->
-                    {reply, ok, State1};
-                #error{} = E ->
-                    {reply, {error, error_to_reason(E)}, State1};
-                #resultset{cols = ColDefs, rows = Rows} ->
-                    Names = [Def#col.name || Def <- ColDefs],
-                    {reply, {ok, Names, Rows}, State1}
-            end;
+            execute_stmt(StmtRec, Args, State);
         error ->
             {reply, {error, not_prepared}, State}
     end;
@@ -384,6 +432,22 @@ handle_cast(_Msg, State) ->
     {noreply, State}.
 
 %% @private
+handle_info(query_cache, State = #state{query_cache = Cache,
+                                        query_cache_time = CacheTime}) ->
+    %% Evict expired queries/statements in the cache used by query/3.
+    {Evicted, Cache1} = mysql_cache:evict_older_than(Cache, CacheTime),
+    %% Unprepare the evicted statements
+    #state{socket = Socket, timeout = Timeout} = State,
+    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
+    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+    lists:foreach(fun ({_Query, Stmt}) ->
+                      mysql_protocol:unprepare(Stmt, SendFun, RecvFun)
+                  end,
+                  Evicted),
+    %% If nonempty, schedule eviction again.
+    mysql_cache:size(Cache1) > 0 andalso
+        erlang:send_after(CacheTime, self(), query_cache),
+    {noreply, State#state{query_cache = Cache1}};
 handle_info(_Info, State) ->
     {noreply, State}.
 
@@ -405,6 +469,23 @@ code_change(_OldVsn, _State, _Extra) ->
 
 %% --- Helpers ---
 
+%% @doc Returns a tuple on the the same form as handle_call/3.
+execute_stmt(StmtRec, Args, State) ->
+    #state{socket = Socket, timeout = Timeout} = State,
+    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
+    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+    Rec = mysql_protocol:execute(StmtRec, Args, SendFun, RecvFun),
+    State1 = update_state(State, Rec),
+    case Rec of
+        #ok{} ->
+            {reply, ok, State1};
+        #error{} = E ->
+            {reply, {error, error_to_reason(E)}, State1};
+        #resultset{cols = ColDefs, rows = Rows} ->
+            Names = [Def#col.name || Def <- ColDefs],
+            {reply, {ok, Names, Rows}, State1}
+    end.
+
 %% @doc Produces a tuple to return as an error reason.
 -spec error_to_reason(#error{}) -> server_reason().
 error_to_reason(#error{code = Code, state = State, msg = Msg}) ->

+ 119 - 0
src/mysql_cache.erl

@@ -0,0 +1,119 @@
+%% Minicache. Feel free to rename this module and include it in other projects.
+%%-----------------------------------------------------------------------------
+%% Copyright 2014 Viktor Söderqvist
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+
+%% @doc A minimalistic time triggered dict based cache data structure.
+%%
+%% The cache keeps track of when each key was last used. Elements are evicted
+%% using manual calls to evict_older_than/2. Most of the functions return a new
+%% updated cache object which should be used in subsequent calls.
+%%
+%% A cache can be initialized to 'empty' which represents the empty cache.
+%% @private
+-module(mysql_cache).
+
+-export_type([cache/2]).
+-export([evict_older_than/2, lookup/2, new/0, size/1, store/3]).
+
+-type cache(K, V) ::
+    {cache, erlang:timestamp(), dict:dict(K, {V, non_neg_integer()})} | empty.
+
+%% @doc Deletes the entries that have not been used for `MaxAge' milliseconds
+%% and returns them along with the new state.
+-spec evict_older_than(Cache :: cache(K, V), MaxAge :: non_neg_integer()) ->
+    {Evicted :: [{K, V}], NewCache :: cache(K, V)}.
+evict_older_than({cache, StartTs, Dict}, MaxAge) ->
+    MinTime = timer:now_diff(os:timestamp(), StartTs) div 1000 - MaxAge,
+    {Evicted, Dict1} = dict:fold(
+        fun (Key, {Value, Time}, {EvictedAcc, DictAcc}) ->
+            if Time =< MinTime -> {[{Key, Value} | EvictedAcc], DictAcc};
+               Time >  MinTime -> {EvictedAcc, dict:store(Key, Value, DictAcc)}
+            end
+        end,
+        {[], dict:new()},
+        Dict),
+    Cache1 = case dict:is_empty(Dict1) of
+        true  -> empty;
+        false -> {cache, StartTs, Dict1}
+    end,
+    {Evicted, Cache1};
+evict_older_than(empty, _) ->
+    {[], empty}.
+
+%% @doc Looks up a key in a cache. If found, returns the value and a new cache
+%% with the 'last used' timestamp updated for the key.
+-spec lookup(Key :: K, Cache :: cache(K, V)) ->
+    {found, Value :: V, UpdatedCache :: cache(K, V)} | not_found.
+lookup(Key, {cache, StartTs, Dict}) ->
+    case dict:find(Key, Dict) of
+        {ok, {Value, _OldTime}} ->
+            NewTime = timer:now_diff(os:timestamp(), StartTs) div 1000,
+            Dict1 = dict:store(Key, {Value, NewTime}, Dict),
+            Cache1 = {cache, StartTs, Dict1},
+            {found, Value, Cache1};
+        error ->
+            not_found
+    end;
+lookup(_Key, empty) ->
+    not_found.
+
+%% @doc Returns the atom `empty' which represents an empty cache.
+-spec new() -> cache(K :: term(), V :: term()).
+new() ->
+    empty.
+
+%% @doc Returns the number of elements in the cache.
+-spec size(cache(K :: term(), V :: term())) -> non_neg_integer().
+size({cache, _, Dict}) ->
+    dict:size(Dict);
+size(empty) ->
+    0.
+
+%% @doc Stores a key-value pair in the cache. If the key already exists, the
+%% associated value is replaced by `Value'.
+-spec store(Key :: K, Value :: V, Cache :: cache(K, V)) -> cache(K, V)
+    when K :: term(), V :: term().
+store(Key, Value, {cache, StartTs, Dict}) ->
+    Time = timer:now_diff(os:timestamp(), StartTs) div 1000,
+    {cache, StartTs, dict:store(Key, {Value, Time}, Dict)};
+store(Key, Value, empty) ->
+    {cache, os:timestamp(), dict:store(Key, {Value, 0}, dict:new())}.
+
+-ifdef(TEST).
+-include_lib("eunit/include/eunit.hrl").
+
+empty_test() ->
+    ?assertEqual(empty, ?MODULE:new()),
+    ?assertEqual(0, ?MODULE:size(empty)),
+    ?assertEqual(not_found, ?MODULE:lookup(foo, empty)),
+    ?assertMatch({[], empty}, ?MODULE:evict_older_than(empty, 10)).
+
+nonempty_test() ->
+    Cache = ?MODULE:store(foo, bar, empty),
+    ?assertMatch({found, bar, _}, ?MODULE:lookup(foo, Cache)),
+    ?assertMatch(not_found, ?MODULE:lookup(baz, Cache)),
+    ?assertMatch({[], _}, ?MODULE:evict_older_than(Cache, 10)),
+    ?assertMatch({cache, _, _}, Cache),
+    ?assertEqual(1, ?MODULE:size(Cache)),
+    receive after 11 -> ok end, %% expire cache
+    ?assertEqual({[{foo, bar}], empty}, ?MODULE:evict_older_than(Cache, 10)),
+    %% lookup un-expires cache
+    {found, bar, NewCache} = ?MODULE:lookup(foo, Cache),
+    ?assertMatch({[], {cache, _, _}}, ?MODULE:evict_older_than(NewCache, 10)),
+    %% store also un-expires
+    NewCache2 = ?MODULE:store(foo, baz, Cache),
+    ?assertMatch({[], {cache, _, _}}, ?MODULE:evict_older_than(NewCache2, 10)).
+
+-endif.

+ 10 - 1
test/mysql_tests.erl

@@ -351,7 +351,8 @@ write_read_text_binary(Conn, Term, SqlLiteral, Table, Column) ->
 with_table_foo_test_() ->
     {setup,
      fun () ->
-         {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password}]),
+         {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password},
+                                       {query_cache_time, 50}]),
          ok = mysql:query(Pid, <<"DROP DATABASE IF EXISTS otptest">>),
          ok = mysql:query(Pid, <<"CREATE DATABASE otptest">>),
          ok = mysql:query(Pid, <<"USE otptest">>),
@@ -363,6 +364,7 @@ with_table_foo_test_() ->
          exit(Pid, normal)
      end,
      {with, [fun prepared_statements/1,
+             fun parameterized_query/1,
              fun transaction_simple_success/1,
              fun transaction_simple_aborted/1]}}.
 
@@ -392,6 +394,13 @@ prepared_statements(Pid) ->
     ?assertEqual({error, not_prepared}, mysql:execute(Pid, not_a_stmt, [])),
     ok.
 
+parameterized_query(Conn) ->
+    %% To see that cache eviction works as expected, look at the code coverage.
+    {ok, _, []} = mysql:query(Conn, "SELECT * FROM foo WHERE bar = ?", [1]),
+    {ok, _, []} = mysql:query(Conn, "SELECT * FROM foo WHERE bar = ?", [2]),
+    receive after 150 -> ok end, %% Now the query cache should emptied
+    {ok, _, []} = mysql:query(Conn, "SELECT * FROM foo WHERE bar = ?", [3]).
+
 transaction_simple_success(Pid) ->
     ?assertNot(mysql:in_transaction(Pid)),
     Result = mysql:transaction(Pid, fun () ->