Browse Source

Add missing client functionality to Websocket code

Loïc Hoguin 10 years ago
parent
commit
a8db5d9f7a
1 changed files with 156 additions and 34 deletions
  1. 156 34
      src/cow_ws.erl

+ 156 - 34
src/cow_ws.erl

@@ -14,33 +14,59 @@
 
 -module(cow_ws).
 
+-export([key/0]).
+-export([encode_key/1]).
+
 -export([negotiate_permessage_deflate/3]).
 -export([negotiate_x_webkit_deflate_frame/3]).
 
+-export([validate_permessage_deflate/3]).
+
 -export([parse_header/3]).
 -export([parse_payload/9]).
 -export([make_frame/4]).
+
 -export([frame/2]).
+-export([masked_frame/2]).
 
 -type close_code() :: 1000..1003 | 1006..1011 | 3000..4999.
 -export_type([close_code/0]).
 
+-type extensions() :: map().
+-export_type([extensions/0]).
+
 -type frag_state() :: undefined | {fin | nofin, text | binary, rsv()}.
 -export_type([frag_state/0]).
 
 -type frame() :: close | ping | pong
 	| {text | binary | close | ping | pong, iodata()}
 	| {close, close_code(), iodata()}
-	| {fragment, fin | nofin, text | binary, iodata()}.
+	| {fragment, fin | nofin, text | binary | continuation, iodata()}.
 -export_type([frame/0]).
 
--type utf8_state() :: 0..8.
--export_type([utf8_state/0]).
-
--type extensions() :: map().
 -type frame_type() :: fragment | text | binary | close | ping | pong.
+-export_type([frame_type/0]).
+
 -type mask_key() :: undefined | 0..16#ffffffff.
+-export_type([mask_key/0]).
+
 -type rsv() :: <<_:3>>.
+-export_type([rsv/0]).
+
+-type utf8_state() :: 0..8.
+-export_type([utf8_state/0]).
+
+%% @doc Generate a key for the Websocket handshake request.
+
+-spec key() -> binary().
+key() ->
+	base64:encode(crypto:rand_bytes(16)).
+
+%% @doc Encode the key into the accept value for the Websocket handshake response.
+
+-spec encode_key(binary()) -> binary().
+encode_key(Key) ->
+	base64:encode(crypto:hash(sha, [Key, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"])).
 
 %% @doc Negotiate the permessage-deflate extension.
 
@@ -54,7 +80,7 @@ negotiate_permessage_deflate(Params, Extensions, Opts) ->
 			ignore;
 		Params2 ->
 			%% @todo Might want to make these configurable defaults.
-			case parse_permessage_deflate_params(Params2, 15, takeover, 15, takeover, []) of
+			case parse_request_permessage_deflate_params(Params2, 15, takeover, 15, takeover, []) of
 				ignore ->
 					ignore;
 				{ClientWindowBits, ClientTakeOver, ServerWindowBits, ServerTakeOver, RespParams} ->
@@ -68,33 +94,33 @@ negotiate_permessage_deflate(Params, Extensions, Opts) ->
 			end
 	end.
 
-parse_permessage_deflate_params([], CB, CTO, SB, STO, RespParams) ->
+parse_request_permessage_deflate_params([], CB, CTO, SB, STO, RespParams) ->
 	{CB, CTO, SB, STO, RespParams};
-parse_permessage_deflate_params([<<"client_max_window_bits">>|Tail], CB, CTO, SB, STO, RespParams) ->
-	parse_permessage_deflate_params(Tail, CB, CTO, SB, STO,
+parse_request_permessage_deflate_params([<<"client_max_window_bits">>|Tail], CB, CTO, SB, STO, RespParams) ->
+	parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO,
 		[<<"; ">>, <<"client_max_window_bits=">>, integer_to_binary(CB)|RespParams]);
-parse_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO, RespParams) ->
+parse_request_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO, RespParams) ->
 	case parse_max_window_bits(Max) of
 		error ->
 			ignore;
 		CB ->
-			parse_permessage_deflate_params(Tail, CB, CTO, SB, STO,
+			parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO,
 				[<<"; ">>, <<"client_max_window_bits=">>, Max|RespParams])
 	end;
-parse_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO, RespParams) ->
-	parse_permessage_deflate_params(Tail, CB, no_takeover, SB, STO, [<<"; ">>, <<"client_no_context_takeover">>|RespParams]);
-parse_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO, RespParams) ->
+parse_request_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO, RespParams) ->
+	parse_request_permessage_deflate_params(Tail, CB, no_takeover, SB, STO, [<<"; ">>, <<"client_no_context_takeover">>|RespParams]);
+parse_request_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO, RespParams) ->
 	case parse_max_window_bits(Max) of
 		error ->
 			ignore;
 		SB ->
-			parse_permessage_deflate_params(Tail, CB, CTO, SB, STO,
+			parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO,
 				[<<"; ">>, <<"server_max_window_bits=">>, Max|RespParams])
 	end;
-parse_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _, RespParams) ->
-	parse_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover, [<<"; ">>, <<"server_no_context_takeover">>|RespParams]);
-%% Ignore if unknown parameter; ignore if parameter with invalid value.
-parse_permessage_deflate_params(_, _, _, _, _, _) ->
+parse_request_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _, RespParams) ->
+	parse_request_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover, [<<"; ">>, <<"server_no_context_takeover">>|RespParams]);
+%% Ignore if unknown parameter; ignore if parameter with invalid or missing value.
+parse_request_permessage_deflate_params(_, _, _, _, _, _) ->
 	ignore.
 
 parse_max_window_bits(<<"8">>) -> 8;
@@ -108,19 +134,19 @@ parse_max_window_bits(<<"15">>) -> 15;
 parse_max_window_bits(_) -> error.
 
 % A negative WindowBits value indicates that zlib headers are not used.
-init_permessage_deflate(ClientWindowBits, ServerWindowBits, Opts) ->
+init_permessage_deflate(InflateWindowBits, DeflateWindowBits, Opts) ->
 	Inflate = zlib:open(),
-	ok = zlib:inflateInit(Inflate, -ClientWindowBits),
+	ok = zlib:inflateInit(Inflate, -InflateWindowBits),
 	Deflate = zlib:open(),
 	%% @todo Remove this case .. of for OTP 18+ if PR https://github.com/erlang/otp/pull/633 gets merged.
-	ServerWindowBits2 = case ServerWindowBits of
+	DeflateWindowBits2 = case DeflateWindowBits of
 		8 -> 9;
-		_ -> ServerWindowBits
+		_ -> DeflateWindowBits
 	end,
 	ok = zlib:deflateInit(Deflate,
 		maps:get(level, Opts, best_compression),
 		deflated,
-		-ServerWindowBits2,
+		-DeflateWindowBits2,
 		maps:get(mem_level, Opts, 8),
 		maps:get(strategy, Opts, default)),
 	{Inflate, Deflate}.
@@ -144,6 +170,51 @@ negotiate_x_webkit_deflate_frame(_Params, Extensions, Opts) ->
 			inflate => Inflate,
 			inflate_takeover => takeover}}.
 
+%% @doc Validate the negotiated permessage-deflate extension.
+
+%% Error when more than one deflate extension was negotiated.
+validate_permessage_deflate(_, #{deflate := _}, _) ->
+	error;
+validate_permessage_deflate(Params, Extensions, Opts) ->
+	case lists:usort(Params) of
+		%% Error if multiple parameters with the same name.
+		Params2 when length(Params) =/= length(Params2) ->
+			error;
+		Params2 ->
+			%% @todo Might want to make some of these configurable defaults if at all possible.
+			case parse_response_permessage_deflate_params(Params2, 15, takeover, 15, takeover) of
+				error ->
+					error;
+				{ClientWindowBits, ClientTakeOver, ServerWindowBits, ServerTakeOver} ->
+					{Inflate, Deflate} = init_permessage_deflate(ServerWindowBits, ClientWindowBits, Opts),
+					{ok, Extensions#{
+						deflate => Deflate,
+						deflate_takeover => ClientTakeOver,
+						inflate => Inflate,
+						inflate_takeover => ServerTakeOver}}
+			end
+	end.
+
+parse_response_permessage_deflate_params([], CB, CTO, SB, STO) ->
+	{CB, CTO, SB, STO};
+parse_response_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO) ->
+	case parse_max_window_bits(Max) of
+		error -> error;
+		CB -> parse_response_permessage_deflate_params(Tail, CB, CTO, SB, STO)
+	end;
+parse_response_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO) ->
+	parse_response_permessage_deflate_params(Tail, CB, no_takeover, SB, STO);
+parse_response_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO) ->
+	case parse_max_window_bits(Max) of
+		error -> error;
+		SB -> parse_response_permessage_deflate_params(Tail, CB, CTO, SB, STO)
+	end;
+parse_response_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _) ->
+	parse_response_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover);
+%% Error if unknown parameter; error if parameter with invalid or missing value.
+parse_response_permessage_deflate_params(_, _, _, _, _) ->
+	error.
+
 %% @doc Parse and validate the Websocket frame header.
 %%
 %% This function also updates the fragmentation state according to
@@ -244,6 +315,7 @@ frag_state(_, 1, _, FragState) -> FragState.
 %% Empty last frame of compressed message.
 parse_payload(Data, _, Utf8State, _, _, 0, {fin, _, << 1:1, 0:2 >>},
 		#{inflate := Inflate, inflate_takeover := TakeOver}, _) ->
+	zlib:inflate(Inflate, << 0, 0, 255, 255 >>),
 	case TakeOver of
 		no_takeover -> zlib:inflateReset(Inflate);
 		takeover -> ok
@@ -307,34 +379,35 @@ validate_close_code(Code) ->
 		true -> ok
 	end.
 
+unmask(Data, undefined, _) ->
+	Data;
 unmask(Data, MaskKey, 0) ->
-	do_unmask(Data, MaskKey, <<>>);
+	mask(Data, MaskKey, <<>>);
 %% We unmask on the fly so we need to continue from the right mask byte.
 unmask(Data, MaskKey, UnmaskedLen) ->
 	Left = UnmaskedLen rem 4,
 	Right = 4 - Left,
 	MaskKey2 = (MaskKey bsl (Left * 8)) + (MaskKey bsr (Right * 8)),
-	do_unmask(Data, MaskKey2, <<>>).
+	mask(Data, MaskKey2, <<>>).
 
-do_unmask(<<>>, _, Unmasked) ->
+mask(<<>>, _, Unmasked) ->
 	Unmasked;
-do_unmask(<< O:32, Rest/bits >>, MaskKey, Acc) ->
+mask(<< O:32, Rest/bits >>, MaskKey, Acc) ->
 	T = O bxor MaskKey,
-	do_unmask(Rest, MaskKey, << Acc/binary, T:32 >>);
-do_unmask(<< O:24 >>, MaskKey, Acc) ->
+	mask(Rest, MaskKey, << Acc/binary, T:32 >>);
+mask(<< O:24 >>, MaskKey, Acc) ->
 	<< MaskKey2:24, _:8 >> = << MaskKey:32 >>,
 	T = O bxor MaskKey2,
 	<< Acc/binary, T:24 >>;
-do_unmask(<< O:16 >>, MaskKey, Acc) ->
+mask(<< O:16 >>, MaskKey, Acc) ->
 	<< MaskKey2:16, _:16 >> = << MaskKey:32 >>,
 	T = O bxor MaskKey2,
 	<< Acc/binary, T:16 >>;
-do_unmask(<< O:8 >>, MaskKey, Acc) ->
+mask(<< O:8 >>, MaskKey, Acc) ->
 	<< MaskKey2:8, _:24 >> = << MaskKey:32 >>,
 	T = O bxor MaskKey2,
 	<< Acc/binary, T:8 >>.
 
-%% @todo Try using iodata() and see if it improves anything.
 inflate_frame(Data, Inflate, TakeOver, FragState, true)
 		when FragState =:= undefined; element(1, FragState) =:= fin ->
 	Data2 = zlib:inflate(Inflate, << Data/binary, 0, 0, 255, 255 >>),
@@ -416,7 +489,6 @@ make_frame(pong, <<>>, _, _) -> pong;
 make_frame(pong, Payload, _, _) -> {pong, Payload}.
 
 %% @doc Construct an unmasked Websocket frame.
-%% @todo Add fragments support.
 
 -spec frame(frame(), extensions()) -> iodata().
 %% Control frames. Control packets must not be > 125 in length.
@@ -457,6 +529,56 @@ frame({binary, Payload}, _) ->
 	Len = payload_length(Payload),
 	[<< 1:1, 0:3, 2:4, 0:1, Len/bits >>, Payload].
 
+%% @doc Construct a masked Websocket frame.
+%%
+%% We use a mask key of 0 if there is no payload for close, ping and pong frames.
+
+-spec masked_frame(frame(), extensions()) -> iodata().
+%% Control frames. Control packets must not be > 125 in length.
+masked_frame(close, _) ->
+	<< 1:1, 0:3, 8:4, 1:1, 0:39 >>;
+masked_frame(ping, _) ->
+	<< 1:1, 0:3, 9:4, 1:1, 0:39 >>;
+masked_frame(pong, _) ->
+	<< 1:1, 0:3, 10:4, 1:1, 0:39 >>;
+masked_frame({close, Payload}, Extensions) ->
+	frame({close, 1000, Payload}, Extensions);
+masked_frame({close, StatusCode, Payload}, _) ->
+	Len = 2 + iolist_size(Payload),
+	true = Len =< 125,
+	MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+	[<< 1:1, 0:3, 8:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary([<< StatusCode:16 >>, Payload]), MaskKey, <<>>)];
+masked_frame({ping, Payload}, _) ->
+	Len = iolist_size(Payload),
+	true = Len =< 125,
+	MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+	[<< 1:1, 0:3, 9:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)];
+masked_frame({pong, Payload}, _) ->
+	Len = iolist_size(Payload),
+	true = Len =< 125,
+	MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+	[<< 1:1, 0:3, 10:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)];
+%% Data frames, deflate-frame extension.
+masked_frame({text, Payload}, #{deflate := Deflate, deflate_takeover := TakeOver}) ->
+	MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+	Payload2 = mask(deflate_frame(Payload, Deflate, TakeOver), MaskKey, <<>>),
+	Len = payload_length(Payload2),
+	[<< 1:1, 1:1, 0:2, 1:4, 1:1, Len/bits >>, MaskKeyBin, Payload2];
+masked_frame({binary, Payload}, #{deflate := Deflate, deflate_takeover := TakeOver}) ->
+	MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+	Payload2 = mask(deflate_frame(Payload, Deflate, TakeOver), MaskKey, <<>>),
+	Len = payload_length(Payload2),
+	[<< 1:1, 1:1, 0:2, 2:4, 1:1, Len/bits >>, MaskKeyBin, Payload2];
+%% Data frames.
+masked_frame({text, Payload}, _) ->
+	MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+	Len = payload_length(Payload),
+	[<< 1:1, 0:3, 1:4, 1:1, Len/bits >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)];
+masked_frame({binary, Payload}, _) ->
+	MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+	Len = payload_length(Payload),
+	[<< 1:1, 0:3, 2:4, 1:1, Len/bits >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)].
+
 payload_length(Payload) ->
 	case byte_size(Payload) of
 		N when N =< 125 -> << N:7 >>;