Просмотр исходного кода

Don't discard data following a Websocket upgrade request

While the protocol does not allow sending data before
receiving a successful Websocket upgrade response, we
do not want to discard that data if it does come in.
Loïc Hoguin 5 лет назад
Родитель
Сommit
c50d6aa09c
4 измененных файлов с 81 добавлено и 60 удалено
  1. 49 51
      src/cowboy_http.erl
  2. 6 2
      src/cowboy_websocket.erl
  3. 2 4
      test/sys_SUITE.erl
  4. 24 3
      test/ws_SUITE.erl

+ 49 - 51
src/cowboy_http.erl

@@ -111,6 +111,7 @@
 	transport :: module(),
 	proxy_header :: undefined | ranch_proxy_header:proxy_info(),
 	opts = #{} :: cowboy:opts(),
+	buffer = <<>> :: binary(),
 
 	%% Some options may be overriden for the current stream.
 	overriden_opts = #{} :: cowboy:opts(),
@@ -175,7 +176,7 @@ init(Parent, Ref, Socket, Transport, ProxyHeader, Opts) ->
 				parent=Parent, ref=Ref, socket=Socket,
 				transport=Transport, proxy_header=ProxyHeader, opts=Opts,
 				peer=Peer, sock=Sock, cert=Cert,
-				last_streamid=LastStreamID}), <<>>);
+				last_streamid=LastStreamID}));
 		{{error, Reason}, _, _} ->
 			terminate(undefined, {socket_error, Reason,
 				'A socket error occurred when retrieving the peer name.'});
@@ -187,22 +188,22 @@ init(Parent, Ref, Socket, Transport, ProxyHeader, Opts) ->
 				'A socket error occurred when retrieving the client TLS certificate.'})
 	end.
 
-before_loop(State=#state{socket=Socket, transport=Transport}, Buffer) ->
+before_loop(State=#state{socket=Socket, transport=Transport}) ->
 	%% @todo disable this when we get to the body, until the stream asks for it?
 	%% Perhaps have a threshold for how much we're willing to read before waiting.
 	Transport:setopts(Socket, [{active, once}]),
-	loop(State, Buffer).
+	loop(State).
 
 loop(State=#state{parent=Parent, socket=Socket, transport=Transport, opts=Opts,
-		timer=TimerRef, children=Children, in_streamid=InStreamID,
-		last_streamid=LastStreamID, streams=Streams}, Buffer) ->
+		buffer=Buffer, timer=TimerRef, children=Children, in_streamid=InStreamID,
+		last_streamid=LastStreamID, streams=Streams}) ->
 	Messages = Transport:messages(),
 	InactivityTimeout = maps:get(inactivity_timeout, Opts, 300000),
 	receive
 		%% Discard data coming in after the last request
 		%% we want to process was received fully.
 		{OK, Socket, _} when OK =:= element(1, Messages), InStreamID > LastStreamID ->
-			before_loop(State, Buffer);
+			before_loop(State);
 		%% Socket messages.
 		{OK, Socket, Data} when OK =:= element(1, Messages) ->
 			%% Only reset the timeout if it is idle_timeout (active streams).
@@ -218,30 +219,30 @@ loop(State=#state{parent=Parent, socket=Socket, transport=Transport, opts=Opts,
 		%% Timeouts.
 		{timeout, Ref, {shutdown, Pid}} ->
 			cowboy_children:shutdown_timeout(Children, Ref, Pid),
-			loop(State, Buffer);
+			loop(State);
 		{timeout, TimerRef, Reason} ->
 			timeout(State, Reason);
 		{timeout, _, _} ->
-			loop(State, Buffer);
+			loop(State);
 		%% System messages.
 		{'EXIT', Parent, Reason} ->
 			terminate(State, {stop, {exit, Reason}, 'Parent process terminated.'});
 		{system, From, Request} ->
-			sys:handle_system_msg(Request, From, Parent, ?MODULE, [], {State, Buffer});
+			sys:handle_system_msg(Request, From, Parent, ?MODULE, [], State);
 		%% Messages pertaining to a stream.
 		{{Pid, StreamID}, Msg} when Pid =:= self() ->
-			loop(info(State, StreamID, Msg), Buffer);
+			loop(info(State, StreamID, Msg));
 		%% Exit signal from children.
 		Msg = {'EXIT', Pid, _} ->
-			loop(down(State, Pid, Msg), Buffer);
+			loop(down(State, Pid, Msg));
 		%% Calls from supervisor module.
 		{'$gen_call', From, Call} ->
 			cowboy_children:handle_supervisor_call(Call, From, Children, ?MODULE),
-			loop(State, Buffer);
+			loop(State);
 		%% Unknown messages.
 		Msg ->
 			cowboy:log(warning, "Received stray message ~p.~n", [Msg], Opts),
-			loop(State, Buffer)
+			loop(State)
 	after InactivityTimeout ->
 		terminate(State, {internal_error, timeout, 'No message or data received before timeout.'})
 	end.
@@ -293,12 +294,12 @@ timeout(State, idle_timeout) ->
 		'Connection idle longer than configuration allows.'}).
 
 parse(<<>>, State) ->
-	before_loop(State, <<>>);
+	before_loop(State#state{buffer= <<>>});
 %% Do not process requests that come in after the last request
 %% and discard the buffer if any to save memory.
 parse(_, State=#state{in_streamid=InStreamID, in_state=#ps_request_line{},
 		last_streamid=LastStreamID}) when InStreamID > LastStreamID ->
-	before_loop(State, <<>>);
+	before_loop(State#state{buffer= <<>>});
 parse(Buffer, State=#state{in_state=#ps_request_line{empty_lines=EmptyLines}}) ->
 	after_parse(parse_request(Buffer, State, EmptyLines));
 parse(Buffer, State=#state{in_state=PS=#ps_header{headers=Headers, name=undefined}}) ->
@@ -317,7 +318,7 @@ parse(Buffer, State=#state{in_state=#ps_body{}}) ->
 
 after_parse({request, Req=#{streamid := StreamID, method := Method,
 		headers := Headers, version := Version},
-		State0=#state{opts=Opts, streams=Streams0}, Buffer}) ->
+		State0=#state{opts=Opts, buffer=Buffer, streams=Streams0}}) ->
 	try cowboy_stream:init(StreamID, Req, Opts) of
 		{Commands, StreamState} ->
 			TE = maps:get(<<"te">>, Headers, undefined),
@@ -339,8 +340,8 @@ after_parse({request, Req=#{streamid := StreamID, method := Method,
 	end;
 %% Streams are sequential so the body is always about the last stream created
 %% unless that stream has terminated.
-after_parse({data, StreamID, IsFin, Data, State=#state{opts=Opts,
-		streams=Streams0=[Stream=#stream{id=StreamID, state=StreamState0}|_]}, Buffer}) ->
+after_parse({data, StreamID, IsFin, Data, State=#state{opts=Opts, buffer=Buffer,
+		streams=Streams0=[Stream=#stream{id=StreamID, state=StreamState0}|_]}}) ->
 	try cowboy_stream:data(StreamID, IsFin, Data, StreamState0) of
 		{Commands, StreamState} ->
 			Streams = lists:keyreplace(StreamID, #stream.id, Streams0,
@@ -355,17 +356,17 @@ after_parse({data, StreamID, IsFin, Data, State=#state{opts=Opts,
 	end;
 %% No corresponding stream. We must skip the body of the previous request
 %% in order to process the next one.
-after_parse({data, _, _, _, State, Buffer}) ->
-	before_loop(State, Buffer);
-after_parse({more, State, Buffer}) ->
-	before_loop(State, Buffer).
+after_parse({data, _, _, _, State}) ->
+	before_loop(State);
+after_parse({more, State}) ->
+	before_loop(State).
 
 %% Request-line.
 
 -spec parse_request(Buffer, State, non_neg_integer())
-	-> {request, cowboy_req:req(), State, Buffer}
-	| {data, cowboy_stream:streamid(), cowboy_stream:fin(), binary(), State, Buffer}
-	| {more, State, Buffer}
+	-> {request, cowboy_req:req(), State}
+	| {data, cowboy_stream:streamid(), cowboy_stream:fin(), binary(), State}
+	| {more, State}
 	when Buffer::binary(), State::#state{}.
 %% Empty lines must be using \r\n.
 parse_request(<< $\n, _/bits >>, State, _) ->
@@ -384,7 +385,7 @@ parse_request(Buffer, State=#state{opts=Opts, in_streamid=InStreamID}, EmptyLine
 			error_terminate(414, State, {connection_error, limit_reached,
 				'The request-line length is larger than configuration allows. (RFC7230 3.1.1)'});
 		nomatch ->
-			{more, State#state{in_state=#ps_request_line{empty_lines=EmptyLines}}, Buffer};
+			{more, State#state{buffer=Buffer, in_state=#ps_request_line{empty_lines=EmptyLines}}};
 		1 when EmptyLines =:= MaxEmptyLines ->
 			error_terminate(400, State, {connection_error, limit_reached,
 				'More empty lines were received than configuration allows. (RFC7230 3.5)'});
@@ -527,7 +528,7 @@ before_parse_headers(Rest, State, M, A, P, Q, V) ->
 
 %% We need two or more bytes in the buffer to continue.
 parse_header(Rest, State=#state{in_state=PS}, Headers) when byte_size(Rest) < 2 ->
-	{more, State#state{in_state=PS#ps_header{headers=Headers}}, Rest};
+	{more, State#state{buffer=Rest, in_state=PS#ps_header{headers=Headers}}};
 parse_header(<< $\r, $\n, Rest/bits >>, S, Headers) ->
 	request(Rest, S, Headers);
 parse_header(Buffer, State=#state{opts=Opts, in_state=PS}, Headers) ->
@@ -554,7 +555,7 @@ parse_header_colon(Buffer, State=#state{opts=Opts, in_state=PS}, Headers) ->
 			%% so check if we have an LF and abort with an error if we do.
 			case match_eol(Buffer, 0) of
 				nomatch ->
-					{more, State#state{in_state=PS#ps_header{headers=Headers}}, Buffer};
+					{more, State#state{buffer=Buffer, in_state=PS#ps_header{headers=Headers}}};
 				_ ->
 					error_terminate(400, State#state{in_state=PS#ps_header{headers=Headers}},
 						{connection_error, protocol_error,
@@ -596,7 +597,7 @@ parse_hd_before_value(Buffer, State=#state{opts=Opts, in_state=PS}, H, N) ->
 				{connection_error, limit_reached,
 					'A header value is larger than configuration allows. (RFC7230 3.2.5, RFC6585 5)'});
 		nomatch ->
-			{more, State#state{in_state=PS#ps_header{headers=H, name=N}}, Buffer};
+			{more, State#state{buffer=Buffer, in_state=PS#ps_header{headers=H, name=N}}};
 		_ ->
 			parse_hd_value(Buffer, State, H, N, <<>>)
 	end.
@@ -766,7 +767,7 @@ request(Buffer, State0=#state{ref=Ref, transport=Transport, peer=Peer, sock=Sock
 				false ->
 					State0#state{in_streamid=StreamID + 1, in_state=#ps_request_line{}}
 			end,
-			{request, Req, State, Buffer};
+			{request, Req, State#state{buffer=Buffer}};
 		{true, HTTP2Settings} ->
 			%% We save the headers in case the upgrade will fail
 			%% and we need to pass them to cowboy_stream:early_error.
@@ -835,28 +836,28 @@ parse_body(Buffer, State=#state{in_streamid=StreamID, in_state=
 	try TDecode(Buffer, TState0) of
 		more ->
 			%% @todo Asks for 0 or more bytes.
-			{more, State, Buffer};
+			{more, State#state{buffer=Buffer}};
 		{more, Data, TState} ->
 			%% @todo Asks for 0 or more bytes.
-			{data, StreamID, nofin, Data, State#state{in_state=
-				PS#ps_body{received=Received + byte_size(Data),
-					transfer_decode_state=TState}}, <<>>};
+			{data, StreamID, nofin, Data, State#state{buffer= <<>>,
+				in_state=PS#ps_body{received=Received + byte_size(Data),
+					transfer_decode_state=TState}}};
 		{more, Data, _Length, TState} when is_integer(_Length) ->
 			%% @todo Asks for Length more bytes.
-			{data, StreamID, nofin, Data, State#state{in_state=
-				PS#ps_body{received=Received + byte_size(Data),
-					transfer_decode_state=TState}}, <<>>};
+			{data, StreamID, nofin, Data, State#state{buffer= <<>>,
+				in_state=PS#ps_body{received=Received + byte_size(Data),
+					transfer_decode_state=TState}}};
 		{more, Data, Rest, TState} ->
 			%% @todo Asks for 0 or more bytes.
-			{data, StreamID, nofin, Data, State#state{in_state=
-				PS#ps_body{received=Received + byte_size(Data),
-					transfer_decode_state=TState}}, Rest};
+			{data, StreamID, nofin, Data, State#state{buffer=Rest,
+				in_state=PS#ps_body{received=Received + byte_size(Data),
+					transfer_decode_state=TState}}};
 		{done, _HasTrailers, Rest} ->
 			{data, StreamID, fin, <<>>, set_timeout(
-				State#state{in_streamid=StreamID + 1, in_state=#ps_request_line{}}), Rest};
+				State#state{buffer=Rest, in_streamid=StreamID + 1, in_state=#ps_request_line{}})};
 		{done, Data, _HasTrailers, Rest} ->
 			{data, StreamID, fin, Data, set_timeout(
-				State#state{in_streamid=StreamID + 1, in_state=#ps_request_line{}}), Rest}
+				State#state{buffer=Rest, in_streamid=StreamID + 1, in_state=#ps_request_line{}})}
 	catch _:_ ->
 		Reason = {connection_error, protocol_error,
 			'Failure to decode the content. (RFC7230 4)'},
@@ -1094,7 +1095,7 @@ commands(State=#state{socket=Socket, transport=Transport, streams=Streams, out_s
 	commands(State#state{out_state=done}, StreamID, Tail);
 %% Protocol takeover.
 commands(State0=#state{ref=Ref, parent=Parent, socket=Socket, transport=Transport,
-		out_state=OutState, opts=Opts, children=Children}, StreamID,
+		out_state=OutState, opts=Opts, buffer=Buffer, children=Children}, StreamID,
 		[{switch_protocol, Headers, Protocol, InitialState}|_Tail]) ->
 	%% @todo This should be the last stream running otherwise we need to wait before switching.
 	%% @todo If there's streams opened after this one, fail instead of 101.
@@ -1117,10 +1118,7 @@ commands(State0=#state{ref=Ref, parent=Parent, socket=Socket, transport=Transpor
 	%% Terminate children processes and flush any remaining messages from the mailbox.
 	cowboy_children:terminate(Children),
 	flush(Parent),
-	%% @todo This is no good because commands return a state normally and here it doesn't
-	%% we need to let this module go entirely. Perhaps it should be handled directly in
-	%% cowboy_clear/cowboy_tls?
-	Protocol:takeover(Parent, Ref, Socket, Transport, Opts, <<>>, InitialState);
+	Protocol:takeover(Parent, Ref, Socket, Transport, Opts, Buffer, InitialState);
 %% Set options dynamically.
 commands(State0=#state{overriden_opts=Opts},
 		StreamID, [{set_options, SetOpts}|Tail]) ->
@@ -1446,12 +1444,12 @@ terminate_linger_loop(State=#state{socket=Socket, transport=Transport}, TimerRef
 
 %% System callbacks.
 
--spec system_continue(_, _, {#state{}, binary()}) -> ok.
-system_continue(_, _, {State, Buffer}) ->
-	loop(State, Buffer).
+-spec system_continue(_, _, #state{}) -> ok.
+system_continue(_, _, State) ->
+	loop(State).
 
 -spec system_terminate(any(), _, _, {#state{}, binary()}) -> no_return().
-system_terminate(Reason, _, _, {State, _}) ->
+system_terminate(Reason, _, _, State) ->
 	terminate(State, {stop, {exit, Reason}, 'sys:terminate/2,3 was called.'}).
 
 -spec system_code_change(Misc, _, _, _) -> {ok, Misc} when Misc::{#state{}, binary()}.

+ 6 - 2
src/cowboy_websocket.erl

@@ -291,10 +291,14 @@ takeover(Parent, Ref, Socket, Transport, _Opts, Buffer,
 	State = loop_timeout(State0#state{parent=Parent,
 		ref=Ref, socket=Socket, transport=Transport,
 		key=undefined, messages=Messages}),
+	%% We call parse_header/3 immediately because there might be
+	%% some data in the buffer that was sent along with the handshake.
+	%% While it is not allowed by the protocol to send frames immediately,
+	%% we still want to process that data if any.
 	case erlang:function_exported(Handler, websocket_init, 1) of
 		true -> handler_call(State, HandlerState, #ps_header{buffer=Buffer},
-			websocket_init, undefined, fun before_loop/3);
-		false -> before_loop(State, HandlerState, #ps_header{buffer=Buffer})
+			websocket_init, undefined, fun parse_header/3);
+		false -> parse_header(State, HandlerState, #ps_header{buffer=Buffer})
 	end.
 
 before_loop(State=#state{active=false}, HandlerState, ParseState) ->

+ 2 - 4
test/sys_SUITE.erl

@@ -602,9 +602,8 @@ sys_get_state_h1(Config) ->
 	{ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), []),
 	timer:sleep(100),
 	Pid = get_remote_pid_tcp(Socket),
-	{State, Buffer} = sys:get_state(Pid),
+	State = sys:get_state(Pid),
 	state = element(1, State),
-	true = is_binary(Buffer),
 	ok.
 
 sys_get_state_h2(Config) ->
@@ -726,9 +725,8 @@ sys_replace_state_h1(Config) ->
 	{ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), []),
 	timer:sleep(100),
 	Pid = get_remote_pid_tcp(Socket),
-	{State, Buffer} = sys:replace_state(Pid, fun(S) -> S end),
+	State = sys:replace_state(Pid, fun(S) -> S end),
 	state = element(1, State),
-	true = is_binary(Buffer),
 	ok.
 
 sys_replace_state_h2(Config) ->

+ 24 - 3
test/ws_SUITE.erl

@@ -304,6 +304,18 @@ do_ws_deflate_opts_z(Path, Config) ->
 	{error, closed} = gen_tcp:recv(Socket, 0, 6000),
 	ok.
 
+ws_first_frame_with_handshake(Config) ->
+	doc("Client sends the first frame immediately with the handshake. "
+		"This is invalid according to the protocol but we still want "
+		"to accept it if the handshake is successful."),
+	Mask = 16#37fa213d,
+	MaskedHello = do_mask(<<"Hello">>, Mask, <<>>),
+	{ok, Socket, _} = do_handshake("/ws_echo", "",
+		<<1:1, 0:3, 1:4, 1:1, 5:7, Mask:32, MaskedHello/binary>>,
+		Config),
+	{ok, <<1:1, 0:3, 1:4, 0:1, 5:7, "Hello">>} = gen_tcp:recv(Socket, 0, 6000),
+	ok.
+
 ws_init_return_ok(Config) ->
 	doc("Handler does nothing."),
 	{ok, Socket, _} = do_handshake("/ws_init?ok", Config),
@@ -636,9 +648,12 @@ ws_webkit_deflate_single_bytes(Config) ->
 %% Internal.
 
 do_handshake(Path, Config) ->
-	do_handshake(Path, "", Config).
+	do_handshake(Path, "", "", Config).
 
 do_handshake(Path, ExtraHeaders, Config) ->
+	do_handshake(Path, ExtraHeaders, "", Config).
+
+do_handshake(Path, ExtraHeaders, ExtraData, Config) ->
 	{ok, Socket} = gen_tcp:connect("localhost", config(port, Config),
 		[binary, {active, false}]),
 	ok = gen_tcp:send(Socket, [
@@ -650,10 +665,16 @@ do_handshake(Path, ExtraHeaders, Config) ->
 		"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
 		"Upgrade: websocket\r\n",
 		ExtraHeaders,
-		"\r\n"]),
+		"\r\n",
+		ExtraData]),
 	{ok, Handshake} = gen_tcp:recv(Socket, 0, 6000),
 	{ok, {http_response, {1, 1}, 101, _}, Rest} = erlang:decode_packet(http, Handshake, []),
-	[Headers, <<>>] = do_decode_headers(erlang:decode_packet(httph, Rest, []), []),
+	[Headers, Data] = do_decode_headers(erlang:decode_packet(httph, Rest, []), []),
+	%% Queue extra data back, if any. We don't want to receive it yet.
+	case Data of
+		<<>> -> ok;
+		_ -> gen_tcp:unrecv(Socket, Data)
+	end,
 	{_, "Upgrade"} = lists:keyfind('Connection', 1, Headers),
 	{_, "websocket"} = lists:keyfind('Upgrade', 1, Headers),
 	{_, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="} = lists:keyfind("sec-websocket-accept", 1, Headers),