Browse Source

Support LOAD DATA LOCAL revised (#168)

* Support for LOAD DATA LOCAL INFILE
* Allow specifying a list of paths from which to send local files
Jan Uhlig 4 years ago
parent
commit
392c443822
7 changed files with 374 additions and 93 deletions
  1. 1 0
      .travis.yml
  2. 5 0
      include/protocol.hrl
  3. 8 0
      src/mysql.erl
  4. 45 30
      src/mysql_conn.erl
  5. 199 42
      src/mysql_protocol.erl
  6. 4 2
      test/mysql_protocol_tests.erl
  7. 112 19
      test/mysql_tests.erl

+ 1 - 0
.travis.yml

@@ -15,6 +15,7 @@ before_script:
   - sudo chmod -R 660 /etc/mysql/*.pem
   - sudo chown -R mysql:mysql /etc/mysql/*.pem
   - cat test/ssl/my-ssl.cnf | sudo tee -a /etc/mysql/conf.d/my-ssl.cnf
+  - (echo '[mysqld]'; echo 'local_infile=ON') | sudo tee -a /etc/mysql/conf.d/my-otp.cnf
   - sudo service mysql start
   - sleep 5
   - sudo mysql -uroot -e "CREATE USER otptest@localhost IDENTIFIED BY 'otptest';"

+ 5 - 0
include/protocol.hrl

@@ -21,6 +21,7 @@
 -define(EOF, 16#fe).
 -define(MORE_DATA, 16#01).
 -define(ERROR, 16#ff).
+-define(LOCAL_INFILE_REQUEST, 16#fb).
 
 %% Character sets
 -define(UTF8MB3, 16#21). %% utf8_general_ci
@@ -39,6 +40,10 @@
 %% Client: Handshake Response Packet contains a schema-name
 -define(CLIENT_CONNECT_WITH_DB, 16#00000008).
 
+%% Server: Enables the LOCAL INFILE request of LOAD DATA|XML
+%% Client: Will handle LOCAL INFILE request
+-define(CLIENT_LOCAL_FILES, 16#00000080).
+
 %% Server: supports the 4.1 protocol
 %% Client: uses the 4.1 protocol
 -define(CLIENT_PROTOCOL_41, 16#00000200).

+ 8 - 0
src/mysql.erl

@@ -82,6 +82,7 @@
                 | {database, iodata()}
                 | {connect_mode, synchronous | asynchronous | lazy}
                 | {connect_timeout, timeout()}
+                | {allowed_local_paths, [binary()]}
                 | {log_warnings, boolean()}
                 | {log_slow_queries, boolean()}
                 | {keep_alive, boolean() | timeout()}
@@ -142,6 +143,13 @@
 %%   </dd>
 %%   <dt>`{connect_timeout, Timeout}'</dt>
 %%   <dd>The maximum time to spend for start_link/1.</dd>
+%%   <dt>`{allowed_local_paths, [binary()]}'</dt>
+%%   <dd>This option allows you to specify a list of directories or individual
+%%       files on the client machine which the server may request, for example
+%%       when executing a `LOAD DATA LOCAL INFILE' query. Only absolute paths
+%%       without relative components such as `..' and `.' are allowed.
+%%       The default is an empty list, meaning the client will not send any
+%%       local files to the server.</dd>
 %%   <dt>`{log_warnings, boolean()}'</dt>
 %%   <dd>Whether to fetch warnings and log them using error_logger; default
 %%       true.</dd>

+ 45 - 30
src/mysql_conn.erl

@@ -52,7 +52,7 @@
 %% Gen_server state
 -record(state, {server_version, connection_id, socket, sockmod, tcp_opts, ssl_opts,
                 host, port, user, password, database, queries, prepares,
-                auth_plugin_name, auth_plugin_data,
+                auth_plugin_name, auth_plugin_data, allowed_local_paths,
                 log_warnings, log_slow_queries,
                 connect_timeout, ping_timeout, query_timeout, query_cache_time,
                 affected_rows = 0, status = 0, warning_count = 0, insert_id = 0,
@@ -68,26 +68,29 @@ init(Opts) ->
         {local, _LocalAddr} -> 0;
         _NonLocalAddr -> ?default_port
     end,
-    Port           = proplists:get_value(port, Opts, DefaultPort),
-
-    User           = proplists:get_value(user, Opts, ?default_user),
-    Password       = proplists:get_value(password, Opts, ?default_password),
-    Database       = proplists:get_value(database, Opts, undefined),
-    LogWarn        = proplists:get_value(log_warnings, Opts, true),
-    LogSlow        = proplists:get_value(log_slow_queries, Opts, false),
-    KeepAlive      = proplists:get_value(keep_alive, Opts, false),
-    ConnectTimeout = proplists:get_value(connect_timeout, Opts,
-                                         ?default_connect_timeout),
-    QueryTimeout   = proplists:get_value(query_timeout, Opts,
-                                         ?default_query_timeout),
-    QueryCacheTime = proplists:get_value(query_cache_time, Opts,
-                                         ?default_query_cache_time),
-    TcpOpts        = proplists:get_value(tcp_options, Opts, []),
-    SetFoundRows   = proplists:get_value(found_rows, Opts, false),
-    SSLOpts        = proplists:get_value(ssl, Opts, undefined),
-
-    Queries        = proplists:get_value(queries, Opts, []),
-    Prepares       = proplists:get_value(prepare, Opts, []),
+    Port              = proplists:get_value(port, Opts, DefaultPort),
+
+    User              = proplists:get_value(user, Opts, ?default_user),
+    Password          = proplists:get_value(password, Opts, ?default_password),
+    Database          = proplists:get_value(database, Opts, undefined),
+    AllowedLocalPaths = proplists:get_value(allowed_local_paths, Opts, []),
+    LogWarn           = proplists:get_value(log_warnings, Opts, true),
+    LogSlow           = proplists:get_value(log_slow_queries, Opts, false),
+    KeepAlive         = proplists:get_value(keep_alive, Opts, false),
+    ConnectTimeout    = proplists:get_value(connect_timeout, Opts,
+                                            ?default_connect_timeout),
+    QueryTimeout      = proplists:get_value(query_timeout, Opts,
+                                            ?default_query_timeout),
+    QueryCacheTime    = proplists:get_value(query_cache_time, Opts,
+                                            ?default_query_cache_time),
+    TcpOpts           = proplists:get_value(tcp_options, Opts, []),
+    SetFoundRows      = proplists:get_value(found_rows, Opts, false),
+    SSLOpts           = proplists:get_value(ssl, Opts, undefined),
+
+    Queries           = proplists:get_value(queries, Opts, []),
+    Prepares          = proplists:get_value(prepare, Opts, []),
+
+    true = lists:all(fun mysql_protocol:valid_path/1, AllowedLocalPaths),
 
     PingTimeout = case KeepAlive of
         true         -> ?default_ping_timeout;
@@ -101,6 +104,7 @@ init(Opts) ->
         host = Host, port = Port,
         user = User, password = Password,
         database = Database,
+        allowed_local_paths = AllowedLocalPaths,
         queries = Queries, prepares = Prepares,
         log_warnings = LogWarn, log_slow_queries = LogSlow,
         connect_timeout = ConnectTimeout,
@@ -448,6 +452,7 @@ handle_call(start_transaction, {FromPid, _},
     end,
     setopts(SockMod, Socket, [{active, false}]),
     {ok, [Res = #ok{}]} = mysql_protocol:query(Query, SockMod, Socket,
+                                               [], no_filtermap_fun,
                                                ?cmd_timeout),
     setopts(SockMod, Socket, [{active, once}]),
     State1 = update_state(Res, State),
@@ -463,6 +468,7 @@ handle_call(rollback, {FromPid, _},
     end,
     setopts(SockMod, Socket, [{active, false}]),
     {ok, [Res = #ok{}]} = mysql_protocol:query(Query, SockMod, Socket,
+                                               [], no_filtermap_fun,
                                                ?cmd_timeout),
     setopts(SockMod, Socket, [{active, once}]),
     State1 = update_state(Res, State),
@@ -478,6 +484,7 @@ handle_call(commit, {FromPid, _},
     end,
     setopts(SockMod, Socket, [{active, false}]),
     {ok, [Res = #ok{}]} = mysql_protocol:query(Query, SockMod, Socket,
+                                               [], no_filtermap_fun,
                                                ?cmd_timeout),
     setopts(SockMod, Socket, [{active, once}]),
     State1 = update_state(Res, State),
@@ -545,15 +552,17 @@ code_change(_OldVsn, _State, _Extra) ->
 %% --- Helpers ---
 
 %% @doc Executes a prepared statement and returns {Reply, NewState}.
-execute_stmt(Stmt, Args, FilterMap, Timeout,
-             State = #state{socket = Socket, sockmod = SockMod}) ->
+execute_stmt(Stmt, Args, FilterMap, Timeout, State) ->
+    #state{socket = Socket, sockmod = SockMod,
+           allowed_local_paths = AllowedPaths} = State,
     setopts(SockMod, Socket, [{active, false}]),
     {ok, Recs} = case mysql_protocol:execute(Stmt, Args, SockMod, Socket,
-                                             FilterMap, Timeout) of
+                                             AllowedPaths, FilterMap,
+                                             Timeout) of
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
             kill_query(State),
             mysql_protocol:fetch_execute_response(SockMod, Socket,
-                                                  FilterMap, ?cmd_timeout);
+                                                  [], FilterMap, ?cmd_timeout);
         {error, timeout} ->
             %% For MySQL 4.x.x there is no way to recover from timeout except
             %% killing the connection itself.
@@ -595,14 +604,17 @@ update_state(Rec, State) ->
 query(Query, FilterMap, default_timeout,
       #state{query_timeout = DefaultTimeout} = State) ->
     query(Query, FilterMap, DefaultTimeout, State);
-query(Query, FilterMap, Timeout,
-      #state{sockmod = SockMod, socket = Socket} = State) ->
+query(Query, FilterMap, Timeout, State) ->
+    #state{sockmod = SockMod, socket = Socket,
+           allowed_local_paths = AllowedPaths} = State,
     setopts(SockMod, Socket, [{active, false}]),
-    Result = mysql_protocol:query(Query, SockMod, Socket, FilterMap, Timeout),
+    Result = mysql_protocol:query(Query, SockMod, Socket, AllowedPaths,
+                                  FilterMap, Timeout),
     {ok, Recs} = case Result of
         {error, timeout} when State#state.server_version >= [5, 0, 0] ->
             kill_query(State),
-            mysql_protocol:fetch_query_response(SockMod, Socket, FilterMap,
+            mysql_protocol:fetch_query_response(SockMod, Socket,
+                                                [], FilterMap,
                                                 ?cmd_timeout);
         {error, timeout} ->
             %% For MySQL 4.x.x there is no way to recover from timeout except
@@ -703,6 +715,7 @@ log_warnings(#state{socket = Socket, sockmod = SockMod}, Query) ->
     setopts(SockMod, Socket, [{active, false}]),
     {ok, [#resultset{rows = Rows}]} = mysql_protocol:query(<<"SHOW WARNINGS">>,
                                                            SockMod, Socket,
+                                                           [], no_filtermap_fun,
                                                            ?cmd_timeout),
     setopts(SockMod, Socket, [{active, once}]),
     Lines = [[Level, " ", integer_to_binary(Code), ": ", Message, "\n"]
@@ -748,7 +761,9 @@ kill_query(#state{connection_id = ConnId, host = Host, port = Port,
             %% Kill and disconnect
             IdBin = integer_to_binary(ConnId),
             {ok, [#ok{}]} = mysql_protocol:query(<<"KILL QUERY ", IdBin/binary>>,
-                                                 SockMod, Socket, ?cmd_timeout),
+                                                 SockMod, Socket,
+                                                 [], no_filtermap_fun,
+                                                 ?cmd_timeout),
             mysql_protocol:quit(SockMod, Socket);
         #error{} = E ->
             error_logger:error_msg("Failed to connect to kill query: ~p",

+ 199 - 42
src/mysql_protocol.erl

@@ -28,16 +28,11 @@
 -module(mysql_protocol).
 
 -export([handshake/8, change_user/8, quit/2, ping/2,
-         query/4, query/5, fetch_query_response/3,
-         fetch_query_response/4, prepare/3, unprepare/3,
-         execute/5, execute/6, fetch_execute_response/3,
-         fetch_execute_response/4, reset_connnection/2,
-         valid_params/1]).
+         query/6, fetch_query_response/5, prepare/3, unprepare/3,
+         execute/7, fetch_execute_response/5, reset_connnection/2,
+         valid_params/1, valid_path/1]).
 
--type query_filtermap() :: no_filtermap_fun
-                         | fun(([term()]) -> query_filtermap_res())
-                         | fun(([term()], [term()]) -> query_filtermap_res()).
--type query_filtermap_res() :: boolean() | {true, term()}.
+-type query_filtermap() :: no_filtermap_fun | mysql:query_filtermap_fun().
 
 -type auth_more_data() :: fast_auth_completed
                         | full_auth_requested
@@ -56,6 +51,7 @@
 -define(ok_pattern, <<?OK, _/binary>>).
 -define(error_pattern, <<?ERROR, _/binary>>).
 -define(eof_pattern, <<?EOF, _:4/binary>>).
+-define(local_infile_pattern, <<?LOCAL_INFILE_REQUEST, _/binary>>).
 
 %% Macros for auth methods.
 -define(authmethod_none, <<>>).
@@ -74,6 +70,7 @@
                 SetFoundRows :: boolean()) ->
     {ok, #handshake{}, SockModule :: module(), Socket :: term()} |
     #error{}.
+
 handshake(Host, Username, Password, Database, SockModule0, SSLOpts, Socket0,
           SetFoundRows) ->
     SeqNum0 = 0,
@@ -196,26 +193,19 @@ ping(SockModule, Socket) ->
     {ok, OkPacket, _SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
     parse_ok_packet(OkPacket).
 
--spec query(Query :: iodata(), module(), term(), timeout()) ->
-    {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
-query(Query, SockModule, Socket, Timeout) ->
-    query(Query, SockModule, Socket, no_filtermap_fun, Timeout).
-
--spec query(Query :: iodata(), module(), term(), query_filtermap(), timeout()) ->
+-spec query(Query :: iodata(), module(), term(), [binary()], query_filtermap(),
+            timeout()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
-query(Query, SockModule, Socket, FilterMap, Timeout) ->
+query(Query, SockModule, Socket, AllowedPaths, FilterMap, Timeout) ->
     Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
     SeqNum0 = 0,
     {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, SeqNum0),
-    fetch_query_response(SockModule, Socket, FilterMap, Timeout).
+    fetch_query_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout).
 
 %% @doc This is used by query/4. If query/4 returns {error, timeout}, this
 %% function can be called to retry to fetch the results of the query.
-fetch_query_response(SockModule, Socket, Timeout) ->
-    fetch_query_response(SockModule, Socket, no_filtermap_fun, Timeout).
-
-fetch_query_response(SockModule, Socket, FilterMap, Timeout) ->
-    fetch_response(SockModule, Socket, Timeout, text, FilterMap, []).
+fetch_query_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout) ->
+    fetch_response(SockModule, Socket, Timeout, text, AllowedPaths, FilterMap, []).
 
 %% @doc Prepares a statement.
 -spec prepare(iodata(), module(), term()) -> #error{} | #prepared{}.
@@ -260,16 +250,11 @@ unprepare(#prepared{statement_id = Id}, SockModule, Socket) ->
     ok.
 
 %% @doc Executes a prepared statement.
--spec execute(#prepared{}, [term()], module(), term(), timeout()) ->
-    {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
-execute(PrepStmt, ParamValues, SockModule, Socket, Timeout) ->
-    execute(PrepStmt, ParamValues, SockModule, Socket, no_filtermap_fun,
-            Timeout).
--spec execute(#prepared{}, [term()], module(), term(), query_filtermap(),
-              timeout()) ->
+-spec execute(#prepared{}, [term()], module(), term(), [binary()],
+              query_filtermap(), timeout()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
 execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
-        SockModule, Socket, FilterMap, Timeout)
+        SockModule, Socket, AllowedPaths, FilterMap, Timeout)
   when ParamCount == length(ParamValues) ->
     %% Flags Constant Name
     %% 0x00 CURSOR_TYPE_NO_CURSOR
@@ -297,15 +282,12 @@ execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
             iolist_to_binary([Req1, TypesAndSigns, EncValues])
     end,
     {ok, _SeqNum1} = send_packet(SockModule, Socket, Req, 0),
-    fetch_execute_response(SockModule, Socket, FilterMap, Timeout).
+    fetch_execute_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout).
 
 %% @doc This is used by execute/5. If execute/5 returns {error, timeout}, this
 %% function can be called to retry to fetch the results of the query.
-fetch_execute_response(SockModule, Socket, Timeout) ->
-    fetch_execute_response(SockModule, Socket, no_filtermap_fun, Timeout).
-
-fetch_execute_response(SockModule, Socket, FilterMap, Timeout) ->
-    fetch_response(SockModule, Socket, Timeout, binary, FilterMap, []).
+fetch_execute_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout) ->
+    fetch_response(SockModule, Socket, Timeout, binary, AllowedPaths, FilterMap, []).
 
 %% @doc Changes the user of the connection.
 -spec change_user(module(), term(), iodata(), iodata(), binary(), binary(),
@@ -461,7 +443,8 @@ build_handshake_response(Handshake, Database, SetFoundRows) ->
 %% @doc The response sent by the client to the server after receiving the
 %% initial handshake from the server
 -spec build_handshake_response(#handshake{}, iodata(), iodata(),
-                               iodata() | undefined, boolean()) -> binary().
+                               iodata() | undefined, boolean()) ->
+    binary().
 build_handshake_response(Handshake, Username, Password, Database,
                          SetFoundRows) ->
     CapabilityFlags = basic_capabilities(Database /= undefined, SetFoundRows),
@@ -519,7 +502,8 @@ add_client_capabilities(Caps) ->
     ?CLIENT_MULTI_RESULTS bor
     ?CLIENT_PS_MULTI_RESULTS bor
     ?CLIENT_PLUGIN_AUTH bor
-    ?CLIENT_LONG_PASSWORD.
+    ?CLIENT_LONG_PASSWORD bor
+    ?CLIENT_LOCAL_FILES.
 
 -spec character_set([integer()]) -> integer().
 character_set(ServerVersion) when ServerVersion >= [5, 5, 3] ->
@@ -564,11 +548,29 @@ parse_handshake_confirm(<<?MORE_DATA, MoreData/binary>>) ->
 %% @doc Fetches one or more results and and parses the result set(s) using
 %% either the text format (for plain queries) or the binary format (for
 %% prepared statements).
--spec fetch_response(module(), term(), timeout(), text | binary,
+-spec fetch_response(module(), term(), timeout(), text | binary, [binary()],
                      query_filtermap(), list()) ->
     {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
-fetch_response(SockModule, Socket, Timeout, Proto, FilterMap, Acc) ->
+fetch_response(SockModule, Socket, Timeout, Proto, AllowedPaths, FilterMap, Acc) ->
     case recv_packet(SockModule, Socket, Timeout, any) of
+        {ok, ?local_infile_pattern = Packet, SeqNum2} ->
+            Filename = parse_local_infile_packet(Packet),
+            Acc1 = case send_file(SockModule, Socket, Filename, AllowedPaths, SeqNum2) of
+                {ok, _SeqNum3} ->
+                    Acc;
+                {{error, not_allowed}, _SeqNum3} ->
+                    ErrorMsg = <<"The server requested a file not permitted by the client: ",
+                                 Filename/binary>>,
+                    [#error{code = -1, msg = ErrorMsg}|Acc];
+                {{error, FileError}, _SeqNum3} ->
+                    FileErrorMsg = list_to_binary(file:format_error(FileError)),
+                    ErrorMsg = <<"The server requested a file which could not be opened "
+                                 "by the client: ", Filename/binary,
+                                 " (", FileErrorMsg/binary, ")">>,
+                    [#error{code = -2, msg = ErrorMsg}|Acc]
+            end,
+            fetch_response(SockModule, Socket, Timeout, Proto, AllowedPaths,
+                           FilterMap, Acc1);
         {ok, Packet, SeqNum2} ->
             Result = case Packet of
                 ?ok_pattern ->
@@ -585,7 +587,7 @@ fetch_response(SockModule, Socket, Timeout, Proto, FilterMap, Acc) ->
             case more_results_exists(Result) of
                 true ->
                     fetch_response(SockModule, Socket, Timeout, Proto,
-                                   FilterMap, Acc1);
+                                   AllowedPaths, FilterMap, Acc1);
                 false ->
                     {ok, lists:reverse(Acc1)}
             end;
@@ -641,7 +643,7 @@ fetch_resultset_rows(SockModule, Socket, FieldCount, ColDefs, Proto,
     end.
 
 -spec filtermap_resultset_row(query_filtermap(), [#col{}], [term()]) ->
-    query_filtermap_res().
+    boolean() | {true, term()}.
 filtermap_resultset_row(no_filtermap_fun, _, _) ->
     true;
 filtermap_resultset_row(Fun, _, Row) when is_function(Fun, 1) ->
@@ -1220,6 +1222,94 @@ recv_packet(SockModule, Socket, Timeout, ExpectSeqNum, Acc) ->
             {error, Reason}
     end.
 
+-spec send_file(module(), term(), Filename :: binary(), AllowedPaths :: [binary()],
+                SeqNum :: integer()) ->
+    {ok | {error, Reason}, NextSeqNum :: integer()}
+    when Reason :: not_allowed
+	         | file:posix()
+		 | badarg
+		 | system_limit.
+send_file(SockModule, Socket, Filename, AllowedPaths, SeqNum0) ->
+    {Result, SeqNum1} = case allowed_path(Filename, AllowedPaths) andalso
+                             file:open(Filename, [read, raw, binary]) of
+        false ->
+            {{error, not_allowed}, SeqNum0};
+        {ok, Handle} ->
+            {ok, SeqNum2} = send_file_chunk(SockModule, Socket, Handle, SeqNum0),
+            ok = file:close(Handle),
+            {ok, SeqNum2};
+        {error, _Reason} = E ->
+            {E, SeqNum0}
+    end,
+    {ok, SeqNum3} = send_packet(SockModule, Socket, <<>>, SeqNum1),
+    {Result, SeqNum3}.
+
+-spec allowed_path(binary(), [binary()]) -> boolean().
+allowed_path(Path, AllowedPaths) ->
+    valid_path(Path) andalso
+    binary:last(Path) =/= $/ andalso
+    lists:any(
+        fun
+            (AllowedPath) when Path =:= AllowedPath ->
+                true;
+            (AllowedPath) ->
+                Size = byte_size(AllowedPath),
+                HasSlash = binary:last(AllowedPath) =:= $/,
+                case Path of
+                    <<AllowedPath:Size/binary, _/binary>> when HasSlash -> true;
+                    <<AllowedPath:Size/binary, $/, _/binary>> -> true;
+                    _ -> false
+                end
+        end,
+        AllowedPaths
+    ).
+
+%% @doc Checks if the argument is a valid path.
+%%
+%% Returns `true' if the argument is an absolute path that does not contain
+%% any relative components like `..' or `.', otherwise `false'.
+-spec valid_path(term()) -> boolean().
+valid_path(Path) when is_binary(Path), byte_size(Path) > 0 ->
+    case filename:pathtype(Path) of
+        absolute ->
+            valid_abspath(Path);
+        volumerelative ->
+            case Path of
+                <<$/, _/binary>> ->
+                    false;
+                _ ->
+                    valid_abspath(Path)
+            end;
+        relative ->
+            false
+    end;
+valid_path(_Path) ->
+    false.
+
+-spec valid_abspath(<<_:8, _:_*8>>) -> boolean().
+valid_abspath(Path) ->
+    lists:all(
+        fun
+            (<<".">>) -> false;
+            (<<"..">>) -> false;
+            (_) -> true
+        end,
+        filename:split(Path)
+    ).
+
+-spec send_file_chunk(module(), term(), Handle :: file:io_device(), SeqNum :: integer()) ->
+    {ok, NextSeqNum :: integer()}.
+send_file_chunk(SockModule, Socket, Handle, SeqNum0) ->
+    case file:read(Handle, 16#ffffff) of
+        eof ->
+            {ok, SeqNum0};
+        {ok, <<>>} ->
+            send_file_chunk(SockModule, Socket, Handle, SeqNum0);
+        {ok, Data} ->
+            {ok, SeqNum1} = send_packet(SockModule, Socket, Data, SeqNum0),
+            send_file_chunk(SockModule, Socket, Handle, SeqNum1)
+    end.
+
 %% @doc Parses a packet header (32 bits) and returns a tuple.
 %%
 %% The client should first read a header and parse it. Then read PacketLength
@@ -1293,6 +1383,9 @@ parse_eof_packet(<<?EOF:8, NumWarnings:16/little, StatusFlags:16/little>>) ->
     %% (Older protocol: <<?EOF:8>>)
     #eof{status = StatusFlags, warning_count = NumWarnings}.
 
+parse_local_infile_packet(<<?LOCAL_INFILE_REQUEST:8, FileName/binary>>) ->
+    FileName.
+
 -spec parse_auth_method_switch(binary()) -> #auth_method_switch{}.
 parse_auth_method_switch(AMSData) ->
     {AuthPluginName, AuthPluginData} = get_null_terminated_binary(AMSData),
@@ -1688,5 +1781,69 @@ valid_params_test() ->
     ?assertNot(valid_params(InvalidParams)),
     ?assertNot(valid_params(ValidParams ++ InvalidParams)).
 
+valid_path_test() ->
+    ValidPaths = [
+        <<"/">>,
+        <<"/tmp">>,
+        <<"/tmp/">>,
+        <<"/tmp/foo">>
+    ],
+    InvalidPaths = [
+        <<>>,
+        <<"tmp">>,
+        <<"tmp/">>,
+        <<"tmp/foo">>,
+        <<"../tmp">>,
+        <<"/tmp/..">>,
+        <<"/tmp/foo/../bar">>,
+        "/tmp"
+    ],
+    lists:foreach(
+        fun (ValidPath) ->
+            ?assert(valid_path(ValidPath))
+        end,
+        ValidPaths
+    ),
+    lists:foreach(
+        fun (InvalidPath) ->
+            ?assertNot(valid_path(InvalidPath))
+        end,
+        InvalidPaths
+    ).
+
+allowed_path_test() ->
+    AllowedPaths = [
+        <<"/tmp/foo/file.csv">>,
+        <<"/tmp/foo/bar/">>,
+        <<"/tmp/foo/baz">>
+    ],
+    ValidPaths = [
+        <<"/tmp/foo/file.csv">>,
+        <<"/tmp/foo/bar/file.csv">>,
+        <<"/tmp/foo/baz/file.csv">>,
+        <<"/tmp/foo/baz">>
+    ],
+    InvalidPaths = [
+        <<"/tmp/file.csv">>,
+        <<"/tmp/foo/other_file.csv">>,
+        <<"/tmp/foo/other_dir/file.csv">>,
+        <<"/tmp/foo/../file.csv">>,
+        <<"/tmp/foo/../bar/file.csv">>,
+        <<"/tmp/foo/bar/">>,
+        <<"/tmp/foo/barbaz">>
+    ],
+    lists:foreach(
+        fun (ValidPath) ->
+            ?assert(allowed_path(ValidPath, AllowedPaths))
+        end,
+        ValidPaths
+    ),
+    lists:foreach(
+        fun (InvalidPath) ->
+            ?assertNot(allowed_path(InvalidPath, AllowedPaths))
+        end,
+        InvalidPaths
+    ).
+
 -endif.
 

+ 4 - 2
test/mysql_protocol_tests.erl

@@ -44,7 +44,8 @@ resultset_test() ->
     ExpectedCommunication = [{send, ExpectedReq},
                              {recv, ExpectedResponse}],
     Sock = mock_tcp:create(ExpectedCommunication),
-    {ok, [ResultSet]} = mysql_protocol:query(Query, mock_tcp, Sock, infinity),
+    {ok, [ResultSet]} = mysql_protocol:query(Query, mock_tcp, Sock, [],
+                                             no_filtermap_fun, infinity),
     mock_tcp:close(Sock),
     ?assertMatch(#resultset{cols = [#col{name = <<"@@version_comment">>}],
                             rows = [[<<"MySQL Community Server (GPL)">>]]},
@@ -81,7 +82,8 @@ resultset_error_test() ->
         "48 04 23 48 59 30 30 30    4e 6f 20 74 61 62 6c 65    H.#HY000No table"
         "73 20 75 73 65 64                                     s used"),
     Sock = mock_tcp:create([{send, ExpectedReq}, {recv, ExpectedResponse}]),
-    {ok, [Result]} = mysql_protocol:query(Query, mock_tcp, Sock, infinity),
+    {ok, [Result]} = mysql_protocol:query(Query, mock_tcp, Sock, [],
+                                          no_filtermap_fun, infinity),
     ?assertMatch(#error{}, Result),
     mock_tcp:close(Sock),
     ok.

+ 112 - 19
test/mysql_tests.erl

@@ -317,25 +317,55 @@ query_test_() ->
          mysql:stop(Pid)
      end,
      fun (Pid) ->
-         [{"Select db on connect", fun () -> connect_with_db(Pid) end},
-          {"Autocommit",           fun () -> autocommit(Pid) end},
-          {"Encode",               fun () -> encode(Pid) end},
-          {"Basic queries",        fun () -> basic_queries(Pid) end},
-          {"Filtermap queries",    fun () -> filtermap_queries(Pid) end},
-          {"FOUND_ROWS option",    fun () -> found_rows(Pid) end},
-          {"Multi statements",     fun () -> multi_statements(Pid) end},
-          {"Text protocol",        fun () -> text_protocol(Pid) end},
-          {"Binary protocol",      fun () -> binary_protocol(Pid) end},
-          {"FLOAT rounding",       fun () -> float_rounding(Pid) end},
-          {"DECIMAL",              fun () -> decimal(Pid) end},
-          {"INT",                  fun () -> int(Pid) end},
-          {"BIT(N)",               fun () -> bit(Pid) end},
-          {"DATE",                 fun () -> date(Pid) end},
-          {"TIME",                 fun () -> time(Pid) end},
-          {"DATETIME",             fun () -> datetime(Pid) end},
-          {"JSON",                 fun () -> json(Pid) end},
-          {"Microseconds",         fun () -> microseconds(Pid) end},
-          {"Invalid params",       fun () -> invalid_params(Pid) end}]
+         [{"Select db on connect",  fun () -> connect_with_db(Pid) end},
+          {"Autocommit",            fun () -> autocommit(Pid) end},
+          {"Encode",                fun () -> encode(Pid) end},
+          {"Basic queries",         fun () -> basic_queries(Pid) end},
+          {"Filtermap queries",     fun () -> filtermap_queries(Pid) end},
+          {"FOUND_ROWS option",     fun () -> found_rows(Pid) end},
+          {"Multi statements",      fun () -> multi_statements(Pid) end},
+          {"Text protocol",         fun () -> text_protocol(Pid) end},
+          {"Binary protocol",       fun () -> binary_protocol(Pid) end},
+          {"FLOAT rounding",        fun () -> float_rounding(Pid) end},
+          {"DECIMAL",               fun () -> decimal(Pid) end},
+          {"INT",                   fun () -> int(Pid) end},
+          {"BIT(N)",                fun () -> bit(Pid) end},
+          {"DATE",                  fun () -> date(Pid) end},
+          {"TIME",                  fun () -> time(Pid) end},
+          {"DATETIME",              fun () -> datetime(Pid) end},
+          {"JSON",                  fun () -> json(Pid) end},
+          {"Microseconds",          fun () -> microseconds(Pid) end},
+          {"Invalid params",        fun () -> invalid_params(Pid) end}]
+     end}.
+
+local_files_test_() ->
+    {setup,
+     fun () ->
+         {ok, Cwd0} = file:get_cwd(),
+         Cwd1 = iolist_to_binary(Cwd0),
+         Cwd2 = case binary:last(Cwd1) of
+             $/ -> Cwd1;
+             _ -> <<Cwd1/binary, $/>>
+         end,
+         {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password},
+                                       {log_warnings, false},
+                                       {keep_alive, true}, {allowed_local_paths, [Cwd2]}]),
+         ok = mysql:query(Pid, <<"DROP DATABASE IF EXISTS otptest">>),
+         ok = mysql:query(Pid, <<"CREATE DATABASE otptest">>),
+         ok = mysql:query(Pid, <<"USE otptest">>),
+         ok = mysql:query(Pid, <<"SET autocommit = 1">>),
+         ok = mysql:query(Pid, <<"SET SESSION sql_mode = ?">>, [?SQL_MODE]),
+         {Pid, Cwd2}
+     end,
+     fun ({Pid, _Cwd}) ->
+         ok = mysql:query(Pid, <<"DROP DATABASE otptest">>),
+         mysql:stop(Pid)
+     end,
+     fun ({Pid, Cwd}) ->
+          [{"Single statement", fun () -> load_data_local_infile(Pid, Cwd) end},
+          {"Missing file", fun () -> load_data_local_infile_missing(Pid, Cwd) end},
+          {"Not allowed", fun () -> load_data_local_infile_not_allowed(Pid, Cwd) end},
+          {"Multi statements", fun () -> load_data_local_infile_multi(Pid, Cwd) end}]
      end}.
 
 connect_with_db(_Pid) ->
@@ -861,6 +891,69 @@ invalid_params(Pid) ->
     ?assertError(badarg, mysql:query(Pid, "SELECT ?", [x])),
     ok = mysql:unprepare(Pid, StmtId).
 
+load_data_local_infile(Pid, Cwd) ->
+    File = iolist_to_binary(filename:join([Cwd, "load_local_infile_test.csv"])),
+    ok = file:write_file(File, <<"1;value 1\n2;value 2\n">>),
+    ok = mysql:query(Pid, <<"CREATE TABLE load_local_test (id int, value blob)">>),
+    ok = mysql:query(Pid, <<"LOAD DATA LOCAL "
+                            "INFILE '", File/binary, "' "
+                            "INTO TABLE load_local_test "
+                            "FIELDS TERMINATED BY ';' "
+                            "LINES TERMINATED BY '\\n'">>),
+    ok = file:delete(File),
+    {ok, Columns, Rows} = mysql:query(Pid,
+                                      <<"SELECT * FROM load_local_test ORDER BY id">>),
+    ?assertEqual([<<"id">>, <<"value">>], Columns),
+    ?assertEqual([[1, <<"value 1">>], [2, <<"value 2">>]], Rows),
+    ok = mysql:query(Pid, <<"DROP TABLE load_local_test">>).
+
+load_data_local_infile_missing(Pid, Cwd) ->
+    File = iolist_to_binary(filename:join([Cwd, "load_local_infile_missing_test.csv"])),
+    ok = mysql:query(Pid, <<"CREATE TABLE load_local_test (id int, value blob)">>),
+    Result = mysql:query(Pid, <<"LOAD DATA LOCAL "
+                                "INFILE '", File/binary, "' "
+                                "INTO TABLE load_local_test "
+                                "FIELDS TERMINATED BY ';' "
+                                "LINES TERMINATED BY '\\n'">>),
+    FilenameSize=byte_size(File),
+    ?assertMatch({error, {-2, undefined, <<"The server requested a file which could "
+                                           "not be opened by the client: ",
+                                           File:FilenameSize/binary, _/binary>>}},
+                 Result),
+    ok = mysql:query(Pid, <<"DROP TABLE load_local_test">>).
+
+load_data_local_infile_not_allowed(Pid, Cwd) ->
+    File = iolist_to_binary(filename:join([Cwd, "../load_local_infile_not_allowed_test.csv"])),
+    ok = mysql:query(Pid, <<"CREATE TABLE load_local_test (id int, value blob)">>),
+    Result = mysql:query(Pid, <<"LOAD DATA LOCAL "
+                                "INFILE '", File/binary, "' "
+                                "INTO TABLE load_local_test "
+                                "FIELDS TERMINATED BY ';' "
+                                "LINES TERMINATED BY '\\n'">>),
+    ?assertEqual({error, {-1, undefined, <<"The server requested a file not permitted "
+                                           "by the client: ", File/binary>>}}, Result),
+    ok = mysql:query(Pid, <<"DROP TABLE load_local_test">>).
+
+load_data_local_infile_multi(Pid, Cwd) ->
+    File = iolist_to_binary(filename:join([Cwd, "load_local_infile_test.csv"])),
+    ok = file:write_file(File, <<"1;value 1\n2;value 2\n">>),
+    ok = mysql:query(Pid, <<"CREATE TABLE load_local_test (id int, value blob)">>),
+    {ok, [Res1, Res2]} = mysql:query(Pid, <<"SELECT 'foo'; "
+                                            "LOAD DATA LOCAL "
+                                            "INFILE '", File/binary, "' "
+                                            "INTO TABLE load_local_test "
+                                            "FIELDS TERMINATED BY ';' "
+                                            "LINES TERMINATED BY '\\n'; "
+                                            "SELECT 'bar'">>),
+    ok = file:delete(File),
+    ?assertEqual({[<<"foo">>], [[<<"foo">>]]}, Res1),
+    ?assertEqual({[<<"bar">>], [[<<"bar">>]]}, Res2),
+    {ok, Columns, Rows} = mysql:query(Pid,
+                                      <<"SELECT * FROM load_local_test ORDER BY id">>),
+    ?assertEqual([<<"id">>, <<"value">>], Columns),
+    ?assertEqual([[1, <<"value 1">>], [2, <<"value 2">>]], Rows),
+    ok = mysql:query(Pid, <<"DROP TABLE load_local_test">>).
+
 %% @doc Tests write and read in text and the binary protocol, all combinations.
 %% This helper function assumes an empty table with a single column.
 write_read_text_binary(Conn, Term, SqlLiteral, Table, Column) ->