Browse Source

Merge the two separate receive loops in cowboy_websocket

Also rename a bunch of functions to make the code easier to read.
Loïc Hoguin 7 years ago
parent
commit
21c9c66971
2 changed files with 109 additions and 103 deletions
  1. 109 100
      src/cowboy_websocket.erl
  2. 0 3
      test/sys_SUITE.erl

+ 109 - 100
src/cowboy_websocket.erl

@@ -20,7 +20,7 @@
 -export([upgrade/4]).
 -export([upgrade/5]).
 -export([takeover/7]).
--export([handler_loop/3]).
+-export([loop/3]).
 
 -export([system_continue/3]).
 -export([system_terminate/4]).
@@ -202,53 +202,64 @@ websocket_handshake(State=#state{key=Key},
 
 %% Connection process.
 
-%% @todo Keep parent and handle system messages.
+-record(ps_header, {
+	buffer = <<>> :: binary()
+}).
+
+-record(ps_payload, {
+	type :: cow_ws:frame_type(),
+	len :: non_neg_integer(),
+	mask_key :: cow_ws:mask_key(),
+	rsv :: cow_ws:rsv(),
+	close_code = undefined :: undefined | cow_ws:close_code(),
+	unmasked = <<>> :: binary(),
+	unmasked_len = 0 :: non_neg_integer(),
+	buffer = <<>> :: binary()
+}).
+
+-type parse_state() :: #ps_header{} | #ps_payload{}.
+
 -spec takeover(pid(), ranch:ref(), inet:socket(), module(), any(), binary(),
-	{#state{}, any()}) -> ok.
+	{#state{}, any()}) -> no_return().
 takeover(Parent, Ref, Socket, Transport, _Opts, Buffer,
 		{State0=#state{handler=Handler}, HandlerState}) ->
 	%% @todo We should have an option to disable this behavior.
 	ranch:remove_connection(Ref),
-	State1 = handler_loop_timeout(State0#state{parent=Parent,
-		ref=Ref, socket=Socket, transport=Transport}),
-	State = State1#state{key=undefined, messages=Transport:messages()},
+	State = loop_timeout(State0#state{parent=Parent,
+		ref=Ref, socket=Socket, transport=Transport,
+		key=undefined, messages=Transport:messages()}),
 	case erlang:function_exported(Handler, websocket_init, 1) of
-		true -> handler_call(State, HandlerState, Buffer, websocket_init, undefined, fun handler_before_loop/3);
-		false -> handler_before_loop(State, HandlerState, Buffer)
+		true -> handler_call(State, HandlerState, #ps_header{buffer=Buffer},
+			websocket_init, undefined, fun before_loop/3);
+		false -> before_loop(State, HandlerState, #ps_header{buffer=Buffer})
 	end.
 
--spec handler_before_loop(#state{}, any(), binary())
-%% @todo Yeah not env.
-	-> {ok, cowboy_middleware:env()}.
-handler_before_loop(State=#state{
-			socket=Socket, transport=Transport, hibernate=true},
-		HandlerState, SoFar) ->
+before_loop(State=#state{socket=Socket, transport=Transport, hibernate=true},
+		HandlerState, ParseState) ->
 	Transport:setopts(Socket, [{active, once}]),
-	proc_lib:hibernate(?MODULE, handler_loop,
-		[State#state{hibernate=false}, HandlerState, SoFar]);
-handler_before_loop(State=#state{socket=Socket, transport=Transport},
-		HandlerState, SoFar) ->
+	proc_lib:hibernate(?MODULE, loop,
+		[State#state{hibernate=false}, HandlerState, ParseState]);
+before_loop(State=#state{socket=Socket, transport=Transport},
+		HandlerState, ParseState) ->
 	Transport:setopts(Socket, [{active, once}]),
-	handler_loop(State, HandlerState, SoFar).
+	loop(State, HandlerState, ParseState).
 
--spec handler_loop_timeout(#state{}) -> #state{}.
-handler_loop_timeout(State=#state{timeout=infinity}) ->
+-spec loop_timeout(#state{}) -> #state{}.
+loop_timeout(State=#state{timeout=infinity}) ->
 	State#state{timeout_ref=undefined};
-handler_loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) ->
+loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) ->
 	_ = case PrevRef of undefined -> ignore; PrevRef ->
 		erlang:cancel_timer(PrevRef) end,
 	TRef = erlang:start_timer(Timeout, self(), ?MODULE),
 	State#state{timeout_ref=TRef}.
 
--spec handler_loop(#state{}, any(), binary())
-	-> {ok, cowboy_middleware:env()}.
-handler_loop(State=#state{parent=Parent, socket=Socket, messages={OK, Closed, Error},
-		timeout_ref=TRef}, HandlerState, SoFar) ->
+-spec loop(#state{}, any(), parse_state()) -> no_return().
+loop(State=#state{parent=Parent, socket=Socket, messages={OK, Closed, Error},
+		timeout_ref=TRef}, HandlerState, ParseState) ->
 	receive
 		{OK, Socket, Data} ->
-			State2 = handler_loop_timeout(State),
-			websocket_data(State2, HandlerState,
-				<< SoFar/binary, Data/binary >>);
+			State2 = loop_timeout(State),
+			parse(State2, HandlerState, ParseState, Data);
 		{Closed, Socket} ->
 			terminate(State, HandlerState, {error, closed});
 		{Error, Socket, Reason} ->
@@ -256,124 +267,121 @@ handler_loop(State=#state{parent=Parent, socket=Socket, messages={OK, Closed, Er
 		{timeout, TRef, ?MODULE} ->
 			websocket_close(State, HandlerState, timeout);
 		{timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
-			handler_loop(State, HandlerState, SoFar);
+			loop(State, HandlerState, ParseState);
 		%% System messages.
 		{'EXIT', Parent, Reason} ->
 			%% @todo We should exit gracefully.
 			exit(Reason);
 		{system, From, Request} ->
 			sys:handle_system_msg(Request, From, Parent, ?MODULE, [],
-				{State, HandlerState, SoFar});
+				{State, HandlerState, ParseState});
 		%% Calls from supervisor module.
 		{'$gen_call', From, Call} ->
 			cowboy_children:handle_supervisor_call(Call, From, [], ?MODULE),
-			handler_loop(State, HandlerState, SoFar);
+			loop(State, HandlerState, ParseState);
 		Message ->
-			handler_call(State, HandlerState,
-				SoFar, websocket_info, Message, fun handler_before_loop/3)
+			handler_call(State, HandlerState, ParseState,
+				websocket_info, Message, fun before_loop/3)
 	end.
 
--spec websocket_data(#state{}, any(), binary())
-	-> {ok, cowboy_middleware:env()}.
-websocket_data(State=#state{frag_state=FragState, extensions=Extensions}, HandlerState, Data) ->
+parse(State, HandlerState, PS=#ps_header{buffer=Buffer}, Data) ->
+	parse_header(State, HandlerState, PS#ps_header{
+		buffer= <<Buffer/binary, Data/binary>>});
+parse(State, HandlerState, PS=#ps_payload{buffer=Buffer}, Data) ->
+	parse_payload(State, HandlerState, PS#ps_payload{buffer= <<>>},
+		<<Buffer/binary, Data/binary>>).
+
+parse_header(State=#state{frag_state=FragState, extensions=Extensions}, HandlerState,
+		ParseState=#ps_header{buffer=Data}) ->
 	case cow_ws:parse_header(Data, Extensions, FragState) of
 		%% All frames sent from the client to the server are masked.
 		{_, _, _, _, undefined, _} ->
 			websocket_close(State, HandlerState, {error, badframe});
 		{Type, FragState2, Rsv, Len, MaskKey, Rest} ->
-			websocket_payload(State#state{frag_state=FragState2}, HandlerState, Type, Len, MaskKey, Rsv, undefined, <<>>, 0, Rest);
+			parse_payload(State#state{frag_state=FragState2}, HandlerState,
+				#ps_payload{type=Type, len=Len, mask_key=MaskKey, rsv=Rsv}, Rest);
 		more ->
-			handler_before_loop(State, HandlerState, Data);
+			before_loop(State, HandlerState, ParseState);
 		error ->
 			websocket_close(State, HandlerState, {error, badframe})
 	end.
 
-websocket_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions},
-		HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data) ->
-	case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen, Type, Len, FragState, Extensions, Rsv) of
-		{ok, CloseCode2, Payload, Utf8State, Rest} ->
-			websocket_dispatch(State#state{utf8_state=Utf8State},
-				HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode2, Rest);
+parse_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions},
+		HandlerState, ParseState=#ps_payload{
+			type=Type, len=Len, mask_key=MaskKey, rsv=Rsv,
+			unmasked=Unmasked, unmasked_len=UnmaskedLen}, Data) ->
+	case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen,
+			Type, Len, FragState, Extensions, Rsv) of
+		{ok, CloseCode, Payload, Utf8State, Rest} ->
+			dispatch_frame(State#state{utf8_state=Utf8State}, HandlerState,
+				ParseState#ps_payload{unmasked= <<Unmasked/binary, Payload/binary>>,
+					close_code=CloseCode}, Rest);
 		{ok, Payload, Utf8State, Rest} ->
-			websocket_dispatch(State#state{utf8_state=Utf8State},
-				HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode, Rest);
-		{more, CloseCode2, Payload, Utf8State} ->
-			websocket_payload_loop(State#state{utf8_state=Utf8State},
-				HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode2,
-				<< Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data));
+			dispatch_frame(State#state{utf8_state=Utf8State}, HandlerState,
+				ParseState#ps_payload{unmasked= <<Unmasked/binary, Payload/binary>>},
+				Rest);
+		{more, CloseCode, Payload, Utf8State} ->
+			before_loop(State#state{utf8_state=Utf8State}, HandlerState,
+				ParseState#ps_payload{len=Len - byte_size(Data), close_code=CloseCode,
+					unmasked= <<Unmasked/binary, Payload/binary>>,
+					unmasked_len=UnmaskedLen + byte_size(Data)});
 		{more, Payload, Utf8State} ->
-			websocket_payload_loop(State#state{utf8_state=Utf8State},
-				HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode,
-				<< Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data));
+			before_loop(State#state{utf8_state=Utf8State}, HandlerState,
+				ParseState#ps_payload{len=Len - byte_size(Data),
+					unmasked= <<Unmasked/binary, Payload/binary>>,
+					unmasked_len=UnmaskedLen + byte_size(Data)});
 		Error = {error, _Reason} ->
 			websocket_close(State, HandlerState, Error)
 	end.
 
-websocket_payload_loop(State=#state{socket=Socket, transport=Transport,
-		messages={OK, Closed, Error}, timeout_ref=TRef},
-		HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen) ->
-	Transport:setopts(Socket, [{active, once}]),
-	receive
-		{OK, Socket, Data} ->
-			State2 = handler_loop_timeout(State),
-			websocket_payload(State2, HandlerState,
-				Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data);
-		{Closed, Socket} ->
-			terminate(State, HandlerState, {error, closed});
-		{Error, Socket, Reason} ->
-			terminate(State, HandlerState, {error, Reason});
-		{timeout, TRef, ?MODULE} ->
-			websocket_close(State, HandlerState, timeout);
-		{timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
-			websocket_payload_loop(State, HandlerState,
-				Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen);
-		Message ->
-			handler_call(State, HandlerState,
-				<<>>, websocket_info, Message,
-				fun (State2, HandlerState2, _) ->
-					websocket_payload_loop(State2, HandlerState2,
-						Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen)
-				end)
-	end.
-
-websocket_dispatch(State=#state{socket=Socket, transport=Transport, frag_state=FragState, frag_buffer=SoFar, extensions=Extensions},
-		HandlerState, Type0, Payload0, CloseCode0, RemainingData) ->
+dispatch_frame(State=#state{socket=Socket, transport=Transport,
+		frag_state=FragState, frag_buffer=SoFar, extensions=Extensions},
+		HandlerState, #ps_payload{type=Type0, unmasked=Payload0, close_code=CloseCode0},
+		RemainingData) ->
 	case cow_ws:make_frame(Type0, Payload0, CloseCode0, FragState) of
 		%% @todo Allow receiving fragments.
 		{fragment, nofin, _, Payload} ->
-			websocket_data(State#state{frag_buffer= << SoFar/binary, Payload/binary >>}, HandlerState, RemainingData);
+			parse_header(State#state{frag_buffer= << SoFar/binary, Payload/binary >>},
+				HandlerState, #ps_header{buffer=RemainingData});
 		{fragment, fin, Type, Payload} ->
-			handler_call(State#state{frag_state=undefined, frag_buffer= <<>>}, HandlerState, RemainingData,
-				websocket_handle, {Type, << SoFar/binary, Payload/binary >>}, fun websocket_data/3);
+			handler_call(State#state{frag_state=undefined, frag_buffer= <<>>}, HandlerState,
+				#ps_header{buffer=RemainingData},
+				websocket_handle, {Type, << SoFar/binary, Payload/binary >>},
+				fun parse_header/3);
 		close ->
 			websocket_close(State, HandlerState, remote);
 		{close, CloseCode, Payload} ->
 			websocket_close(State, HandlerState, {remote, CloseCode, Payload});
 		Frame = ping ->
 			Transport:send(Socket, cow_ws:frame(pong, Extensions)),
-			handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3);
+			handler_call(State, HandlerState,
+				#ps_header{buffer=RemainingData},
+				websocket_handle, Frame, fun parse_header/3);
 		Frame = {ping, Payload} ->
 			Transport:send(Socket, cow_ws:frame({pong, Payload}, Extensions)),
-			handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3);
+			handler_call(State, HandlerState,
+				#ps_header{buffer=RemainingData},
+				websocket_handle, Frame, fun parse_header/3);
 		Frame ->
-			handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3)
+			handler_call(State, HandlerState,
+				#ps_header{buffer=RemainingData},
+				websocket_handle, Frame, fun parse_header/3)
 	end.
 
--spec handler_call(#state{}, any(), binary(), atom(), any(), fun()) -> no_return().
 handler_call(State=#state{handler=Handler}, HandlerState,
-		RemainingData, Callback, Message, NextState) ->
+		ParseState, Callback, Message, NextState) ->
 	try case Callback of
 		websocket_init -> Handler:websocket_init(HandlerState);
 		_ -> Handler:Callback(Message, HandlerState)
 	end of
 		{ok, HandlerState2} ->
-			NextState(State, HandlerState2, RemainingData);
+			NextState(State, HandlerState2, ParseState);
 		{ok, HandlerState2, hibernate} ->
-			NextState(State#state{hibernate=true}, HandlerState2, RemainingData);
+			NextState(State#state{hibernate=true}, HandlerState2, ParseState);
 		{reply, Payload, HandlerState2} ->
 			case websocket_send(Payload, State) of
 				ok ->
-					NextState(State, HandlerState2, RemainingData);
+					NextState(State, HandlerState2, ParseState);
 				stop ->
 					terminate(State, HandlerState2, stop);
 				Error = {error, _} ->
@@ -383,7 +391,7 @@ handler_call(State=#state{handler=Handler}, HandlerState,
 			case websocket_send(Payload, State) of
 				ok ->
 					NextState(State#state{hibernate=true},
-						HandlerState2, RemainingData);
+						HandlerState2, ParseState);
 				stop ->
 					terminate(State, HandlerState2, stop);
 				Error = {error, _} ->
@@ -458,15 +466,16 @@ handler_terminate(#state{handler=Handler, req=Req}, HandlerState, Reason) ->
 
 %% System callbacks.
 
--spec system_continue(_, _, {#state{}, any(), binary()}) -> ok.
-system_continue(_, _, {State, HandlerState, SoFar}) ->
-	handler_loop(State, HandlerState, SoFar).
+-spec system_continue(_, _, {#state{}, any(), parse_state()}) -> no_return().
+system_continue(_, _, {State, HandlerState, ParseState}) ->
+	loop(State, HandlerState, ParseState).
 
--spec system_terminate(any(), _, _, {#state{}, any(), binary()}) -> no_return().
+-spec system_terminate(any(), _, _, {#state{}, any(), parse_state()}) -> no_return().
 system_terminate(Reason, _, _, {State, HandlerState, _}) ->
 	%% @todo We should exit gracefully, if possible.
 	terminate(State, HandlerState, Reason).
 
--spec system_code_change(Misc, _, _, _) -> {ok, Misc} when Misc::{#state{}, any(), binary()}.
+-spec system_code_change(Misc, _, _, _)
+	-> {ok, Misc} when Misc::{#state{}, any(), parse_state()}.
 system_code_change(Misc, _, _, _) ->
 	{ok, Misc}.

+ 0 - 3
test/sys_SUITE.erl

@@ -112,9 +112,6 @@ proc_lib_initial_call_tls(Config) ->
 %% so that it doesn't eat up system messages. It should only
 %% flush messages that are specific to cowboy_http.
 
-%% @todo The cowboy_websocket module needs to have the functions
-%% handler_loop and websocket_payload_loop merged into one.
-
 bad_system_from_h1(Config) ->
 	doc("h1: Sending a system message with a bad From value results in a process crash."),
 	{ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), [{active, false}]),