Browse Source

Introduce the req_filter Websocket option

This option allows customizing the compacting of the Req object
when using Websocket. By default it will keep most public fields
excluding headers of course, since those can be large.
Loïc Hoguin 8 years ago
parent
commit
5f421f93bc

+ 9 - 1
doc/src/manual/cowboy_websocket.asciidoc

@@ -134,7 +134,8 @@ timeout::
 ----
 opts() :: #{
     compress => boolean(),
-    idle_timeout => timeout()
+    idle_timeout => timeout(),
+    req_filter => fun((cowboy_req:req()) -> map())
 }
 ----
 
@@ -162,6 +163,13 @@ idle_timeout (60000)::
     connection open without receiving anything from
     the client.
 
+req_filter::
+    A function applied to the Req to compact it and
+    only keep required information. The Req is only
+    given back in the `terminate/3` callback. By default
+    it keeps the method, version, URI components and peer
+    information.
+
 == Changelog
 
 * *2.0*: The Req object is no longer passed to Websocket callbacks.

+ 11 - 7
src/cowboy_websocket.erl

@@ -47,14 +47,13 @@
 -callback websocket_info(any(), State)
 	-> call_result(State) when State::any().
 
-%% @todo OK this I am not sure what to do about it. We don't have a Req anymore.
-%% We probably should have a websocket_terminate instead.
 -callback terminate(any(), cowboy_req:req(), any()) -> ok.
 -optional_callbacks([terminate/3]).
 
 -type opts() :: #{
+	compress => boolean(),
 	idle_timeout => timeout(),
-	compress => boolean()
+	req_filter => fun((cowboy_req:req()) -> map())
 }.
 -export_type([opts/0]).
 
@@ -71,7 +70,8 @@
 	frag_state = undefined :: cow_ws:frag_state(),
 	frag_buffer = <<>> :: binary(),
 	utf8_state = 0 :: cow_ws:utf8_state(),
-	extensions = #{} :: map()
+	extensions = #{} :: map(),
+	req = #{} :: map()
 }).
 
 %% Stream process.
@@ -90,7 +90,11 @@ upgrade(Req, Env, Handler, HandlerState) ->
 upgrade(Req0, Env, Handler, HandlerState, Opts) ->
 	Timeout = maps:get(idle_timeout, Opts, 60000),
 	Compress = maps:get(compress, Opts, false),
-	State0 = #state{handler=Handler, timeout=Timeout, compress=Compress},
+	FilteredReq = case maps:get(req_filter, Opts, undefined) of
+		undefined -> maps:with([method, version, scheme, host, port, path, qs, peer], Req0);
+		FilterFun -> FilterFun(Req0)
+	end,
+	State0 = #state{handler=Handler, timeout=Timeout, compress=Compress, req=FilteredReq},
 	try websocket_upgrade(State0, Req0) of
 		{ok, State, Req} ->
 			websocket_handshake(State, Req, HandlerState, Env)
@@ -417,5 +421,5 @@ terminate(State, HandlerState, Reason) ->
 	handler_terminate(State, HandlerState, Reason),
 	exit(normal).
 
-handler_terminate(#state{handler=Handler}, HandlerState, Reason) ->
-	cowboy_handler:terminate(Reason, undefined, HandlerState, Handler).
+handler_terminate(#state{handler=Handler, req=Req}, HandlerState, Reason) ->
+	cowboy_handler:terminate(Reason, Req, HandlerState, Handler).

+ 34 - 0
test/handlers/ws_terminate_h.erl

@@ -0,0 +1,34 @@
+%% This module sends a message with terminate arguments to the test case process.
+
+-module(ws_terminate_h).
+-behavior(cowboy_websocket).
+
+-export([init/2]).
+-export([websocket_init/1]).
+-export([websocket_handle/2]).
+-export([websocket_info/2]).
+-export([terminate/3]).
+
+-record(state, {
+	pid
+}).
+
+init(Req, _) ->
+	Pid = list_to_pid(binary_to_list(cowboy_req:header(<<"x-test-pid">>, Req))),
+	Opts = case cowboy_req:qs(Req) of
+		<<"req_filter">> -> #{req_filter => fun(_) -> filtered end};
+		_ -> #{}
+	end,
+	{cowboy_websocket, Req, #state{pid=Pid}, Opts}.
+
+websocket_init(State) ->
+	{ok, State}.
+
+websocket_handle(_, State) ->
+	{ok, State}.
+
+websocket_info(_, State) ->
+	{ok, State}.
+
+terminate(Reason, Req, #state{pid=Pid}) ->
+	Pid ! {terminate, Reason, Req}.

+ 34 - 0
test/ws_SUITE.erl

@@ -79,6 +79,7 @@ init_dispatch() ->
 					{text, <<"won't be received">>}]}
 			]},
 			{"/ws_subprotocol", ws_subprotocol, []},
+			{"/terminate", ws_terminate_h, []},
 			{"/ws_timeout_hibernate", ws_timeout_hibernate, []},
 			{"/ws_timeout_cancel", ws_timeout_cancel, []}
 		]}
@@ -355,6 +356,39 @@ ws_subprotocol(Config) ->
 	{_, "foo"} = lists:keyfind("sec-websocket-protocol", 1, Headers),
 	ok.
 
+ws_terminate(Config) ->
+	doc("The Req object is kept in a more compact form by default."),
+	{ok, Socket, _} = do_handshake("/terminate",
+		"x-test-pid: " ++ pid_to_list(self()) ++ "\r\n", Config),
+	%% Send a close frame.
+	ok = gen_tcp:send(Socket, << 1:1, 0:3, 8:4, 1:1, 0:7, 0:32 >>),
+	{ok, << 1:1, 0:3, 8:4, 0:8 >>} = gen_tcp:recv(Socket, 0, 6000),
+	{error, closed} = gen_tcp:recv(Socket, 0, 6000),
+	%% Confirm terminate/3 was called with a compacted Req.
+	receive {terminate, _, Req} ->
+		true = maps:is_key(path, Req),
+		false = maps:is_key(headers, Req),
+		ok
+	after 1000 ->
+		error(timeout)
+	end.
+
+ws_terminate_fun(Config) ->
+	doc("A function can be given to filter the Req object."),
+	{ok, Socket, _} = do_handshake("/terminate?req_filter",
+		"x-test-pid: " ++ pid_to_list(self()) ++ "\r\n", Config),
+	%% Send a close frame.
+	ok = gen_tcp:send(Socket, << 1:1, 0:3, 8:4, 1:1, 0:7, 0:32 >>),
+	{ok, << 1:1, 0:3, 8:4, 0:8 >>} = gen_tcp:recv(Socket, 0, 6000),
+	{error, closed} = gen_tcp:recv(Socket, 0, 6000),
+	%% Confirm terminate/3 was called with a compacted Req.
+	receive {terminate, _, Req} ->
+		filtered = Req,
+		ok
+	after 1000 ->
+		error(timeout)
+	end.
+
 ws_text_fragments(Config) ->
 	doc("Client sends fragmented text frames."),
 	{ok, Socket, _} = do_handshake("/ws_echo", Config),