Browse Source

Merge pull request #147 from seriyps/with-transaction-3

Add with_transaction/3.
Sergey Prokhorov 7 years ago
parent
commit
0c50beae3b
4 changed files with 96 additions and 34 deletions
  1. 44 5
      src/epgsql.erl
  2. 52 1
      test/epgsql_SUITE.erl
  3. 0 14
      test/epgsql_cast.erl
  4. 0 14
      test/epgsql_incremental.erl

+ 44 - 5
src/epgsql.erl

@@ -22,6 +22,7 @@
          update_type_cache/1,
          update_type_cache/2,
          with_transaction/2,
+         with_transaction/3,
          sync_on_error/2,
          standby_status_update/3,
          start_replication/5,
@@ -328,17 +329,55 @@ cancel(C) ->
                                   when
       Reply :: any().
 with_transaction(C, F) ->
-    try {ok, [], []} = squery(C, "BEGIN"),
+    with_transaction(C, F, [{reraise, false}]).
+
+%% @doc Execute callback function with connection in a transaction.
+%% Transaction will be rolled back in case of exception.
+%% Options (proplist or map):
+%% - reraise (true): when set to true, exception will be re-thrown, otherwise
+%%   {rollback, ErrorReason} will be returned
+%% - ensure_comitted (false): even when callback returns without exception,
+%%   check that transaction was comitted by checking CommandComplete status
+%%   of "COMMIT" command. In case when transaction was rolled back, status will be
+%%   "rollback" instead of "commit".
+%% - begin_opts (""): append extra options to "BEGIN" command (see
+%%   https://www.postgresql.org/docs/current/static/sql-begin.html)
+%%   Beware of SQL injections! No escaping is made on begin_opts!
+-spec with_transaction(
+        connection(), fun((connection()) -> Reply), Opts) -> Reply | {rollback, any()} | no_return() when
+      Reply :: any(),
+      Opts :: [{reraise, boolean()} |
+               {ensure_committed, boolean()} |
+               {begin_opts, iodata()}].
+with_transaction(C, F, Opts0) ->
+    Opts = to_proplist(Opts0),
+    Begin = case proplists:get_value(begin_opts, Opts) of
+                undefined -> <<"BEGIN">>;
+                BeginOpts ->
+                    [<<"BEGIN ">> | BeginOpts]
+            end,
+    try
+        {ok, [], []} = squery(C, Begin),
         R = F(C),
-        {ok, [], []} = squery(C, "COMMIT"),
+        {ok, [], []} = squery(C, <<"COMMIT">>),
+        case proplists:get_value(ensure_committed, Opts, false) of
+            true ->
+                {ok, CmdStatus} = get_cmd_status(C),
+                (commit == CmdStatus) orelse error({ensure_committed_failed, CmdStatus});
+            false -> ok
+        end,
         R
     catch
-        _:Why ->
+        Type:Reason ->
             squery(C, "ROLLBACK"),
-            %% TODO hides error stacktrace
-            {rollback, Why}
+            handle_error(Type, Reason, proplists:get_value(reraise, Opts, true))
     end.
 
+handle_error(_, Reason, false) ->
+    {rollback, Reason};
+handle_error(Type, Reason, true) ->
+    erlang:raise(Type, Reason, erlang:get_stacktrace()).
+
 sync_on_error(C, Error = {error, _}) ->
     ok = sync(C),
     Error;

+ 52 - 1
test/epgsql_SUITE.erl

@@ -68,6 +68,9 @@ groups() ->
             range_type,
             range8_type,
             custom_types
+        ]},
+        {generic, [parallel], [
+            with_transaction
         ]}
     ],
 
@@ -125,7 +128,12 @@ groups() ->
         set_notice_receiver,
         get_cmd_status
     ],
-    Groups ++ [{Module, [], Tests} || Module <- modules()].
+    Groups ++ [case Module of
+                   epgsql ->
+                       {Module, [], [{group, generic} | Tests]};
+                   _ ->
+                       {Module, [], Tests}
+               end || Module <- modules()].
 
 end_per_suite(_Config) ->
     ok.
@@ -1043,6 +1051,49 @@ range8_type(Config) ->
         ])
     end, []).
 
+
+with_transaction(Config) ->
+    Module = ?config(module, Config),
+    epgsql_ct:with_connection(
+      Config,
+      fun(C) ->
+              %% Success case
+              ?assertEqual(
+                 success, Module:with_transaction(C, fun(_) -> success end)),
+              ?assertEqual(
+                 success, Module:with_transaction(C, fun(_) -> success end,
+                                                  [{ensure_committed, true}])),
+              %% begin_opts
+              ?assertMatch(
+                 [{ok, _, [{<<"serializable">>}]},
+                  {ok, _, [{<<"on">>}]}],
+                 Module:with_transaction(
+                   C, fun(C1) ->
+                              Module:squery(C1, ("SHOW transaction_isolation; "
+                                                 "SHOW transaction_read_only"))
+                      end,
+                   [{begin_opts, "READ ONLY ISOLATION LEVEL SERIALIZABLE"}])),
+              %% ensure_committed failure
+              ?assertError(
+                 {ensure_committed_failed, rollback},
+                 Module:with_transaction(
+                   C, fun(C1) ->
+                              {error, _} = Module:squery(C1, "SELECT col FROM _nowhere_"),
+                              ok
+                      end,
+                   [{ensure_committed, true}])),
+              %% reraise
+              ?assertEqual(
+                 {rollback, my_err},
+                 Module:with_transaction(
+                   C, fun(_) -> error(my_err) end,
+                   [{reraise, false}])),
+              ?assertError(
+                 my_err,
+                 Module:with_transaction(
+                   C, fun(_) -> error(my_err) end, []))
+      end, []).
+
 %% =============================================================================
 %% Internal functions
 %% ============================================================================

+ 0 - 14
test/epgsql_cast.erl

@@ -11,7 +11,6 @@
 -export([parse/2, parse/3, parse/4, describe/2, describe/3]).
 -export([bind/3, bind/4, execute/2, execute/3, execute/4, execute_batch/2]).
 -export([close/2, close/3, sync/1]).
--export([with_transaction/2]).
 -export([receive_result/2, sync_on_error/2]).
 
 -include("epgsql.hrl").
@@ -143,19 +142,6 @@ sync(C) ->
     Ref = epgsqla:sync(C),
     receive_result(C, Ref).
 
-%% misc helper functions
-with_transaction(C, F) ->
-    try {ok, [], []} = squery(C, "BEGIN"),
-        R = F(C),
-        {ok, [], []} = squery(C, "COMMIT"),
-        R
-    catch
-        _:Why ->
-            squery(C, "ROLLBACK"),
-            %% TODO hides error stacktrace
-            {rollback, Why}
-    end.
-
 receive_result(C, Ref) ->
     %% TODO timeout
     receive

+ 0 - 14
test/epgsql_incremental.erl

@@ -11,7 +11,6 @@
 -export([parse/2, parse/3, parse/4, describe/2, describe/3]).
 -export([bind/3, bind/4, execute/2, execute/3, execute/4, execute_batch/2]).
 -export([close/2, close/3, sync/1]).
--export([with_transaction/2]).
 
 -include("epgsql.hrl").
 
@@ -147,19 +146,6 @@ sync(C) ->
     Ref = epgsqli:sync(C),
     receive_atom(C, Ref, ok, ok).
 
-%% misc helper functions
-with_transaction(C, F) ->
-    try {ok, [], []} = squery(C, "BEGIN"),
-        R = F(C),
-        {ok, [], []} = squery(C, "COMMIT"),
-        R
-    catch
-        _:Why ->
-            squery(C, "ROLLBACK"),
-            %% TODO hides error stacktrace
-            {rollback, Why}
-    end.
-
 %% -- internal functions --
 
 receive_result(C, Ref, Result) ->