Browse Source

Handle supervisor calls properly everywhere

Loïc Hoguin 7 years ago
parent
commit
b9c8d86502
5 changed files with 35 additions and 47 deletions
  1. 19 2
      src/cowboy_children.erl
  2. 2 8
      src/cowboy_http.erl
  3. 2 8
      src/cowboy_http2.erl
  4. 4 0
      src/cowboy_websocket.erl
  5. 8 29
      test/sys_SUITE.erl

+ 19 - 2
src/cowboy_children.erl

@@ -20,8 +20,7 @@
 -export([shutdown/2]).
 -export([shutdown/2]).
 -export([shutdown_timeout/3]).
 -export([shutdown_timeout/3]).
 -export([terminate/1]).
 -export([terminate/1]).
--export([which_children/2]).
--export([count_children/1]).
+-export([handle_supervisor_call/4]).
 
 
 -record(child, {
 -record(child, {
 	pid :: pid(),
 	pid :: pid(),
@@ -160,6 +159,24 @@ longest_shutdown_time([#child{shutdown=ChildTime}|Tail], Time) when ChildTime >
 longest_shutdown_time([_|Tail], Time) ->
 longest_shutdown_time([_|Tail], Time) ->
 	longest_shutdown_time(Tail, Time).
 	longest_shutdown_time(Tail, Time).
 
 
+-spec handle_supervisor_call(any(), {pid(), any()}, children(), module()) -> ok.
+handle_supervisor_call(which_children, {From, Tag}, Children, Module) ->
+	From ! {Tag, which_children(Children, Module)},
+	ok;
+handle_supervisor_call(count_children, {From, Tag}, Children, _) ->
+	From ! {Tag, count_children(Children)},
+	ok;
+%% We disable start_child since only incoming requests
+%% end up creating a new process.
+handle_supervisor_call({start_child, _}, {From, Tag}, _, _) ->
+	From ! {Tag, {error, start_child_disabled}},
+	ok;
+%% All other calls refer to children. We act in a similar way
+%% to a simple_one_for_one so we never find those.
+handle_supervisor_call(_, {From, Tag}, _, _) ->
+	From ! {Tag, {error, not_found}},
+	ok.
+
 -spec which_children(children(), module()) -> [{module(), pid(), worker, [module()]}].
 -spec which_children(children(), module()) -> [{module(), pid(), worker, [module()]}].
 which_children(Children, Module) ->
 which_children(Children, Module) ->
 	[{Module, Pid, worker, [Module]} || #child{pid=Pid} <- Children].
 	[{Module, Pid, worker, [Module]} || #child{pid=Pid} <- Children].

+ 2 - 8
src/cowboy_http.erl

@@ -202,14 +202,8 @@ loop(State=#state{parent=Parent, socket=Socket, transport=Transport, opts=Opts,
 		Msg = {'EXIT', Pid, _} ->
 		Msg = {'EXIT', Pid, _} ->
 			loop(down(State, Pid, Msg), Buffer);
 			loop(down(State, Pid, Msg), Buffer);
 		%% Calls from supervisor module.
 		%% Calls from supervisor module.
-		{'$gen_call', {From, Tag}, which_children} ->
-			From ! {Tag, cowboy_children:which_children(Children, ?MODULE)},
-			loop(State, Buffer);
-		{'$gen_call', {From, Tag}, count_children} ->
-			From ! {Tag, cowboy_children:count_children(Children)},
-			loop(State, Buffer);
-		{'$gen_call', {From, Tag}, _} ->
-			From ! {Tag, {error, ?MODULE}},
+		{'$gen_call', From, Call} ->
+			cowboy_children:handle_supervisor_call(Call, From, Children, ?MODULE),
 			loop(State, Buffer);
 			loop(State, Buffer);
 		%% Unknown messages.
 		%% Unknown messages.
 		Msg ->
 		Msg ->

+ 2 - 8
src/cowboy_http2.erl

@@ -247,14 +247,8 @@ loop(State=#state{parent=Parent, socket=Socket, transport=Transport,
 		Msg = {'EXIT', Pid, _} ->
 		Msg = {'EXIT', Pid, _} ->
 			loop(down(State, Pid, Msg), Buffer);
 			loop(down(State, Pid, Msg), Buffer);
 		%% Calls from supervisor module.
 		%% Calls from supervisor module.
-		{'$gen_call', {From, Tag}, which_children} ->
-			From ! {Tag, cowboy_children:which_children(Children, ?MODULE)},
-			loop(State, Buffer);
-		{'$gen_call', {From, Tag}, count_children} ->
-			From ! {Tag, cowboy_children:count_children(Children)},
-			loop(State, Buffer);
-		{'$gen_call', {From, Tag}, _} ->
-			From ! {Tag, {error, ?MODULE}},
+		{'$gen_call', From, Call} ->
+			cowboy_children:handle_supervisor_call(Call, From, Children, ?MODULE),
 			loop(State, Buffer);
 			loop(State, Buffer);
 		Msg ->
 		Msg ->
 			error_logger:error_msg("Received stray message ~p.", [Msg]),
 			error_logger:error_msg("Received stray message ~p.", [Msg]),

+ 4 - 0
src/cowboy_websocket.erl

@@ -250,6 +250,10 @@ handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error},
 			websocket_close(State, HandlerState, timeout);
 			websocket_close(State, HandlerState, timeout);
 		{timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
 		{timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
 			handler_loop(State, HandlerState, SoFar);
 			handler_loop(State, HandlerState, SoFar);
+		%% Calls from supervisor module.
+		{'$gen_call', From, Call} ->
+			cowboy_children:handle_supervisor_call(Call, From, [], ?MODULE),
+			handler_loop(State, HandlerState, SoFar);
 		Message ->
 		Message ->
 			handler_call(State, HandlerState,
 			handler_call(State, HandlerState,
 				SoFar, websocket_info, Message, fun handler_before_loop/3)
 				SoFar, websocket_info, Message, fun handler_before_loop/3)

+ 8 - 29
test/sys_SUITE.erl

@@ -670,10 +670,7 @@ supervisor_count_children_ws(Config) ->
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	timer:sleep(100),
 	timer:sleep(100),
 	Pid = do_get_remote_pid_tcp(Socket),
 	Pid = do_get_remote_pid_tcp(Socket),
-	%% We use gen_server:call directly because the supervisor:count_children
-	%% function has a timeout of infinity.
-	%% @todo This can be changed to supervisor:count_children/1 once it is fixed.
-	Counts = gen_server:call(Pid, count_children, 1000),
+	Counts = supervisor:count_children(Pid),
 	1 = proplists:get_value(specs, Counts),
 	1 = proplists:get_value(specs, Counts),
 	0 = proplists:get_value(active, Counts),
 	0 = proplists:get_value(active, Counts),
 	0 = proplists:get_value(supervisors, Counts),
 	0 = proplists:get_value(supervisors, Counts),
@@ -741,10 +738,7 @@ supervisor_delete_child_not_found_ws(Config) ->
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	timer:sleep(100),
 	timer:sleep(100),
 	Pid = do_get_remote_pid_tcp(Socket),
 	Pid = do_get_remote_pid_tcp(Socket),
-	%% We use gen_server:call directly because the supervisor:delete_child
-	%% function has a timeout of infinity.
-	%% @todo This can be changed to supervisor:delete_child/2 once it is fixed.
-	{error, not_found} = gen_server:call(Pid, {delete_child, cowboy_websocket}, 1000),
+	{error, not_found} = supervisor:delete_child(Pid, cowboy_websocket),
 	ok.
 	ok.
 
 
 %% supervisor:get_childspec/2.
 %% supervisor:get_childspec/2.
@@ -808,10 +802,7 @@ supervisor_get_childspec_not_found_ws(Config) ->
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	timer:sleep(100),
 	timer:sleep(100),
 	Pid = do_get_remote_pid_tcp(Socket),
 	Pid = do_get_remote_pid_tcp(Socket),
-	%% We use gen_server:call directly because the supervisor:get_childspec
-	%% function has a timeout of infinity.
-	%% @todo This can be changed to supervisor:get_childspec/2 once it is fixed.
-	{error, not_found} = gen_server:call(Pid, {get_childspec, cowboy_websocket}, 1000),
+	{error, not_found} = supervisor:get_childspec(Pid, cowboy_websocket),
 	ok.
 	ok.
 
 
 %% supervisor:restart_child/2.
 %% supervisor:restart_child/2.
@@ -875,10 +866,7 @@ supervisor_restart_child_not_found_ws(Config) ->
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	timer:sleep(100),
 	timer:sleep(100),
 	Pid = do_get_remote_pid_tcp(Socket),
 	Pid = do_get_remote_pid_tcp(Socket),
-	%% We use gen_server:call directly because the supervisor:restart_child
-	%% function has a timeout of infinity.
-	%% @todo This can be changed to supervisor:restart_child/2 once it is fixed.
-	{error, not_found} = gen_server:call(Pid, {restart_child, cowboy_websocket}, 1000),
+	{error, not_found} = supervisor:restart_child(Pid, cowboy_websocket),
 	ok.
 	ok.
 
 
 %% supervisor:start_child/2 must return {error, start_child_disabled}
 %% supervisor:start_child/2 must return {error, start_child_disabled}
@@ -929,13 +917,10 @@ supervisor_start_child_not_found_ws(Config) ->
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	timer:sleep(100),
 	timer:sleep(100),
 	Pid = do_get_remote_pid_tcp(Socket),
 	Pid = do_get_remote_pid_tcp(Socket),
-	%% We use gen_server:call directly because the supervisor:start_child
-	%% function has a timeout of infinity.
-	%% @todo This can be changed to supervisor:start_child/2 once it is fixed.
-	{error, start_child_disabled} = gen_server:call(Pid, {start_child, #{
+	{error, start_child_disabled} = supervisor:start_child(Pid, #{
 		id => error,
 		id => error,
 		start => {error, error, []}
 		start => {error, error, []}
-	}}, 1000),
+	}),
 	ok.
 	ok.
 
 
 %% supervisor:terminate_child/2.
 %% supervisor:terminate_child/2.
@@ -999,10 +984,7 @@ supervisor_terminate_child_not_found_ws(Config) ->
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	timer:sleep(100),
 	timer:sleep(100),
 	Pid = do_get_remote_pid_tcp(Socket),
 	Pid = do_get_remote_pid_tcp(Socket),
-	%% We use gen_server:call directly because the supervisor:terminate_child
-	%% function has a timeout of infinity.
-	%% @todo This can be changed to supervisor:terminate_child/2 once it is fixed.
-	{error, not_found} = gen_server:call(Pid, {terminate_child, cowboy_websocket}, 1000),
+	{error, not_found} = supervisor:terminate_child(Pid, cowboy_websocket),
 	ok.
 	ok.
 
 
 %% supervisor:which_children/1.
 %% supervisor:which_children/1.
@@ -1072,8 +1054,5 @@ supervisor_which_children_ws(Config) ->
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	{ok, {http_response, {1, 1}, 101, _}, _} = erlang:decode_packet(http, Handshake, []),
 	timer:sleep(100),
 	timer:sleep(100),
 	Pid = do_get_remote_pid_tcp(Socket),
 	Pid = do_get_remote_pid_tcp(Socket),
-	%% We use gen_server:call directly because the supervisor:which_children
-	%% function has a timeout of infinity.
-	%% @todo This can be changed to supervisor:which_children/1 once it is fixed.
-	[] = gen_server:call(Pid, which_children, 1000),
+	[] = supervisor:which_children(Pid),
 	ok.
 	ok.