Просмотр исходного кода

Default SNI to hostname for SSL connections

The ssl option `server_name_indication` did not have a default value,
which, in conjunction with the `verify` option defaulting to `verify_peer`,
meant that the IP address will be used in order to check the identities
in the certificate.
This change sets the default value for `server_name_indication` to the
value of the `host` option given in `mysql:start_link/1` if it is a hostname
string (ie, not a IP address tuple or string or an atom). Otherwise, no
default value will be set.
juhlig 4 лет назад
Родитель
Сommit
8d30a851a9
4 измененных файлов с 32 добавлено и 20 удалено
  1. 6 1
      src/mysql.erl
  2. 5 5
      src/mysql_conn.erl
  3. 19 12
      src/mysql_protocol.erl
  4. 2 2
      test/mysql_protocol_tests.erl

+ 6 - 1
src/mysql.erl

@@ -147,7 +147,12 @@
 %%       `{recbuf, Size}' and `{sndbuf, Size}' if you send or receive more than
 %%       the default (typically 8K) per query.</dd>
 %%   <dt>`{ssl, Options}'</dt>
-%%   <dd>Additional options for `ssl:connect/3'.</dd>
+%%   <dd>Additional options for `ssl:connect/3'.<br />
+%%       The `verify' option, if not given explicitly, defaults to
+%%       `verify_peer'.<br />
+%%       The `server_name_indication' option, if omitted, defaults to the value
+%%       of the `host' option if it is a hostname string, otherwise no default
+%%       value is set.</dd>
 %% </dl>
 -spec start_link(Options) -> {ok, pid()} | ignore | {error, term()}
     when Options :: [Option],

+ 5 - 5
src/mysql_conn.erl

@@ -192,11 +192,11 @@ sanitize_tcp_opts(TcpOpts0) ->
     [binary, {packet, raw}, {active, false} | TcpOpts2].
 
 handshake(#state{socket = Socket0, ssl_opts = SSLOpts,
-        user = User, password = Password, database = Database,
-        cap_found_rows = SetFoundRows} = State0) ->
+          host = Host, user = User, password = Password, database = Database,
+          cap_found_rows = SetFoundRows} = State0) ->
     %% Exchange handshake communication.
-    Result = mysql_protocol:handshake(User, Password, Database, gen_tcp, SSLOpts,
-                                      Socket0, SetFoundRows),
+    Result = mysql_protocol:handshake(Host, User, Password, Database, gen_tcp,
+                                      SSLOpts, Socket0, SetFoundRows),
     case Result of
         {ok, Handshake, SockMod, Socket} ->
             setopts(SockMod, Socket, [{active, once}]),
@@ -741,7 +741,7 @@ kill_query(#state{connection_id = ConnId, host = Host, port = Port,
     {ok, Socket0} = gen_tcp:connect(Host, Port, SockOpts),
 
     %% Exchange handshake communication.
-    Result = mysql_protocol:handshake(User, Password, undefined, gen_tcp,
+    Result = mysql_protocol:handshake(Host, User, Password, undefined, gen_tcp,
                                       SSLOpts, Socket0, SetFoundRows),
     case Result of
         {ok, #handshake{}, SockMod, Socket} ->

+ 19 - 12
src/mysql_protocol.erl

@@ -27,7 +27,7 @@
 %% @private
 -module(mysql_protocol).
 
--export([handshake/7, change_user/8, quit/2, ping/2,
+-export([handshake/8, change_user/8, quit/2, ping/2,
          query/4, query/5, fetch_query_response/3,
          fetch_query_response/4, prepare/3, unprepare/3,
          execute/5, execute/6, fetch_execute_response/3,
@@ -66,21 +66,22 @@
 %% @doc Performs a handshake using the supplied socket and socket module for
 %% communication. Returns an ok or an error record. Raises errors when various
 %% unimplemented features are requested.
--spec handshake(Username :: iodata(), Password :: iodata(),
+-spec handshake(Host :: inet:socket_address() | inet:hostname(),
+                Username :: iodata(), Password :: iodata(),
                 Database :: iodata() | undefined,
                 SockModule :: module(), SSLOpts :: list() | undefined,
                 Socket :: term(),
                 SetFoundRows :: boolean()) ->
     {ok, #handshake{}, SockModule :: module(), Socket :: term()} |
     #error{}.
-handshake(Username, Password, Database, SockModule0, SSLOpts, Socket0,
+handshake(Host, Username, Password, Database, SockModule0, SSLOpts, Socket0,
           SetFoundRows) ->
     SeqNum0 = 0,
     {ok, HandshakePacket, SeqNum1} = recv_packet(SockModule0, Socket0, SeqNum0),
     case parse_handshake(HandshakePacket) of
         #handshake{} = Handshake ->
             {ok, SockModule, Socket, SeqNum2} =
-                maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, Handshake,
+                maybe_do_ssl_upgrade(Host, SockModule0, Socket0, SeqNum1, Handshake,
                                      SSLOpts, Database, SetFoundRows),
             Response = build_handshake_response(Handshake, Username, Password,
                                                 Database, SetFoundRows),
@@ -397,7 +398,8 @@ server_version_to_list(ServerVersion) ->
                             [{capture, all_but_first, binary}]),
     lists:map(fun binary_to_integer/1, Parts).
 
--spec maybe_do_ssl_upgrade(SockModule0 :: module(),
+-spec maybe_do_ssl_upgrade(Host :: inet:socket_address() | inet:hostname(),
+                           SockModule0 :: module(),
                            Socket0 :: term(),
                            SeqNum1 :: non_neg_integer(),
                            Handshake :: #handshake{},
@@ -406,24 +408,29 @@ server_version_to_list(ServerVersion) ->
                            SetFoundRows :: boolean()) ->
     {ok, SockModule :: module(), Socket :: term(),
      SeqNum2 :: non_neg_integer()}.
-maybe_do_ssl_upgrade(SockModule0, Socket0, SeqNum1, _Handshake, undefined,
-                     _Database, _SetFoundRows) ->
+maybe_do_ssl_upgrade(_Host, SockModule0, Socket0, SeqNum1, _Handshake,
+                     undefined, _Database, _SetFoundRows) ->
     {ok, SockModule0, Socket0, SeqNum1};
-maybe_do_ssl_upgrade(gen_tcp, Socket0, SeqNum1, Handshake, SSLOpts,
+maybe_do_ssl_upgrade(Host, gen_tcp, Socket0, SeqNum1, Handshake, SSLOpts,
                      Database, SetFoundRows) ->
     Response = build_handshake_response(Handshake, Database, SetFoundRows),
     {ok, SeqNum2} = send_packet(gen_tcp, Socket0, Response, SeqNum1),
-    case ssl_connect(Socket0, SSLOpts, 5000) of
+    case ssl_connect(Host, Socket0, SSLOpts, 5000) of
         {ok, SSLSocket} ->
             {ok, ssl, SSLSocket, SeqNum2};
         {error, Reason} ->
             exit({failed_to_upgrade_socket, Reason})
     end.
 
-ssl_connect(Port, ConfigSSLOpts, Timeout) ->
-    DefaultSSLOpts = [{versions, [tlsv1]}, {verify, verify_peer}],
+ssl_connect(Host, Port, ConfigSSLOpts, Timeout) ->
+    DefaultSSLOpts0 = [{versions, [tlsv1]}, {verify, verify_peer}],
+    DefaultSSLOpts1 = case is_list(Host) andalso inet:parse_address(Host) of
+        false -> DefaultSSLOpts0;
+        {ok, _} -> DefaultSSLOpts0;
+        {error, einval} -> [{server_name_indication, Host} | DefaultSSLOpts0]
+    end,
     MandatorySSLOpts = [{active, false}],
-    MergedSSLOpts = merge_ssl_options(DefaultSSLOpts, MandatorySSLOpts, ConfigSSLOpts),
+    MergedSSLOpts = merge_ssl_options(DefaultSSLOpts1, MandatorySSLOpts, ConfigSSLOpts),
     ssl:connect(Port, MergedSSLOpts, Timeout).
 
 -spec merge_ssl_options(list(), list(), list()) -> list().

+ 2 - 2
test/mysql_protocol_tests.erl

@@ -116,7 +116,7 @@ bad_protocol_version_test() ->
     Sock = mock_tcp:create([{recv, <<2, 0, 0, 0, 9, 0>>}]),
     SSLOpts = undefined,
     ?assertError(unknown_protocol,
-                 mysql_protocol:handshake("foo", "bar", "db", mock_tcp,
+                 mysql_protocol:handshake("foo", "bar", "baz", "db", mock_tcp,
                                           SSLOpts, Sock, false)),
     mock_tcp:close(Sock).
 
@@ -129,7 +129,7 @@ error_as_initial_packet_test() ->
     Sock = mock_tcp:create([{recv, Packet}]),
     SSLOpts = undefined,
     ?assertMatch(#error{code = 1040, msg = <<"Too many connections">>},
-                 mysql_protocol:handshake("foo", "bar", "db", mock_tcp,
+                 mysql_protocol:handshake("foo", "bar", "baz", "db", mock_tcp,
                                           SSLOpts, Sock, false)),
     mock_tcp:close(Sock).