cowboy_websocket.erl 18 KB


  1. %% Copyright (c) 2011-2014, Loïc Hoguin <essen@ninenines.eu>
  2. %%
  3. %% Permission to use, copy, modify, and/or distribute this software for any
  4. %% purpose with or without fee is hereby granted, provided that the above
  5. %% copyright notice and this permission notice appear in all copies.
  6. %%
  7. %% THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  8. %% WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  9. %% MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  10. %% ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  11. %% WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  12. %% ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  13. %% OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  14. %% Cowboy supports versions 7 through 17 of the Websocket drafts.
  15. %% It also supports RFC6455, the proposed standard for Websocket.
  16. -module(cowboy_websocket).
  17. -behaviour(cowboy_sub_protocol).
  18. -export([upgrade/6]).
  19. -export([takeover/7]).
  20. -export([handler_loop/4]).
  21. -type terminate_reason() :: normal | stop | timeout
  22. | remote | {remote, cow_ws:close_code(), binary()}
  23. | {error, badencoding | badframe | closed | atom()}
  24. | {crash, error | exit | throw, any()}.
  25. -callback init(Req, any())
  26. -> {ok | module(), Req, any()}
  27. | {module(), Req, any(), hibernate}
  28. | {module(), Req, any(), timeout()}
  29. | {module(), Req, any(), timeout(), hibernate}
  30. when Req::cowboy_req:req().
  31. -callback websocket_handle({text | binary | ping | pong, binary()}, Req, State)
  32. -> {ok, Req, State}
  33. | {ok, Req, State, hibernate}
  34. | {reply, cow_ws:frame() | [cow_ws:frame()], Req, State}
  35. | {reply, cow_ws:frame() | [cow_ws:frame()], Req, State, hibernate}
  36. | {stop, Req, State}
  37. when Req::cowboy_req:req(), State::any().
  38. -callback websocket_info(any(), Req, State)
  39. -> {ok, Req, State}
  40. | {ok, Req, State, hibernate}
  41. | {reply, cow_ws:frame() | [cow_ws:frame()], Req, State}
  42. | {reply, cow_ws:frame() | [cow_ws:frame()], Req, State, hibernate}
  43. | {stop, Req, State}
  44. when Req::cowboy_req:req(), State::any().
  45. -callback terminate(any(), cowboy_req:req(), any()) -> ok.
  46. -optional_callbacks([terminate/3]).
  47. -record(state, {
  48. socket = undefined :: inet:socket(),
  49. transport = undefined :: module(),
  50. handler :: module(),
  51. key = undefined :: undefined | binary(),
  52. timeout = infinity :: timeout(),
  53. timeout_ref = undefined :: undefined | reference(),
  54. messages = undefined :: undefined | {atom(), atom(), atom()},
  55. hibernate = false :: boolean(),
  56. frag_state = undefined :: cow_ws:frag_state(),
  57. frag_buffer = <<>> :: binary(),
  58. utf8_state = 0 :: cow_ws:utf8_state(),
  59. extensions = #{} :: map()
  60. }).
  61. %% Stream process.
  62. -spec upgrade(Req, Env, module(), any(), timeout(), run | hibernate)
  63. -> {ok, Req, Env}
  64. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  65. upgrade(Req, Env, Handler, HandlerState, Timeout, Hibernate) ->
  66. State = #state{handler=Handler, timeout=Timeout,
  67. hibernate=Hibernate =:= hibernate},
  68. %% @todo We need to fail if HTTP/2.
  69. try websocket_upgrade(State, Req) of
  70. {ok, State2, Req2} ->
  71. websocket_handshake(State2, Req2, HandlerState, Env)
  72. catch _:_ ->
  73. %% @todo Test that we can have 2 /ws 400 status code in a row on the same connection.
  74. cowboy_req:reply(400, Req),
  75. {ok, Req, Env}
  76. end.
  77. -spec websocket_upgrade(#state{}, Req)
  78. -> {ok, #state{}, Req} when Req::cowboy_req:req().
  79. websocket_upgrade(State, Req) ->
  80. ConnTokens = cowboy_req:parse_header(<<"connection">>, Req),
  81. true = lists:member(<<"upgrade">>, ConnTokens),
  82. %% @todo Should probably send a 426 if the Upgrade header is missing.
  83. [<<"websocket">>] = cowboy_req:parse_header(<<"upgrade">>, Req),
  84. Version = cowboy_req:header(<<"sec-websocket-version">>, Req),
  85. IntVersion = list_to_integer(binary_to_list(Version)),
  86. true = (IntVersion =:= 7) orelse (IntVersion =:= 8)
  87. orelse (IntVersion =:= 13),
  88. Key = cowboy_req:header(<<"sec-websocket-key">>, Req),
  89. false = Key =:= undefined,
  90. websocket_extensions(State#state{key=Key}, Req#{websocket_version => IntVersion}).
  91. -spec websocket_extensions(#state{}, Req)
  92. -> {ok, #state{}, Req} when Req::cowboy_req:req().
  93. websocket_extensions(State, Req) ->
  94. %% @todo We want different options for this. For example
  95. %% * compress everything auto
  96. %% * compress only text auto
  97. %% * compress only binary auto
  98. %% * compress nothing auto (but still enabled it)
  99. %% * disable compression
  100. Compress = maps:get(websocket_compress, Req, false),
  101. Req2 = Req#{websocket_compress => false},
  102. case {Compress, cowboy_req:parse_header(<<"sec-websocket-extensions">>, Req2)} of
  103. {true, Extensions} when Extensions =/= undefined ->
  104. websocket_extensions(State, Req2, Extensions, []);
  105. _ ->
  106. {ok, State, Req2}
  107. end.
  108. websocket_extensions(State, Req, [], []) ->
  109. {ok, State, Req};
  110. websocket_extensions(State, Req, [], [<<", ">>|RespHeader]) ->
  111. {ok, State, cowboy_req:set_resp_header(<<"sec-websocket-extensions">>, lists:reverse(RespHeader), Req)};
  112. websocket_extensions(State=#state{extensions=Extensions}, Req, [{<<"permessage-deflate">>, Params}|Tail], RespHeader) ->
  113. %% @todo Make deflate options configurable.
  114. Opts = #{level => best_compression, mem_level => 8, strategy => default},
  115. case cow_ws:negotiate_permessage_deflate(Params, Extensions, Opts) of
  116. {ok, RespExt, Extensions2} ->
  117. Req2 = Req#{websocket_compress => true},
  118. websocket_extensions(State#state{extensions=Extensions2},
  119. Req2, Tail, [<<", ">>, RespExt|RespHeader]);
  120. ignore ->
  121. websocket_extensions(State, Req, Tail, RespHeader)
  122. end;
  123. websocket_extensions(State=#state{extensions=Extensions}, Req, [{<<"x-webkit-deflate-frame">>, Params}|Tail], RespHeader) ->
  124. %% @todo Make deflate options configurable.
  125. Opts = #{level => best_compression, mem_level => 8, strategy => default},
  126. case cow_ws:negotiate_x_webkit_deflate_frame(Params, Extensions, Opts) of
  127. {ok, RespExt, Extensions2} ->
  128. Req2 = cowboy_req:set_meta(websocket_compress, true, Req),
  129. websocket_extensions(State#state{extensions=Extensions2},
  130. Req2, Tail, [<<", ">>, RespExt|RespHeader]);
  131. ignore ->
  132. websocket_extensions(State, Req, Tail, RespHeader)
  133. end;
  134. websocket_extensions(State, Req, [_|Tail], RespHeader) ->
  135. websocket_extensions(State, Req, Tail, RespHeader).
  136. -spec websocket_handshake(#state{}, Req, any(), Env)
  137. -> {ok, Req, Env}
  138. when Req::cowboy_req:req(), Env::cowboy_middleware:env().
  139. websocket_handshake(State=#state{key=Key},
  140. Req=#{pid := Pid, streamid := StreamID}, HandlerState, Env) ->
  141. Challenge = base64:encode(crypto:hash(sha,
  142. << Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" >>)),
  143. Headers = #{
  144. %% @todo Hmm should those be here or in cowboy_http?
  145. <<"connection">> => <<"Upgrade">>,
  146. <<"upgrade">> => <<"websocket">>,
  147. <<"sec-websocket-accept">> => Challenge
  148. },
  149. Pid ! {{Pid, StreamID}, {switch_protocol, Headers, ?MODULE, {Req, State, HandlerState}}},
  150. {ok, Req, Env}.
  151. %% Connection process.
  152. %% @todo Keep parent and handle system messages.
  153. -spec takeover(pid(), ranch:ref(), inet:socket(), module(), any(), binary(),
  154. {cowboy_req:req(), #state{}, any()}) -> ok.
  155. takeover(_Parent, Ref, Socket, Transport, _Opts, Buffer,
  156. {Req0, State=#state{handler=Handler}, HandlerState0}) ->
  157. ranch:remove_connection(Ref),
  158. %% @todo Remove Req from Websocket callbacks.
  159. %% @todo Allow sending a reply from websocket_init.
  160. %% @todo Try/catch.
  161. {ok, Req, HandlerState} = case erlang:function_exported(Handler, websocket_init, 2) of
  162. true -> Handler:websocket_init(Req0, HandlerState0);
  163. false -> {ok, Req0, HandlerState0}
  164. end,
  165. State2 = handler_loop_timeout(State#state{socket=Socket, transport=Transport}),
  166. handler_before_loop(State2#state{key=undefined,
  167. messages=Transport:messages()}, Req, HandlerState, Buffer).
  168. -spec handler_before_loop(#state{}, Req, any(), binary())
  169. -> {ok, Req, cowboy_middleware:env()}
  170. when Req::cowboy_req:req().
  171. handler_before_loop(State=#state{
  172. socket=Socket, transport=Transport, hibernate=true},
  173. Req, HandlerState, SoFar) ->
  174. Transport:setopts(Socket, [{active, once}]),
  175. proc_lib:hibernate(?MODULE, handler_loop,
  176. [State#state{hibernate=false}, Req, HandlerState, SoFar]);
  177. handler_before_loop(State=#state{socket=Socket, transport=Transport},
  178. Req, HandlerState, SoFar) ->
  179. Transport:setopts(Socket, [{active, once}]),
  180. handler_loop(State, Req, HandlerState, SoFar).
  181. -spec handler_loop_timeout(#state{}) -> #state{}.
  182. handler_loop_timeout(State=#state{timeout=infinity}) ->
  183. State#state{timeout_ref=undefined};
  184. handler_loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) ->
  185. _ = case PrevRef of undefined -> ignore; PrevRef ->
  186. erlang:cancel_timer(PrevRef) end,
  187. TRef = erlang:start_timer(Timeout, self(), ?MODULE),
  188. State#state{timeout_ref=TRef}.
  189. -spec handler_loop(#state{}, Req, any(), binary())
  190. -> {ok, Req, cowboy_middleware:env()}
  191. when Req::cowboy_req:req().
  192. handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error},
  193. timeout_ref=TRef}, Req, HandlerState, SoFar) ->
  194. receive
  195. {OK, Socket, Data} ->
  196. State2 = handler_loop_timeout(State),
  197. websocket_data(State2, Req, HandlerState,
  198. << SoFar/binary, Data/binary >>);
  199. {Closed, Socket} ->
  200. handler_terminate(State, Req, HandlerState, {error, closed});
  201. {Error, Socket, Reason} ->
  202. handler_terminate(State, Req, HandlerState, {error, Reason});
  203. {timeout, TRef, ?MODULE} ->
  204. websocket_close(State, Req, HandlerState, timeout);
  205. {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
  206. handler_loop(State, Req, HandlerState, SoFar);
  207. Message ->
  208. handler_call(State, Req, HandlerState,
  209. SoFar, websocket_info, Message, fun handler_before_loop/4)
  210. end.
  211. -spec websocket_data(#state{}, Req, any(), binary())
  212. -> {ok, Req, cowboy_middleware:env()}
  213. when Req::cowboy_req:req().
  214. websocket_data(State=#state{frag_state=FragState, extensions=Extensions}, Req, HandlerState, Data) ->
  215. case cow_ws:parse_header(Data, Extensions, FragState) of
  216. %% All frames sent from the client to the server are masked.
  217. {_, _, _, _, undefined, _} ->
  218. websocket_close(State, Req, HandlerState, {error, badframe});
  219. {Type, FragState2, Rsv, Len, MaskKey, Rest} ->
  220. websocket_payload(State#state{frag_state=FragState2}, Req, HandlerState, Type, Len, MaskKey, Rsv, undefined, <<>>, 0, Rest);
  221. more ->
  222. handler_before_loop(State, Req, HandlerState, Data);
  223. error ->
  224. websocket_close(State, Req, HandlerState, {error, badframe})
  225. end.
  226. websocket_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions},
  227. Req, HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data) ->
  228. case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen, Type, Len, FragState, Extensions, Rsv) of
  229. {ok, CloseCode2, Payload, Utf8State, Rest} ->
  230. websocket_dispatch(State#state{utf8_state=Utf8State},
  231. Req, HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode2, Rest);
  232. {ok, Payload, Utf8State, Rest} ->
  233. websocket_dispatch(State#state{utf8_state=Utf8State},
  234. Req, HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode, Rest);
  235. {more, CloseCode2, Payload, Utf8State} ->
  236. websocket_payload_loop(State#state{utf8_state=Utf8State},
  237. Req, HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode2,
  238. << Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data));
  239. {more, Payload, Utf8State} ->
  240. websocket_payload_loop(State#state{utf8_state=Utf8State},
  241. Req, HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode,
  242. << Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data));
  243. Error = {error, _Reason} ->
  244. websocket_close(State, Req, HandlerState, Error)
  245. end.
  246. websocket_payload_loop(State=#state{socket=Socket, transport=Transport,
  247. messages={OK, Closed, Error}, timeout_ref=TRef},
  248. Req, HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen) ->
  249. Transport:setopts(Socket, [{active, once}]),
  250. receive
  251. {OK, Socket, Data} ->
  252. State2 = handler_loop_timeout(State),
  253. websocket_payload(State2, Req, HandlerState,
  254. Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data);
  255. {Closed, Socket} ->
  256. handler_terminate(State, Req, HandlerState, {error, closed});
  257. {Error, Socket, Reason} ->
  258. handler_terminate(State, Req, HandlerState, {error, Reason});
  259. {timeout, TRef, ?MODULE} ->
  260. websocket_close(State, Req, HandlerState, timeout);
  261. {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
  262. websocket_payload_loop(State, Req, HandlerState,
  263. Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen);
  264. Message ->
  265. handler_call(State, Req, HandlerState,
  266. <<>>, websocket_info, Message,
  267. fun (State2, Req2, HandlerState2, _) ->
  268. websocket_payload_loop(State2, Req2, HandlerState2,
  269. Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen)
  270. end)
  271. end.
  272. websocket_dispatch(State=#state{socket=Socket, transport=Transport, frag_state=FragState, frag_buffer=SoFar, extensions=Extensions},
  273. Req, HandlerState, Type0, Payload0, CloseCode0, RemainingData) ->
  274. case cow_ws:make_frame(Type0, Payload0, CloseCode0, FragState) of
  275. %% @todo Allow receiving fragments.
  276. {fragment, nofin, _, Payload} ->
  277. websocket_data(State#state{frag_buffer= << SoFar/binary, Payload/binary >>}, Req, HandlerState, RemainingData);
  278. {fragment, fin, Type, Payload} ->
  279. handler_call(State#state{frag_state=undefined, frag_buffer= <<>>}, Req, HandlerState, RemainingData,
  280. websocket_handle, {Type, << SoFar/binary, Payload/binary >>}, fun websocket_data/4);
  281. close ->
  282. websocket_close(State, Req, HandlerState, remote);
  283. {close, CloseCode, Payload} ->
  284. websocket_close(State, Req, HandlerState, {remote, CloseCode, Payload});
  285. Frame = ping ->
  286. Transport:send(Socket, cow_ws:frame(pong, Extensions)),
  287. handler_call(State, Req, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/4);
  288. Frame = {ping, Payload} ->
  289. Transport:send(Socket, cow_ws:frame({pong, Payload}, Extensions)),
  290. handler_call(State, Req, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/4);
  291. Frame ->
  292. handler_call(State, Req, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/4)
  293. end.
  294. -spec handler_call(#state{}, Req, any(), binary(), atom(), any(), fun())
  295. -> {ok, Req, cowboy_middleware:env()}
  296. when Req::cowboy_req:req().
  297. handler_call(State=#state{handler=Handler}, Req, HandlerState,
  298. RemainingData, Callback, Message, NextState) ->
  299. try Handler:Callback(Message, Req, HandlerState) of
  300. {ok, Req2, HandlerState2} ->
  301. NextState(State, Req2, HandlerState2, RemainingData);
  302. {ok, Req2, HandlerState2, hibernate} ->
  303. NextState(State#state{hibernate=true},
  304. Req2, HandlerState2, RemainingData);
  305. {reply, Payload, Req2, HandlerState2}
  306. when is_list(Payload) ->
  307. case websocket_send_many(Payload, State) of
  308. ok ->
  309. NextState(State, Req2, HandlerState2, RemainingData);
  310. stop ->
  311. handler_terminate(State, Req2, HandlerState2, stop);
  312. Error = {error, _} ->
  313. handler_terminate(State, Req2, HandlerState2, Error)
  314. end;
  315. {reply, Payload, Req2, HandlerState2, hibernate}
  316. when is_list(Payload) ->
  317. case websocket_send_many(Payload, State) of
  318. ok ->
  319. NextState(State#state{hibernate=true},
  320. Req2, HandlerState2, RemainingData);
  321. stop ->
  322. handler_terminate(State, Req2, HandlerState2, stop);
  323. Error = {error, _} ->
  324. handler_terminate(State, Req2, HandlerState2, Error)
  325. end;
  326. {reply, Payload, Req2, HandlerState2} ->
  327. case websocket_send(Payload, State) of
  328. ok ->
  329. NextState(State, Req2, HandlerState2, RemainingData);
  330. stop ->
  331. handler_terminate(State, Req2, HandlerState2, stop);
  332. Error = {error, _} ->
  333. handler_terminate(State, Req2, HandlerState2, Error)
  334. end;
  335. {reply, Payload, Req2, HandlerState2, hibernate} ->
  336. case websocket_send(Payload, State) of
  337. ok ->
  338. NextState(State#state{hibernate=true},
  339. Req2, HandlerState2, RemainingData);
  340. stop ->
  341. handler_terminate(State, Req2, HandlerState2, stop);
  342. Error = {error, _} ->
  343. handler_terminate(State, Req2, HandlerState2, Error)
  344. end;
  345. {stop, Req2, HandlerState2} ->
  346. websocket_close(State, Req2, HandlerState2, stop)
  347. catch Class:Reason ->
  348. _ = websocket_close(State, Req, HandlerState, {crash, Class, Reason}),
  349. exit({cowboy_handler, [
  350. {class, Class},
  351. {reason, Reason},
  352. {mfa, {Handler, Callback, 3}},
  353. {stacktrace, erlang:get_stacktrace()},
  354. {msg, Message},
  355. {req, Req},
  356. {state, HandlerState}
  357. ]})
  358. end.
  359. -spec websocket_send(cow_ws:frame(), #state{}) -> ok | stop | {error, atom()}.
  360. websocket_send(Frame, #state{socket=Socket, transport=Transport, extensions=Extensions}) ->
  361. Res = Transport:send(Socket, cow_ws:frame(Frame, Extensions)),
  362. case Frame of
  363. close -> stop;
  364. {close, _} -> stop;
  365. {close, _, _} -> stop;
  366. _ -> Res
  367. end.
  368. -spec websocket_send_many([cow_ws:frame()], #state{}) -> ok | stop | {error, atom()}.
  369. websocket_send_many([], _) ->
  370. ok;
  371. websocket_send_many([Frame|Tail], State) ->
  372. case websocket_send(Frame, State) of
  373. ok -> websocket_send_many(Tail, State);
  374. stop -> stop;
  375. Error -> Error
  376. end.
  377. -spec websocket_close(#state{}, Req, any(), terminate_reason())
  378. -> {ok, Req, cowboy_middleware:env()}
  379. when Req::cowboy_req:req().
  380. websocket_close(State=#state{socket=Socket, transport=Transport, extensions=Extensions},
  381. Req, HandlerState, Reason) ->
  382. case Reason of
  383. Normal when Normal =:= stop; Normal =:= timeout ->
  384. Transport:send(Socket, cow_ws:frame({close, 1000, <<>>}, Extensions));
  385. {error, badframe} ->
  386. Transport:send(Socket, cow_ws:frame({close, 1002, <<>>}, Extensions));
  387. {error, badencoding} ->
  388. Transport:send(Socket, cow_ws:frame({close, 1007, <<>>}, Extensions));
  389. {crash, _, _} ->
  390. Transport:send(Socket, cow_ws:frame({close, 1011, <<>>}, Extensions));
  391. remote ->
  392. Transport:send(Socket, cow_ws:frame(close, Extensions));
  393. {remote, Code, _} ->
  394. Transport:send(Socket, cow_ws:frame({close, Code, <<>>}, Extensions))
  395. end,
  396. handler_terminate(State, Req, HandlerState, Reason).
  397. -spec handler_terminate(#state{}, Req, any(), terminate_reason())
  398. -> {ok, Req, cowboy_middleware:env()}
  399. when Req::cowboy_req:req().
  400. handler_terminate(#state{handler=Handler},
  401. Req, HandlerState, Reason) ->
  402. cowboy_handler:terminate(Reason, Req, HandlerState, Handler),
  403. exit(normal).