Просмотр исходного кода

Allow HTTP protocol upgrades to use keepalive

REST needed this to be allowed to chain requests on the same connection.
Loïc Hoguin 13 лет назад
Родитель
Сommit
8d2102fe11
4 измененных файлов с 79 добавлено и 48 удалено
  1. 21 9
      src/cowboy_http_protocol.erl
  2. 5 3
      src/cowboy_http_rest.erl
  3. 28 33
      src/cowboy_http_websocket.erl
  4. 25 3
      test/http_SUITE.erl

+ 21 - 9
src/cowboy_http_protocol.erl

@@ -219,8 +219,8 @@ dispatch(Next, Req=#http_req{host=Host, path=Path},
 	end.
 
 -spec handler_init(#http_req{}, #state{}) -> ok | none().
-handler_init(Req, State=#state{listener=ListenerPid,
-		transport=Transport, handler={Handler, Opts}}) ->
+handler_init(Req, State=#state{transport=Transport,
+		handler={Handler, Opts}}) ->
 	try Handler:init({Transport:name(), http}, Req, Opts) of
 		{ok, Req2, HandlerState} ->
 			handler_handle(HandlerState, Req2, State);
@@ -239,7 +239,7 @@ handler_init(Req, State=#state{listener=ListenerPid,
 			handler_terminate(HandlerState, Req2, State);
 		%% @todo {upgrade, transport, Module}
 		{upgrade, protocol, Module} ->
-			Module:upgrade(ListenerPid, Handler, Opts, Req)
+			upgrade_protocol(Req, State, Module)
 	catch Class:Reason ->
 		error_terminate(500, State),
 		error_logger:error_msg(
@@ -250,11 +250,19 @@ handler_init(Req, State=#state{listener=ListenerPid,
 			[Handler, Class, Reason, Opts, Req, erlang:get_stacktrace()])
 	end.
 
+-spec upgrade_protocol(#http_req{}, #state{}, atom()) -> ok | none().
+upgrade_protocol(Req, State=#state{listener=ListenerPid,
+		handler={Handler, Opts}}, Module) ->
+	case Module:upgrade(ListenerPid, Handler, Opts, Req) of
+		{UpgradeRes, Req2} -> next_request(Req2, State, UpgradeRes);
+		_Any -> terminate(State)
+	end.
+
 -spec handler_handle(any(), #http_req{}, #state{}) -> ok | none().
 handler_handle(HandlerState, Req, State=#state{handler={Handler, Opts}}) ->
 	try Handler:handle(Req, HandlerState) of
 		{ok, Req2, HandlerState2} ->
-			next_request(HandlerState2, Req2, State)
+			terminate_request(HandlerState2, Req2, State)
 	catch Class:Reason ->
 		error_logger:error_msg(
 			"** Handler ~p terminating in handle/2~n"
@@ -294,7 +302,7 @@ handler_loop_timeout(State=#state{loop_timeout=Timeout,
 handler_loop(HandlerState, Req, State=#state{loop_timeout_ref=TRef}) ->
 	receive
 		{?MODULE, timeout, TRef} ->
-			next_request(HandlerState, Req, State);
+			terminate_request(HandlerState, Req, State);
 		{?MODULE, timeout, OlderTRef} when is_reference(OlderTRef) ->
 			handler_loop(HandlerState, Req, State);
 		Message ->
@@ -306,7 +314,7 @@ handler_call(HandlerState, Req, State=#state{handler={Handler, Opts}},
 		Message) ->
 	try Handler:info(Message, Req, HandlerState) of
 		{ok, Req2, HandlerState2} ->
-			next_request(HandlerState2, Req2, State);
+			terminate_request(HandlerState2, Req2, State);
 		{loop, Req2, HandlerState2} ->
 			handler_before_loop(HandlerState2, Req2, State);
 		{loop, Req2, HandlerState2, hibernate} ->
@@ -336,10 +344,14 @@ handler_terminate(HandlerState, Req, #state{handler={Handler, Opts}}) ->
 			 HandlerState, Req, erlang:get_stacktrace()])
 	end.
 
--spec next_request(any(), #http_req{}, #state{}) -> ok | none().
-next_request(HandlerState, Req=#http_req{connection=Conn, buffer=Buffer},
-		State) ->
+-spec terminate_request(any(), #http_req{}, #state{}) -> ok | none().
+terminate_request(HandlerState, Req, State) ->
 	HandlerRes = handler_terminate(HandlerState, Req, State),
+	next_request(Req, State, HandlerRes).
+
+-spec next_request(#http_req{}, #state{}, any()) -> ok | none().
+next_request(Req=#http_req{connection=Conn, buffer=Buffer},
+		State, HandlerRes) ->
 	BodyRes = ensure_body_processed(Req),
 	RespRes = ensure_response(Req),
 	case {HandlerRes, BodyRes, RespRes, Conn} of

+ 5 - 3
src/cowboy_http_rest.erl

@@ -53,7 +53,7 @@
 %% You do not need to call this function manually. To upgrade to the REST
 %% protocol, you simply need to return <em>{upgrade, protocol, {@module}}</em>
 %% in your <em>cowboy_http_handler:init/3</em> handler function.
--spec upgrade(pid(), module(), any(), #http_req{}) -> ok.
+-spec upgrade(pid(), module(), any(), #http_req{}) -> {ok, #http_req{}}.
 upgrade(_ListenerPid, Handler, Opts, Req) ->
 	try
 		case erlang:function_exported(Handler, rest_init, 2) of
@@ -753,6 +753,8 @@ respond(Req, State, StatusCode) ->
 
 terminate(Req, #state{handler=Handler, handler_state=HandlerState}) ->
 	case erlang:function_exported(Handler, rest_terminate, 2) of
-		true -> ok = Handler:rest_terminate(Req, HandlerState);
+		true -> ok = Handler:rest_terminate(
+			Req#http_req{resp_state=locked}, HandlerState);
 		false -> ok
-	end.
+	end,
+	{ok, Req}.

+ 28 - 33
src/cowboy_http_websocket.erl

@@ -64,7 +64,7 @@
 %% You do not need to call this function manually. To upgrade to the WebSocket
 %% protocol, you simply need to return <em>{upgrade, protocol, {@module}}</em>
 %% in your <em>cowboy_http_handler:init/3</em> handler function.
--spec upgrade(pid(), module(), any(), #http_req{}) -> ok | none().
+-spec upgrade(pid(), module(), any(), #http_req{}) -> closed | none().
 upgrade(ListenerPid, Handler, Opts, Req) ->
 	cowboy_listener:move_connection(ListenerPid, websocket, self()),
 	case catch websocket_upgrade(#state{handler=Handler, opts=Opts}, Req) of
@@ -113,7 +113,7 @@ websocket_upgrade(Version, State, Req)
 	IntVersion = list_to_integer(binary_to_list(Version)),
 	{ok, State#state{version=IntVersion, challenge=Challenge}, Req2}.
 
--spec handler_init(#state{}, #http_req{}) -> ok | none().
+-spec handler_init(#state{}, #http_req{}) -> closed | none().
 handler_init(State=#state{handler=Handler, opts=Opts},
 		Req=#http_req{transport=Transport}) ->
 	try Handler:websocket_init(Transport:name(), Req, Opts) of
@@ -139,31 +139,27 @@ handler_init(State=#state{handler=Handler, opts=Opts},
 			[Handler, Class, Reason, Opts, Req, erlang:get_stacktrace()])
 	end.
 
--spec upgrade_error(#http_req{}) -> ok.
+-spec upgrade_error(#http_req{}) -> closed.
 upgrade_error(Req) ->
-	{ok, Req2} = cowboy_http_req:reply(400, [], [],
+	{ok, _Req2} = cowboy_http_req:reply(400, [], [],
 		Req#http_req{resp_state=waiting}),
-	upgrade_terminate(Req2).
+	closed.
 
 %% @see cowboy_http_protocol:ensure_response/1
--spec upgrade_denied(#http_req{}) -> ok.
-upgrade_denied(Req=#http_req{resp_state=done}) ->
-	upgrade_terminate(Req);
+-spec upgrade_denied(#http_req{}) -> closed.
+upgrade_denied(#http_req{resp_state=done}) ->
+	closed;
 upgrade_denied(Req=#http_req{resp_state=waiting}) ->
-	{ok, Req2} = cowboy_http_req:reply(400, [], [], Req),
-	upgrade_terminate(Req2);
-upgrade_denied(Req=#http_req{method='HEAD', resp_state=chunks}) ->
-	upgrade_terminate(Req);
-upgrade_denied(Req=#http_req{socket=Socket, transport=Transport,
+	{ok, _Req2} = cowboy_http_req:reply(400, [], [], Req),
+	closed;
+upgrade_denied(#http_req{method='HEAD', resp_state=chunks}) ->
+	closed;
+upgrade_denied(#http_req{socket=Socket, transport=Transport,
 		resp_state=chunks}) ->
 	Transport:send(Socket, <<"0\r\n\r\n">>),
-	upgrade_terminate(Req).
+	closed.
 
--spec upgrade_terminate(#http_req{}) -> ok.
-upgrade_terminate(#http_req{socket=Socket, transport=Transport}) ->
-	Transport:close(Socket).
-
--spec websocket_handshake(#state{}, #http_req{}, any()) -> ok | none().
+-spec websocket_handshake(#state{}, #http_req{}, any()) -> closed | none().
 websocket_handshake(State=#state{version=0, origin=Origin,
 		challenge={Key1, Key2}}, Req=#http_req{socket=Socket,
 		transport=Transport, raw_host=Host, port=Port,
@@ -185,7 +181,7 @@ websocket_handshake(State=#state{version=0, origin=Origin,
 			handler_before_loop(State#state{messages=Transport:messages()},
 				Req3, HandlerState, <<>>);
 		_Any ->
-			ok %% If an error happened reading the body, stop there.
+			closed %% If an error happened reading the body, stop there.
 	end;
 websocket_handshake(State=#state{challenge=Challenge},
 		Req=#http_req{transport=Transport}, HandlerState) ->
@@ -197,7 +193,7 @@ websocket_handshake(State=#state{challenge=Challenge},
 	handler_before_loop(State#state{messages=Transport:messages()},
 		Req2, HandlerState, <<>>).
 
--spec handler_before_loop(#state{}, #http_req{}, any(), binary()) -> ok | none().
+-spec handler_before_loop(#state{}, #http_req{}, any(), binary()) -> closed | none().
 handler_before_loop(State=#state{hibernate=true},
 		Req=#http_req{socket=Socket, transport=Transport},
 		HandlerState, SoFar) ->
@@ -222,7 +218,7 @@ handler_loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) ->
 	State#state{timeout_ref=TRef}.
 
 %% @private
--spec handler_loop(#state{}, #http_req{}, any(), binary()) -> ok | none().
+-spec handler_loop(#state{}, #http_req{}, any(), binary()) -> closed | none().
 handler_loop(State=#state{messages={OK, Closed, Error}, timeout_ref=TRef},
 		Req=#http_req{socket=Socket}, HandlerState, SoFar) ->
 	receive
@@ -242,7 +238,7 @@ handler_loop(State=#state{messages={OK, Closed, Error}, timeout_ref=TRef},
 				SoFar, websocket_info, Message, fun handler_before_loop/4)
 	end.
 
--spec websocket_data(#state{}, #http_req{}, any(), binary()) -> ok | none().
+-spec websocket_data(#state{}, #http_req{}, any(), binary()) -> closed | none().
 %% No more data.
 websocket_data(State, Req, HandlerState, <<>>) ->
 	handler_before_loop(State, Req, HandlerState, <<>>);
@@ -296,14 +292,14 @@ websocket_data(State, Req, HandlerState, _Bad) ->
 
 %% hybi unmasking.
 -spec websocket_unmask(#state{}, #http_req{}, any(), binary(),
-	opcode(), binary(), mask_key()) -> ok | none().
+	opcode(), binary(), mask_key()) -> closed | none().
 websocket_unmask(State, Req, HandlerState, RemainingData,
 		Opcode, Payload, MaskKey) ->
 	websocket_unmask(State, Req, HandlerState, RemainingData,
 		Opcode, Payload, MaskKey, <<>>).
 
 -spec websocket_unmask(#state{}, #http_req{}, any(), binary(),
-	opcode(), binary(), mask_key(), binary()) -> ok | none().
+	opcode(), binary(), mask_key(), binary()) -> closed | none().
 websocket_unmask(State, Req, HandlerState, RemainingData,
 		Opcode, << O:32, Rest/bits >>, MaskKey, Acc) ->
 	T = O bxor MaskKey,
@@ -334,7 +330,7 @@ websocket_unmask(State, Req, HandlerState, RemainingData,
 
 %% hybi dispatching.
 -spec websocket_dispatch(#state{}, #http_req{}, any(), binary(),
-	opcode(), binary()) -> ok | none().
+	opcode(), binary()) -> closed | none().
 %% @todo Fragmentation.
 %~ websocket_dispatch(State, Req, HandlerState, RemainingData, 0, Payload) ->
 %% Text frame.
@@ -362,7 +358,7 @@ websocket_dispatch(State, Req, HandlerState, RemainingData, 10, Payload) ->
 		websocket_handle, {pong, Payload}, fun websocket_data/4).
 
 -spec handler_call(#state{}, #http_req{}, any(), binary(),
-	atom(), any(), fun()) -> ok | none().
+	atom(), any(), fun()) -> closed | none().
 handler_call(State=#state{handler=Handler, opts=Opts}, Req, HandlerState,
 		RemainingData, Callback, Message, NextState) ->
 	try Handler:Callback(Message, Req, HandlerState) of
@@ -391,7 +387,7 @@ handler_call(State=#state{handler=Handler, opts=Opts}, Req, HandlerState,
 		websocket_close(State, Req, HandlerState, {error, handler})
 	end.
 
--spec websocket_send(binary(), #state{}, #http_req{}) -> ok | ignore.
+-spec websocket_send(binary(), #state{}, #http_req{}) -> closed | ignore.
 %% hixie-76 text frame.
 websocket_send({text, Payload}, #state{version=0},
 		#http_req{socket=Socket, transport=Transport}) ->
@@ -411,21 +407,19 @@ websocket_send({Type, Payload}, _State,
 	Transport:send(Socket, [<< 1:1, 0:3, Opcode:4, 0:1, Len/bits >>,
 		Payload]).
 
--spec websocket_close(#state{}, #http_req{}, any(), {atom(), atom()}) -> ok.
+-spec websocket_close(#state{}, #http_req{}, any(), {atom(), atom()}) -> closed.
 websocket_close(State=#state{version=0}, Req=#http_req{socket=Socket,
 		transport=Transport}, HandlerState, Reason) ->
 	Transport:send(Socket, << 255, 0 >>),
-	Transport:close(Socket),
 	handler_terminate(State, Req, HandlerState, Reason);
 %% @todo Send a Payload? Using Reason is usually good but we're quite careless.
 websocket_close(State, Req=#http_req{socket=Socket,
 		transport=Transport}, HandlerState, Reason) ->
 	Transport:send(Socket, << 1:1, 0:3, 8:4, 0:8 >>),
-	Transport:close(Socket),
 	handler_terminate(State, Req, HandlerState, Reason).
 
 -spec handler_terminate(#state{}, #http_req{},
-	any(), atom() | {atom(), atom()}) -> ok.
+	any(), atom() | {atom(), atom()}) -> closed.
 handler_terminate(#state{handler=Handler, opts=Opts},
 		Req, HandlerState, TerminateReason) ->
 	try
@@ -438,7 +432,8 @@ handler_terminate(#state{handler=Handler, opts=Opts},
 			"** Request was ~p~n** Stacktrace: ~p~n~n",
 			[Handler, Class, Reason, TerminateReason, Opts,
 			 HandlerState, Req, erlang:get_stacktrace()])
-	end.
+	end,
+	closed.
 
 %% hixie-76 specific.
 

+ 25 - 3
test/http_SUITE.erl

@@ -25,7 +25,7 @@
 	set_resp_overwrite/1, set_resp_body/1, response_as_req/1]). %% http.
 -export([http_200/1, http_404/1]). %% http and https.
 -export([http_10_hostless/1]). %% misc.
--export([rest_simple/1]). %% rest.
+-export([rest_simple/1, rest_keepalive/1]). %% rest.
 
 %% ct.
 
@@ -41,7 +41,7 @@ groups() ->
 		set_resp_body, response_as_req] ++ BaseTests},
 	{https, [], BaseTests},
 	{misc, [], [http_10_hostless]},
-	{rest, [], [rest_simple]}].
+	{rest, [], [rest_simple, rest_keepalive]}].
 
 init_per_suite(Config) ->
 	application:start(inets),
@@ -299,7 +299,12 @@ ws0(Config) ->
 	{ok, << 0, "websocket_handle", 255 >>} = gen_tcp:recv(Socket, 0, 6000),
 	{ok, << 0, "websocket_handle", 255 >>} = gen_tcp:recv(Socket, 0, 6000),
 	{ok, << 0, "websocket_handle", 255 >>} = gen_tcp:recv(Socket, 0, 6000),
-	ok = gen_tcp:send(Socket, << 255, 0 >>),
+	%% We try to send another HTTP request to make sure
+	%% the server closed the request.
+	ok = gen_tcp:send(Socket, [
+		<< 255, 0 >>, %% Close websocket command.
+		"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n" %% Server should ignore it.
+	]),
 	{ok, << 255, 0 >>} = gen_tcp:recv(Socket, 0, 6000),
 	{error, closed} = gen_tcp:recv(Socket, 0, 6000),
 	ok.
@@ -574,3 +579,20 @@ http_10_hostless(Config) ->
 rest_simple(Config) ->
 	Packet = "GET /simple HTTP/1.1\r\nHost: localhost\r\n\r\n",
 	{Packet, 200} = raw_req(Packet, Config).
+
+rest_keepalive(Config) ->
+	{port, Port} = lists:keyfind(port, 1, Config),
+	{ok, Socket} = gen_tcp:connect("localhost", Port,
+		[binary, {active, false}, {packet, raw}]),
+	ok = rest_keepalive_loop(Socket, 100),
+	ok = gen_tcp:close(Socket).
+
+rest_keepalive_loop(_Socket, 0) ->
+	ok;
+rest_keepalive_loop(Socket, N) ->
+	ok = gen_tcp:send(Socket, "GET /simple HTTP/1.1\r\n"
+		"Host: localhost\r\nConnection: keep-alive\r\n\r\n"),
+	{ok, Data} = gen_tcp:recv(Socket, 0, 6000),
+	{0, 12} = binary:match(Data, <<"HTTP/1.1 200">>),
+	nomatch = binary:match(Data, <<"Connection: close">>),
+	rest_keepalive_loop(Socket, N - 1).