Browse Source

Merge branch 'response-hook'

Loïc Hoguin 13 years ago
parent
commit
0406a632dc
5 changed files with 98 additions and 37 deletions
  1. 2 0
      include/http.hrl
  2. 3 1
      src/cowboy_client.erl
  3. 12 5
      src/cowboy_http_protocol.erl
  4. 42 25
      src/cowboy_http_req.erl
  5. 39 6
      test/http_SUITE.erl

+ 2 - 0
include/http.hrl

@@ -52,5 +52,7 @@
 								fun(() -> {sent, non_neg_integer()})},
 
 	%% Functions.
+	onresponse = undefined :: undefined | fun((cowboy_http:status(),
+		cowboy_http:headers(), #http_req{}) -> #http_req{}),
 	urldecode :: {fun((binary(), T) -> binary()), T}
 }).

+ 3 - 1
src/cowboy_client.erl

@@ -158,7 +158,9 @@ response_body_loop(Client, Acc) ->
 		{ok, Data, Client2} ->
 			response_body_loop(Client2, << Acc/binary, Data/binary >>);
 		{done, Client2} ->
-			{ok, Acc, Client2}
+			{ok, Acc, Client2};
+		{error, Reason} ->
+			{error, Reason}
 	end.
 
 skip_body(Client=#client{state=response_body}) ->

+ 12 - 5
src/cowboy_http_protocol.erl

@@ -48,6 +48,8 @@
 	dispatch :: cowboy_dispatcher:dispatch_rules(),
 	handler :: {module(), any()},
 	onrequest :: undefined | fun((#http_req{}) -> #http_req{}),
+	onresponse = undefined :: undefined | fun((cowboy_http:status(),
+		cowboy_http:headers(), #http_req{}) -> #http_req{}),
 	urldecode :: {fun((binary(), T) -> binary()), T},
 	req_empty_lines = 0 :: integer(),
 	max_empty_lines :: integer(),
@@ -79,6 +81,7 @@ init(ListenerPid, Socket, Transport, Opts) ->
 	MaxKeepalive = proplists:get_value(max_keepalive, Opts, infinity),
 	MaxLineLength = proplists:get_value(max_line_length, Opts, 4096),
 	OnRequest = proplists:get_value(onrequest, Opts),
+	OnResponse = proplists:get_value(onresponse, Opts),
 	Timeout = proplists:get_value(timeout, Opts, 5000),
 	URLDecDefault = {fun cowboy_http:urldecode/2, crash},
 	URLDec = proplists:get_value(urldecode, Opts, URLDecDefault),
@@ -86,7 +89,8 @@ init(ListenerPid, Socket, Transport, Opts) ->
 	wait_request(#state{listener=ListenerPid, socket=Socket, transport=Transport,
 		dispatch=Dispatch, max_empty_lines=MaxEmptyLines,
 		max_keepalive=MaxKeepalive, max_line_length=MaxLineLength,
-		timeout=Timeout, onrequest=OnRequest, urldecode=URLDec}).
+		timeout=Timeout, onrequest=OnRequest, onresponse=OnResponse,
+		urldecode=URLDec}).
 
 %% @private
 -spec parse_request(#state{}) -> ok.
@@ -122,7 +126,7 @@ request({http_request, Method, {absoluteURI, _Scheme, _Host, _Port, Path},
 request({http_request, Method, {abs_path, AbsPath}, Version},
 		State=#state{socket=Socket, transport=Transport,
 		req_keepalive=Keepalive, max_keepalive=MaxKeepalive,
-		urldecode={URLDecFun, URLDecArg}=URLDec}) ->
+		onresponse=OnResponse, urldecode={URLDecFun, URLDecArg}=URLDec}) ->
 	URLDecode = fun(Bin) -> URLDecFun(Bin, URLDecArg) end,
 	{Path, RawPath, Qs} = cowboy_dispatcher:split_path(AbsPath, URLDecode),
 	ConnAtom = if Keepalive < MaxKeepalive -> version_to_connection(Version);
@@ -130,16 +134,19 @@ request({http_request, Method, {abs_path, AbsPath}, Version},
 	end,
 	parse_header(#http_req{socket=Socket, transport=Transport,
 		connection=ConnAtom, pid=self(), method=Method, version=Version,
-		path=Path, raw_path=RawPath, raw_qs=Qs, urldecode=URLDec}, State);
+		path=Path, raw_path=RawPath, raw_qs=Qs, onresponse=OnResponse,
+		urldecode=URLDec}, State);
 request({http_request, Method, '*', Version},
 		State=#state{socket=Socket, transport=Transport,
-		req_keepalive=Keepalive, max_keepalive=MaxKeepalive, urldecode=URLDec}) ->
+		req_keepalive=Keepalive, max_keepalive=MaxKeepalive,
+		onresponse=OnResponse, urldecode=URLDec}) ->
 	ConnAtom = if Keepalive < MaxKeepalive -> version_to_connection(Version);
 		true -> close
 	end,
 	parse_header(#http_req{socket=Socket, transport=Transport,
 		connection=ConnAtom, pid=self(), method=Method, version=Version,
-		path='*', raw_path= <<"*">>, raw_qs= <<>>, urldecode=URLDec}, State);
+		path='*', raw_path= <<"*">>, raw_qs= <<>>, onresponse=OnResponse,
+		urldecode=URLDec}, State);
 request({http_request, _Method, _URI, _Version}, State) ->
 	error_terminate(501, State);
 request({http_error, <<"\r\n">>},

+ 42 - 25
src/cowboy_http_req.erl

@@ -696,7 +696,7 @@ reply(Status, Headers, Req=#http_req{resp_body=Body}) ->
 -spec reply(cowboy_http:status(), cowboy_http:headers(), iodata(), #http_req{})
 	-> {ok, #http_req{}}.
 reply(Status, Headers, Body, Req=#http_req{socket=Socket, transport=Transport,
-		version=Version, connection=Connection, pid=ReqPid,
+		version=Version, connection=Connection,
 		method=Method, resp_state=waiting, resp_headers=RespHeaders}) ->
 	RespConn = response_connection(Headers, Connection),
 	ContentLen = case Body of {CL, _} -> CL; _ -> iolist_size(Body) end,
@@ -704,18 +704,20 @@ reply(Status, Headers, Body, Req=#http_req{socket=Socket, transport=Transport,
 		{1, 1} -> [{<<"Connection">>, atom_to_connection(Connection)}];
 		_ -> []
 	end,
-	response(Status, Headers, RespHeaders,  [
+	{ReplyType, Req2} = response(Status, Headers, RespHeaders,  [
 		{<<"Content-Length">>, integer_to_list(ContentLen)},
 		{<<"Date">>, cowboy_clock:rfc1123()},
 		{<<"Server">>, <<"Cowboy">>}
 	|HTTP11Headers], Req),
-	case {Method, Body} of
-		{'HEAD', _} -> ok;
-		{_, {_, StreamFun}} -> StreamFun();
-		{_, _} -> Transport:send(Socket, Body)
+	if	Method =:= 'HEAD' -> ok;
+		ReplyType =:= hook -> ok; %% Hook replied for us, stop there.
+		true ->
+			case Body of
+				{_, StreamFun} -> StreamFun();
+				_ -> Transport:send(Socket, Body)
+			end
 	end,
-	ReqPid ! {?MODULE, resp_sent},
-	{ok, Req#http_req{connection=RespConn, resp_state=done,
+	{ok, Req2#http_req{connection=RespConn, resp_state=done,
 		resp_headers=[], resp_body= <<>>}}.
 
 %% @equiv chunked_reply(Status, [], Req)
@@ -729,7 +731,7 @@ chunked_reply(Status, Req) ->
 	-> {ok, #http_req{}}.
 chunked_reply(Status, Headers, Req=#http_req{
 		version=Version, connection=Connection,
-		pid=ReqPid, resp_state=waiting, resp_headers=RespHeaders}) ->
+		resp_state=waiting, resp_headers=RespHeaders}) ->
 	RespConn = response_connection(Headers, Connection),
 	HTTP11Headers = case Version of
 		{1, 1} -> [
@@ -737,12 +739,11 @@ chunked_reply(Status, Headers, Req=#http_req{
 			{<<"Transfer-Encoding">>, <<"chunked">>}];
 		_ -> []
 	end,
-	response(Status, Headers, RespHeaders, [
+	{_, Req2} = response(Status, Headers, RespHeaders, [
 		{<<"Date">>, cowboy_clock:rfc1123()},
 		{<<"Server">>, <<"Cowboy">>}
 	|HTTP11Headers], Req),
-	ReqPid ! {?MODULE, resp_sent},
-	{ok, Req#http_req{connection=RespConn, resp_state=chunks,
+	{ok, Req2#http_req{connection=RespConn, resp_state=chunks,
 		resp_headers=[], resp_body= <<>>}}.
 
 %% @doc Send a chunk of data.
@@ -762,12 +763,11 @@ chunk(Data, #http_req{socket=Socket, transport=Transport, resp_state=chunks}) ->
 -spec upgrade_reply(cowboy_http:status(), cowboy_http:headers(), #http_req{})
 	-> {ok, #http_req{}}.
 upgrade_reply(Status, Headers, Req=#http_req{
-		pid=ReqPid, resp_state=waiting, resp_headers=RespHeaders}) ->
-	response(Status, Headers, RespHeaders, [
+		resp_state=waiting, resp_headers=RespHeaders}) ->
+	{_, Req2} = response(Status, Headers, RespHeaders, [
 		{<<"Connection">>, <<"Upgrade">>}
 	], Req),
-	ReqPid ! {?MODULE, resp_sent},
-	{ok, Req#http_req{resp_state=done, resp_headers=[], resp_body= <<>>}}.
+	{ok, Req2#http_req{resp_state=done, resp_headers=[], resp_body= <<>>}}.
 
 %% Misc API.
 
@@ -798,16 +798,33 @@ transport(#http_req{transport=Transport, socket=Socket}) ->
 %% Internal.
 
 -spec response(cowboy_http:status(), cowboy_http:headers(),
-	cowboy_http:headers(), cowboy_http:headers(), #http_req{}) -> ok.
-response(Status, Headers, RespHeaders, DefaultHeaders, #http_req{
-		socket=Socket, transport=Transport, version=Version}) ->
+	cowboy_http:headers(), cowboy_http:headers(), #http_req{})
+	-> {normal | hook, #http_req{}}.
+response(Status, Headers, RespHeaders, DefaultHeaders, Req=#http_req{
+		socket=Socket, transport=Transport, version=Version,
+		pid=ReqPid, onresponse=OnResponse}) ->
 	FullHeaders = response_merge_headers(Headers, RespHeaders, DefaultHeaders),
-	%% @todo 'onresponse' hook here.
-	HTTPVer = cowboy_http:version_to_binary(Version),
-	StatusLine = << HTTPVer/binary, " ", (status(Status))/binary, "\r\n" >>,
-	HeaderLines = [[Key, <<": ">>, Value, <<"\r\n">>]
-		|| {Key, Value} <- FullHeaders],
-	Transport:send(Socket, [StatusLine, HeaderLines, <<"\r\n">>]).
+	Req2 = case OnResponse of
+		undefined -> Req;
+		OnResponse -> OnResponse(Status, FullHeaders,
+			%% Don't call 'onresponse' from the hook itself.
+			Req#http_req{resp_headers=[], resp_body= <<>>,
+				onresponse=undefined})
+	end,
+	ReplyType = case Req2#http_req.resp_state of
+		waiting ->
+			HTTPVer = cowboy_http:version_to_binary(Version),
+			StatusLine = << HTTPVer/binary, " ",
+				(status(Status))/binary, "\r\n" >>,
+			HeaderLines = [[Key, <<": ">>, Value, <<"\r\n">>]
+				|| {Key, Value} <- FullHeaders],
+			Transport:send(Socket, [StatusLine, HeaderLines, <<"\r\n">>]),
+			ReqPid ! {?MODULE, resp_sent},
+			normal;
+		_ ->
+			hook
+	end,
+	{ReplyType, Req2}.
 
 -spec response_connection(cowboy_http:headers(), keepalive | close)
 	-> keepalive | close.

+ 39 - 6
test/http_SUITE.erl

@@ -44,6 +44,7 @@
 -export([nc_zero/1]).
 -export([onrequest/1]).
 -export([onrequest_reply/1]).
+-export([onresponse_reply/1]).
 -export([pipeline/1]).
 -export([rest_keepalive/1]).
 -export([rest_keepalive_post/1]).
@@ -66,7 +67,7 @@
 %% ct.
 
 all() ->
-	[{group, http}, {group, https}, {group, hooks}].
+	[{group, http}, {group, https}, {group, onrequest}, {group, onresponse}].
 
 groups() ->
 	Tests = [
@@ -108,9 +109,12 @@ groups() ->
 	[
 		{http, [], Tests},
 		{https, [], Tests},
-		{hooks, [], [
+		{onrequest, [], [
 			onrequest,
 			onrequest_reply
+		]},
+		{onresponse, [], [
+			onresponse_reply
 		]}
 	].
 
@@ -160,10 +164,10 @@ init_per_group(https, Config) ->
 	{ok, Client} = cowboy_client:init(Opts),
 	[{scheme, <<"https">>}, {port, Port}, {opts, Opts},
 		{transport, Transport}, {client, Client}|Config1];
-init_per_group(hooks, Config) ->
+init_per_group(onrequest, Config) ->
 	Port = 33082,
 	Transport = cowboy_tcp_transport,
-	{ok, _} = cowboy:start_listener(hooks, 100,
+	{ok, _} = cowboy:start_listener(onrequest, 100,
 		Transport, [{port, Port}],
 		cowboy_http_protocol, [
 			{dispatch, init_dispatch(Config)},
@@ -173,6 +177,20 @@ init_per_group(hooks, Config) ->
 		]),
 	{ok, Client} = cowboy_client:init([]),
 	[{scheme, <<"http">>}, {port, Port}, {opts, []},
+		{transport, Transport}, {client, Client}|Config];
+init_per_group(onresponse, Config) ->
+	Port = 33083,
+	Transport = cowboy_tcp_transport,
+	{ok, _} = cowboy:start_listener(onresponse, 100,
+		Transport, [{port, Port}],
+		cowboy_http_protocol, [
+			{dispatch, init_dispatch(Config)},
+			{max_keepalive, 50},
+			{onresponse, fun onresponse_hook/3},
+			{timeout, 500}
+		]),
+	{ok, Client} = cowboy_client:init([]),
+	[{scheme, <<"http">>}, {port, Port}, {opts, []},
 		{transport, Transport}, {client, Client}|Config].
 
 end_per_group(https, Config) ->
@@ -185,8 +203,8 @@ end_per_group(https, Config) ->
 end_per_group(http, Config) ->
 	cowboy:stop_listener(http),
 	end_static_dir(Config);
-end_per_group(hooks, _) ->
-	cowboy:stop_listener(hooks),
+end_per_group(Name, _) ->
+	cowboy:stop_listener(Name),
 	ok.
 
 %% Dispatch configuration.
@@ -570,6 +588,21 @@ onrequest_hook(Req) ->
 			Req3
 	end.
 
+onresponse_reply(Config) ->
+	Client = ?config(client, Config),
+	{ok, Client2} = cowboy_client:request(<<"GET">>,
+		build_url("/", Config), Client),
+	{ok, 777, Headers, Client3} = cowboy_client:response(Client2),
+	{<<"x-hook">>, <<"onresponse">>} = lists:keyfind(<<"x-hook">>, 1, Headers),
+	%% Make sure we don't get the body initially sent.
+	{error, closed} = cowboy_client:response_body(Client3).
+
+%% Hook for the above onresponse tests.
+onresponse_hook(_, Headers, Req) ->
+	{ok, Req2} = cowboy_http_req:reply(
+		<<"777 Lucky">>, [{<<"x-hook">>, <<"onresponse">>}|Headers], Req),
+	Req2.
+
 pipeline(Config) ->
 	Client = ?config(client, Config),
 	{ok, Client2} = cowboy_client:request(<<"GET">>,