Browse Source

Implements selecting db in the connection phase

Viktor Söderqvist 10 years ago
parent
commit
5f057dda27
4 changed files with 48 additions and 24 deletions
  1. 4 0
      include/protocol.hrl
  2. 4 9
      src/mysql.erl
  3. 18 13
      src/mysql_protocol.erl
  4. 22 2
      test/mysql_tests.erl

+ 4 - 0
include/protocol.hrl

@@ -26,6 +26,10 @@
 
 %% --- Capability flags ---
 
+%% Server: supports schema-name in Handshake Response Packet
+%% Client: Handshake Response Packet contains a schema-name
+-define(CLIENT_CONNECT_WITH_DB, 16#00000008).
+
 %% Server: supports the 4.1 protocol 
 %% Client: uses the 4.1 protocol 
 -define(CLIENT_PROTOCOL_41, 16#00000200).

+ 4 - 9
src/mysql.erl

@@ -73,9 +73,6 @@
 %%   <dd>The name of the database AKA schema to use. This can be changed later
 %%       using the query `USE <database>'.</dd>
 %% </dl>
-%%
-%% TODO: Implement {database, Database}. Currently the database has to be
-%% selected using a `USE <database>' query after connecting.
 -spec start_link(Options) -> {ok, pid()} | ignore | {error, term()}
     when Options :: [Option],
          Option :: {name, ServerName} | {host, iodata()} | {port, integer()} | 
@@ -365,9 +362,6 @@ handle_call(autocommit, _From, State) ->
     {reply, State#state.status band ?SERVER_STATUS_AUTOCOMMIT /= 0, State};
 handle_call(in_transaction, _From, State) ->
     {reply, State#state.status band ?SERVER_STATUS_IN_TRANS /= 0, State}.
-%handle_call(get_state, _From, State) ->
-%    %% *** FOR DEBUGGING ***
-%    {reply, State, State}.
 
 %% @private
 handle_cast(_Msg, State) ->
@@ -387,9 +381,10 @@ terminate(Reason, State) when Reason == normal; Reason == shutdown ->
 terminate(_Reason, _State) ->
     ok.
 
-%% @private
-code_change(_OldVsn, State, _Extra) ->
-    {ok, State}.
+code_change(_OldVsn, State = #state{}, _Extra) ->
+    {ok, State};
+code_change(_OldVsn, _State, _Extra) ->
+    {error, incompatible_state}.
 
 %% --- Helpers ---
 

+ 18 - 13
src/mysql_protocol.erl

@@ -49,17 +49,14 @@
 %% @doc Performs a handshake using the supplied functions for communication.
 %% Returns an ok or an error record. Raises errors when various unimplemented
 %% features are requested.
-%%
-%% TODO: Implement setting the database in the handshake. Currently an error
-%% occurs if Database is anything other than undefined.
 -spec handshake(iodata(), iodata(), iodata() | undefined, sendfun(),
                 recvfun()) -> #ok{} | #error{}.
 handshake(Username, Password, Database, SendFun, RecvFun) ->
     SeqNum0 = 0,
-    Database == undefined orelse error(database_in_handshake),
     {ok, HandshakePacket, SeqNum1} = recv_packet(RecvFun, SeqNum0),
     Handshake = parse_handshake(HandshakePacket),
-    Response = build_handshake_response(Handshake, Username, Password),
+    Response = build_handshake_response(Handshake, Username, Password,
+                                        Database),
     {ok, SeqNum2} = send_packet(SendFun, Response, SeqNum1),
     {ok, ConfirmPacket, _SeqNum3} = recv_packet(RecvFun, SeqNum2),
     parse_handshake_confirm(ConfirmPacket).
@@ -236,12 +233,17 @@ parse_handshake(<<Protocol:8, _/binary>>) when Protocol /= 10 ->
 
 %% @doc The response sent by the client to the server after receiving the
 %% initial handshake from the server
--spec build_handshake_response(#handshake{}, iodata(), iodata()) -> binary().
-build_handshake_response(Handshake, Username, Password) ->
+-spec build_handshake_response(#handshake{}, iodata(), iodata(),
+                               iodata() | undefined) -> binary().
+build_handshake_response(Handshake, Username, Password, Database) ->
     %% We require these capabilities. Make sure the server handles them.
-    CapabilityFlags = ?CLIENT_PROTOCOL_41 bor
-                      ?CLIENT_TRANSACTIONS bor
-                      ?CLIENT_SECURE_CONNECTION,
+    CapabilityFlags0 = ?CLIENT_PROTOCOL_41 bor
+                       ?CLIENT_TRANSACTIONS bor
+                       ?CLIENT_SECURE_CONNECTION,
+    CapabilityFlags = case Database of
+        undefined -> CapabilityFlags0;
+        _         -> CapabilityFlags0 bor ?CLIENT_CONNECT_WITH_DB
+    end,
     Handshake#handshake.capabilities band CapabilityFlags == CapabilityFlags
         orelse error(old_server_version),
     Hash = case Handshake#handshake.auth_plugin_name of
@@ -256,6 +258,10 @@ build_handshake_response(Handshake, Username, Password) ->
     HashLength = size(Hash),
     CharacterSet = ?UTF8,
     UsernameUtf8 = unicode:characters_to_binary(Username),
+    DbBin = case Database of
+        undefined -> <<>>;
+        _         -> <<(iolist_to_binary(Database))/binary, 0>>
+    end,
     <<CapabilityFlags:32/little,
       ?MAX_BYTES_PER_PACKET:32/little,
       CharacterSet:8,
@@ -263,7 +269,8 @@ build_handshake_response(Handshake, Username, Password) ->
       UsernameUtf8/binary,
       0, %% NUL-terminator for the username
       HashLength,
-      Hash/binary>>.
+      Hash/binary,
+      DbBin/binary>>.
 
 %% @doc Handles the second packet from the server, when we have replied to the
 %% initial handshake. Returns an error if the server returns an error. Raises
@@ -675,8 +682,6 @@ floor(Value) ->
 %% @doc Encodes a term reprenting av value as a binary for use in the binary
 %% protocol. As this is used to encode parameters for prepared statements, the
 %% encoding is in its required form, namely `<<Type:8, Sign:8, Value/binary>>'.
-%%
-%% TODO: Maybe change Erlang representation of BIT to `<<_:1>>'.
 -spec encode_param(term()) -> {TypeAndSign :: binary(), Data :: binary()}.
 encode_param(null) ->
     {<<?TYPE_NULL, 0>>, <<>>};

+ 22 - 2
test/mysql_tests.erl

@@ -39,7 +39,11 @@
 connect_test() ->
     %% A connection with a registered name
     Options = [{name, {local, tardis}}, {user, ?user}, {password, ?password}],
-    ?assertMatch({ok, Pid} when is_pid(Pid), mysql:start_link(Options)),
+    {ok, Pid} = mysql:start_link(Options),
+    %% Test some gen_server callbacks not tested elsewhere
+    State = sys:get_state(Pid),
+    ?assertMatch({ok, State}, mysql:code_change("0.1.0", State, [])),
+    ?assertMatch({error, _}, mysql:code_change("2.0.0", unknown_state, [])),
     exit(whereis(tardis), normal).
 
 query_test_() ->
@@ -56,7 +60,8 @@ query_test_() ->
          ok = mysql:query(Pid, <<"DROP DATABASE otptest">>),
          exit(Pid, normal)
      end,
-     {with, [fun autocommit/1,
+     {with, [fun connect_with_db/1,
+             fun autocommit/1,
              fun basic_queries/1,
              fun text_protocol/1,
              fun binary_protocol/1,
@@ -67,6 +72,14 @@ query_test_() ->
              fun time/1,
              fun microseconds/1]}}.
 
+connect_with_db(_Pid) ->
+    %% Make another connection and set the db in the handshake phase
+    {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password},
+                                  {database, "otptest"}]),
+    ?assertMatch({ok, _, [[<<"otptest">>]]},
+                 mysql:query(Pid, "SELECT DATABASE()")),
+    exit(Pid, normal).
+
 autocommit(Pid) ->
     ?assert(mysql:autocommit(Pid)),
     ok = mysql:query(Pid, <<"SET autocommit = 0">>),
@@ -405,3 +418,10 @@ transaction_simple_aborted(Pid) ->
                  mysql:transaction(Pid, fun () -> throw(foo) end)),
     ?assertEqual({aborted, foo},
                  mysql:transaction(Pid, fun () -> exit(foo) end)).
+
+%% --- simple gen_server callbacks ---
+
+gen_server_coverage_test() ->
+    {noreply, state} = mysql:handle_cast(foo, state),
+    {noreply, state} = mysql:handle_info(foo, state),
+    ok = mysql:terminate(kill, state).