Browse Source

properly decode all command complete tags

Will 16 years ago
parent
commit
7e04153e57
2 changed files with 40 additions and 15 deletions
  1. 20 14
      src/pgsql_connection.erl
  2. 20 1
      test_src/pgsql_tests.erl

+ 20 - 14
src/pgsql_connection.erl

@@ -83,7 +83,7 @@ handle_event({notice, Notice}, State_Name, State) ->
 handle_event({parameter_status, Name, Value}, State_Name, State) ->
 handle_event({parameter_status, Name, Value}, State_Name, State) ->
     Parameters2 = lists:keystore(Name, 1, State#state.parameters, {Name, Value}),
     Parameters2 = lists:keystore(Name, 1, State#state.parameters, {Name, Value}),
     {next_state, State_Name, State#state{parameters = Parameters2}};
     {next_state, State_Name, State#state{parameters = Parameters2}};
-    
+
 handle_event(stop, _State_Name, State) ->
 handle_event(stop, _State_Name, State) ->
     {stop, normal, State};
     {stop, normal, State};
 
 
@@ -98,8 +98,8 @@ handle_info({'EXIT', Pid, Reason}, _State_Name, State = #state{reader = Pid}) ->
 
 
 handle_info(Info, _State_Name, State) ->
 handle_info(Info, _State_Name, State) ->
     {stop, {unsupported_info, Info}, State}.
     {stop, {unsupported_info, Info}, State}.
-    
-terminate(_Reason, _State_Name, State = #state{sock = Sock}) 
+
+terminate(_Reason, _State_Name, State = #state{sock = Sock})
   when Sock =/= undefined ->
   when Sock =/= undefined ->
     send(State, $X, []),
     send(State, $X, []),
     gen_tcp:close(Sock);
     gen_tcp:close(Sock);
@@ -113,7 +113,7 @@ code_change(_Old_Vsn, State_Name, State, _Extra) ->
 %% -- states --
 %% -- states --
 
 
 startup({connect, Host, Username, Password, Opts}, From, State) ->
 startup({connect, Host, Username, Password, Opts}, From, State) ->
-    Port      = proplists:get_value(port, Opts, 5432),    
+    Port      = proplists:get_value(port, Opts, 5432),
     Sock_Opts = [{active, false}, {packet, raw}, binary],
     Sock_Opts = [{active, false}, {packet, raw}, binary],
     case gen_tcp:connect(Host, Port, Sock_Opts) of
     case gen_tcp:connect(Host, Port, Sock_Opts) of
         {ok, Sock} ->
         {ok, Sock} ->
@@ -124,14 +124,14 @@ startup({connect, Host, Username, Password, Opts}, From, State) ->
                 undefined -> Opts3 = Opts2;
                 undefined -> Opts3 = Opts2;
                 Database  -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
                 Database  -> Opts3 = [Opts2 | ["database", 0, Database, 0]]
             end,
             end,
-            
+
             put(username, Username),
             put(username, Username),
             put(password, Password),
             put(password, Password),
             State2 = State#state{reader   = Reader,
             State2 = State#state{reader   = Reader,
                                  sock     = Sock,
                                  sock     = Sock,
                                  reply_to = From},
                                  reply_to = From},
             send(State2, [<<196608:32>>, Opts3, 0]),
             send(State2, [<<196608:32>>, Opts3, 0]),
-    
+
             {next_state, auth, State2};
             {next_state, auth, State2};
         Error ->
         Error ->
             {stop, normal, Error, State}
             {stop, normal, Error, State}
@@ -284,7 +284,7 @@ querying({$3, <<>>}, State) ->
 %% RowDescription
 %% RowDescription
 querying({$T, <<Count:?int16, Bin/binary>>}, State) ->
 querying({$T, <<Count:?int16, Bin/binary>>}, State) ->
     Columns = decode_columns(Count, Bin),
     Columns = decode_columns(Count, Bin),
-    S2 = (State#state.statement)#statement{columns = Columns},    
+    S2 = (State#state.statement)#statement{columns = Columns},
     notify(State, {columns, Columns}),
     notify(State, {columns, Columns}),
     {next_state, querying, State#state{statement = S2}};
     {next_state, querying, State#state{statement = S2}};
 
 
@@ -496,12 +496,18 @@ decode_columns(N, Bin, Acc) ->
     decode_columns(N - 1, Rest2, [Desc | Acc]).
     decode_columns(N - 1, Rest2, [Desc | Acc]).
 
 
 %% decode command complete msg
 %% decode command complete msg
+decode_complete(<<"SELECT", 0>>)   -> select;
+decode_complete(<<"BEGIN", 0>>)    -> 'begin';
+decode_complete(<<"ROLLBACK", 0>>) -> rollback;
 decode_complete(Bin) ->
 decode_complete(Bin) ->
     {Str, _} = decode_string(Bin),
     {Str, _} = decode_string(Bin),
     case string:tokens(binary_to_list(Str), " ") of
     case string:tokens(binary_to_list(Str), " ") of
-        [Type]             -> lower_atom(Type);
-        [Type, _Oid, Rows] -> {lower_atom(Type), list_to_integer(Rows)};
-        [Type, Rows]       -> {lower_atom(Type), list_to_integer(Rows)}
+        ["INSERT", _Oid, Rows] -> {insert, list_to_integer(Rows)};
+        ["UPDATE", Rows]       -> {update, list_to_integer(Rows)};
+        ["DELETE", Rows]       -> {delete, list_to_integer(Rows)};
+        ["MOVE", Rows]         -> {move, list_to_integer(Rows)};
+        ["FETCH", _Rows]       -> fetch;
+        [Type | _Rest]         -> lower_atom(Type)
     end.
     end.
 
 
 %% decode ErrorResponse
 %% decode ErrorResponse
@@ -517,7 +523,7 @@ decode_error(Bin) ->
 decode_error_extra(Fields) ->
 decode_error_extra(Fields) ->
     Types = [{$D, detail}, {$H, hint}, {$P, position}],
     Types = [{$D, detail}, {$H, hint}, {$P, position}],
     decode_error_extra(Types, Fields, []).
     decode_error_extra(Types, Fields, []).
-    
+
 decode_error_extra([], _Fields, Extra) ->
 decode_error_extra([], _Fields, Extra) ->
     Extra;
     Extra;
 decode_error_extra([{Type, Name} | T], Fields, Extra) ->
 decode_error_extra([{Type, Name} | T], Fields, Extra) ->
@@ -608,7 +614,7 @@ hex(Bin) ->
 
 
 send(#state{sock = Sock}, Type, Data) ->
 send(#state{sock = Sock}, Type, Data) ->
     Bin = iolist_to_binary(Data),
     Bin = iolist_to_binary(Data),
-    gen_tcp:send(Sock, <<Type:8, (byte_size(Bin) + 4):?int32, Bin/binary>>).    
+    gen_tcp:send(Sock, <<Type:8, (byte_size(Bin) + 4):?int32, Bin/binary>>).
 
 
 send(#state{sock = Sock}, Data) ->
 send(#state{sock = Sock}, Data) ->
     Bin = iolist_to_binary(Data),
     Bin = iolist_to_binary(Data),
@@ -636,7 +642,7 @@ decode(Fsm, Sock, <<Type:8, Len:?int32, Rest/binary>> = Bin) ->
             gen_fsm:send_event(Fsm, {Type, Data}),
             gen_fsm:send_event(Fsm, {Type, Data}),
             decode(Fsm, Sock, Tail);
             decode(Fsm, Sock, Tail);
         _Other ->
         _Other ->
-            read(Fsm, Sock, Bin)
+            ?MODULE:read(Fsm, Sock, Bin)
     end;
     end;
 decode(Fsm, Sock, Bin) ->
 decode(Fsm, Sock, Bin) ->
-    read(Fsm, Sock, Bin).
+    ?MODULE:read(Fsm, Sock, Bin).

+ 20 - 1
test_src/pgsql_tests.erl

@@ -66,6 +66,25 @@ update_test() ->
               {ok, _, [{<<"2">>}]} = pgsql:squery(C, "select count(*) from test_table1 where value = 'foo'")
               {ok, _, [{<<"2">>}]} = pgsql:squery(C, "select count(*) from test_table1 where value = 'foo'")
       end).
       end).
 
 
+create_and_drop_table_test() ->
+    with_rollback(
+      fun(C) ->
+              {ok, [], []} = pgsql:squery(C, "create table test_table3 (id int4)"),
+              {ok, [#column{type = int4}], []} = pgsql:squery(C, "select * from test_table3"),
+              {ok, [], []} = pgsql:squery(C, "drop table test_table3")
+      end).
+
+cursor_test() ->
+    with_connection(
+      fun(C) ->
+              {ok, [], []} = pgsql:squery(C, "begin"),
+              {ok, [], []} = pgsql:squery(C, "declare c cursor for select id from test_table1"),
+              {ok, 2} = pgsql:squery(C, "move forward 2 from c"),
+              {ok, 1} = pgsql:squery(C, "move backward 1 from c"),
+              {ok, _Cols, [{<<"2">>}]} = pgsql:squery(C, "fetch next from c"),
+              {ok, [], []} = pgsql:squery(C, "close c")
+              end).
+
 multiple_result_test() ->
 multiple_result_test() ->
     with_connection(
     with_connection(
       fun(C) ->
       fun(C) ->
@@ -219,7 +238,7 @@ describe_error_test() ->
               {ok, S} = pgsql:parse(C, "select * from test_table1"),
               {ok, S} = pgsql:parse(C, "select * from test_table1"),
               {ok, S} = pgsql:describe(C, statement, ""),
               {ok, S} = pgsql:describe(C, statement, ""),
               ok = pgsql:sync(C)
               ok = pgsql:sync(C)
-      
+
       end).
       end).
 
 
 portal_test() ->
 portal_test() ->