Browse Source

Add SSL connection support

Piotr Nosek 7 years ago
parent
commit
8d322d48e4
13 changed files with 440 additions and 176 deletions
  1. 4 0
      .gitignore
  2. 9 0
      .travis.yml
  3. 3 1
      Makefile
  4. 16 4
      README.md
  5. 3 0
      include/protocol.hrl
  6. 91 75
      src/mysql.erl
  7. 135 84
      src/mysql_protocol.erl
  8. 64 0
      src/mysql_sock_ssl.erl
  9. 42 0
      src/mysql_sock_tcp.erl
  10. 3 1
      test/mysql_protocol_tests.erl
  11. 39 11
      test/mysql_tests.erl
  12. 27 0
      test/ssl/Makefile
  13. 4 0
      test/ssl/my-ssl.cnf.template

+ 4 - 0
.gitignore

@@ -16,3 +16,7 @@ cover
 *.coverdata
 tests.output
 test/ct.cover.spec
+test/ssl/ca*
+test/ssl/server*
+test/ssl/my-ssl.cnf
+test/ssl/my-ssl.cnf-e

+ 9 - 0
.travis.yml

@@ -2,7 +2,16 @@ language: erlang
 services:
   - mysql
 before_script:
+  - sudo service mysql stop
+  - SSLDIR=/etc/mysql/ make tests-prep
+  - sudo cp test/ssl/*.pem /etc/mysql/
+  - sudo chmod -R 660 /etc/mysql/*.pem
+  - sudo chown -R mysql:mysql /etc/mysql/*.pem
+  - cat test/ssl/my-ssl.cnf | sudo tee -a /etc/mysql/conf.d/my-ssl.cnf
+  - sudo service mysql start
+  - sleep 5
   - mysql -uroot -e "grant all privileges on otptest.* to otptest@localhost identified by 'otptest'"
+  - mysql -uroot -e "grant all privileges on otptestssl.* to otptestssl@localhost identified by 'otptestssl' require ssl"
 script: 'make tests'
 otp_release:
   - 19.0

+ 3 - 1
Makefile

@@ -8,6 +8,7 @@
 #  - tests-report:   Creates doc/eunit.html with the coverage and eunit output.
 #  - gh-pages:       Generates docs and eunit reports and commits these in the
 #                    gh-pages which Github publishes automatically when pushed.
+.PHONY: gh-pages tests-report tests-prep CHANGELOG.md
 
 PROJECT = mysql
 EDOC_OPTS = {stylesheet_file,"priv/edoc-style.css"},{todo,true}
@@ -16,7 +17,8 @@ SHELL_PATH = -pa ebin
 
 include erlang.mk
 
-.PHONY: gh-pages tests-report CHANGELOG.md
+tests-prep:
+	cd test/ssl && $(MAKE)
 
 CHANGELOG.md:
 	./changelog.sh > $@

+ 16 - 4
README.md

@@ -32,9 +32,10 @@ Synopsis
 --------
 
 ```Erlang
-%% Connect
+%% Connect (ssl option is not mandatory)
 {ok, Pid} = mysql:start_link([{host, "localhost"}, {user, "foo"},
-                              {password, "hello"}, {database, "test"}]),
+                              {password, "hello"}, {database, "test"},
+                              {ssl, [{cacertfile, "/path/to/ca.pem"}]}]),
 
 %% Select
 {ok, ColumnNames, Rows} =
@@ -88,13 +89,24 @@ Using *rebar*:
 Contributing
 ------------
 
-Run the eunit tests with `make tests`. For the suite `mysql_tests` you
-need MySQL running on localhost and give privileges to the `otptest` user:
+Before running the tests you'll need to generate SSL files and MySQL extra config file.
+In order to do so, please execute `make tests-prep`.
+
+The MySQL server configuration must include `my-ssl.cnf` file,
+which can be found in `test/ssl/`.
+**Do not run** `make tests-prep` after you start MySQL,
+because CA certificates will no longer match.
+
+For the suite `mysql_tests` you need to start MySQL on localhost and give
+privileges to the `otptest` and `otptestssl` users:
 
 ```SQL
 grant all privileges on otptest.* to otptest@localhost identified by 'otptest';
+grant all privileges on otptest.* to otptestssl@localhost identified by 'otptestssl' require ssl;
 ```
 
+EUnit tests are executed with `make tests`.
+
 If you run `make tests COVER=1` a coverage report will be generated. Open
 `cover/index.html` to see that any lines you have added or modified are covered
 by a test.

+ 3 - 0
include/protocol.hrl

@@ -37,6 +37,9 @@
 %% Client: uses the 4.1 protocol
 -define(CLIENT_PROTOCOL_41, 16#00000200).
 
+%% Client: supports SSL
+-define(CLIENT_SSL, 16#00000800).
+
 %% Server: can send status flags in EOF_Packet
 %% Client: expects status flags in EOF_Packet
 -define(CLIENT_TRANSACTIONS, 16#00002000).

+ 91 - 75
src/mysql.erl

@@ -1,6 +1,7 @@
 %% MySQL/OTP – MySQL client library for Erlang/OTP
 %% Copyright (C) 2014-2015 Viktor Söderqvist,
 %%               2016 Johan Lövdahl
+%%               2017 Piotr Nosek, Michal Slaski
 %%
 %% This file is part of MySQL/OTP.
 %%
@@ -457,7 +458,7 @@ encode(Conn, Term) ->
 -include("server_status.hrl").
 
 %% Gen_server state
--record(state, {server_version, connection_id, socket,
+-record(state, {server_version, connection_id, socket, sockmod, ssl_opts,
                 host, port, user, password, log_warnings,
                 ping_timeout,
                 query_timeout, query_cache_time,
@@ -481,6 +482,8 @@ init(Opts) ->
                                          ?default_query_cache_time),
     TcpOpts        = proplists:get_value(tcp_options, Opts, []),
     SetFoundRows   = proplists:get_value(found_rows, Opts, false),
+    SSLOpts        = proplists:get_value(ssl, Opts, undefined),
+    SockMod0       = mysql_sock_tcp,
 
     PingTimeout = case KeepAlive of
         true         -> ?default_ping_timeout;
@@ -489,19 +492,21 @@ init(Opts) ->
     end,
 
     %% Connect socket
-    SockOpts = [binary, {packet, raw},{active, false} | TcpOpts],
-    {ok, Socket} = gen_tcp:connect(Host, Port, SockOpts),
+    SockOpts = [binary, {packet, raw}, {active, false} | TcpOpts],
+    {ok, Socket0} = SockMod0:connect(Host, Port, SockOpts),
 
     %% Exchange handshake communication.
-    inet:setopts(Socket, [{active, false}]),
-    Result = mysql_protocol:handshake(User, Password, Database, gen_tcp,
-                                      Socket, SetFoundRows),
-    inet:setopts(Socket, [{active, once}]),
+    Result = mysql_protocol:handshake(User, Password, Database, SockMod0, SSLOpts,
+                                      Socket0, SetFoundRows),
     case Result of
-        #handshake{server_version = Version, connection_id = ConnId,
-                   status = Status} ->
+        {ok, Handshake, SockMod, Socket} ->
+            SockMod:setopts(Socket, [{active, once}]),
+            #handshake{server_version = Version, connection_id = ConnId,
+                       status = Status} = Handshake,
             State = #state{server_version = Version, connection_id = ConnId,
+                           sockmod = SockMod,
                            socket = Socket,
+                           ssl_opts = SSLOpts,
                            host = Host, port = Port, user = User,
                            password = Password, status = Status,
                            log_warnings = LogWarn,
@@ -565,12 +570,13 @@ init(Opts) ->
 handle_call({query, Query}, From, State) ->
     handle_call({query, Query, State#state.query_timeout}, From, State);
 handle_call({query, Query, Timeout}, _From, State) ->
+    SockMod = State#state.sockmod,
     Socket = State#state.socket,
-    inet:setopts(Socket, [{active, false}]),
-    {ok, Recs} = case mysql_protocol:query(Query, gen_tcp, Socket, Timeout) of
+    SockMod:setopts(Socket, [{active, false}]),
+    {ok, Recs} = case mysql_protocol:query(Query, SockMod, Socket, Timeout) of
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
             kill_query(State),
-            mysql_protocol:fetch_query_response(gen_tcp, Socket, ?cmd_timeout);
+            mysql_protocol:fetch_query_response(SockMod, Socket, ?cmd_timeout);
         {error, timeout} ->
             %% For MySQL 4.x.x there is no way to recover from timeout except
             %% killing the connection itself.
@@ -578,7 +584,7 @@ handle_call({query, Query, Timeout}, _From, State) ->
         QueryResult ->
             QueryResult
     end,
-    inet:setopts(Socket, [{active, once}]),
+    SockMod:setopts(Socket, [{active, once}]),
     State1 = lists:foldl(fun update_state/2, State, Recs),
     State1#state.warning_count > 0 andalso State1#state.log_warnings
         andalso log_warnings(State1, Query),
@@ -589,7 +595,7 @@ handle_call({param_query, Query, Params}, From, State) ->
 handle_call({param_query, Query, Params, Timeout}, _From, State) ->
     %% Parametrized query: Prepared statement cached with the query as the key
     QueryBin = iolist_to_binary(Query),
-    #state{socket = Socket} = State,
+    #state{socket = Socket, sockmod = SockMod} = State,
     Cache = State#state.query_cache,
     {StmtResult, Cache1} = case mysql_cache:lookup(QueryBin, Cache) of
         {found, FoundStmt, NewCache} ->
@@ -597,9 +603,10 @@ handle_call({param_query, Query, Params, Timeout}, _From, State) ->
             {{ok, FoundStmt}, NewCache};
         not_found ->
             %% Prepare
-            inet:setopts(Socket, [{active, false}]),
-            Rec = mysql_protocol:prepare(Query, gen_tcp, Socket),
-            inet:setopts(Socket, [{active, once}]),
+            SockMod:setopts(Socket, [{active, false}]),
+	    SockMod = State#state.sockmod,
+            Rec = mysql_protocol:prepare(Query, SockMod, Socket),
+            SockMod:setopts(Socket, [{active, once}]),
             %State1 = update_state(Rec, State),
             case Rec of
                 #error{} = E ->
@@ -630,10 +637,11 @@ handle_call({execute, Stmt, Args, Timeout}, _From, State) ->
             {reply, {error, not_prepared}, State}
     end;
 handle_call({prepare, Query}, _From, State) ->
-    #state{socket = Socket} = State,
-    inet:setopts(Socket, [{active, false}]),
-    Rec = mysql_protocol:prepare(Query, gen_tcp, Socket),
-    inet:setopts(Socket, [{active, once}]),
+    #state{socket = Socket, sockmod = SockMod} = State,
+    SockMod:setopts(Socket, [{active, false}]),
+    SockMod = State#state.sockmod,
+    Rec = mysql_protocol:prepare(Query, SockMod, Socket),
+    SockMod:setopts(Socket, [{active, once}]),
     State1 = update_state(Rec, State),
     case Rec of
         #error{} = E ->
@@ -644,18 +652,19 @@ handle_call({prepare, Query}, _From, State) ->
             {reply, {ok, Id}, State2}
     end;
 handle_call({prepare, Name, Query}, _From, State) when is_atom(Name) ->
-    #state{socket = Socket} = State,
+    #state{socket = Socket, sockmod = SockMod} = State,
     %% First unprepare if there is an old statement with this name.
-    inet:setopts(Socket, [{active, false}]),
+    SockMod:setopts(Socket, [{active, false}]),
+    SockMod = State#state.sockmod,
     State1 = case dict:find(Name, State#state.stmts) of
         {ok, OldStmt} ->
-            mysql_protocol:unprepare(OldStmt, gen_tcp, Socket),
+            mysql_protocol:unprepare(OldStmt, SockMod, Socket),
             State#state{stmts = dict:erase(Name, State#state.stmts)};
         error ->
             State
     end,
-    Rec = mysql_protocol:prepare(Query, gen_tcp, Socket),
-    inet:setopts(Socket, [{active, once}]),
+    Rec = mysql_protocol:prepare(Query, SockMod, Socket),
+    SockMod:setopts(Socket, [{active, once}]),
     State2 = update_state(Rec, State1),
     case Rec of
         #error{} = E ->
@@ -669,10 +678,11 @@ handle_call({unprepare, Stmt}, _From, State) when is_atom(Stmt);
                                                   is_integer(Stmt) ->
     case dict:find(Stmt, State#state.stmts) of
         {ok, StmtRec} ->
-            #state{socket = Socket} = State,
-            inet:setopts(Socket, [{active, false}]),
-            mysql_protocol:unprepare(StmtRec, gen_tcp, Socket),
-            inet:setopts(Socket, [{active, once}]),
+            #state{socket = Socket, sockmod = SockMod} = State,
+            SockMod:setopts(Socket, [{active, false}]),
+            SockMod = State#state.sockmod,
+            mysql_protocol:unprepare(StmtRec, SockMod, Socket),
+            SockMod:setopts(Socket, [{active, once}]),
             State1 = State#state{stmts = dict:erase(Stmt, State#state.stmts)},
             State2 = schedule_ping(State1),
             {reply, ok, State2};
@@ -692,44 +702,47 @@ handle_call(backslash_escapes_enabled, _From, State = #state{status = S}) ->
 handle_call(in_transaction, _From, State) ->
     {reply, State#state.status band ?SERVER_STATUS_IN_TRANS /= 0, State};
 handle_call(start_transaction, _From,
-            State = #state{socket = Socket, transaction_level = L,
-                           status = Status})
+            State = #state{socket = Socket, sockmod = SockMod,
+                           transaction_level = L, status = Status})
   when Status band ?SERVER_STATUS_IN_TRANS == 0, L == 0;
        Status band ?SERVER_STATUS_IN_TRANS /= 0, L > 0 ->
     Query = case L of
         0 -> <<"BEGIN">>;
         _ -> <<"SAVEPOINT s", (integer_to_binary(L))/binary>>
     end,
-    inet:setopts(Socket, [{active, false}]),
-    {ok, [Res = #ok{}]} = mysql_protocol:query(Query, gen_tcp, Socket,
+    SockMod:setopts(Socket, [{active, false}]),
+    SockMod = State#state.sockmod,
+    {ok, [Res = #ok{}]} = mysql_protocol:query(Query, SockMod, Socket,
                                                ?cmd_timeout),
-    inet:setopts(Socket, [{active, once}]),
+    SockMod:setopts(Socket, [{active, once}]),
     State1 = update_state(Res, State),
     {reply, ok, State1#state{transaction_level = L + 1}};
-handle_call(rollback, _From, State = #state{socket = Socket, status = Status,
-                                            transaction_level = L})
+handle_call(rollback, _From, State = #state{socket = Socket, sockmod = SockMod,
+                                            status = Status, transaction_level = L})
   when Status band ?SERVER_STATUS_IN_TRANS /= 0, L >= 1 ->
     Query = case L of
         1 -> <<"ROLLBACK">>;
         _ -> <<"ROLLBACK TO s", (integer_to_binary(L - 1))/binary>>
     end,
-    inet:setopts(Socket, [{active, false}]),
-    {ok, [Res = #ok{}]} = mysql_protocol:query(Query, gen_tcp, Socket,
+    SockMod:setopts(Socket, [{active, false}]),
+    SockMod = State#state.sockmod,
+    {ok, [Res = #ok{}]} = mysql_protocol:query(Query, SockMod, Socket,
                                                ?cmd_timeout),
-    inet:setopts(Socket, [{active, once}]),
+    SockMod:setopts(Socket, [{active, once}]),
     State1 = update_state(Res, State),
     {reply, ok, State1#state{transaction_level = L - 1}};
-handle_call(commit, _From, State = #state{socket = Socket, status = Status,
-                                          transaction_level = L})
+handle_call(commit, _From, State = #state{socket = Socket, sockmod = SockMod,
+                                          status = Status, transaction_level = L})
   when Status band ?SERVER_STATUS_IN_TRANS /= 0, L >= 1 ->
     Query = case L of
         1 -> <<"COMMIT">>;
         _ -> <<"RELEASE SAVEPOINT s", (integer_to_binary(L - 1))/binary>>
     end,
-    inet:setopts(Socket, [{active, false}]),
-    {ok, [Res = #ok{}]} = mysql_protocol:query(Query, gen_tcp, Socket,
+    SockMod:setopts(Socket, [{active, false}]),
+    SockMod = State#state.sockmod,
+    {ok, [Res = #ok{}]} = mysql_protocol:query(Query, SockMod, Socket,
                                                ?cmd_timeout),
-    inet:setopts(Socket, [{active, once}]),
+    SockMod:setopts(Socket, [{active, once}]),
     State1 = update_state(Res, State),
     {reply, ok, State1#state{transaction_level = L - 1}}.
 
@@ -743,21 +756,23 @@ handle_info(query_cache, #state{query_cache = Cache,
     %% Evict expired queries/statements in the cache used by query/3.
     {Evicted, Cache1} = mysql_cache:evict_older_than(Cache, CacheTime),
     %% Unprepare the evicted statements
-    #state{socket = Socket} = State,
-    inet:setopts(Socket, [{active, false}]),
+    #state{socket = Socket, sockmod = SockMod} = State,
+    SockMod:setopts(Socket, [{active, false}]),
+    SockMod = State#state.sockmod,
     lists:foreach(fun ({_Query, Stmt}) ->
-                      mysql_protocol:unprepare(Stmt, gen_tcp, Socket)
+                      mysql_protocol:unprepare(Stmt, SockMod, Socket)
                   end,
                   Evicted),
-    inet:setopts(Socket, [{active, once}]),
+    SockMod:setopts(Socket, [{active, once}]),
     %% If nonempty, schedule eviction again.
     mysql_cache:size(Cache1) > 0 andalso
         erlang:send_after(CacheTime, self(), query_cache),
     {noreply, State#state{query_cache = Cache1}};
-handle_info(ping, #state{socket = Socket} = State) ->
-    inet:setopts(Socket, [{active, false}]),
-    Ok = mysql_protocol:ping(gen_tcp, Socket),
-    inet:setopts(Socket, [{active, once}]),
+handle_info(ping, #state{socket = Socket, sockmod = SockMod} = State) ->
+    SockMod:setopts(Socket, [{active, false}]),
+    SockMod = State#state.sockmod,
+    Ok = mysql_protocol:ping(SockMod, Socket),
+    SockMod:setopts(Socket, [{active, once}]),
     {noreply, update_state(Ok, State)};
 handle_info({tcp_closed, _Socket}, State) ->
     stop_server(tcp_closed, State);
@@ -767,12 +782,12 @@ handle_info(_Info, State) ->
     {noreply, State}.
 
 %% @private
-terminate(Reason, #state{socket = Socket})
+terminate(Reason, #state{socket = Socket, sockmod = SockMod})
   when Reason == normal; Reason == shutdown ->
       %% Send the goodbye message for politeness.
-      inet:setopts(Socket, [{active, false}]),
-      R = mysql_protocol:quit(gen_tcp, Socket),
-      inet:setopts(Socket, [{active, once}]),
+      SockMod:setopts(Socket, [{active, false}]),
+      R = mysql_protocol:quit(SockMod, Socket),
+      SockMod:setopts(Socket, [{active, once}]),
       R;
 terminate(_Reason, _State) ->
     ok.
@@ -799,13 +814,14 @@ query_call(Conn, CallReq) ->
     end.
 
 %% @doc Executes a prepared statement and returns {Reply, NextState}.
-execute_stmt(Stmt, Args, Timeout, State = #state{socket = Socket}) ->
-    inet:setopts(Socket, [{active, false}]),
-    {ok, Recs} = case mysql_protocol:execute(Stmt, Args, gen_tcp, Socket,
+execute_stmt(Stmt, Args, Timeout, State = #state{socket = Socket, sockmod = SockMod}) ->
+    SockMod:setopts(Socket, [{active, false}]),
+    SockMod = State#state.sockmod,
+    {ok, Recs} = case mysql_protocol:execute(Stmt, Args, SockMod, Socket,
                                              Timeout) of
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
             kill_query(State),
-            mysql_protocol:fetch_execute_response(gen_tcp, Socket,
+            mysql_protocol:fetch_execute_response(SockMod, Socket,
                                                   ?cmd_timeout);
         {error, timeout} ->
             %% For MySQL 4.x.x there is no way to recover from timeout except
@@ -814,7 +830,7 @@ execute_stmt(Stmt, Args, Timeout, State = #state{socket = Socket}) ->
         QueryResult ->
             QueryResult
     end,
-    inet:setopts(Socket, [{active, once}]),
+    SockMod:setopts(Socket, [{active, once}]),
     State1 = lists:foldl(fun update_state/2, State, Recs),
     State1#state.warning_count > 0 andalso State1#state.log_warnings
         andalso log_warnings(State1, Stmt#prepared.orig_query),
@@ -893,12 +909,13 @@ clear_transaction_status(State = #state{status = Status}) ->
                 transaction_level = 0}.
 
 %% @doc Fetches and logs warnings. Query is the query that gave the warnings.
-log_warnings(#state{socket = Socket}, Query) ->
-    inet:setopts(Socket, [{active, false}]),
+log_warnings(#state{socket = Socket, sockmod = SockMod} = State, Query) ->
+    SockMod:setopts(Socket, [{active, false}]),
+    SockMod = State#state.sockmod,
     {ok, [#resultset{rows = Rows}]} = mysql_protocol:query(<<"SHOW WARNINGS">>,
-                                                           gen_tcp, Socket,
+                                                           SockMod, Socket,
                                                            ?cmd_timeout),
-    inet:setopts(Socket, [{active, once}]),
+    SockMod:setopts(Socket, [{active, once}]),
     Lines = [[Level, " ", integer_to_binary(Code), ": ", Message, "\n"]
              || [Level, Code, Message] <- Rows],
     error_logger:warning_msg("~s in ~s~n", [Lines, Query]).
@@ -906,23 +923,22 @@ log_warnings(#state{socket = Socket}, Query) ->
 %% @doc Makes a separate connection and execute KILL QUERY. We do this to get
 %% our main connection back to normal. KILL QUERY appeared in MySQL 5.0.0.
 kill_query(#state{connection_id = ConnId, host = Host, port = Port,
-                  user = User, password = Password,
+                  user = User, password = Password, ssl_opts = SSLOpts,
                   cap_found_rows = SetFoundRows}) ->
     %% Connect socket
     SockOpts = [{active, false}, binary, {packet, raw}],
-    {ok, Socket} = gen_tcp:connect(Host, Port, SockOpts),
+    {ok, Socket0} = mysql_sock_tcp:connect(Host, Port, SockOpts),
 
     %% Exchange handshake communication.
-    Result = mysql_protocol:handshake(User, Password, undefined, gen_tcp,
-                                      Socket, SetFoundRows),
+    Result = mysql_protocol:handshake(User, Password, undefined, mysql_sock_tcp,
+                                      SSLOpts, Socket0, SetFoundRows),
     case Result of
-        #handshake{} ->
+        {ok, #handshake{}, SockMod, Socket} ->
             %% Kill and disconnect
             IdBin = integer_to_binary(ConnId),
-            {ok, [#ok{}]} = mysql_protocol:query(<<"KILL QUERY ",
-                                                   IdBin/binary>>, gen_tcp,
-                                                 Socket, ?cmd_timeout),
-            mysql_protocol:quit(gen_tcp, Socket);
+            {ok, [#ok{}]} = mysql_protocol:query(<<"KILL QUERY ", IdBin/binary>>,
+                                                 SockMod, Socket, ?cmd_timeout),
+            mysql_protocol:quit(SockMod, Socket);
         #error{} = E ->
             error_logger:error_msg("Failed to connect to kill query: ~p",
                                    [error_to_reason(E)])

+ 135 - 84
src/mysql_protocol.erl

@@ -1,5 +1,6 @@
 %% MySQL/OTP – MySQL client library for Erlang/OTP
 %% Copyright (C) 2014 Viktor Söderqvist
+%%               2017 Piotr Nosek, Michal Slaski
 %%
 %% This file is part of MySQL/OTP.
 %%
@@ -26,7 +27,7 @@
 %% @private
 -module(mysql_protocol).
 
--export([handshake/6, quit/2, ping/2,
+-export([handshake/7, quit/2, ping/2,
          query/4, fetch_query_response/3,
          prepare/3, unprepare/3, execute/5, fetch_execute_response/3]).
 
@@ -45,23 +46,28 @@
 %% @doc Performs a handshake using the supplied functions for communication.
 %% Returns an ok or an error record. Raises errors when various unimplemented
 %% features are requested.
--spec handshake(iodata(), iodata(), iodata() | undefined, atom(),
-                term(), boolean()) -> #handshake{} | #error{}.
-handshake(Username, Password, Database, TcpModule, Socket, SetFoundRows) ->
+-spec handshake(Username :: iodata(), Password :: iodata(), Database :: iodata() | undefined,
+                SockModule :: module(), SSLOpts :: list() | undefined, Socket :: term(),
+                SetFoundRows :: boolean()) ->
+    {ok, #handshake{}, SockModule :: module(), Socket :: term()} | #error{}.
+handshake(Username, Password, Database, SockModule0, SSLOpts, Socket0, SetFoundRows) ->
     SeqNum0 = 0,
-    {ok, HandshakePacket, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum0),
+    {ok, HandshakePacket, SeqNum1} = recv_packet(SockModule0, Socket0, SeqNum0),
     Handshake = parse_handshake(HandshakePacket),
+    {ok, SockModule, Socket, SeqNum2}
+    = maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, Handshake,
+                           SSLOpts, Database, SetFoundRows),
     Response = build_handshake_response(Handshake, Username, Password,
                                         Database, SetFoundRows),
-    {ok, SeqNum2} = send_packet(TcpModule, Socket, Response, SeqNum1),
-    handshake_finish_or_switch_auth(Handshake, Password, TcpModule, Socket, SeqNum2).
+    {ok, SeqNum3} = send_packet(SockModule, Socket, Response, SeqNum2),
+    handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket, SeqNum3).
 
-handshake_finish_or_switch_auth(Handshake, Password, TcpModule, Socket, SeqNum0) ->
-    {ok, ConfirmPacket, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum0),
+handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket, SeqNum0) ->
+    {ok, ConfirmPacket, SeqNum1} = recv_packet(SockModule, Socket, SeqNum0),
     case parse_handshake_confirm(ConfirmPacket) of
         #ok{status = OkStatus} ->
             OkStatus = Handshake#handshake.status,
-            Handshake;
+            {ok, Handshake, SockModule, Socket};
         #auth_method_switch{auth_plugin_name = AuthPluginName, auth_plugin_data = AuthPluginData} ->
             Hash = case AuthPluginName of
                        <<>> ->
@@ -71,45 +77,45 @@ handshake_finish_or_switch_auth(Handshake, Password, TcpModule, Socket, SeqNum0)
                        UnknownAuthMethod ->
                            error({auth_method, UnknownAuthMethod})
                    end,
-            {ok, SeqNum2} = send_packet(TcpModule, Socket, Hash, SeqNum1),
-            handshake_finish_or_switch_auth(Handshake, Password, TcpModule, Socket, SeqNum2);
+            {ok, SeqNum2} = send_packet(SockModule, Socket, Hash, SeqNum1),
+            handshake_finish_or_switch_auth(Handshake, Password, SockModule, Socket, SeqNum2);
         Error ->
             Error
     end.
 
 -spec quit(atom(), term()) -> ok.
-quit(TcpModule, Socket) ->
-    {ok, SeqNum1} = send_packet(TcpModule, Socket, <<?COM_QUIT>>, 0),
-    case recv_packet(TcpModule, Socket, SeqNum1) of
+quit(SockModule, Socket) ->
+    {ok, SeqNum1} = send_packet(SockModule, Socket, <<?COM_QUIT>>, 0),
+    case recv_packet(SockModule, Socket, SeqNum1) of
         {error, closed} -> ok;            %% MySQL 5.5.40 and more
         {ok, ?ok_pattern, _SeqNum2} -> ok %% Some older MySQL versions?
     end.
 
 -spec ping(atom(), term()) -> #ok{}.
-ping(TcpModule, Socket) ->
-    {ok, SeqNum1} = send_packet(TcpModule, Socket, <<?COM_PING>>, 0),
-    {ok, OkPacket, _SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
+ping(SockModule, Socket) ->
+    {ok, SeqNum1} = send_packet(SockModule, Socket, <<?COM_PING>>, 0),
+    {ok, OkPacket, _SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
     parse_ok_packet(OkPacket).
 
 -spec query(Query :: iodata(), atom(), term(), timeout()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
-query(Query, TcpModule, Socket, Timeout) ->
+query(Query, SockModule, Socket, Timeout) ->
     Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
     SeqNum0 = 0,
-    {ok, _SeqNum1} = send_packet(TcpModule, Socket, Req, SeqNum0),
-    fetch_query_response(TcpModule, Socket, Timeout).
+    {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, SeqNum0),
+    fetch_query_response(SockModule, Socket, Timeout).
 
 %% @doc This is used by query/4. If query/4 returns {error, timeout}, this
 %% function can be called to retry to fetch the results of the query.
-fetch_query_response(TcpModule, Socket, Timeout) ->
-    fetch_response(TcpModule, Socket, Timeout, text, []).
+fetch_query_response(SockModule, Socket, Timeout) ->
+    fetch_response(SockModule, Socket, Timeout, text, []).
 
 %% @doc Prepares a statement.
 -spec prepare(iodata(), atom(), term()) -> #error{} | #prepared{}.
-prepare(Query, TcpModule, Socket) ->
+prepare(Query, SockModule, Socket) ->
     Req = <<?COM_STMT_PREPARE, (iolist_to_binary(Query))/binary>>,
-    {ok, SeqNum1} = send_packet(TcpModule, Socket, Req, 0),
-    {ok, Resp, SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
+    {ok, SeqNum1} = send_packet(SockModule, Socket, Req, 0),
+    {ok, Resp, SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
     case Resp of
         ?error_pattern ->
             parse_error_packet(Resp);
@@ -125,13 +131,13 @@ prepare(Query, TcpModule, Socket) ->
             %% with charset 'binary' so we have to select a type ourselves for
             %% the parameters we have in execute/4.
             {_ParamDefs, SeqNum3} =
-                fetch_column_definitions_if_any(NumParams, TcpModule, Socket,
+                fetch_column_definitions_if_any(NumParams, SockModule, Socket,
                                                 SeqNum2),
             %% Column Definition Block. We get column definitions in execute
             %% too, so we don't need them here. We *could* store them to be able
             %% to provide the user with some info about a prepared statement.
             {_ColDefs, _SeqNum4} =
-                fetch_column_definitions_if_any(NumColumns, TcpModule, Socket,
+                fetch_column_definitions_if_any(NumColumns, SockModule, Socket,
                                                 SeqNum3),
             #prepared{statement_id = StmtId,
                       orig_query = Query,
@@ -141,8 +147,8 @@ prepare(Query, TcpModule, Socket) ->
 
 %% @doc Deallocates a prepared statement.
 -spec unprepare(#prepared{}, atom(), term()) -> ok.
-unprepare(#prepared{statement_id = Id}, TcpModule, Socket) ->
-    {ok, _SeqNum} = send_packet(TcpModule, Socket,
+unprepare(#prepared{statement_id = Id}, SockModule, Socket) ->
+    {ok, _SeqNum} = send_packet(SockModule, Socket,
                                 <<?COM_STMT_CLOSE, Id:32/little>>, 0),
     ok.
 
@@ -150,7 +156,7 @@ unprepare(#prepared{statement_id = Id}, TcpModule, Socket) ->
 -spec execute(#prepared{}, [term()], atom(), term(), timeout()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
 execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
-        TcpModule, Socket, Timeout) when ParamCount == length(ParamValues) ->
+        SockModule, Socket, Timeout) when ParamCount == length(ParamValues) ->
     %% Flags Constant Name
     %% 0x00 CURSOR_TYPE_NO_CURSOR
     %% 0x01 CURSOR_TYPE_READ_ONLY
@@ -176,13 +182,13 @@ execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
             {TypesAndSigns, EncValues} = lists:unzip(EncodedParams),
             iolist_to_binary([Req1, TypesAndSigns, EncValues])
     end,
-    {ok, _SeqNum1} = send_packet(TcpModule, Socket, Req, 0),
-    fetch_execute_response(TcpModule, Socket, Timeout).
+    {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, 0),
+    fetch_execute_response(SockModule, Socket, Timeout).
 
 %% @doc This is used by execute/5. If execute/5 returns {error, timeout}, this
 %% function can be called to retry to fetch the results of the query.
-fetch_execute_response(TcpModule, Socket, Timeout) ->
-    fetch_response(TcpModule, Socket, Timeout, binary, []).
+fetch_execute_response(SockModule, Socket, Timeout) ->
+    fetch_response(SockModule, Socket, Timeout, binary, []).
 
 %% --- internal ---
 
@@ -237,32 +243,50 @@ server_version_to_list(ServerVersion) ->
                             [{capture, all_but_first, binary}]),
     lists:map(fun binary_to_integer/1, Parts).
 
+-spec maybe_do_ssl_upgrade(SockModule0 :: module(),
+                           Socket0 :: term(),
+                           SeqNum1 :: non_neg_integer(),
+                           Handshake :: #handshake{},
+                           SSLOpts :: undefined | list(),
+                           Database :: iodata() | undefined,
+                           SetFoundRows :: boolean()) ->
+    {ok, SockModule :: module(), Socket :: term(), SeqNum2 :: non_neg_integer()}.
+maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, _Handshake, undefined,
+                     _Database, _SetFoundRows) ->
+    {ok, SockModule0, Socket0, SeqNum1};
+maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, Handshake, SSLOpts, Database, SetFoundRows) ->
+    Response = build_handshake_response(Handshake, Database, SetFoundRows),
+    {ok, SeqNum2} = send_packet(SockModule0, Socket0, Response, SeqNum1),
+    case mysql_sock_ssl:connect(Socket0, SSLOpts, 5000) of
+        {ok, SSLSocket} ->
+            {ok, ssl, SSLSocket, SeqNum2};
+        {error, Reason} ->
+            exit({failed_to_upgrade_socket, Reason})
+    end.
+
+-spec build_handshake_response(#handshake{}, iodata() | undefined, boolean()) -> binary().
+build_handshake_response(Handshake, Database, SetFoundRows) ->
+    CapabilityFlags = basic_capabilities(Database /= undefined, SetFoundRows),
+    verify_server_capabilities(Handshake, CapabilityFlags),
+    ClientCapabilities = add_client_capabilities(CapabilityFlags),
+    ClientSSLCapabilities = ClientCapabilities bor ?CLIENT_SSL,
+    CharacterSet = ?UTF8,
+    <<ClientSSLCapabilities:32/little,
+      ?MAX_BYTES_PER_PACKET:32/little,
+      CharacterSet:8,
+      0:23/unit:8>>.
+
 %% @doc The response sent by the client to the server after receiving the
 %% initial handshake from the server
 -spec build_handshake_response(#handshake{}, iodata(), iodata(),
                                iodata() | undefined, boolean()) -> binary().
 build_handshake_response(Handshake, Username, Password, Database, SetFoundRows) ->
-    %% We require these capabilities. Make sure the server handles them.
-    CapabilityFlags0 = ?CLIENT_PROTOCOL_41 bor
-                       ?CLIENT_TRANSACTIONS bor
-                       ?CLIENT_SECURE_CONNECTION,
-    CapabilityFlags1 = case Database of
-        undefined -> CapabilityFlags0;
-        _         -> CapabilityFlags0 bor ?CLIENT_CONNECT_WITH_DB
-    end,
-    CapabilityFlags = case SetFoundRows of
-        true -> CapabilityFlags1 bor ?CLIENT_FOUND_ROWS;
-        _    -> CapabilityFlags1
-    end,
-    Handshake#handshake.capabilities band CapabilityFlags == CapabilityFlags
-        orelse error(old_server_version),
+    CapabilityFlags = basic_capabilities(Database /= undefined, SetFoundRows),
+    verify_server_capabilities(Handshake, CapabilityFlags),
     %% Add some extra capability flags only for signalling to the server what
     %% the client wants to do. The server doesn't say it handles them although
     %% it does. (http://bugs.mysql.com/bug.php?id=42268)
-    ClientCapabilityFlags = CapabilityFlags bor
-                            ?CLIENT_MULTI_STATEMENTS bor
-                            ?CLIENT_MULTI_RESULTS bor
-                            ?CLIENT_PS_MULTI_RESULTS,
+    ClientCapabilityFlags = add_client_capabilities(CapabilityFlags),
     Hash = case Handshake#handshake.auth_plugin_name of
         <<>> ->
             %% Server doesn't know auth plugins
@@ -289,6 +313,33 @@ build_handshake_response(Handshake, Username, Password, Database, SetFoundRows)
       Hash/binary,
       DbBin/binary>>.
 
+-spec verify_server_capabilities(Handshake :: #handshake{}, CapabilityFlags :: integer()) -> true | no_return().
+verify_server_capabilities(Handshake, CapabilityFlags) ->
+    %% We require these capabilities. Make sure the server handles them.
+    Handshake#handshake.capabilities band CapabilityFlags == CapabilityFlags
+        orelse error(old_server_version).
+
+-spec basic_capabilities(ConnectWithDB :: boolean(), SetFoundRows :: boolean()) -> integer().
+basic_capabilities(ConnectWithDB, SetFoundRows) ->
+    CapabilityFlags0 = ?CLIENT_PROTOCOL_41 bor
+                       ?CLIENT_TRANSACTIONS bor
+                       ?CLIENT_SECURE_CONNECTION,
+    CapabilityFlags1 = case ConnectWithDB of
+                           true -> CapabilityFlags0 bor ?CLIENT_CONNECT_WITH_DB;
+                           _ -> CapabilityFlags0
+                       end,
+    case SetFoundRows of
+        true -> CapabilityFlags1 bor ?CLIENT_FOUND_ROWS;
+        _    -> CapabilityFlags1
+    end.
+
+-spec add_client_capabilities(Caps :: integer()) -> integer().
+add_client_capabilities(Caps) ->
+    Caps bor
+    ?CLIENT_MULTI_STATEMENTS bor
+    ?CLIENT_MULTI_RESULTS bor
+    ?CLIENT_PS_MULTI_RESULTS.
+
 %% @doc Handles the second packet from the server, when we have replied to the
 %% initial handshake. Returns an error if the server returns an error. Raises
 %% an error if unimplemented features are required.
@@ -321,8 +372,8 @@ parse_handshake_confirm(Packet) ->
 %% prepared statements).
 -spec fetch_response(atom(), term(), timeout(), text | binary, list()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
-fetch_response(TcpModule, Socket, Timeout, Proto, Acc) ->
-    case recv_packet(TcpModule, Socket, Timeout, any) of
+fetch_response(SockModule, Socket, Timeout, Proto, Acc) ->
+    case recv_packet(SockModule, Socket, Timeout, any) of
         {ok, Packet, SeqNum2} ->
             Result = case Packet of
                 ?ok_pattern ->
@@ -332,7 +383,7 @@ fetch_response(TcpModule, Socket, Timeout, Proto, Acc) ->
                 ResultPacket ->
                     %% The first packet in a resultset is only the column count.
                     {ColCount, <<>>} = lenenc_int(ResultPacket),
-                    R0 = fetch_resultset(TcpModule, Socket, ColCount, SeqNum2),
+                    R0 = fetch_resultset(SockModule, Socket, ColCount, SeqNum2),
                     case R0 of
                         #error{} = E ->
                             %% TODO: Find a way to get here + testcase
@@ -344,7 +395,7 @@ fetch_response(TcpModule, Socket, Timeout, Proto, Acc) ->
             Acc1 = [Result | Acc],
             case more_results_exists(Result) of
                 true ->
-                    fetch_response(TcpModule, Socket, Timeout, Proto, Acc1);
+                    fetch_response(SockModule, Socket, Timeout, Proto, Acc1);
                 false ->
                     {ok, lists:reverse(Acc1)}
             end;
@@ -358,12 +409,12 @@ fetch_response(TcpModule, Socket, Timeout, Proto, Acc) ->
 %% be parsed.
 -spec fetch_resultset(atom(), term(), integer(), integer()) ->
     #resultset{} | #error{}.
-fetch_resultset(TcpModule, Socket, FieldCount, SeqNum) ->
-    {ok, ColDefs, SeqNum1} = fetch_column_definitions(TcpModule, Socket, SeqNum,
+fetch_resultset(SockModule, Socket, FieldCount, SeqNum) ->
+    {ok, ColDefs, SeqNum1} = fetch_column_definitions(SockModule, Socket, SeqNum,
                                                       FieldCount, []),
-    {ok, DelimiterPacket, SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
+    {ok, DelimiterPacket, SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
     #eof{status = S, warning_count = W} = parse_eof_packet(DelimiterPacket),
-    case fetch_resultset_rows(TcpModule, Socket, SeqNum2, []) of
+    case fetch_resultset_rows(SockModule, Socket, SeqNum2, []) of
         {ok, Rows, _SeqNum3} ->
             ColDefs1 = lists:map(fun parse_column_definition/1, ColDefs),
             #resultset{cols = ColDefs1, rows = Rows,
@@ -393,12 +444,12 @@ more_results_exists(#resultset{status = S}) ->
 -spec fetch_column_definitions(atom(), term(), SeqNum :: integer(),
                                NumLeft :: integer(), Acc :: [binary()]) ->
     {ok, ColDefPackets :: [binary()], NextSeqNum :: integer()}.
-fetch_column_definitions(TcpModule, Socket, SeqNum, NumLeft, Acc)
+fetch_column_definitions(SockModule, Socket, SeqNum, NumLeft, Acc)
   when NumLeft > 0 ->
-    {ok, Packet, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum),
-    fetch_column_definitions(TcpModule, Socket, SeqNum1, NumLeft - 1,
+    {ok, Packet, SeqNum1} = recv_packet(SockModule, Socket, SeqNum),
+    fetch_column_definitions(SockModule, Socket, SeqNum1, NumLeft - 1,
                              [Packet | Acc]);
-fetch_column_definitions(_TcpModule, _Socket, SeqNum, 0, Acc) ->
+fetch_column_definitions(_SockModule, _Socket, SeqNum, 0, Acc) ->
     {ok, lists:reverse(Acc), SeqNum}.
 
 %% @doc Fetches rows in a result set. There is a packet per row. The row packets
@@ -408,15 +459,15 @@ fetch_column_definitions(_TcpModule, _Socket, SeqNum, 0, Acc) ->
     {ok, Rows, integer()} | #error{}
     when Acc :: [binary()],
          Rows :: [binary()].
-fetch_resultset_rows(TcpModule, Socket, SeqNum, Acc) ->
-    {ok, Packet, SeqNum1} = recv_packet(TcpModule, Socket, SeqNum),
+fetch_resultset_rows(SockModule, Socket, SeqNum, Acc) ->
+    {ok, Packet, SeqNum1} = recv_packet(SockModule, Socket, SeqNum),
     case Packet of
         ?error_pattern ->
             parse_error_packet(Packet);
         ?eof_pattern ->
             {ok, lists:reverse(Acc), SeqNum1};
         Row ->
-            fetch_resultset_rows(TcpModule, Socket, SeqNum1, [Row | Acc])
+            fetch_resultset_rows(SockModule, Socket, SeqNum1, [Row | Acc])
     end.
 
 %% Parses a packet containing a column definition (part of a result set)
@@ -543,12 +594,12 @@ decode_text(#col{type = T}, Text) when T == ?TYPE_FLOAT;
 
 %% @doc If NumColumns is non-zero, fetches this number of column definitions
 %% and an EOF packet. Used by prepare/3.
-fetch_column_definitions_if_any(0, _TcpModule, _Socket, SeqNum) ->
+fetch_column_definitions_if_any(0, _SockModule, _Socket, SeqNum) ->
     {[], SeqNum};
-fetch_column_definitions_if_any(N, TcpModule, Socket, SeqNum) ->
-    {ok, Defs, SeqNum1} = fetch_column_definitions(TcpModule, Socket, SeqNum,
+fetch_column_definitions_if_any(N, SockModule, Socket, SeqNum) ->
+    {ok, Defs, SeqNum1} = fetch_column_definitions(SockModule, Socket, SeqNum,
                                                    N, []),
-    {ok, ?eof_pattern, SeqNum2} = recv_packet(TcpModule, Socket, SeqNum1),
+    {ok, ?eof_pattern, SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
     {Defs, SeqNum2}.
 
 %% @doc Decodes a packet representing a row in a binary result set.
@@ -884,40 +935,40 @@ decode_decimal(Bin, P, S) when P >= 16, S > 0 ->
 
 %% -- Protocol basics: packets --
 
-%% @doc Wraps Data in packet headers, sends it by calling TcpModule:send/2 with
+%% @doc Wraps Data in packet headers, sends it by calling SockModule:send/2 with
 %% Socket and returns {ok, SeqNum1} where SeqNum1 is the next sequence number.
 -spec send_packet(atom(), term(), Data :: binary(), SeqNum :: integer()) ->
     {ok, NextSeqNum :: integer()}.
-send_packet(TcpModule, Socket, Data, SeqNum) ->
+send_packet(SockModule, Socket, Data, SeqNum) ->
     {WithHeaders, SeqNum1} = add_packet_headers(Data, SeqNum),
-    ok = TcpModule:send(Socket, WithHeaders),
+    ok = SockModule:send(Socket, WithHeaders),
     {ok, SeqNum1}.
 
 %% @see recv_packet/4
-recv_packet(TcpModule, Socket, SeqNum) ->
-    recv_packet(TcpModule, Socket, infinity, SeqNum).
+recv_packet(SockModule, Socket, SeqNum) ->
+    recv_packet(SockModule, Socket, infinity, SeqNum).
 
-%% @doc Receives data by calling TcpModule:recv/2 and removes the packet
+%% @doc Receives data by calling SockModule:recv/2 and removes the packet
 %% headers. Returns the packet contents and the next packet sequence number.
 -spec recv_packet(atom(), term(), timeout(), integer() | any) ->
     {ok, Data :: binary(), NextSeqNum :: integer()} | {error, term()}.
-recv_packet(TcpModule, Socket, Timeout, SeqNum) ->
-    recv_packet(TcpModule, Socket, Timeout, SeqNum, <<>>).
+recv_packet(SockModule, Socket, Timeout, SeqNum) ->
+    recv_packet(SockModule, Socket, Timeout, SeqNum, <<>>).
 
 %% @doc Accumulating helper for recv_packet/4
 -spec recv_packet(atom(), term(), timeout(), integer() | any, binary()) ->
     {ok, Data :: binary(), NextSeqNum :: integer()} | {error, term()}.
-recv_packet(TcpModule, Socket, Timeout, ExpectSeqNum, Acc) ->
-    case TcpModule:recv(Socket, 4, Timeout) of
+recv_packet(SockModule, Socket, Timeout, ExpectSeqNum, Acc) ->
+    case SockModule:recv(Socket, 4, Timeout) of
         {ok, Header} ->
             {Size, SeqNum, More} = parse_packet_header(Header),
             true = SeqNum == ExpectSeqNum orelse ExpectSeqNum == any,
-            {ok, Body} = TcpModule:recv(Socket, Size),
+            {ok, Body} = SockModule:recv(Socket, Size),
             Acc1 = <<Acc/binary, Body/binary>>,
             NextSeqNum = (SeqNum + 1) band 16#ff,
             case More of
                 false -> {ok, Acc1, NextSeqNum};
-                true  -> recv_packet(TcpModule, Socket, Timeout, NextSeqNum,
+                true  -> recv_packet(SockModule, Socket, Timeout, NextSeqNum,
                                      Acc1)
             end;
         {error, Reason} ->

+ 64 - 0
src/mysql_sock_ssl.erl

@@ -0,0 +1,64 @@
+%% MySQL/OTP – MySQL client library for Erlang/OTP
+%% Copyright (C) 2017 Piotr Nosek, Michal Slaski
+%%
+%% This file is part of MySQL/OTP.
+%%
+%% MySQL/OTP is free software: you can redistribute it and/or modify it under
+%% the terms of the GNU Lesser General Public License as published by the Free
+%% Software Foundation, either version 3 of the License, or (at your option)
+%% any later version.
+%%
+%% This program is distributed in the hope that it will be useful, but WITHOUT
+%% ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+%% FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+%% more details.
+%%
+%% You should have received a copy of the GNU Lesser General Public License
+%% along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+%% @doc This module provides SSL socket interface, i.e. is a proxy to ssl module.
+%% @private
+-module(mysql_sock_ssl).
+
+-export([connect/3, close/1, send/2, recv/2, recv/3]).
+-export([setopts/2]).
+
+%% --------------------------------------------------
+%% API
+%% --------------------------------------------------
+
+connect(Port, ConfigSSLOpts, Timeout) ->
+    DefaultSSLOpts = [{versions, [tlsv1]}, {verify, verify_peer}],
+    MandatorySSLOpts = [{active, false}],
+    MergedSSLOpts = merge_ssl_options(DefaultSSLOpts, MandatorySSLOpts, ConfigSSLOpts),
+    ssl:connect(Port, MergedSSLOpts, Timeout).
+
+close(Socket) ->
+    ssl:close(Socket).
+
+send(Socket, Packet) ->
+    ssl:send(Socket, Packet).
+
+recv(Socket, Length) ->
+    ssl:recv(Socket, Length).
+
+recv(Socket, Length, Timeout) ->
+    ssl:recv(Socket, Length, Timeout).
+
+setopts(Socket, SockOpts) ->
+    ssl:setopts(Socket, SockOpts).
+
+%% --------------------------------------------------
+%% Internal functions
+%% --------------------------------------------------
+
+-spec merge_ssl_options(list(), list(), list()) -> list().
+merge_ssl_options(DefaultSSLOpts, MandatorySSLOpts, ConfigSSLOpts) ->
+    SSLOpts1 =
+    lists:foldl(fun({Key, _} = Opt, OptsAcc) ->
+                        lists:keystore(Key, 1, OptsAcc, Opt)
+                end, DefaultSSLOpts, ConfigSSLOpts),
+    lists:foldl(fun({Key, _} = Opt, OptsAcc) ->
+                        lists:keystore(Key, 1, OptsAcc, Opt)
+                end, SSLOpts1, MandatorySSLOpts).
+

+ 42 - 0
src/mysql_sock_tcp.erl

@@ -0,0 +1,42 @@
+%% MySQL/OTP – MySQL client library for Erlang/OTP
+%% Copyright (C) 2017 Piotr Nosek, Michal Slaski
+%%
+%% This file is part of MySQL/OTP.
+%%
+%% MySQL/OTP is free software: you can redistribute it and/or modify it under
+%% the terms of the GNU Lesser General Public License as published by the Free
+%% Software Foundation, either version 3 of the License, or (at your option)
+%% any later version.
+%%
+%% This program is distributed in the hope that it will be useful, but WITHOUT
+%% ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+%% FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+%% more details.
+%%
+%% You should have received a copy of the GNU Lesser General Public License
+%% along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+%% @doc This module provides TCP socket interface, i.e. is a proxy to gen_tcp and inet.
+%% @private
+-module(mysql_sock_tcp).
+
+-export([connect/3, close/1, send/2, recv/2, recv/3]).
+-export([setopts/2]).
+
+connect(Host, Port, SockOpts) ->
+    gen_tcp:connect(Host, Port, SockOpts).
+
+close(Socket) ->
+    gen_tcp:close(Socket).
+
+send(Socket, Packet) ->
+    gen_tcp:send(Socket, Packet).
+
+recv(Socket, Length) ->
+    gen_tcp:recv(Socket, Length).
+
+recv(Socket, Length, Timeout) ->
+    gen_tcp:recv(Socket, Length, Timeout).
+
+setopts(Socket, SockOpts) ->
+    inet:setopts(Socket, SockOpts).

+ 3 - 1
test/mysql_protocol_tests.erl

@@ -1,5 +1,6 @@
 %% MySQL/OTP – MySQL client library for Erlang/OTP
 %% Copyright (C) 2014 Viktor Söderqvist
+%%               2017 Piotr Nosek
 %%
 %% This file is part of MySQL/OTP.
 %%
@@ -113,8 +114,9 @@ prepare_test() ->
 
 bad_protocol_version_test() ->
     Sock = mock_tcp:create([{recv, <<2, 0, 0, 0, 9, 0>>}]),
+    UndefSSLOpts = undefined,
     ?assertError(unknown_protocol,
-                 mysql_protocol:handshake("foo", "bar", "db", mock_tcp, Sock, false)),
+                 mysql_protocol:handshake("foo", "bar", "db", mock_tcp, UndefSSLOpts, Sock, false)),
     mock_tcp:close(Sock).
 
 %% --- Helper functions for the above tests ---

+ 39 - 11
test/mysql_tests.erl

@@ -1,5 +1,6 @@
 %% MySQL/OTP – MySQL client library for Erlang/OTP
 %% Copyright (C) 2014-2016 Viktor Söderqvist
+%%               2017 Piotr Nosek
 %%
 %% This file is part of MySQL/OTP.
 %%
@@ -21,8 +22,10 @@
 
 -include_lib("eunit/include/eunit.hrl").
 
--define(user,     "otptest").
--define(password, "otptest").
+-define(user,         "otptest").
+-define(password,     "otptest").
+-define(ssl_user,     "otptestssl").
+-define(ssl_password, "otptestssl").
 
 %% We need to set a the SQL mode so it is consistent across MySQL versions
 %% and distributions.
@@ -49,26 +52,51 @@ failing_connect_test() ->
     receive
         {'EXIT', _Pid, {1045, <<"28000">>, <<"Access denie", _/binary>>}} -> ok
     after 1000 ->
-        ?assertEqual(ok, no_exit_message)
+        error(no_exit_message)
     end,
     process_flag(trap_exit, false).
 
 successful_connect_test() ->
     %% A connection with a registered name and execute initial queries and
     %% create prepared statements.
-    Options = [{name, {local, tardis}}, {user, ?user}, {password, ?password},
+    Pid = common_basic_check([{user, ?user}, {password, ?password}]),
+
+    %% Test some gen_server callbacks not tested elsewhere
+    State = get_state(Pid),
+    ?assertMatch({ok, State}, mysql:code_change("0.1.0", State, [])),
+    ?assertMatch({error, _}, mysql:code_change("2.0.0", unknown_state, [])),
+    common_conn_close().
+
+successful_ssl_connect_test() ->
+    %% The same test as successful_connect_test(), minus gen_server checks,
+    %% plus SSL
+    [ application:start(App) || App <- [crypto, asn1, public_key, ssl] ],
+    common_basic_check([{ssl, [{cacertfile, "test/ssl/ca.pem"}]},
+                        {user, ?ssl_user}, {password, ?ssl_password}]),
+    common_conn_close(),
+    ok.
+
+common_basic_check(ExtraOpts) ->
+    Options = [{name, {local, tardis}},
                {queries, ["SET @foo = 'bar'", "SELECT 1",
                           "SELECT 1; SELECT 2"]},
-               {prepare, [{foo, "SELECT @foo"}]}],
+               {prepare, [{foo, "SELECT @foo"}]} | ExtraOpts],
     {ok, Pid} = mysql:start_link(Options),
     %% Check that queries and prepare has been done.
     ?assertEqual({ok, [<<"@foo">>], [[<<"bar">>]]},
                  mysql:execute(Pid, foo, [])),
-    %% Test some gen_server callbacks not tested elsewhere
-    State = get_state(Pid),
-    ?assertMatch({ok, State}, mysql:code_change("0.1.0", State, [])),
-    ?assertMatch({error, _}, mysql:code_change("2.0.0", unknown_state, [])),
-    exit(whereis(tardis), normal).
+    Pid.
+
+common_conn_close() ->
+    Pid = whereis(tardis),
+    process_flag(trap_exit, true),
+    exit(Pid, normal),
+    receive
+        {'EXIT', Pid, normal} -> ok
+    after
+        5000 -> error({cant_stop_connection, Pid})
+    end,
+    process_flag(trap_exit, false).
 
 server_disconnect_test() ->
     process_flag(trap_exit, true),
@@ -102,7 +130,7 @@ tcp_error_test() ->
         receive
             {'EXIT', Pid, {tcp_error, tcp_reason}} -> ok
         after 1000 ->
-            ?assertEqual(ok, no_exit_message)
+            error(no_exit_message)
         end
     end),
     process_flag(trap_exit, false),

+ 27 - 0
test/ssl/Makefile

@@ -0,0 +1,27 @@
+.PHONY: all clean
+
+SSLDIR ?= $(shell pwd)
+CAKEY = ca.key
+CACERT = ca.pem
+SERVERKEY = server-key.pem
+SERVERCSR = server.csr
+SERVERCERT = server-cert.pem
+
+CASTRING = "/C=PL/L=Krakow/CN=MYSQL CA"
+SERVERSTRING = "/C=PL/L=Krakow/CN=localhost"
+
+all:
+	openssl genrsa -out $(CAKEY) 2048
+	openssl req -x509 -new -nodes -key $(CAKEY) -sha256 -days 1024 -out $(CACERT) -subj $(CASTRING)
+	openssl genrsa -out $(SERVERKEY) 2048
+	openssl req -new -key $(SERVERKEY) -out $(SERVERCSR) -subj $(SERVERSTRING)
+	openssl x509 -req -in $(SERVERCSR) -CA $(CACERT) -CAkey $(CAKEY) -CAcreateserial -out $(SERVERCERT) -days 500 -sha256
+	cp my-ssl.cnf.template my-ssl.cnf
+	sed -i -e "s~%%CACERT%%~$(SSLDIR)/$(CACERT)~g" my-ssl.cnf
+	sed -i -e "s~%%SERVERCERT%%~$(SSLDIR)/$(SERVERCERT)~g" my-ssl.cnf
+	sed -i -e "s~%%SERVERKEY%%~$(SSLDIR)/$(SERVERKEY)~g" my-ssl.cnf
+
+clean:
+	rm -f ca*
+	rm -f server*
+	rm -f my-ssl.cnf my-ssl.cnf-e

+ 4 - 0
test/ssl/my-ssl.cnf.template

@@ -0,0 +1,4 @@
+[mysqld]
+ssl-ca=%%CACERT%%
+ssl-cert=%%SERVERCERT%%
+ssl-key=%%SERVERKEY%%