cowboy_websocket.erl 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. %% Copyright (c) 2011-2017, Loïc Hoguin <essen@ninenines.eu>
  2. %%
  3. %% Permission to use, copy, modify, and/or distribute this software for any
  4. %% purpose with or without fee is hereby granted, provided that the above
  5. %% copyright notice and this permission notice appear in all copies.
  6. %%
  7. %% THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  8. %% WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  9. %% MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  10. %% ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  11. %% WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  12. %% ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  13. %% OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  14. %% Cowboy supports versions 7 through 17 of the Websocket drafts.
  15. %% It also supports RFC6455, the proposed standard for Websocket.
  16. -module(cowboy_websocket).
  17. -behaviour(cowboy_sub_protocol).
  18. -export([upgrade/4]).
  19. -export([upgrade/5]).
  20. -export([takeover/7]).
  21. -export([handler_loop/3]).
  22. -type call_result(State) :: {ok, State}
  23. | {ok, State, hibernate}
  24. | {reply, cow_ws:frame() | [cow_ws:frame()], State}
  25. | {reply, cow_ws:frame() | [cow_ws:frame()], State, hibernate}
  26. | {stop, State}.
  27. -type terminate_reason() :: normal | stop | timeout
  28. | remote | {remote, cow_ws:close_code(), binary()}
  29. | {error, badencoding | badframe | closed | atom()}
  30. | {crash, error | exit | throw, any()}.
  31. -callback init(Req, any())
  32. -> {ok | module(), Req, any()}
  33. | {module(), Req, any(), any()}
  34. when Req::cowboy_req:req().
  35. -callback websocket_init(State)
  36. -> call_result(State) when State::any().
  37. -optional_callbacks([websocket_init/1]).
  38. -callback websocket_handle({text | binary | ping | pong, binary()}, State)
  39. -> call_result(State) when State::any().
  40. -callback websocket_info(any(), State)
  41. -> call_result(State) when State::any().
  42. -callback terminate(any(), cowboy_req:req(), any()) -> ok.
  43. -optional_callbacks([terminate/3]).
  44. -type opts() :: #{
  45. compress => boolean(),
  46. idle_timeout => timeout(),
  47. req_filter => fun((cowboy_req:req()) -> map())
  48. }.
  49. -export_type([opts/0]).
  50. -record(state, {
  51. socket = undefined :: inet:socket() | undefined,
  52. transport = undefined :: module(),
  53. handler :: module(),
  54. key = undefined :: undefined | binary(),
  55. timeout = infinity :: timeout(),
  56. timeout_ref = undefined :: undefined | reference(),
  57. compress = false :: boolean(),
  58. messages = undefined :: undefined | {atom(), atom(), atom()},
  59. hibernate = false :: boolean(),
  60. frag_state = undefined :: cow_ws:frag_state(),
  61. frag_buffer = <<>> :: binary(),
  62. utf8_state = 0 :: cow_ws:utf8_state(),
  63. extensions = #{} :: map(),
  64. req = #{} :: map()
  65. }).
  66. %% Stream process.
  67. -spec upgrade(Req, Env, module(), any())
  68. -> {ok, Req, Env}
  69. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  70. upgrade(Req, Env, Handler, HandlerState) ->
  71. upgrade(Req, Env, Handler, HandlerState, #{}).
  72. -spec upgrade(Req, Env, module(), any(), opts())
  73. -> {ok, Req, Env}
  74. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  75. %% @todo Immediately crash if a response has already been sent.
  76. %% @todo Error out if HTTP/2.
  77. upgrade(Req0, Env, Handler, HandlerState, Opts) ->
  78. Timeout = maps:get(idle_timeout, Opts, 60000),
  79. Compress = maps:get(compress, Opts, false),
  80. FilteredReq = case maps:get(req_filter, Opts, undefined) of
  81. undefined -> maps:with([method, version, scheme, host, port, path, qs, peer], Req0);
  82. FilterFun -> FilterFun(Req0)
  83. end,
  84. State0 = #state{handler=Handler, timeout=Timeout, compress=Compress, req=FilteredReq},
  85. try websocket_upgrade(State0, Req0) of
  86. {ok, State, Req} ->
  87. websocket_handshake(State, Req, HandlerState, Env)
  88. catch _:_ ->
  89. %% @todo Probably log something here?
  90. %% @todo Test that we can have 2 /ws 400 status code in a row on the same connection.
  91. %% @todo Does this even work?
  92. {ok, cowboy_req:reply(400, Req0), Env}
  93. end.
  94. -spec websocket_upgrade(#state{}, Req)
  95. -> {ok, #state{}, Req} when Req::cowboy_req:req().
  96. websocket_upgrade(State, Req) ->
  97. ConnTokens = cowboy_req:parse_header(<<"connection">>, Req),
  98. true = lists:member(<<"upgrade">>, ConnTokens),
  99. %% @todo Should probably send a 426 if the Upgrade header is missing.
  100. [<<"websocket">>] = cowboy_req:parse_header(<<"upgrade">>, Req),
  101. Version = cowboy_req:header(<<"sec-websocket-version">>, Req),
  102. IntVersion = binary_to_integer(Version),
  103. true = (IntVersion =:= 7) orelse (IntVersion =:= 8)
  104. orelse (IntVersion =:= 13),
  105. Key = cowboy_req:header(<<"sec-websocket-key">>, Req),
  106. false = Key =:= undefined,
  107. websocket_extensions(State#state{key=Key}, Req#{websocket_version => IntVersion}).
  108. -spec websocket_extensions(#state{}, Req)
  109. -> {ok, #state{}, Req} when Req::cowboy_req:req().
  110. websocket_extensions(State=#state{compress=Compress}, Req) ->
  111. %% @todo We want different options for this. For example
  112. %% * compress everything auto
  113. %% * compress only text auto
  114. %% * compress only binary auto
  115. %% * compress nothing auto (but still enabled it)
  116. %% * disable compression
  117. case {Compress, cowboy_req:parse_header(<<"sec-websocket-extensions">>, Req)} of
  118. {true, Extensions} when Extensions =/= undefined ->
  119. websocket_extensions(State, Req, Extensions, []);
  120. _ ->
  121. {ok, State, Req}
  122. end.
  123. websocket_extensions(State, Req, [], []) ->
  124. {ok, State, Req};
  125. websocket_extensions(State, Req, [], [<<", ">>|RespHeader]) ->
  126. {ok, State, cowboy_req:set_resp_header(<<"sec-websocket-extensions">>, lists:reverse(RespHeader), Req)};
  127. websocket_extensions(State=#state{extensions=Extensions}, Req=#{pid := Pid},
  128. [{<<"permessage-deflate">>, Params}|Tail], RespHeader) ->
  129. %% @todo Make deflate options configurable.
  130. Opts = #{level => best_compression, mem_level => 8, strategy => default},
  131. try cow_ws:negotiate_permessage_deflate(Params, Extensions, Opts#{owner => Pid}) of
  132. {ok, RespExt, Extensions2} ->
  133. websocket_extensions(State#state{extensions=Extensions2},
  134. Req, Tail, [<<", ">>, RespExt|RespHeader]);
  135. ignore ->
  136. websocket_extensions(State, Req, Tail, RespHeader)
  137. catch exit:{error, incompatible_zlib_version, _} ->
  138. websocket_extensions(State, Req, Tail, RespHeader)
  139. end;
  140. websocket_extensions(State=#state{extensions=Extensions}, Req=#{pid := Pid},
  141. [{<<"x-webkit-deflate-frame">>, Params}|Tail], RespHeader) ->
  142. %% @todo Make deflate options configurable.
  143. Opts = #{level => best_compression, mem_level => 8, strategy => default},
  144. try cow_ws:negotiate_x_webkit_deflate_frame(Params, Extensions, Opts#{owner => Pid}) of
  145. {ok, RespExt, Extensions2} ->
  146. websocket_extensions(State#state{extensions=Extensions2},
  147. Req, Tail, [<<", ">>, RespExt|RespHeader]);
  148. ignore ->
  149. websocket_extensions(State, Req, Tail, RespHeader)
  150. catch exit:{error, incompatible_zlib_version, _} ->
  151. websocket_extensions(State, Req, Tail, RespHeader)
  152. end;
  153. websocket_extensions(State, Req, [_|Tail], RespHeader) ->
  154. websocket_extensions(State, Req, Tail, RespHeader).
  155. -spec websocket_handshake(#state{}, Req, any(), Env)
  156. -> {ok, Req, Env}
  157. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  158. websocket_handshake(State=#state{key=Key},
  159. Req=#{pid := Pid, streamid := StreamID}, HandlerState, Env) ->
  160. Challenge = base64:encode(crypto:hash(sha,
  161. << Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" >>)),
  162. %% @todo We don't want date and server headers.
  163. Headers = cowboy_req:response_headers(#{
  164. <<"connection">> => <<"Upgrade">>,
  165. <<"upgrade">> => <<"websocket">>,
  166. <<"sec-websocket-accept">> => Challenge
  167. }, Req),
  168. Pid ! {{Pid, StreamID}, {switch_protocol, Headers, ?MODULE, {State, HandlerState}}},
  169. {ok, Req, Env}.
  170. %% Connection process.
  171. %% @todo Keep parent and handle system messages.
  172. -spec takeover(pid(), ranch:ref(), inet:socket(), module(), any(), binary(),
  173. {#state{}, any()}) -> ok.
  174. takeover(_Parent, Ref, Socket, Transport, _Opts, Buffer,
  175. {State0=#state{handler=Handler}, HandlerState}) ->
  176. %% @todo We should have an option to disable this behavior.
  177. ranch:remove_connection(Ref),
  178. State1 = handler_loop_timeout(State0#state{socket=Socket, transport=Transport}),
  179. State = State1#state{key=undefined, messages=Transport:messages()},
  180. case erlang:function_exported(Handler, websocket_init, 1) of
  181. true -> handler_call(State, HandlerState, Buffer, websocket_init, undefined, fun handler_before_loop/3);
  182. false -> handler_before_loop(State, HandlerState, Buffer)
  183. end.
  184. -spec handler_before_loop(#state{}, any(), binary())
  185. %% @todo Yeah not env.
  186. -> {ok, cowboy_middleware:env()}.
  187. handler_before_loop(State=#state{
  188. socket=Socket, transport=Transport, hibernate=true},
  189. HandlerState, SoFar) ->
  190. Transport:setopts(Socket, [{active, once}]),
  191. proc_lib:hibernate(?MODULE, handler_loop,
  192. [State#state{hibernate=false}, HandlerState, SoFar]);
  193. handler_before_loop(State=#state{socket=Socket, transport=Transport},
  194. HandlerState, SoFar) ->
  195. Transport:setopts(Socket, [{active, once}]),
  196. handler_loop(State, HandlerState, SoFar).
  197. -spec handler_loop_timeout(#state{}) -> #state{}.
  198. handler_loop_timeout(State=#state{timeout=infinity}) ->
  199. State#state{timeout_ref=undefined};
  200. handler_loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) ->
  201. _ = case PrevRef of undefined -> ignore; PrevRef ->
  202. erlang:cancel_timer(PrevRef) end,
  203. TRef = erlang:start_timer(Timeout, self(), ?MODULE),
  204. State#state{timeout_ref=TRef}.
  205. -spec handler_loop(#state{}, any(), binary())
  206. -> {ok, cowboy_middleware:env()}.
  207. handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error},
  208. timeout_ref=TRef}, HandlerState, SoFar) ->
  209. receive
  210. {OK, Socket, Data} ->
  211. State2 = handler_loop_timeout(State),
  212. websocket_data(State2, HandlerState,
  213. << SoFar/binary, Data/binary >>);
  214. {Closed, Socket} ->
  215. terminate(State, HandlerState, {error, closed});
  216. {Error, Socket, Reason} ->
  217. terminate(State, HandlerState, {error, Reason});
  218. {timeout, TRef, ?MODULE} ->
  219. websocket_close(State, HandlerState, timeout);
  220. {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
  221. handler_loop(State, HandlerState, SoFar);
  222. Message ->
  223. handler_call(State, HandlerState,
  224. SoFar, websocket_info, Message, fun handler_before_loop/3)
  225. end.
  226. -spec websocket_data(#state{}, any(), binary())
  227. -> {ok, cowboy_middleware:env()}.
  228. websocket_data(State=#state{frag_state=FragState, extensions=Extensions}, HandlerState, Data) ->
  229. case cow_ws:parse_header(Data, Extensions, FragState) of
  230. %% All frames sent from the client to the server are masked.
  231. {_, _, _, _, undefined, _} ->
  232. websocket_close(State, HandlerState, {error, badframe});
  233. {Type, FragState2, Rsv, Len, MaskKey, Rest} ->
  234. websocket_payload(State#state{frag_state=FragState2}, HandlerState, Type, Len, MaskKey, Rsv, undefined, <<>>, 0, Rest);
  235. more ->
  236. handler_before_loop(State, HandlerState, Data);
  237. error ->
  238. websocket_close(State, HandlerState, {error, badframe})
  239. end.
  240. websocket_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions},
  241. HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data) ->
  242. case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen, Type, Len, FragState, Extensions, Rsv) of
  243. {ok, CloseCode2, Payload, Utf8State, Rest} ->
  244. websocket_dispatch(State#state{utf8_state=Utf8State},
  245. HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode2, Rest);
  246. {ok, Payload, Utf8State, Rest} ->
  247. websocket_dispatch(State#state{utf8_state=Utf8State},
  248. HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode, Rest);
  249. {more, CloseCode2, Payload, Utf8State} ->
  250. websocket_payload_loop(State#state{utf8_state=Utf8State},
  251. HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode2,
  252. << Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data));
  253. {more, Payload, Utf8State} ->
  254. websocket_payload_loop(State#state{utf8_state=Utf8State},
  255. HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode,
  256. << Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data));
  257. Error = {error, _Reason} ->
  258. websocket_close(State, HandlerState, Error)
  259. end.
  260. websocket_payload_loop(State=#state{socket=Socket, transport=Transport,
  261. messages={OK, Closed, Error}, timeout_ref=TRef},
  262. HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen) ->
  263. Transport:setopts(Socket, [{active, once}]),
  264. receive
  265. {OK, Socket, Data} ->
  266. State2 = handler_loop_timeout(State),
  267. websocket_payload(State2, HandlerState,
  268. Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data);
  269. {Closed, Socket} ->
  270. terminate(State, HandlerState, {error, closed});
  271. {Error, Socket, Reason} ->
  272. terminate(State, HandlerState, {error, Reason});
  273. {timeout, TRef, ?MODULE} ->
  274. websocket_close(State, HandlerState, timeout);
  275. {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
  276. websocket_payload_loop(State, HandlerState,
  277. Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen);
  278. Message ->
  279. handler_call(State, HandlerState,
  280. <<>>, websocket_info, Message,
  281. fun (State2, HandlerState2, _) ->
  282. websocket_payload_loop(State2, HandlerState2,
  283. Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen)
  284. end)
  285. end.
  286. websocket_dispatch(State=#state{socket=Socket, transport=Transport, frag_state=FragState, frag_buffer=SoFar, extensions=Extensions},
  287. HandlerState, Type0, Payload0, CloseCode0, RemainingData) ->
  288. case cow_ws:make_frame(Type0, Payload0, CloseCode0, FragState) of
  289. %% @todo Allow receiving fragments.
  290. {fragment, nofin, _, Payload} ->
  291. websocket_data(State#state{frag_buffer= << SoFar/binary, Payload/binary >>}, HandlerState, RemainingData);
  292. {fragment, fin, Type, Payload} ->
  293. handler_call(State#state{frag_state=undefined, frag_buffer= <<>>}, HandlerState, RemainingData,
  294. websocket_handle, {Type, << SoFar/binary, Payload/binary >>}, fun websocket_data/3);
  295. close ->
  296. websocket_close(State, HandlerState, remote);
  297. {close, CloseCode, Payload} ->
  298. websocket_close(State, HandlerState, {remote, CloseCode, Payload});
  299. Frame = ping ->
  300. Transport:send(Socket, cow_ws:frame(pong, Extensions)),
  301. handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3);
  302. Frame = {ping, Payload} ->
  303. Transport:send(Socket, cow_ws:frame({pong, Payload}, Extensions)),
  304. handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3);
  305. Frame ->
  306. handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3)
  307. end.
  308. -spec handler_call(#state{}, any(), binary(), atom(), any(), fun()) -> no_return().
  309. handler_call(State=#state{handler=Handler}, HandlerState,
  310. RemainingData, Callback, Message, NextState) ->
  311. try case Callback of
  312. websocket_init -> Handler:websocket_init(HandlerState);
  313. _ -> Handler:Callback(Message, HandlerState)
  314. end of
  315. {ok, HandlerState2} ->
  316. NextState(State, HandlerState2, RemainingData);
  317. {ok, HandlerState2, hibernate} ->
  318. NextState(State#state{hibernate=true}, HandlerState2, RemainingData);
  319. {reply, Payload, HandlerState2} ->
  320. case websocket_send(Payload, State) of
  321. ok ->
  322. NextState(State, HandlerState2, RemainingData);
  323. stop ->
  324. terminate(State, HandlerState2, stop);
  325. Error = {error, _} ->
  326. terminate(State, HandlerState2, Error)
  327. end;
  328. {reply, Payload, HandlerState2, hibernate} ->
  329. case websocket_send(Payload, State) of
  330. ok ->
  331. NextState(State#state{hibernate=true},
  332. HandlerState2, RemainingData);
  333. stop ->
  334. terminate(State, HandlerState2, stop);
  335. Error = {error, _} ->
  336. terminate(State, HandlerState2, Error)
  337. end;
  338. {stop, HandlerState2} ->
  339. websocket_close(State, HandlerState2, stop)
  340. catch Class:Reason ->
  341. websocket_send_close(State, {crash, Class, Reason}),
  342. handler_terminate(State, HandlerState, {crash, Class, Reason}),
  343. erlang:raise(Class, Reason, erlang:get_stacktrace())
  344. end.
  345. -spec websocket_send(cow_ws:frame(), #state{}) -> ok | stop | {error, atom()}.
  346. websocket_send(Frames, State) when is_list(Frames) ->
  347. websocket_send_many(Frames, State, []);
  348. websocket_send(Frame, #state{socket=Socket, transport=Transport, extensions=Extensions}) ->
  349. Res = Transport:send(Socket, cow_ws:frame(Frame, Extensions)),
  350. case is_close_frame(Frame) of
  351. true -> stop;
  352. false -> Res
  353. end.
  354. websocket_send_many([], #state{socket=Socket, transport=Transport}, Acc) ->
  355. Transport:send(Socket, lists:reverse(Acc));
  356. websocket_send_many([Frame|Tail], State=#state{socket=Socket, transport=Transport,
  357. extensions=Extensions}, Acc0) ->
  358. Acc = [cow_ws:frame(Frame, Extensions)|Acc0],
  359. case is_close_frame(Frame) of
  360. true ->
  361. _ = Transport:send(Socket, lists:reverse(Acc)),
  362. stop;
  363. false ->
  364. websocket_send_many(Tail, State, Acc)
  365. end.
  366. is_close_frame(close) -> true;
  367. is_close_frame({close, _}) -> true;
  368. is_close_frame({close, _, _}) -> true;
  369. is_close_frame(_) -> false.
  370. -spec websocket_close(#state{}, any(), terminate_reason()) -> no_return().
  371. websocket_close(State, HandlerState, Reason) ->
  372. websocket_send_close(State, Reason),
  373. terminate(State, HandlerState, Reason).
  374. websocket_send_close(#state{socket=Socket, transport=Transport,
  375. extensions=Extensions}, Reason) ->
  376. _ = case Reason of
  377. Normal when Normal =:= stop; Normal =:= timeout ->
  378. Transport:send(Socket, cow_ws:frame({close, 1000, <<>>}, Extensions));
  379. {error, badframe} ->
  380. Transport:send(Socket, cow_ws:frame({close, 1002, <<>>}, Extensions));
  381. {error, badencoding} ->
  382. Transport:send(Socket, cow_ws:frame({close, 1007, <<>>}, Extensions));
  383. {crash, _, _} ->
  384. Transport:send(Socket, cow_ws:frame({close, 1011, <<>>}, Extensions));
  385. remote ->
  386. Transport:send(Socket, cow_ws:frame(close, Extensions));
  387. {remote, Code, _} ->
  388. Transport:send(Socket, cow_ws:frame({close, Code, <<>>}, Extensions))
  389. end,
  390. ok.
  391. -spec terminate(#state{}, any(), terminate_reason()) -> no_return().
  392. terminate(State, HandlerState, Reason) ->
  393. handler_terminate(State, HandlerState, Reason),
  394. exit(normal).
  395. handler_terminate(#state{handler=Handler, req=Req}, HandlerState, Reason) ->
  396. cowboy_handler:terminate(Reason, Req, HandlerState, Handler).