|
@@ -14,33 +14,59 @@
|
|
|
|
|
|
-module(cow_ws).
|
|
|
|
|
|
+-export([key/0]).
|
|
|
+-export([encode_key/1]).
|
|
|
+
|
|
|
-export([negotiate_permessage_deflate/3]).
|
|
|
-export([negotiate_x_webkit_deflate_frame/3]).
|
|
|
|
|
|
+-export([validate_permessage_deflate/3]).
|
|
|
+
|
|
|
-export([parse_header/3]).
|
|
|
-export([parse_payload/9]).
|
|
|
-export([make_frame/4]).
|
|
|
+
|
|
|
-export([frame/2]).
|
|
|
+-export([masked_frame/2]).
|
|
|
|
|
|
-type close_code() :: 1000..1003 | 1006..1011 | 3000..4999.
|
|
|
-export_type([close_code/0]).
|
|
|
|
|
|
+-type extensions() :: map().
|
|
|
+-export_type([extensions/0]).
|
|
|
+
|
|
|
-type frag_state() :: undefined | {fin | nofin, text | binary, rsv()}.
|
|
|
-export_type([frag_state/0]).
|
|
|
|
|
|
-type frame() :: close | ping | pong
|
|
|
| {text | binary | close | ping | pong, iodata()}
|
|
|
| {close, close_code(), iodata()}
|
|
|
- | {fragment, fin | nofin, text | binary, iodata()}.
|
|
|
+ | {fragment, fin | nofin, text | binary | continuation, iodata()}.
|
|
|
-export_type([frame/0]).
|
|
|
|
|
|
--type utf8_state() :: 0..8.
|
|
|
--export_type([utf8_state/0]).
|
|
|
-
|
|
|
--type extensions() :: map().
|
|
|
-type frame_type() :: fragment | text | binary | close | ping | pong.
|
|
|
+-export_type([frame_type/0]).
|
|
|
+
|
|
|
-type mask_key() :: undefined | 0..16#ffffffff.
|
|
|
+-export_type([mask_key/0]).
|
|
|
+
|
|
|
-type rsv() :: <<_:3>>.
|
|
|
+-export_type([rsv/0]).
|
|
|
+
|
|
|
+-type utf8_state() :: 0..8.
|
|
|
+-export_type([utf8_state/0]).
|
|
|
+
|
|
|
+%% @doc Generate a key for the Websocket handshake request.
|
|
|
+
|
|
|
+-spec key() -> binary().
|
|
|
+key() ->
|
|
|
+ base64:encode(crypto:rand_bytes(16)).
|
|
|
+
|
|
|
+%% @doc Encode the key into the accept value for the Websocket handshake response.
|
|
|
+
|
|
|
+-spec encode_key(binary()) -> binary().
|
|
|
+encode_key(Key) ->
|
|
|
+ base64:encode(crypto:hash(sha, [Key, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"])).
|
|
|
|
|
|
%% @doc Negotiate the permessage-deflate extension.
|
|
|
|
|
@@ -54,7 +80,7 @@ negotiate_permessage_deflate(Params, Extensions, Opts) ->
|
|
|
ignore;
|
|
|
Params2 ->
|
|
|
%% @todo Might want to make these configurable defaults.
|
|
|
- case parse_permessage_deflate_params(Params2, 15, takeover, 15, takeover, []) of
|
|
|
+ case parse_request_permessage_deflate_params(Params2, 15, takeover, 15, takeover, []) of
|
|
|
ignore ->
|
|
|
ignore;
|
|
|
{ClientWindowBits, ClientTakeOver, ServerWindowBits, ServerTakeOver, RespParams} ->
|
|
@@ -68,33 +94,33 @@ negotiate_permessage_deflate(Params, Extensions, Opts) ->
|
|
|
end
|
|
|
end.
|
|
|
|
|
|
-parse_permessage_deflate_params([], CB, CTO, SB, STO, RespParams) ->
|
|
|
+parse_request_permessage_deflate_params([], CB, CTO, SB, STO, RespParams) ->
|
|
|
{CB, CTO, SB, STO, RespParams};
|
|
|
-parse_permessage_deflate_params([<<"client_max_window_bits">>|Tail], CB, CTO, SB, STO, RespParams) ->
|
|
|
- parse_permessage_deflate_params(Tail, CB, CTO, SB, STO,
|
|
|
+parse_request_permessage_deflate_params([<<"client_max_window_bits">>|Tail], CB, CTO, SB, STO, RespParams) ->
|
|
|
+ parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO,
|
|
|
[<<"; ">>, <<"client_max_window_bits=">>, integer_to_binary(CB)|RespParams]);
|
|
|
-parse_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO, RespParams) ->
|
|
|
+parse_request_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO, RespParams) ->
|
|
|
case parse_max_window_bits(Max) of
|
|
|
error ->
|
|
|
ignore;
|
|
|
CB ->
|
|
|
- parse_permessage_deflate_params(Tail, CB, CTO, SB, STO,
|
|
|
+ parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO,
|
|
|
[<<"; ">>, <<"client_max_window_bits=">>, Max|RespParams])
|
|
|
end;
|
|
|
-parse_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO, RespParams) ->
|
|
|
- parse_permessage_deflate_params(Tail, CB, no_takeover, SB, STO, [<<"; ">>, <<"client_no_context_takeover">>|RespParams]);
|
|
|
-parse_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO, RespParams) ->
|
|
|
+parse_request_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO, RespParams) ->
|
|
|
+ parse_request_permessage_deflate_params(Tail, CB, no_takeover, SB, STO, [<<"; ">>, <<"client_no_context_takeover">>|RespParams]);
|
|
|
+parse_request_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO, RespParams) ->
|
|
|
case parse_max_window_bits(Max) of
|
|
|
error ->
|
|
|
ignore;
|
|
|
SB ->
|
|
|
- parse_permessage_deflate_params(Tail, CB, CTO, SB, STO,
|
|
|
+ parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO,
|
|
|
[<<"; ">>, <<"server_max_window_bits=">>, Max|RespParams])
|
|
|
end;
|
|
|
-parse_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _, RespParams) ->
|
|
|
- parse_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover, [<<"; ">>, <<"server_no_context_takeover">>|RespParams]);
|
|
|
-%% Ignore if unknown parameter; ignore if parameter with invalid value.
|
|
|
-parse_permessage_deflate_params(_, _, _, _, _, _) ->
|
|
|
+parse_request_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _, RespParams) ->
|
|
|
+ parse_request_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover, [<<"; ">>, <<"server_no_context_takeover">>|RespParams]);
|
|
|
+%% Ignore if unknown parameter; ignore if parameter with invalid or missing value.
|
|
|
+parse_request_permessage_deflate_params(_, _, _, _, _, _) ->
|
|
|
ignore.
|
|
|
|
|
|
parse_max_window_bits(<<"8">>) -> 8;
|
|
@@ -108,19 +134,19 @@ parse_max_window_bits(<<"15">>) -> 15;
|
|
|
parse_max_window_bits(_) -> error.
|
|
|
|
|
|
% A negative WindowBits value indicates that zlib headers are not used.
|
|
|
-init_permessage_deflate(ClientWindowBits, ServerWindowBits, Opts) ->
|
|
|
+init_permessage_deflate(InflateWindowBits, DeflateWindowBits, Opts) ->
|
|
|
Inflate = zlib:open(),
|
|
|
- ok = zlib:inflateInit(Inflate, -ClientWindowBits),
|
|
|
+ ok = zlib:inflateInit(Inflate, -InflateWindowBits),
|
|
|
Deflate = zlib:open(),
|
|
|
%% @todo Remove this case .. of for OTP 18+ if PR https://github.com/erlang/otp/pull/633 gets merged.
|
|
|
- ServerWindowBits2 = case ServerWindowBits of
|
|
|
+ DeflateWindowBits2 = case DeflateWindowBits of
|
|
|
8 -> 9;
|
|
|
- _ -> ServerWindowBits
|
|
|
+ _ -> DeflateWindowBits
|
|
|
end,
|
|
|
ok = zlib:deflateInit(Deflate,
|
|
|
maps:get(level, Opts, best_compression),
|
|
|
deflated,
|
|
|
- -ServerWindowBits2,
|
|
|
+ -DeflateWindowBits2,
|
|
|
maps:get(mem_level, Opts, 8),
|
|
|
maps:get(strategy, Opts, default)),
|
|
|
{Inflate, Deflate}.
|
|
@@ -144,6 +170,51 @@ negotiate_x_webkit_deflate_frame(_Params, Extensions, Opts) ->
|
|
|
inflate => Inflate,
|
|
|
inflate_takeover => takeover}}.
|
|
|
|
|
|
+%% @doc Validate the negotiated permessage-deflate extension.
|
|
|
+
|
|
|
+%% Error when more than one deflate extension was negotiated.
|
|
|
+validate_permessage_deflate(_, #{deflate := _}, _) ->
|
|
|
+ error;
|
|
|
+validate_permessage_deflate(Params, Extensions, Opts) ->
|
|
|
+ case lists:usort(Params) of
|
|
|
+ %% Error if multiple parameters with the same name.
|
|
|
+ Params2 when length(Params) =/= length(Params2) ->
|
|
|
+ error;
|
|
|
+ Params2 ->
|
|
|
+ %% @todo Might want to make some of these configurable defaults if at all possible.
|
|
|
+ case parse_response_permessage_deflate_params(Params2, 15, takeover, 15, takeover) of
|
|
|
+ error ->
|
|
|
+ error;
|
|
|
+ {ClientWindowBits, ClientTakeOver, ServerWindowBits, ServerTakeOver} ->
|
|
|
+ {Inflate, Deflate} = init_permessage_deflate(ServerWindowBits, ClientWindowBits, Opts),
|
|
|
+ {ok, Extensions#{
|
|
|
+ deflate => Deflate,
|
|
|
+ deflate_takeover => ClientTakeOver,
|
|
|
+ inflate => Inflate,
|
|
|
+ inflate_takeover => ServerTakeOver}}
|
|
|
+ end
|
|
|
+ end.
|
|
|
+
|
|
|
+parse_response_permessage_deflate_params([], CB, CTO, SB, STO) ->
|
|
|
+ {CB, CTO, SB, STO};
|
|
|
+parse_response_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO) ->
|
|
|
+ case parse_max_window_bits(Max) of
|
|
|
+ error -> error;
|
|
|
+ CB -> parse_response_permessage_deflate_params(Tail, CB, CTO, SB, STO)
|
|
|
+ end;
|
|
|
+parse_response_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO) ->
|
|
|
+ parse_response_permessage_deflate_params(Tail, CB, no_takeover, SB, STO);
|
|
|
+parse_response_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO) ->
|
|
|
+ case parse_max_window_bits(Max) of
|
|
|
+ error -> error;
|
|
|
+ SB -> parse_response_permessage_deflate_params(Tail, CB, CTO, SB, STO)
|
|
|
+ end;
|
|
|
+parse_response_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _) ->
|
|
|
+ parse_response_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover);
|
|
|
+%% Error if unknown parameter; error if parameter with invalid or missing value.
|
|
|
+parse_response_permessage_deflate_params(_, _, _, _, _) ->
|
|
|
+ error.
|
|
|
+
|
|
|
%% @doc Parse and validate the Websocket frame header.
|
|
|
%%
|
|
|
%% This function also updates the fragmentation state according to
|
|
@@ -244,6 +315,7 @@ frag_state(_, 1, _, FragState) -> FragState.
|
|
|
%% Empty last frame of compressed message.
|
|
|
parse_payload(Data, _, Utf8State, _, _, 0, {fin, _, << 1:1, 0:2 >>},
|
|
|
#{inflate := Inflate, inflate_takeover := TakeOver}, _) ->
|
|
|
+ zlib:inflate(Inflate, << 0, 0, 255, 255 >>),
|
|
|
case TakeOver of
|
|
|
no_takeover -> zlib:inflateReset(Inflate);
|
|
|
takeover -> ok
|
|
@@ -307,34 +379,35 @@ validate_close_code(Code) ->
|
|
|
true -> ok
|
|
|
end.
|
|
|
|
|
|
+unmask(Data, undefined, _) ->
|
|
|
+ Data;
|
|
|
unmask(Data, MaskKey, 0) ->
|
|
|
- do_unmask(Data, MaskKey, <<>>);
|
|
|
+ mask(Data, MaskKey, <<>>);
|
|
|
%% We unmask on the fly so we need to continue from the right mask byte.
|
|
|
unmask(Data, MaskKey, UnmaskedLen) ->
|
|
|
Left = UnmaskedLen rem 4,
|
|
|
Right = 4 - Left,
|
|
|
MaskKey2 = (MaskKey bsl (Left * 8)) + (MaskKey bsr (Right * 8)),
|
|
|
- do_unmask(Data, MaskKey2, <<>>).
|
|
|
+ mask(Data, MaskKey2, <<>>).
|
|
|
|
|
|
-do_unmask(<<>>, _, Unmasked) ->
|
|
|
+mask(<<>>, _, Unmasked) ->
|
|
|
Unmasked;
|
|
|
-do_unmask(<< O:32, Rest/bits >>, MaskKey, Acc) ->
|
|
|
+mask(<< O:32, Rest/bits >>, MaskKey, Acc) ->
|
|
|
T = O bxor MaskKey,
|
|
|
- do_unmask(Rest, MaskKey, << Acc/binary, T:32 >>);
|
|
|
-do_unmask(<< O:24 >>, MaskKey, Acc) ->
|
|
|
+ mask(Rest, MaskKey, << Acc/binary, T:32 >>);
|
|
|
+mask(<< O:24 >>, MaskKey, Acc) ->
|
|
|
<< MaskKey2:24, _:8 >> = << MaskKey:32 >>,
|
|
|
T = O bxor MaskKey2,
|
|
|
<< Acc/binary, T:24 >>;
|
|
|
-do_unmask(<< O:16 >>, MaskKey, Acc) ->
|
|
|
+mask(<< O:16 >>, MaskKey, Acc) ->
|
|
|
<< MaskKey2:16, _:16 >> = << MaskKey:32 >>,
|
|
|
T = O bxor MaskKey2,
|
|
|
<< Acc/binary, T:16 >>;
|
|
|
-do_unmask(<< O:8 >>, MaskKey, Acc) ->
|
|
|
+mask(<< O:8 >>, MaskKey, Acc) ->
|
|
|
<< MaskKey2:8, _:24 >> = << MaskKey:32 >>,
|
|
|
T = O bxor MaskKey2,
|
|
|
<< Acc/binary, T:8 >>.
|
|
|
|
|
|
-%% @todo Try using iodata() and see if it improves anything.
|
|
|
inflate_frame(Data, Inflate, TakeOver, FragState, true)
|
|
|
when FragState =:= undefined; element(1, FragState) =:= fin ->
|
|
|
Data2 = zlib:inflate(Inflate, << Data/binary, 0, 0, 255, 255 >>),
|
|
@@ -416,7 +489,6 @@ make_frame(pong, <<>>, _, _) -> pong;
|
|
|
make_frame(pong, Payload, _, _) -> {pong, Payload}.
|
|
|
|
|
|
%% @doc Construct an unmasked Websocket frame.
|
|
|
-%% @todo Add fragments support.
|
|
|
|
|
|
-spec frame(frame(), extensions()) -> iodata().
|
|
|
%% Control frames. Control packets must not be > 125 in length.
|
|
@@ -457,6 +529,56 @@ frame({binary, Payload}, _) ->
|
|
|
Len = payload_length(Payload),
|
|
|
[<< 1:1, 0:3, 2:4, 0:1, Len/bits >>, Payload].
|
|
|
|
|
|
+%% @doc Construct a masked Websocket frame.
|
|
|
+%%
|
|
|
+%% We use a mask key of 0 if there is no payload for close, ping and pong frames.
|
|
|
+
|
|
|
+-spec masked_frame(frame(), extensions()) -> iodata().
|
|
|
+%% Control frames. Control packets must not be > 125 in length.
|
|
|
+masked_frame(close, _) ->
|
|
|
+ << 1:1, 0:3, 8:4, 1:1, 0:39 >>;
|
|
|
+masked_frame(ping, _) ->
|
|
|
+ << 1:1, 0:3, 9:4, 1:1, 0:39 >>;
|
|
|
+masked_frame(pong, _) ->
|
|
|
+ << 1:1, 0:3, 10:4, 1:1, 0:39 >>;
|
|
|
+masked_frame({close, Payload}, Extensions) ->
|
|
|
+ frame({close, 1000, Payload}, Extensions);
|
|
|
+masked_frame({close, StatusCode, Payload}, _) ->
|
|
|
+ Len = 2 + iolist_size(Payload),
|
|
|
+ true = Len =< 125,
|
|
|
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
|
|
|
+ [<< 1:1, 0:3, 8:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary([<< StatusCode:16 >>, Payload]), MaskKey, <<>>)];
|
|
|
+masked_frame({ping, Payload}, _) ->
|
|
|
+ Len = iolist_size(Payload),
|
|
|
+ true = Len =< 125,
|
|
|
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
|
|
|
+ [<< 1:1, 0:3, 9:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)];
|
|
|
+masked_frame({pong, Payload}, _) ->
|
|
|
+ Len = iolist_size(Payload),
|
|
|
+ true = Len =< 125,
|
|
|
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
|
|
|
+ [<< 1:1, 0:3, 10:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)];
|
|
|
+%% Data frames, deflate-frame extension.
|
|
|
+masked_frame({text, Payload}, #{deflate := Deflate, deflate_takeover := TakeOver}) ->
|
|
|
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
|
|
|
+ Payload2 = mask(deflate_frame(Payload, Deflate, TakeOver), MaskKey, <<>>),
|
|
|
+ Len = payload_length(Payload2),
|
|
|
+ [<< 1:1, 1:1, 0:2, 1:4, 1:1, Len/bits >>, MaskKeyBin, Payload2];
|
|
|
+masked_frame({binary, Payload}, #{deflate := Deflate, deflate_takeover := TakeOver}) ->
|
|
|
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
|
|
|
+ Payload2 = mask(deflate_frame(Payload, Deflate, TakeOver), MaskKey, <<>>),
|
|
|
+ Len = payload_length(Payload2),
|
|
|
+ [<< 1:1, 1:1, 0:2, 2:4, 1:1, Len/bits >>, MaskKeyBin, Payload2];
|
|
|
+%% Data frames.
|
|
|
+masked_frame({text, Payload}, _) ->
|
|
|
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
|
|
|
+ Len = payload_length(Payload),
|
|
|
+ [<< 1:1, 0:3, 1:4, 1:1, Len/bits >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)];
|
|
|
+masked_frame({binary, Payload}, _) ->
|
|
|
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
|
|
|
+ Len = payload_length(Payload),
|
|
|
+ [<< 1:1, 0:3, 2:4, 1:1, Len/bits >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)].
|
|
|
+
|
|
|
payload_length(Payload) ->
|
|
|
case byte_size(Payload) of
|
|
|
N when N =< 125 -> << N:7 >>;
|