Browse Source

Add functions to build the PROXY protocol header

Also add tests of the type parse(build(Info)), including
for testing the TLVs and the padding/checksum verification
options.
Loïc Hoguin 6 years ago
parent
commit
1cc7de15b6
1 changed files with 321 additions and 14 deletions
  1. 321 14
      src/ranch_proxy_header.erl

+ 321 - 14
src/ranch_proxy_header.erl

@@ -15,6 +15,8 @@
 -module(ranch_proxy_header).
 
 -export([parse/1]).
+-export([header/1]).
+-export([header/2]).
 
 -type proxy_info() :: #{
 	%% Mandatory part.
@@ -45,6 +47,13 @@
 }.
 -export_type([proxy_info/0]).
 
+-type build_opts() :: #{
+	checksum => crc32,
+	padding => pos_integer() %% >= 3
+}.
+
+%% Parsing.
+
 -spec parse(Data) -> {ok, proxy_info(), Data} | {error, atom()} when Data::binary().
 parse(<<"\r\n\r\n\0\r\nQUIT\n", Rest/bits>>) ->
 	parse_v2(Rest);
@@ -123,7 +132,9 @@ parse_ip(<<Addr:12/binary, $\s, Rest/binary>>, ipv4) -> parse_ipv4(Addr, Rest);
 parse_ip(<<Addr:13/binary, $\s, Rest/binary>>, ipv4) -> parse_ipv4(Addr, Rest);
 parse_ip(<<Addr:14/binary, $\s, Rest/binary>>, ipv4) -> parse_ipv4(Addr, Rest);
 parse_ip(<<Addr:15/binary, $\s, Rest/binary>>, ipv4) -> parse_ipv4(Addr, Rest);
-parse_ip(<<Addr:39/binary, $\s, Rest/binary>>, ipv6) -> parse_ipv6(Addr, Rest).
+parse_ip(Data, ipv6) ->
+	[Addr, Rest] = binary:split(Data, <<$\s>>),
+	parse_ipv6(Addr, Rest).
 
 parse_ipv4(Addr0, Rest) ->
 	case inet:parse_ipv4strict_address(binary_to_list(Addr0)) of
@@ -145,7 +156,7 @@ parse_port(<<Port:5/binary, C, Rest/bits>>, C) -> parse_port(Port, Rest);
 
 parse_port(Port0, Rest) ->
 	try binary_to_integer(Port0) of
-		Port when Port >= 0, Port =< 65535 ->
+		Port when Port > 0, Port =< 65535 ->
 			{ok, Port, Rest};
 		_ ->
 			throw(parse_port_error)
@@ -271,7 +282,7 @@ parse_v2(<<2:4, 1:4, Family:4, Protocol:4, Len:16, Rest/bits>>)
 		when Family =< 3, Protocol =< 2 ->
 	case Rest of
 		<<Header:Len/binary, _/bits>> ->
-			parse_v2(Rest, Len, family(Family), protocol(Protocol),
+			parse_v2(Rest, Len, parse_family(Family), parse_protocol(Protocol),
 				<<Family:4, Protocol:4, Len:16, Header:Len/binary>>);
 		_ ->
 			{error, 'Missing data in the PROXY protocol binary header. (PP 2.2)'}
@@ -286,14 +297,14 @@ parse_v2(<<_:8, Family:4, _/bits>>) when Family > 3 ->
 parse_v2(<<_:12, Protocol:4, _/bits>>) when Protocol > 2 ->
 	{error, 'Invalid transport protocol in the PROXY protocol binary header. (PP 2.2)'}.
 
-family(0) -> undefined;
-family(1) -> ipv4;
-family(2) -> ipv6;
-family(3) -> unix.
+parse_family(0) -> undefined;
+parse_family(1) -> ipv4;
+parse_family(2) -> ipv6;
+parse_family(3) -> unix.
 
-protocol(0) -> undefined;
-protocol(1) -> stream;
-protocol(2) -> dgram.
+parse_protocol(0) -> undefined;
+parse_protocol(1) -> stream;
+parse_protocol(2) -> dgram.
 
 parse_v2(Data, Len, Family, Protocol, _)
 		when Family =:= undefined; Protocol =:= undefined ->
@@ -484,12 +495,12 @@ parse_tlv(<<16#2, TLVLen:16, Authority:TLVLen/binary, Rest/bits>>, Len, Info, He
 %% PP2_TYPE_CRC32C.
 parse_tlv(<<16#3, TLVLen:16, CRC32C:32, Rest/bits>>, Len0, Info, Header) when TLVLen =:= 4 ->
 	Len = Len0 - TLVLen - 3,
-	BeforeLen = byte_size(Header) - Len - 7, %% 3 Family/Protocol/Len, 4 CRC32C
+	BeforeLen = byte_size(Header) - Len - TLVLen,
 	<<Before:BeforeLen/binary, _:32, After:Len/binary>> = Header,
 	%% The initial CRC is erlang:crc32(<<"\r\n\r\n\0\r\nQUIT\n", 2:4, 1:4>>).
 	case erlang:crc32(1302506282, [Before, <<0:32>>, After]) of
 		CRC32C ->
-			parse_tlv(Rest, Len - TLVLen - 3, Info, Header);
+			parse_tlv(Rest, Len, Info, Header);
 		_ ->
 			{error, 'Failed CRC32C verification in PROXY protocol binary header. (PP 2.2)'}
 	end;
@@ -502,7 +513,7 @@ parse_tlv(<<16#20, TLVLen:16, Client, Verify:32, Rest0/bits>>, Len, Info, Header
 	case Rest0 of
 		<<Subs:SubsLen/binary, Rest/bits>> ->
 			SSL0 = #{
-				client => client(<<Client>>),
+				client => parse_client(<<Client>>),
 				verified => Verify =:= 0
 			},
 			case parse_ssl_tlv(Subs, SubsLen, SSL0) of
@@ -525,7 +536,7 @@ parse_tlv(<<TLVType, TLVLen:16, TLVValue:TLVLen/binary, Rest/bits>>, Len, Info,
 parse_tlv(_, _, _, _) ->
 	{error, 'Invalid TLV length in the PROXY protocol binary header. (PP 2.2)'}.
 
-client(<<_:5, ClientCertSess:1, ClientCertConn:1, ClientSSL:1>>) ->
+parse_client(<<_:5, ClientCertSess:1, ClientCertConn:1, ClientSSL:1>>) ->
 	Client0 = case ClientCertSess of
 		0 -> [];
 		1 -> [cert_sess]
@@ -559,3 +570,299 @@ ssl_subtype(16#23) -> cipher;
 ssl_subtype(16#24) -> sig_alg;
 ssl_subtype(16#25) -> key_alg;
 ssl_subtype(_) -> undefined.
+
+%% Building.
+
+-spec header(proxy_info()) -> iodata().
+header(ProxyInfo) ->
+	header(ProxyInfo, #{}).
+
+-spec header(proxy_info(), build_opts()) -> iodata().
+header(#{version := 2, command := local}, _) ->
+	<<"\r\n\r\n\0\r\nQUIT\n", 2:4, 0:28>>;
+header(#{version := 2, command := proxy,
+		transport_family := Family,
+		transport_protocol := Protocol}, _)
+		when Family =:= undefined; Protocol =:= undefined ->
+	<<"\r\n\r\n\0\r\nQUIT\n", 2:4, 1:4, 0:24>>;
+header(ProxyInfo=#{version := 2, command := proxy,
+		transport_family := Family,
+		transport_protocol := Protocol}, Opts) ->
+	Addresses = addresses(ProxyInfo),
+	TLVs = tlvs(ProxyInfo, Opts),
+	ExtraLen = case Opts of
+		#{checksum := crc32} -> 7;
+		_ -> 0
+	end,
+	Len = iolist_size(Addresses) + iolist_size(TLVs) + ExtraLen,
+	Header = [
+		<<"\r\n\r\n\0\r\nQUIT\n", 2:4, 1:4>>,
+		<<(family(Family)):4, (protocol(Protocol)):4>>,
+		<<Len:16>>,
+		Addresses,
+		TLVs
+	],
+	case Opts of
+		#{checksum := crc32} ->
+			CRC32C = erlang:crc32([Header, <<16#3, 4:16, 0:32>>]),
+			[Header, <<16#3, 4:16, CRC32C:32>>];
+		_ ->
+			Header
+	end;
+header(#{version := 1, command := proxy,
+		transport_family := undefined,
+		transport_protocol := undefined}, _) ->
+	<<"PROXY UNKNOWN\r\n">>;
+header(#{version := 1, command := proxy,
+		transport_family := Family0,
+		transport_protocol := stream,
+		src_address := SrcAddress, src_port := SrcPort,
+		dest_address := DestAddress, dest_port := DestPort}, _)
+		when SrcPort > 0, SrcPort =< 65535, DestPort > 0, DestPort =< 65535 ->
+	[
+		<<"PROXY ">>,
+		case Family0 of
+			ipv4 when tuple_size(SrcAddress) =:= 4, tuple_size(DestAddress) =:= 4 ->
+				[<<"TCP4 ">>, inet:ntoa(SrcAddress), $\s, inet:ntoa(DestAddress)];
+			ipv6 when tuple_size(SrcAddress) =:= 8, tuple_size(DestAddress) =:= 8 ->
+				[<<"TCP6 ">>, inet:ntoa(SrcAddress), $\s, inet:ntoa(DestAddress)]
+		end,
+		$\s,
+		integer_to_binary(SrcPort),
+		$\s,
+		integer_to_binary(DestPort),
+		$\r, $\n
+	].
+
+family(ipv4) -> 1;
+family(ipv6) -> 2;
+family(unix) -> 3.
+
+protocol(stream) -> 1;
+protocol(dgram) -> 2.
+
+addresses(#{transport_family := ipv4,
+		src_address := {S1, S2, S3, S4}, src_port := SrcPort,
+		dest_address := {D1, D2, D3, D4}, dest_port := DestPort})
+		when SrcPort > 0, SrcPort =< 65535, DestPort > 0, DestPort =< 65535 ->
+	<<S1, S2, S3, S4, D1, D2, D3, D4, SrcPort:16, DestPort:16>>;
+addresses(#{transport_family := ipv6,
+		src_address := {S1, S2, S3, S4, S5, S6, S7, S8}, src_port := SrcPort,
+		dest_address := {D1, D2, D3, D4, D5, D6, D7, D8}, dest_port := DestPort})
+		when SrcPort > 0, SrcPort =< 65535, DestPort > 0, DestPort =< 65535 ->
+	<<
+		S1:16, S2:16, S3:16, S4:16, S5:16, S6:16, S7:16, S8:16,
+		D1:16, D2:16, D3:16, D4:16, D5:16, D6:16, D7:16, D8:16,
+		SrcPort:16, DestPort:16
+	>>;
+addresses(#{transport_family := unix,
+		src_address := SrcAddress, dest_address := DestAddress})
+		when byte_size(SrcAddress) =< 108, byte_size(DestAddress) =< 108 ->
+	SrcPadding = 8 * (108 - byte_size(SrcAddress)),
+	DestPadding = 8 * (108 - byte_size(DestAddress)),
+	<<
+		SrcAddress/binary, 0:SrcPadding,
+		DestAddress/binary, 0:DestPadding
+	>>.
+
+tlvs(ProxyInfo, Opts) ->
+	[
+		binary_tlv(ProxyInfo, alpn, 16#1),
+		binary_tlv(ProxyInfo, authority, 16#2),
+		ssl_tlv(ProxyInfo),
+		binary_tlv(ProxyInfo, netns, 16#30),
+		raw_tlvs(ProxyInfo),
+		noop_tlv(Opts)
+	].
+
+binary_tlv(Info, Key, Type) ->
+	case Info of
+		#{Key := Bin} ->
+			Len = byte_size(Bin),
+			<<Type, Len:16, Bin/binary>>;
+		_ ->
+			<<>>
+	end.
+
+noop_tlv(#{padding := Len0}) when Len0 >= 3 ->
+	Len = Len0 - 3,
+	<<16#4, Len:16, 0:Len/unit:8>>;
+noop_tlv(_) ->
+	<<>>.
+
+ssl_tlv(#{ssl := Info=#{client := Client0, verified := Verify0}}) ->
+	Client = client(Client0, 0),
+	Verify = if
+		Verify0 -> 0;
+		not Verify0 -> 1
+	end,
+	TLVs = [
+		binary_tlv(Info, version, 16#21),
+		binary_tlv(Info, cn, 16#22),
+		binary_tlv(Info, cipher, 16#23),
+		binary_tlv(Info, sig_alg, 16#24),
+		binary_tlv(Info, key_alg, 16#25)
+	],
+	Len = iolist_size(TLVs) + 5,
+	[<<16#20, Len:16, Client, Verify:32>>, TLVs];
+ssl_tlv(_) ->
+	<<>>.
+
+client([], Client) -> Client;
+client([ssl|Tail], Client) -> client(Tail, Client bor 16#1);
+client([cert_conn|Tail], Client) -> client(Tail, Client bor 16#2);
+client([cert_sess|Tail], Client) -> client(Tail, Client bor 16#4).
+
+raw_tlvs(Info) ->
+	[begin
+		Len = byte_size(Bin),
+		<<Type, Len:16, Bin/binary>>
+	end || {Type, Bin} <- maps:get(raw_tlvs, Info, [])].
+
+-ifdef(TEST).
+v1_test() ->
+	Test1 = #{
+		version => 1,
+		command => proxy,
+		transport_family => undefined,
+		transport_protocol => undefined
+	},
+	{ok, Test1, <<>>} = parse(iolist_to_binary(header(Test1))),
+	Test2 = #{
+		version => 1,
+		command => proxy,
+		transport_family => ipv4,
+		transport_protocol => stream,
+		src_address => {127, 0, 0, 1},
+		src_port => 1234,
+		dest_address => {10, 11, 12, 13},
+		dest_port => 23456
+	},
+	{ok, Test2, <<>>} = parse(iolist_to_binary(header(Test2))),
+	Test3 = #{
+		version => 1,
+		command => proxy,
+		transport_family => ipv6,
+		transport_protocol => stream,
+		src_address => {1, 2, 3, 4, 5, 6, 7, 8},
+		src_port => 1234,
+		dest_address => {65535, 55555, 2222, 333, 1, 9999, 777, 8},
+		dest_port => 23456
+	},
+	{ok, Test3, <<>>} = parse(iolist_to_binary(header(Test3))),
+	ok.
+
+v2_test() ->
+	Test0 = #{
+		version => 2,
+		command => local
+	},
+	{ok, Test0, <<>>} = parse(iolist_to_binary(header(Test0))),
+	Test1 = #{
+		version => 2,
+		command => proxy,
+		transport_family => undefined,
+		transport_protocol => undefined
+	},
+	{ok, Test1, <<>>} = parse(iolist_to_binary(header(Test1))),
+	Test2 = #{
+		version => 2,
+		command => proxy,
+		transport_family => ipv4,
+		transport_protocol => stream,
+		src_address => {127, 0, 0, 1},
+		src_port => 1234,
+		dest_address => {10, 11, 12, 13},
+		dest_port => 23456
+	},
+	{ok, Test2, <<>>} = parse(iolist_to_binary(header(Test2))),
+	Test3 = #{
+		version => 2,
+		command => proxy,
+		transport_family => ipv6,
+		transport_protocol => stream,
+		src_address => {1, 2, 3, 4, 5, 6, 7, 8},
+		src_port => 1234,
+		dest_address => {65535, 55555, 2222, 333, 1, 9999, 777, 8},
+		dest_port => 23456
+	},
+	{ok, Test3, <<>>} = parse(iolist_to_binary(header(Test3))),
+	Test4 = #{
+		version => 2,
+		command => proxy,
+		transport_family => unix,
+		transport_protocol => dgram,
+		src_address => <<"/run/source.sock">>,
+		dest_address => <<"/run/destination.sock">>
+	},
+	{ok, Test4, <<>>} = parse(iolist_to_binary(header(Test4))),
+	ok.
+
+v2_tlvs_test() ->
+	Common = #{
+		version => 2,
+		command => proxy,
+		transport_family => ipv4,
+		transport_protocol => stream,
+		src_address => {127, 0, 0, 1},
+		src_port => 1234,
+		dest_address => {10, 11, 12, 13},
+		dest_port => 23456
+	},
+	Test1 = Common#{alpn => <<"h2">>},
+	{ok, Test1, <<>>} = parse(iolist_to_binary(header(Test1))),
+	Test2 = Common#{authority => <<"internal.example.org">>},
+	{ok, Test2, <<>>} = parse(iolist_to_binary(header(Test2))),
+	Test3 = Common#{netns => <<"/var/run/netns/example">>},
+	{ok, Test3, <<>>} = parse(iolist_to_binary(header(Test3))),
+	Test4 = Common#{ssl => #{
+		client => [ssl, cert_conn, cert_sess],
+		verified => true,
+		version => <<"TLSv1.3">>, %% Note that I'm not sure this example value is correct.
+		cipher => <<"ECDHE-RSA-AES128-GCM-SHA256">>,
+		sig_alg => <<"SHA256">>,
+		key_alg => <<"RSA2048">>,
+		cn => <<"example.com">>
+	}},
+	{ok, Test4, <<>>} = parse(iolist_to_binary(header(Test4))),
+	%% Note that the raw_tlvs order is not relevant and therefore
+	%% the parser does not reverse the list it builds.
+	Test5In = Common#{raw_tlvs => RawTLVs=[
+		%% The only custom TLV I am aware of is defined at:
+		%% https://docs.aws.amazon.com/elasticloadbalancing/latest/network/load-balancer-target-groups.html#proxy-protocol
+		{16#ea, <<16#1, "instance-id">>},
+		%% This TLV is entirely fictional.
+		{16#ff, <<1, 2, 3, 4, 5, 6, 7, 8, 9, 0>>}
+	]},
+	Test5Out = Test5In#{raw_tlvs => lists:reverse(RawTLVs)},
+	{ok, Test5Out, <<>>} = parse(iolist_to_binary(header(Test5In))),
+	ok.
+
+v2_checksum_test() ->
+	Test = #{
+		version => 2,
+		command => proxy,
+		transport_family => ipv4,
+		transport_protocol => stream,
+		src_address => {127, 0, 0, 1},
+		src_port => 1234,
+		dest_address => {10, 11, 12, 13},
+		dest_port => 23456
+	},
+	{ok, Test, <<>>} = parse(iolist_to_binary(header(Test, #{checksum => crc32}))),
+	ok.
+
+v2_padding_test() ->
+	Test = #{
+		version => 2,
+		command => proxy,
+		transport_family => ipv4,
+		transport_protocol => stream,
+		src_address => {127, 0, 0, 1},
+		src_port => 1234,
+		dest_address => {10, 11, 12, 13},
+		dest_port => 23456
+	},
+	{ok, Test, <<>>} = parse(iolist_to_binary(header(Test, #{padding => 123}))),
+	ok.
+-endif.