cowboy_websocket.erl 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. %% Copyright (c) 2011-2017, Loïc Hoguin <essen@ninenines.eu>
  2. %%
  3. %% Permission to use, copy, modify, and/or distribute this software for any
  4. %% purpose with or without fee is hereby granted, provided that the above
  5. %% copyright notice and this permission notice appear in all copies.
  6. %%
  7. %% THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  8. %% WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  9. %% MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  10. %% ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  11. %% WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  12. %% ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  13. %% OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  14. %% Cowboy supports versions 7 through 17 of the Websocket drafts.
  15. %% It also supports RFC6455, the proposed standard for Websocket.
  16. -module(cowboy_websocket).
  17. -behaviour(cowboy_sub_protocol).
  18. -ifdef(OTP_RELEASE).
  19. -compile({nowarn_deprecated_function, [{erlang, get_stacktrace, 0}]}).
  20. -endif.
  21. -export([is_upgrade_request/1]).
  22. -export([upgrade/4]).
  23. -export([upgrade/5]).
  24. -export([takeover/7]).
  25. -export([loop/3]).
  26. -export([system_continue/3]).
  27. -export([system_terminate/4]).
  28. -export([system_code_change/4]).
  29. -type commands() :: [cow_ws:frame()
  30. | {active, boolean()}
  31. | {deflate, boolean()}
  32. | {set_options, map()}
  33. ].
  34. -export_type([commands/0]).
  35. -type call_result(State) :: {commands(), State} | {commands(), State, hibernate}.
  36. -type deprecated_call_result(State) :: {ok, State}
  37. | {ok, State, hibernate}
  38. | {reply, cow_ws:frame() | [cow_ws:frame()], State}
  39. | {reply, cow_ws:frame() | [cow_ws:frame()], State, hibernate}
  40. | {stop, State}.
  41. -type terminate_reason() :: normal | stop | timeout
  42. | remote | {remote, cow_ws:close_code(), binary()}
  43. | {error, badencoding | badframe | closed | atom()}
  44. | {crash, error | exit | throw, any()}.
  45. -callback init(Req, any())
  46. -> {ok | module(), Req, any()}
  47. | {module(), Req, any(), any()}
  48. when Req::cowboy_req:req().
  49. -callback websocket_init(State)
  50. -> call_result(State) | deprecated_call_result(State) when State::any().
  51. -optional_callbacks([websocket_init/1]).
  52. -callback websocket_handle(ping | pong | {text | binary | ping | pong, binary()}, State)
  53. -> call_result(State) | deprecated_call_result(State) when State::any().
  54. -callback websocket_info(any(), State)
  55. -> call_result(State) | deprecated_call_result(State) when State::any().
  56. -callback terminate(any(), cowboy_req:req(), any()) -> ok.
  57. -optional_callbacks([terminate/3]).
  58. -type opts() :: #{
  59. compress => boolean(),
  60. deflate_opts => cow_ws:deflate_opts(),
  61. idle_timeout => timeout(),
  62. max_frame_size => non_neg_integer() | infinity,
  63. req_filter => fun((cowboy_req:req()) -> map())
  64. }.
  65. -export_type([opts/0]).
  66. -record(state, {
  67. parent :: undefined | pid(),
  68. ref :: ranch:ref(),
  69. socket = undefined :: inet:socket() | {pid(), cowboy_stream:streamid()} | undefined,
  70. transport = undefined :: module() | undefined,
  71. opts = #{} :: opts(),
  72. active = true :: boolean(),
  73. handler :: module(),
  74. key = undefined :: undefined | binary(),
  75. timeout_ref = undefined :: undefined | reference(),
  76. messages = undefined :: undefined | {atom(), atom(), atom()},
  77. hibernate = false :: boolean(),
  78. frag_state = undefined :: cow_ws:frag_state(),
  79. frag_buffer = <<>> :: binary(),
  80. utf8_state = 0 :: cow_ws:utf8_state(),
  81. deflate = true :: boolean(),
  82. extensions = #{} :: map(),
  83. req = #{} :: map()
  84. }).
  85. %% Because the HTTP/1.1 and HTTP/2 handshakes are so different,
  86. %% this function is necessary to figure out whether a request
  87. %% is trying to upgrade to the Websocket protocol.
  88. -spec is_upgrade_request(cowboy_req:req()) -> boolean().
  89. is_upgrade_request(#{version := 'HTTP/2', method := <<"CONNECT">>, protocol := Protocol}) ->
  90. <<"websocket">> =:= cowboy_bstr:to_lower(Protocol);
  91. is_upgrade_request(Req=#{version := 'HTTP/1.1', method := <<"GET">>}) ->
  92. ConnTokens = cowboy_req:parse_header(<<"connection">>, Req, []),
  93. case lists:member(<<"upgrade">>, ConnTokens) of
  94. false ->
  95. false;
  96. true ->
  97. UpgradeTokens = cowboy_req:parse_header(<<"upgrade">>, Req),
  98. lists:member(<<"websocket">>, UpgradeTokens)
  99. end;
  100. is_upgrade_request(_) ->
  101. false.
  102. %% Stream process.
  103. -spec upgrade(Req, Env, module(), any())
  104. -> {ok, Req, Env}
  105. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  106. upgrade(Req, Env, Handler, HandlerState) ->
  107. upgrade(Req, Env, Handler, HandlerState, #{}).
  108. -spec upgrade(Req, Env, module(), any(), opts())
  109. -> {ok, Req, Env}
  110. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  111. %% @todo Immediately crash if a response has already been sent.
  112. upgrade(Req0=#{version := Version}, Env, Handler, HandlerState, Opts) ->
  113. FilteredReq = case maps:get(req_filter, Opts, undefined) of
  114. undefined -> maps:with([method, version, scheme, host, port, path, qs, peer], Req0);
  115. FilterFun -> FilterFun(Req0)
  116. end,
  117. State0 = #state{opts=Opts, handler=Handler, req=FilteredReq},
  118. try websocket_upgrade(State0, Req0) of
  119. {ok, State, Req} ->
  120. websocket_handshake(State, Req, HandlerState, Env);
  121. %% The status code 426 is specific to HTTP/1.1 connections.
  122. {error, upgrade_required} when Version =:= 'HTTP/1.1' ->
  123. {ok, cowboy_req:reply(426, #{
  124. <<"connection">> => <<"upgrade">>,
  125. <<"upgrade">> => <<"websocket">>
  126. }, Req0), Env};
  127. %% Use a generic 400 error for HTTP/2.
  128. {error, upgrade_required} ->
  129. {ok, cowboy_req:reply(400, Req0), Env}
  130. catch _:_ ->
  131. %% @todo Probably log something here?
  132. %% @todo Test that we can have 2 /ws 400 status code in a row on the same connection.
  133. %% @todo Does this even work?
  134. {ok, cowboy_req:reply(400, Req0), Env}
  135. end.
  136. websocket_upgrade(State, Req=#{version := Version}) ->
  137. case is_upgrade_request(Req) of
  138. false ->
  139. {error, upgrade_required};
  140. true when Version =:= 'HTTP/1.1' ->
  141. Key = cowboy_req:header(<<"sec-websocket-key">>, Req),
  142. false = Key =:= undefined,
  143. websocket_version(State#state{key=Key}, Req);
  144. true ->
  145. websocket_version(State, Req)
  146. end.
  147. websocket_version(State, Req) ->
  148. WsVersion = cowboy_req:parse_header(<<"sec-websocket-version">>, Req),
  149. case WsVersion of
  150. 7 -> ok;
  151. 8 -> ok;
  152. 13 -> ok
  153. end,
  154. websocket_extensions(State, Req#{websocket_version => WsVersion}).
  155. websocket_extensions(State=#state{opts=Opts}, Req) ->
  156. %% @todo We want different options for this. For example
  157. %% * compress everything auto
  158. %% * compress only text auto
  159. %% * compress only binary auto
  160. %% * compress nothing auto (but still enabled it)
  161. %% * disable compression
  162. Compress = maps:get(compress, Opts, false),
  163. case {Compress, cowboy_req:parse_header(<<"sec-websocket-extensions">>, Req)} of
  164. {true, Extensions} when Extensions =/= undefined ->
  165. websocket_extensions(State, Req, Extensions, []);
  166. _ ->
  167. {ok, State, Req}
  168. end.
  169. websocket_extensions(State, Req, [], []) ->
  170. {ok, State, Req};
  171. websocket_extensions(State, Req, [], [<<", ">>|RespHeader]) ->
  172. {ok, State, cowboy_req:set_resp_header(<<"sec-websocket-extensions">>, lists:reverse(RespHeader), Req)};
  173. %% For HTTP/2 we ARE on the controlling process and do NOT want to update the owner.
  174. websocket_extensions(State=#state{opts=Opts, extensions=Extensions},
  175. Req=#{pid := Pid, version := Version},
  176. [{<<"permessage-deflate">>, Params}|Tail], RespHeader) ->
  177. DeflateOpts0 = maps:get(deflate_opts, Opts, #{}),
  178. DeflateOpts = case Version of
  179. 'HTTP/1.1' -> DeflateOpts0#{owner => Pid};
  180. _ -> DeflateOpts0
  181. end,
  182. try cow_ws:negotiate_permessage_deflate(Params, Extensions, DeflateOpts) of
  183. {ok, RespExt, Extensions2} ->
  184. websocket_extensions(State#state{extensions=Extensions2},
  185. Req, Tail, [<<", ">>, RespExt|RespHeader]);
  186. ignore ->
  187. websocket_extensions(State, Req, Tail, RespHeader)
  188. catch exit:{error, incompatible_zlib_version, _} ->
  189. websocket_extensions(State, Req, Tail, RespHeader)
  190. end;
  191. websocket_extensions(State=#state{opts=Opts, extensions=Extensions},
  192. Req=#{pid := Pid, version := Version},
  193. [{<<"x-webkit-deflate-frame">>, Params}|Tail], RespHeader) ->
  194. DeflateOpts0 = maps:get(deflate_opts, Opts, #{}),
  195. DeflateOpts = case Version of
  196. 'HTTP/1.1' -> DeflateOpts0#{owner => Pid};
  197. _ -> DeflateOpts0
  198. end,
  199. try cow_ws:negotiate_x_webkit_deflate_frame(Params, Extensions, DeflateOpts) of
  200. {ok, RespExt, Extensions2} ->
  201. websocket_extensions(State#state{extensions=Extensions2},
  202. Req, Tail, [<<", ">>, RespExt|RespHeader]);
  203. ignore ->
  204. websocket_extensions(State, Req, Tail, RespHeader)
  205. catch exit:{error, incompatible_zlib_version, _} ->
  206. websocket_extensions(State, Req, Tail, RespHeader)
  207. end;
  208. websocket_extensions(State, Req, [_|Tail], RespHeader) ->
  209. websocket_extensions(State, Req, Tail, RespHeader).
  210. -spec websocket_handshake(#state{}, Req, any(), Env)
  211. -> {ok, Req, Env}
  212. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  213. websocket_handshake(State=#state{key=Key},
  214. Req=#{version := 'HTTP/1.1', pid := Pid, streamid := StreamID},
  215. HandlerState, Env) ->
  216. Challenge = base64:encode(crypto:hash(sha,
  217. << Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" >>)),
  218. %% @todo We don't want date and server headers.
  219. Headers = cowboy_req:response_headers(#{
  220. <<"connection">> => <<"Upgrade">>,
  221. <<"upgrade">> => <<"websocket">>,
  222. <<"sec-websocket-accept">> => Challenge
  223. }, Req),
  224. Pid ! {{Pid, StreamID}, {switch_protocol, Headers, ?MODULE, {State, HandlerState}}},
  225. {ok, Req, Env};
  226. %% For HTTP/2 we do not let the process die, we instead keep it
  227. %% for the Websocket stream. This is because in HTTP/2 we only
  228. %% have a stream, it doesn't take over the whole connection.
  229. websocket_handshake(State, Req=#{ref := Ref, pid := Pid, streamid := StreamID},
  230. HandlerState, _Env) ->
  231. %% @todo We don't want date and server headers.
  232. Headers = cowboy_req:response_headers(#{}, Req),
  233. Pid ! {{Pid, StreamID}, {switch_protocol, Headers, ?MODULE, {State, HandlerState}}},
  234. takeover(Pid, Ref, {Pid, StreamID}, undefined, undefined, <<>>,
  235. {State, HandlerState}).
  236. %% Connection process.
  237. -record(ps_header, {
  238. buffer = <<>> :: binary()
  239. }).
  240. -record(ps_payload, {
  241. type :: cow_ws:frame_type(),
  242. len :: non_neg_integer(),
  243. mask_key :: cow_ws:mask_key(),
  244. rsv :: cow_ws:rsv(),
  245. close_code = undefined :: undefined | cow_ws:close_code(),
  246. unmasked = <<>> :: binary(),
  247. unmasked_len = 0 :: non_neg_integer(),
  248. buffer = <<>> :: binary()
  249. }).
  250. -type parse_state() :: #ps_header{} | #ps_payload{}.
  251. -spec takeover(pid(), ranch:ref(), inet:socket() | {pid(), cowboy_stream:streamid()},
  252. module() | undefined, any(), binary(),
  253. {#state{}, any()}) -> no_return().
  254. takeover(Parent, Ref, Socket, Transport, _Opts, Buffer,
  255. {State0=#state{handler=Handler}, HandlerState}) ->
  256. %% @todo We should have an option to disable this behavior.
  257. ranch:remove_connection(Ref),
  258. Messages = case Transport of
  259. undefined -> undefined;
  260. _ -> Transport:messages()
  261. end,
  262. State = loop_timeout(State0#state{parent=Parent,
  263. ref=Ref, socket=Socket, transport=Transport,
  264. key=undefined, messages=Messages}),
  265. case erlang:function_exported(Handler, websocket_init, 1) of
  266. true -> handler_call(State, HandlerState, #ps_header{buffer=Buffer},
  267. websocket_init, undefined, fun before_loop/3);
  268. false -> before_loop(State, HandlerState, #ps_header{buffer=Buffer})
  269. end.
  270. before_loop(State=#state{active=false}, HandlerState, ParseState) ->
  271. loop(State, HandlerState, ParseState);
  272. %% @todo We probably shouldn't do the setopts if we have not received a socket message.
  273. %% @todo We need to hibernate when HTTP/2 is used too.
  274. before_loop(State=#state{socket=Stream={Pid, _}, transport=undefined},
  275. HandlerState, ParseState) ->
  276. %% @todo Keep Ref around.
  277. ReadBodyRef = make_ref(),
  278. Pid ! {Stream, {read_body, self(), ReadBodyRef, auto, infinity}},
  279. loop(State, HandlerState, ParseState);
  280. before_loop(State=#state{socket=Socket, transport=Transport, hibernate=true},
  281. HandlerState, ParseState) ->
  282. Transport:setopts(Socket, [{active, once}]),
  283. proc_lib:hibernate(?MODULE, loop,
  284. [State#state{hibernate=false}, HandlerState, ParseState]);
  285. before_loop(State=#state{socket=Socket, transport=Transport},
  286. HandlerState, ParseState) ->
  287. Transport:setopts(Socket, [{active, once}]),
  288. loop(State, HandlerState, ParseState).
  289. -spec loop_timeout(#state{}) -> #state{}.
  290. loop_timeout(State=#state{opts=Opts, timeout_ref=PrevRef}) ->
  291. _ = case PrevRef of
  292. undefined -> ignore;
  293. PrevRef -> erlang:cancel_timer(PrevRef)
  294. end,
  295. case maps:get(idle_timeout, Opts, 60000) of
  296. infinity ->
  297. State#state{timeout_ref=undefined};
  298. Timeout ->
  299. TRef = erlang:start_timer(Timeout, self(), ?MODULE),
  300. State#state{timeout_ref=TRef}
  301. end.
  302. -spec loop(#state{}, any(), parse_state()) -> no_return().
  303. loop(State=#state{parent=Parent, socket=Socket, messages=Messages,
  304. timeout_ref=TRef}, HandlerState, ParseState) ->
  305. receive
  306. %% Socket messages. (HTTP/1.1)
  307. {OK, Socket, Data} when OK =:= element(1, Messages) ->
  308. State2 = loop_timeout(State),
  309. parse(State2, HandlerState, ParseState, Data);
  310. {Closed, Socket} when Closed =:= element(2, Messages) ->
  311. terminate(State, HandlerState, {error, closed});
  312. {Error, Socket, Reason} when Error =:= element(3, Messages) ->
  313. terminate(State, HandlerState, {error, Reason});
  314. %% Body reading messages. (HTTP/2)
  315. {request_body, _Ref, nofin, Data} ->
  316. State2 = loop_timeout(State),
  317. parse(State2, HandlerState, ParseState, Data);
  318. %% @todo We need to handle this case as if it was an {error, closed}
  319. %% but not before we finish processing frames. We probably should have
  320. %% a check in before_loop to let us stop looping if a flag is set.
  321. {request_body, _Ref, fin, _, Data} ->
  322. State2 = loop_timeout(State),
  323. parse(State2, HandlerState, ParseState, Data);
  324. %% Timeouts.
  325. {timeout, TRef, ?MODULE} ->
  326. websocket_close(State, HandlerState, timeout);
  327. {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
  328. %% @todo This should call before_loop.
  329. loop(State, HandlerState, ParseState);
  330. %% System messages.
  331. {'EXIT', Parent, Reason} ->
  332. %% @todo We should exit gracefully.
  333. exit(Reason);
  334. {system, From, Request} ->
  335. sys:handle_system_msg(Request, From, Parent, ?MODULE, [],
  336. {State, HandlerState, ParseState});
  337. %% Calls from supervisor module.
  338. {'$gen_call', From, Call} ->
  339. cowboy_children:handle_supervisor_call(Call, From, [], ?MODULE),
  340. %% @todo This should call before_loop.
  341. loop(State, HandlerState, ParseState);
  342. Message ->
  343. handler_call(State, HandlerState, ParseState,
  344. websocket_info, Message, fun before_loop/3)
  345. end.
  346. parse(State, HandlerState, PS=#ps_header{buffer=Buffer}, Data) ->
  347. parse_header(State, HandlerState, PS#ps_header{
  348. buffer= <<Buffer/binary, Data/binary>>});
  349. parse(State, HandlerState, PS=#ps_payload{buffer=Buffer}, Data) ->
  350. parse_payload(State, HandlerState, PS#ps_payload{buffer= <<>>},
  351. <<Buffer/binary, Data/binary>>).
  352. parse_header(State=#state{opts=Opts, frag_state=FragState, extensions=Extensions},
  353. HandlerState, ParseState=#ps_header{buffer=Data}) ->
  354. MaxFrameSize = maps:get(max_frame_size, Opts, infinity),
  355. case cow_ws:parse_header(Data, Extensions, FragState) of
  356. %% All frames sent from the client to the server are masked.
  357. {_, _, _, _, undefined, _} ->
  358. websocket_close(State, HandlerState, {error, badframe});
  359. {_, _, _, Len, _, _} when Len > MaxFrameSize ->
  360. websocket_close(State, HandlerState, {error, badsize});
  361. {Type, FragState2, Rsv, Len, MaskKey, Rest} ->
  362. parse_payload(State#state{frag_state=FragState2}, HandlerState,
  363. #ps_payload{type=Type, len=Len, mask_key=MaskKey, rsv=Rsv}, Rest);
  364. more ->
  365. before_loop(State, HandlerState, ParseState);
  366. error ->
  367. websocket_close(State, HandlerState, {error, badframe})
  368. end.
  369. parse_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions},
  370. HandlerState, ParseState=#ps_payload{
  371. type=Type, len=Len, mask_key=MaskKey, rsv=Rsv,
  372. unmasked=Unmasked, unmasked_len=UnmaskedLen}, Data) ->
  373. case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen,
  374. Type, Len, FragState, Extensions, Rsv) of
  375. {ok, CloseCode, Payload, Utf8State, Rest} ->
  376. dispatch_frame(State#state{utf8_state=Utf8State}, HandlerState,
  377. ParseState#ps_payload{unmasked= <<Unmasked/binary, Payload/binary>>,
  378. close_code=CloseCode}, Rest);
  379. {ok, Payload, Utf8State, Rest} ->
  380. dispatch_frame(State#state{utf8_state=Utf8State}, HandlerState,
  381. ParseState#ps_payload{unmasked= <<Unmasked/binary, Payload/binary>>},
  382. Rest);
  383. {more, CloseCode, Payload, Utf8State} ->
  384. before_loop(State#state{utf8_state=Utf8State}, HandlerState,
  385. ParseState#ps_payload{len=Len - byte_size(Data), close_code=CloseCode,
  386. unmasked= <<Unmasked/binary, Payload/binary>>,
  387. unmasked_len=UnmaskedLen + byte_size(Data)});
  388. {more, Payload, Utf8State} ->
  389. before_loop(State#state{utf8_state=Utf8State}, HandlerState,
  390. ParseState#ps_payload{len=Len - byte_size(Data),
  391. unmasked= <<Unmasked/binary, Payload/binary>>,
  392. unmasked_len=UnmaskedLen + byte_size(Data)});
  393. Error = {error, _Reason} ->
  394. websocket_close(State, HandlerState, Error)
  395. end.
  396. dispatch_frame(State=#state{opts=Opts, frag_state=FragState, frag_buffer=SoFar}, HandlerState,
  397. #ps_payload{type=Type0, unmasked=Payload0, close_code=CloseCode0}, RemainingData) ->
  398. MaxFrameSize = maps:get(max_frame_size, Opts, infinity),
  399. case cow_ws:make_frame(Type0, Payload0, CloseCode0, FragState) of
  400. %% @todo Allow receiving fragments.
  401. {fragment, _, _, Payload} when byte_size(Payload) + byte_size(SoFar) > MaxFrameSize ->
  402. websocket_close(State, HandlerState, {error, badsize});
  403. {fragment, nofin, _, Payload} ->
  404. parse_header(State#state{frag_buffer= << SoFar/binary, Payload/binary >>},
  405. HandlerState, #ps_header{buffer=RemainingData});
  406. {fragment, fin, Type, Payload} ->
  407. handler_call(State#state{frag_state=undefined, frag_buffer= <<>>}, HandlerState,
  408. #ps_header{buffer=RemainingData},
  409. websocket_handle, {Type, << SoFar/binary, Payload/binary >>},
  410. fun parse_header/3);
  411. close ->
  412. websocket_close(State, HandlerState, remote);
  413. {close, CloseCode, Payload} ->
  414. websocket_close(State, HandlerState, {remote, CloseCode, Payload});
  415. Frame = ping ->
  416. transport_send(State, nofin, frame(pong, State)),
  417. handler_call(State, HandlerState,
  418. #ps_header{buffer=RemainingData},
  419. websocket_handle, Frame, fun parse_header/3);
  420. Frame = {ping, Payload} ->
  421. transport_send(State, nofin, frame({pong, Payload}, State)),
  422. handler_call(State, HandlerState,
  423. #ps_header{buffer=RemainingData},
  424. websocket_handle, Frame, fun parse_header/3);
  425. Frame ->
  426. handler_call(State, HandlerState,
  427. #ps_header{buffer=RemainingData},
  428. websocket_handle, Frame, fun parse_header/3)
  429. end.
  430. handler_call(State=#state{handler=Handler}, HandlerState,
  431. ParseState, Callback, Message, NextState) ->
  432. try case Callback of
  433. websocket_init -> Handler:websocket_init(HandlerState);
  434. _ -> Handler:Callback(Message, HandlerState)
  435. end of
  436. {Commands, HandlerState2} when is_list(Commands) ->
  437. handler_call_result(State,
  438. HandlerState2, ParseState, NextState, Commands);
  439. {Commands, HandlerState2, hibernate} when is_list(Commands) ->
  440. handler_call_result(State#state{hibernate=true},
  441. HandlerState2, ParseState, NextState, Commands);
  442. %% The following call results are deprecated.
  443. {ok, HandlerState2} ->
  444. NextState(State, HandlerState2, ParseState);
  445. {ok, HandlerState2, hibernate} ->
  446. NextState(State#state{hibernate=true}, HandlerState2, ParseState);
  447. {reply, Payload, HandlerState2} ->
  448. case websocket_send(Payload, State) of
  449. ok ->
  450. NextState(State, HandlerState2, ParseState);
  451. stop ->
  452. terminate(State, HandlerState2, stop);
  453. Error = {error, _} ->
  454. terminate(State, HandlerState2, Error)
  455. end;
  456. {reply, Payload, HandlerState2, hibernate} ->
  457. case websocket_send(Payload, State) of
  458. ok ->
  459. NextState(State#state{hibernate=true},
  460. HandlerState2, ParseState);
  461. stop ->
  462. terminate(State, HandlerState2, stop);
  463. Error = {error, _} ->
  464. terminate(State, HandlerState2, Error)
  465. end;
  466. {stop, HandlerState2} ->
  467. websocket_close(State, HandlerState2, stop)
  468. catch Class:Reason ->
  469. StackTrace = erlang:get_stacktrace(),
  470. websocket_send_close(State, {crash, Class, Reason}),
  471. handler_terminate(State, HandlerState, {crash, Class, Reason}),
  472. erlang:raise(Class, Reason, StackTrace)
  473. end.
  474. -spec handler_call_result(#state{}, any(), parse_state(), fun(), commands()) -> no_return().
  475. handler_call_result(State0, HandlerState, ParseState, NextState, Commands) ->
  476. case commands(Commands, State0, []) of
  477. {ok, State} ->
  478. NextState(State, HandlerState, ParseState);
  479. {stop, State} ->
  480. terminate(State, HandlerState, stop);
  481. {Error = {error, _}, State} ->
  482. terminate(State, HandlerState, Error)
  483. end.
  484. commands([], State, []) ->
  485. {ok, State};
  486. commands([], State, Data) ->
  487. Result = transport_send(State, nofin, lists:reverse(Data)),
  488. {Result, State};
  489. commands([{active, Active}|Tail], State, Data) when is_boolean(Active) ->
  490. commands(Tail, State#state{active=Active}, Data);
  491. commands([{deflate, Deflate}|Tail], State, Data) when is_boolean(Deflate) ->
  492. commands(Tail, State#state{deflate=Deflate}, Data);
  493. commands([{set_options, SetOpts}|Tail], State0=#state{opts=Opts}, Data) ->
  494. State = case SetOpts of
  495. #{idle_timeout := IdleTimeout} ->
  496. loop_timeout(State0#state{opts=Opts#{idle_timeout => IdleTimeout}});
  497. _ ->
  498. State0
  499. end,
  500. commands(Tail, State, Data);
  501. commands([Frame|Tail], State, Data0) ->
  502. Data = [frame(Frame, State)|Data0],
  503. case is_close_frame(Frame) of
  504. true ->
  505. _ = transport_send(State, fin, lists:reverse(Data)),
  506. {stop, State};
  507. false ->
  508. commands(Tail, State, Data)
  509. end.
  510. transport_send(#state{socket=Stream={Pid, _}, transport=undefined}, IsFin, Data) ->
  511. Pid ! {Stream, {data, IsFin, Data}},
  512. ok;
  513. transport_send(#state{socket=Socket, transport=Transport}, _, Data) ->
  514. Transport:send(Socket, Data).
  515. -spec websocket_send(cow_ws:frame(), #state{}) -> ok | stop | {error, atom()}.
  516. websocket_send(Frames, State) when is_list(Frames) ->
  517. websocket_send_many(Frames, State, []);
  518. websocket_send(Frame, State) ->
  519. Data = frame(Frame, State),
  520. case is_close_frame(Frame) of
  521. true ->
  522. _ = transport_send(State, fin, Data),
  523. stop;
  524. false ->
  525. transport_send(State, nofin, Data)
  526. end.
  527. websocket_send_many([], State, Acc) ->
  528. transport_send(State, nofin, lists:reverse(Acc));
  529. websocket_send_many([Frame|Tail], State, Acc0) ->
  530. Acc = [frame(Frame, State)|Acc0],
  531. case is_close_frame(Frame) of
  532. true ->
  533. _ = transport_send(State, fin, lists:reverse(Acc)),
  534. stop;
  535. false ->
  536. websocket_send_many(Tail, State, Acc)
  537. end.
  538. is_close_frame(close) -> true;
  539. is_close_frame({close, _}) -> true;
  540. is_close_frame({close, _, _}) -> true;
  541. is_close_frame(_) -> false.
  542. -spec websocket_close(#state{}, any(), terminate_reason()) -> no_return().
  543. websocket_close(State, HandlerState, Reason) ->
  544. websocket_send_close(State, Reason),
  545. terminate(State, HandlerState, Reason).
  546. websocket_send_close(State, Reason) ->
  547. _ = case Reason of
  548. Normal when Normal =:= stop; Normal =:= timeout ->
  549. transport_send(State, fin, frame({close, 1000, <<>>}, State));
  550. {error, badframe} ->
  551. transport_send(State, fin, frame({close, 1002, <<>>}, State));
  552. {error, badencoding} ->
  553. transport_send(State, fin, frame({close, 1007, <<>>}, State));
  554. {error, badsize} ->
  555. transport_send(State, fin, frame({close, 1009, <<>>}, State));
  556. {crash, _, _} ->
  557. transport_send(State, fin, frame({close, 1011, <<>>}, State));
  558. remote ->
  559. transport_send(State, fin, frame(close, State));
  560. {remote, Code, _} ->
  561. transport_send(State, fin, frame({close, Code, <<>>}, State))
  562. end,
  563. ok.
  564. %% Don't compress frames while deflate is disabled.
  565. frame(Frame, #state{deflate=false, extensions=Extensions}) ->
  566. cow_ws:frame(Frame, Extensions#{deflate => false});
  567. frame(Frame, #state{extensions=Extensions}) ->
  568. cow_ws:frame(Frame, Extensions).
  569. -spec terminate(#state{}, any(), terminate_reason()) -> no_return().
  570. terminate(State, HandlerState, Reason) ->
  571. handler_terminate(State, HandlerState, Reason),
  572. exit(normal).
  573. handler_terminate(#state{handler=Handler, req=Req}, HandlerState, Reason) ->
  574. cowboy_handler:terminate(Reason, Req, HandlerState, Handler).
  575. %% System callbacks.
  576. -spec system_continue(_, _, {#state{}, any(), parse_state()}) -> no_return().
  577. system_continue(_, _, {State, HandlerState, ParseState}) ->
  578. loop(State, HandlerState, ParseState).
  579. -spec system_terminate(any(), _, _, {#state{}, any(), parse_state()}) -> no_return().
  580. system_terminate(Reason, _, _, {State, HandlerState, _}) ->
  581. %% @todo We should exit gracefully, if possible.
  582. terminate(State, HandlerState, Reason).
  583. -spec system_code_change(Misc, _, _, _)
  584. -> {ok, Misc} when Misc::{#state{}, any(), parse_state()}.
  585. system_code_change(Misc, _, _, _) ->
  586. {ok, Misc}.