Browse Source

Allow setting CLIENT_FOUND_ROWS on handshake.

The option is useful for compatibility with other SQL databases and
drivers (ODBC, PostgreSQL).
Konrad Zemek 8 years ago
parent
commit
6d8eada81c
4 changed files with 40 additions and 24 deletions
  1. 3 2
      include/protocol.hrl
  2. 19 8
      src/mysql.erl
  3. 16 12
      src/mysql_protocol.erl
  4. 2 2
      test/mysql_protocol_tests.erl

+ 3 - 2
include/protocol.hrl

@@ -26,6 +26,9 @@
 
 %% --- Capability flags ---
 
+%% Server: sends found rows instead of affected rows in EOF_Packet
+-define(CLIENT_FOUND_ROWS, 16#00000002).
+
 %% Server: supports schema-name in Handshake Response Packet
 %% Client: Handshake Response Packet contains a schema-name
 -define(CLIENT_CONNECT_WITH_DB, 16#00000008).
@@ -124,5 +127,3 @@
 -define(TYPE_VAR_STRING, 16#fd).
 -define(TYPE_STRING, 16#fe).
 -define(TYPE_GEOMETRY, 16#ff).
-
-

+ 19 - 8
src/mysql.erl

@@ -112,6 +112,11 @@
 %%   <dd>The default time to wait for a response when executing a query or a
 %%       prepared statement. This can be given per query using `query/3,4' and
 %%       `execute/4'. The default is `infinity'.</dd>
+%%   <dt>`{found_rows, boolean()}'</dt>
+%%   <dd>If set to true, the connection will be established with
+%%       CLIENT_FOUND_ROWS capability. affected_rows/1 will now return the
+%%       number of found rows, not the number of rows changed by the
+%%       query.</dd>
 %%   <dt>`{query_cache_time, Timeout}'</dt>
 %%   <dd>The minimum number of milliseconds to cache prepared statements used
 %%       for parametrized queries with query/3.</dd>
@@ -122,7 +127,7 @@
 %% </dl>
 -spec start_link(Options) -> {ok, pid()} | ignore | {error, term()}
     when Options :: [Option],
-         Option :: {name, ServerName} | {host, iodata()} | {port, integer()} | 
+         Option :: {name, ServerName} | {host, iodata()} | {port, integer()} |
                    {user, iodata()} | {password, iodata()} |
                    {database, iodata()} |
                    {connect_timeout, timeout()} |
@@ -131,6 +136,7 @@
                    {prepare, NamedStatements} |
                    {queries, [iodata()]} |
                    {query_timeout, timeout()} |
+                   {found_rows, boolean()} |
                    {query_cache_time, non_neg_integer()},
          ServerName :: {local, Name :: atom()} |
                        {global, GlobalName :: term()} |
@@ -287,7 +293,9 @@ warning_count(Conn) ->
     gen_server:call(Conn, warning_count).
 
 %% @doc Returns the number of inserted, updated and deleted rows of the last
-%% executed query or prepared statement.
+%% executed query or prepared statement. If found_rows is set on the
+%% connection, for update operation the return value will equal to the number
+%% of rows matched by the query.
 -spec affected_rows(connection()) -> integer().
 affected_rows(Conn) ->
     gen_server:call(Conn, affected_rows).
@@ -327,7 +335,7 @@ transaction(Conn, Fun, Retries) ->
     transaction(Conn, Fun, [], Retries).
 
 %% @doc This function executes the functional object Fun with arguments Args as
-%% a transaction. 
+%% a transaction.
 %%
 %% The semantics are as close as possible to mnesia's transactions. Transactions
 %% can be nested and are restarted automatically when deadlocks are detected.
@@ -455,7 +463,7 @@ encode(Conn, Term) ->
                 query_timeout, query_cache_time,
                 affected_rows = 0, status = 0, warning_count = 0, insert_id = 0,
                 transaction_level = 0, ping_ref = undefined,
-                stmts = dict:new(), query_cache = empty}).
+                stmts = dict:new(), query_cache = empty, cap_found_rows = false}).
 
 %% @private
 init(Opts) ->
@@ -472,6 +480,7 @@ init(Opts) ->
     QueryCacheTime = proplists:get_value(query_cache_time, Opts,
                                          ?default_query_cache_time),
     TcpOpts        = proplists:get_value(tcp_options, Opts, []),
+    SetFoundRows   = proplists:get_value(found_rows, Opts, false),
 
     PingTimeout = case KeepAlive of
         true         -> ?default_ping_timeout;
@@ -486,7 +495,7 @@ init(Opts) ->
     %% Exchange handshake communication.
     inet:setopts(Socket, [{active, false}]),
     Result = mysql_protocol:handshake(User, Password, Database, gen_tcp,
-                                      Socket),
+                                      Socket, SetFoundRows),
     inet:setopts(Socket, [{active, once}]),
     case Result of
         #handshake{server_version = Version, connection_id = ConnId,
@@ -498,7 +507,8 @@ init(Opts) ->
                            log_warnings = LogWarn,
                            ping_timeout = PingTimeout,
                            query_timeout = Timeout,
-                           query_cache_time = QueryCacheTime},
+                           query_cache_time = QueryCacheTime,
+                           cap_found_rows = (SetFoundRows =:= true)},
             %% Trap exit so that we can properly disconnect when we die.
             process_flag(trap_exit, true),
             State1 = schedule_ping(State),
@@ -896,14 +906,15 @@ log_warnings(#state{socket = Socket}, Query) ->
 %% @doc Makes a separate connection and execute KILL QUERY. We do this to get
 %% our main connection back to normal. KILL QUERY appeared in MySQL 5.0.0.
 kill_query(#state{connection_id = ConnId, host = Host, port = Port,
-                  user = User, password = Password}) ->
+                  user = User, password = Password,
+                  cap_found_rows = SetFoundRows}) ->
     %% Connect socket
     SockOpts = [{active, false}, binary, {packet, raw}],
     {ok, Socket} = gen_tcp:connect(Host, Port, SockOpts),
 
     %% Exchange handshake communication.
     Result = mysql_protocol:handshake(User, Password, undefined, gen_tcp,
-                                      Socket),
+                                      Socket, SetFoundRows),
     case Result of
         #handshake{} ->
             %% Kill and disconnect

+ 16 - 12
src/mysql_protocol.erl

@@ -26,7 +26,7 @@
 %% @private
 -module(mysql_protocol).
 
--export([handshake/5, quit/2, ping/2,
+-export([handshake/6, quit/2, ping/2,
          query/4, fetch_query_response/3,
          prepare/3, unprepare/3, execute/5, fetch_execute_response/3]).
 
@@ -45,14 +45,14 @@
 %% @doc Performs a handshake using the supplied functions for communication.
 %% Returns an ok or an error record. Raises errors when various unimplemented
 %% features are requested.
--spec handshake(iodata(), iodata(), iodata() | undefined, atom(), term()) ->
-    #handshake{} | #error{}.
-handshake(Username, Password, Database, TcpModule, Socket) ->
+-spec handshake(iodata(), iodata(), iodata() | undefined, atom(),
+                term(), boolean()) -> #handshake{} | #error{}.
+handshake(Username, Password, Database, TcpModule, Socket, SetFoundRows) ->
     SeqNum0 = 0,
     {ok, HandshakePacket, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum0),
     Handshake = parse_handshake(HandshakePacket),
     Response = build_handshake_response(Handshake, Username, Password,
-                                        Database),
+                                        Database, SetFoundRows),
     {ok, SeqNum2} = send_packet(TcpModule, Socket, Response, SeqNum1),
     {ok, ConfirmPacket, _SeqNum3} = recv_packet(TcpModule, Socket, SeqNum2),
     case parse_handshake_confirm(ConfirmPacket) of
@@ -86,7 +86,7 @@ query(Query, TcpModule, Socket, Timeout) ->
     fetch_query_response(TcpModule, Socket, Timeout).
 
 %% @doc This is used by query/4. If query/4 returns {error, timeout}, this
-%% function can be called to retry to fetch the results of the query. 
+%% function can be called to retry to fetch the results of the query.
 fetch_query_response(TcpModule, Socket, Timeout) ->
     fetch_response(TcpModule, Socket, Timeout, text, []).
 
@@ -226,16 +226,20 @@ server_version_to_list(ServerVersion) ->
 %% @doc The response sent by the client to the server after receiving the
 %% initial handshake from the server
 -spec build_handshake_response(#handshake{}, iodata(), iodata(),
-                               iodata() | undefined) -> binary().
-build_handshake_response(Handshake, Username, Password, Database) ->
+                               iodata() | undefined, boolean()) -> binary().
+build_handshake_response(Handshake, Username, Password, Database, SetFoundRows) ->
     %% We require these capabilities. Make sure the server handles them.
     CapabilityFlags0 = ?CLIENT_PROTOCOL_41 bor
                        ?CLIENT_TRANSACTIONS bor
                        ?CLIENT_SECURE_CONNECTION,
-    CapabilityFlags = case Database of
+    CapabilityFlags1 = case Database of
         undefined -> CapabilityFlags0;
         _         -> CapabilityFlags0 bor ?CLIENT_CONNECT_WITH_DB
     end,
+    CapabilityFlags = case SetFoundRows of
+        true -> CapabilityFlags1 bor ?CLIENT_FOUND_ROWS;
+        _    -> CapabilityFlags1
+    end,
     Handshake#handshake.capabilities band CapabilityFlags == CapabilityFlags
         orelse error(old_server_version),
     %% Add some extra capability flags only for signalling to the server what
@@ -404,9 +408,9 @@ fetch_resultset_rows(TcpModule, Socket, SeqNum, Acc) ->
 %% Parses a packet containing a column definition (part of a result set)
 parse_column_definition(Data) ->
     {<<"def">>, Rest1} = lenenc_str(Data),   %% catalog (always "def")
-    {_Schema, Rest2} = lenenc_str(Rest1),    %% schema-name 
-    {_Table, Rest3} = lenenc_str(Rest2),     %% virtual table-name 
-    {_OrgTable, Rest4} = lenenc_str(Rest3),  %% physical table-name 
+    {_Schema, Rest2} = lenenc_str(Rest1),    %% schema-name
+    {_Table, Rest3} = lenenc_str(Rest2),     %% virtual table-name
+    {_OrgTable, Rest4} = lenenc_str(Rest3),  %% physical table-name
     {Name, Rest5} = lenenc_str(Rest4),       %% virtual column name
     {_OrgName, Rest6} = lenenc_str(Rest5),   %% physical column name
     {16#0c, Rest7} = lenenc_int(Rest6),      %% length of the following fields

+ 2 - 2
test/mysql_protocol_tests.erl

@@ -110,11 +110,11 @@ prepare_test() ->
                            warning_count = 0} when is_integer(StmtId),
                  Result),
     ok.
-    
+
 bad_protocol_version_test() ->
     Sock = mock_tcp:create([{recv, <<2, 0, 0, 0, 9, 0>>}]),
     ?assertError(unknown_protocol,
-                 mysql_protocol:handshake("foo", "bar", "db", mock_tcp, Sock)),
+                 mysql_protocol:handshake("foo", "bar", "db", mock_tcp, Sock, false)),
     mock_tcp:close(Sock).
 
 %% --- Helper functions for the above tests ---