Просмотр исходного кода

Merge pull request #212 from seriyps/configurable-null

Make it possible to configure NULL representation.
Sergey Prokhorov 5 лет назад
Родитель
Сommit
31a413190f
6 измененных файлов с 182 добавлено и 70 удалено
  1. 9 0
      README.md
  2. 3 2
      src/commands/epgsql_cmd_connect.erl
  3. 2 0
      src/epgsql.erl
  4. 111 50
      src/epgsql_binary.erl
  5. 14 12
      src/epgsql_wire.erl
  6. 43 6
      test/epgsql_SUITE.erl

+ 9 - 0
README.md

@@ -74,6 +74,7 @@ connect(Opts) -> {ok, Connection :: epgsql:connection()} | {error, Reason :: epg
       timeout =>  timeout(),             % socket connect timeout, default: 5000 ms
       async =>    pid() | atom(),        % process to receive LISTEN/NOTIFY msgs
       codecs =>   [{epgsql_codec:codec_mod(), any()}]}
+      nulls =>    [any(), ...],          % NULL terms
       replication => Replication :: string()} % Pass "database" to connect in replication mode
     | list().
 
@@ -104,6 +105,12 @@ Only `host` and `username` are mandatory, but most likely you would need `databa
 - `ssl_opts` will be passed as is to `ssl:connect/3`
 - `async` see [Server notifications](#server-notifications)
 - `codecs` see [Pluggable datatype codecs](#pluggable-datatype-codecs)
+- `nulls` terms which will be used to represent SQL `NULL`. If any of those has been encountered in
+   placeholder parameters (`$1`, `$2` etc values), it will be interpreted as `NULL`.
+   1st element of the list will be used to represent NULLs received from the server. It's not recommended
+   to use `"string"`s or lists. Try to keep this list short for performance!
+   Default is `[null, undefined]`, i.e. encode `null` or `undefined` in parameters as `NULL`
+   and decode `NULL`s as atom `null`.
 - `replication` see [Streaming replication protocol](#streaming-replication-protocol)
 
 Options may be passed as proplist or as map with the same key names.
@@ -469,6 +476,8 @@ PG type       | Representation
   tstzrange   | `{{Hour, Minute, Second.Microsecond}, {Hour, Minute, Second.Microsecond}}`
   daterange   | `{{Year, Month, Day}, {Year, Month, Day}}`
 
+`null` can be configured. See `nulls` `connect/1` option.
+
 `timestamp` and `timestamptz` parameters can take `erlang:now()` format: `{MegaSeconds, Seconds, MicroSeconds}`
 
 `int4range` is a range type for ints that obeys inclusive/exclusive semantics,

+ 3 - 2
src/commands/epgsql_cmd_connect.erl

@@ -242,8 +242,9 @@ handle_message(?CANCELLATION_KEY, <<Pid:?int32, Key:?int32>>, Sock, _State) ->
     {noaction, epgsql_sock:set_attr(backend, {Pid, Key}, Sock)};
 
 %% ReadyForQuery
-handle_message(?READY_FOR_QUERY, _, Sock, _State) ->
-    Codec = epgsql_binary:new_codec(Sock, []),
+handle_message(?READY_FOR_QUERY, _, Sock, #connect{opts = Opts}) ->
+    CodecOpts = maps:with([nulls], Opts),
+    Codec = epgsql_binary:new_codec(Sock, CodecOpts),
     Sock1 = epgsql_sock:set_attr(codec, Codec, Sock),
     {finish, connected, connected, Sock1};
 

+ 2 - 0
src/epgsql.erl

@@ -57,6 +57,7 @@
     {timeout,  TimeoutMs  :: timeout()}            | % default: 5000 ms
     {async,    Receiver   :: pid() | atom()}       | % process to receive LISTEN/NOTIFY msgs
     {codecs,   Codecs     :: [{epgsql_codec:codec_mod(), any()}]} |
+    {nulls,    Nulls      :: [any(), ...]} |    % terms to be used as NULL
     {replication, Replication :: string()}. % Pass "database" to connect in replication mode
 
 -type connect_opts() ::
@@ -71,6 +72,7 @@
           timeout => timeout(),
           async => pid() | atom(),
           codecs => [{epgsql_codec:codec_mod(), any()}],
+          nulls => [any(), ...],
           replication => string()}.
 
 -type connect_error() :: epgsql_cmd_connect:connect_error().

+ 111 - 50
src/epgsql_binary.erl

@@ -4,6 +4,8 @@
 
 -export([new_codec/2,
          update_codec/2,
+         null/1,
+         is_null/2,
          type_to_oid/2,
          typeinfo_to_name_array/2,
          typeinfo_to_oid_info/2,
@@ -17,10 +19,24 @@
 -export_type([codec/0, decoder/0]).
 
 -include("protocol.hrl").
+-define(DEFAULT_NULLS, [null, undefined]).
 
 -record(codec,
-        {opts = [] :: list(),                   % not used yet
+        {opts = #{} :: opts(),                   % not used yet
+         nulls = ?DEFAULT_NULLS :: nulls(),
          oid_db :: epgsql_oid_db:db()}).
+-record(array_decoder,
+        {element_decoder :: decoder(),
+         null_term :: any() }).
+-record(array_encoder,
+        {element_encoder :: epgsql_codec:codec_entry(),
+         n_dims = 0 :: non_neg_integer(),
+         lengths = [] :: [non_neg_integer()],
+         has_null = false :: boolean(),
+         codec :: codec()}).
+
+-type nulls() :: [any(), ...].
+-type opts() :: #{nulls => nulls()}.
 
 -opaque codec() :: #codec{}.
 -opaque decoder() :: {fun((binary(), epgsql:type_name(), epgsql_codec:codec_state()) -> any()),
@@ -36,7 +52,7 @@
 %% Codec is used to convert data (result rows and query parameters) between Erlang and postgresql formats
 %% It uses mappings between OID, type names and `epgsql_codec_*' modules (epgsql_oid_db)
 
--spec new_codec(epgsql_sock:pg_sock(), list()) -> codec().
+-spec new_codec(epgsql_sock:pg_sock(), opts()) -> codec().
 new_codec(PgSock, Opts) ->
     Codecs = default_codecs(),
     Oids = default_oids(),
@@ -45,7 +61,9 @@ new_codec(PgSock, Opts) ->
 new_codec(PgSock, Codecs, Oids, Opts) ->
     CodecEntries = epgsql_codec:init_mods(Codecs, PgSock),
     Types = epgsql_oid_db:join_codecs_oids(Oids, CodecEntries),
-    #codec{oid_db = epgsql_oid_db:from_list(Types), opts = Opts}.
+    #codec{oid_db = epgsql_oid_db:from_list(Types),
+           nulls = maps:get(nulls, Opts, ?DEFAULT_NULLS),
+           opts = Opts}.
 
 -spec update_codec([epgsql_oid_db:type_info()], codec()) -> codec().
 update_codec(TypeInfos, #codec{oid_db = Db} = Codec) ->
@@ -63,6 +81,16 @@ oid_to_name(Oid, Codec) ->
             end
     end.
 
+%% @doc Return the value that represents NULL (1st element of `nulls' list)
+-spec null(codec()) -> any().
+null(#codec{nulls = [Null | _]}) ->
+    Null.
+
+%% @doc Returns `true' if `Value' is a term representing `NULL'
+-spec is_null(any(), codec()) -> boolean().
+is_null(Value, #codec{nulls = Nulls}) ->
+    lists:member(Value, Nulls).
+
 -spec type_to_oid(type(), codec()) -> epgsql_oid_db:oid().
 type_to_oid({array, Name}, Codec) ->
     type_to_oid(Name, true, Codec);
@@ -117,28 +145,30 @@ decode(Bin, {Fun, TypeName, State}) ->
 oid_to_decoder(?RECORD_OID, binary, Codec) ->
     {fun ?MODULE:decode_record/3, record, Codec};
 oid_to_decoder(?RECORD_ARRAY_OID, binary, Codec) ->
-    %% See `make_array_decoder/3'
-    {fun ?MODULE:decode_array/3, [], oid_to_decoder(?RECORD_OID, binary, Codec)};
-oid_to_decoder(Oid, Format, #codec{oid_db = Db}) ->
+    {fun ?MODULE:decode_array/3, array,
+     #array_decoder{
+        element_decoder = oid_to_decoder(?RECORD_OID, binary, Codec),
+        null_term = null(Codec)}};
+oid_to_decoder(Oid, Format, #codec{oid_db = Db} = Codec) ->
     case epgsql_oid_db:find_by_oid(Oid, Db) of
         undefined when Format == binary ->
             {fun epgsql_codec_noop:decode/3, undefined, []};
         undefined when Format == text ->
             {fun epgsql_codec_noop:decode_text/3, undefined, []};
         Type ->
-            make_decoder(Type, Format)
+            make_decoder(Type, Format, Codec)
     end.
 
--spec make_decoder(epgsql_oid_db:type_info(), binary | text) -> decoder().
-make_decoder(Type, Format) ->
+-spec make_decoder(epgsql_oid_db:type_info(), binary | text, codec()) -> decoder().
+make_decoder(Type, Format, Codec) ->
     {Name, Mod, State} = epgsql_oid_db:type_to_codec_entry(Type),
     {_Oid, Name, IsArray} = epgsql_oid_db:type_to_oid_info(Type),
-    make_decoder(Name, Mod, State, Format, IsArray).
+    make_decoder(Name, Mod, State, Codec, Format, IsArray).
 
-make_decoder(_Name, _Mod, _State, text, true) ->
+make_decoder(_Name, _Mod, _State, _Codec, text, true) ->
     %% Don't try to decode text arrays
     {fun epgsql_codec_noop:decode_text/3, undefined, []};
-make_decoder(Name, Mod, State, text, false) ->
+make_decoder(Name, Mod, State, _Codec, text, false) ->
     %% decode_text/3 is optional callback. If it's not defined, do NOOP.
     case erlang:function_exported(Mod, decode_text, 3) of
         true ->
@@ -146,18 +176,18 @@ make_decoder(Name, Mod, State, text, false) ->
         false ->
             {fun epgsql_codec_noop:decode_text/3, undefined, []}
     end;
-make_decoder(Name, Mod, State, binary, true) ->
-    make_array_decoder(Name, Mod, State);
-make_decoder(Name, Mod, State, binary, false) ->
+make_decoder(Name, Mod, State, Codec, binary, true) ->
+    {fun ?MODULE:decode_array/3, array,
+     #array_decoder{
+        element_decoder = {fun Mod:decode/3, Name, State},
+        null_term = null(Codec)}};
+make_decoder(Name, Mod, State, _Codec, binary, false) ->
     {fun Mod:decode/3, Name, State}.
 
 
 %% Array decoding
 %%% $PG$/src/backend/utils/adt/arrayfuncs.c
-make_array_decoder(Name, Mod, State) ->
-    {fun ?MODULE:decode_array/3, [], {fun Mod:decode/3, Name, State}}.
-
-decode_array(<<NDims:?int32, _HasNull:?int32, _Oid:?int32, Rest/binary>>, _, ElemDecoder) ->
+decode_array(<<NDims:?int32, _HasNull:?int32, _Oid:?int32, Rest/binary>>, _, ArrayDecoder) ->
     %% 4b: n_dimensions;
     %% 4b: flags;
     %% 4b: Oid // should be the same as in column spec;
@@ -168,27 +198,29 @@ decode_array(<<NDims:?int32, _HasNull:?int32, _Oid:?int32, Rest/binary>>, _, Ele
     %% https://www.postgresql.org/docs/current/static/arrays.html#arrays-io
     {Dims, Data} = erlang:split_binary(Rest, NDims * 2 * 4),
     Lengths = [Len || <<Len:?int32, _LBound:?int32>> <= Dims],
-    {Array, <<>>} = decode_array1(Data, Lengths, ElemDecoder),
+    {Array, <<>>} = decode_array1(Data, Lengths, ArrayDecoder),
     Array.
 
 decode_array1(Data, [], _)  ->
     %% zero-dimensional array
     {[], Data};
-decode_array1(Data, [Len], ElemDecoder) ->
+decode_array1(Data, [Len], ArrayDecoder) ->
     %% 1-dimensional array
-    decode_elements(Data, [], Len, ElemDecoder);
-decode_array1(Data, [Len | T], ElemDecoder) ->
+    decode_elements(Data, [], Len, ArrayDecoder);
+decode_array1(Data, [Len | T], ArrayDecoder) ->
     %% multidimensional array
-    F = fun(_N, Rest) -> decode_array1(Rest, T, ElemDecoder) end,
+    F = fun(_N, Rest) -> decode_array1(Rest, T, ArrayDecoder) end,
     lists:mapfoldl(F, Data, lists:seq(1, Len)).
 
-decode_elements(Rest, Acc, 0, _ElDec) ->
+decode_elements(Rest, Acc, 0, _ArDec) ->
     {lists:reverse(Acc), Rest};
-decode_elements(<<-1:?int32, Rest/binary>>, Acc, N, ElDec) ->
-    decode_elements(Rest, [null | Acc], N - 1, ElDec);
-decode_elements(<<Len:?int32, Value:Len/binary, Rest/binary>>, Acc, N, ElemDecoder) ->
+decode_elements(<<-1:?int32, Rest/binary>>, Acc, N,
+                #array_decoder{null_term = Null} = ArDec) ->
+    decode_elements(Rest, [Null | Acc], N - 1, ArDec);
+decode_elements(<<Len:?int32, Value:Len/binary, Rest/binary>>, Acc, N,
+                #array_decoder{element_decoder = ElemDecoder} = ArDecoder) ->
     Value2 = decode(Value, ElemDecoder),
-    decode_elements(Rest, [Value2 | Acc], N - 1, ElemDecoder).
+    decode_elements(Rest, [Value2 | Acc], N - 1, ArDecoder).
 
 
 
@@ -199,7 +231,7 @@ decode_record(<<Size:?int32, Bin/binary>>, record, Codec) ->
 
 decode_record1(<<>>, 0, _Codec) -> [];
 decode_record1(<<_Type:?int32, -1:?int32, Rest/binary>>, Size, Codec) ->
-    [null | decode_record1(Rest, Size - 1, Codec)];
+    [null(Codec) | decode_record1(Rest, Size - 1, Codec)];
 decode_record1(<<Oid:?int32, Len:?int32, ValueBin:Len/binary, Rest/binary>>, Size, Codec) ->
     Value = decode(ValueBin, oid_to_decoder(Oid, binary, Codec)),
     [Value | decode_record1(Rest, Size - 1, Codec)].
@@ -213,44 +245,73 @@ decode_record1(<<Oid:?int32, Len:?int32, ValueBin:Len/binary, Rest/binary>>, Siz
 -spec encode(epgsql:type_name() | {array, epgsql:type_name()}, any(), codec()) -> iolist().
 encode(TypeName, Value, Codec) ->
     Type = type_to_type_info(TypeName, Codec),
-    encode_with_type(Type, Value).
+    encode_with_type(Type, Value, Codec).
 
-encode_with_type(Type, Value) ->
-    {Name, Mod, State} = epgsql_oid_db:type_to_codec_entry(Type),
+encode_with_type(Type, Value, Codec) ->
+    NameModState = epgsql_oid_db:type_to_codec_entry(Type),
     case epgsql_oid_db:type_to_oid_info(Type) of
         {_ArrayOid, _, true} ->
             %FIXME: check if this OID is the same as was returned by 'Describe'
             ElementOid = epgsql_oid_db:type_to_element_oid(Type),
-            encode_array(Value, ElementOid, {Mod, Name, State});
+            encode_array(Value, ElementOid,
+                         #array_encoder{
+                            element_encoder = NameModState,
+                            codec = Codec});
         {_Oid, _, false} ->
-            encode_value(Value, {Mod, Name, State})
+            encode_value(Value, NameModState)
     end.
 
-encode_value(Value, {Mod, Name, State}) ->
+encode_value(Value, {Name, Mod, State}) ->
     Payload = epgsql_codec:encode(Mod, Value, Name, State),
     [<<(iolist_size(Payload)):?int32>> | Payload].
 
 
 %% Number of dimensions determined at encode-time by introspection of data, so,
 %% we can't encode array of lists (eg. strings).
-encode_array(Array, Oid, ValueEncoder) ->
-    {Data, {NDims, Lengths}} = encode_array(Array, 0, [], ValueEncoder),
+encode_array(Array, Oid, ArrayEncoder) ->
+    {Data, {NDims, Lengths, HasNull}} = encode_array_dims(Array, ArrayEncoder),
     Lens = [<<N:?int32, 1:?int32>> || N <- lists:reverse(Lengths)],
-    Hdr  = <<NDims:?int32, 0:?int32, Oid:?int32>>,
+    HasNullInt = case HasNull of
+                     true -> 1;
+                     false -> 0
+                 end,
+    Hdr  = <<NDims:?int32, HasNullInt:?int32, Oid:?int32>>,
     Payload  = [Hdr, Lens, Data],
     [<<(iolist_size(Payload)):?int32>> | Payload].
 
-encode_array([], NDims, Lengths, _Codec) ->
-    {[], {NDims, Lengths}};
-encode_array([H | _] = Array, NDims, Lengths, ValueEncoder) when not is_list(H) ->
-    F = fun(E, Len) -> {encode_value(E, ValueEncoder), Len + 1} end,
-    {Data, Len} = lists:mapfoldl(F, 0, Array),
-    {Data, {NDims + 1, [Len | Lengths]}};
-encode_array(Array, NDims, Lengths, Codec) ->
-    Lengths2 = [length(Array) | Lengths],
-    F = fun(A2, {_NDims, _Lengths}) -> encode_array(A2, NDims, Lengths2, Codec) end,
-    {Data, {NDims2, Lengths3}} = lists:mapfoldl(F, {NDims, Lengths2}, Array),
-    {Data, {NDims2 + 1, Lengths3}}.
+encode_array_dims([], #array_encoder{n_dims = NDims,
+                                     lengths = Lengths,
+                                     has_null = HasNull}) ->
+    {[], {NDims, Lengths, HasNull}};
+encode_array_dims([H | _] = Array,
+                  #array_encoder{n_dims = NDims0,
+                                 lengths = Lengths0,
+                                 has_null = HasNull0,
+                                 codec = Codec,
+                                 element_encoder = ValueEncoder}) when not is_list(H) ->
+    F = fun(El, {Len, HasNull1}) ->
+                case is_null(El, Codec) of
+                    false ->
+                        {encode_value(El, ValueEncoder), {Len + 1, HasNull1}};
+                    true ->
+                        {<<-1:?int32>>, {Len + 1, true}}
+                end
+        end,
+    {Data, {Len, HasNull2}} = lists:mapfoldl(F, {0, HasNull0}, Array),
+    {Data, {NDims0 + 1, [Len | Lengths0], HasNull2}};
+encode_array_dims(Array, #array_encoder{lengths = Lengths0,
+                                        n_dims = NDims0,
+                                        has_null = HasNull0} = ArrayEncoder) ->
+    Lengths1 = [length(Array) | Lengths0],
+    F = fun(A2, {_NDims, _Lengths, HasNull1}) ->
+                encode_array_dims(A2, ArrayEncoder#array_encoder{
+                                   n_dims = NDims0,
+                                   has_null = HasNull1,
+                                   lengths = Lengths1})
+        end,
+    {Data, {NDims2, Lengths2, HasNull2}} =
+        lists:mapfoldl(F, {NDims0, Lengths1, HasNull0}, Array),
+    {Data, {NDims2 + 1, Lengths2, HasNull2}}.
 
 
 %% Supports

+ 14 - 12
src/epgsql_wire.erl

@@ -143,12 +143,12 @@ decode_data(Bin, {Decoders, _Columns, Codec}) ->
 
 decode_data(_, [], _) -> [];
 decode_data(<<-1:?int32, Rest/binary>>, [_Dec | Decs], Codec) ->
-    [null | decode_data(Rest, Decs, Codec)];
+    [epgsql_binary:null(Codec) | decode_data(Rest, Decs, Codec)];
 decode_data(<<Len:?int32, Value:Len/binary, Rest/binary>>, [Decoder | Decs], Codec) ->
     [epgsql_binary:decode(Value, Decoder)
      | decode_data(Rest, Decs, Codec)].
 
-%% @doc decode column information
+%% @doc decode RowDescription column information
 -spec decode_columns(non_neg_integer(), binary(), epgsql_binary:codec()) -> [epgsql:column()].
 decode_columns(0, _Bin, _Codec) -> [];
 decode_columns(Count, Bin, Codec) ->
@@ -177,7 +177,7 @@ decode_parameters(<<_Count:?int16, Bin/binary>>, Codec) ->
          TypeInfo -> TypeInfo
      end || <<Oid:?int32>> <= Bin].
 
-%% @doc decode command complete msg
+%% @doc decode CcommandComplete msg
 decode_complete(<<"SELECT", 0>>)        -> select;
 decode_complete(<<"SELECT", _/binary>>) -> select;
 decode_complete(<<"BEGIN", 0>>)         -> 'begin';
@@ -246,16 +246,18 @@ encode_parameters([P | T], Count, Formats, Values, Codec) ->
       Type :: epgsql:type_name()
             | {array, epgsql:type_name()}
             | {unknown_oid, epgsql_oid_db:oid()}.
-encode_parameter({T, undefined}, Codec) ->
-    encode_parameter({T, null}, Codec);
-encode_parameter({_, null}, _Codec) ->
-    {1, <<-1:?int32>>};
-encode_parameter({{unknown_oid, _Oid}, Value}, _Codec) ->
-    {0, encode_text(Value)};
 encode_parameter({Type, Value}, Codec) ->
-    {1, epgsql_binary:encode(Type, Value, Codec)};
-encode_parameter(Value, _Codec) ->
-    {0, encode_text(Value)}.
+    case epgsql_binary:is_null(Value, Codec) of
+        false ->
+            encode_parameter(Type, Value, Codec);
+        true ->
+            {1, <<-1:?int32>>}
+    end.
+
+encode_parameter({unknown_oid, _Oid}, Value, _Codec) ->
+    {0, encode_text(Value)};
+encode_parameter(Type, Value, Codec) ->
+    {1, epgsql_binary:encode(Type, Value, Codec)}.
 
 encode_text(B) when is_binary(B)  -> encode_bin(B);
 encode_text(A) when is_atom(A)    -> encode_bin(atom_to_binary(A, utf8));

+ 43 - 6
test/epgsql_SUITE.erl

@@ -67,7 +67,8 @@ groups() ->
             range_type,
             range8_type,
             date_time_range_type,
-            custom_types
+            custom_types,
+            custom_null
         ]},
         {generic, [parallel], [
             with_transaction
@@ -931,18 +932,24 @@ array_type(Config) ->
         {ok, _, [{[1, 2]}]} = Module:equery(C, "select ($1::int[])[1:2]", [[1, 2, 3]]),
         {ok, _, [{[{1, <<"one">>}, {2, <<"two">>}]}]} =
             Module:equery(C, "select Array(select (id, value) from test_table1)", []),
-        Select = fun(Type, A) ->
+        {ok, _, [{ [[1], [null], [3], [null]] }]} =
+            Module:equery(C, "select $1::int2[]", [ [[1], [null], [3], [undefined]] ]),
+        Select = fun(Type, AIn) ->
             Query = "select $1::" ++ atom_to_list(Type) ++ "[]",
-            {ok, _Cols, [{A2}]} = Module:equery(C, Query, [A]),
-            case lists:all(fun({V, V2}) -> compare(Type, V, V2) end, lists:zip(A, A2)) of
+            {ok, _Cols, [{AOut}]} = Module:equery(C, Query, [AIn]),
+            case lists:all(fun({VIn, VOut}) ->
+                                   compare(Type, VIn, VOut)
+                           end, lists:zip(AIn, AOut)) of
                 true  -> ok;
-                false -> ?assertMatch(A, A2)
+                false -> ?assertEqual(AIn, AOut)
             end
         end,
         Select(int2,   []),
         Select(int2,   [1, 2, 3, 4]),
         Select(int2,   [[1], [2], [3], [4]]),
         Select(int2,   [[[[[[1, 2]]]]]]),
+        Select(int2,   [1, null, 3, undefined]),
+        Select(int2,   [[1], [null], [3], [null]]),
         Select(bool,   [true]),
         Select(char,   [$a, $b, $c]),
         Select(int4,   [[1, 2]]),
@@ -983,7 +990,10 @@ record_type(Config) ->
         Select("select (1, '{2,3}'::int[])", {{1, [2, 3]}}),
 
         %% Array of records inside record
-        Select("select (0, ARRAY(select (id, value) from test_table1))", {{0,[{1,<<"one">>},{2,<<"two">>}]}})
+        Select("select (0, ARRAY(select (id, value) from test_table1))", {{0,[{1,<<"one">>},{2,<<"two">>}]}}),
+
+        %% Record with NULLs
+        Select("select (1, NULL::integer, 2)", {{1, null, 2}})
     end).
 
 custom_types(Config) ->
@@ -1000,6 +1010,33 @@ custom_types(Config) ->
         ?assertMatch({ok, _, [{bar}]}, Module:equery(C, "SELECT col FROM t_foo"))
     end).
 
+custom_null(Config) ->
+    Module = ?config(module, Config),
+    epgsql_ct:with_connection(Config, fun(C) ->
+        Test3 = fun(Type, In, Out) ->
+                        Q = ["SELECT $1::", Type],
+                        {ok, _, [{Res}]} = Module:equery(C, Q, [In]),
+                        ?assertEqual(Out, Res)
+                end,
+        Test = fun(Type, In) ->
+                       Test3(Type, In, In)
+               end,
+        Test("int2", nil),
+        Test3("int2", 'NULL', nil),
+        Test("text", nil),
+        Test3("text", 'NULL', nil),
+        Test3("text", null, <<"null">>),
+        Test("int2[]", [nil, 1, nil, 2]),
+        Test3("text[]", [null, <<"ok">>], [<<"null">>, <<"ok">>]),
+        Test3("int2[]", ['NULL', 1, nil, 2], [nil, 1, nil, 2]),
+        Test("int2[]", [[nil], [1], [nil], [2]]),
+        Test3("int2[]", [['NULL'], [1], [nil], [2]], [[nil], [1], [nil], [2]]),
+        ?assertMatch(
+           {ok, _, [{ {1, nil, {2, nil, 3}} }]},
+           Module:equery(C, "SELECT (1, NULL, (2, NULL, 3))", []))
+    end,
+    [{nulls, [nil, 'NULL']}]).
+
 text_format(Config) ->
     Module = ?config(module, Config),
     epgsql_ct:with_connection(Config, fun(C) ->