cowboy_websocket.erl 18 KB


  1. %% Copyright (c) 2011-2017, Loïc Hoguin <essen@ninenines.eu>
  2. %%
  3. %% Permission to use, copy, modify, and/or distribute this software for any
  4. %% purpose with or without fee is hereby granted, provided that the above
  5. %% copyright notice and this permission notice appear in all copies.
  6. %%
  7. %% THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  8. %% WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  9. %% MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  10. %% ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  11. %% WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  12. %% ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  13. %% OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  14. %% Cowboy supports versions 7 through 17 of the Websocket drafts.
  15. %% It also supports RFC6455, the proposed standard for Websocket.
  16. -module(cowboy_websocket).
  17. -behaviour(cowboy_sub_protocol).
  18. -export([upgrade/4]).
  19. -export([upgrade/5]).
  20. -export([takeover/7]).
  21. -export([handler_loop/3]).
  22. -type call_result(State) :: {ok, State}
  23. | {ok, State, hibernate}
  24. | {reply, cow_ws:frame() | [cow_ws:frame()], State}
  25. | {reply, cow_ws:frame() | [cow_ws:frame()], State, hibernate}
  26. | {stop, State}.
  27. -type terminate_reason() :: normal | stop | timeout
  28. | remote | {remote, cow_ws:close_code(), binary()}
  29. | {error, badencoding | badframe | closed | atom()}
  30. | {crash, error | exit | throw, any()}.
  31. -callback init(Req, any())
  32. -> {ok | module(), Req, any()}
  33. | {module(), Req, any(), any()}
  34. when Req::cowboy_req:req().
  35. -callback websocket_init(State)
  36. -> call_result(State) when State::any().
  37. -optional_callbacks([websocket_init/1]).
  38. -callback websocket_handle({text | binary | ping | pong, binary()}, State)
  39. -> call_result(State) when State::any().
  40. -callback websocket_info(any(), State)
  41. -> call_result(State) when State::any().
  42. -callback terminate(any(), cowboy_req:req(), any()) -> ok.
  43. -optional_callbacks([terminate/3]).
  44. -type opts() :: #{
  45. compress => boolean(),
  46. idle_timeout => timeout(),
  47. req_filter => fun((cowboy_req:req()) -> map())
  48. }.
  49. -export_type([opts/0]).
  50. -record(state, {
  51. socket = undefined :: inet:socket() | undefined,
  52. transport = undefined :: module(),
  53. handler :: module(),
  54. key = undefined :: undefined | binary(),
  55. timeout = infinity :: timeout(),
  56. timeout_ref = undefined :: undefined | reference(),
  57. compress = false :: boolean(),
  58. messages = undefined :: undefined | {atom(), atom(), atom()},
  59. hibernate = false :: boolean(),
  60. frag_state = undefined :: cow_ws:frag_state(),
  61. frag_buffer = <<>> :: binary(),
  62. utf8_state = 0 :: cow_ws:utf8_state(),
  63. extensions = #{} :: map(),
  64. req = #{} :: map()
  65. }).
  66. %% Stream process.
  67. -spec upgrade(Req, Env, module(), any())
  68. -> {ok, Req, Env}
  69. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  70. upgrade(Req, Env, Handler, HandlerState) ->
  71. upgrade(Req, Env, Handler, HandlerState, #{}).
  72. -spec upgrade(Req, Env, module(), any(), opts())
  73. -> {ok, Req, Env}
  74. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  75. %% @todo Immediately crash if a response has already been sent.
  76. %% @todo Error out if HTTP/2.
  77. upgrade(Req0, Env, Handler, HandlerState, Opts) ->
  78. Timeout = maps:get(idle_timeout, Opts, 60000),
  79. Compress = maps:get(compress, Opts, false),
  80. FilteredReq = case maps:get(req_filter, Opts, undefined) of
  81. undefined -> maps:with([method, version, scheme, host, port, path, qs, peer], Req0);
  82. FilterFun -> FilterFun(Req0)
  83. end,
  84. State0 = #state{handler=Handler, timeout=Timeout, compress=Compress, req=FilteredReq},
  85. try websocket_upgrade(State0, Req0) of
  86. {ok, State, Req} ->
  87. websocket_handshake(State, Req, HandlerState, Env);
  88. {error, upgrade_required} ->
  89. {ok, cowboy_req:reply(426, #{
  90. <<"connection">> => <<"upgrade">>,
  91. <<"upgrade">> => <<"websocket">>
  92. }, Req0), Env}
  93. catch _:_ ->
  94. %% @todo Probably log something here?
  95. %% @todo Test that we can have 2 /ws 400 status code in a row on the same connection.
  96. %% @todo Does this even work?
  97. {ok, cowboy_req:reply(400, Req0), Env}
  98. end.
  99. websocket_upgrade(State, Req) ->
  100. ConnTokens = cowboy_req:parse_header(<<"connection">>, Req, []),
  101. case lists:member(<<"upgrade">>, ConnTokens) of
  102. false ->
  103. {error, upgrade_required};
  104. true ->
  105. UpgradeTokens = cowboy_req:parse_header(<<"upgrade">>, Req, []),
  106. case lists:member(<<"websocket">>, UpgradeTokens) of
  107. false ->
  108. {error, upgrade_required};
  109. true ->
  110. Version = cowboy_req:header(<<"sec-websocket-version">>, Req),
  111. IntVersion = binary_to_integer(Version),
  112. true = (IntVersion =:= 7) orelse (IntVersion =:= 8)
  113. orelse (IntVersion =:= 13),
  114. Key = cowboy_req:header(<<"sec-websocket-key">>, Req),
  115. false = Key =:= undefined,
  116. websocket_extensions(State#state{key=Key}, Req#{websocket_version => IntVersion})
  117. end
  118. end.
  119. websocket_extensions(State=#state{compress=Compress}, Req) ->
  120. %% @todo We want different options for this. For example
  121. %% * compress everything auto
  122. %% * compress only text auto
  123. %% * compress only binary auto
  124. %% * compress nothing auto (but still enabled it)
  125. %% * disable compression
  126. case {Compress, cowboy_req:parse_header(<<"sec-websocket-extensions">>, Req)} of
  127. {true, Extensions} when Extensions =/= undefined ->
  128. websocket_extensions(State, Req, Extensions, []);
  129. _ ->
  130. {ok, State, Req}
  131. end.
  132. websocket_extensions(State, Req, [], []) ->
  133. {ok, State, Req};
  134. websocket_extensions(State, Req, [], [<<", ">>|RespHeader]) ->
  135. {ok, State, cowboy_req:set_resp_header(<<"sec-websocket-extensions">>, lists:reverse(RespHeader), Req)};
  136. websocket_extensions(State=#state{extensions=Extensions}, Req=#{pid := Pid},
  137. [{<<"permessage-deflate">>, Params}|Tail], RespHeader) ->
  138. %% @todo Make deflate options configurable.
  139. Opts = #{level => best_compression, mem_level => 8, strategy => default},
  140. try cow_ws:negotiate_permessage_deflate(Params, Extensions, Opts#{owner => Pid}) of
  141. {ok, RespExt, Extensions2} ->
  142. websocket_extensions(State#state{extensions=Extensions2},
  143. Req, Tail, [<<", ">>, RespExt|RespHeader]);
  144. ignore ->
  145. websocket_extensions(State, Req, Tail, RespHeader)
  146. catch exit:{error, incompatible_zlib_version, _} ->
  147. websocket_extensions(State, Req, Tail, RespHeader)
  148. end;
  149. websocket_extensions(State=#state{extensions=Extensions}, Req=#{pid := Pid},
  150. [{<<"x-webkit-deflate-frame">>, Params}|Tail], RespHeader) ->
  151. %% @todo Make deflate options configurable.
  152. Opts = #{level => best_compression, mem_level => 8, strategy => default},
  153. try cow_ws:negotiate_x_webkit_deflate_frame(Params, Extensions, Opts#{owner => Pid}) of
  154. {ok, RespExt, Extensions2} ->
  155. websocket_extensions(State#state{extensions=Extensions2},
  156. Req, Tail, [<<", ">>, RespExt|RespHeader]);
  157. ignore ->
  158. websocket_extensions(State, Req, Tail, RespHeader)
  159. catch exit:{error, incompatible_zlib_version, _} ->
  160. websocket_extensions(State, Req, Tail, RespHeader)
  161. end;
  162. websocket_extensions(State, Req, [_|Tail], RespHeader) ->
  163. websocket_extensions(State, Req, Tail, RespHeader).
  164. -spec websocket_handshake(#state{}, Req, any(), Env)
  165. -> {ok, Req, Env}
  166. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  167. websocket_handshake(State=#state{key=Key},
  168. Req=#{pid := Pid, streamid := StreamID}, HandlerState, Env) ->
  169. Challenge = base64:encode(crypto:hash(sha,
  170. << Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" >>)),
  171. %% @todo We don't want date and server headers.
  172. Headers = cowboy_req:response_headers(#{
  173. <<"connection">> => <<"Upgrade">>,
  174. <<"upgrade">> => <<"websocket">>,
  175. <<"sec-websocket-accept">> => Challenge
  176. }, Req),
  177. Pid ! {{Pid, StreamID}, {switch_protocol, Headers, ?MODULE, {State, HandlerState}}},
  178. {ok, Req, Env}.
  179. %% Connection process.
  180. %% @todo Keep parent and handle system messages.
  181. -spec takeover(pid(), ranch:ref(), inet:socket(), module(), any(), binary(),
  182. {#state{}, any()}) -> ok.
  183. takeover(_Parent, Ref, Socket, Transport, _Opts, Buffer,
  184. {State0=#state{handler=Handler}, HandlerState}) ->
  185. %% @todo We should have an option to disable this behavior.
  186. ranch:remove_connection(Ref),
  187. State1 = handler_loop_timeout(State0#state{socket=Socket, transport=Transport}),
  188. State = State1#state{key=undefined, messages=Transport:messages()},
  189. case erlang:function_exported(Handler, websocket_init, 1) of
  190. true -> handler_call(State, HandlerState, Buffer, websocket_init, undefined, fun handler_before_loop/3);
  191. false -> handler_before_loop(State, HandlerState, Buffer)
  192. end.
  193. -spec handler_before_loop(#state{}, any(), binary())
  194. %% @todo Yeah not env.
  195. -> {ok, cowboy_middleware:env()}.
  196. handler_before_loop(State=#state{
  197. socket=Socket, transport=Transport, hibernate=true},
  198. HandlerState, SoFar) ->
  199. Transport:setopts(Socket, [{active, once}]),
  200. proc_lib:hibernate(?MODULE, handler_loop,
  201. [State#state{hibernate=false}, HandlerState, SoFar]);
  202. handler_before_loop(State=#state{socket=Socket, transport=Transport},
  203. HandlerState, SoFar) ->
  204. Transport:setopts(Socket, [{active, once}]),
  205. handler_loop(State, HandlerState, SoFar).
  206. -spec handler_loop_timeout(#state{}) -> #state{}.
  207. handler_loop_timeout(State=#state{timeout=infinity}) ->
  208. State#state{timeout_ref=undefined};
  209. handler_loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) ->
  210. _ = case PrevRef of undefined -> ignore; PrevRef ->
  211. erlang:cancel_timer(PrevRef) end,
  212. TRef = erlang:start_timer(Timeout, self(), ?MODULE),
  213. State#state{timeout_ref=TRef}.
  214. -spec handler_loop(#state{}, any(), binary())
  215. -> {ok, cowboy_middleware:env()}.
  216. handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error},
  217. timeout_ref=TRef}, HandlerState, SoFar) ->
  218. receive
  219. {OK, Socket, Data} ->
  220. State2 = handler_loop_timeout(State),
  221. websocket_data(State2, HandlerState,
  222. << SoFar/binary, Data/binary >>);
  223. {Closed, Socket} ->
  224. terminate(State, HandlerState, {error, closed});
  225. {Error, Socket, Reason} ->
  226. terminate(State, HandlerState, {error, Reason});
  227. {timeout, TRef, ?MODULE} ->
  228. websocket_close(State, HandlerState, timeout);
  229. {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
  230. handler_loop(State, HandlerState, SoFar);
  231. %% Calls from supervisor module.
  232. {'$gen_call', From, Call} ->
  233. cowboy_children:handle_supervisor_call(Call, From, [], ?MODULE),
  234. handler_loop(State, HandlerState, SoFar);
  235. Message ->
  236. handler_call(State, HandlerState,
  237. SoFar, websocket_info, Message, fun handler_before_loop/3)
  238. end.
  239. -spec websocket_data(#state{}, any(), binary())
  240. -> {ok, cowboy_middleware:env()}.
  241. websocket_data(State=#state{frag_state=FragState, extensions=Extensions}, HandlerState, Data) ->
  242. case cow_ws:parse_header(Data, Extensions, FragState) of
  243. %% All frames sent from the client to the server are masked.
  244. {_, _, _, _, undefined, _} ->
  245. websocket_close(State, HandlerState, {error, badframe});
  246. {Type, FragState2, Rsv, Len, MaskKey, Rest} ->
  247. websocket_payload(State#state{frag_state=FragState2}, HandlerState, Type, Len, MaskKey, Rsv, undefined, <<>>, 0, Rest);
  248. more ->
  249. handler_before_loop(State, HandlerState, Data);
  250. error ->
  251. websocket_close(State, HandlerState, {error, badframe})
  252. end.
  253. websocket_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions},
  254. HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data) ->
  255. case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen, Type, Len, FragState, Extensions, Rsv) of
  256. {ok, CloseCode2, Payload, Utf8State, Rest} ->
  257. websocket_dispatch(State#state{utf8_state=Utf8State},
  258. HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode2, Rest);
  259. {ok, Payload, Utf8State, Rest} ->
  260. websocket_dispatch(State#state{utf8_state=Utf8State},
  261. HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode, Rest);
  262. {more, CloseCode2, Payload, Utf8State} ->
  263. websocket_payload_loop(State#state{utf8_state=Utf8State},
  264. HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode2,
  265. << Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data));
  266. {more, Payload, Utf8State} ->
  267. websocket_payload_loop(State#state{utf8_state=Utf8State},
  268. HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode,
  269. << Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data));
  270. Error = {error, _Reason} ->
  271. websocket_close(State, HandlerState, Error)
  272. end.
  273. websocket_payload_loop(State=#state{socket=Socket, transport=Transport,
  274. messages={OK, Closed, Error}, timeout_ref=TRef},
  275. HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen) ->
  276. Transport:setopts(Socket, [{active, once}]),
  277. receive
  278. {OK, Socket, Data} ->
  279. State2 = handler_loop_timeout(State),
  280. websocket_payload(State2, HandlerState,
  281. Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data);
  282. {Closed, Socket} ->
  283. terminate(State, HandlerState, {error, closed});
  284. {Error, Socket, Reason} ->
  285. terminate(State, HandlerState, {error, Reason});
  286. {timeout, TRef, ?MODULE} ->
  287. websocket_close(State, HandlerState, timeout);
  288. {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
  289. websocket_payload_loop(State, HandlerState,
  290. Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen);
  291. Message ->
  292. handler_call(State, HandlerState,
  293. <<>>, websocket_info, Message,
  294. fun (State2, HandlerState2, _) ->
  295. websocket_payload_loop(State2, HandlerState2,
  296. Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen)
  297. end)
  298. end.
  299. websocket_dispatch(State=#state{socket=Socket, transport=Transport, frag_state=FragState, frag_buffer=SoFar, extensions=Extensions},
  300. HandlerState, Type0, Payload0, CloseCode0, RemainingData) ->
  301. case cow_ws:make_frame(Type0, Payload0, CloseCode0, FragState) of
  302. %% @todo Allow receiving fragments.
  303. {fragment, nofin, _, Payload} ->
  304. websocket_data(State#state{frag_buffer= << SoFar/binary, Payload/binary >>}, HandlerState, RemainingData);
  305. {fragment, fin, Type, Payload} ->
  306. handler_call(State#state{frag_state=undefined, frag_buffer= <<>>}, HandlerState, RemainingData,
  307. websocket_handle, {Type, << SoFar/binary, Payload/binary >>}, fun websocket_data/3);
  308. close ->
  309. websocket_close(State, HandlerState, remote);
  310. {close, CloseCode, Payload} ->
  311. websocket_close(State, HandlerState, {remote, CloseCode, Payload});
  312. Frame = ping ->
  313. Transport:send(Socket, cow_ws:frame(pong, Extensions)),
  314. handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3);
  315. Frame = {ping, Payload} ->
  316. Transport:send(Socket, cow_ws:frame({pong, Payload}, Extensions)),
  317. handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3);
  318. Frame ->
  319. handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3)
  320. end.
  321. -spec handler_call(#state{}, any(), binary(), atom(), any(), fun()) -> no_return().
  322. handler_call(State=#state{handler=Handler}, HandlerState,
  323. RemainingData, Callback, Message, NextState) ->
  324. try case Callback of
  325. websocket_init -> Handler:websocket_init(HandlerState);
  326. _ -> Handler:Callback(Message, HandlerState)
  327. end of
  328. {ok, HandlerState2} ->
  329. NextState(State, HandlerState2, RemainingData);
  330. {ok, HandlerState2, hibernate} ->
  331. NextState(State#state{hibernate=true}, HandlerState2, RemainingData);
  332. {reply, Payload, HandlerState2} ->
  333. case websocket_send(Payload, State) of
  334. ok ->
  335. NextState(State, HandlerState2, RemainingData);
  336. stop ->
  337. terminate(State, HandlerState2, stop);
  338. Error = {error, _} ->
  339. terminate(State, HandlerState2, Error)
  340. end;
  341. {reply, Payload, HandlerState2, hibernate} ->
  342. case websocket_send(Payload, State) of
  343. ok ->
  344. NextState(State#state{hibernate=true},
  345. HandlerState2, RemainingData);
  346. stop ->
  347. terminate(State, HandlerState2, stop);
  348. Error = {error, _} ->
  349. terminate(State, HandlerState2, Error)
  350. end;
  351. {stop, HandlerState2} ->
  352. websocket_close(State, HandlerState2, stop)
  353. catch Class:Reason ->
  354. websocket_send_close(State, {crash, Class, Reason}),
  355. handler_terminate(State, HandlerState, {crash, Class, Reason}),
  356. erlang:raise(Class, Reason, erlang:get_stacktrace())
  357. end.
  358. -spec websocket_send(cow_ws:frame(), #state{}) -> ok | stop | {error, atom()}.
  359. websocket_send(Frames, State) when is_list(Frames) ->
  360. websocket_send_many(Frames, State, []);
  361. websocket_send(Frame, #state{socket=Socket, transport=Transport, extensions=Extensions}) ->
  362. Res = Transport:send(Socket, cow_ws:frame(Frame, Extensions)),
  363. case is_close_frame(Frame) of
  364. true -> stop;
  365. false -> Res
  366. end.
  367. websocket_send_many([], #state{socket=Socket, transport=Transport}, Acc) ->
  368. Transport:send(Socket, lists:reverse(Acc));
  369. websocket_send_many([Frame|Tail], State=#state{socket=Socket, transport=Transport,
  370. extensions=Extensions}, Acc0) ->
  371. Acc = [cow_ws:frame(Frame, Extensions)|Acc0],
  372. case is_close_frame(Frame) of
  373. true ->
  374. _ = Transport:send(Socket, lists:reverse(Acc)),
  375. stop;
  376. false ->
  377. websocket_send_many(Tail, State, Acc)
  378. end.
  379. is_close_frame(close) -> true;
  380. is_close_frame({close, _}) -> true;
  381. is_close_frame({close, _, _}) -> true;
  382. is_close_frame(_) -> false.
  383. -spec websocket_close(#state{}, any(), terminate_reason()) -> no_return().
  384. websocket_close(State, HandlerState, Reason) ->
  385. websocket_send_close(State, Reason),
  386. terminate(State, HandlerState, Reason).
  387. websocket_send_close(#state{socket=Socket, transport=Transport,
  388. extensions=Extensions}, Reason) ->
  389. _ = case Reason of
  390. Normal when Normal =:= stop; Normal =:= timeout ->
  391. Transport:send(Socket, cow_ws:frame({close, 1000, <<>>}, Extensions));
  392. {error, badframe} ->
  393. Transport:send(Socket, cow_ws:frame({close, 1002, <<>>}, Extensions));
  394. {error, badencoding} ->
  395. Transport:send(Socket, cow_ws:frame({close, 1007, <<>>}, Extensions));
  396. {crash, _, _} ->
  397. Transport:send(Socket, cow_ws:frame({close, 1011, <<>>}, Extensions));
  398. remote ->
  399. Transport:send(Socket, cow_ws:frame(close, Extensions));
  400. {remote, Code, _} ->
  401. Transport:send(Socket, cow_ws:frame({close, Code, <<>>}, Extensions))
  402. end,
  403. ok.
  404. -spec terminate(#state{}, any(), terminate_reason()) -> no_return().
  405. terminate(State, HandlerState, Reason) ->
  406. handler_terminate(State, HandlerState, Reason),
  407. exit(normal).
  408. handler_terminate(#state{handler=Handler, req=Req}, HandlerState, Reason) ->
  409. cowboy_handler:terminate(Reason, Req, HandlerState, Handler).