cowboy_websocket.erl 31 KB


  1. %% Copyright (c) 2011-2014, Loïc Hoguin <essen@ninenines.eu>
  2. %%
  3. %% Permission to use, copy, modify, and/or distribute this software for any
  4. %% purpose with or without fee is hereby granted, provided that the above
  5. %% copyright notice and this permission notice appear in all copies.
  6. %%
  7. %% THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  8. %% WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  9. %% MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  10. %% ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  11. %% WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  12. %% ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  13. %% OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  14. %% Cowboy supports versions 7 through 17 of the Websocket drafts.
  15. %% It also supports RFC6455, the proposed standard for Websocket.
  16. -module(cowboy_websocket).
  17. -behaviour(cowboy_sub_protocol).
  18. %% Ignore the deprecation warning for crypto:sha/1.
  19. %% @todo Remove when we support only R16B+.
  20. -compile(nowarn_deprecated_function).
  21. -export([upgrade/4]).
  22. -export([handler_loop/4]).
  23. -type close_code() :: 1000..4999.
  24. -export_type([close_code/0]).
  25. -type frame() :: close | ping | pong
  26. | {text | binary | close | ping | pong, iodata()}
  27. | {close, close_code(), iodata()}.
  28. -export_type([frame/0]).
  29. -type opcode() :: 0 | 1 | 2 | 8 | 9 | 10.
  30. -type mask_key() :: 0..16#ffffffff.
  31. -type frag_state() :: undefined
  32. | {nofin, opcode(), binary()} | {fin, opcode(), binary()}.
  33. -type rsv() :: << _:3 >>.
  34. -type terminate_reason() :: {normal | error | remote, atom()}
  35. | {remote, close_code(), binary()}.
  36. -record(state, {
  37. env :: cowboy_middleware:env(),
  38. socket = undefined :: inet:socket(),
  39. transport = undefined :: module(),
  40. handler :: module(),
  41. key = undefined :: undefined | binary(),
  42. timeout = infinity :: timeout(),
  43. timeout_ref = undefined :: undefined | reference(),
  44. messages = undefined :: undefined | {atom(), atom(), atom()},
  45. hibernate = false :: boolean(),
  46. frag_state = undefined :: frag_state(),
  47. utf8_state = <<>> :: binary(),
  48. deflate_frame = false :: boolean(),
  49. inflate_state :: undefined | port(),
  50. deflate_state :: undefined | port()
  51. }).
  52. -spec upgrade(Req, Env, module(), any())
  53. -> {ok, Req, Env} | {error, 400, Req}
  54. | {suspend, module(), atom(), [any()]}
  55. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  56. upgrade(Req, Env, Handler, HandlerOpts) ->
  57. {_, Ref} = lists:keyfind(listener, 1, Env),
  58. ranch:remove_connection(Ref),
  59. [Socket, Transport] = cowboy_req:get([socket, transport], Req),
  60. State = #state{env=Env, socket=Socket, transport=Transport,
  61. handler=Handler},
  62. try websocket_upgrade(State, Req) of
  63. {ok, State2, Req2} ->
  64. handler_init(State2, Req2, HandlerOpts)
  65. catch _:_ ->
  66. cowboy_req:maybe_reply(400, Req),
  67. exit(normal)
  68. end.
  69. -spec websocket_upgrade(#state{}, Req)
  70. -> {ok, #state{}, Req} when Req::cowboy_req:req().
  71. websocket_upgrade(State, Req) ->
  72. {ok, ConnTokens, Req2}
  73. = cowboy_req:parse_header(<<"connection">>, Req),
  74. true = lists:member(<<"upgrade">>, ConnTokens),
  75. %% @todo Should probably send a 426 if the Upgrade header is missing.
  76. {ok, [<<"websocket">>], Req3}
  77. = cowboy_req:parse_header(<<"upgrade">>, Req2),
  78. {Version, Req4} = cowboy_req:header(<<"sec-websocket-version">>, Req3),
  79. IntVersion = list_to_integer(binary_to_list(Version)),
  80. true = (IntVersion =:= 7) orelse (IntVersion =:= 8)
  81. orelse (IntVersion =:= 13),
  82. {Key, Req5} = cowboy_req:header(<<"sec-websocket-key">>, Req4),
  83. false = Key =:= undefined,
  84. websocket_extensions(State#state{key=Key},
  85. cowboy_req:set_meta(websocket_version, IntVersion, Req5)).
  86. -spec websocket_extensions(#state{}, Req)
  87. -> {ok, #state{}, Req} when Req::cowboy_req:req().
  88. websocket_extensions(State, Req) ->
  89. case cowboy_req:parse_header(<<"sec-websocket-extensions">>, Req) of
  90. {ok, Extensions, Req2} when Extensions =/= undefined ->
  91. [Compress] = cowboy_req:get([resp_compress], Req),
  92. case lists:keyfind(<<"x-webkit-deflate-frame">>, 1, Extensions) of
  93. {<<"x-webkit-deflate-frame">>, []} when Compress =:= true ->
  94. Inflate = zlib:open(),
  95. Deflate = zlib:open(),
  96. % Since we are negotiating an unconstrained deflate-frame
  97. % then we must be willing to accept frames using the
  98. % maximum window size which is 2^15. The negative value
  99. % indicates that zlib headers are not used.
  100. ok = zlib:inflateInit(Inflate, -15),
  101. % Initialize the deflater with a window size of 2^15 bits and disable
  102. % the zlib headers.
  103. ok = zlib:deflateInit(Deflate, best_compression, deflated, -15, 8, default),
  104. {ok, State#state{
  105. deflate_frame = true,
  106. inflate_state = Inflate,
  107. deflate_state = Deflate
  108. }, cowboy_req:set_meta(websocket_compress, true, Req2)};
  109. _ ->
  110. {ok, State, cowboy_req:set_meta(websocket_compress, false, Req2)}
  111. end;
  112. _ ->
  113. {ok, State, cowboy_req:set_meta(websocket_compress, false, Req)}
  114. end.
  115. -spec handler_init(#state{}, Req, any())
  116. -> {ok, Req, cowboy_middleware:env()} | {error, 400, Req}
  117. | {suspend, module(), atom(), [any()]}
  118. when Req::cowboy_req:req().
  119. handler_init(State=#state{env=Env, transport=Transport,
  120. handler=Handler}, Req, HandlerOpts) ->
  121. try Handler:websocket_init(Transport:name(), Req, HandlerOpts) of
  122. {ok, Req2, HandlerState} ->
  123. websocket_handshake(State, Req2, HandlerState);
  124. {ok, Req2, HandlerState, hibernate} ->
  125. websocket_handshake(State#state{hibernate=true},
  126. Req2, HandlerState);
  127. {ok, Req2, HandlerState, Timeout} ->
  128. websocket_handshake(State#state{timeout=Timeout},
  129. Req2, HandlerState);
  130. {ok, Req2, HandlerState, Timeout, hibernate} ->
  131. websocket_handshake(State#state{timeout=Timeout,
  132. hibernate=true}, Req2, HandlerState);
  133. {shutdown, Req2} ->
  134. cowboy_req:ensure_response(Req2, 400),
  135. {ok, Req2, [{result, closed}|Env]}
  136. catch Class:Reason ->
  137. cowboy_req:maybe_reply(400, Req),
  138. erlang:Class([
  139. {reason, Reason},
  140. {mfa, {Handler, websocket_init, 3}},
  141. {stacktrace, erlang:get_stacktrace()},
  142. {req, cowboy_req:to_list(Req)},
  143. {opts, HandlerOpts}
  144. ])
  145. end.
  146. -spec websocket_handshake(#state{}, Req, any())
  147. -> {ok, Req, cowboy_middleware:env()}
  148. | {suspend, module(), atom(), [any()]}
  149. when Req::cowboy_req:req().
  150. websocket_handshake(State=#state{
  151. transport=Transport, key=Key, deflate_frame=DeflateFrame},
  152. Req, HandlerState) ->
  153. %% @todo Change into crypto:hash/2 for R17B+ or when supporting only R16B+.
  154. Challenge = base64:encode(crypto:sha(
  155. << Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" >>)),
  156. Extensions = case DeflateFrame of
  157. false -> [];
  158. true -> [{<<"sec-websocket-extensions">>, <<"x-webkit-deflate-frame">>}]
  159. end,
  160. {ok, Req2} = cowboy_req:upgrade_reply(
  161. 101,
  162. [{<<"upgrade">>, <<"websocket">>},
  163. {<<"sec-websocket-accept">>, Challenge}|
  164. Extensions],
  165. Req),
  166. %% Flush the resp_sent message before moving on.
  167. receive {cowboy_req, resp_sent} -> ok after 0 -> ok end,
  168. State2 = handler_loop_timeout(State),
  169. handler_before_loop(State2#state{key=undefined,
  170. messages=Transport:messages()}, Req2, HandlerState, <<>>).
  171. -spec handler_before_loop(#state{}, Req, any(), binary())
  172. -> {ok, Req, cowboy_middleware:env()}
  173. | {suspend, module(), atom(), [any()]}
  174. when Req::cowboy_req:req().
  175. handler_before_loop(State=#state{
  176. socket=Socket, transport=Transport, hibernate=true},
  177. Req, HandlerState, SoFar) ->
  178. Transport:setopts(Socket, [{active, once}]),
  179. {suspend, ?MODULE, handler_loop,
  180. [State#state{hibernate=false}, Req, HandlerState, SoFar]};
  181. handler_before_loop(State=#state{socket=Socket, transport=Transport},
  182. Req, HandlerState, SoFar) ->
  183. Transport:setopts(Socket, [{active, once}]),
  184. handler_loop(State, Req, HandlerState, SoFar).
  185. -spec handler_loop_timeout(#state{}) -> #state{}.
  186. handler_loop_timeout(State=#state{timeout=infinity}) ->
  187. State#state{timeout_ref=undefined};
  188. handler_loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) ->
  189. _ = case PrevRef of undefined -> ignore; PrevRef ->
  190. erlang:cancel_timer(PrevRef) end,
  191. TRef = erlang:start_timer(Timeout, self(), ?MODULE),
  192. State#state{timeout_ref=TRef}.
  193. -spec handler_loop(#state{}, Req, any(), binary())
  194. -> {ok, Req, cowboy_middleware:env()}
  195. | {suspend, module(), atom(), [any()]}
  196. when Req::cowboy_req:req().
  197. handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error},
  198. timeout_ref=TRef}, Req, HandlerState, SoFar) ->
  199. receive
  200. {OK, Socket, Data} ->
  201. State2 = handler_loop_timeout(State),
  202. websocket_data(State2, Req, HandlerState,
  203. << SoFar/binary, Data/binary >>);
  204. {Closed, Socket} ->
  205. handler_terminate(State, Req, HandlerState, {error, closed});
  206. {Error, Socket, Reason} ->
  207. handler_terminate(State, Req, HandlerState, {error, Reason});
  208. {timeout, TRef, ?MODULE} ->
  209. websocket_close(State, Req, HandlerState, {normal, timeout});
  210. {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
  211. handler_loop(State, Req, HandlerState, SoFar);
  212. Message ->
  213. handler_call(State, Req, HandlerState,
  214. SoFar, websocket_info, Message, fun handler_before_loop/4)
  215. end.
  216. %% All frames passing through this function are considered valid,
  217. %% with the only exception of text and close frames with a payload
  218. %% which may still contain errors.
  219. -spec websocket_data(#state{}, Req, any(), binary())
  220. -> {ok, Req, cowboy_middleware:env()}
  221. | {suspend, module(), atom(), [any()]}
  222. when Req::cowboy_req:req().
  223. %% RSV bits MUST be 0 unless an extension is negotiated
  224. %% that defines meanings for non-zero values.
  225. websocket_data(State, Req, HandlerState, << _:1, Rsv:3, _/bits >>)
  226. when Rsv =/= 0, State#state.deflate_frame =:= false ->
  227. websocket_close(State, Req, HandlerState, {error, badframe});
  228. %% Invalid opcode. Note that these opcodes may be used by extensions.
  229. websocket_data(State, Req, HandlerState, << _:4, Opcode:4, _/bits >>)
  230. when Opcode > 2, Opcode =/= 8, Opcode =/= 9, Opcode =/= 10 ->
  231. websocket_close(State, Req, HandlerState, {error, badframe});
  232. %% Control frames MUST NOT be fragmented.
  233. websocket_data(State, Req, HandlerState, << 0:1, _:3, Opcode:4, _/bits >>)
  234. when Opcode >= 8 ->
  235. websocket_close(State, Req, HandlerState, {error, badframe});
  236. %% A frame MUST NOT use the zero opcode unless fragmentation was initiated.
  237. websocket_data(State=#state{frag_state=undefined}, Req, HandlerState,
  238. << _:4, 0:4, _/bits >>) ->
  239. websocket_close(State, Req, HandlerState, {error, badframe});
  240. %% Non-control opcode when expecting control message or next fragment.
  241. websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState,
  242. << _:4, Opcode:4, _/bits >>)
  243. when Opcode =/= 0, Opcode < 8 ->
  244. websocket_close(State, Req, HandlerState, {error, badframe});
  245. %% Close control frame length MUST be 0 or >= 2.
  246. websocket_data(State, Req, HandlerState, << _:4, 8:4, _:1, 1:7, _/bits >>) ->
  247. websocket_close(State, Req, HandlerState, {error, badframe});
  248. %% Close control frame with incomplete close code. Need more data.
  249. websocket_data(State, Req, HandlerState,
  250. Data = << _:4, 8:4, 1:1, Len:7, _/bits >>)
  251. when Len > 1, byte_size(Data) < 8 ->
  252. handler_before_loop(State, Req, HandlerState, Data);
  253. %% 7 bits payload length.
  254. websocket_data(State, Req, HandlerState, << Fin:1, Rsv:3/bits, Opcode:4, 1:1,
  255. Len:7, MaskKey:32, Rest/bits >>)
  256. when Len < 126 ->
  257. websocket_data(State, Req, HandlerState,
  258. Opcode, Len, MaskKey, Rest, Rsv, Fin);
  259. %% 16 bits payload length.
  260. websocket_data(State, Req, HandlerState, << Fin:1, Rsv:3/bits, Opcode:4, 1:1,
  261. 126:7, Len:16, MaskKey:32, Rest/bits >>)
  262. when Len > 125, Opcode < 8 ->
  263. websocket_data(State, Req, HandlerState,
  264. Opcode, Len, MaskKey, Rest, Rsv, Fin);
  265. %% 63 bits payload length.
  266. websocket_data(State, Req, HandlerState, << Fin:1, Rsv:3/bits, Opcode:4, 1:1,
  267. 127:7, 0:1, Len:63, MaskKey:32, Rest/bits >>)
  268. when Len > 16#ffff, Opcode < 8 ->
  269. websocket_data(State, Req, HandlerState,
  270. Opcode, Len, MaskKey, Rest, Rsv, Fin);
  271. %% When payload length is over 63 bits, the most significant bit MUST be 0.
  272. websocket_data(State, Req, HandlerState, << _:8, 1:1, 127:7, 1:1, _:7, _/binary >>) ->
  273. websocket_close(State, Req, HandlerState, {error, badframe});
  274. %% All frames sent from the client to the server are masked.
  275. websocket_data(State, Req, HandlerState, << _:8, 0:1, _/bits >>) ->
  276. websocket_close(State, Req, HandlerState, {error, badframe});
  277. %% For the next two clauses, it can be one of the following:
  278. %%
  279. %% * The minimal number of bytes MUST be used to encode the length
  280. %% * All control frames MUST have a payload length of 125 bytes or less
  281. websocket_data(State, Req, HandlerState, << _:9, 126:7, _:48, _/bits >>) ->
  282. websocket_close(State, Req, HandlerState, {error, badframe});
  283. websocket_data(State, Req, HandlerState, << _:9, 127:7, _:96, _/bits >>) ->
  284. websocket_close(State, Req, HandlerState, {error, badframe});
  285. %% Need more data.
  286. websocket_data(State, Req, HandlerState, Data) ->
  287. handler_before_loop(State, Req, HandlerState, Data).
  288. %% Initialize or update fragmentation state.
  289. -spec websocket_data(#state{}, Req, any(),
  290. opcode(), non_neg_integer(), mask_key(), binary(), rsv(), 0 | 1)
  291. -> {ok, Req, cowboy_middleware:env()}
  292. | {suspend, module(), atom(), [any()]}
  293. when Req::cowboy_req:req().
  294. %% The opcode is only included in the first frame fragment.
  295. websocket_data(State=#state{frag_state=undefined}, Req, HandlerState,
  296. Opcode, Len, MaskKey, Data, Rsv, 0) ->
  297. websocket_payload(State#state{frag_state={nofin, Opcode, <<>>}},
  298. Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv);
  299. %% Subsequent frame fragments.
  300. websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState,
  301. 0, Len, MaskKey, Data, Rsv, 0) ->
  302. websocket_payload(State, Req, HandlerState,
  303. 0, Len, MaskKey, <<>>, 0, Data, Rsv);
  304. %% Final frame fragment.
  305. websocket_data(State=#state{frag_state={nofin, Opcode, SoFar}},
  306. Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 1) ->
  307. websocket_payload(State#state{frag_state={fin, Opcode, SoFar}},
  308. Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv);
  309. %% Unfragmented frame.
  310. websocket_data(State, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 1) ->
  311. websocket_payload(State, Req, HandlerState,
  312. Opcode, Len, MaskKey, <<>>, 0, Data, Rsv).
  313. -spec websocket_payload(#state{}, Req, any(),
  314. opcode(), non_neg_integer(), mask_key(), binary(), non_neg_integer(),
  315. binary(), rsv())
  316. -> {ok, Req, cowboy_middleware:env()}
  317. | {suspend, module(), atom(), [any()]}
  318. when Req::cowboy_req:req().
  319. %% Close control frames with a payload MUST contain a valid close code.
  320. websocket_payload(State, Req, HandlerState,
  321. Opcode=8, Len, MaskKey, <<>>, 0,
  322. << MaskedCode:2/binary, Rest/bits >>, Rsv) ->
  323. Unmasked = << Code:16 >> = websocket_unmask(MaskedCode, MaskKey, <<>>),
  324. if Code < 1000; Code =:= 1004; Code =:= 1005; Code =:= 1006;
  325. (Code > 1011) and (Code < 3000); Code > 4999 ->
  326. websocket_close(State, Req, HandlerState, {error, badframe});
  327. true ->
  328. websocket_payload(State, Req, HandlerState,
  329. Opcode, Len - 2, MaskKey, Unmasked, byte_size(MaskedCode),
  330. Rest, Rsv)
  331. end;
  332. %% Text frames and close control frames MUST have a payload that is valid UTF-8.
  333. websocket_payload(State=#state{utf8_state=Incomplete},
  334. Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen,
  335. Data, Rsv)
  336. when (byte_size(Data) < Len) andalso ((Opcode =:= 1) orelse
  337. ((Opcode =:= 8) andalso (Unmasked =/= <<>>))) ->
  338. Unmasked2 = websocket_unmask(Data,
  339. rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
  340. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
  341. case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
  342. false ->
  343. websocket_close(State2, Req, HandlerState, {error, badencoding});
  344. Utf8State ->
  345. websocket_payload_loop(State2#state{utf8_state=Utf8State},
  346. Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey,
  347. << Unmasked/binary, Unmasked3/binary >>,
  348. UnmaskedLen + byte_size(Data), Rsv)
  349. end;
  350. websocket_payload(State=#state{utf8_state=Incomplete},
  351. Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen,
  352. Data, Rsv)
  353. when Opcode =:= 1; (Opcode =:= 8) and (Unmasked =/= <<>>) ->
  354. << End:Len/binary, Rest/bits >> = Data,
  355. Unmasked2 = websocket_unmask(End,
  356. rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
  357. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
  358. case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
  359. <<>> ->
  360. websocket_dispatch(State2#state{utf8_state= <<>>},
  361. Req, HandlerState, Rest, Opcode,
  362. << Unmasked/binary, Unmasked3/binary >>);
  363. _ ->
  364. websocket_close(State2, Req, HandlerState, {error, badencoding})
  365. end;
  366. %% Fragmented text frames may cut payload in the middle of UTF-8 codepoints.
  367. websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete},
  368. Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen,
  369. Data, Rsv)
  370. when byte_size(Data) < Len ->
  371. Unmasked2 = websocket_unmask(Data,
  372. rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
  373. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
  374. case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
  375. false ->
  376. websocket_close(State2, Req, HandlerState, {error, badencoding});
  377. Utf8State ->
  378. websocket_payload_loop(State2#state{utf8_state=Utf8State},
  379. Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey,
  380. << Unmasked/binary, Unmasked3/binary >>,
  381. UnmaskedLen + byte_size(Data), Rsv)
  382. end;
  383. websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete},
  384. Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen,
  385. Data, Rsv) ->
  386. << End:Len/binary, Rest/bits >> = Data,
  387. Unmasked2 = websocket_unmask(End,
  388. rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
  389. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, Fin =:= fin, State),
  390. case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
  391. <<>> ->
  392. websocket_dispatch(State2#state{utf8_state= <<>>},
  393. Req, HandlerState, Rest, Opcode,
  394. << Unmasked/binary, Unmasked3/binary >>);
  395. Utf8State when is_binary(Utf8State), Fin =:= nofin ->
  396. websocket_dispatch(State2#state{utf8_state=Utf8State},
  397. Req, HandlerState, Rest, Opcode,
  398. << Unmasked/binary, Unmasked3/binary >>);
  399. _ ->
  400. websocket_close(State, Req, HandlerState, {error, badencoding})
  401. end;
  402. %% Other frames have a binary payload.
  403. websocket_payload(State, Req, HandlerState,
  404. Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv)
  405. when byte_size(Data) < Len ->
  406. Unmasked2 = websocket_unmask(Data,
  407. rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
  408. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
  409. websocket_payload_loop(State2, Req, HandlerState,
  410. Opcode, Len - byte_size(Data), MaskKey,
  411. << Unmasked/binary, Unmasked3/binary >>, UnmaskedLen + byte_size(Data),
  412. Rsv);
  413. websocket_payload(State, Req, HandlerState,
  414. Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv) ->
  415. << End:Len/binary, Rest/bits >> = Data,
  416. Unmasked2 = websocket_unmask(End,
  417. rotate_mask_key(MaskKey, UnmaskedLen), <<>>),
  418. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
  419. websocket_dispatch(State2, Req, HandlerState, Rest, Opcode,
  420. << Unmasked/binary, Unmasked3/binary >>).
  421. -spec websocket_inflate_frame(binary(), rsv(), boolean(), #state{}) ->
  422. {binary(), #state{}}.
  423. websocket_inflate_frame(Data, << Rsv1:1, _:2 >>, _,
  424. #state{deflate_frame = DeflateFrame} = State)
  425. when DeflateFrame =:= false orelse Rsv1 =:= 0 ->
  426. {Data, State};
  427. websocket_inflate_frame(Data, << 1:1, _:2 >>, false, State) ->
  428. Result = zlib:inflate(State#state.inflate_state, Data),
  429. {iolist_to_binary(Result), State};
  430. websocket_inflate_frame(Data, << 1:1, _:2 >>, true, State) ->
  431. Result = zlib:inflate(State#state.inflate_state,
  432. << Data/binary, 0:8, 0:8, 255:8, 255:8 >>),
  433. {iolist_to_binary(Result), State}.
  434. -spec websocket_unmask(B, mask_key(), B) -> B when B::binary().
  435. websocket_unmask(<<>>, _, Unmasked) ->
  436. Unmasked;
  437. websocket_unmask(<< O:32, Rest/bits >>, MaskKey, Acc) ->
  438. T = O bxor MaskKey,
  439. websocket_unmask(Rest, MaskKey, << Acc/binary, T:32 >>);
  440. websocket_unmask(<< O:24 >>, MaskKey, Acc) ->
  441. << MaskKey2:24, _:8 >> = << MaskKey:32 >>,
  442. T = O bxor MaskKey2,
  443. << Acc/binary, T:24 >>;
  444. websocket_unmask(<< O:16 >>, MaskKey, Acc) ->
  445. << MaskKey2:16, _:16 >> = << MaskKey:32 >>,
  446. T = O bxor MaskKey2,
  447. << Acc/binary, T:16 >>;
  448. websocket_unmask(<< O:8 >>, MaskKey, Acc) ->
  449. << MaskKey2:8, _:24 >> = << MaskKey:32 >>,
  450. T = O bxor MaskKey2,
  451. << Acc/binary, T:8 >>.
  452. %% Because we unmask on the fly we need to continue from the right mask byte.
  453. -spec rotate_mask_key(mask_key(), non_neg_integer()) -> mask_key().
  454. rotate_mask_key(MaskKey, UnmaskedLen) ->
  455. Left = UnmaskedLen rem 4,
  456. Right = 4 - Left,
  457. (MaskKey bsl (Left * 8)) + (MaskKey bsr (Right * 8)).
  458. %% Returns <<>> if the argument is valid UTF-8, false if not,
  459. %% or the incomplete part of the argument if we need more data.
  460. -spec is_utf8(binary()) -> false | binary().
  461. is_utf8(Valid = <<>>) ->
  462. Valid;
  463. is_utf8(<< _/utf8, Rest/binary >>) ->
  464. is_utf8(Rest);
  465. %% 2 bytes. Codepages C0 and C1 are invalid; fail early.
  466. is_utf8(<< 2#1100000:7, _/bits >>) ->
  467. false;
  468. is_utf8(Incomplete = << 2#110:3, _:5 >>) ->
  469. Incomplete;
  470. %% 3 bytes.
  471. is_utf8(Incomplete = << 2#1110:4, _:4 >>) ->
  472. Incomplete;
  473. is_utf8(Incomplete = << 2#1110:4, _:4, 2#10:2, _:6 >>) ->
  474. Incomplete;
  475. %% 4 bytes. Codepage F4 may have invalid values greater than 0x10FFFF.
  476. is_utf8(<< 2#11110100:8, 2#10:2, High:6, _/bits >>) when High >= 2#10000 ->
  477. false;
  478. is_utf8(Incomplete = << 2#11110:5, _:3 >>) ->
  479. Incomplete;
  480. is_utf8(Incomplete = << 2#11110:5, _:3, 2#10:2, _:6 >>) ->
  481. Incomplete;
  482. is_utf8(Incomplete = << 2#11110:5, _:3, 2#10:2, _:6, 2#10:2, _:6 >>) ->
  483. Incomplete;
  484. %% Invalid.
  485. is_utf8(_) ->
  486. false.
  487. -spec websocket_payload_loop(#state{}, Req, any(),
  488. opcode(), non_neg_integer(), mask_key(), binary(),
  489. non_neg_integer(), rsv())
  490. -> {ok, Req, cowboy_middleware:env()}
  491. | {suspend, module(), atom(), [any()]}
  492. when Req::cowboy_req:req().
  493. websocket_payload_loop(State=#state{socket=Socket, transport=Transport,
  494. messages={OK, Closed, Error}, timeout_ref=TRef},
  495. Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv) ->
  496. Transport:setopts(Socket, [{active, once}]),
  497. receive
  498. {OK, Socket, Data} ->
  499. State2 = handler_loop_timeout(State),
  500. websocket_payload(State2, Req, HandlerState,
  501. Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv);
  502. {Closed, Socket} ->
  503. handler_terminate(State, Req, HandlerState, {error, closed});
  504. {Error, Socket, Reason} ->
  505. handler_terminate(State, Req, HandlerState, {error, Reason});
  506. {timeout, TRef, ?MODULE} ->
  507. websocket_close(State, Req, HandlerState, {normal, timeout});
  508. {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
  509. websocket_payload_loop(State, Req, HandlerState,
  510. Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv);
  511. Message ->
  512. handler_call(State, Req, HandlerState,
  513. <<>>, websocket_info, Message,
  514. fun (State2, Req2, HandlerState2, _) ->
  515. websocket_payload_loop(State2, Req2, HandlerState2,
  516. Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv)
  517. end)
  518. end.
  519. -spec websocket_dispatch(#state{}, Req, any(), binary(), opcode(), binary())
  520. -> {ok, Req, cowboy_middleware:env()}
  521. | {suspend, module(), atom(), [any()]}
  522. when Req::cowboy_req:req().
  523. %% Continuation frame.
  524. websocket_dispatch(State=#state{frag_state={nofin, Opcode, SoFar}},
  525. Req, HandlerState, RemainingData, 0, Payload) ->
  526. websocket_data(State#state{frag_state={nofin, Opcode,
  527. << SoFar/binary, Payload/binary >>}}, Req, HandlerState, RemainingData);
  528. %% Last continuation frame.
  529. websocket_dispatch(State=#state{frag_state={fin, Opcode, SoFar}},
  530. Req, HandlerState, RemainingData, 0, Payload) ->
  531. websocket_dispatch(State#state{frag_state=undefined}, Req, HandlerState,
  532. RemainingData, Opcode, << SoFar/binary, Payload/binary >>);
  533. %% Text frame.
  534. websocket_dispatch(State, Req, HandlerState, RemainingData, 1, Payload) ->
  535. handler_call(State, Req, HandlerState, RemainingData,
  536. websocket_handle, {text, Payload}, fun websocket_data/4);
  537. %% Binary frame.
  538. websocket_dispatch(State, Req, HandlerState, RemainingData, 2, Payload) ->
  539. handler_call(State, Req, HandlerState, RemainingData,
  540. websocket_handle, {binary, Payload}, fun websocket_data/4);
  541. %% Close control frame.
  542. websocket_dispatch(State, Req, HandlerState, _RemainingData, 8, <<>>) ->
  543. websocket_close(State, Req, HandlerState, {remote, closed});
  544. websocket_dispatch(State, Req, HandlerState, _RemainingData, 8,
  545. << Code:16, Payload/bits >>) ->
  546. websocket_close(State, Req, HandlerState, {remote, Code, Payload});
  547. %% Ping control frame. Send a pong back and forward the ping to the handler.
  548. websocket_dispatch(State=#state{socket=Socket, transport=Transport},
  549. Req, HandlerState, RemainingData, 9, Payload) ->
  550. Len = payload_length_to_binary(byte_size(Payload)),
  551. Transport:send(Socket, << 1:1, 0:3, 10:4, 0:1, Len/bits, Payload/binary >>),
  552. handler_call(State, Req, HandlerState, RemainingData,
  553. websocket_handle, {ping, Payload}, fun websocket_data/4);
  554. %% Pong control frame.
  555. websocket_dispatch(State, Req, HandlerState, RemainingData, 10, Payload) ->
  556. handler_call(State, Req, HandlerState, RemainingData,
  557. websocket_handle, {pong, Payload}, fun websocket_data/4).
  558. -spec handler_call(#state{}, Req, any(), binary(), atom(), any(), fun())
  559. -> {ok, Req, cowboy_middleware:env()}
  560. | {suspend, module(), atom(), [any()]}
  561. when Req::cowboy_req:req().
  562. handler_call(State=#state{handler=Handler}, Req, HandlerState,
  563. RemainingData, Callback, Message, NextState) ->
  564. try Handler:Callback(Message, Req, HandlerState) of
  565. {ok, Req2, HandlerState2} ->
  566. NextState(State, Req2, HandlerState2, RemainingData);
  567. {ok, Req2, HandlerState2, hibernate} ->
  568. NextState(State#state{hibernate=true},
  569. Req2, HandlerState2, RemainingData);
  570. {reply, Payload, Req2, HandlerState2}
  571. when is_list(Payload) ->
  572. case websocket_send_many(Payload, State) of
  573. {ok, State2} ->
  574. NextState(State2, Req2, HandlerState2, RemainingData);
  575. {shutdown, State2} ->
  576. handler_terminate(State2, Req2, HandlerState2,
  577. {normal, shutdown});
  578. {{error, _} = Error, State2} ->
  579. handler_terminate(State2, Req2, HandlerState2, Error)
  580. end;
  581. {reply, Payload, Req2, HandlerState2, hibernate}
  582. when is_list(Payload) ->
  583. case websocket_send_many(Payload, State) of
  584. {ok, State2} ->
  585. NextState(State2#state{hibernate=true},
  586. Req2, HandlerState2, RemainingData);
  587. {shutdown, State2} ->
  588. handler_terminate(State2, Req2, HandlerState2,
  589. {normal, shutdown});
  590. {{error, _} = Error, State2} ->
  591. handler_terminate(State2, Req2, HandlerState2, Error)
  592. end;
  593. {reply, Payload, Req2, HandlerState2} ->
  594. case websocket_send(Payload, State) of
  595. {ok, State2} ->
  596. NextState(State2, Req2, HandlerState2, RemainingData);
  597. {shutdown, State2} ->
  598. handler_terminate(State2, Req2, HandlerState2,
  599. {normal, shutdown});
  600. {{error, _} = Error, State2} ->
  601. handler_terminate(State2, Req2, HandlerState2, Error)
  602. end;
  603. {reply, Payload, Req2, HandlerState2, hibernate} ->
  604. case websocket_send(Payload, State) of
  605. {ok, State2} ->
  606. NextState(State2#state{hibernate=true},
  607. Req2, HandlerState2, RemainingData);
  608. {shutdown, State2} ->
  609. handler_terminate(State2, Req2, HandlerState2,
  610. {normal, shutdown});
  611. {{error, _} = Error, State2} ->
  612. handler_terminate(State2, Req2, HandlerState2, Error)
  613. end;
  614. {shutdown, Req2, HandlerState2} ->
  615. websocket_close(State, Req2, HandlerState2, {normal, shutdown})
  616. catch Class:Reason ->
  617. _ = websocket_close(State, Req, HandlerState, {error, handler}),
  618. erlang:Class([
  619. {reason, Reason},
  620. {mfa, {Handler, Callback, 3}},
  621. {stacktrace, erlang:get_stacktrace()},
  622. {msg, Message},
  623. {req, cowboy_req:to_list(Req)},
  624. {state, HandlerState}
  625. ])
  626. end.
  627. websocket_opcode(text) -> 1;
  628. websocket_opcode(binary) -> 2;
  629. websocket_opcode(close) -> 8;
  630. websocket_opcode(ping) -> 9;
  631. websocket_opcode(pong) -> 10.
  632. -spec websocket_deflate_frame(opcode(), binary(), #state{}) ->
  633. {binary(), rsv(), #state{}}.
  634. websocket_deflate_frame(Opcode, Payload,
  635. State=#state{deflate_frame = DeflateFrame})
  636. when DeflateFrame =:= false orelse Opcode >= 8 ->
  637. {Payload, << 0:3 >>, State};
  638. websocket_deflate_frame(_, Payload, State=#state{deflate_state = Deflate}) ->
  639. Deflated = iolist_to_binary(zlib:deflate(Deflate, Payload, sync)),
  640. DeflatedBodyLength = erlang:size(Deflated) - 4,
  641. Deflated1 = case Deflated of
  642. << Body:DeflatedBodyLength/binary, 0:8, 0:8, 255:8, 255:8 >> -> Body;
  643. _ -> Deflated
  644. end,
  645. {Deflated1, << 1:1, 0:2 >>, State}.
  646. -spec websocket_send(frame(), #state{})
  647. -> {ok, #state{}} | {shutdown, #state{}} | {{error, atom()}, #state{}}.
  648. websocket_send(Type, State=#state{socket=Socket, transport=Transport})
  649. when Type =:= close ->
  650. Opcode = websocket_opcode(Type),
  651. case Transport:send(Socket, << 1:1, 0:3, Opcode:4, 0:8 >>) of
  652. ok -> {shutdown, State};
  653. Error -> {Error, State}
  654. end;
  655. websocket_send(Type, State=#state{socket=Socket, transport=Transport})
  656. when Type =:= ping; Type =:= pong ->
  657. Opcode = websocket_opcode(Type),
  658. {Transport:send(Socket, << 1:1, 0:3, Opcode:4, 0:8 >>), State};
  659. websocket_send({close, Payload}, State) ->
  660. websocket_send({close, 1000, Payload}, State);
  661. websocket_send({Type = close, StatusCode, Payload}, State=#state{
  662. socket=Socket, transport=Transport}) ->
  663. Opcode = websocket_opcode(Type),
  664. Len = 2 + iolist_size(Payload),
  665. %% Control packets must not be > 125 in length.
  666. true = Len =< 125,
  667. BinLen = payload_length_to_binary(Len),
  668. Transport:send(Socket,
  669. [<< 1:1, 0:3, Opcode:4, 0:1, BinLen/bits, StatusCode:16 >>, Payload]),
  670. {shutdown, State};
  671. websocket_send({Type, Payload0}, State=#state{socket=Socket, transport=Transport}) ->
  672. Opcode = websocket_opcode(Type),
  673. {Payload, Rsv, State2} = websocket_deflate_frame(Opcode, iolist_to_binary(Payload0), State),
  674. Len = iolist_size(Payload),
  675. %% Control packets must not be > 125 in length.
  676. true = if Type =:= ping; Type =:= pong ->
  677. Len =< 125;
  678. true ->
  679. true
  680. end,
  681. BinLen = payload_length_to_binary(Len),
  682. {Transport:send(Socket,
  683. [<< 1:1, Rsv/bits, Opcode:4, 0:1, BinLen/bits >>, Payload]), State2}.
  684. -spec websocket_send_many([frame()], #state{})
  685. -> {ok, #state{}} | {shutdown, #state{}} | {{error, atom()}, #state{}}.
  686. websocket_send_many([], State) ->
  687. {ok, State};
  688. websocket_send_many([Frame|Tail], State) ->
  689. case websocket_send(Frame, State) of
  690. {ok, State2} -> websocket_send_many(Tail, State2);
  691. {shutdown, State2} -> {shutdown, State2};
  692. {Error, State2} -> {Error, State2}
  693. end.
  694. -spec websocket_close(#state{}, Req, any(), terminate_reason())
  695. -> {ok, Req, cowboy_middleware:env()}
  696. when Req::cowboy_req:req().
  697. websocket_close(State=#state{socket=Socket, transport=Transport},
  698. Req, HandlerState, Reason) ->
  699. case Reason of
  700. {normal, _} ->
  701. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1000:16 >>);
  702. {error, badframe} ->
  703. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1002:16 >>);
  704. {error, badencoding} ->
  705. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1007:16 >>);
  706. {error, handler} ->
  707. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1011:16 >>);
  708. {remote, closed} ->
  709. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:8 >>);
  710. {remote, Code, _} ->
  711. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, Code:16 >>)
  712. end,
  713. handler_terminate(State, Req, HandlerState, Reason).
  714. -spec handler_terminate(#state{}, Req, any(), terminate_reason())
  715. -> {ok, Req, cowboy_middleware:env()}
  716. when Req::cowboy_req:req().
  717. handler_terminate(#state{env=Env, handler=Handler},
  718. Req, HandlerState, TerminateReason) ->
  719. try
  720. Handler:websocket_terminate(TerminateReason, Req, HandlerState)
  721. catch Class:Reason ->
  722. erlang:Class([
  723. {reason, Reason},
  724. {mfa, {Handler, websocket_terminate, 3}},
  725. {stacktrace, erlang:get_stacktrace()},
  726. {req, cowboy_req:to_list(Req)},
  727. {state, HandlerState},
  728. {terminate_reason, TerminateReason}
  729. ])
  730. end,
  731. {ok, Req, [{result, closed}|Env]}.
  732. -spec payload_length_to_binary(0..16#7fffffffffffffff)
  733. -> << _:7 >> | << _:23 >> | << _:71 >>.
  734. payload_length_to_binary(N) ->
  735. case N of
  736. N when N =< 125 -> << N:7 >>;
  737. N when N =< 16#ffff -> << 126:7, N:16 >>;
  738. N when N =< 16#7fffffffffffffff -> << 127:7, N:64 >>
  739. end.