Browse Source

Merge branch 'hstore' of git://github.com/bullno1/epgsql into devel

David N. Welton 11 years ago
parent
commit
c41fb595bd
4 changed files with 73 additions and 3 deletions
  1. 33 0
      src/pgsql_binary.erl
  2. 4 0
      src/pgsql_types.erl
  3. 33 2
      test/pgsql_tests.erl
  4. 3 1
      test_data/test_schema.sql

+ 33 - 0
src/pgsql_binary.erl

@@ -28,6 +28,7 @@ encode(bytea, B) when is_binary(B)          -> <<(byte_size(B)):?int32, B/binary
 encode(text, B) when is_binary(B)           -> <<(byte_size(B)):?int32, B/binary>>;
 encode(varchar, B) when is_binary(B)        -> <<(byte_size(B)):?int32, B/binary>>;
 encode(uuid, B) when is_binary(B)           -> encode_uuid(B);
+encode(hstore, {L}) when is_list(L)         -> encode_hstore(L);
 encode({array, char}, L) when is_list(L)    -> encode_array(bpchar, L);
 encode({array, Type}, L) when is_list(L)    -> encode_array(Type, L);
 encode(Type, L) when is_list(L)             -> encode(Type, list_to_binary(L));
@@ -49,6 +50,7 @@ decode(timestamp = Type, B)                 -> ?datetime:decode(Type, B);
 decode(timestamptz = Type, B)               -> ?datetime:decode(Type, B);
 decode(interval = Type, B)                  -> ?datetime:decode(Type, B);
 decode(uuid, B)                             -> decode_uuid(B);
+decode(hstore, Hstore)                      -> decode_hstore(Hstore);
 decode({array, _Type}, B)                   -> decode_array(B);
 decode(_Other, Bin)                         -> Bin.
 
@@ -83,6 +85,26 @@ encode_uuid(U) ->
     {ok, [Int], _} = io_lib:fread("~16u", Hex),
     <<16:?int32,Int:128>>.
 
+encode_hstore(HstoreEntries) ->
+    Body = << <<(encode_hstore_entry(Entry))/binary>> || Entry <- HstoreEntries >>,
+    <<(byte_size(Body) + 4):?int32, (length(HstoreEntries)):?int32, Body/binary>>.
+
+encode_hstore_entry({Key, Value}) ->
+    <<(encode_hstore_key(Key))/binary, (encode_hstore_value(Value))/binary>>.
+
+encode_hstore_key(Key) -> encode_hstore_string(Key).
+
+encode_hstore_value(null) -> <<-1:?int32>>;
+encode_hstore_value(Val) -> encode_hstore_string(Val).
+
+encode_hstore_string(Str) when is_list(Str) -> encode_hstore_string(list_to_binary(Str));
+encode_hstore_string(Str) when is_atom(Str) -> encode_hstore_string(atom_to_binary(Str, utf8));
+encode_hstore_string(Str) when is_integer(Str) ->
+    encode_hstore_string(erlang:integer_to_binary(Str));
+encode_hstore_string(Str) when is_float(Str) ->
+    encode_hstore_string(iolist_to_binary(io_lib:format("~w", [Str])));
+encode_hstore_string(Str) when is_binary(Str) -> <<(byte_size(Str)):?int32, Str/binary>>.
+
 decode_array(<<NDims:?int32, _HasNull:?int32, Oid:?int32, Rest/binary>>) ->
     {Dims, Data} = erlang:split_binary(Rest, NDims * 2 * 4),
     Lengths = [Len || <<Len:?int32, _LBound:?int32>> <= Dims],
@@ -118,6 +140,15 @@ decode_uuid(<<U0:32, U1:16, U2:16, U3:16, U4:48>>) ->
     Format = "~8.16.0b-~4.16.0b-~4.16.0b-~4.16.0b-~12.16.0b",
     iolist_to_binary(io_lib:format(Format, [U0, U1, U2, U3, U4])).
 
+decode_hstore(<<NumElements:?int32, Elements/binary>>) ->
+    {decode_hstore1(NumElements, Elements, [])}.
+
+decode_hstore1(0, _Elements, Acc) -> Acc;
+decode_hstore1(N, <<KeyLen:?int32, Key:KeyLen/binary, -1:?int32, Rest/binary>>, Acc) ->
+    decode_hstore1(N - 1, Rest, [{Key, null} | Acc]);
+decode_hstore1(N, <<KeyLen:?int32, Key:KeyLen/binary, ValLen:?int32, Value:ValLen/binary, Rest/binary>>, Acc) ->
+    decode_hstore1(N - 1, Rest, [{Key, Value} | Acc]).
+
 supports(bool)    -> true;
 supports(bpchar)  -> true;
 supports(int2)    -> true;
@@ -136,6 +167,7 @@ supports(timestamp)   -> true;
 supports(timestamptz) -> true;
 supports(interval)    -> true;
 supports(uuid)        -> true;
+supports(hstore)      -> true;
 supports({array, bool})   -> true;
 supports({array, int2})   -> true;
 supports({array, int4})   -> true;
@@ -150,6 +182,7 @@ supports({array, timetz}) -> true;
 supports({array, timestamp})     -> true;
 supports({array, timestamptz})   -> true;
 supports({array, interval})      -> true;
+supports({array, hstore})        -> true;
 supports({array, varchar}) -> true;
 supports({array, uuid})   -> true;
 supports(_Type)       -> false.

+ 4 - 0
src/pgsql_types.erl

@@ -93,6 +93,8 @@ oid2type(2776) -> anynonarray;
 oid2type(2950) -> uuid;
 oid2type(2951) -> {array, uuid};
 oid2type(3500) -> anyenum;
+oid2type(16831) -> hstore;
+oid2type(16836) -> {array, hstore};
 oid2type(Oid)  -> {unknown_oid, Oid}.
 
 type2oid(bool)                  -> 16;
@@ -186,4 +188,6 @@ type2oid(anynonarray)           -> 2776;
 type2oid(uuid)                  -> 2950;
 type2oid({array, uuid})         -> 2951;
 type2oid(anyenum)               -> 3500;
+type2oid(hstore)                -> 16831;
+type2oid({array, hstore})       -> 16836;
 type2oid(Type)                  -> {unknown_type, Type}.

+ 33 - 2
test/pgsql_tests.erl

@@ -545,6 +545,19 @@ misc_type_test(Module) ->
     check_type(Module, bool, "true", true, [true, false]),
     check_type(Module, bytea, "E'\001\002'", <<1,2>>, [<<>>, <<0,128,255>>]).
 
+hstore_type_test(Module) ->
+    Values = [
+        {[]},
+        {[{null, null}]},
+        {[{1, null}]},
+        {[{1.0, null}]},
+        {[{<<"a">>, <<"c">>}, {<<"c">>, <<"d">>}]},
+        {[{<<"a">>, <<"c">>}, {<<"c">>, null}]}
+    ],
+    check_type(Module, hstore, "''", {[]}, []),
+    check_type(Module, hstore, "'a => 1, b => 2.0, c => null'",
+               {[{<<"c">>, null}, {<<"b">>, <<"2.0">>}, {<<"a">>, <<"1">>}]}, Values).
+
 array_type_test(Module) ->
     with_connection(
       Module,
@@ -575,7 +588,9 @@ array_type_test(Module) ->
           Select(timetz, [{{0,1,2.0},1*60*60}, {{0,1,3.0},1*60*60}]),
           Select(timestamp, [{{2008,1,2},{3,4,5.0}}, {{2008,1,2},{3,4,6.0}}]),
           Select(timestamptz, [{{2008,1,2},{3,4,5.0}}, {{2008,1,2},{3,4,6.0}}]),
-          Select(interval, [{{1,2,3.1},0,0}, {{1,2,3.2},0,0}])
+          Select(interval, [{{1,2,3.1},0,0}, {{1,2,3.2},0,0}]),
+          Select(hstore, [{[{null, null}, {a, 1}, {1, 2}]}]),
+          Select(hstore, [[{[{null, null}, {a, 1}, {1, 2}]}, {[]}], [{[{a, 1}]}, {[{null, 2}]}]])
       end).
 
 text_format_test(Module) ->
@@ -815,7 +830,7 @@ check_type(Module, Type, In, Out, Values, Column) ->
               Sql = io_lib:format("insert into test_table2 (~s) values ($1) returning ~s", [Column, Column]),
               {ok, #statement{columns = [#column{type = Type}]} = S} = Module:parse(C, Sql),
               Insert = fun(V) ->
-                               Module:bind(C, S, [V]),
+                               ok = Module:bind(C, S, [V]),
                                {ok, 1, [{V2}]} = Module:execute(C, S),
                                case compare(Type, V, V2) of
                                    true  -> ok;
@@ -829,12 +844,28 @@ check_type(Module, Type, In, Out, Values, Column) ->
 compare(_Type, null, null) -> true;
 compare(float4, V1, V2)    -> abs(V2 - V1) < 0.000001;
 compare(float8, V1, V2)    -> abs(V2 - V1) < 0.000000000000001;
+compare(hstore, {V1}, V2)  -> compare(hstore, V1, V2);
+compare(hstore, V1, {V2})  -> compare(hstore, V1, V2);
+compare(hstore, V1, V2)    ->
+    orddict:from_list(format_hstore(V1)) =:= orddict:from_list(format_hstore(V2));
 compare(Type, V1 = {_, _, MS}, {D2, {H2, M2, S2}}) when Type == timestamp;
                                                         Type == timestamptz ->
     {D1, {H1, M1, S1}} = calendar:now_to_universal_time(V1),
     ({D1, H1, M1} =:= {D2, H2, M2}) and (abs(S1 + MS/1000000 - S2) < 0.000000000000001);
 compare(_Type, V1, V2)     -> V1 =:= V2.
 
+format_hstore({Hstore}) -> Hstore;
+format_hstore(Hstore) ->
+    [{format_hstore_key(Key), format_hstore_value(Value)} || {Key, Value} <- Hstore].
+
+format_hstore_key(Key) -> format_hstore_string(Key).
+
+format_hstore_value(null) -> null;
+format_hstore_value(Value) -> format_hstore_string(Value).
+
+format_hstore_string(Num) when is_number(Num) -> iolist_to_binary(io_lib:format("~w", [Num]));
+format_hstore_string(Str) -> iolist_to_binary(io_lib:format("~s", [Str])).
+
 %% flush mailbox
 flush() ->
     ?assertEqual([], flush([])).

+ 3 - 1
test_data/test_schema.sql

@@ -35,6 +35,7 @@ GRANT ALL ON DATABASE epgsql_test_db2 to epgsql_test;
 
 \c epgsql_test_db1;
 
+CREATE EXTENSION hstore;
 CREATE TABLE test_table1 (id integer primary key, value text);
 
 INSERT INTO test_table1 (id, value) VALUES (1, 'one');
@@ -57,7 +58,8 @@ CREATE TABLE test_table2 (
   c_timetz timetz,
   c_timestamp timestamp,
   c_timestamptz timestamptz,
-  c_interval interval);
+  c_interval interval,
+  c_hstore hstore);
 
 CREATE LANGUAGE plpgsql;