Browse Source

implement COM_CHANGE_USER (#105)

Implement COM_CHANGE_USER to enable switching to a different
user without closing the connection. This is useful when mysql-otp
is to be used in a connection pool to access a database with
many different users (eg a web hosting environment).
Jan Uhlig 6 years ago
parent
commit
892858bef6
6 changed files with 301 additions and 24 deletions
  1. 2 0
      .travis.yml
  2. 5 2
      README.md
  3. 86 17
      src/mysql.erl
  4. 26 4
      src/mysql_conn.erl
  5. 29 1
      src/mysql_protocol.erl
  6. 153 0
      test/mysql_change_user_tests.erl

+ 2 - 0
.travis.yml

@@ -12,6 +12,8 @@ before_script:
   - sleep 5
   - sleep 5
   - mysql -uroot -e "CREATE USER otptest@localhost IDENTIFIED BY 'otptest';"
   - mysql -uroot -e "CREATE USER otptest@localhost IDENTIFIED BY 'otptest';"
   - mysql -uroot -e "GRANT ALL PRIVILEGES ON otptest.* TO otptest@localhost;"
   - mysql -uroot -e "GRANT ALL PRIVILEGES ON otptest.* TO otptest@localhost;"
+  - mysql -uroot -e "CREATE USER otptest2@localhost IDENTIFIED BY 'otptest2';"
+  - mysql -uroot -e "GRANT ALL PRIVILEGES ON otptest.* TO otptest2@localhost;"
   - mysql -uroot -e "CREATE USER otptestssl@localhost IDENTIFIED BY 'otptestssl';"
   - mysql -uroot -e "CREATE USER otptestssl@localhost IDENTIFIED BY 'otptestssl';"
   - mysql -uroot -e "GRANT ALL PRIVILEGES ON otptest.* TO otptestssl@localhost REQUIRE SSL;"
   - mysql -uroot -e "GRANT ALL PRIVILEGES ON otptest.* TO otptestssl@localhost REQUIRE SSL;"
 script: 'make tests'
 script: 'make tests'

+ 5 - 2
README.md

@@ -111,13 +111,16 @@ The encode and protocol test suites does not require a
 running MySQL server on localhost.
 running MySQL server on localhost.
 
 
 For the suites `mysql_tests`, `ssl_tests` and `transaction_tests` you need to
 For the suites `mysql_tests`, `ssl_tests` and `transaction_tests` you need to
-start MySQL on localhost and give privileges to the user `otptest` and (for
-`ssl_tests`) to the user `otptestssl`:
+start MySQL on localhost and give privileges to the users `otptest`, `otptest2`
+and (for `ssl_tests`) to the user `otptestssl`:
 
 
 ```SQL
 ```SQL
 CREATE USER otptest@localhost IDENTIFIED BY 'otptest';
 CREATE USER otptest@localhost IDENTIFIED BY 'otptest';
 GRANT ALL PRIVILEGES ON otptest.* TO otptest@localhost;
 GRANT ALL PRIVILEGES ON otptest.* TO otptest@localhost;
 
 
+CREATE USER otptest2@localhost IDENTIFIED BY 'otptest2';
+GRANT ALL PRIVILEGES ON otptest.* TO otptest2@localhost;
+
 CREATE USER otptestssl@localhost IDENTIFIED BY 'otptestssl';
 CREATE USER otptestssl@localhost IDENTIFIED BY 'otptestssl';
 GRANT ALL PRIVILEGES ON otptest.* TO otptestssl@localhost REQUIRE SSL;
 GRANT ALL PRIVILEGES ON otptest.* TO otptestssl@localhost REQUIRE SSL;
 ```
 ```

+ 86 - 17
src/mysql.erl

@@ -30,7 +30,8 @@
          prepare/2, prepare/3, unprepare/2,
          prepare/2, prepare/3, unprepare/2,
          warning_count/1, affected_rows/1, autocommit/1, insert_id/1,
          warning_count/1, affected_rows/1, autocommit/1, insert_id/1,
          encode/2, in_transaction/1,
          encode/2, in_transaction/1,
-         transaction/2, transaction/3, transaction/4]).
+         transaction/2, transaction/3, transaction/4,
+         change_user/3, change_user/4]).
 
 
 -export_type([connection/0, server_reason/0, query_result/0]).
 -export_type([connection/0, server_reason/0, query_result/0]).
 
 
@@ -147,22 +148,9 @@ start_link(Options) ->
     end,
     end,
     case Ret of
     case Ret of
         {ok, Pid} ->
         {ok, Pid} ->
-            %% Initial queries
-            Queries = proplists:get_value(queries, Options, []),
-            lists:foreach(fun (Query) ->
-                              case mysql:query(Pid, Query) of
-                                  ok -> ok;
-                                  {ok, _, _} -> ok;
-                                  {ok, _} -> ok
-                              end
-                          end,
-                          Queries),
-            %% Prepare
-            Prepare = proplists:get_value(prepare, Options, []),
-            lists:foreach(fun ({Name, Stmt}) ->
-                              {ok, Name} = mysql:prepare(Pid, Name, Stmt)
-                          end,
-                          Prepare);
+            execute_after_connect(Pid,
+                                  proplists:get_value(queries, Options, []),
+                                  proplists:get_value(prepare, Options, []));
         _ -> ok
         _ -> ok
     end,
     end,
     Ret.
     Ret.
@@ -570,6 +558,12 @@ execute_transaction(Conn, Fun, Args, Retries) ->
             %% Returning 'atomic' or 'aborted' would both be wrong. Raise an
             %% Returning 'atomic' or 'aborted' would both be wrong. Raise an
             %% exception is the best we can do.
             %% exception is the best we can do.
             erlang:raise(error, E, ?GET_STACK(Stacktrace));
             erlang:raise(error, E, ?GET_STACK(Stacktrace));
+        ?EXCEPTION(error, change_user_in_transaction = E, Stacktrace) ->
+            %% The called tried to change user inside the transaction, which
+            %% is not allowed and a serious mistake. We roll back and raise
+            %% an error.
+            ok = gen_server:call(Conn, rollback, infinity),
+            erlang:raise(error, E, ?GET_STACK(Stacktrace));
         ?EXCEPTION(Class, Reason, Stacktrace) ->
         ?EXCEPTION(Class, Reason, Stacktrace) ->
             %% We must be able to rollback. Otherwise let's crash.
             %% We must be able to rollback. Otherwise let's crash.
             ok = gen_server:call(Conn, rollback, infinity),
             ok = gen_server:call(Conn, rollback, infinity),
@@ -582,6 +576,62 @@ execute_transaction(Conn, Fun, Args, Retries) ->
             {aborted, Aborted}
             {aborted, Aborted}
     end.
     end.
 
 
+%% @doc Equivalent to `change_user(Conn, Username, Password, [])'.
+%% @see change_user/4
+-spec change_user(Conn, Username, Password) -> Result
+    when Conn :: connection(),
+         Username :: iodata(),
+         Password :: iodata(),
+         Result :: ok.
+change_user(Conn, Username, Password) ->
+    change_user(Conn, Username, Password, []).
+
+%% @doc Changes the user of the active connection without closing and
+%% and re-opening it. The currently active session will be reset (ie,
+%% user variables, temporary tables, prepared statements, etc will
+%% be lost) independent of whether the operation succeeds or fails.
+%%
+%% If change user is called when a transaction is active (ie, neither
+%% committed nor rolled back), calling `change_user' will fail with
+%% an error exception and `change_user_in_transaction' as the error
+%% message.
+%%
+%% If the change user operation fails for other reasons (eg authentication
+%% failure), an error exception occurs, and the connection process
+%% exits with reason `change_user_failed'. The connection can not be used
+%% any longer if this happens.
+%%
+%% For a description of the `database', `queries' and `prepare'
+%% options, see `start_link/1'.
+%%
+%% @see start_link/1
+-spec change_user(Conn, Username, Password, Options) -> Result
+    when Conn :: connection(),
+         Username :: iodata(),
+         Password :: iodata(),
+         Options :: [Option],
+         Result :: ok,
+         Option :: {database, iodata()}
+                 | {queries, [iodata()]}
+                 | {prepare, [NamedStatement]},
+         NamedStatement :: {StatementName :: atom(), Statement :: iodata()}.
+change_user(Conn, Username, Password, Options) ->
+    case in_transaction(Conn) of
+        true -> error(change_user_in_transaction);
+        false -> ok
+    end,
+    Database = proplists:get_value(database, Options, undefined),
+    Ret = gen_server:call(Conn, {change_user, Username, Password, Database}),
+    case Ret of
+        ok ->
+            execute_after_connect(Conn,
+                                  proplists:get_value(queries, Options, []),
+                                  proplists:get_value(prepare, Options, [])),
+            ok;
+        {error, Reason} ->
+            error(Reason)
+    end.
+
 %% @doc Encodes a term as a MySQL literal so that it can be used to inside a
 %% @doc Encodes a term as a MySQL literal so that it can be used to inside a
 %% query. If backslash escapes are enabled, backslashes and single quotes in
 %% query. If backslash escapes are enabled, backslashes and single quotes in
 %% strings and binaries are escaped. Otherwise only single quotes are escaped.
 %% strings and binaries are escaped. Otherwise only single quotes are escaped.
@@ -602,6 +652,25 @@ encode(Conn, Term) ->
 
 
 %% --- Helpers ---
 %% --- Helpers ---
 
 
+%% @doc Executes the given queries and prepares the given statements after a
+%% connection has been made.
+-spec execute_after_connect(connection(), [iodata()], [{atom(), iodata()}])
+    -> ok.
+execute_after_connect(Conn, Queries, Prepares) ->
+    lists:foreach(fun (Query) ->
+                      case query(Conn, Query) of
+                          ok -> ok;
+                          {ok, _} -> ok;
+                          {ok, _, _} -> ok
+                      end
+                  end,
+                  Queries),
+    lists:foreach(fun ({Name, Stmt}) ->
+                      {ok, Name} = prepare(Conn, Name, Stmt)
+                  end,
+                  Prepares),
+    ok.
+
 %% @doc Makes a gen_server call for a query (plain, parametrized or prepared),
 %% @doc Makes a gen_server call for a query (plain, parametrized or prepared),
 %% checks the reply and sometimes throws an exception when we need to jump out
 %% checks the reply and sometimes throws an exception when we need to jump out
 %% of a transaction.
 %% of a transaction.

+ 26 - 4
src/mysql_conn.erl

@@ -50,7 +50,7 @@
 
 
 %% Gen_server state
 %% Gen_server state
 -record(state, {server_version, connection_id, socket, sockmod, ssl_opts,
 -record(state, {server_version, connection_id, socket, sockmod, ssl_opts,
-                host, port, user, password, log_warnings,
+                host, port, user, password, auth_plugin_data, log_warnings,
                 ping_timeout,
                 ping_timeout,
                 query_timeout, query_cache_time,
                 query_timeout, query_cache_time,
                 affected_rows = 0, status = 0, warning_count = 0, insert_id = 0,
                 affected_rows = 0, status = 0, warning_count = 0, insert_id = 0,
@@ -93,13 +93,16 @@ init(Opts) ->
         {ok, Handshake, SockMod, Socket} ->
         {ok, Handshake, SockMod, Socket} ->
             setopts(SockMod, Socket, [{active, once}]),
             setopts(SockMod, Socket, [{active, once}]),
             #handshake{server_version = Version, connection_id = ConnId,
             #handshake{server_version = Version, connection_id = ConnId,
-                       status = Status} = Handshake,
+                       status = Status,
+                       auth_plugin_data = AuthPluginData} = Handshake,
             State = #state{server_version = Version, connection_id = ConnId,
             State = #state{server_version = Version, connection_id = ConnId,
                            sockmod = SockMod,
                            sockmod = SockMod,
                            socket = Socket,
                            socket = Socket,
                            ssl_opts = SSLOpts,
                            ssl_opts = SSLOpts,
-                           host = Host, port = Port, user = User,
-                           password = Password, status = Status,
+                           host = Host, port = Port,
+                           user = User, password = Password,
+                           auth_plugin_data = AuthPluginData,
+                           status = Status,
                            log_warnings = LogWarn,
                            log_warnings = LogWarn,
                            ping_timeout = PingTimeout,
                            ping_timeout = PingTimeout,
                            query_timeout = Timeout,
                            query_timeout = Timeout,
@@ -272,6 +275,25 @@ handle_call({unprepare, Stmt}, _From, State) when is_atom(Stmt);
         error ->
         error ->
             {reply, {error, not_prepared}, State}
             {reply, {error, not_prepared}, State}
     end;
     end;
+handle_call({change_user, Username, Password, Database}, From,
+            State = #state{transaction_levels = []}) ->
+    #state{socket = Socket, sockmod = SockMod,
+           auth_plugin_data = AuthPluginData} = State,
+    setopts(SockMod, Socket, [{active, false}]),
+    Result = mysql_protocol:change_user(SockMod, Socket, Username, Password,
+                                        AuthPluginData, Database),
+    setopts(SockMod, Socket, [{active, once}]),
+    State1 = update_state(Result, State),
+    State1#state.warning_count > 0 andalso State1#state.log_warnings
+        andalso log_warnings(State1, "CHANGE USER"),
+    State2 = State1#state{query_cache = empty, stmts = dict:new()},
+    case Result of
+        #ok{} ->
+            {reply, ok, State2#state{user = Username, password = Password}};
+        #error{} = E ->
+            gen_server:reply(From, {error, error_to_reason(E)}),
+            stop_server(change_user_failed, State2)
+    end;
 handle_call(warning_count, _From, State) ->
 handle_call(warning_count, _From, State) ->
     {reply, State#state.warning_count, State};
     {reply, State#state.warning_count, State};
 handle_call(insert_id, _From, State) ->
 handle_call(insert_id, _From, State) ->

+ 29 - 1
src/mysql_protocol.erl

@@ -27,7 +27,7 @@
 %% @private
 %% @private
 -module(mysql_protocol).
 -module(mysql_protocol).
 
 
--export([handshake/7, quit/2, ping/2,
+-export([handshake/7, change_user/6, quit/2, ping/2,
          query/4, query/5, fetch_query_response/3,
          query/4, query/5, fetch_query_response/3,
          fetch_query_response/4, prepare/3, unprepare/3,
          fetch_query_response/4, prepare/3, unprepare/3,
          execute/5, execute/6, fetch_execute_response/3,
          execute/5, execute/6, fetch_execute_response/3,
@@ -228,6 +228,27 @@ fetch_execute_response(SockModule, Socket, Timeout) ->
 fetch_execute_response(SockModule, Socket, FilterMap, Timeout) ->
 fetch_execute_response(SockModule, Socket, FilterMap, Timeout) ->
     fetch_response(SockModule, Socket, Timeout, binary, FilterMap, []).
     fetch_response(SockModule, Socket, Timeout, binary, FilterMap, []).
 
 
+%% @doc Changes the user of the connection.
+-spec change_user(atom(), term(), iodata(), iodata(), binary(),
+                  undefined | iodata()) -> #ok{} | #error{}.
+change_user(SockMod, Socket, Username, Password, Salt, Database) ->
+    DbBin = case Database of
+        undefined -> <<>>;
+        _ -> iolist_to_binary(Database)
+    end,
+    Hash = hash_password(Password, Salt),
+    Req = <<?COM_CHANGE_USER, (iolist_to_binary(Username))/binary, 0,
+            (lenenc_str_encode(Hash))/binary,
+            DbBin/binary, 0, ?UTF8:16/little>>,
+    {ok, _SeqNum1} = send_packet(SockMod, Socket, Req, 0),
+    {ok, Packet, _SeqNum2} = recv_packet(SockMod, Socket, infinity, any),
+    case Packet of
+        ?ok_pattern ->
+            parse_ok_packet(Packet);
+        ?error_pattern ->
+            parse_error_packet(Packet)
+    end.
+
 %% --- internal ---
 %% --- internal ---
 
 
 %% @doc Parses a handshake. This is the first thing that comes from the server
 %% @doc Parses a handshake. This is the first thing that comes from the server
@@ -1203,6 +1224,13 @@ lenenc_str(Bin) ->
     <<String:Length/binary, Rest1/binary>> = Rest,
     <<String:Length/binary, Rest1/binary>> = Rest,
     {String, Rest1}.
     {String, Rest1}.
 
 
+%% Length-encoded-string encode. Prefixes the value with a
+%% length-encoded-integer denoting its size.
+-spec lenenc_str_encode(Input :: binary()) -> binary().
+lenenc_str_encode(Bin) ->
+    Length = byte_size(Bin),
+    <<(lenenc_int_encode(Length))/binary, Bin:Length/binary>>.
+
 %% nts/1 decodes a nul-terminated string
 %% nts/1 decodes a nul-terminated string
 -spec nulterm_str(Input :: binary()) -> {String :: binary(), Rest :: binary()}.
 -spec nulterm_str(Input :: binary()) -> {String :: binary(), Rest :: binary()}.
 nulterm_str(Bin) ->
 nulterm_str(Bin) ->

+ 153 - 0
test/mysql_change_user_tests.erl

@@ -0,0 +1,153 @@
+%% MySQL/OTP – MySQL client library for Erlang/OTP
+%% Copyright (C) 2019 Jan Uhlig
+%%
+%% This file is part of MySQL/OTP.
+%%
+%% MySQL/OTP is free software: you can redistribute it and/or modify it under
+%% the terms of the GNU Lesser General Public License as published by the Free
+%% Software Foundation, either version 3 of the License, or (at your option)
+%% any later version.
+%%
+%% This program is distributed in the hope that it will be useful, but WITHOUT
+%% ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+%% FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+%% more details.
+%%
+%% You should have received a copy of the GNU Lesser General Public License
+%% along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+%% @doc This module performs test to an actual database.
+-module(mysql_change_user_tests).
+
+-include_lib("eunit/include/eunit.hrl").
+
+-define(user1,     "otptest").
+-define(password1, "otptest").
+-define(user2,     "otptest2").
+-define(password2, "otptest2").
+
+%% Ensure that the current user can be changed to another user
+%% when given correct credentials.
+correct_credentials_test() ->
+    Pid = connect_db(?user1),
+    ?assertEqual(ok, mysql:change_user(Pid, ?user2, ?password2)),
+    ?assert(is_current_user(Pid, ?user2)),
+    close_conn(Pid),
+    ok.
+
+%% Ensure that change user fails when given incorrect credentials,
+%% and that the current user still works.
+incorrect_credentials_fail_test() ->
+    Pid = connect_db(?user1),
+    TrapExit = erlang:process_flag(trap_exit, true),
+    ?assertError({1045, <<"28000">>, <<"Access denied", _/binary>>},
+                 mysql:change_user(Pid, ?user2, ?password1)),
+    ExitReason = receive {'EXIT', Pid, Reason} -> Reason after 1000 -> error(timeout) end,
+    erlang:process_flag(trap_exit, TrapExit),
+    ?assertEqual(change_user_failed, ExitReason),
+    close_conn(Pid),
+    ok.
+
+%% Ensure that user variables are reset after a successful change user
+%% operation.
+reset_variables_test() ->
+    Pid = connect_db(?user1),
+    ok = mysql:query(Pid, <<"SET @foo=123">>),
+    ?assertEqual(ok, mysql:change_user(Pid, ?user2, ?password2)),
+    ?assert(is_current_user(Pid, ?user2)),
+    ?assertEqual({ok,
+                  [<<"@foo">>],
+                  [[null]]},
+                 mysql:query(Pid, <<"SELECT @foo">>)),
+    close_conn(Pid),
+    ok.
+
+%% Ensure that temporary tables are reset after a successful change user
+%% operation.
+reset_temptables_test() ->
+    Pid = connect_db(?user1),
+    ok = mysql:query(Pid, <<"CREATE DATABASE IF NOT EXISTS otptest">>),
+    ok = mysql:query(Pid, <<"CREATE TEMPORARY TABLE otptest.foo (bar INT)">>),
+    ?assertEqual(ok, mysql:change_user(Pid, ?user2, ?password2)),
+    ?assert(is_current_user(Pid, ?user2)),
+    ?assertMatch({error,
+                  {1146, <<"42S02">>, _}},
+                 mysql:query(Pid, <<"SELECT * FROM otptest.foo">>)),
+    ok = mysql:query(Pid, <<"DROP DATABASE IF EXISTS otptest">>),
+    close_conn(Pid),
+    ok.
+
+%% Ensure that change user fails when inside an unmanaged transaction.
+fail_in_unmanaged_transaction_test() ->
+    Pid = connect_db(?user1),
+    ok = mysql:query(Pid, <<"BEGIN">>),
+    ?assert(mysql:in_transaction(Pid)),
+    ?assertError(change_user_in_transaction,
+                 mysql:change_user(Pid, ?user2, ?password2)),
+    ?assert(is_current_user(Pid, ?user1)),
+    ?assert(mysql:in_transaction(Pid)),
+    close_conn(Pid),
+    ok.
+
+%% Ensure that change user fails when inside a managed transaction.
+fail_in_managed_transaction_test() ->
+    Pid = connect_db(?user1),
+    ?assertError(change_user_in_transaction,
+                 mysql:transaction(Pid,
+                                   fun () -> mysql:change_user(Pid,
+                                                               ?user2,
+                                                               ?password2)
+                                   end)),
+    ?assert(is_current_user(Pid, ?user1)),
+    close_conn(Pid),
+    ok.
+
+with_db_test() ->
+    Pid = connect_db(?user1),
+    ok = mysql:query(Pid, <<"CREATE DATABASE IF NOT EXISTS otptest">>),
+    ?assertEqual(ok, mysql:change_user(Pid, ?user2, ?password2, [{database, <<"otptest">>}])),
+    ?assert(is_current_user(Pid, ?user2)),
+    ?assertEqual({ok,
+                  [<<"DATABASE()">>],
+                  [[<<"otptest">>]]},
+                 mysql:query(Pid, <<"SELECT DATABASE()">>)),
+    ok = mysql:query(Pid, <<"DROP DATABASE IF EXISTS otptest">>),
+    close_conn(Pid),
+    ok.
+
+execute_queries_test() ->
+    Pid = connect_db(?user1),
+    ?assertEqual(ok, mysql:change_user(Pid, ?user2, ?password2, [{queries, [<<"SET @foo=123">>]}])),
+    ?assert(is_current_user(Pid, ?user2)),
+    ?assertEqual({ok,
+                  [<<"@foo">>],
+                  [[123]]},
+                 mysql:query(Pid, <<"SELECT @foo">>)),
+    close_conn(Pid),
+    ok.
+
+prepare_statements_test() ->
+    Pid = connect_db(?user1),
+    ?assertEqual(ok, mysql:change_user(Pid, ?user2, ?password2, [{prepare, [{foo, <<"SELECT ? AS foo">>}]}])),
+    ?assert(is_current_user(Pid, ?user2)),
+    ?assertEqual({ok,
+                  [<<"foo">>],
+                  [[123]]},
+                 mysql:execute(Pid, foo, [123])),
+    close_conn(Pid),
+    ok.
+
+
+connect_db(User) ->
+    {ok, Pid} = mysql:start_link([{user, User}, {password, ?password1},
+                                  {log_warnings, false}]),
+    Pid.
+
+close_conn(Pid) ->
+    exit(Pid, normal).
+
+is_current_user(Pid, User) when is_binary(User) ->
+    {ok, [<<"CURRENT_USER()">>], [[CurUser]]}=mysql:query(Pid, <<"SELECT CURRENT_USER()">>),
+    <<User/binary, "@localhost">> =:= CurUser;
+is_current_user(Pid, User) ->
+    is_current_user(Pid, iolist_to_binary(User)).