Browse Source

Tests and some of the basic functionality

Seth Falcon 14 years ago
parent
commit
6b9b942488
2 changed files with 131 additions and 17 deletions
  1. 42 15
      src/pidq.erl
  2. 89 2
      test/pidq_test.erl

+ 42 - 15
src/pidq.erl

@@ -29,7 +29,6 @@
 
 
 -export([start/1,
 -export([start/1,
          stop/0,
          stop/0,
-         stop/1,
          take_pid/0,
          take_pid/0,
          return_pid/2,
          return_pid/2,
          remove_pool/2,
          remove_pool/2,
@@ -52,14 +51,12 @@ start(Config) ->
 stop() ->
 stop() ->
     gen_server:call(?SERVER, stop).
     gen_server:call(?SERVER, stop).
 
 
-stop(_How) ->
-    stop().
-
 take_pid() ->
 take_pid() ->
     gen_server:call(?SERVER, take_pid).
     gen_server:call(?SERVER, take_pid).
 
 
 return_pid(Pid, Status) when Status == ok; Status == fail ->
 return_pid(Pid, Status) when Status == ok; Status == fail ->
-    gen_server:call(?SERVER, {return_pid, Pid, Status}).
+    gen_server:cast(?SERVER, {return_pid, Pid, Status}),
+    ok.
 
 
 remove_pool(Name, How) when How == graceful; How == immediate ->
 remove_pool(Name, How) when How == graceful; How == immediate ->
     gen_server:call(?SERVER, {remove_pool, Name, How}).
     gen_server:call(?SERVER, {remove_pool, Name, How}).
@@ -82,6 +79,7 @@ init(Config) ->
                                      {?MODULE, default_stopper}),
                                      {?MODULE, default_stopper}),
                    npools = length(Pools),
                    npools = length(Pools),
                    pools = dict:from_list(Pools)},
                    pools = dict:from_list(Pools)},
+    process_flag(trap_exit, true),
     {ok, State}.
     {ok, State}.
 
 
 handle_call(take_pid, {CPid, _Tag}, State) ->
 handle_call(take_pid, {CPid, _Tag}, State) ->
@@ -90,16 +88,25 @@ handle_call(take_pid, {CPid, _Tag}, State) ->
     {NewPid, NewState} = take_pid(PoolName, CPid, State),
     {NewPid, NewState} = take_pid(PoolName, CPid, State),
     {reply, NewPid, NewState};
     {reply, NewPid, NewState};
 handle_call(stop, _From, State) ->
 handle_call(stop, _From, State) ->
+    % FIXME:
+    % loop over in use and free pids and stop them?
+    % {M, F} = State#state.pid_stopper,
     {stop, normal, stop_ok, State};
     {stop, normal, stop_ok, State};
 handle_call(_Request, _From, State) ->
 handle_call(_Request, _From, State) ->
     {noreply, ok, State}.
     {noreply, ok, State}.
 
 
 
 
-handle_cast({return_pid, Pid, _Status}, State) ->
-    {noreply, do_return_pid(Pid, State)};
+handle_cast({return_pid, Pid, Status}, State) ->
+    {noreply, do_return_pid(Pid, Status, State)};
 handle_cast(_Msg, State) ->
 handle_cast(_Msg, State) ->
     {noreply, State}.
     {noreply, State}.
 
 
+handle_info({'EXIT', Pid, _Reason}, State) ->
+    State1 = case dict:find(Pid, State#state.in_use_pids) of
+                 {ok, _PName} -> do_return_pid(Pid, fail, State);
+                 error -> State
+             end,
+    {noreply, State1};
 handle_info(_Info, State) ->
 handle_info(_Info, State) ->
     {noreply, State}.
     {noreply, State}.
 
 
@@ -123,13 +130,15 @@ props_to_pool(P) ->
     Values = [ ?gv(Field, P2) || Field <- record_info(fields, pool) ],
     Values = [ ?gv(Field, P2) || Field <- record_info(fields, pool) ],
     list_to_tuple([pool|Values]).
     list_to_tuple([pool|Values]).
 
 
+add_pids(error, _N, State) ->
+    {bad_pool_name, State};
 add_pids(PoolName, N, State) ->
 add_pids(PoolName, N, State) ->
     #state{pools = Pools, pid_starter = {M, F}} = State,
     #state{pools = Pools, pid_starter = {M, F}} = State,
     Pool = dict:fetch(PoolName, Pools),
     Pool = dict:fetch(PoolName, Pools),
     #pool{max_pids = Max, free_pids = Free, in_use_count = NumInUse,
     #pool{max_pids = Max, free_pids = Free, in_use_count = NumInUse,
           pid_starter_args = Args} = Pool,
           pid_starter_args = Args} = Pool,
     Total = length(Free) + NumInUse,
     Total = length(Free) + NumInUse,
-    case Total + N < Max of
+    case Total + N =< Max of
         true ->
         true ->
             % FIXME: we'll want to link to these pids so we'll know if
             % FIXME: we'll want to link to these pids so we'll know if
             % they crash. Or should the starter function be expected
             % they crash. Or should the starter function be expected
@@ -149,8 +158,12 @@ take_pid(PoolName, From, State) ->
         [] when NumInUse == Max ->
         [] when NumInUse == Max ->
             {error_no_pids, State};
             {error_no_pids, State};
         [] when NumInUse < Max ->
         [] when NumInUse < Max ->
-            {_Status, State1} = add_pids(PoolName, 1, State),
-            take_pid(PoolName, From, State1);
+            case add_pids(PoolName, 1, State) of
+                {ok, State1} ->
+                    take_pid(PoolName, From, State1);
+                {max_pids_reached, _} ->
+                    {error_no_pids, State}
+            end;
         [Pid|Rest] ->
         [Pid|Rest] ->
             % FIXME: handle min_free here -- should adding pids
             % FIXME: handle min_free here -- should adding pids
             % to satisfy min_free be done in a spawned worker?
             % to satisfy min_free be done in a spawned worker?
@@ -161,16 +174,30 @@ take_pid(PoolName, From, State) ->
                               consumer_to_pid = CPMap1}}
                               consumer_to_pid = CPMap1}}
     end.
     end.
 
 
-do_return_pid(Pid, State) ->
+do_return_pid(Pid, Status, State) ->
     #state{in_use_pids = InUse, pools = Pools} = State,
     #state{in_use_pids = InUse, pools = Pools} = State,
     case dict:find(Pid, InUse) of
     case dict:find(Pid, InUse) of
         {ok, PoolName} ->
         {ok, PoolName} ->
             Pool = dict:fetch(PoolName, Pools),
             Pool = dict:fetch(PoolName, Pools),
-            #pool{free_pids = Free, in_use_count = NumInUse} = Pool,
-            Pool1 = Pool#pool{free_pids = [Pid|Free], in_use_count = NumInUse - 1},
-            State#state{in_use_pids = dict:erase(Pid, InUse),
-                        pools = dict:store(PoolName, Pool1, Pools)};
+            {Pool1, State1} =
+                case Status of
+                    ok -> {add_pid_to_free(Pid, Pool), State};
+                    fail -> handle_failed_pid(Pid, PoolName, Pool, State)
+                    end,
+            State1#state{in_use_pids = dict:erase(Pid, InUse),
+                         pools = dict:store(PoolName, Pool1, Pools)};
         error ->
         error ->
             error_logger:warning_report({return_pid_not_found, Pid}),
             error_logger:warning_report({return_pid_not_found, Pid}),
             State
             State
     end.
     end.
+
+add_pid_to_free(Pid, Pool) ->
+    #pool{free_pids = Free, in_use_count = NumInUse} = Pool,
+    Pool#pool{free_pids = [Pid|Free], in_use_count = NumInUse - 1}.
+
+handle_failed_pid(Pid, PoolName, Pool, State) ->
+    {M, F} = State#state.pid_stopper,
+    M:F(Pid),
+    {_, NewState} = add_pids(PoolName, 1, State),
+    NumInUse = Pool#pool.in_use_count,
+    {Pool#pool{in_use_count = NumInUse - 1}, NewState}.

+ 89 - 2
test/pidq_test.erl

@@ -72,8 +72,11 @@ stop_tc(Pid) ->
 
 
 tc_starter(Type) ->
 tc_starter(Type) ->
     Ref = make_ref(),
     Ref = make_ref(),
-    spawn(fun() -> tc_loop({Type, Ref}) end).
+    spawn_link(fun() -> tc_loop({Type, Ref}) end).
 
 
+assert_tc_valid(Pid) ->
+    ?assertMatch({_Type, _Ref}, get_tc_id(Pid)),
+    ok.
 
 
 tc_sanity_test() ->
 tc_sanity_test() ->
     Pid1 = tc_starter("1"),
     Pid1 = tc_starter("1"),
@@ -91,6 +94,78 @@ user_sanity_test() ->
     user_crash(User),
     user_crash(User),
     stop_tc(Pid1).
     stop_tc(Pid1).
 
 
+pidq_basics_test_() ->
+    {foreach,
+     % setup
+     fun() ->
+             Pools = [[{name, "p1"},
+                       {max_pids, 3}, {min_free, 1},
+                       {init_size, 2}, {pid_starter_args, ["type-0"]}]],
+
+             Config = [{pid_starter, {?MODULE, tc_starter}},
+                       {pid_stopper, {?MODULE, stop_tc}},
+                       {pools, Pools}],
+             pidq:start(Config)
+     end,
+     fun(_X) ->
+             pidq:stop()
+     end,
+     [
+      {"take and return one",
+       fun() ->
+               P = pidq:take_pid(),
+               ?assertMatch({"type-0", _Id}, get_tc_id(P)),
+               ok = pidq:return_pid(P, ok)
+       end},
+
+      {"pids are created on demand until max",
+       fun() ->
+               Pids = [pidq:take_pid(), pidq:take_pid(), pidq:take_pid()],
+               ?assertMatch(error_no_pids, pidq:take_pid()),
+               ?assertMatch(error_no_pids, pidq:take_pid()),
+               PRefs = [ R || {_T, R} <- [ get_tc_id(P) || P <- Pids ] ],
+               ?assertEqual(3, length(lists:usort(PRefs)))
+       end
+      },
+
+      {"pids are reused most recent return first",
+       fun() ->
+               P1 = pidq:take_pid(),
+               P2 = pidq:take_pid(),
+               ?assertNot(P1 == P2),
+               ok =  pidq:return_pid(P1, ok),
+               ok = pidq:return_pid(P2, ok),
+               % pids are reused most recent first
+               ?assertEqual(P2, pidq:take_pid()),
+               ?assertEqual(P1, pidq:take_pid())
+       end},
+
+      {"if a pid crashes it is replaced",
+       fun() ->
+               Pids0 = [pidq:take_pid(), pidq:take_pid(), pidq:take_pid()],
+               Ids0 = [ get_tc_id(P) || P <- Pids0 ],
+               % crash them all
+               [ P ! crash || P <- Pids0 ],
+               Pids1 = get_n_pids(3, []),
+               Ids1 = [ get_tc_id(P) || P <- Pids1 ],
+               [ ?assertNot(lists:member(I, Ids0)) || I <- Ids1 ]
+       end
+       }
+
+      % {"if a pid is returned with bad status it is replaced",
+      %  fun() ->
+      %          P1 = pidq:take_pid(),
+      %          P2 = pidq:take_pid(),
+      %          pidq:return_pid(P2, ok),
+      %          pidq:return_pid(P1, fail),
+      %          PN = pidq:take_pid(),
+      %          ?assertEqual(P2, pidq:take_pid()),
+      %          ?assertNot(PN == P1)
+      %  end
+      %  }
+      ]}.
+
+
 pidq_integration_test_() ->
 pidq_integration_test_() ->
     {foreach,
     {foreach,
      % setup
      % setup
@@ -152,4 +227,16 @@ pidq_integration_test_() ->
       %          TcIds3 = lists:sort([ user_id(UPid) || UPid <- Users ]),
       %          TcIds3 = lists:sort([ user_id(UPid) || UPid <- Users ]),
       %          ?assertEqual(lists:usort(TcIds3), TcIds3)
       %          ?assertEqual(lists:usort(TcIds3), TcIds3)
 
 
-             
+
+% testing crash recovery means race conditions when either pids
+% haven't yet crashed or pidq hasn't recovered.  So this helper loops
+% forver until N pids are obtained, ignoring error_no_pids.
+get_n_pids(0, Acc) ->
+    Acc;
+get_n_pids(N, Acc) ->
+    case pidq:take_pid() of
+        error_no_pids ->
+            get_n_pids(N, Acc);
+        Pid ->
+            get_n_pids(N - 1, [Pid|Acc])
+    end.