Browse Source

Possibility to apply a filtermap fun on results (#104)

The query/execute functions are provided with an additional parameter
to allow for an (optional) filtermap function to be passed in, for
filtering and/or mapping the rows in the results in a similar fashion
as lists:filtermap/2.

This fun is applied in the process of the connection, so that copying
to the caller process can be avoided for large result sets.
Jan Uhlig 6 years ago
parent
commit
a727c0db9a
4 changed files with 358 additions and 128 deletions
  1. 190 45
      src/mysql.erl
  2. 26 23
      src/mysql_conn.erl
  3. 94 60
      src/mysql_protocol.erl
  4. 48 0
      test/mysql_tests.erl

+ 190 - 45
src/mysql.erl

@@ -25,7 +25,8 @@
 %% gen_server is locally registered.
 %% gen_server is locally registered.
 -module(mysql).
 -module(mysql).
 
 
--export([start_link/1, query/2, query/3, query/4, execute/3, execute/4,
+-export([start_link/1, query/2, query/3, query/4, query/5,
+         execute/3, execute/4, execute/5,
          prepare/2, prepare/3, unprepare/2,
          prepare/2, prepare/3, unprepare/2,
          warning_count/1, affected_rows/1, autocommit/1, insert_id/1,
          warning_count/1, affected_rows/1, autocommit/1, insert_id/1,
          encode/2, in_transaction/1,
          encode/2, in_transaction/1,
@@ -45,7 +46,13 @@
                           Message :: binary()}.
                           Message :: binary()}.
 
 
 -type column_names() :: [binary()].
 -type column_names() :: [binary()].
--type rows() :: [[term()]].
+-type row() :: [term()].
+-type rows() :: [row()].
+
+-type query_filtermap_fun() :: fun((row()) -> query_filtermap_res())
+                             | fun((column_names(), row()) -> query_filtermap_res()).
+-type query_filtermap_res() :: boolean()
+                             | {true, term()}.
 
 
 -type query_result() :: ok
 -type query_result() :: ok
                       | {ok, column_names(), rows()}
                       | {ok, column_names(), rows()}
@@ -160,47 +167,75 @@ start_link(Options) ->
     end,
     end,
     Ret.
     Ret.
 
 
-%% @doc Executes a query with the query timeout as given to start_link/1.
-%%
-%% It is possible to execute multiple semicolon-separated queries.
-%%
-%% Results are returned in the form `{ok, ColumnNames, Rows}' if there is one
-%% result set. If there are more than one result sets, they are returned in the
-%% form `{ok, [{ColumnNames, Rows}, ...]}'.
-%%
-%% For queries that don't return any rows (INSERT, UPDATE, etc.) only the atom
-%% `ok' is returned.
+%% @see query/5.
 -spec query(Conn, Query) -> Result
 -spec query(Conn, Query) -> Result
     when Conn :: connection(),
     when Conn :: connection(),
          Query :: iodata(),
          Query :: iodata(),
          Result :: query_result().
          Result :: query_result().
 query(Conn, Query) ->
 query(Conn, Query) ->
-    query_call(Conn, {query, Query}).
+    query(Conn, Query, no_params, no_filtermap_fun, default_timeout).
 
 
-%% @doc Depending on the 3rd argument this function does different things.
-%%
-%% If the 3rd argument is a list, it executes a parameterized query. This is
-%% equivallent to query/4 with the query timeout as given to start_link/1.
-%%
-%% If the 3rd argument is a timeout, it executes a plain query with this
-%% timeout.
-%%
-%% The return value is the same as for query/2.
-%%
-%% @see query/2.
-%% @see query/4.
--spec query(Conn, Query, Params | Timeout) -> Result
+%% @see query/5.
+-spec query(Conn, Query, Params | FilterMap | Timeout) -> Result
     when Conn :: connection(),
     when Conn :: connection(),
          Query :: iodata(),
          Query :: iodata(),
-         Timeout :: timeout(),
-         Params :: [term()],
+         Timeout :: default_timeout | timeout(),
+         Params :: no_params | [term()],
+         FilterMap :: no_filtermap_fun | query_filtermap_fun(),
          Result :: query_result().
          Result :: query_result().
-query(Conn, Query, Params) when is_list(Params) ->
-    query_call(Conn, {param_query, Query, Params});
-query(Conn, Query, Timeout) when is_integer(Timeout); Timeout == infinity ->
-    query_call(Conn, {query, Query, Timeout}).
-
-%% @doc Executes a parameterized query with a timeout.
+query(Conn, Query, Params) when Params == no_params;
+                                is_list(Params) ->
+    query(Conn, Query, Params, no_filtermap_fun, default_timeout);
+query(Conn, Query, FilterMap) when FilterMap == no_filtermap_fun;
+                                   is_function(FilterMap, 1);
+                                   is_function(FilterMap, 2) ->
+    query(Conn, Query, no_params, FilterMap, default_timeout);
+query(Conn, Query, Timeout) when Timeout == default_timeout;
+                                 is_integer(Timeout);
+                                 Timeout == infinity ->
+    query(Conn, Query, no_params, no_filtermap_fun, Timeout).
+
+%% @see query/5.
+-spec query(Conn, Query, Params, Timeout) -> Result
+        when Conn :: connection(),
+             Query :: iodata(),
+             Timeout :: default_timeout | timeout(),
+             Params :: no_params | [term()],
+             Result :: query_result();
+    (Conn, Query, FilterMap, Timeout) -> Result
+        when Conn :: connection(),
+             Query :: iodata(),
+             Timeout :: default_timeout | timeout(),
+             FilterMap :: no_filtermap_fun | query_filtermap_fun(),
+             Result :: query_result();
+    (Conn, Query, Params, FilterMap) -> Result
+        when Conn :: connection(),
+             Query :: iodata(),
+             Params :: no_params | [term()],
+             FilterMap :: no_filtermap_fun | query_filtermap_fun(),
+             Result :: query_result().
+query(Conn, Query, Params, Timeout) when (Params == no_params orelse
+                                          is_list(Params)) andalso
+                                         (Timeout == default_timeout orelse
+                                          is_integer(Timeout) orelse
+                                          Timeout == infinity) ->
+    query(Conn, Query, Params, no_filtermap_fun, Timeout);
+query(Conn, Query, FilterMap, Timeout) when (FilterMap == no_filtermap_fun orelse
+                                             is_function(FilterMap, 1) orelse
+                                             is_function(FilterMap, 2)) andalso
+                                            (Timeout == default_timeout orelse
+                                             is_integer(Timeout) orelse
+                                             Timeout=:=infinity) ->
+    query(Conn, Query, no_params, FilterMap, Timeout);
+query(Conn, Query, Params, FilterMap) when (Params == no_params orelse
+                                            is_list(Params)) andalso
+                                           (FilterMap == no_filtermap_fun orelse
+                                            is_function(FilterMap, 1) orelse
+                                            is_function(FilterMap, 2)) ->
+    query(Conn, Query, Params, FilterMap, default_timeout).
+
+%% @doc Executes a parameterized query with a timeout and applies a filter/map
+%% function to the result rows.
 %%
 %%
 %% A prepared statement is created, executed and then cached for a certain
 %% A prepared statement is created, executed and then cached for a certain
 %% time. If the same query is executed again when it is already cached, it does
 %% time. If the same query is executed again when it is already cached, it does
@@ -209,40 +244,150 @@ query(Conn, Query, Timeout) when is_integer(Timeout); Timeout == infinity ->
 %% The minimum time the prepared statement is cached can be specified using the
 %% The minimum time the prepared statement is cached can be specified using the
 %% option `{query_cache_time, Milliseconds}' to start_link/1.
 %% option `{query_cache_time, Milliseconds}' to start_link/1.
 %%
 %%
-%% The return value is the same as for query/2.
--spec query(Conn, Query, Params, Timeout) -> Result
+%% Results are returned in the form `{ok, ColumnNames, Rows}' if there is one
+%% result set. If there are more than one result sets, they are returned in the
+%% form `{ok, [{ColumnNames, Rows}, ...]}'.
+%%
+%% For queries that don't return any rows (INSERT, UPDATE, etc.) only the atom
+%% `ok' is returned.
+%%
+%% The `Params', `FilterMap' and `Timeout' arguments are optional.
+%% <ul>
+%%   <li>If the `Params' argument is the atom `no_params' or is omitted, a plain
+%%       query will be executed instead of a parameterized one.</li>
+%%   <li>If the `FilterMap' argument is the atom `no_filtermap_fun' or is
+%%       omitted, no row filtering/mapping will be applied and all result rows
+%%       will be returned unchanged.</li>
+%%   <li>If the `Timeout' argument is the atom `default_timeout' or is omitted,
+%%       the timeout given in `start_link/1' is used.</li>
+%% </ul>
+%%
+%% If the `FilterMap' argument is used, it must be a function of arity 1 or 2
+%% that returns either `true', `false', or `{true, Value}'.
+%%
+%% Each result row is handed to the given function as soon as it is received
+%% from the server, and only when the function has returned, the next row is
+%% fetched. This provides the ability to prevent memory exhaustion; on the
+%% other hand, it can cause the server to time out on sending if your function
+%% is doing something slow (see the MySQL documentation on `NET_WRITE_TIMEOUT').
+%%
+%% If the function is of arity 1, only the row is passed to it as the single
+%% argument, while if the function is of arity 2, the column names are passed
+%% in as the first argument and the row as the second.
+%%
+%% The value returned is then used to decide if the row is to be included in
+%% the result(s) returned from the `query' call (filtering), or if something
+%% else is to be included in the result instead (mapping). You may also use
+%% this function for side effects, like writing rows to disk or sending them
+%% to another process etc.
+%%
+%% Here is an example showing some of the things that are possible:
+%% ```
+%% Query = "SELECT a, b, c FROM foo",
+%% FilterMap = fun
+%%     %% Include all rows where the first column is < 10.
+%%     ([A|_]) when A < 10 ->
+%%         true;
+%%     %% Exclude all rows where the first column is >= 10 and < 20.
+%%     ([A|_]) when A < 20 ->
+%%         false;
+%%     %% For rows where the first column is >= 20 and < 30, include
+%%     %% the atom 'foo' in place of the row instead.
+%%     ([A|_]) when A < 30 ->
+%%         {true, foo}};
+%%     %% For rows where the first row is >= 30 and < 40, send the
+%%     %% row to a gen_server via call (ie, wait for a response),
+%%     %% and do not include the row in the result.
+%%     (R=[A|_]) when A < 40 ->
+%%         gen_server:call(Pid, R),
+%%         false;
+%%     %% For rows where the first column is >= 40 and < 50, send the
+%%     %% row to a gen_server via cast (ie, do not wait for a reply),
+%%     %% and include the row in the result, also.
+%%     (R=[A|_]) when A < 50 ->
+%%         gen_server:cast(Pid, R),
+%%         true;
+%%     %% Exclude all other rows from the result.
+%%     (_) ->
+%%         false
+%% end,
+%% query(Conn, Query, no_params, FilterMap, default_timeout).
+%% '''
+-spec query(Conn, Query, Params, FilterMap, Timeout) -> Result
     when Conn :: connection(),
     when Conn :: connection(),
          Query :: iodata(),
          Query :: iodata(),
-         Timeout :: timeout(),
-         Params :: [term()],
+         Timeout :: default_timeout | timeout(),
+         Params :: no_params | [term()],
+         FilterMap :: no_filtermap_fun | query_filtermap_fun(),
          Result :: query_result().
          Result :: query_result().
-query(Conn, Query, Params, Timeout) ->
-    query_call(Conn, {param_query, Query, Params, Timeout}).
+query(Conn, Query, no_params, FilterMap, Timeout) ->
+    query_call(Conn, {query, Query, FilterMap, Timeout});
+query(Conn, Query, Params, FilterMap, Timeout) ->
+    query_call(Conn, {param_query, Query, Params, FilterMap, Timeout}).
 
 
 %% @doc Executes a prepared statement with the default query timeout as given
 %% @doc Executes a prepared statement with the default query timeout as given
 %% to start_link/1.
 %% to start_link/1.
 %% @see prepare/2
 %% @see prepare/2
 %% @see prepare/3
 %% @see prepare/3
+%% @see prepare/4
+%% @see execute/5
 -spec execute(Conn, StatementRef, Params) -> Result | {error, not_prepared}
 -spec execute(Conn, StatementRef, Params) -> Result | {error, not_prepared}
   when Conn :: connection(),
   when Conn :: connection(),
        StatementRef :: atom() | integer(),
        StatementRef :: atom() | integer(),
        Params :: [term()],
        Params :: [term()],
        Result :: query_result().
        Result :: query_result().
 execute(Conn, StatementRef, Params) ->
 execute(Conn, StatementRef, Params) ->
-    query_call(Conn, {execute, StatementRef, Params}).
+    execute(Conn, StatementRef, Params, no_filtermap_fun, default_timeout).
 
 
 %% @doc Executes a prepared statement.
 %% @doc Executes a prepared statement.
 %% @see prepare/2
 %% @see prepare/2
 %% @see prepare/3
 %% @see prepare/3
--spec execute(Conn, StatementRef, Params, Timeout) ->
+%% @see prepare/4
+%% @see execute/5
+-spec execute(Conn, StatementRef, Params, FilterMap | Timeout) ->
+    Result | {error, not_prepared}
+  when Conn :: connection(),
+       StatementRef :: atom() | integer(),
+       Params :: [term()],
+       FilterMap :: no_filtermap_fun | query_filtermap_fun(),
+       Timeout :: default_timeout | timeout(),
+       Result :: query_result().
+execute(Conn, StatementRef, Params, Timeout) when Timeout == default_timeout;
+                                                  is_integer(Timeout);
+                                                  Timeout=:=infinity ->
+    execute(Conn, StatementRef, Params, no_filtermap_fun, Timeout);
+execute(Conn, StatementRef, Params, FilterMap) when FilterMap == no_filtermap_fun;
+                                                    is_function(FilterMap, 1);
+                                                    is_function(FilterMap, 2) ->
+    execute(Conn, StatementRef, Params, FilterMap, default_timeout).
+
+%% @doc Executes a prepared statement.
+%% 
+%% The `FilterMap' and `Timeout' arguments are optional.
+%% <ul>
+%%   <li>If the `FilterMap' argument is the atom `no_filtermap_fun' or is
+%%       omitted, no row filtering/mapping will be applied and all result rows
+%%       will be returned unchanged.</li>
+%%   <li>If the `Timeout' argument is the atom `default_timeout' or is omitted,
+%%       the timeout given in `start_link/1' is used.</li>
+%% </ul>
+%%
+%% See `query/5' for an explanation of the `FilterMap' argument.
+%%
+%% @see prepare/2
+%% @see prepare/3
+%% @see prepare/4
+%% @see query/5
+-spec execute(Conn, StatementRef, Params, FilterMap, Timeout) ->
     Result | {error, not_prepared}
     Result | {error, not_prepared}
   when Conn :: connection(),
   when Conn :: connection(),
        StatementRef :: atom() | integer(),
        StatementRef :: atom() | integer(),
        Params :: [term()],
        Params :: [term()],
-       Timeout :: timeout(),
+       FilterMap :: no_filtermap_fun | query_filtermap_fun(),
+       Timeout :: default_timeout | timeout(),
        Result :: query_result().
        Result :: query_result().
-execute(Conn, StatementRef, Params, Timeout) ->
-    query_call(Conn, {execute, StatementRef, Params, Timeout}).
+execute(Conn, StatementRef, Params, FilterMap, Timeout) ->
+    query_call(Conn, {execute, StatementRef, Params, FilterMap, Timeout}).
 
 
 %% @doc Creates a prepared statement from the passed query.
 %% @doc Creates a prepared statement from the passed query.
 %% @see prepare/3
 %% @see prepare/3

+ 26 - 23
src/mysql_conn.erl

@@ -119,12 +119,9 @@ init(Opts) ->
 %% Query and execute calls:
 %% Query and execute calls:
 %%
 %%
 %% <ul>
 %% <ul>
-%%   <li>{query, Query}</li>
-%%   <li>{query, Query, Timeout}</li>
-%%   <li>{param_query, Query, Params}</li>
-%%   <li>{param_query, Query, Params, Timeout}</li>
-%%   <li>{execute, Stmt, Args}</li>
-%%   <li>{execute, Stmt, Args, Timeout}</li>
+%%   <li>{query, Query, FilterMap, Timeout}</li>
+%%   <li>{param_query, Query, Params, FilterMap, Timeout}</li>
+%%   <li>{execute, Stmt, Args, FilterMap, Timeout}</li>
 %% </ul>
 %% </ul>
 %%
 %%
 %% For the calls listed above, we return these values:
 %% For the calls listed above, we return these values:
@@ -154,15 +151,18 @@ init(Opts) ->
 %%       able to handle this in the caller's process, we also return the
 %%       able to handle this in the caller's process, we also return the
 %%       nesting level.</dd>
 %%       nesting level.</dd>
 %% </dl>
 %% </dl>
-handle_call({query, Query}, From, State) ->
-    handle_call({query, Query, State#state.query_timeout}, From, State);
-handle_call({query, Query, Timeout}, _From,
+handle_call({query, Query, FilterMap, default_timeout}, From, State) ->
+    handle_call({query, Query, FilterMap, State#state.query_timeout}, From,
+                State);
+handle_call({query, Query, FilterMap, Timeout}, _From,
             #state{sockmod = SockMod, socket = Socket} = State) ->
             #state{sockmod = SockMod, socket = Socket} = State) ->
     setopts(SockMod, Socket, [{active, false}]),
     setopts(SockMod, Socket, [{active, false}]),
-    {ok, Recs} = case mysql_protocol:query(Query, SockMod, Socket, Timeout) of
+    Result = mysql_protocol:query(Query, SockMod, Socket, FilterMap, Timeout),
+    {ok, Recs} = case Result of
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
             kill_query(State),
             kill_query(State),
-            mysql_protocol:fetch_query_response(SockMod, Socket, ?cmd_timeout);
+            mysql_protocol:fetch_query_response(SockMod, Socket, FilterMap,
+                                                ?cmd_timeout);
         {error, timeout} ->
         {error, timeout} ->
             %% For MySQL 4.x.x there is no way to recover from timeout except
             %% For MySQL 4.x.x there is no way to recover from timeout except
             %% killing the connection itself.
             %% killing the connection itself.
@@ -175,10 +175,11 @@ handle_call({query, Query, Timeout}, _From,
     State1#state.warning_count > 0 andalso State1#state.log_warnings
     State1#state.warning_count > 0 andalso State1#state.log_warnings
         andalso log_warnings(State1, Query),
         andalso log_warnings(State1, Query),
     handle_query_call_reply(Recs, Query, State1, []);
     handle_query_call_reply(Recs, Query, State1, []);
-handle_call({param_query, Query, Params}, From, State) ->
-    handle_call({param_query, Query, Params, State#state.query_timeout}, From,
-                State);
-handle_call({param_query, Query, Params, Timeout}, _From,
+handle_call({param_query, Query, Params, FilterMap, default_timeout}, From,
+            State) ->
+    handle_call({param_query, Query, Params, FilterMap,
+                State#state.query_timeout}, From, State);
+handle_call({param_query, Query, Params, FilterMap, Timeout}, _From,
             #state{socket = Socket, sockmod = SockMod} = State) ->
             #state{socket = Socket, sockmod = SockMod} = State) ->
     %% Parametrized query: Prepared statement cached with the query as the key
     %% Parametrized query: Prepared statement cached with the query as the key
     QueryBin = iolist_to_binary(Query),
     QueryBin = iolist_to_binary(Query),
@@ -207,16 +208,17 @@ handle_call({param_query, Query, Params, Timeout}, _From,
     case StmtResult of
     case StmtResult of
         {ok, StmtRec} ->
         {ok, StmtRec} ->
             State1 = State#state{query_cache = Cache1},
             State1 = State#state{query_cache = Cache1},
-            execute_stmt(StmtRec, Params, Timeout, State1);
+            execute_stmt(StmtRec, Params, FilterMap, Timeout, State1);
         PrepareError ->
         PrepareError ->
             {reply, PrepareError, State}
             {reply, PrepareError, State}
     end;
     end;
-handle_call({execute, Stmt, Args}, From, State) ->
-    handle_call({execute, Stmt, Args, State#state.query_timeout}, From, State);
-handle_call({execute, Stmt, Args, Timeout}, _From, State) ->
+handle_call({execute, Stmt, Args, FilterMap, default_timeout}, From, State) ->
+    handle_call({execute, Stmt, Args, FilterMap, State#state.query_timeout},
+        From, State);
+handle_call({execute, Stmt, Args, FilterMap, Timeout}, _From, State) ->
     case dict:find(Stmt, State#state.stmts) of
     case dict:find(Stmt, State#state.stmts) of
         {ok, StmtRec} ->
         {ok, StmtRec} ->
-            execute_stmt(StmtRec, Args, Timeout, State);
+            execute_stmt(StmtRec, Args, FilterMap, Timeout, State);
         error ->
         error ->
             {reply, {error, not_prepared}, State}
             {reply, {error, not_prepared}, State}
     end;
     end;
@@ -382,14 +384,15 @@ code_change(_OldVsn, _State, _Extra) ->
 %% --- Helpers ---
 %% --- Helpers ---
 
 
 %% @doc Executes a prepared statement and returns {Reply, NextState}.
 %% @doc Executes a prepared statement and returns {Reply, NextState}.
-execute_stmt(Stmt, Args, Timeout, State = #state{socket = Socket, sockmod = SockMod}) ->
+execute_stmt(Stmt, Args, FilterMap, Timeout,
+             State = #state{socket = Socket, sockmod = SockMod}) ->
     setopts(SockMod, Socket, [{active, false}]),
     setopts(SockMod, Socket, [{active, false}]),
     {ok, Recs} = case mysql_protocol:execute(Stmt, Args, SockMod, Socket,
     {ok, Recs} = case mysql_protocol:execute(Stmt, Args, SockMod, Socket,
-                                             Timeout) of
+                                             FilterMap, Timeout) of
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
             kill_query(State),
             kill_query(State),
             mysql_protocol:fetch_execute_response(SockMod, Socket,
             mysql_protocol:fetch_execute_response(SockMod, Socket,
-                                                  ?cmd_timeout);
+                                                  FilterMap, ?cmd_timeout);
         {error, timeout} ->
         {error, timeout} ->
             %% For MySQL 4.x.x there is no way to recover from timeout except
             %% For MySQL 4.x.x there is no way to recover from timeout except
             %% killing the connection itself.
             %% killing the connection itself.

+ 94 - 60
src/mysql_protocol.erl

@@ -28,8 +28,15 @@
 -module(mysql_protocol).
 -module(mysql_protocol).
 
 
 -export([handshake/7, quit/2, ping/2,
 -export([handshake/7, quit/2, ping/2,
-         query/4, fetch_query_response/3,
-         prepare/3, unprepare/3, execute/5, fetch_execute_response/3]).
+         query/4, query/5, fetch_query_response/3,
+         fetch_query_response/4, prepare/3, unprepare/3,
+         execute/5, execute/6, fetch_execute_response/3,
+         fetch_execute_response/4]).
+
+-type query_filtermap() :: no_filtermap_fun
+                         | fun(([term()]) -> query_filtermap_res())
+                         | fun(([term()], [term()]) -> query_filtermap_res()).
+-type query_filtermap_res() :: boolean() | {true, term()}.
 
 
 %% How much data do we want per packet?
 %% How much data do we want per packet?
 -define(MAX_BYTES_PER_PACKET, 16#1000000).
 -define(MAX_BYTES_PER_PACKET, 16#1000000).
@@ -113,15 +120,23 @@ ping(SockModule, Socket) ->
 -spec query(Query :: iodata(), atom(), term(), timeout()) ->
 -spec query(Query :: iodata(), atom(), term(), timeout()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
 query(Query, SockModule, Socket, Timeout) ->
 query(Query, SockModule, Socket, Timeout) ->
+    query(Query, SockModule, Socket, no_filtermap_fun, Timeout).
+
+-spec query(Query :: iodata(), atom(), term(), query_filtermap(), timeout()) ->
+    {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
+query(Query, SockModule, Socket, FilterMap, Timeout) ->
     Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
     Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
     SeqNum0 = 0,
     SeqNum0 = 0,
     {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, SeqNum0),
     {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, SeqNum0),
-    fetch_query_response(SockModule, Socket, Timeout).
+    fetch_query_response(SockModule, Socket, FilterMap, Timeout).
 
 
 %% @doc This is used by query/4. If query/4 returns {error, timeout}, this
 %% @doc This is used by query/4. If query/4 returns {error, timeout}, this
 %% function can be called to retry to fetch the results of the query.
 %% function can be called to retry to fetch the results of the query.
 fetch_query_response(SockModule, Socket, Timeout) ->
 fetch_query_response(SockModule, Socket, Timeout) ->
-    fetch_response(SockModule, Socket, Timeout, text, []).
+    fetch_query_response(SockModule, Socket, no_filtermap_fun, Timeout).
+
+fetch_query_response(SockModule, Socket, FilterMap, Timeout) ->
+    fetch_response(SockModule, Socket, Timeout, text, FilterMap, []).
 
 
 %% @doc Prepares a statement.
 %% @doc Prepares a statement.
 -spec prepare(iodata(), atom(), term()) -> #error{} | #prepared{}.
 -spec prepare(iodata(), atom(), term()) -> #error{} | #prepared{}.
@@ -168,8 +183,15 @@ unprepare(#prepared{statement_id = Id}, SockModule, Socket) ->
 %% @doc Executes a prepared statement.
 %% @doc Executes a prepared statement.
 -spec execute(#prepared{}, [term()], atom(), term(), timeout()) ->
 -spec execute(#prepared{}, [term()], atom(), term(), timeout()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
+execute(PrepStmt, ParamValues, SockModule, Socket, Timeout) ->
+    execute(PrepStmt, ParamValues, SockModule, Socket, no_filtermap_fun,
+            Timeout).
+-spec execute(#prepared{}, [term()], atom(), term(), query_filtermap(),
+              timeout()) ->
+    {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
 execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
 execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
-        SockModule, Socket, Timeout) when ParamCount == length(ParamValues) ->
+        SockModule, Socket, FilterMap, Timeout)
+  when ParamCount == length(ParamValues) ->
     %% Flags Constant Name
     %% Flags Constant Name
     %% 0x00 CURSOR_TYPE_NO_CURSOR
     %% 0x00 CURSOR_TYPE_NO_CURSOR
     %% 0x01 CURSOR_TYPE_READ_ONLY
     %% 0x01 CURSOR_TYPE_READ_ONLY
@@ -196,12 +218,15 @@ execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
             iolist_to_binary([Req1, TypesAndSigns, EncValues])
             iolist_to_binary([Req1, TypesAndSigns, EncValues])
     end,
     end,
     {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, 0),
     {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, 0),
-    fetch_execute_response(SockModule, Socket, Timeout).
+    fetch_execute_response(SockModule, Socket, FilterMap, Timeout).
 
 
 %% @doc This is used by execute/5. If execute/5 returns {error, timeout}, this
 %% @doc This is used by execute/5. If execute/5 returns {error, timeout}, this
 %% function can be called to retry to fetch the results of the query.
 %% function can be called to retry to fetch the results of the query.
 fetch_execute_response(SockModule, Socket, Timeout) ->
 fetch_execute_response(SockModule, Socket, Timeout) ->
-    fetch_response(SockModule, Socket, Timeout, binary, []).
+    fetch_execute_response(SockModule, Socket, no_filtermap_fun, Timeout).
+
+fetch_execute_response(SockModule, Socket, FilterMap, Timeout) ->
+    fetch_response(SockModule, Socket, Timeout, binary, FilterMap, []).
 
 
 %% --- internal ---
 %% --- internal ---
 
 
@@ -413,9 +438,10 @@ parse_handshake_confirm(Packet) ->
 %% @doc Fetches one or more results and and parses the result set(s) using
 %% @doc Fetches one or more results and and parses the result set(s) using
 %% either the text format (for plain queries) or the binary format (for
 %% either the text format (for plain queries) or the binary format (for
 %% prepared statements).
 %% prepared statements).
--spec fetch_response(atom(), term(), timeout(), text | binary, list()) ->
+-spec fetch_response(atom(), term(), timeout(), text | binary,
+                     query_filtermap(), list()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
-fetch_response(SockModule, Socket, Timeout, Proto, Acc) ->
+fetch_response(SockModule, Socket, Timeout, Proto, FilterMap, Acc) ->
     case recv_packet(SockModule, Socket, Timeout, any) of
     case recv_packet(SockModule, Socket, Timeout, any) of
         {ok, Packet, SeqNum2} ->
         {ok, Packet, SeqNum2} ->
             Result = case Packet of
             Result = case Packet of
@@ -426,19 +452,14 @@ fetch_response(SockModule, Socket, Timeout, Proto, Acc) ->
                 ResultPacket ->
                 ResultPacket ->
                     %% The first packet in a resultset is only the column count.
                     %% The first packet in a resultset is only the column count.
                     {ColCount, <<>>} = lenenc_int(ResultPacket),
                     {ColCount, <<>>} = lenenc_int(ResultPacket),
-                    R0 = fetch_resultset(SockModule, Socket, ColCount, SeqNum2),
-                    case R0 of
-                        #error{} = E ->
-                            %% TODO: Find a way to get here + testcase
-                            E;
-                        #resultset{} = R ->
-                            parse_resultset(R, ColCount, Proto)
-                    end
+                    fetch_resultset(SockModule, Socket, ColCount, Proto,
+                                    FilterMap, SeqNum2)
             end,
             end,
             Acc1 = [Result | Acc],
             Acc1 = [Result | Acc],
             case more_results_exists(Result) of
             case more_results_exists(Result) of
                 true ->
                 true ->
-                    fetch_response(SockModule, Socket, Timeout, Proto, Acc1);
+                    fetch_response(SockModule, Socket, Timeout, Proto,
+                                   FilterMap, Acc1);
                 false ->
                 false ->
                     {ok, lists:reverse(Acc1)}
                     {ok, lists:reverse(Acc1)}
             end;
             end;
@@ -446,36 +467,60 @@ fetch_response(SockModule, Socket, Timeout, Proto, Acc) ->
             {error, timeout}
             {error, timeout}
     end.
     end.
 
 
-%% @doc Fetches packets for a result set. The column definitions are parsed but
-%% the rows are unparsed binary packages. This function is used for both the
-%% text protocol and the binary protocol. This affects the way the rows need to
-%% be parsed.
--spec fetch_resultset(atom(), term(), integer(), integer()) ->
+%% @doc Fetches a result set.
+-spec fetch_resultset(atom(), term(), integer(), text | binary,
+                      query_filtermap(), integer()) ->
     #resultset{} | #error{}.
     #resultset{} | #error{}.
-fetch_resultset(SockModule, Socket, FieldCount, SeqNum) ->
-    {ok, ColDefs, SeqNum1} = fetch_column_definitions(SockModule, Socket,
-                                                      SeqNum, FieldCount, []),
-    {ok, DelimiterPacket, SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
-    #eof{status = S, warning_count = W} = parse_eof_packet(DelimiterPacket),
-    case fetch_resultset_rows(SockModule, Socket, SeqNum2, []) of
+fetch_resultset(SockModule, Socket, FieldCount, Proto, FilterMap, SeqNum0) ->
+    {ok, ColDefs0, SeqNum1} = fetch_column_definitions(SockModule, Socket,
+                                                       SeqNum0, FieldCount, []),
+    {ok, DelimPacket, SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
+    #eof{status = S, warning_count = W} = parse_eof_packet(DelimPacket),
+    ColDefs1 = lists:map(fun parse_column_definition/1, ColDefs0),
+    case fetch_resultset_rows(SockModule, Socket, FieldCount, ColDefs1, Proto,
+                              FilterMap, SeqNum2, []) of
         {ok, Rows, _SeqNum3} ->
         {ok, Rows, _SeqNum3} ->
-            ColDefs1 = lists:map(fun parse_column_definition/1, ColDefs),
-            #resultset{cols = ColDefs1, rows = Rows,
-                       status = S, warning_count = W};
+            #resultset{cols = ColDefs1, rows = Rows, status = S,
+                       warning_count = W};
         #error{} = E ->
         #error{} = E ->
             E
             E
     end.
     end.
 
 
-parse_resultset(#resultset{cols = ColDefs, rows = Rows} = R, ColumnCount,
-                text) ->
-    %% Parse the rows according to the 'text protocol' representation.
-    Rows1 = [decode_text_row(ColumnCount, ColDefs, Row) || Row <- Rows],
-    R#resultset{rows = Rows1};
-parse_resultset(#resultset{cols = ColDefs, rows = Rows} = R, ColumnCount,
-                binary) ->
-    %% Parse the rows according to the 'binary protocol' representation.
-    Rows1 = [decode_binary_row(ColumnCount, ColDefs, Row) || Row <- Rows],
-    R#resultset{rows = Rows1}.
+%% @doc Fetches the rows for a result set and decodes them using either the text
+%% format (for plain queries) or binary format (for prepared statements).
+-spec fetch_resultset_rows(atom(), term(), integer(), [#col{}], text | binary,
+                           query_filtermap(), integer(), [[term()]]) ->
+    {ok, [[term()]], integer()} | #error{}.
+fetch_resultset_rows(SockModule, Socket, FieldCount, ColDefs, Proto,
+                     FilterMap, SeqNum0, Acc) ->
+    {ok, Packet, SeqNum1} = recv_packet(SockModule, Socket, SeqNum0),
+    case Packet of
+        ?error_pattern ->
+            parse_error_packet(Packet);
+        ?eof_pattern ->
+            {ok, lists:reverse(Acc), SeqNum1};
+        RowPacket ->
+            Row0=decode_row(FieldCount, ColDefs, RowPacket, Proto),
+            Acc1 = case filtermap_resultset_row(FilterMap, ColDefs, Row0) of
+                false ->
+                    Acc;
+                true ->
+                    [Row0|Acc];
+                {true, Row1} ->
+                    [Row1|Acc]
+            end,
+            fetch_resultset_rows(SockModule, Socket, FieldCount, ColDefs,
+                                 Proto, FilterMap, SeqNum1, Acc1)
+    end.
+
+-spec filtermap_resultset_row(query_filtermap(), [#col{}], [term()]) ->
+    query_filtermap_res().
+filtermap_resultset_row(no_filtermap_fun, _, _) ->
+    true;
+filtermap_resultset_row(Fun, _, Row) when is_function(Fun, 1) ->
+    Fun(Row);
+filtermap_resultset_row(Fun, ColDefs, Row) when is_function(Fun, 2) ->
+    Fun([Col#col.name || Col <- ColDefs], Row).
 
 
 more_results_exists(#ok{status = S}) ->
 more_results_exists(#ok{status = S}) ->
     S band ?SERVER_MORE_RESULTS_EXISTS /= 0;
     S band ?SERVER_MORE_RESULTS_EXISTS /= 0;
@@ -497,24 +542,6 @@ fetch_column_definitions(SockModule, Socket, SeqNum, NumLeft, Acc)
 fetch_column_definitions(_SockModule, _Socket, SeqNum, 0, Acc) ->
 fetch_column_definitions(_SockModule, _Socket, SeqNum, 0, Acc) ->
     {ok, lists:reverse(Acc), SeqNum}.
     {ok, lists:reverse(Acc), SeqNum}.
 
 
-%% @doc Fetches rows in a result set. There is a packet per row. The row packets
-%% are not decoded. This function can be used for both the binary and the text
-%% protocol result sets.
--spec fetch_resultset_rows(atom(), term(), SeqNum :: integer(), Acc) ->
-    {ok, Rows, integer()} | #error{}
-    when Acc :: [binary()],
-         Rows :: [binary()].
-fetch_resultset_rows(SockModule, Socket, SeqNum, Acc) ->
-    {ok, Packet, SeqNum1} = recv_packet(SockModule, Socket, SeqNum),
-    case Packet of
-        ?error_pattern ->
-            parse_error_packet(Packet);
-        ?eof_pattern ->
-            {ok, lists:reverse(Acc), SeqNum1};
-        Row ->
-            fetch_resultset_rows(SockModule, Socket, SeqNum1, [Row | Acc])
-    end.
-
 %% Parses a packet containing a column definition (part of a result set)
 %% Parses a packet containing a column definition (part of a result set)
 parse_column_definition(Data) ->
 parse_column_definition(Data) ->
     {<<"def">>, Rest1} = lenenc_str(Data),   %% catalog (always "def")
     {<<"def">>, Rest1} = lenenc_str(Data),   %% catalog (always "def")
@@ -540,6 +567,13 @@ parse_column_definition(Data) ->
     #col{name = Name, type = Type, charset = Charset, length = Length,
     #col{name = Name, type = Type, charset = Charset, length = Length,
          decimals = Decimals, flags = Flags}.
          decimals = Decimals, flags = Flags}.
 
 
+%% @doc Decodes a row using either the text or binary format.
+-spec decode_row(integer(), [#col{}], binary(), text | binary) -> [term()].
+decode_row(FieldCount, ColDefs, RowPacket, text) ->
+    decode_text_row(FieldCount, ColDefs, RowPacket);
+decode_row(FieldCount, ColDefs, RowPacket, binary) ->
+    decode_binary_row(FieldCount, ColDefs, RowPacket).
+
 %% -- text protocol --
 %% -- text protocol --
 
 
 -spec decode_text_row(NumColumns :: integer(),
 -spec decode_text_row(NumColumns :: integer(),

+ 48 - 0
test/mysql_tests.erl

@@ -200,6 +200,7 @@ query_test_() ->
           {"Autocommit",           fun () -> autocommit(Pid) end},
           {"Autocommit",           fun () -> autocommit(Pid) end},
           {"Encode",               fun () -> encode(Pid) end},
           {"Encode",               fun () -> encode(Pid) end},
           {"Basic queries",        fun () -> basic_queries(Pid) end},
           {"Basic queries",        fun () -> basic_queries(Pid) end},
+          {"Filtermap queries",    fun () -> filtermap_queries(Pid) end},
           {"FOUND_ROWS option",    fun () -> found_rows(Pid) end},
           {"FOUND_ROWS option",    fun () -> found_rows(Pid) end},
           {"Multi statements",     fun () -> multi_statements(Pid) end},
           {"Multi statements",     fun () -> multi_statements(Pid) end},
           {"Text protocol",        fun () -> text_protocol(Pid) end},
           {"Text protocol",        fun () -> text_protocol(Pid) end},
@@ -279,6 +280,53 @@ basic_queries(Pid) ->
 
 
     ok.
     ok.
 
 
+filtermap_queries(Pid) ->
+    ok = mysql:query(Pid, ?create_table_t),
+    ok = mysql:query(Pid, <<"INSERT INTO t (id, tx) VALUES (1, 'text 1')">>),
+    ok = mysql:query(Pid, <<"INSERT INTO t (id, tx) VALUES (2, 'text 2')">>),
+    ok = mysql:query(Pid, <<"INSERT INTO t (id, tx) VALUES (3, 'text 3')">>),
+
+    Query = <<"SELECT id, tx FROM t ORDER BY id">>,
+
+    %% one-ary filtermap fun
+    FilterMap1 = fun
+        ([1|_]) ->
+            true;
+        ([2|_]) ->
+            false;
+        (Row1=[3|_]) ->
+            {true, list_to_tuple(Row1)}
+    end,
+
+    %% two-ary filtermap fun
+    FilterMap2 = fun
+        (_, Row2) ->
+            FilterMap1(Row2)
+    end,
+
+    Expected = [[1, <<"text 1">>], {3, <<"text 3">>}],
+
+    %% test with plain query
+    {ok, _, Rows1}=mysql:query(Pid, Query, FilterMap1),
+    ?assertEqual(Expected, Rows1),
+    {ok, _, Rows2}=mysql:query(Pid, Query, FilterMap2),
+    ?assertEqual(Expected, Rows2),
+
+    %% test with parameterized query
+    {ok, _, Rows3}=mysql:query(Pid, Query, [], FilterMap1),
+    ?assertEqual(Expected, Rows3),
+    {ok, _, Rows4}=mysql:query(Pid, Query, [], FilterMap2),
+    ?assertEqual(Expected, Rows4),
+
+    %% test with prepared statement
+    {ok, PrepStmt} = mysql:prepare(Pid, Query),
+    {ok, _, Rows5}=mysql:execute(Pid, PrepStmt, [], FilterMap1),
+    ?assertEqual(Expected, Rows5),
+    {ok, _, Rows6}=mysql:execute(Pid, PrepStmt, [], FilterMap2),
+    ?assertEqual(Expected, Rows6),
+
+    ok = mysql:query(Pid, <<"DROP TABLE t">>).
+
 found_rows(Pid) ->
 found_rows(Pid) ->
     Options = [{user, ?user}, {password, ?password}, {log_warnings, false},
     Options = [{user, ?user}, {password, ?password}, {log_warnings, false},
                {keep_alive, true}, {found_rows, true}],
                {keep_alive, true}, {found_rows, true}],