Просмотр исходного кода

Named prepared statements + unprepare/2 + {error, not_prepared} for execute/3

Viktor Söderqvist 10 лет назад
Родитель
Сommit
4a1abc48bc
4 измененных файлов с 108 добавлено и 23 удалено
  1. 18 2
      src/mysql.erl
  2. 53 17
      src/mysql_connection.erl
  3. 7 1
      src/mysql_protocol.erl
  4. 30 3
      test/mysql_tests.erl

+ 18 - 2
src/mysql.erl

@@ -22,8 +22,9 @@
 %% `connection()' type is the same as returned by `gen_server:start_link/2,3'.
 -module(mysql).
 
--export([start_link/1, query/2, execute/3, prepare/2, warning_count/1,
-         affected_rows/1, autocommit/1, insert_id/1, in_transaction/1,
+-export([start_link/1, query/2, execute/3, prepare/2, prepare/3, unprepare/2,
+         warning_count/1, affected_rows/1, autocommit/1, insert_id/1,
+         in_transaction/1,
          transaction/2, transaction/3]).
 
 -export_type([connection/0]).
@@ -74,6 +75,21 @@ execute(Conn, StatementId, Args) ->
 prepare(Conn, Query) ->
     gen_server:call(Conn, {prepare, Query}).
 
+%% @doc Creates a prepared statement from the passed query and associates it
+%% with the given name.
+%% @see execute/3
+-spec prepare(Conn :: connection(), Name :: term(), Query :: iodata()) ->
+    {ok, Name :: term()} | {error, Reason :: reason()}.
+prepare(Conn, Name, Query) ->
+    gen_server:call(Conn, {prepare, Name, Query}).
+
+%% @doc Deallocates a prepared statement.
+%% @see prepare/3
+-spec unprepare(Conn :: connection(), StatementRef :: term()) ->
+    ok | {error, not_prepared} | {error, Reason :: reason()}.
+unprepare(Conn, StatementRef) ->
+    gen_server:call(Conn, {unprepare, StatementRef}).
+
 %% @doc Returns the number of warnings generated by the last query/2 or
 %% execute/3 calls.
 -spec warning_count(connection()) -> integer().

+ 53 - 17
src/mysql_connection.erl

@@ -85,23 +85,25 @@ handle_call({query, Query}, _From, State) when is_binary(Query);
             Names = [Def#column_definition.name || Def <- ColDefs],
             {reply, {ok, Names, Rows}, State1}
     end;
-handle_call({execute, Stmt, Args}, _From, State) when is_integer(Stmt);
-                                                      is_atom(Stmt) ->
-    %% TODO: Return {error, not_prepared} instead of crashing if not found.
-    StmtRec = dict:fetch(Stmt, State#state.stmts),
-    #state{socket = Socket, timeout = Timeout} = State,
-    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
-    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
-    Rec = mysql_protocol:execute(StmtRec, Args, SendFun, RecvFun),
-    State1 = update_state(State, Rec),
-    case Rec of
-        #ok{} ->
-            {reply, ok, State1};
-        #error{} = E ->
-            {reply, {error, error_to_reason(E)}, State1};
-        #resultset{column_definitions = ColDefs, rows = Rows} ->
-            Names = [Def#column_definition.name || Def <- ColDefs],
-            {reply, {ok, Names, Rows}, State1}
+handle_call({execute, Stmt, Args}, _From, State) ->
+    case dict:find(Stmt, State#state.stmts) of
+        {ok, StmtRec} ->
+            #state{socket = Socket, timeout = Timeout} = State,
+            SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
+            RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+            Rec = mysql_protocol:execute(StmtRec, Args, SendFun, RecvFun),
+            State1 = update_state(State, Rec),
+            case Rec of
+                #ok{} ->
+                    {reply, ok, State1};
+                #error{} = E ->
+                    {reply, {error, error_to_reason(E)}, State1};
+                #resultset{column_definitions = ColDefs, rows = Rows} ->
+                    Names = [Def#column_definition.name || Def <- ColDefs],
+                    {reply, {ok, Names, Rows}, State1}
+            end;
+        error ->
+            {reply, {error, not_prepared}, State}
     end;
 handle_call({prepare, Query}, _From, State) ->
     #state{socket = Socket, timeout = Timeout} = State,
@@ -117,6 +119,40 @@ handle_call({prepare, Query}, _From, State) ->
             State2 = State#state{stmts = Stmts1},
             {reply, {ok, Id}, State2}
     end;
+handle_call({prepare, Name, Query}, _From, State) when is_atom(Name) ->
+    #state{socket = Socket, timeout = Timeout} = State,
+    SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
+    RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+    %% First unprepare if there is an old statement with this name.
+    State1 = case dict:find(Name, State#state.stmts) of
+        {ok, OldStmt} ->
+            mysql_protocol:unprepare(OldStmt, SendFun, RecvFun),
+            State#state{stmts = dict:erase(Name, State#state.stmts)};
+        error ->
+            State
+    end,
+    Rec = mysql_protocol:prepare(Query, SendFun, RecvFun),
+    State2 = update_state(State1, Rec),
+    case Rec of
+        #error{} = E ->
+            {reply, {error, error_to_reason(E)}, State2};
+        #prepared{} = Stmt ->
+            Stmts1 = dict:store(Name, Stmt, State2#state.stmts),
+            State3 = State2#state{stmts = Stmts1},
+            {reply, {ok, Name}, State3}
+    end;
+handle_call({unprepare, Name}, _From, State) ->
+    case dict:find(Name, State#state.stmts) of
+        {ok, StmtRec} ->
+            #state{socket = Socket, timeout = Timeout} = State,
+            SendFun = fun (Data) -> gen_tcp:send(Socket, Data) end,
+            RecvFun = fun (Size) -> gen_tcp:recv(Socket, Size, Timeout) end,
+            mysql_protocol:unprepare(StmtRec, SendFun, RecvFun),
+            Stmts1 = dict:erase(Name, State#state.stmts),
+            {reply, ok, State#state{stmts = Stmts1}};
+        error ->
+            {reply, {error, not_prepared}, State}
+    end;
 handle_call(warning_count, _From, State) ->
     {reply, State#state.warning_count, State};
 handle_call(insert_id, _From, State) ->

+ 7 - 1
src/mysql_protocol.erl

@@ -27,7 +27,7 @@
 
 -export([handshake/5,
          query/3,
-         prepare/3, execute/4]).
+         prepare/3, unprepare/3, execute/4]).
 
 -export_type([sendfun/0, recvfun/0]).
 
@@ -124,6 +124,12 @@ prepare(Query, SendFun, RecvFun) ->
                       warning_count = WarningCount}
     end.
 
+%% @doc Deallocates a prepared statement.
+-spec unprepare(#prepared{}, sendfun(), recvfun()) -> ok.
+unprepare(#prepared{statement_id = Id}, SendFun, _RecvFun) ->
+    {ok, _SeqNum} = send_packet(SendFun, <<?COM_STMT_CLOSE, Id:32/little>>, 0),
+    ok.
+
 %% @doc Executes a prepared statement.
 -spec execute(#prepared{}, [term()], sendfun(), recvfun()) -> #resultset{}.
 execute(#prepared{statement_id = Id, param_count = ParamCount}, ParamValues,

+ 30 - 3
test/mysql_tests.erl

@@ -286,9 +286,9 @@ run_test_microseconds(Pid) ->
 
 %% --------------------------------------------------------------------------
 
-%% Transaction tests
+%% Prepared statements and transactions
 
-transaction_single_connection_test_() ->
+with_table_foo_test_() ->
     {setup,
      fun () ->
          {ok, Pid} = mysql:start_link([{user, ?user}, {password, ?password}]),
@@ -302,9 +302,36 @@ transaction_single_connection_test_() ->
          ok = mysql:query(Pid, <<"DROP DATABASE otptest">>),
          exit(Pid, normal)
      end,
-     {with, [fun transaction_simple_success/1,
+     {with, [fun prepared_statements/1,
+             fun transaction_simple_success/1,
              fun transaction_simple_aborted/1]}}.
 
+prepared_statements(Pid) ->
+    %% Unnamed
+    ?assertEqual({error,{1146, <<"42S02">>,
+                         <<"Table 'otptest.tab' doesn't exist">>}},
+                 mysql:prepare(Pid, "SELECT * FROM tab WHERE id = ?")),
+    {ok, StmtId} = mysql:prepare(Pid, "SELECT * FROM foo WHERE bar = ?"),
+    ?assert(is_integer(StmtId)),
+    ?assertEqual(ok, mysql:unprepare(Pid, StmtId)),
+    ?assertEqual({error, not_prepared}, mysql:unprepare(Pid, StmtId)),
+
+    %% Named
+    ?assertEqual({error,{1146, <<"42S02">>,
+                         <<"Table 'otptest.tab' doesn't exist">>}},
+                 mysql:prepare(Pid, tab, "SELECT * FROM tab WHERE id = ?")),
+    ?assertEqual({ok, foo},
+                 mysql:prepare(Pid, foo, "SELECT * FROM foo WHERE bar = ?")),
+    %% Prepare again unprepares the old stmt associated with this name.
+    ?assertEqual({ok, foo},
+                 mysql:prepare(Pid, foo, "SELECT bar FROM foo WHERE bar = ?")),
+    ?assertEqual(ok, mysql:unprepare(Pid, foo)),
+    ?assertEqual({error, not_prepared}, mysql:unprepare(Pid, foo)),
+
+    %% Execute when not prepared
+    ?assertEqual({error, not_prepared}, mysql:execute(Pid, not_a_stmt, [])),
+    ok.
+
 transaction_simple_success(Pid) ->
     ?assertNot(mysql:in_transaction(Pid)),
     Result = mysql:transaction(Pid, fun () ->