Browse Source

Fixes various client issues in cow_http2_machine

Loïc Hoguin 6 years ago
parent
commit
616b3b4015
1 changed files with 85 additions and 37 deletions
  1. 85 37
      src/cow_http2_machine.erl

+ 85 - 37
src/cow_http2_machine.erl

@@ -15,6 +15,7 @@
 -module(cow_http2_machine).
 
 -export([init/2]).
+-export([init_stream/2]).
 -export([init_upgrade_stream/2]).
 -export([frame/2]).
 -export([ignored_frame/1]).
@@ -29,6 +30,7 @@
 -export([get_local_setting/2]).
 -export([get_last_streamid/1]).
 -export([get_stream_local_state/2]).
+-export([get_stream_remote_state/2]).
 
 -type opts() :: #{
 	enable_connect_protocol => boolean(),
@@ -80,6 +82,7 @@
 	remote_read_size = 0 :: non_neg_integer(),
 
 	%% Unparsed te header. Used to know if we can send trailers.
+	%% Note that we can always send trailers to the server.
 	te :: undefined | binary()
 }).
 
@@ -133,15 +136,15 @@
 	%% by the client or by the server through PUSH_PROMISE frames.
 	streams = [] :: [stream()],
 
-	%% HTTP/2 streams that have been reset recently by the server.
+	%% HTTP/2 streams that have recently been reset locally.
 	%% We are expected to keep receiving additional frames after
 	%% sending an RST_STREAM.
-	lingering_streams = [] :: [cow_http2:streamid()],
+	local_lingering_streams = [] :: [cow_http2:streamid()],
 
-	%% HTTP/2 streams that have been reset recently by the client.
+	%% HTTP/2 streams that have recently been reset remotely.
 	%% We keep a few of these around in order to reject subsequent
 	%% frames on these streams.
-	rst_lingering_streams = [] :: [cow_http2:streamid()],
+	remote_lingering_streams = [] :: [cow_http2:streamid()],
 
 	%% HPACK decoding and encoding state.
 	decode_state = cow_hpack:init() :: cow_hpack:state(),
@@ -252,6 +255,16 @@ setting_from_opt(Settings, Opts, OptName, SettingName, Default) ->
 		Value -> Settings#{SettingName => Value}
 	end.
 
+-spec init_stream(binary(), State)
+	-> {ok, cow_http2:streamid(), State} when State::http2_machine().
+init_stream(Method, State=#http2_machine{mode=client, local_streamid=LocalStreamID,
+		local_settings=#{initial_window_size := RemoteWindow},
+		remote_settings=#{initial_window_size := LocalWindow}}) ->
+	Stream = #stream{id=LocalStreamID, method=Method,
+		local_window=LocalWindow, remote_window=RemoteWindow},
+	{ok, LocalStreamID, stream_store(Stream, State#http2_machine{
+		local_streamid=LocalStreamID + 2})}.
+
 -spec init_upgrade_stream(binary(), State)
 	-> {ok, cow_http2:streamid(), State} when State::http2_machine().
 init_upgrade_stream(Method, State=#http2_machine{mode=server, remote_streamid=0,
@@ -318,7 +331,7 @@ data_frame({data, _, _, Data}, State=#http2_machine{remote_window=ConnWindow})
 		'DATA frame overflowed the connection flow control window. (RFC7540 6.9, RFC7540 6.9.1)'},
 		State};
 data_frame(Frame={data, StreamID, _, Data}, State0=#http2_machine{
-		remote_window=ConnWindow, lingering_streams=Lingering}) ->
+		remote_window=ConnWindow, local_lingering_streams=Lingering}) ->
 	DataLen = byte_size(Data),
 	State = State0#http2_machine{remote_window=ConnWindow - DataLen},
 	case stream_get(StreamID, State) of
@@ -335,7 +348,7 @@ data_frame(Frame={data, StreamID, _, Data}, State0=#http2_machine{
 				'DATA frame received for a half-closed (remote) stream. (RFC7540 5.1)');
 		undefined ->
 			%% After we send an RST_STREAM frame and terminate a stream,
-			%% the client still might be sending us some more frames
+			%% the remote endpoint still might be sending us some more frames
 			%% until it can process this RST_STREAM. We therefore ignore
 			%% DATA frames received for such lingering streams.
 			case lists:member(StreamID, Lingering) of
@@ -514,7 +527,7 @@ headers_enforce_concurrency_limit(Frame=#headers{id=StreamID},
 	end.
 
 headers_pseudo_headers(Frame, State=#http2_machine{local_settings=LocalSettings},
-		Type=request, Stream, Headers0) ->
+		Type, Stream, Headers0) when Type =:= request; Type =:= push_promise ->
 	IsExtendedConnectEnabled = maps:get(enable_connect_protocol, LocalSettings, false),
 	case request_pseudo_headers(Headers0, #{}) of
 		%% Extended CONNECT method (RFC8441).
@@ -622,6 +635,8 @@ headers_regular_headers(Frame=#headers{id=StreamID},
 	case regular_headers(Headers, Type) of
 		ok when Type =:= request ->
 			request_expected_size(Frame, State, Type, Stream, PseudoHeaders, Headers);
+		ok when Type =:= push_promise ->
+			push_promise_frame(Frame, State, Stream, PseudoHeaders, Headers);
 		ok when Type =:= response ->
 			response_expected_size(Frame, State, Type, Stream, PseudoHeaders, Headers);
 		ok when Type =:= trailers ->
@@ -735,24 +750,24 @@ headers_frame(#headers{id=StreamID, fin=IsFin}, State0=#http2_machine{
 		local_settings=#{initial_window_size := RemoteWindow},
 		remote_settings=#{initial_window_size := LocalWindow}},
 		Type, Stream0, PseudoHeaders, Headers, Len) ->
-	Stream = case Stream0 of
-		undefined ->
+	{Stream, State1} = case Type of
+		request ->
 			TE = case lists:keyfind(<<"te">>, 1, Headers) of
 				{_, TE0} -> TE0;
 				false -> undefined
 			end,
-			#stream{id=StreamID, method=maps:get(method, PseudoHeaders),
+			{#stream{id=StreamID, method=maps:get(method, PseudoHeaders),
 				remote=IsFin, remote_expected_size=Len,
-				local_window=LocalWindow, remote_window=RemoteWindow, te=TE};
-		_ ->
-			case {Type, PseudoHeaders} of
-				{response, #{status := Status}} when Status >= 100, Status =< 199 ->
-					Stream0;
-				_ ->
-					Stream0#stream{remote=IsFin, remote_expected_size=Len}
-			end
+				local_window=LocalWindow, remote_window=RemoteWindow, te=TE},
+				State0#http2_machine{remote_streamid=StreamID}};
+		response ->
+			Stream1 = case PseudoHeaders of
+				#{status := Status} when Status >= 100, Status =< 199 -> Stream0;
+				_ -> Stream0#stream{remote=IsFin, remote_expected_size=Len}
+			end,
+			{Stream1, State0}
 	end,
-	State = stream_store(Stream, State0#http2_machine{remote_streamid=StreamID}),
+	State = stream_store(Stream, State1),
 	{ok, {headers, StreamID, IsFin, Headers, PseudoHeaders, Len}, State}.
 
 trailers_frame(#headers{id=StreamID}, State0, Stream0, Headers) ->
@@ -783,12 +798,12 @@ rst_stream_frame({rst_stream, StreamID, _}, State=#http2_machine{mode=Mode,
 		'RST_STREAM frame received on a stream in idle state. (RFC7540 5.1)'},
 		State};
 rst_stream_frame({rst_stream, StreamID, Reason}, State=#http2_machine{
-		streams=Streams0, rst_lingering_streams=Lingering0}) ->
+		streams=Streams0, remote_lingering_streams=Lingering0}) ->
 	Streams = lists:keydelete(StreamID, #stream.id, Streams0),
 	%% We only keep up to 10 streams in this state. @todo Make it configurable?
 	Lingering = [StreamID|lists:sublist(Lingering0, 10 - 1)],
 	{ok, {rst_stream, StreamID, Reason},
-		State#http2_machine{streams=Streams, rst_lingering_streams=Lingering}}.
+		State#http2_machine{streams=Streams, remote_lingering_streams=Lingering}}.
 
 %% SETTINGS frame.
 
@@ -868,7 +883,7 @@ streams_update_remote_window(State=#http2_machine{streams=Streams0}, Increment)
 
 push_promise_frame(_, State=#http2_machine{mode=server}) ->
 	{error, {connection_error, protocol_error,
-		'PUSH_PROMISE frames MUST only be sent on a peer-initiated stream. (RFC7540 6.6)'},
+		'PUSH_PROMISE frames MUST NOT be sent by the client. (RFC7540 6.6)'},
 		State};
 push_promise_frame(_, State=#http2_machine{local_settings=#{enable_push := false}}) ->
 	{error, {connection_error, protocol_error,
@@ -888,14 +903,12 @@ push_promise_frame(#push_promise{id=StreamID}, State)
 push_promise_frame(Frame=#push_promise{id=StreamID, head=IsHeadFin,
 		promised_id=PromisedStreamID, data=HeaderData}, State) ->
 	case stream_get(StreamID, State) of
-		#stream{remote=idle} ->
+		Stream=#stream{remote=idle} ->
 			case IsHeadFin of
 				head_fin ->
-					%% @todo Gotta make sure the headers_* functions
-					%% will work properly for PUSH_PROMISE requests.
 					headers_decode(#headers{id=PromisedStreamID,
 						fin=fin, head=IsHeadFin, data=HeaderData},
-						State, push_promise, undefined);
+						State, push_promise, Stream);
 				head_nofin ->
 					{ok, State#http2_machine{state={continuation, push_promise, Frame}}}
 			end;
@@ -911,6 +924,22 @@ push_promise_frame(Frame=#push_promise{id=StreamID, head=IsHeadFin,
 				State}
 	end.
 
+push_promise_frame(#headers{id=PromisedStreamID},
+		State0=#http2_machine{
+			local_settings=#{initial_window_size := RemoteWindow},
+			remote_settings=#{initial_window_size := LocalWindow}},
+		#stream{id=StreamID}, PseudoHeaders=#{method := Method}, Headers) ->
+	TE = case lists:keyfind(<<"te">>, 1, Headers) of
+		{_, TE0} -> TE0;
+		false -> undefined
+	end,
+	PromisedStream = #stream{id=PromisedStreamID, method=Method,
+		local=fin, local_window=LocalWindow,
+		remote_window=RemoteWindow, te=TE},
+	State = stream_store(PromisedStream,
+		State0#http2_machine{remote_streamid=PromisedStreamID}),
+	{ok, {push_promise, StreamID, PromisedStreamID, Headers, PseudoHeaders}, State}.
+
 %% PING frame.
 
 ping_frame({ping, _}, State) ->
@@ -939,14 +968,15 @@ window_update_frame({window_update, Increment}, State=#http2_machine{local_windo
 window_update_frame({window_update, Increment}, State=#http2_machine{local_window=ConnWindow}) ->
 	send_data(State#http2_machine{local_window=ConnWindow + Increment});
 %% Stream-specific WINDOW_UPDATE frame.
-window_update_frame({window_update, StreamID, _},
-		State=#http2_machine{remote_streamid=RemoteStreamID})
-		when StreamID > RemoteStreamID ->
+window_update_frame({window_update, StreamID, _}, State=#http2_machine{mode=Mode,
+		local_streamid=LocalStreamID, remote_streamid=RemoteStreamID})
+		when (?IS_LOCAL(Mode, StreamID) andalso (StreamID >= LocalStreamID))
+		orelse ((not ?IS_LOCAL(Mode, StreamID)) andalso (StreamID > RemoteStreamID)) ->
 	{error, {connection_error, protocol_error,
 		'WINDOW_UPDATE frame received on a stream in idle state. (RFC7540 5.1)'},
 		State};
 window_update_frame({window_update, StreamID, Increment},
-		State0=#http2_machine{rst_lingering_streams=RstLingering}) ->
+		State0=#http2_machine{remote_lingering_streams=Lingering}) ->
 	case stream_get(StreamID, State0) of
 		#stream{local_window=StreamWindow} when StreamWindow + Increment > 16#7fffffff ->
 			stream_reset(StreamID, State0, flow_control_error,
@@ -956,7 +986,7 @@ window_update_frame({window_update, StreamID, Increment},
 		undefined ->
 			%% WINDOW_UPDATE frames may be received for a short period of time
 			%% after a stream is closed. They must be ignored.
-			case lists:member(StreamID, RstLingering) of
+			case lists:member(StreamID, Lingering) of
 				false -> {ok, State0};
 				true -> stream_reset(StreamID, State0, stream_closed,
 					'WINDOW_UPDATE frame received after the stream was reset. (RFC7540 5.1)')
@@ -1247,14 +1277,16 @@ queue_data(Stream=#stream{local_buffer=Q0, local_buffer_size=Size0}, IsFin, Data
 
 %% Public interface to update the flow control window.
 
--spec update_window(0..16#7fffffff, State)
+-spec update_window(1..16#7fffffff, State)
 	-> State when State::http2_machine().
-update_window(Size, State=#http2_machine{remote_window=RemoteWindow}) ->
+update_window(Size, State=#http2_machine{remote_window=RemoteWindow})
+		when Size > 0 ->
 	State#http2_machine{remote_window=RemoteWindow + Size}.
 
--spec update_window(cow_http2:streamid(), 0..16#7fffffff, State)
+-spec update_window(cow_http2:streamid(), 1..16#7fffffff, State)
 	-> State when State::http2_machine().
-update_window(StreamID, Size, State) ->
+update_window(StreamID, Size, State)
+		when Size > 0 ->
 	Stream = #stream{remote_window=RemoteWindow} = stream_get(StreamID, State),
 	stream_store(Stream#stream{remote_window=RemoteWindow + Size}, State).
 
@@ -1314,6 +1346,22 @@ get_stream_local_state(StreamID, State=#http2_machine{mode=Mode,
 			{error, not_found}
 	end.
 
+%% Retrieve the remote state for a stream.
+
+-spec get_stream_remote_state(cow_http2:streamid(), http2_machine())
+	-> {ok, idle | cow_http2:fin()} | {error, not_found | closed}.
+get_stream_remote_state(StreamID, State=#http2_machine{mode=Mode,
+		local_streamid=LocalStreamID, remote_streamid=RemoteStreamID}) ->
+	case stream_get(StreamID, State) of
+		#stream{remote=IsFin} ->
+			{ok, IsFin};
+		undefined when (?IS_LOCAL(Mode, StreamID) andalso (StreamID < LocalStreamID))
+				orelse ((not ?IS_LOCAL(Mode, StreamID)) andalso (StreamID =< RemoteStreamID)) ->
+			{error, closed};
+		undefined ->
+			{error, not_found}
+	end.
+
 %% Stream-related functions.
 
 stream_get(StreamID, #http2_machine{streams=Streams}) ->
@@ -1336,7 +1384,7 @@ stream_reset(StreamID, State, Reason, HumanReadable) ->
 	{error, {stream_error, StreamID, Reason, HumanReadable},
 		stream_linger(StreamID, State)}.
 
-stream_linger(StreamID, State=#http2_machine{lingering_streams=Lingering0}) ->
+stream_linger(StreamID, State=#http2_machine{local_lingering_streams=Lingering0}) ->
 	%% We only keep up to 100 streams in this state. @todo Make it configurable?
 	Lingering = [StreamID|lists:sublist(Lingering0, 100 - 1)],
-	State#http2_machine{lingering_streams=Lingering}.
+	State#http2_machine{local_lingering_streams=Lingering}.