cowboy_websocket.erl 32 KB


  1. %% Copyright (c) 2011-2013, Loïc Hoguin <essen@ninenines.eu>
  2. %%
  3. %% Permission to use, copy, modify, and/or distribute this software for any
  4. %% purpose with or without fee is hereby granted, provided that the above
  5. %% copyright notice and this permission notice appear in all copies.
  6. %%
  7. %% THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  8. %% WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  9. %% MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  10. %% ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  11. %% WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  12. %% ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  13. %% OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  14. %% @doc Websocket protocol implementation.
  15. %%
  16. %% Cowboy supports versions 7 through 17 of the Websocket drafts.
  17. %% It also supports RFC6455, the proposed standard for Websocket.
  18. -module(cowboy_websocket).
  19. -behaviour(cowboy_sub_protocol).
  20. %% Ignore the deprecation warning for crypto:sha/1.
  21. %% @todo Remove when we support only R16B+.
  22. -compile({nowarn_deprecated_function, {crypto, sha, 1}}).
  23. %% API.
  24. -export([upgrade/4]).
  25. %% Internal.
  26. -export([handler_loop/4]).
  27. -type close_code() :: 1000..4999.
  28. -export_type([close_code/0]).
  29. -type frame() :: close | ping | pong
  30. | {text | binary | close | ping | pong, iodata()}
  31. | {close, close_code(), iodata()}.
  32. -export_type([frame/0]).
  33. -type opcode() :: 0 | 1 | 2 | 8 | 9 | 10.
  34. -type mask_key() :: 0..16#ffffffff.
  35. -type frag_state() :: undefined
  36. | {nofin, opcode(), binary()} | {fin, opcode(), binary()}.
  37. -type rsv() :: << _:3 >>.
  38. -record(state, {
  39. env :: cowboy_middleware:env(),
  40. socket = undefined :: inet:socket(),
  41. transport = undefined :: module(),
  42. handler :: module(),
  43. handler_opts :: any(),
  44. key = undefined :: undefined | binary(),
  45. timeout = infinity :: timeout(),
  46. timeout_ref = undefined :: undefined | reference(),
  47. messages = undefined :: undefined | {atom(), atom(), atom()},
  48. hibernate = false :: boolean(),
  49. frag_state = undefined :: frag_state(),
  50. utf8_state = <<>> :: binary(),
  51. deflate_frame = false :: boolean(),
  52. inflate_state :: any(),
  53. inflate_buffer = <<>> :: binary(),
  54. deflate_state :: any()
  55. }).
  56. %% @doc Upgrade an HTTP request to the Websocket protocol.
  57. %%
  58. %% You do not need to call this function manually. To upgrade to the Websocket
  59. %% protocol, you simply need to return <em>{upgrade, protocol, {@module}}</em>
  60. %% in your <em>cowboy_http_handler:init/3</em> handler function.
  61. -spec upgrade(Req, Env, module(), any())
  62. -> {ok, Req, Env} | {error, 400, Req}
  63. | {suspend, module(), atom(), [any()]}
  64. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  65. upgrade(Req, Env, Handler, HandlerOpts) ->
  66. {_, Ref} = lists:keyfind(listener, 1, Env),
  67. ranch:remove_connection(Ref),
  68. [Socket, Transport] = cowboy_req:get([socket, transport], Req),
  69. State = #state{env=Env, socket=Socket, transport=Transport,
  70. handler=Handler, handler_opts=HandlerOpts},
  71. case catch websocket_upgrade(State, Req) of
  72. {ok, State2, Req2} -> handler_init(State2, Req2);
  73. {'EXIT', _Reason} -> upgrade_error(Req, Env)
  74. end.
  75. -spec websocket_upgrade(#state{}, Req)
  76. -> {ok, #state{}, Req} when Req::cowboy_req:req().
  77. websocket_upgrade(State, Req) ->
  78. {ok, ConnTokens, Req2}
  79. = cowboy_req:parse_header(<<"connection">>, Req),
  80. true = lists:member(<<"upgrade">>, ConnTokens),
  81. %% @todo Should probably send a 426 if the Upgrade header is missing.
  82. {ok, [<<"websocket">>], Req3}
  83. = cowboy_req:parse_header(<<"upgrade">>, Req2),
  84. {Version, Req4} = cowboy_req:header(<<"sec-websocket-version">>, Req3),
  85. IntVersion = list_to_integer(binary_to_list(Version)),
  86. true = (IntVersion =:= 7) orelse (IntVersion =:= 8)
  87. orelse (IntVersion =:= 13),
  88. {Key, Req5} = cowboy_req:header(<<"sec-websocket-key">>, Req4),
  89. false = Key =:= undefined,
  90. websocket_extensions(State#state{key=Key},
  91. cowboy_req:set_meta(websocket_version, IntVersion, Req5)).
  92. -spec websocket_extensions(#state{}, Req)
  93. -> {ok, #state{}, Req} when Req::cowboy_req:req().
  94. websocket_extensions(State, Req) ->
  95. case cowboy_req:parse_header(<<"sec-websocket-extensions">>, Req) of
  96. {ok, Extensions, Req2} when Extensions =/= undefined ->
  97. [Compress] = cowboy_req:get([resp_compress], Req),
  98. case lists:keyfind(<<"x-webkit-deflate-frame">>, 1, Extensions) of
  99. {<<"x-webkit-deflate-frame">>, []} when Compress =:= true ->
  100. Inflate = zlib:open(),
  101. Deflate = zlib:open(),
  102. % Since we are negotiating an unconstrained deflate-frame
  103. % then we must be willing to accept frames using the
  104. % maximum window size which is 2^15. The negative value
  105. % indicates that zlib headers are not used.
  106. ok = zlib:inflateInit(Inflate, -15),
  107. % Initialize the deflater with a window size of 2^15 bits and disable
  108. % the zlib headers.
  109. ok = zlib:deflateInit(Deflate, best_compression, deflated, -15, 8, default),
  110. {ok, State#state{
  111. deflate_frame = true,
  112. inflate_state = Inflate,
  113. inflate_buffer = <<>>,
  114. deflate_state = Deflate
  115. }, Req2};
  116. _ ->
  117. {ok, State, Req2}
  118. end;
  119. _ ->
  120. {ok, State, Req}
  121. end.
  122. -spec handler_init(#state{}, Req)
  123. -> {ok, Req, cowboy_middleware:env()} | {error, 400, Req}
  124. | {suspend, module(), atom(), [any()]}
  125. when Req::cowboy_req:req().
  126. handler_init(State=#state{env=Env, transport=Transport,
  127. handler=Handler, handler_opts=HandlerOpts}, Req) ->
  128. try Handler:websocket_init(Transport:name(), Req, HandlerOpts) of
  129. {ok, Req2, HandlerState} ->
  130. websocket_handshake(State, Req2, HandlerState);
  131. {ok, Req2, HandlerState, hibernate} ->
  132. websocket_handshake(State#state{hibernate=true},
  133. Req2, HandlerState);
  134. {ok, Req2, HandlerState, Timeout} ->
  135. websocket_handshake(State#state{timeout=Timeout},
  136. Req2, HandlerState);
  137. {ok, Req2, HandlerState, Timeout, hibernate} ->
  138. websocket_handshake(State#state{timeout=Timeout,
  139. hibernate=true}, Req2, HandlerState);
  140. {shutdown, Req2} ->
  141. cowboy_req:ensure_response(Req2, 400),
  142. {ok, Req2, [{result, closed}|Env]}
  143. catch Class:Reason ->
  144. error_logger:error_msg(
  145. "** Cowboy handler ~p terminating in ~p/~p~n"
  146. " for the reason ~p:~p~n** Options were ~p~n"
  147. "** Request was ~p~n** Stacktrace: ~p~n~n",
  148. [Handler, websocket_init, 3, Class, Reason, HandlerOpts,
  149. cowboy_req:to_list(Req),erlang:get_stacktrace()]),
  150. upgrade_error(Req, Env)
  151. end.
  152. %% Only send an error reply if there is no resp_sent message.
  153. -spec upgrade_error(Req, Env) -> {ok, Req, Env} | {error, 400, Req}
  154. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  155. upgrade_error(Req, Env) ->
  156. receive
  157. {cowboy_req, resp_sent} ->
  158. {ok, Req, [{result, closed}|Env]}
  159. after 0 ->
  160. {error, 400, Req}
  161. end.
  162. -spec websocket_handshake(#state{}, Req, any())
  163. -> {ok, Req, cowboy_middleware:env()}
  164. | {suspend, module(), atom(), [any()]}
  165. when Req::cowboy_req:req().
  166. websocket_handshake(State=#state{
  167. transport=Transport, key=Key, deflate_frame=DeflateFrame},
  168. Req, HandlerState) ->
  169. %% @todo Change into crypto:hash/2 for R17B+ or when supporting only R16B+.
  170. Challenge = base64:encode(crypto:sha(
  171. << Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" >>)),
  172. Extensions = case DeflateFrame of
  173. false -> [];
  174. true -> [{<<"sec-websocket-extensions">>, <<"x-webkit-deflate-frame">>}]
  175. end,
  176. {ok, Req2} = cowboy_req:upgrade_reply(
  177. 101,
  178. [{<<"upgrade">>, <<"websocket">>},
  179. {<<"sec-websocket-accept">>, Challenge}|
  180. Extensions],
  181. Req),
  182. %% Flush the resp_sent message before moving on.
  183. receive {cowboy_req, resp_sent} -> ok after 0 -> ok end,
  184. State2 = handler_loop_timeout(State),
  185. handler_before_loop(State2#state{key=undefined,
  186. messages=Transport:messages()}, Req2, HandlerState, <<>>).
  187. -spec handler_before_loop(#state{}, Req, any(), binary())
  188. -> {ok, Req, cowboy_middleware:env()}
  189. | {suspend, module(), atom(), [any()]}
  190. when Req::cowboy_req:req().
  191. handler_before_loop(State=#state{
  192. socket=Socket, transport=Transport, hibernate=true},
  193. Req, HandlerState, SoFar) ->
  194. Transport:setopts(Socket, [{active, once}]),
  195. {suspend, ?MODULE, handler_loop,
  196. [State#state{hibernate=false}, Req, HandlerState, SoFar]};
  197. handler_before_loop(State=#state{socket=Socket, transport=Transport},
  198. Req, HandlerState, SoFar) ->
  199. Transport:setopts(Socket, [{active, once}]),
  200. handler_loop(State, Req, HandlerState, SoFar).
  201. -spec handler_loop_timeout(#state{}) -> #state{}.
  202. handler_loop_timeout(State=#state{timeout=infinity}) ->
  203. State#state{timeout_ref=undefined};
  204. handler_loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) ->
  205. _ = case PrevRef of undefined -> ignore; PrevRef ->
  206. erlang:cancel_timer(PrevRef) end,
  207. TRef = erlang:start_timer(Timeout, self(), ?MODULE),
  208. State#state{timeout_ref=TRef}.
  209. %% @private
  210. -spec handler_loop(#state{}, Req, any(), binary())
  211. -> {ok, Req, cowboy_middleware:env()}
  212. | {suspend, module(), atom(), [any()]}
  213. when Req::cowboy_req:req().
  214. handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error},
  215. timeout_ref=TRef}, Req, HandlerState, SoFar) ->
  216. receive
  217. {OK, Socket, Data} ->
  218. State2 = handler_loop_timeout(State),
  219. websocket_data(State2, Req, HandlerState,
  220. << SoFar/binary, Data/binary >>);
  221. {Closed, Socket} ->
  222. handler_terminate(State, Req, HandlerState, {error, closed});
  223. {Error, Socket, Reason} ->
  224. handler_terminate(State, Req, HandlerState, {error, Reason});
  225. {timeout, TRef, ?MODULE} ->
  226. websocket_close(State, Req, HandlerState, {normal, timeout});
  227. {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
  228. handler_loop(State, Req, HandlerState, SoFar);
  229. Message ->
  230. handler_call(State, Req, HandlerState,
  231. SoFar, websocket_info, Message, fun handler_before_loop/4)
  232. end.
  233. %% All frames passing through this function are considered valid,
  234. %% with the only exception of text and close frames with a payload
  235. %% which may still contain errors.
  236. -spec websocket_data(#state{}, Req, any(), binary())
  237. -> {ok, Req, cowboy_middleware:env()}
  238. | {suspend, module(), atom(), [any()]}
  239. when Req::cowboy_req:req().
  240. %% RSV bits MUST be 0 unless an extension is negotiated
  241. %% that defines meanings for non-zero values.
  242. websocket_data(State, Req, HandlerState, << _:1, Rsv:3, _/bits >>)
  243. when Rsv =/= 0, State#state.deflate_frame =:= false ->
  244. websocket_close(State, Req, HandlerState, {error, badframe});
  245. %% Invalid opcode. Note that these opcodes may be used by extensions.
  246. websocket_data(State, Req, HandlerState, << _:4, Opcode:4, _/bits >>)
  247. when Opcode > 2, Opcode =/= 8, Opcode =/= 9, Opcode =/= 10 ->
  248. websocket_close(State, Req, HandlerState, {error, badframe});
  249. %% Control frames MUST NOT be fragmented.
  250. websocket_data(State, Req, HandlerState, << 0:1, _:3, Opcode:4, _/bits >>)
  251. when Opcode >= 8 ->
  252. websocket_close(State, Req, HandlerState, {error, badframe});
  253. %% A frame MUST NOT use the zero opcode unless fragmentation was initiated.
  254. websocket_data(State=#state{frag_state=undefined}, Req, HandlerState,
  255. << _:4, 0:4, _/bits >>) ->
  256. websocket_close(State, Req, HandlerState, {error, badframe});
  257. %% Non-control opcode when expecting control message or next fragment.
  258. websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState,
  259. << _:4, Opcode:4, _/bits >>)
  260. when Opcode =/= 0, Opcode < 8 ->
  261. websocket_close(State, Req, HandlerState, {error, badframe});
  262. %% Close control frame length MUST be 0 or >= 2.
  263. websocket_data(State, Req, HandlerState, << _:4, 8:4, _:1, 1:7, _/bits >>) ->
  264. websocket_close(State, Req, HandlerState, {error, badframe});
  265. %% Close control frame with incomplete close code. Need more data.
  266. websocket_data(State, Req, HandlerState,
  267. Data = << _:4, 8:4, 1:1, Len:7, _/bits >>)
  268. when Len > 1, byte_size(Data) < 8 ->
  269. handler_before_loop(State, Req, HandlerState, Data);
  270. %% 7 bits payload length.
  271. websocket_data(State, Req, HandlerState, << Fin:1, Rsv:3/bits, Opcode:4, 1:1,
  272. Len:7, MaskKey:32, Rest/bits >>)
  273. when Len < 126 ->
  274. websocket_data(State, Req, HandlerState,
  275. Opcode, Len, MaskKey, Rest, Rsv, Fin);
  276. %% 16 bits payload length.
  277. websocket_data(State, Req, HandlerState, << Fin:1, Rsv:3/bits, Opcode:4, 1:1,
  278. 126:7, Len:16, MaskKey:32, Rest/bits >>)
  279. when Len > 125, Opcode < 8 ->
  280. websocket_data(State, Req, HandlerState,
  281. Opcode, Len, MaskKey, Rest, Rsv, Fin);
  282. %% 63 bits payload length.
  283. websocket_data(State, Req, HandlerState, << Fin:1, Rsv:3/bits, Opcode:4, 1:1,
  284. 127:7, 0:1, Len:63, MaskKey:32, Rest/bits >>)
  285. when Len > 16#ffff, Opcode < 8 ->
  286. websocket_data(State, Req, HandlerState,
  287. Opcode, Len, MaskKey, Rest, Rsv, Fin);
  288. %% When payload length is over 63 bits, the most significant bit MUST be 0.
  289. websocket_data(State, Req, HandlerState, << _:8, 1:1, 127:7, 1:1, _:7, _/binary >>) ->
  290. websocket_close(State, Req, HandlerState, {error, badframe});
  291. %% All frames sent from the client to the server are masked.
  292. websocket_data(State, Req, HandlerState, << _:8, 0:1, _/bits >>) ->
  293. websocket_close(State, Req, HandlerState, {error, badframe});
  294. %% For the next two clauses, it can be one of the following:
  295. %%
  296. %% * The minimal number of bytes MUST be used to encode the length
  297. %% * All control frames MUST have a payload length of 125 bytes or less
  298. websocket_data(State, Req, HandlerState, << _:9, 126:7, _:48, _/bits >>) ->
  299. websocket_close(State, Req, HandlerState, {error, badframe});
  300. websocket_data(State, Req, HandlerState, << _:9, 127:7, _:96, _/bits >>) ->
  301. websocket_close(State, Req, HandlerState, {error, badframe});
  302. %% Need more data.
  303. websocket_data(State, Req, HandlerState, Data) ->
  304. handler_before_loop(State, Req, HandlerState, Data).
  305. %% Initialize or update fragmentation state.
  306. -spec websocket_data(#state{}, Req, any(),
  307. opcode(), non_neg_integer(), mask_key(), binary(), rsv(), 0 | 1)
  308. -> {ok, Req, cowboy_middleware:env()}
  309. | {suspend, module(), atom(), [any()]}
  310. when Req::cowboy_req:req().
  311. %% The opcode is only included in the first frame fragment.
  312. websocket_data(State=#state{frag_state=undefined}, Req, HandlerState,
  313. Opcode, Len, MaskKey, Data, Rsv, 0) ->
  314. websocket_payload(State#state{frag_state={nofin, Opcode, <<>>}},
  315. Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv);
  316. %% Subsequent frame fragments.
  317. websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState,
  318. 0, Len, MaskKey, Data, Rsv, 0) ->
  319. websocket_payload(State, Req, HandlerState,
  320. 0, Len, MaskKey, <<>>, Data, Rsv);
  321. %% Final frame fragment.
  322. websocket_data(State=#state{frag_state={nofin, Opcode, SoFar}},
  323. Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 1) ->
  324. websocket_payload(State#state{frag_state={fin, Opcode, SoFar}},
  325. Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv);
  326. %% Unfragmented frame.
  327. websocket_data(State, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 1) ->
  328. websocket_payload(State, Req, HandlerState,
  329. Opcode, Len, MaskKey, <<>>, Data, Rsv).
  330. -spec websocket_payload(#state{}, Req, any(),
  331. opcode(), non_neg_integer(), mask_key(), binary(), binary(), rsv())
  332. -> {ok, Req, cowboy_middleware:env()}
  333. | {suspend, module(), atom(), [any()]}
  334. when Req::cowboy_req:req().
  335. %% Close control frames with a payload MUST contain a valid close code.
  336. websocket_payload(State, Req, HandlerState,
  337. Opcode=8, Len, MaskKey, <<>>, << MaskedCode:2/binary, Rest/bits >>, Rsv) ->
  338. Unmasked = << Code:16 >> = websocket_unmask(MaskedCode, MaskKey, <<>>),
  339. if Code < 1000; Code =:= 1004; Code =:= 1005; Code =:= 1006;
  340. (Code > 1011) and (Code < 3000); Code > 4999 ->
  341. websocket_close(State, Req, HandlerState, {error, badframe});
  342. true ->
  343. websocket_payload(State, Req, HandlerState,
  344. Opcode, Len - 2, MaskKey, Unmasked, Rest, Rsv)
  345. end;
  346. %% Text frames and close control frames MUST have a payload that is valid UTF-8.
  347. websocket_payload(State=#state{utf8_state=Incomplete},
  348. Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv)
  349. when (byte_size(Data) < Len) andalso ((Opcode =:= 1) orelse
  350. ((Opcode =:= 8) andalso (Unmasked =/= <<>>))) ->
  351. Unmasked2 = websocket_unmask(Data,
  352. rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
  353. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
  354. case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
  355. false ->
  356. websocket_close(State2, Req, HandlerState, {error, badencoding});
  357. Utf8State ->
  358. websocket_payload_loop(State2#state{utf8_state=Utf8State},
  359. Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey,
  360. << Unmasked/binary, Unmasked3/binary >>, Rsv)
  361. end;
  362. websocket_payload(State=#state{utf8_state=Incomplete},
  363. Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv)
  364. when Opcode =:= 1; (Opcode =:= 8) and (Unmasked =/= <<>>) ->
  365. << End:Len/binary, Rest/bits >> = Data,
  366. Unmasked2 = websocket_unmask(End,
  367. rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
  368. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
  369. case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
  370. <<>> ->
  371. websocket_dispatch(State2#state{utf8_state= <<>>},
  372. Req, HandlerState, Rest, Opcode,
  373. << Unmasked/binary, Unmasked3/binary >>);
  374. _ ->
  375. websocket_close(State2, Req, HandlerState, {error, badencoding})
  376. end;
  377. %% Fragmented text frames may cut payload in the middle of UTF-8 codepoints.
  378. websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete},
  379. Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv)
  380. when byte_size(Data) < Len ->
  381. Unmasked2 = websocket_unmask(Data,
  382. rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
  383. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
  384. case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
  385. false ->
  386. websocket_close(State2, Req, HandlerState, {error, badencoding});
  387. Utf8State ->
  388. websocket_payload_loop(State2#state{utf8_state=Utf8State},
  389. Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey,
  390. << Unmasked/binary, Unmasked3/binary >>, Rsv)
  391. end;
  392. websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete},
  393. Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv) ->
  394. << End:Len/binary, Rest/bits >> = Data,
  395. Unmasked2 = websocket_unmask(End,
  396. rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>),
  397. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
  398. case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of
  399. <<>> ->
  400. websocket_dispatch(State2#state{utf8_state= <<>>},
  401. Req, HandlerState, Rest, Opcode,
  402. << Unmasked/binary, Unmasked3/binary >>);
  403. Utf8State when is_binary(Utf8State), Fin =:= nofin ->
  404. websocket_dispatch(State2#state{utf8_state=Utf8State},
  405. Req, HandlerState, Rest, Opcode,
  406. << Unmasked/binary, Unmasked3/binary >>);
  407. _ ->
  408. websocket_close(State, Req, HandlerState, {error, badencoding})
  409. end;
  410. %% Other frames have a binary payload.
  411. websocket_payload(State, Req, HandlerState,
  412. Opcode, Len, MaskKey, Unmasked, Data, Rsv)
  413. when byte_size(Data) < Len ->
  414. Unmasked2 = websocket_unmask(Data,
  415. rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked),
  416. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State),
  417. websocket_payload_loop(State2, Req, HandlerState,
  418. Opcode, Len - byte_size(Data), MaskKey, Unmasked3, Rsv);
  419. websocket_payload(State, Req, HandlerState,
  420. Opcode, Len, MaskKey, Unmasked, Data, Rsv) ->
  421. << End:Len/binary, Rest/bits >> = Data,
  422. Unmasked2 = websocket_unmask(End,
  423. rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked),
  424. {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State),
  425. websocket_dispatch(State2, Req, HandlerState, Rest, Opcode, Unmasked3).
  426. -spec websocket_inflate_frame(binary(), rsv(), boolean(), #state{}) ->
  427. {binary(), #state{}}.
  428. websocket_inflate_frame(Data, << Rsv1:1, _:2 >>, _,
  429. #state{deflate_frame = DeflateFrame} = State)
  430. when DeflateFrame =:= false orelse Rsv1 =:= 0 ->
  431. {Data, State};
  432. websocket_inflate_frame(Data, << 1:1, _:2 >>, false,
  433. #state{inflate_buffer = Buffer} = State) ->
  434. {<<>>, State#state{inflate_buffer = << Buffer/binary, Data/binary >>}};
  435. websocket_inflate_frame(Data, << 1:1, _:2 >>, true,
  436. #state{inflate_state = Inflate, inflate_buffer = Buffer} = State) ->
  437. Deflated = << Buffer/binary, Data/binary, 0:8, 0:8, 255:8, 255:8 >>,
  438. Result = zlib:inflate(Inflate, Deflated),
  439. {iolist_to_binary(Result), State#state{inflate_buffer = <<>>}}.
  440. -spec websocket_unmask(B, mask_key(), B) -> B when B::binary().
  441. websocket_unmask(<<>>, _, Unmasked) ->
  442. Unmasked;
  443. websocket_unmask(<< O:32, Rest/bits >>, MaskKey, Acc) ->
  444. T = O bxor MaskKey,
  445. websocket_unmask(Rest, MaskKey, << Acc/binary, T:32 >>);
  446. websocket_unmask(<< O:24 >>, MaskKey, Acc) ->
  447. << MaskKey2:24, _:8 >> = << MaskKey:32 >>,
  448. T = O bxor MaskKey2,
  449. << Acc/binary, T:24 >>;
  450. websocket_unmask(<< O:16 >>, MaskKey, Acc) ->
  451. << MaskKey2:16, _:16 >> = << MaskKey:32 >>,
  452. T = O bxor MaskKey2,
  453. << Acc/binary, T:16 >>;
  454. websocket_unmask(<< O:8 >>, MaskKey, Acc) ->
  455. << MaskKey2:8, _:24 >> = << MaskKey:32 >>,
  456. T = O bxor MaskKey2,
  457. << Acc/binary, T:8 >>.
  458. %% Because we unmask on the fly we need to continue from the right mask byte.
  459. -spec rotate_mask_key(mask_key(), non_neg_integer()) -> mask_key().
  460. rotate_mask_key(MaskKey, UnmaskedLen) ->
  461. Left = UnmaskedLen rem 4,
  462. Right = 4 - Left,
  463. (MaskKey bsl (Left * 8)) + (MaskKey bsr (Right * 8)).
  464. %% Returns <<>> if the argument is valid UTF-8, false if not,
  465. %% or the incomplete part of the argument if we need more data.
  466. -spec is_utf8(binary()) -> false | binary().
  467. is_utf8(Valid = <<>>) ->
  468. Valid;
  469. is_utf8(<< _/utf8, Rest/binary >>) ->
  470. is_utf8(Rest);
  471. %% 2 bytes. Codepages C0 and C1 are invalid; fail early.
  472. is_utf8(<< 2#1100000:7, _/bits >>) ->
  473. false;
  474. is_utf8(Incomplete = << 2#110:3, _:5 >>) ->
  475. Incomplete;
  476. %% 3 bytes.
  477. is_utf8(Incomplete = << 2#1110:4, _:4 >>) ->
  478. Incomplete;
  479. is_utf8(Incomplete = << 2#1110:4, _:4, 2#10:2, _:6 >>) ->
  480. Incomplete;
  481. %% 4 bytes. Codepage F4 may have invalid values greater than 0x10FFFF.
  482. is_utf8(<< 2#11110100:8, 2#10:2, High:6, _/bits >>) when High >= 2#10000 ->
  483. false;
  484. is_utf8(Incomplete = << 2#11110:5, _:3 >>) ->
  485. Incomplete;
  486. is_utf8(Incomplete = << 2#11110:5, _:3, 2#10:2, _:6 >>) ->
  487. Incomplete;
  488. is_utf8(Incomplete = << 2#11110:5, _:3, 2#10:2, _:6, 2#10:2, _:6 >>) ->
  489. Incomplete;
  490. %% Invalid.
  491. is_utf8(_) ->
  492. false.
  493. -spec websocket_payload_loop(#state{}, Req, any(),
  494. opcode(), non_neg_integer(), mask_key(), binary(), rsv())
  495. -> {ok, Req, cowboy_middleware:env()}
  496. | {suspend, module(), atom(), [any()]}
  497. when Req::cowboy_req:req().
  498. websocket_payload_loop(State=#state{socket=Socket, transport=Transport,
  499. messages={OK, Closed, Error}, timeout_ref=TRef},
  500. Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Rsv) ->
  501. Transport:setopts(Socket, [{active, once}]),
  502. receive
  503. {OK, Socket, Data} ->
  504. State2 = handler_loop_timeout(State),
  505. websocket_payload(State2, Req, HandlerState,
  506. Opcode, Len, MaskKey, Unmasked, Data, Rsv);
  507. {Closed, Socket} ->
  508. handler_terminate(State, Req, HandlerState, {error, closed});
  509. {Error, Socket, Reason} ->
  510. handler_terminate(State, Req, HandlerState, {error, Reason});
  511. {timeout, TRef, ?MODULE} ->
  512. websocket_close(State, Req, HandlerState, {normal, timeout});
  513. {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
  514. websocket_payload_loop(State, Req, HandlerState,
  515. Opcode, Len, MaskKey, Unmasked, Rsv);
  516. Message ->
  517. handler_call(State, Req, HandlerState,
  518. <<>>, websocket_info, Message,
  519. fun (State2, Req2, HandlerState2, _) ->
  520. websocket_payload_loop(State2, Req2, HandlerState2,
  521. Opcode, Len, MaskKey, Unmasked, Rsv)
  522. end)
  523. end.
  524. -spec websocket_dispatch(#state{}, Req, any(), binary(), opcode(), binary())
  525. -> {ok, Req, cowboy_middleware:env()}
  526. | {suspend, module(), atom(), [any()]}
  527. when Req::cowboy_req:req().
  528. %% Continuation frame.
  529. websocket_dispatch(State=#state{frag_state={nofin, Opcode, SoFar}},
  530. Req, HandlerState, RemainingData, 0, Payload) ->
  531. websocket_data(State#state{frag_state={nofin, Opcode,
  532. << SoFar/binary, Payload/binary >>}}, Req, HandlerState, RemainingData);
  533. %% Last continuation frame.
  534. websocket_dispatch(State=#state{frag_state={fin, Opcode, SoFar}},
  535. Req, HandlerState, RemainingData, 0, Payload) ->
  536. websocket_dispatch(State#state{frag_state=undefined}, Req, HandlerState,
  537. RemainingData, Opcode, << SoFar/binary, Payload/binary >>);
  538. %% Text frame.
  539. websocket_dispatch(State, Req, HandlerState, RemainingData, 1, Payload) ->
  540. handler_call(State, Req, HandlerState, RemainingData,
  541. websocket_handle, {text, Payload}, fun websocket_data/4);
  542. %% Binary frame.
  543. websocket_dispatch(State, Req, HandlerState, RemainingData, 2, Payload) ->
  544. handler_call(State, Req, HandlerState, RemainingData,
  545. websocket_handle, {binary, Payload}, fun websocket_data/4);
  546. %% Close control frame.
  547. websocket_dispatch(State, Req, HandlerState, _RemainingData, 8, <<>>) ->
  548. websocket_close(State, Req, HandlerState, {remote, closed});
  549. websocket_dispatch(State, Req, HandlerState, _RemainingData, 8,
  550. << Code:16, Payload/bits >>) ->
  551. websocket_close(State, Req, HandlerState, {remote, Code, Payload});
  552. %% Ping control frame. Send a pong back and forward the ping to the handler.
  553. websocket_dispatch(State=#state{socket=Socket, transport=Transport},
  554. Req, HandlerState, RemainingData, 9, Payload) ->
  555. Len = payload_length_to_binary(byte_size(Payload)),
  556. Transport:send(Socket, << 1:1, 0:3, 10:4, 0:1, Len/bits, Payload/binary >>),
  557. handler_call(State, Req, HandlerState, RemainingData,
  558. websocket_handle, {ping, Payload}, fun websocket_data/4);
  559. %% Pong control frame.
  560. websocket_dispatch(State, Req, HandlerState, RemainingData, 10, Payload) ->
  561. handler_call(State, Req, HandlerState, RemainingData,
  562. websocket_handle, {pong, Payload}, fun websocket_data/4).
  563. -spec handler_call(#state{}, Req, any(), binary(), atom(), any(), fun())
  564. -> {ok, Req, cowboy_middleware:env()}
  565. | {suspend, module(), atom(), [any()]}
  566. when Req::cowboy_req:req().
  567. handler_call(State=#state{handler=Handler, handler_opts=HandlerOpts}, Req,
  568. HandlerState, RemainingData, Callback, Message, NextState) ->
  569. try Handler:Callback(Message, Req, HandlerState) of
  570. {ok, Req2, HandlerState2} ->
  571. NextState(State, Req2, HandlerState2, RemainingData);
  572. {ok, Req2, HandlerState2, hibernate} ->
  573. NextState(State#state{hibernate=true},
  574. Req2, HandlerState2, RemainingData);
  575. {reply, Payload, Req2, HandlerState2}
  576. when is_tuple(Payload) ->
  577. case websocket_send(Payload, State) of
  578. {ok, State2} ->
  579. NextState(State2, Req2, HandlerState2, RemainingData);
  580. {shutdown, State2} ->
  581. handler_terminate(State2, Req2, HandlerState2,
  582. {normal, shutdown});
  583. {{error, _} = Error, State2} ->
  584. handler_terminate(State2, Req2, HandlerState2, Error)
  585. end;
  586. {reply, Payload, Req2, HandlerState2, hibernate}
  587. when is_tuple(Payload) ->
  588. case websocket_send(Payload, State) of
  589. {ok, State2} ->
  590. NextState(State2#state{hibernate=true},
  591. Req2, HandlerState2, RemainingData);
  592. {shutdown, State2} ->
  593. handler_terminate(State2, Req2, HandlerState2,
  594. {normal, shutdown});
  595. {{error, _} = Error, State2} ->
  596. handler_terminate(State2, Req2, HandlerState2, Error)
  597. end;
  598. {reply, Payload, Req2, HandlerState2}
  599. when is_list(Payload) ->
  600. case websocket_send_many(Payload, State) of
  601. {ok, State2} ->
  602. NextState(State2, Req2, HandlerState2, RemainingData);
  603. {shutdown, State2} ->
  604. handler_terminate(State2, Req2, HandlerState2,
  605. {normal, shutdown});
  606. {{error, _} = Error, State2} ->
  607. handler_terminate(State2, Req2, HandlerState2, Error)
  608. end;
  609. {reply, Payload, Req2, HandlerState2, hibernate}
  610. when is_list(Payload) ->
  611. case websocket_send_many(Payload, State) of
  612. {ok, State2} ->
  613. NextState(State2#state{hibernate=true},
  614. Req2, HandlerState2, RemainingData);
  615. {shutdown, State2} ->
  616. handler_terminate(State2, Req2, HandlerState2,
  617. {normal, shutdown});
  618. {{error, _} = Error, State2} ->
  619. handler_terminate(State2, Req2, HandlerState2, Error)
  620. end;
  621. {shutdown, Req2, HandlerState2} ->
  622. websocket_close(State, Req2, HandlerState2, {normal, shutdown})
  623. catch Class:Reason ->
  624. PLReq = cowboy_req:to_list(Req),
  625. error_logger:error_msg(
  626. "** Cowboy handler ~p terminating in ~p/~p~n"
  627. " for the reason ~p:~p~n** Message was ~p~n"
  628. "** Options were ~p~n** Handler state was ~p~n"
  629. "** Request was ~p~n** Stacktrace: ~p~n~n",
  630. [Handler, Callback, 3, Class, Reason, Message, HandlerOpts,
  631. HandlerState, PLReq, erlang:get_stacktrace()]),
  632. websocket_close(State, Req, HandlerState, {error, handler})
  633. end.
  634. websocket_opcode(text) -> 1;
  635. websocket_opcode(binary) -> 2;
  636. websocket_opcode(close) -> 8;
  637. websocket_opcode(ping) -> 9;
  638. websocket_opcode(pong) -> 10.
  639. -spec websocket_deflate_frame(opcode(), binary(), #state{}) -> {binary(), <<_:3>>, #state{}}.
  640. websocket_deflate_frame(Opcode, Payload,
  641. State=#state{deflate_frame = DeflateFrame})
  642. when DeflateFrame =:= false orelse Opcode >= 8 ->
  643. {Payload, <<0:3>>, State};
  644. websocket_deflate_frame(_, Payload, State=#state{deflate_state = Deflate}) ->
  645. Deflated = iolist_to_binary(zlib:deflate(Deflate, Payload, sync)),
  646. DeflatedBodyLength = erlang:size(Deflated) - 4,
  647. Deflated1 = case Deflated of
  648. <<Body:DeflatedBodyLength/binary, 0:8, 0:8, 255:8, 255:8>> -> Body;
  649. _ -> Deflated
  650. end,
  651. {Deflated1, <<1:1, 0:2>>, State}.
  652. -spec websocket_send(frame(), #state{})
  653. -> {ok, #state{}} | {shutdown, #state{}} | {{error, atom()}, #state{}}.
  654. websocket_send(Type, State=#state{socket=Socket, transport=Transport})
  655. when Type =:= close ->
  656. Opcode = websocket_opcode(Type),
  657. case Transport:send(Socket, << 1:1, 0:3, Opcode:4, 0:8 >>) of
  658. ok -> {shutdown, State};
  659. Error -> {Error, State}
  660. end;
  661. websocket_send(Type, State=#state{socket=Socket, transport=Transport})
  662. when Type =:= ping; Type =:= pong ->
  663. Opcode = websocket_opcode(Type),
  664. {Transport:send(Socket, << 1:1, 0:3, Opcode:4, 0:8 >>), State};
  665. websocket_send({close, Payload}, State) ->
  666. websocket_send({close, 1000, Payload}, State);
  667. websocket_send({Type = close, StatusCode, Payload}, State=#state{
  668. socket=Socket, transport=Transport}) ->
  669. Opcode = websocket_opcode(Type),
  670. Len = 2 + iolist_size(Payload),
  671. %% Control packets must not be > 125 in length.
  672. true = Len =< 125,
  673. BinLen = payload_length_to_binary(Len),
  674. Transport:send(Socket,
  675. [<< 1:1, 0:3, Opcode:4, 0:1, BinLen/bits, StatusCode:16 >>, Payload]),
  676. {shutdown, State};
  677. websocket_send({Type, Payload0}, State=#state{socket=Socket, transport=Transport}) ->
  678. Opcode = websocket_opcode(Type),
  679. {Payload, Rsv, State2} = websocket_deflate_frame(Opcode, iolist_to_binary(Payload0), State),
  680. Len = iolist_size(Payload),
  681. %% Control packets must not be > 125 in length.
  682. true = if Type =:= ping; Type =:= pong ->
  683. Len =< 125;
  684. true ->
  685. true
  686. end,
  687. BinLen = payload_length_to_binary(Len),
  688. {Transport:send(Socket,
  689. [<< 1:1, Rsv/bits, Opcode:4, 0:1, BinLen/bits >>, Payload]), State2}.
  690. -spec websocket_send_many([frame()], #state{})
  691. -> {ok, #state{}} | {shutdown, #state{}} | {{error, atom()}, #state{}}.
  692. websocket_send_many([], State) ->
  693. {ok, State};
  694. websocket_send_many([Frame|Tail], State) ->
  695. case websocket_send(Frame, State) of
  696. {ok, State2} -> websocket_send_many(Tail, State2);
  697. {shutdown, State2} -> {shutdown, State2};
  698. {Error, State2} -> {Error, State2}
  699. end.
  700. -spec websocket_close(#state{}, Req, any(),
  701. {atom(), atom()} | {remote, close_code(), binary()})
  702. -> {ok, Req, cowboy_middleware:env()}
  703. when Req::cowboy_req:req().
  704. websocket_close(State=#state{socket=Socket, transport=Transport},
  705. Req, HandlerState, Reason) ->
  706. case Reason of
  707. {normal, _} ->
  708. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1000:16 >>);
  709. {error, badframe} ->
  710. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1002:16 >>);
  711. {error, badencoding} ->
  712. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1007:16 >>);
  713. {error, handler} ->
  714. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1011:16 >>);
  715. {remote, closed} ->
  716. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:8 >>);
  717. {remote, Code, _} ->
  718. Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, Code:16 >>)
  719. end,
  720. handler_terminate(State, Req, HandlerState, Reason).
  721. -spec handler_terminate(#state{}, Req, any(), atom() | {atom(), atom()})
  722. -> {ok, Req, cowboy_middleware:env()}
  723. when Req::cowboy_req:req().
  724. handler_terminate(#state{env=Env, handler=Handler, handler_opts=HandlerOpts},
  725. Req, HandlerState, TerminateReason) ->
  726. try
  727. Handler:websocket_terminate(TerminateReason, Req, HandlerState)
  728. catch Class:Reason ->
  729. PLReq = cowboy_req:to_list(Req),
  730. error_logger:error_msg(
  731. "** Cowboy handler ~p terminating in ~p/~p~n"
  732. " for the reason ~p:~p~n** Initial reason was ~p~n"
  733. "** Options were ~p~n** Handler state was ~p~n"
  734. "** Request was ~p~n** Stacktrace: ~p~n~n",
  735. [Handler, websocket_terminate, 3, Class, Reason, TerminateReason,
  736. HandlerOpts, HandlerState, PLReq, erlang:get_stacktrace()])
  737. end,
  738. {ok, Req, [{result, closed}|Env]}.
  739. -spec payload_length_to_binary(0..16#7fffffffffffffff)
  740. -> << _:7 >> | << _:23 >> | << _:71 >>.
  741. payload_length_to_binary(N) ->
  742. case N of
  743. N when N =< 125 -> << N:7 >>;
  744. N when N =< 16#ffff -> << 126:7, N:16 >>;
  745. N when N =< 16#7fffffffffffffff -> << 127:7, N:64 >>
  746. end.