Browse Source

Send a 426 when Websocket is required and client didn't upgrade

Loïc Hoguin 7 years ago
parent
commit
c2b813684e
2 changed files with 46 additions and 13 deletions
  1. 25 12
      src/cowboy_websocket.erl
  2. 21 1
      test/rfc7231_SUITE.erl

+ 25 - 12
src/cowboy_websocket.erl

@@ -97,7 +97,12 @@ upgrade(Req0, Env, Handler, HandlerState, Opts) ->
 	State0 = #state{handler=Handler, timeout=Timeout, compress=Compress, req=FilteredReq},
 	try websocket_upgrade(State0, Req0) of
 		{ok, State, Req} ->
-			websocket_handshake(State, Req, HandlerState, Env)
+			websocket_handshake(State, Req, HandlerState, Env);
+		{error, upgrade_required} ->
+			{ok, cowboy_req:reply(426, #{
+				<<"connection">> => <<"upgrade">>,
+				<<"upgrade">> => <<"websocket">>
+			}, Req0), Env}
 	catch _:_ ->
 		%% @todo Probably log something here?
 		%% @todo Test that we can have 2 /ws 400 status code in a row on the same connection.
@@ -108,17 +113,25 @@ upgrade(Req0, Env, Handler, HandlerState, Opts) ->
 -spec websocket_upgrade(#state{}, Req)
 	-> {ok, #state{}, Req} when Req::cowboy_req:req().
 websocket_upgrade(State, Req) ->
-	ConnTokens = cowboy_req:parse_header(<<"connection">>, Req),
-	true = lists:member(<<"upgrade">>, ConnTokens),
-	%% @todo Should probably send a 426 if the Upgrade header is missing.
-	[<<"websocket">>] = cowboy_req:parse_header(<<"upgrade">>, Req),
-	Version = cowboy_req:header(<<"sec-websocket-version">>, Req),
-	IntVersion = binary_to_integer(Version),
-	true = (IntVersion =:= 7) orelse (IntVersion =:= 8)
-		orelse (IntVersion =:= 13),
-	Key = cowboy_req:header(<<"sec-websocket-key">>, Req),
-	false = Key =:= undefined,
-	websocket_extensions(State#state{key=Key}, Req#{websocket_version => IntVersion}).
+	ConnTokens = cowboy_req:parse_header(<<"connection">>, Req, []),
+	case lists:member(<<"upgrade">>, ConnTokens) of
+		false ->
+			{error, upgrade_required};
+		true ->
+			UpgradeTokens = cowboy_req:parse_header(<<"upgrade">>, Req, []),
+			case lists:member(<<"websocket">>, UpgradeTokens) of
+				false ->
+					{error, upgrade_required};
+				true ->
+					Version = cowboy_req:header(<<"sec-websocket-version">>, Req),
+					IntVersion = binary_to_integer(Version),
+					true = (IntVersion =:= 7) orelse (IntVersion =:= 8)
+						orelse (IntVersion =:= 13),
+					Key = cowboy_req:header(<<"sec-websocket-key">>, Req),
+					false = Key =:= undefined,
+					websocket_extensions(State#state{key=Key}, Req#{websocket_version => IntVersion})
+			end
+	end.
 
 -spec websocket_extensions(#state{}, Req)
 	-> {ok, #state{}, Req} when Req::cowboy_req:req().

+ 21 - 1
test/rfc7231_SUITE.erl

@@ -41,7 +41,8 @@ init_dispatch(_) ->
 		{"*", asterisk_h, []},
 		{"/", hello_h, []},
 		{"/echo/:key", echo_h, []},
-		{"/resp/:key[/:arg]", resp_h, []}
+		{"/resp/:key[/:arg]", resp_h, []},
+		{"/ws", ws_init_h, []}
 	]}]).
 
 %% @todo The documentation should list what methods, headers and status codes
@@ -514,6 +515,25 @@ status_code_426(Config) ->
 	{response, _, 426, _} = gun:await(ConnPid, Ref),
 	ok.
 
+status_code_426_upgrade_header(Config) ->
+	case config(protocol, Config) of
+		http ->
+			do_status_code_426_upgrade_header(Config);
+		http2 ->
+			doc("HTTP/2 does not support the HTTP/1.1 Upgrade mechanism.")
+	end.
+
+do_status_code_426_upgrade_header(Config) ->
+	doc("A 426 response must include a upgrade header. (RFC7231 6.5.15)"),
+	ConnPid = gun_open(Config),
+	Ref = gun:get(ConnPid, "/ws?ok", [
+		{<<"accept-encoding">>, <<"gzip">>}
+	]),
+	{response, _, 426, Headers} = gun:await(ConnPid, Ref),
+	{_, <<"upgrade">>} = lists:keyfind(<<"connection">>, 1, Headers),
+	{_, <<"websocket">>} = lists:keyfind(<<"upgrade">>, 1, Headers),
+	ok.
+
 status_code_500(Config) ->
 	doc("The 500 Internal Server Error status code can be sent. (RFC7231 6.6.1)"),
 	ConnPid = gun_open(Config),