|
@@ -28,16 +28,11 @@
|
|
|
-module(mysql_protocol).
|
|
|
|
|
|
-export([handshake/8, change_user/8, quit/2, ping/2,
|
|
|
- query/4, query/5, fetch_query_response/3,
|
|
|
- fetch_query_response/4, prepare/3, unprepare/3,
|
|
|
- execute/5, execute/6, fetch_execute_response/3,
|
|
|
- fetch_execute_response/4, reset_connnection/2,
|
|
|
- valid_params/1]).
|
|
|
+ query/6, fetch_query_response/5, prepare/3, unprepare/3,
|
|
|
+ execute/7, fetch_execute_response/5, reset_connnection/2,
|
|
|
+ valid_params/1, valid_path/1]).
|
|
|
|
|
|
--type query_filtermap() :: no_filtermap_fun
|
|
|
- | fun(([term()]) -> query_filtermap_res())
|
|
|
- | fun(([term()], [term()]) -> query_filtermap_res()).
|
|
|
--type query_filtermap_res() :: boolean() | {true, term()}.
|
|
|
+-type query_filtermap() :: no_filtermap_fun | mysql:query_filtermap_fun().
|
|
|
|
|
|
-type auth_more_data() :: fast_auth_completed
|
|
|
| full_auth_requested
|
|
@@ -56,6 +51,7 @@
|
|
|
-define(ok_pattern, <<?OK, _/binary>>).
|
|
|
-define(error_pattern, <<?ERROR, _/binary>>).
|
|
|
-define(eof_pattern, <<?EOF, _:4/binary>>).
|
|
|
+-define(local_infile_pattern, <<?LOCAL_INFILE_REQUEST, _/binary>>).
|
|
|
|
|
|
%% Macros for auth methods.
|
|
|
-define(authmethod_none, <<>>).
|
|
@@ -74,6 +70,7 @@
|
|
|
SetFoundRows :: boolean()) ->
|
|
|
{ok, #handshake{}, SockModule :: module(), Socket :: term()} |
|
|
|
#error{}.
|
|
|
+
|
|
|
handshake(Host, Username, Password, Database, SockModule0, SSLOpts, Socket0,
|
|
|
SetFoundRows) ->
|
|
|
SeqNum0 = 0,
|
|
@@ -196,26 +193,19 @@ ping(SockModule, Socket) ->
|
|
|
{ok, OkPacket, _SeqNum2} = recv_packet(SockModule, Socket, SeqNum1),
|
|
|
parse_ok_packet(OkPacket).
|
|
|
|
|
|
--spec query(Query :: iodata(), module(), term(), timeout()) ->
|
|
|
- {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
|
|
|
-query(Query, SockModule, Socket, Timeout) ->
|
|
|
- query(Query, SockModule, Socket, no_filtermap_fun, Timeout).
|
|
|
-
|
|
|
--spec query(Query :: iodata(), module(), term(), query_filtermap(), timeout()) ->
|
|
|
+-spec query(Query :: iodata(), module(), term(), [binary()], query_filtermap(),
|
|
|
+ timeout()) ->
|
|
|
{ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
|
|
|
-query(Query, SockModule, Socket, FilterMap, Timeout) ->
|
|
|
+query(Query, SockModule, Socket, AllowedPaths, FilterMap, Timeout) ->
|
|
|
Req = <<?COM_QUERY, (iolist_to_binary(Query))/binary>>,
|
|
|
SeqNum0 = 0,
|
|
|
{ok, _SeqNum1} = send_packet(SockModule, Socket, Req, SeqNum0),
|
|
|
- fetch_query_response(SockModule, Socket, FilterMap, Timeout).
|
|
|
+ fetch_query_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout).
|
|
|
|
|
|
%% @doc This is used by query/4. If query/4 returns {error, timeout}, this
|
|
|
%% function can be called to retry to fetch the results of the query.
|
|
|
-fetch_query_response(SockModule, Socket, Timeout) ->
|
|
|
- fetch_query_response(SockModule, Socket, no_filtermap_fun, Timeout).
|
|
|
-
|
|
|
-fetch_query_response(SockModule, Socket, FilterMap, Timeout) ->
|
|
|
- fetch_response(SockModule, Socket, Timeout, text, FilterMap, []).
|
|
|
+fetch_query_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout) ->
|
|
|
+ fetch_response(SockModule, Socket, Timeout, text, AllowedPaths, FilterMap, []).
|
|
|
|
|
|
%% @doc Prepares a statement.
|
|
|
-spec prepare(iodata(), module(), term()) -> #error{} | #prepared{}.
|
|
@@ -260,16 +250,11 @@ unprepare(#prepared{statement_id = Id}, SockModule, Socket) ->
|
|
|
ok.
|
|
|
|
|
|
%% @doc Executes a prepared statement.
|
|
|
--spec execute(#prepared{}, [term()], module(), term(), timeout()) ->
|
|
|
- {ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
|
|
|
-execute(PrepStmt, ParamValues, SockModule, Socket, Timeout) ->
|
|
|
- execute(PrepStmt, ParamValues, SockModule, Socket, no_filtermap_fun,
|
|
|
- Timeout).
|
|
|
--spec execute(#prepared{}, [term()], module(), term(), query_filtermap(),
|
|
|
- timeout()) ->
|
|
|
+-spec execute(#prepared{}, [term()], module(), term(), [binary()],
|
|
|
+ query_filtermap(), timeout()) ->
|
|
|
{ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
|
|
|
execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
|
|
|
- SockModule, Socket, FilterMap, Timeout)
|
|
|
+ SockModule, Socket, AllowedPaths, FilterMap, Timeout)
|
|
|
when ParamCount == length(ParamValues) ->
|
|
|
%% Flags Constant Name
|
|
|
%% 0x00 CURSOR_TYPE_NO_CURSOR
|
|
@@ -297,15 +282,12 @@ execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,
|
|
|
iolist_to_binary([Req1, TypesAndSigns, EncValues])
|
|
|
end,
|
|
|
{ok, _SeqNum1} = send_packet(SockModule, Socket, Req, 0),
|
|
|
- fetch_execute_response(SockModule, Socket, FilterMap, Timeout).
|
|
|
+ fetch_execute_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout).
|
|
|
|
|
|
%% @doc This is used by execute/5. If execute/5 returns {error, timeout}, this
|
|
|
%% function can be called to retry to fetch the results of the query.
|
|
|
-fetch_execute_response(SockModule, Socket, Timeout) ->
|
|
|
- fetch_execute_response(SockModule, Socket, no_filtermap_fun, Timeout).
|
|
|
-
|
|
|
-fetch_execute_response(SockModule, Socket, FilterMap, Timeout) ->
|
|
|
- fetch_response(SockModule, Socket, Timeout, binary, FilterMap, []).
|
|
|
+fetch_execute_response(SockModule, Socket, AllowedPaths, FilterMap, Timeout) ->
|
|
|
+ fetch_response(SockModule, Socket, Timeout, binary, AllowedPaths, FilterMap, []).
|
|
|
|
|
|
%% @doc Changes the user of the connection.
|
|
|
-spec change_user(module(), term(), iodata(), iodata(), binary(), binary(),
|
|
@@ -461,7 +443,8 @@ build_handshake_response(Handshake, Database, SetFoundRows) ->
|
|
|
%% @doc The response sent by the client to the server after receiving the
|
|
|
%% initial handshake from the server
|
|
|
-spec build_handshake_response(#handshake{}, iodata(), iodata(),
|
|
|
- iodata() | undefined, boolean()) -> binary().
|
|
|
+ iodata() | undefined, boolean()) ->
|
|
|
+ binary().
|
|
|
build_handshake_response(Handshake, Username, Password, Database,
|
|
|
SetFoundRows) ->
|
|
|
CapabilityFlags = basic_capabilities(Database /= undefined, SetFoundRows),
|
|
@@ -519,7 +502,8 @@ add_client_capabilities(Caps) ->
|
|
|
?CLIENT_MULTI_RESULTS bor
|
|
|
?CLIENT_PS_MULTI_RESULTS bor
|
|
|
?CLIENT_PLUGIN_AUTH bor
|
|
|
- ?CLIENT_LONG_PASSWORD.
|
|
|
+ ?CLIENT_LONG_PASSWORD bor
|
|
|
+ ?CLIENT_LOCAL_FILES.
|
|
|
|
|
|
-spec character_set([integer()]) -> integer().
|
|
|
character_set(ServerVersion) when ServerVersion >= [5, 5, 3] ->
|
|
@@ -564,11 +548,29 @@ parse_handshake_confirm(<<?MORE_DATA, MoreData/binary>>) ->
|
|
|
%% @doc Fetches one or more results and and parses the result set(s) using
|
|
|
%% either the text format (for plain queries) or the binary format (for
|
|
|
%% prepared statements).
|
|
|
--spec fetch_response(module(), term(), timeout(), text | binary,
|
|
|
+-spec fetch_response(module(), term(), timeout(), text | binary, [binary()],
|
|
|
query_filtermap(), list()) ->
|
|
|
{ok, [#ok{} | #resultset{} | #error{}]} | {error, timeout}.
|
|
|
-fetch_response(SockModule, Socket, Timeout, Proto, FilterMap, Acc) ->
|
|
|
+fetch_response(SockModule, Socket, Timeout, Proto, AllowedPaths, FilterMap, Acc) ->
|
|
|
case recv_packet(SockModule, Socket, Timeout, any) of
|
|
|
+ {ok, ?local_infile_pattern = Packet, SeqNum2} ->
|
|
|
+ Filename = parse_local_infile_packet(Packet),
|
|
|
+ Acc1 = case send_file(SockModule, Socket, Filename, AllowedPaths, SeqNum2) of
|
|
|
+ {ok, _SeqNum3} ->
|
|
|
+ Acc;
|
|
|
+ {{error, not_allowed}, _SeqNum3} ->
|
|
|
+ ErrorMsg = <<"The server requested a file not permitted by the client: ",
|
|
|
+ Filename/binary>>,
|
|
|
+ [#error{code = -1, msg = ErrorMsg}|Acc];
|
|
|
+ {{error, FileError}, _SeqNum3} ->
|
|
|
+ FileErrorMsg = list_to_binary(file:format_error(FileError)),
|
|
|
+ ErrorMsg = <<"The server requested a file which could not be opened "
|
|
|
+ "by the client: ", Filename/binary,
|
|
|
+ " (", FileErrorMsg/binary, ")">>,
|
|
|
+ [#error{code = -2, msg = ErrorMsg}|Acc]
|
|
|
+ end,
|
|
|
+ fetch_response(SockModule, Socket, Timeout, Proto, AllowedPaths,
|
|
|
+ FilterMap, Acc1);
|
|
|
{ok, Packet, SeqNum2} ->
|
|
|
Result = case Packet of
|
|
|
?ok_pattern ->
|
|
@@ -585,7 +587,7 @@ fetch_response(SockModule, Socket, Timeout, Proto, FilterMap, Acc) ->
|
|
|
case more_results_exists(Result) of
|
|
|
true ->
|
|
|
fetch_response(SockModule, Socket, Timeout, Proto,
|
|
|
- FilterMap, Acc1);
|
|
|
+ AllowedPaths, FilterMap, Acc1);
|
|
|
false ->
|
|
|
{ok, lists:reverse(Acc1)}
|
|
|
end;
|
|
@@ -641,7 +643,7 @@ fetch_resultset_rows(SockModule, Socket, FieldCount, ColDefs, Proto,
|
|
|
end.
|
|
|
|
|
|
-spec filtermap_resultset_row(query_filtermap(), [#col{}], [term()]) ->
|
|
|
- query_filtermap_res().
|
|
|
+ boolean() | {true, term()}.
|
|
|
filtermap_resultset_row(no_filtermap_fun, _, _) ->
|
|
|
true;
|
|
|
filtermap_resultset_row(Fun, _, Row) when is_function(Fun, 1) ->
|
|
@@ -1220,6 +1222,94 @@ recv_packet(SockModule, Socket, Timeout, ExpectSeqNum, Acc) ->
|
|
|
{error, Reason}
|
|
|
end.
|
|
|
|
|
|
+-spec send_file(module(), term(), Filename :: binary(), AllowedPaths :: [binary()],
|
|
|
+ SeqNum :: integer()) ->
|
|
|
+ {ok | {error, Reason}, NextSeqNum :: integer()}
|
|
|
+ when Reason :: not_allowed
|
|
|
+ | file:posix()
|
|
|
+ | badarg
|
|
|
+ | system_limit.
|
|
|
+send_file(SockModule, Socket, Filename, AllowedPaths, SeqNum0) ->
|
|
|
+ {Result, SeqNum1} = case allowed_path(Filename, AllowedPaths) andalso
|
|
|
+ file:open(Filename, [read, raw, binary]) of
|
|
|
+ false ->
|
|
|
+ {{error, not_allowed}, SeqNum0};
|
|
|
+ {ok, Handle} ->
|
|
|
+ {ok, SeqNum2} = send_file_chunk(SockModule, Socket, Handle, SeqNum0),
|
|
|
+ ok = file:close(Handle),
|
|
|
+ {ok, SeqNum2};
|
|
|
+ {error, _Reason} = E ->
|
|
|
+ {E, SeqNum0}
|
|
|
+ end,
|
|
|
+ {ok, SeqNum3} = send_packet(SockModule, Socket, <<>>, SeqNum1),
|
|
|
+ {Result, SeqNum3}.
|
|
|
+
|
|
|
+-spec allowed_path(binary(), [binary()]) -> boolean().
|
|
|
+allowed_path(Path, AllowedPaths) ->
|
|
|
+ valid_path(Path) andalso
|
|
|
+ binary:last(Path) =/= $/ andalso
|
|
|
+ lists:any(
|
|
|
+ fun
|
|
|
+ (AllowedPath) when Path =:= AllowedPath ->
|
|
|
+ true;
|
|
|
+ (AllowedPath) ->
|
|
|
+ Size = byte_size(AllowedPath),
|
|
|
+ HasSlash = binary:last(AllowedPath) =:= $/,
|
|
|
+ case Path of
|
|
|
+ <<AllowedPath:Size/binary, _/binary>> when HasSlash -> true;
|
|
|
+ <<AllowedPath:Size/binary, $/, _/binary>> -> true;
|
|
|
+ _ -> false
|
|
|
+ end
|
|
|
+ end,
|
|
|
+ AllowedPaths
|
|
|
+ ).
|
|
|
+
|
|
|
+%% @doc Checks if the argument is a valid path.
|
|
|
+%%
|
|
|
+%% Returns `true' if the argument is an absolute path that does not contain
|
|
|
+%% any relative components like `..' or `.', otherwise `false'.
|
|
|
+-spec valid_path(term()) -> boolean().
|
|
|
+valid_path(Path) when is_binary(Path), byte_size(Path) > 0 ->
|
|
|
+ case filename:pathtype(Path) of
|
|
|
+ absolute ->
|
|
|
+ valid_abspath(Path);
|
|
|
+ volumerelative ->
|
|
|
+ case Path of
|
|
|
+ <<$/, _/binary>> ->
|
|
|
+ false;
|
|
|
+ _ ->
|
|
|
+ valid_abspath(Path)
|
|
|
+ end;
|
|
|
+ relative ->
|
|
|
+ false
|
|
|
+ end;
|
|
|
+valid_path(_Path) ->
|
|
|
+ false.
|
|
|
+
|
|
|
+-spec valid_abspath(<<_:8, _:_*8>>) -> boolean().
|
|
|
+valid_abspath(Path) ->
|
|
|
+ lists:all(
|
|
|
+ fun
|
|
|
+ (<<".">>) -> false;
|
|
|
+ (<<"..">>) -> false;
|
|
|
+ (_) -> true
|
|
|
+ end,
|
|
|
+ filename:split(Path)
|
|
|
+ ).
|
|
|
+
|
|
|
+-spec send_file_chunk(module(), term(), Handle :: file:io_device(), SeqNum :: integer()) ->
|
|
|
+ {ok, NextSeqNum :: integer()}.
|
|
|
+send_file_chunk(SockModule, Socket, Handle, SeqNum0) ->
|
|
|
+ case file:read(Handle, 16#ffffff) of
|
|
|
+ eof ->
|
|
|
+ {ok, SeqNum0};
|
|
|
+ {ok, <<>>} ->
|
|
|
+ send_file_chunk(SockModule, Socket, Handle, SeqNum0);
|
|
|
+ {ok, Data} ->
|
|
|
+ {ok, SeqNum1} = send_packet(SockModule, Socket, Data, SeqNum0),
|
|
|
+ send_file_chunk(SockModule, Socket, Handle, SeqNum1)
|
|
|
+ end.
|
|
|
+
|
|
|
%% @doc Parses a packet header (32 bits) and returns a tuple.
|
|
|
%%
|
|
|
%% The client should first read a header and parse it. Then read PacketLength
|
|
@@ -1293,6 +1383,9 @@ parse_eof_packet(<<?EOF:8, NumWarnings:16/little, StatusFlags:16/little>>) ->
|
|
|
%% (Older protocol: <<?EOF:8>>)
|
|
|
#eof{status = StatusFlags, warning_count = NumWarnings}.
|
|
|
|
|
|
+parse_local_infile_packet(<<?LOCAL_INFILE_REQUEST:8, FileName/binary>>) ->
|
|
|
+ FileName.
|
|
|
+
|
|
|
-spec parse_auth_method_switch(binary()) -> #auth_method_switch{}.
|
|
|
parse_auth_method_switch(AMSData) ->
|
|
|
{AuthPluginName, AuthPluginData} = get_null_terminated_binary(AMSData),
|
|
@@ -1688,5 +1781,69 @@ valid_params_test() ->
|
|
|
?assertNot(valid_params(InvalidParams)),
|
|
|
?assertNot(valid_params(ValidParams ++ InvalidParams)).
|
|
|
|
|
|
+valid_path_test() ->
|
|
|
+ ValidPaths = [
|
|
|
+ <<"/">>,
|
|
|
+ <<"/tmp">>,
|
|
|
+ <<"/tmp/">>,
|
|
|
+ <<"/tmp/foo">>
|
|
|
+ ],
|
|
|
+ InvalidPaths = [
|
|
|
+ <<>>,
|
|
|
+ <<"tmp">>,
|
|
|
+ <<"tmp/">>,
|
|
|
+ <<"tmp/foo">>,
|
|
|
+ <<"../tmp">>,
|
|
|
+ <<"/tmp/..">>,
|
|
|
+ <<"/tmp/foo/../bar">>,
|
|
|
+ "/tmp"
|
|
|
+ ],
|
|
|
+ lists:foreach(
|
|
|
+ fun (ValidPath) ->
|
|
|
+ ?assert(valid_path(ValidPath))
|
|
|
+ end,
|
|
|
+ ValidPaths
|
|
|
+ ),
|
|
|
+ lists:foreach(
|
|
|
+ fun (InvalidPath) ->
|
|
|
+ ?assertNot(valid_path(InvalidPath))
|
|
|
+ end,
|
|
|
+ InvalidPaths
|
|
|
+ ).
|
|
|
+
|
|
|
+allowed_path_test() ->
|
|
|
+ AllowedPaths = [
|
|
|
+ <<"/tmp/foo/file.csv">>,
|
|
|
+ <<"/tmp/foo/bar/">>,
|
|
|
+ <<"/tmp/foo/baz">>
|
|
|
+ ],
|
|
|
+ ValidPaths = [
|
|
|
+ <<"/tmp/foo/file.csv">>,
|
|
|
+ <<"/tmp/foo/bar/file.csv">>,
|
|
|
+ <<"/tmp/foo/baz/file.csv">>,
|
|
|
+ <<"/tmp/foo/baz">>
|
|
|
+ ],
|
|
|
+ InvalidPaths = [
|
|
|
+ <<"/tmp/file.csv">>,
|
|
|
+ <<"/tmp/foo/other_file.csv">>,
|
|
|
+ <<"/tmp/foo/other_dir/file.csv">>,
|
|
|
+ <<"/tmp/foo/../file.csv">>,
|
|
|
+ <<"/tmp/foo/../bar/file.csv">>,
|
|
|
+ <<"/tmp/foo/bar/">>,
|
|
|
+ <<"/tmp/foo/barbaz">>
|
|
|
+ ],
|
|
|
+ lists:foreach(
|
|
|
+ fun (ValidPath) ->
|
|
|
+ ?assert(allowed_path(ValidPath, AllowedPaths))
|
|
|
+ end,
|
|
|
+ ValidPaths
|
|
|
+ ),
|
|
|
+ lists:foreach(
|
|
|
+ fun (InvalidPath) ->
|
|
|
+ ?assertNot(allowed_path(InvalidPath, AllowedPaths))
|
|
|
+ end,
|
|
|
+ InvalidPaths
|
|
|
+ ).
|
|
|
+
|
|
|
-endif.
|
|
|
|