Browse Source

Initial commit

Alexander Verbitsky 10 years ago
commit
b27afcc821
6 changed files with 261 additions and 0 deletions
  1. 4 0
      README.md
  2. 12 0
      include/epgsql_pool.hrl
  3. 12 0
      src/epgsql_pool.app.src
  4. 44 0
      src/epgsql_pool.erl
  5. 53 0
      src/epgsql_pool_app.erl
  6. 136 0
      src/epgsql_pool_worker.erl

+ 4 - 0
README.md

@@ -0,0 +1,4 @@
+# Connection pool for PostgreSQL
+
+Connection pool for PostgreSQL based on [epgsql](https://github.com/epgsql/epgsql)
+and [pooler](https://github.com/seth/pooler)

+ 12 - 0
include/epgsql_pool.hrl

@@ -0,0 +1,12 @@
+
+-record(epgsql_params, {
+    host                :: string(),
+    port                :: non_neg_integer(),
+    username            :: string(),
+    password            :: string(),
+    database            :: string(),
+    connection_timeout  :: non_neg_integer(),
+    query_timeout       :: non_neg_integer()
+}).
+
+-define(POOL_NAME, epgsql_pool).

+ 12 - 0
src/epgsql_pool.app.src

@@ -0,0 +1,12 @@
+{application, epgsql_pool, [
+    {description, "Connection pool for PostgreSQL"},
+    {vsn, "1"},
+    {registered, []},
+    {applications, [
+        wgconfig,
+        pooler,
+        epgsql
+    ]},
+    {mod, { epgsql_pool_app, []}},
+    {env, []}
+]}.

+ 44 - 0
src/epgsql_pool.erl

@@ -0,0 +1,44 @@
+-module(epgsql_pool).
+
+-export([
+    transaction/1,
+    equery/2
+]).
+
+-include("epgsql_pool.hrl").
+
+-define(TIMEOUT, 1000).
+
+equery(Stmt, Params) ->
+    % Example
+    % epgsql_pool:equery("SELECT NOW() as now", []).
+    transaction(
+        fun(Worker) ->
+            equery({worker, Worker}, Stmt, Params)
+        end).
+
+equery({worker, Worker}, Stmt, Params) ->
+    gen_server:call(Worker, {equery, Stmt, Params}, infinity).
+
+transaction(Fun) ->
+    % TODO: logging with time of execution
+    case pooler:take_member(?POOL_NAME, ?TIMEOUT) of
+        Worker when is_pid(Worker) ->
+            W = {worker, Worker},
+            try
+                equery(W, "BEGIN", []),
+                Result = Fun(Worker),
+                equery(W, "COMMIT", []),
+                Result
+            catch
+                Err:Reason ->
+                    equery(W, "ROLLBACK", []),
+                    erlang:raise(Err, Reason, erlang:get_stacktrace())
+            after
+                pooler:return_member(?POOL_NAME, Worker, ok)
+            end;
+        error_no_members ->
+            PoolStats = pooler:pool_stats(?POOL_NAME),
+            lager:warning("Pool overload: ~p", [PoolStats]),
+            {error, no_members}
+    end.

+ 53 - 0
src/epgsql_pool_app.erl

@@ -0,0 +1,53 @@
+-module(epgsql_pool_app).
+
+-behaviour(application).
+
+%% Application callbacks
+-export([start/2, stop/1]).
+
+%% ===================================================================
+%% Application callbacks
+%% ===================================================================
+
+start(_StartType, _StartArgs) ->
+    register(?MODULE, self()),
+    start_pool(),
+    {ok, self()}.
+
+stop(_State) ->
+    stop_pool(),
+    ok.
+
+-include("epgsql_pool.hrl").
+
+start_pool() ->
+    InitCount = wgconfig:get_int(?POOL_NAME, init_count),
+    MaxCount = wgconfig:get_int(?POOL_NAME, max_count),
+
+    % Connection parameters
+    Hots = wgconfig:get_string(?POOL_NAME, host),
+    Port = wgconfig:get_int(?POOL_NAME, port),
+    Username = wgconfig:get_string(?POOL_NAME, username),
+    Password = wgconfig:get_string(?POOL_NAME, password),
+    Database = wgconfig:get_string(?POOL_NAME, database),
+    ConnectionTimeout = wgconfig:get_int(?POOL_NAME, connection_timeout),
+    QueryTimeout = wgconfig:get_int(?POOL_NAME, query_timeout),
+
+    Params = #epgsql_params{
+        host=Hots, port=Port,
+        username=Username, password=Password,
+        database=Database,
+        connection_timeout=ConnectionTimeout,
+        query_timeout=QueryTimeout
+    },
+
+    PoolConfig = [
+        {name, ?POOL_NAME},
+        {init_count, InitCount},
+        {max_count, MaxCount},
+        {start_mfa, {epgsql_pool_worker, start_link, [Params]}}
+    ],
+    pooler:new_pool(PoolConfig).
+
+stop_pool() ->
+    pooler:rm_pool(?POOL_NAME).

+ 136 - 0
src/epgsql_pool_worker.erl

@@ -0,0 +1,136 @@
+-module(epgsql_pool_worker).
+
+-behaviour(gen_server).
+
+-export([start_link/1]).
+
+-export([
+    init/1,
+    handle_call/3,
+    handle_cast/2,
+    handle_info/2,
+    terminate/2,
+    code_change/3
+]).
+
+-include("epgsql_pool.hrl").
+-include_lib("epgsql/include/epgsql.hrl").
+
+-define(MAX_RECONNECT_TIMEOUT, 1000*30).
+-define(MIN_RECONNECT_TIMEOUT, 200).
+
+-record(state, {
+    connection            :: pid(),
+    params                :: #epgsql_params{},
+    reconnect_attempt = 0 :: non_neg_integer(),
+    reconnect_timeout = 0 :: non_neg_integer()
+}).
+
+start_link(Params) ->
+    gen_server:start_link(?MODULE, Params, []).
+
+init(Params) -> 
+    process_flag(trap_exit, true),
+    random:seed(now()),
+    self() ! open_connection,
+    {ok, #state{params=Params}}.
+
+handle_call({_Message}, _From, #state{connection = undefined} = State) ->
+    {reply, {error, no_connection}, State};
+handle_call({equery, Stmt, Params}, From, State) ->
+    TStart = now(),
+    Result = epgsql:equery(State#state.connection, Stmt, Params),
+    Time = timer:now_diff(now(), TStart),
+    lager:debug(
+        "Stmt=~p, Params=~p, Time=~p ms, Result=~p",
+        [Stmt, Params, Time / 1.0e3, Result]),
+    {reply, Result, State};
+handle_call(Message, From, State) ->
+    lager:info(
+        "Call / Message: ~p, From: ~p, State: ~p", [Message, From, State]),
+    {reply, ok, State}.
+
+handle_cast(Message, State) ->
+    lager:info("Cast / Message: ~p, State: ~p", [Message, State]),
+    {noreply, State}.
+
+handle_info(open_connection, State) ->
+    case open_connection(State) of
+        {ok, UpdState} ->
+            {noreply, UpdState};
+        {error, UpdState} ->
+            {noreply, reconnect(UpdState)}
+    end;
+
+handle_info({'EXIT', Pid, Reason}, #state{connection = C} = State) when Pid == C ->
+    lager:error("Exit with reason: ~p", [Reason]),
+    close_connection(State),
+    {noreply, reconnect(State)};
+
+handle_info(Message, State) ->
+    lager:debug("Info / Msg: ~p, State: ~p", [Message, State]),
+    {noreply, State}.
+
+terminate(Reason, State) ->
+    lager:debug("Terminate / Reason: ~p, State: ~p", [Reason, State]),
+    normal.
+
+code_change(_OldVsn, State, _Extra) ->
+    {ok, State}.
+
+%% -- internal functions --
+
+open_connection(#state{params = Params} = State) ->
+
+    #epgsql_params{
+        host               = Host,
+        port               = Port,
+        username           = Username,
+        password           = Password,
+        database           = Database,
+        connection_timeout = ConnectionTimeout
+    } = Params,
+
+    Res = epgsql:connect(Host, Username, Password, [        
+        {port, Port},
+        {database, Database},
+        {timeout, ConnectionTimeout}
+    ]),
+    case Res of
+        {ok, Sock} ->
+            {ok, State#state{
+                connection=Sock,
+                reconnect_attempt=0}};
+        {error, Reason} ->
+            lager:error("Connect fail: ~p", [Reason]),
+            {error, State}
+    end.
+
+close_connection(State) ->
+    Connection = State#state.connection,
+    epgsql:close(Connection),
+    #state{connection = undefined}.
+
+reconnect(#state{
+        reconnect_attempt = R,
+        reconnect_timeout = T} = State) ->
+    case T > ?MAX_RECONNECT_TIMEOUT of
+        true ->
+            reconnect_after(R, ?MIN_RECONNECT_TIMEOUT, T),
+            State#state{reconnect_attempt = R + 1};
+        _ ->
+            T2 = exponential_backoff(R, ?MIN_RECONNECT_TIMEOUT),
+            reconnect_after(R, ?MIN_RECONNECT_TIMEOUT, T2),
+            State#state{reconnect_attempt=R + 1, reconnect_timeout=T2}
+    end.
+
+reconnect_after(R, Tmin, Tmax) ->
+    Delay = rand_range(Tmin, Tmax),
+    lager:error("Reconnect after ~w ms (attempt ~w)", [Delay, R]),
+    erlang:send_after(Delay, self(), open_connection).
+
+rand_range(Min, Max) ->
+    max(random:uniform(Max), Min).
+
+exponential_backoff(N, T) ->
+    erlang:round(math:pow(2, N)) * T.