Browse Source

return rows when 'returning' clause is used

Will 16 years ago
parent
commit
0a19e07b97
3 changed files with 46 additions and 14 deletions
  1. 11 5
      src/pgsql.erl
  2. 1 1
      src/pgsql_connection.erl
  3. 34 8
      test_src/pgsql_tests.erl

+ 11 - 5
src/pgsql.erl

@@ -24,7 +24,7 @@ connect(Host, Username, Opts) ->
 connect(Host, Username, Password, Opts) ->
 connect(Host, Username, Password, Opts) ->
     {ok, C} = pgsql_connection:start_link(),
     {ok, C} = pgsql_connection:start_link(),
     pgsql_connection:connect(C, Host, Username, Password, Opts).
     pgsql_connection:connect(C, Host, Username, Password, Opts).
-    
+
 close(C) when is_pid(C) ->
 close(C) when is_pid(C) ->
     catch pgsql_connection:stop(C),
     catch pgsql_connection:stop(C),
     ok.
     ok.
@@ -119,13 +119,13 @@ receive_result(C) ->
     receive
     receive
         {pgsql, C, done} -> R
         {pgsql, C, done} -> R
     end.
     end.
-            
+
 receive_results(C, Results) ->
 receive_results(C, Results) ->
     case receive_result(C, [], []) of
     case receive_result(C, [], []) of
         done -> lists:reverse(Results);
         done -> lists:reverse(Results);
         R    -> receive_results(C, [R | Results])
         R    -> receive_results(C, [R | Results])
     end.
     end.
-            
+
 receive_result(C, Cols, Rows) ->
 receive_result(C, Cols, Rows) ->
     receive
     receive
         {pgsql, C, {columns, Cols2}} ->
         {pgsql, C, {columns, Cols2}} ->
@@ -135,7 +135,10 @@ receive_result(C, Cols, Rows) ->
         {pgsql, C, {error, _E} = Error} ->
         {pgsql, C, {error, _E} = Error} ->
             Error;
             Error;
         {pgsql, C, {complete, {_Type, Count}}} ->
         {pgsql, C, {complete, {_Type, Count}}} ->
-            {ok, Count};
+            case Rows of
+                [] -> {ok, Count};
+                _L -> {ok, Count, Cols, lists:reverse(Rows)}
+            end;
         {pgsql, C, {complete, _Type}} ->
         {pgsql, C, {complete, _Type}} ->
             {ok, Cols, lists:reverse(Rows)};
             {ok, Cols, lists:reverse(Rows)};
         {pgsql, C, {notice, _N}} ->
         {pgsql, C, {notice, _N}} ->
@@ -158,7 +161,10 @@ receive_extended_result(C, Rows) ->
         {pgsql, C, suspended} ->
         {pgsql, C, suspended} ->
             {partial, lists:reverse(Rows)};
             {partial, lists:reverse(Rows)};
         {pgsql, C, {complete, {_Type, Count}}} ->
         {pgsql, C, {complete, {_Type, Count}}} ->
-            {ok, Count};
+            case Rows of
+                [] -> {ok, Count};
+                _L -> {ok, Count, lists:reverse(Rows)}
+            end;
         {pgsql, C, {complete, _Type}} ->
         {pgsql, C, {complete, _Type}} ->
             {ok, lists:reverse(Rows)};
             {ok, lists:reverse(Rows)};
         {pgsql, C, {notice, _N}} ->
         {pgsql, C, {notice, _N}} ->

+ 1 - 1
src/pgsql_connection.erl

@@ -506,7 +506,7 @@ decode_complete(Bin) ->
         ["UPDATE", Rows]       -> {update, list_to_integer(Rows)};
         ["UPDATE", Rows]       -> {update, list_to_integer(Rows)};
         ["DELETE", Rows]       -> {delete, list_to_integer(Rows)};
         ["DELETE", Rows]       -> {delete, list_to_integer(Rows)};
         ["MOVE", Rows]         -> {move, list_to_integer(Rows)};
         ["MOVE", Rows]         -> {move, list_to_integer(Rows)};
-        ["FETCH", _Rows]       -> fetch;
+        ["FETCH", Rows]        -> {fetch, list_to_integer(Rows)};
         [Type | _Rest]         -> lower_atom(Type)
         [Type | _Rest]         -> lower_atom(Type)
     end.
     end.
 
 

+ 34 - 8
test_src/pgsql_tests.erl

@@ -48,22 +48,22 @@ insert_test() ->
               {ok, 1} = pgsql:squery(C, "insert into test_table1 (id, value) values (3, 'three')")
               {ok, 1} = pgsql:squery(C, "insert into test_table1 (id, value) values (3, 'three')")
       end).
       end).
 
 
-delete_test() ->
+update_test() ->
     with_rollback(
     with_rollback(
       fun(C) ->
       fun(C) ->
               {ok, 1} = pgsql:squery(C, "insert into test_table1 (id, value) values (3, 'three')"),
               {ok, 1} = pgsql:squery(C, "insert into test_table1 (id, value) values (3, 'three')"),
               {ok, 1} = pgsql:squery(C, "insert into test_table1 (id, value) values (4, 'four')"),
               {ok, 1} = pgsql:squery(C, "insert into test_table1 (id, value) values (4, 'four')"),
-              {ok, 2} = pgsql:squery(C, "delete from test_table1 where id > 2"),
-              {ok, _, [{<<"2">>}]} = pgsql:squery(C, "select count(*) from test_table1")
+              {ok, 2} = pgsql:squery(C, "update test_table1 set value = 'foo' where id > 2"),
+              {ok, _, [{<<"2">>}]} = pgsql:squery(C, "select count(*) from test_table1 where value = 'foo'")
       end).
       end).
 
 
-update_test() ->
+delete_test() ->
     with_rollback(
     with_rollback(
       fun(C) ->
       fun(C) ->
               {ok, 1} = pgsql:squery(C, "insert into test_table1 (id, value) values (3, 'three')"),
               {ok, 1} = pgsql:squery(C, "insert into test_table1 (id, value) values (3, 'three')"),
               {ok, 1} = pgsql:squery(C, "insert into test_table1 (id, value) values (4, 'four')"),
               {ok, 1} = pgsql:squery(C, "insert into test_table1 (id, value) values (4, 'four')"),
-              {ok, 2} = pgsql:squery(C, "update test_table1 set value = 'foo' where id > 2"),
-              {ok, _, [{<<"2">>}]} = pgsql:squery(C, "select count(*) from test_table1 where value = 'foo'")
+              {ok, 2} = pgsql:squery(C, "delete from test_table1 where id > 2"),
+              {ok, _, [{<<"2">>}]} = pgsql:squery(C, "select count(*) from test_table1")
       end).
       end).
 
 
 create_and_drop_table_test() ->
 create_and_drop_table_test() ->
@@ -81,9 +81,9 @@ cursor_test() ->
               {ok, [], []} = pgsql:squery(C, "declare c cursor for select id from test_table1"),
               {ok, [], []} = pgsql:squery(C, "declare c cursor for select id from test_table1"),
               {ok, 2} = pgsql:squery(C, "move forward 2 from c"),
               {ok, 2} = pgsql:squery(C, "move forward 2 from c"),
               {ok, 1} = pgsql:squery(C, "move backward 1 from c"),
               {ok, 1} = pgsql:squery(C, "move backward 1 from c"),
-              {ok, _Cols, [{<<"2">>}]} = pgsql:squery(C, "fetch next from c"),
+              {ok, 1, _Cols, [{<<"2">>}]} = pgsql:squery(C, "fetch next from c"),
               {ok, [], []} = pgsql:squery(C, "close c")
               {ok, [], []} = pgsql:squery(C, "close c")
-              end).
+      end).
 
 
 multiple_result_test() ->
 multiple_result_test() ->
     with_connection(
     with_connection(
@@ -115,6 +115,23 @@ extended_sync_error_test() ->
               {ok, _Cols, [{<<"one">>}]} = pgsql:equery(C, "select value from test_table1 where id = $1", [1])
               {ok, _Cols, [{<<"one">>}]} = pgsql:equery(C, "select value from test_table1 where id = $1", [1])
       end).
       end).
 
 
+returning_from_insert_test() ->
+    with_rollback(
+      fun(C) ->
+              {ok, 1, _Cols, [{3}]} = pgsql:equery(C, "insert into test_table1 (id) values (3) returning id")
+      end).
+
+returning_from_update_test() ->
+    with_rollback(
+      fun(C) ->
+              {ok, 2, _Cols, [{1}, {2}]} = pgsql:equery(C, "update test_table1 set value = 'hi' returning id")
+      end).
+
+returning_from_delete_test() ->
+    with_rollback(
+      fun(C) ->
+              {ok, 2, _Cols, [{1}, {2}]} = pgsql:equery(C, "delete from test_table1 returning id")
+      end).
 
 
 parse_test() ->
 parse_test() ->
     with_connection(
     with_connection(
@@ -253,6 +270,15 @@ portal_test() ->
               ok = pgsql:sync(C)
               ok = pgsql:sync(C)
       end).
       end).
 
 
+returning_test() ->
+    with_rollback(
+      fun(C) ->
+              {ok, S} = pgsql:parse(C, "update test_table1 set value = $1 returning id"),
+              ok = pgsql:bind(C, S, ["foo"]),
+              {ok, 2, [{1}, {2}]} = pgsql:execute(C, S),
+              ok = pgsql:sync(C)
+      end).
+
 multiple_statement_test() ->
 multiple_statement_test() ->
     with_connection(
     with_connection(
       fun(C) ->
       fun(C) ->