Просмотр исходного кода

add support for encoding and decoding text and int arrays

Will 14 лет назад
Родитель
Сommit
688a55dc0c
3 измененных файлов с 117 добавлено и 23 удалено
  1. 89 22
      src/pgsql_binary.erl
  2. 10 1
      src/pgsql_types.erl
  3. 18 0
      test_src/pgsql_tests.erl

+ 89 - 22
src/pgsql_binary.erl

@@ -7,24 +7,33 @@
 -define(int32, 1/big-signed-unit:32).
 -define(datetime, (get(datetime_mod))).
 
-encode(_Any, null)  -> <<-1:?int32>>;
-encode(bool, true)  -> <<1:?int32, 1:1/big-signed-unit:8>>;
-encode(bool, false) -> <<1:?int32, 0:1/big-signed-unit:8>>;
-encode(int2, N)     -> <<2:?int32, N:1/big-signed-unit:16>>;
-encode(int4, N)     -> <<4:?int32, N:1/big-signed-unit:32>>;
-encode(int8, N)     -> <<8:?int32, N:1/big-signed-unit:64>>;
-encode(float4, N)   -> <<4:?int32, N:1/big-float-unit:32>>;
-encode(float8, N)   -> <<8:?int32, N:1/big-float-unit:64>>;
-encode(bpchar, C) when is_integer(C) -> <<1:?int32, C:1/big-unsigned-unit:8>>;
-encode(bpchar, B) when is_binary(B)  -> <<(byte_size(B)):?int32, B/binary>>;
-encode(Type, B) when Type == time; Type == timetz          -> ?datetime:encode(Type, B);
-encode(Type, B) when Type == date; Type == timestamp       -> ?datetime:encode(Type, B);
-encode(Type, B) when Type == timestamptz; Type == interval -> ?datetime:encode(Type, B);
-encode(bytea, B) when is_binary(B)   -> <<(byte_size(B)):?int32, B/binary>>;
-encode(text, B) when is_binary(B)    -> <<(byte_size(B)):?int32, B/binary>>;
-encode(varchar, B) when is_binary(B) -> <<(byte_size(B)):?int32, B/binary>>;
-encode(Type, L) when is_list(L)      -> encode(Type, list_to_binary(L));
-encode(_Type, _Value)                -> {error, unsupported}.
+encode(_Any, null)                          -> <<-1:?int32>>;
+encode(bool, true)                          -> <<1:?int32, 1:1/big-signed-unit:8>>;
+encode(bool, false)                         -> <<1:?int32, 0:1/big-signed-unit:8>>;
+encode(int2, N)                             -> <<2:?int32, N:1/big-signed-unit:16>>;
+encode(int4, N)                             -> <<4:?int32, N:1/big-signed-unit:32>>;
+encode(int8, N)                             -> <<8:?int32, N:1/big-signed-unit:64>>;
+encode(float4, N)                           -> <<4:?int32, N:1/big-float-unit:32>>;
+encode(float8, N)                           -> <<8:?int32, N:1/big-float-unit:64>>;
+encode(bpchar, C) when is_integer(C)        -> <<1:?int32, C:1/big-unsigned-unit:8>>;
+encode(bpchar, B) when is_binary(B)         -> <<(byte_size(B)):?int32, B/binary>>;
+encode(time = Type, B)                      -> ?datetime:encode(Type, B);
+encode(timetz = Type, B)                    -> ?datetime:encode(Type, B);
+encode(date = Type, B)                      -> ?datetime:encode(Type, B);
+encode(timestamp = Type, B)                 -> ?datetime:encode(Type, B);
+encode(timestamptz = Type, B)               -> ?datetime:encode(Type, B);
+encode(interval = Type, B)                  -> ?datetime:encode(Type, B);
+encode(bytea, B) when is_binary(B)          -> <<(byte_size(B)):?int32, B/binary>>;
+encode(text, B) when is_binary(B)           -> <<(byte_size(B)):?int32, B/binary>>;
+encode(varchar, B) when is_binary(B)        -> <<(byte_size(B)):?int32, B/binary>>;
+encode(boolarray, L) when is_list(L)        -> encode_array(bool, L);
+encode(int2array, L) when is_list(L)        -> encode_array(int2, L);
+encode(int4array, L) when is_list(L)        -> encode_array(int4, L);
+encode(int8array, L) when is_list(L)        -> encode_array(int8, L);
+encode(chararray, L) when is_list(L)        -> encode_array(bpchar, L);
+encode(textarray, L) when is_list(L)        -> encode_array(text, L);
+encode(Type, L) when is_list(L)             -> encode(Type, list_to_binary(L));
+encode(_Type, _Value)                       -> {error, unsupported}.
 
 decode(bool, <<1:1/big-signed-unit:8>>)     -> true;
 decode(bool, <<0:1/big-signed-unit:8>>)     -> false;
@@ -35,10 +44,62 @@ decode(int8, <<N:1/big-signed-unit:64>>)    -> N;
 decode(float4, <<N:1/big-float-unit:32>>)   -> N;
 decode(float8, <<N:1/big-float-unit:64>>)   -> N;
 decode(record, <<_:?int32, Rest/binary>>)   -> list_to_tuple(decode_record(Rest, []));
-decode(Type, B) when Type == time; Type == timetz          -> ?datetime:decode(Type, B);
-decode(Type, B) when Type == date; Type == timestamp       -> ?datetime:decode(Type, B);
-decode(Type, B) when Type == timestamptz; Type == interval -> ?datetime:decode(Type, B);
-decode(_Other, Bin) -> Bin.
+decode(time = Type, B)                      -> ?datetime:decode(Type, B);
+decode(timetz = Type, B)                    -> ?datetime:decode(Type, B);
+decode(date = Type, B)                      -> ?datetime:decode(Type, B);
+decode(timestamp = Type, B)                 -> ?datetime:decode(Type, B);
+decode(timestamptz = Type, B)               -> ?datetime:decode(Type, B);
+decode(interval = Type, B)                  -> ?datetime:decode(Type, B);
+decode(boolarray, B)                        -> decode_array(B);
+decode(int2array, B)                        -> decode_array(B);
+decode(int4array, B)                        -> decode_array(B);
+decode(int8array, B)                        -> decode_array(B);
+decode(chararray, B)                        -> decode_array(B);
+decode(textarray, B)                        -> decode_array(B);
+decode(_Other, Bin)                         -> Bin.
+
+encode_array(Type, A) ->
+    {Data, {NDims, Lengths}} = encode_array(Type, A, 0, []),
+    Oid  = pgsql_types:type2oid(Type),
+    Lens = [<<N:?int32, 0:?int32>> || N <- lists:reverse(Lengths)],
+    Hdr  = <<NDims:?int32, 0:?int32, Oid:?int32>>,
+    Bin  = iolist_to_binary([Hdr, Lens, Data]),
+    <<(byte_size(Bin)):?int32, Bin/binary>>.
+
+encode_array(_Type, [], NDims, Lengths) ->
+    {<<>>, {NDims, Lengths}};
+encode_array(Type, [H | _] = Array, NDims, Lengths) when not is_list(H) ->
+    F = fun(E, Len) -> {encode(Type, E), Len + 1} end,
+    {Data, Len} = lists:mapfoldl(F, 0, Array),
+    {Data, {NDims + 1, [Len | Lengths]}};
+encode_array(Type, Array, NDims, Lengths) ->
+    Lengths2 = [length(Array) | Lengths],
+    F = fun(A2, {_NDims, _Lengths}) -> encode_array(Type, A2, NDims, Lengths2) end,
+    {Data, {NDims2, Lengths3}} = lists:mapfoldl(F, {NDims, Lengths2}, Array),
+    {Data, {NDims2 + 1, Lengths3}}.
+
+decode_array(<<NDims:?int32, _HasNull:?int32, Oid:?int32, Rest/binary>>) ->
+    {Dims, Data} = erlang:split_binary(Rest, NDims * 2 * 4),
+    Lengths = [Len || <<Len:?int32, _LBound:?int32>> <= Dims],
+    Type = pgsql_types:oid2type(Oid),
+    {Array, <<>>} = decode_array(Data, Type, Lengths),
+    Array.
+
+decode_array(Data, _Type, [])  ->
+    {[], Data};
+decode_array(Data, Type, [Len]) ->
+    decode_elements(Data, Type, [], Len);
+decode_array(Data, Type, [Len | T]) ->
+    F = fun(_N, Rest) -> decode_array(Rest, Type, T) end,
+    lists:mapfoldl(F, Data, lists:seq(1, Len)).
+
+decode_elements(Rest, _Type, Acc, 0) ->
+    {lists:reverse(Acc), Rest};
+decode_elements(<<-1:?int32, Rest/binary>>, Type, Acc, N) ->
+    decode_elements(Rest, Type, [null | Acc], N - 1);
+decode_elements(<<Len:?int32, Value:Len/binary, Rest/binary>>, Type, Acc, N) ->
+    Value2 = decode(Type, Value),
+    decode_elements(Rest, Type, [Value2 | Acc], N - 1).
 
 decode_record(<<>>, Acc) ->
     lists:reverse(Acc);
@@ -65,4 +126,10 @@ supports(timetz)  -> true;
 supports(timestamp)   -> true;
 supports(timestamptz) -> true;
 supports(interval)    -> true;
+supports(boolarray)   -> true;
+supports(int2array)   -> true;
+supports(int4array)   -> true;
+supports(int8array)   -> true;
+supports(chararray)   -> true;
+supports(textarray)   -> true;
 supports(_Type)       -> false.

+ 10 - 1
src/pgsql_types.erl

@@ -39,7 +39,12 @@ oid2type(790)  -> cash;
 oid2type(829)  -> macaddr;
 oid2type(869)  -> inet;
 oid2type(650)  -> cidr;
+oid2type(1000) -> boolarray;
+oid2type(1005) -> int2array;
 oid2type(1007) -> int4array;
+oid2type(1009) -> textarray;
+oid2type(1014) -> chararray;
+oid2type(1016) -> int8array;
 oid2type(1021) -> float4array;
 oid2type(1033) -> aclitem;
 oid2type(1263) -> cstringarray;
@@ -117,7 +122,12 @@ type2oid(cash)                  -> 790;
 type2oid(macaddr)               -> 829;
 type2oid(inet)                  -> 869;
 type2oid(cidr)                  -> 650;
+type2oid(boolarray)             -> 1000;
+type2oid(int2array)             -> 1005;
 type2oid(int4array)             -> 1007;
+type2oid(textarray)             -> 1009;
+type2oid(chararray)             -> 1014;
+type2oid(int8array)             -> 1016;
 type2oid(float4array)           -> 1021;
 type2oid(aclitem)               -> 1033;
 type2oid(cstringarray)          -> 1263;
@@ -157,4 +167,3 @@ type2oid(anyelement)            -> 2283;
 type2oid(anynonarray)           -> 2776;
 type2oid(anyenum)               -> 3500;
 type2oid(Type)                  -> {unknown_type, Type}.
-

+ 18 - 0
test_src/pgsql_tests.erl

@@ -409,6 +409,24 @@ misc_type_test() ->
     check_type(bool, "true", true, [true, false]),
     check_type(bytea, "E'\001\002'", <<1,2>>, [<<>>, <<0,128,255>>]).
 
+array_type_test() ->
+    with_connection(
+      fun(C) ->
+          Select = fun(Type, V) ->
+                       Query = "select $1::" ++ Type,
+                       {ok, _Cols, [{V}]} = pgsql:equery(C, Query, [V])
+                   end,
+          Select("int2[]", []),
+          Select("int2[]", [1, 2, 3, 4]),
+          Select("int2[]", [[1], [2], [3], [4]]),
+          Select("int2[]", [[[[[[1, 2]]]]]]),
+          Select("bool[]", [true]),
+          Select("char[]", [$a, $b, $c]),
+          Select("int4[]", [[1, 2]]),
+          Select("int8[]", [[[[1, 2]], [[3, 4]]]]),
+          Select("text[]", [<<"one">>, <<"two>">>])
+      end).
+
 text_format_test() ->
     with_connection(
       fun(C) ->