Browse Source

Add an 'onrequest' hook for HTTP

This new protocol option is a fun.

It expects a single arg, the Req, and should only return a possibly
modified Req. This can be used for many things like URL rewriting,
access logging or listener-wide authentication.

If a reply is sent inside the hook, then Cowboy will consider the
request handled and will move on to the next one.
Loïc Hoguin 13 years ago
parent
commit
8e2cc3d7f1
2 changed files with 72 additions and 16 deletions
  1. 24 13
      src/cowboy_http_protocol.erl
  2. 48 3
      test/http_SUITE.erl

+ 24 - 13
src/cowboy_http_protocol.erl

@@ -47,6 +47,7 @@
 	transport :: module(),
 	transport :: module(),
 	dispatch :: cowboy_dispatcher:dispatch_rules(),
 	dispatch :: cowboy_dispatcher:dispatch_rules(),
 	handler :: {module(), any()},
 	handler :: {module(), any()},
+	onrequest :: undefined | fun((#http_req{}) -> #http_req{}),
 	urldecode :: {fun((binary(), T) -> binary()), T},
 	urldecode :: {fun((binary(), T) -> binary()), T},
 	req_empty_lines = 0 :: integer(),
 	req_empty_lines = 0 :: integer(),
 	max_empty_lines :: integer(),
 	max_empty_lines :: integer(),
@@ -77,6 +78,7 @@ init(ListenerPid, Socket, Transport, Opts) ->
 	MaxEmptyLines = proplists:get_value(max_empty_lines, Opts, 5),
 	MaxEmptyLines = proplists:get_value(max_empty_lines, Opts, 5),
 	MaxKeepalive = proplists:get_value(max_keepalive, Opts, infinity),
 	MaxKeepalive = proplists:get_value(max_keepalive, Opts, infinity),
 	MaxLineLength = proplists:get_value(max_line_length, Opts, 4096),
 	MaxLineLength = proplists:get_value(max_line_length, Opts, 4096),
+	OnRequest = proplists:get_value(onrequest, Opts),
 	Timeout = proplists:get_value(timeout, Opts, 5000),
 	Timeout = proplists:get_value(timeout, Opts, 5000),
 	URLDecDefault = {fun cowboy_http:urldecode/2, crash},
 	URLDecDefault = {fun cowboy_http:urldecode/2, crash},
 	URLDec = proplists:get_value(urldecode, Opts, URLDecDefault),
 	URLDec = proplists:get_value(urldecode, Opts, URLDecDefault),
@@ -84,7 +86,7 @@ init(ListenerPid, Socket, Transport, Opts) ->
 	wait_request(#state{listener=ListenerPid, socket=Socket, transport=Transport,
 	wait_request(#state{listener=ListenerPid, socket=Socket, transport=Transport,
 		dispatch=Dispatch, max_empty_lines=MaxEmptyLines,
 		dispatch=Dispatch, max_empty_lines=MaxEmptyLines,
 		max_keepalive=MaxKeepalive, max_line_length=MaxLineLength,
 		max_keepalive=MaxKeepalive, max_line_length=MaxLineLength,
-		timeout=Timeout, urldecode=URLDec}).
+		timeout=Timeout, onrequest=OnRequest, urldecode=URLDec}).
 
 
 %% @private
 %% @private
 -spec parse_request(#state{}) -> ok.
 -spec parse_request(#state{}) -> ok.
@@ -170,11 +172,11 @@ header({http_header, _I, 'Host', _R, RawHost}, Req=#http_req{
 	case catch cowboy_dispatcher:split_host(RawHost2) of
 	case catch cowboy_dispatcher:split_host(RawHost2) of
 		{Host, RawHost3, undefined} ->
 		{Host, RawHost3, undefined} ->
 			Port = default_port(Transport:name()),
 			Port = default_port(Transport:name()),
-			dispatch(fun parse_header/2, Req#http_req{
+			parse_header(Req#http_req{
 				host=Host, raw_host=RawHost3, port=Port,
 				host=Host, raw_host=RawHost3, port=Port,
 				headers=[{'Host', RawHost3}|Req#http_req.headers]}, State);
 				headers=[{'Host', RawHost3}|Req#http_req.headers]}, State);
 		{Host, RawHost3, Port} ->
 		{Host, RawHost3, Port} ->
-			dispatch(fun parse_header/2, Req#http_req{
+			parse_header(Req#http_req{
 				host=Host, raw_host=RawHost3, port=Port,
 				host=Host, raw_host=RawHost3, port=Port,
 				headers=[{'Host', RawHost3}|Req#http_req.headers]}, State);
 				headers=[{'Host', RawHost3}|Req#http_req.headers]}, State);
 		{'EXIT', _Reason} ->
 		{'EXIT', _Reason} ->
@@ -201,24 +203,33 @@ header(http_eoh, #http_req{version={1, 1}, host=undefined}, State) ->
 header(http_eoh, Req=#http_req{version={1, 0}, transport=Transport,
 header(http_eoh, Req=#http_req{version={1, 0}, transport=Transport,
 		host=undefined}, State=#state{buffer=Buffer}) ->
 		host=undefined}, State=#state{buffer=Buffer}) ->
 	Port = default_port(Transport:name()),
 	Port = default_port(Transport:name()),
-	dispatch(fun handler_init/2, Req#http_req{host=[], raw_host= <<>>,
+	onrequest(Req#http_req{host=[], raw_host= <<>>,
 		port=Port, buffer=Buffer}, State#state{buffer= <<>>});
 		port=Port, buffer=Buffer}, State#state{buffer= <<>>});
 header(http_eoh, Req, State=#state{buffer=Buffer}) ->
 header(http_eoh, Req, State=#state{buffer=Buffer}) ->
-	handler_init(Req#http_req{buffer=Buffer}, State#state{buffer= <<>>});
+	onrequest(Req#http_req{buffer=Buffer}, State#state{buffer= <<>>});
 header(_Any, _Req, State) ->
 header(_Any, _Req, State) ->
 	error_terminate(400, State).
 	error_terminate(400, State).
 
 
--spec dispatch(fun((#http_req{}, #state{}) -> ok),
-	#http_req{}, #state{}) -> ok.
-dispatch(Next, Req=#http_req{host=Host, path=Path},
+%% Call the global onrequest callback. The callback can send a reply,
+%% in which case we consider the request handled and move on to the next
+%% one. Note that since we haven't dispatched yet, we don't know the
+%% handler, host_info, path_info or bindings yet.
+-spec onrequest(#http_req{}, #state{}) -> ok.
+onrequest(Req, State=#state{onrequest=undefined}) ->
+	dispatch(Req, State);
+onrequest(Req, State=#state{onrequest=OnRequest}) ->
+	Req2 = OnRequest(Req),
+	case Req2#http_req.resp_state of
+		waiting -> dispatch(Req2, State);
+		_ -> next_request(Req2, State, ok)
+	end.
+
+-spec dispatch(#http_req{}, #state{}) -> ok.
+dispatch(Req=#http_req{host=Host, path=Path},
 		State=#state{dispatch=Dispatch}) ->
 		State=#state{dispatch=Dispatch}) ->
-	%% @todo We should allow a configurable chain of handlers here to
-	%%       allow things like url rewriting, site-wide authentication,
-	%%       optional dispatching, and more. It would default to what
-	%%       we are doing so far.
 	case cowboy_dispatcher:match(Host, Path, Dispatch) of
 	case cowboy_dispatcher:match(Host, Path, Dispatch) of
 		{ok, Handler, Opts, Binds, HostInfo, PathInfo} ->
 		{ok, Handler, Opts, Binds, HostInfo, PathInfo} ->
-			Next(Req#http_req{host_info=HostInfo, path_info=PathInfo,
+			handler_init(Req#http_req{host_info=HostInfo, path_info=PathInfo,
 				bindings=Binds}, State#state{handler={Handler, Opts}});
 				bindings=Binds}, State#state{handler={Handler, Opts}});
 		{error, notfound, host} ->
 		{error, notfound, host} ->
 			error_terminate(400, State);
 			error_terminate(400, State);

+ 48 - 3
test/http_SUITE.erl

@@ -31,11 +31,13 @@
 -export([http_10_hostless/1, http_10_chunkless/1]). %% misc.
 -export([http_10_hostless/1, http_10_chunkless/1]). %% misc.
 -export([rest_simple/1, rest_keepalive/1, rest_keepalive_post/1,
 -export([rest_simple/1, rest_keepalive/1, rest_keepalive_post/1,
 	rest_nodelete/1, rest_resource_etags/1]). %% rest.
 	rest_nodelete/1, rest_resource_etags/1]). %% rest.
+-export([onrequest/1, onrequest_reply/1]). %% hooks.
 
 
 %% ct.
 %% ct.
 
 
 all() ->
 all() ->
-	[{group, http}, {group, https}, {group, misc}, {group, rest}].
+	[{group, http}, {group, https}, {group, misc}, {group, rest},
+		{group, hooks}].
 
 
 groups() ->
 groups() ->
 	BaseTests = [http_200, http_404, handler_errors,
 	BaseTests = [http_200, http_404, handler_errors,
@@ -49,7 +51,8 @@ groups() ->
 	{https, [], BaseTests},
 	{https, [], BaseTests},
 	{misc, [], [http_10_hostless, http_10_chunkless]},
 	{misc, [], [http_10_hostless, http_10_chunkless]},
 	{rest, [], [rest_simple, rest_keepalive, rest_keepalive_post,
 	{rest, [], [rest_simple, rest_keepalive, rest_keepalive_post,
-		rest_nodelete, rest_resource_etags]}].
+		rest_nodelete, rest_resource_etags]},
+	{hooks, [], [onrequest, onrequest_reply]}].
 
 
 init_per_suite(Config) ->
 init_per_suite(Config) ->
 	application:start(inets),
 	application:start(inets),
@@ -104,7 +107,16 @@ init_per_group(rest, Config) ->
 			{[<<"nodelete">>], rest_nodelete_resource, []},
 			{[<<"nodelete">>], rest_nodelete_resource, []},
 			{[<<"resetags">>], rest_resource_etags, []}
 			{[<<"resetags">>], rest_resource_etags, []}
 	]}]}]),
 	]}]}]),
-	[{scheme, "http"},{port, Port}|Config].
+	[{scheme, "http"},{port, Port}|Config];
+init_per_group(hooks, Config) ->
+	Port = 33084,
+	{ok, _} = cowboy:start_listener(hooks, 100,
+		cowboy_tcp_transport, [{port, Port}],
+		cowboy_http_protocol, [
+			{dispatch, init_http_dispatch(Config)},
+			{onrequest, fun onrequest_hook/1}
+		]),
+	[{scheme, "http"}, {port, Port}|Config].
 
 
 end_per_group(https, Config) ->
 end_per_group(https, Config) ->
 	cowboy:stop_listener(https),
 	cowboy:stop_listener(https),
@@ -691,3 +703,36 @@ rest_resource_etags(Config) ->
 		"Host: localhost\r\n", "Connection: close\r\n",
 		"Host: localhost\r\n", "Connection: close\r\n",
 		"If-None-Match: \"etag-header-value\"\r\n", "\r\n"], Config)
 		"If-None-Match: \"etag-header-value\"\r\n", "\r\n"], Config)
 	end().
 	end().
+
+onrequest(Config) ->
+	{port, Port} = lists:keyfind(port, 1, Config),
+	{ok, Socket} = gen_tcp:connect("localhost", Port,
+		[binary, {active, false}, {packet, raw}]),
+	ok = gen_tcp:send(Socket, "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"),
+	{ok, Data} = gen_tcp:recv(Socket, 0, 6000),
+	{_, _} = binary:match(Data, <<"Server: Serenity">>),
+	{_, _} = binary:match(Data, <<"http_handler">>),
+	gen_tcp:close(Socket).
+
+onrequest_reply(Config) ->
+	{port, Port} = lists:keyfind(port, 1, Config),
+	{ok, Socket} = gen_tcp:connect("localhost", Port,
+		[binary, {active, false}, {packet, raw}]),
+	ok = gen_tcp:send(Socket, "GET /?reply=1 HTTP/1.1\r\nHost: localhost\r\n\r\n"),
+	{ok, Data} = gen_tcp:recv(Socket, 0, 6000),
+	{_, _} = binary:match(Data, <<"Server: Cowboy">>),
+	nomatch = binary:match(Data, <<"http_handler">>),
+	{_, _} = binary:match(Data, <<"replied!">>),
+	gen_tcp:close(Socket).
+
+onrequest_hook(Req) ->
+	case cowboy_http_req:qs_val(<<"reply">>, Req) of
+		{undefined, Req2} ->
+			{ok, Req3} = cowboy_http_req:set_resp_header(
+				'Server', <<"Serenity">>, Req2),
+			Req3;
+		{_, Req2} ->
+			{ok, Req3} = cowboy_http_req:reply(
+				200, [], <<"replied!">>, Req2),
+			Req3
+	end.