Browse Source

Prepared statements partially + various fixes

Viktor Söderqvist 10 years ago
parent
commit
3c9e335460
3 changed files with 76 additions and 10 deletions
  1. 6 1
      include/records.hrl
  2. 38 8
      src/mysql_protocol.erl
  3. 32 1
      test/mysql_protocol_tests.erl

+ 6 - 1
include/records.hrl

@@ -21,7 +21,6 @@
 %% EOF packet, commonly used in the protocol.
 -record(eof, {status, warning_count}).
 
-
 %% Column definition, used while parsing a result set.
 -record(column_definition, {name, type, charset}).
 
@@ -29,3 +28,9 @@
 %% All values are binary (SQL code) except NULL.
 -record(text_resultset, {column_definitions :: [#column_definition{}],
                          rows :: [[binary() | null]]}).
+
+%% Response of a successfull prepare call.
+-record(prepared, {statement_id :: integer(),
+                   params :: [#column_definition{}],
+                   columns :: [#column_definition{}],
+                   warning_count :: integer()}).

+ 38 - 8
src/mysql_protocol.erl

@@ -8,7 +8,8 @@
 -module(mysql_protocol).
 
 -export([handshake/5,
-         query/3]).
+         query/3,
+         prepare/3]).
 
 -export_type([sendfun/0, recvfun/0]).
 
@@ -184,10 +185,34 @@ query(Query, SendFun, RecvFun) ->
             fetch_resultset(RecvFun, FieldCount, SeqNum2)
     end.
 
+%% @doc Prepares a statement.
+-spec prepare(iodata(), sendfun(), recvfun()) -> #error{} | #prepared{}.
 prepare(Query, SendFun, RecvFun) ->
     Req = <<?COM_STMT_PREPARE, (iolist_to_binary(Query))/binary>>,
     {ok, SeqNum1} = send_packet(SendFun, Req, 0),
     {ok, Resp, SeqNum2} = recv_packet(RecvFun, SeqNum1),
+    case Resp of
+        ?error_pattern ->
+            parse_error_packet(Resp);
+        <<?OK,
+          StmtId:32/little,
+          NumColumns:16/little,
+          NumParams:16/little,
+          0, %% reserved_1 -- [00] filler
+          WarningCount:16/little>> ->
+            %% This was the first packet.
+            %% If NumParams > 0 more packets will follow:
+            {ok, ParamDefs, SeqNum3} =
+                fetch_column_definitions(RecvFun, SeqNum2, NumParams, []),
+            {ok, ?eof_pattern, SeqNum4} = recv_packet(RecvFun, SeqNum3),
+            {ok, ColDefs, SeqNum5} =
+                fetch_column_definitions(RecvFun, SeqNum4, NumColumns, []),
+            {ok, ?eof_pattern, _SeqNum6} = recv_packet(RecvFun, SeqNum5),
+            #prepared{statement_id = StmtId,
+                      params = ParamDefs,
+                      columns = ColDefs,
+                      warning_count = WarningCount}
+    end.
 
 -spec fetch_resultset(recvfun(), integer(), integer()) ->
     #text_resultset{} | #error{}.
@@ -195,14 +220,12 @@ fetch_resultset(RecvFun, FieldCount, SeqNum) ->
     {ok, ColDefs, SeqNum1} = fetch_column_definitions(RecvFun, SeqNum,
                                                       FieldCount, []),
     {ok, DelimiterPacket, SeqNum2} = recv_packet(RecvFun, SeqNum1),
-    case DelimiterPacket of
-        ?eof_pattern ->
-            #eof{} = parse_eof_packet(DelimiterPacket),
-            {ok, Rows, _SeqNum3} = fetch_resultset_rows(RecvFun, ColDefs,
-                                                        SeqNum2, []),
+    #eof{} = parse_eof_packet(DelimiterPacket),
+    case fetch_resultset_rows(RecvFun, ColDefs, SeqNum2, []) of
+        {ok, Rows, _SeqNum3} ->
             #text_resultset{column_definitions = ColDefs, rows = Rows};
-        ?error_pattern ->
-            parse_error_packet(DelimiterPacket)
+        #error{} = E ->
+            E
     end.
 
 %% Receives NumLeft packets and parses them as column definitions.
@@ -216,9 +239,16 @@ fetch_column_definitions(RecvFun, SeqNum, NumLeft, Acc) when NumLeft > 0 ->
 fetch_column_definitions(_RecvFun, SeqNum, 0, Acc) ->
     {ok, lists:reverse(Acc), SeqNum}.
 
+-spec fetch_resultset_rows(recvfun(), ColumnDefinitions, integer(),
+                           Acc) -> {ok, Rows, integer()} | #error{}
+    when ColumnDefinitions :: [#column_definition{}],
+         Acc :: [[binary() | null]],
+         Rows :: [[binary() | null]].
 fetch_resultset_rows(RecvFun, ColDefs, SeqNum, Acc) ->
     {ok, Packet, SeqNum1} = recv_packet(RecvFun, SeqNum),
     case Packet of
+        ?error_pattern ->
+            parse_error_packet(Packet);
         ?eof_pattern ->
             {ok, lists:reverse(Acc), SeqNum1};
         _AnotherRow ->

+ 32 - 1
test/protocol_tests.erl → test/mysql_protocol_tests.erl

@@ -1,5 +1,5 @@
 %% @doc Eunit test cases for the mysql_protocol module.
--module(protocol_tests).
+-module(mysql_protocol_tests).
 
 -include_lib("eunit/include/eunit.hrl").
 
@@ -69,6 +69,37 @@ resultset_error_test() ->
     fakesocket_close(Sock),
     ok.
 
+prepare_test() ->
+    %% Prepared statement. The example from "14.7.4 COM_STMT_PREPARE" in the
+    %% "MySQL Internals" guide.
+    Query = <<"SELECT CONCAT(?, ?) AS col1">>,
+    ExpectedReq = hexdump_to_bin(
+        "1c 00 00 00 16 53 45 4c    45 43 54 20 43 4f 4e 43    .....SELECT CONC"
+        "41 54 28 3f 2c 20 3f 29    20 41 53 20 63 6f 6c 31    AT(?, ?) AS col1"
+        ),
+    ExpectedResp = hexdump_to_bin(
+        "0c 00 00 01 00 01 00 00    00 01 00 02 00 00 00 00|   ................"
+        "17 00 00 02 03 64 65 66    00 00 00 01 3f 00 0c 3f    .....def....?..?"
+        "00 00 00 00 00 fd 80 00    00 00 00|17 00 00 03 03    ................"
+        "64 65 66 00 00 00 01 3f    00 0c 3f 00 00 00 00 00    def....?..?....."
+        "fd 80 00 00 00 00|05 00    00 04 fe 00 00 02 00|1a    ................"
+        "00 00 05 03 64 65 66 00    00 00 04 63 6f 6c 31 00    ....def....col1."
+        "0c 3f 00 00 00 00 00 fd    80 00 1f 00 00|05 00 00    .?.............."
+        "06 fe 00 00 02 00                                     ......"),
+    Sock = fakesocket_create([{send, ExpectedReq}, {recv, ExpectedResp}]),
+    SendFun = fun (Data) -> fakesocket_send(Sock, Data) end,
+    RecvFun = fun (Size) -> fakesocket_recv(Sock, Size) end,
+    Result = mysql_protocol:prepare(Query, SendFun, RecvFun),
+    fakesocket_close(Sock),
+    ?assertMatch(#prepared{statement_id = StmtId,
+                           params = [#column_definition{name = <<"?">>},
+                                     #column_definition{name = <<"?">>}],
+                           columns = [#column_definition{name = <<"col1">>}],
+                           warning_count = 0} when is_integer(StmtId),
+                 Result),
+    ok.
+    
+
 %% --- Helper functions for the above tests ---
 
 %% Convert hex dumps to binaries. This is a helper function for the tests.