Browse Source

Fix websocket unmasking when compression is enabled

The unmasking logic was based on the length of inflated data instead
of the length of the deflated data. This meant data would get corrupted
when we receive a websocket frame split across multiple TCP packets.
Ali Sabil 12 years ago
parent
commit
a3b9438d16
1 changed files with 41 additions and 28 deletions
  1. 41 28
      src/cowboy_websocket.erl

+ 41 - 28
src/cowboy_websocket.erl

@@ -329,45 +329,49 @@ websocket_data(State, Req, HandlerState, Data) ->
 websocket_data(State=#state{frag_state=undefined}, Req, HandlerState,
 websocket_data(State=#state{frag_state=undefined}, Req, HandlerState,
 		Opcode, Len, MaskKey, Data, Rsv, 0) ->
 		Opcode, Len, MaskKey, Data, Rsv, 0) ->
 	websocket_payload(State#state{frag_state={nofin, Opcode, <<>>}},
 	websocket_payload(State#state{frag_state={nofin, Opcode, <<>>}},
-		Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv);
+		Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv);
 %% Subsequent frame fragments.
 %% Subsequent frame fragments.
 websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState,
 websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState,
 		0, Len, MaskKey, Data, Rsv, 0) ->
 		0, Len, MaskKey, Data, Rsv, 0) ->
 	websocket_payload(State, Req, HandlerState,
 	websocket_payload(State, Req, HandlerState,
-		0, Len, MaskKey, <<>>, Data, Rsv);
+		0, Len, MaskKey, <<>>, 0, Data, Rsv);
 %% Final frame fragment.
 %% Final frame fragment.
 websocket_data(State=#state{frag_state={nofin, Opcode, SoFar}},
 websocket_data(State=#state{frag_state={nofin, Opcode, SoFar}},
 		Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 1) ->
 		Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 1) ->
 	websocket_payload(State#state{frag_state={fin, Opcode, SoFar}},
 	websocket_payload(State#state{frag_state={fin, Opcode, SoFar}},
-		Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv);
+		Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv);
 %% Unfragmented frame.
 %% Unfragmented frame.
 websocket_data(State, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 1) ->
 websocket_data(State, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 1) ->
 	websocket_payload(State, Req, HandlerState,
 	websocket_payload(State, Req, HandlerState,
-		Opcode, Len, MaskKey, <<>>, Data, Rsv).
+		Opcode, Len, MaskKey, <<>>, 0, Data, Rsv).
 
 
 -spec websocket_payload(#state{}, Req, any(),
 -spec websocket_payload(#state{}, Req, any(),
-	opcode(), non_neg_integer(), mask_key(), binary(), binary(), rsv())
+	opcode(), non_neg_integer(), mask_key(), binary(), non_neg_integer(),
+	binary(), rsv())
 	-> {ok, Req, cowboy_middleware:env()}
 	-> {ok, Req, cowboy_middleware:env()}
 	| {suspend, module(), atom(), [any()]}
 	| {suspend, module(), atom(), [any()]}
 	when Req::cowboy_req:req().
 	when Req::cowboy_req:req().
 %% Close control frames with a payload MUST contain a valid close code.
 %% Close control frames with a payload MUST contain a valid close code.
 websocket_payload(State, Req, HandlerState,
 websocket_payload(State, Req, HandlerState,
-		Opcode=8, Len, MaskKey, <<>>, << MaskedCode:2/binary, Rest/bits >>, Rsv) ->
+		Opcode=8, Len, MaskKey, <<>>, 0,
+		<< MaskedCode:2/binary, Rest/bits >>, Rsv) ->
 	Unmasked = << Code:16 >> = websocket_unmask(MaskedCode, MaskKey, <<>>),
 	Unmasked = << Code:16 >> = websocket_unmask(MaskedCode, MaskKey, <<>>),
 	if	Code < 1000; Code =:= 1004; Code =:= 1005; Code =:= 1006;
 	if	Code < 1000; Code =:= 1004; Code =:= 1005; Code =:= 1006;
 				(Code > 1011) and (Code < 3000); Code > 4999 ->
 				(Code > 1011) and (Code < 3000); Code > 4999 ->
 			websocket_close(State, Req, HandlerState, {error, badframe});
 			websocket_close(State, Req, HandlerState, {error, badframe});
 		true ->
 		true ->
 			websocket_payload(State, Req, HandlerState,
 			websocket_payload(State, Req, HandlerState,
-				Opcode, Len - 2, MaskKey, Unmasked, Rest, Rsv)
+				Opcode, Len - 2, MaskKey, Unmasked, byte_size(MaskedCode),
+				Rest, Rsv)
 	end;
 	end;
 %% Text frames and close control frames MUST have a payload that is valid UTF-8.
 %% Text frames and close control frames MUST have a payload that is valid UTF-8.
 websocket_payload(State=#state{utf8_state=Incomplete},
 websocket_payload(State=#state{utf8_state=Incomplete},
-		Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv)
+		Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen,
+		Data, Rsv)
 		when (byte_size(Data) < Len) andalso ((Opcode =:= 1) orelse
 		when (byte_size(Data) < Len) andalso ((Opcode =:= 1) orelse
 			((Opcode =:= 8) andalso (Unmasked =/= <<>>))) ->
 			((Opcode =:= 8) andalso (Unmasked =/= <<>>))) ->
 	Unmasked2 = websocket_unmask(Data,
 	Unmasked2 = websocket_unmask(Data,
-		rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
+		rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
 	case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
 	case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
 		false ->
 		false ->
@@ -375,14 +379,16 @@ websocket_payload(State=#state{utf8_state=Incomplete},
 		Utf8State ->
 		Utf8State ->
 			websocket_payload_loop(State2#state{utf8_state=Utf8State},
 			websocket_payload_loop(State2#state{utf8_state=Utf8State},
 				Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey,
 				Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey,
-				<< Unmasked/binary, Unmasked3/binary >>, Rsv)
+				<< Unmasked/binary, Unmasked3/binary >>,
+				UnmaskedLen + byte_size(Data), Rsv)
 	end;
 	end;
 websocket_payload(State=#state{utf8_state=Incomplete},
 websocket_payload(State=#state{utf8_state=Incomplete},
-		Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv)
+		Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen,
+		Data, Rsv)
 		when Opcode =:= 1; (Opcode =:= 8) and (Unmasked =/= <<>>) ->
 		when Opcode =:= 1; (Opcode =:= 8) and (Unmasked =/= <<>>) ->
 	<< End:Len/binary, Rest/bits >> = Data,
 	<< End:Len/binary, Rest/bits >> = Data,
 	Unmasked2 = websocket_unmask(End,
 	Unmasked2 = websocket_unmask(End,
-		rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
+		rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
 	case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
 	case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
 		<<>> ->
 		<<>> ->
@@ -394,10 +400,11 @@ websocket_payload(State=#state{utf8_state=Incomplete},
 	end;
 	end;
 %% Fragmented text frames may cut payload in the middle of UTF-8 codepoints.
 %% Fragmented text frames may cut payload in the middle of UTF-8 codepoints.
 websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete},
 websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete},
-		Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv)
+		Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen,
+		Data, Rsv)
 		when byte_size(Data) < Len ->
 		when byte_size(Data) < Len ->
 	Unmasked2 = websocket_unmask(Data,
 	Unmasked2 = websocket_unmask(Data,
-		rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
+		rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
 	case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
 	case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
 		false ->
 		false ->
@@ -405,13 +412,15 @@ websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete},
 		Utf8State ->
 		Utf8State ->
 			websocket_payload_loop(State2#state{utf8_state=Utf8State},
 			websocket_payload_loop(State2#state{utf8_state=Utf8State},
 				Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey,
 				Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey,
-				<< Unmasked/binary, Unmasked3/binary >>, Rsv)
+				<< Unmasked/binary, Unmasked3/binary >>,
+				UnmaskedLen + byte_size(Data), Rsv)
 	end;
 	end;
 websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete},
 websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete},
-		Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv) ->
+		Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen,
+		Data, Rsv) ->
 	<< End:Len/binary, Rest/bits >> = Data,
 	<< End:Len/binary, Rest/bits >> = Data,
 	Unmasked2 = websocket_unmask(End,
 	Unmasked2 = websocket_unmask(End,
-		rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
+		rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
 	case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
 	case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
 		<<>> ->
 		<<>> ->
@@ -427,20 +436,23 @@ websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete},
 	end;
 	end;
 %% Other frames have a binary payload.
 %% Other frames have a binary payload.
 websocket_payload(State, Req, HandlerState,
 websocket_payload(State, Req, HandlerState,
-		Opcode, Len, MaskKey, Unmasked, Data, Rsv)
+		Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv)
 		when byte_size(Data) < Len ->
 		when byte_size(Data) < Len ->
 	Unmasked2 = websocket_unmask(Data,
 	Unmasked2 = websocket_unmask(Data,
-		rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked),
+		rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
 	websocket_payload_loop(State2, Req, HandlerState,
 	websocket_payload_loop(State2, Req, HandlerState,
-		Opcode, Len - byte_size(Data), MaskKey, Unmasked3, Rsv);
+		Opcode, Len - byte_size(Data), MaskKey,
+		<< Unmasked/binary, Unmasked3/binary >>, UnmaskedLen + byte_size(Data),
+		Rsv);
 websocket_payload(State, Req, HandlerState,
 websocket_payload(State, Req, HandlerState,
-		Opcode, Len, MaskKey, Unmasked, Data, Rsv) ->
+		Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv) ->
 	<< End:Len/binary, Rest/bits >> = Data,
 	<< End:Len/binary, Rest/bits >> = Data,
 	Unmasked2 = websocket_unmask(End,
 	Unmasked2 = websocket_unmask(End,
-		rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked),
+		rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
 	{Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
-	websocket_dispatch(State2, Req, HandlerState, Rest, Opcode, Unmasked3).
+	websocket_dispatch(State2, Req, HandlerState, Rest, Opcode,
+		<< Unmasked/binary, Unmasked3/binary >>).
 
 
 -spec websocket_inflate_frame(binary(), rsv(), boolean(), #state{}) ->
 -spec websocket_inflate_frame(binary(), rsv(), boolean(), #state{}) ->
 		{binary(), #state{}}.
 		{binary(), #state{}}.
@@ -513,19 +525,20 @@ is_utf8(_) ->
 	false.
 	false.
 
 
 -spec websocket_payload_loop(#state{}, Req, any(),
 -spec websocket_payload_loop(#state{}, Req, any(),
-		opcode(), non_neg_integer(), mask_key(), binary(), rsv())
+		opcode(), non_neg_integer(), mask_key(), binary(),
+		non_neg_integer(), rsv())
 	-> {ok, Req, cowboy_middleware:env()}
 	-> {ok, Req, cowboy_middleware:env()}
 	| {suspend, module(), atom(), [any()]}
 	| {suspend, module(), atom(), [any()]}
 	when Req::cowboy_req:req().
 	when Req::cowboy_req:req().
 websocket_payload_loop(State=#state{socket=Socket, transport=Transport,
 websocket_payload_loop(State=#state{socket=Socket, transport=Transport,
 		messages={OK, Closed, Error}, timeout_ref=TRef},
 		messages={OK, Closed, Error}, timeout_ref=TRef},
-		Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Rsv) ->
+		Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv) ->
 	Transport:setopts(Socket, [{active, once}]),
 	Transport:setopts(Socket, [{active, once}]),
 	receive
 	receive
 		{OK, Socket, Data} ->
 		{OK, Socket, Data} ->
 			State2 = handler_loop_timeout(State),
 			State2 = handler_loop_timeout(State),
 			websocket_payload(State2, Req, HandlerState,
 			websocket_payload(State2, Req, HandlerState,
-				Opcode, Len, MaskKey, Unmasked, Data, Rsv);
+				Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv);
 		{Closed, Socket} ->
 		{Closed, Socket} ->
 			handler_terminate(State, Req, HandlerState, {error, closed});
 			handler_terminate(State, Req, HandlerState, {error, closed});
 		{Error, Socket, Reason} ->
 		{Error, Socket, Reason} ->
@@ -534,13 +547,13 @@ websocket_payload_loop(State=#state{socket=Socket, transport=Transport,
 			websocket_close(State, Req, HandlerState, {normal, timeout});
 			websocket_close(State, Req, HandlerState, {normal, timeout});
 		{timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
 		{timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
 			websocket_payload_loop(State, Req, HandlerState,
 			websocket_payload_loop(State, Req, HandlerState,
-				Opcode, Len, MaskKey, Unmasked, Rsv);
+				Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv);
 		Message ->
 		Message ->
 			handler_call(State, Req, HandlerState,
 			handler_call(State, Req, HandlerState,
 				<<>>, websocket_info, Message,
 				<<>>, websocket_info, Message,
 				fun (State2, Req2, HandlerState2, _) ->
 				fun (State2, Req2, HandlerState2, _) ->
 					websocket_payload_loop(State2, Req2, HandlerState2,
 					websocket_payload_loop(State2, Req2, HandlerState2,
-						Opcode, Len, MaskKey, Unmasked, Rsv)
+						Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv)
 				end)
 				end)
 	end.
 	end.