Make it possible to restrict connections by TLS SNI domains

parent 35c8f0da
...@@ -11,11 +11,12 @@ ...@@ -11,11 +11,12 @@
-behaviour(mtp_codec). -behaviour(mtp_codec).
-export([format_secret/2]).
-export([from_client_hello/2, -export([from_client_hello/2,
new/0, new/0,
try_decode_packet/2, try_decode_packet/2,
encode_packet/2]). encode_packet/2]).
-export_type([codec/0]). -export_type([codec/0, meta/0]).
-include_lib("hut/include/hut.hrl"). -include_lib("hut/include/hut.hrl").
...@@ -56,16 +57,39 @@ ...@@ -56,16 +57,39 @@
-define(TLS_CHANGE_CIPHER, ?TLS_REC_CHANGE_CIPHER, ?TLS_12_VERSION, 0, 1, 1). -define(TLS_CHANGE_CIPHER, ?TLS_REC_CHANGE_CIPHER, ?TLS_12_VERSION, 0, 1, 1).
-define(EXT_SNI, 0). -define(EXT_SNI, 0).
-define(EXT_SNI_HOST_NAME, 0).
-define(APP, mtproto_proxy). -define(APP, mtproto_proxy).
-opaque codec() :: #st{}. -opaque codec() :: #st{}.
-type meta() :: #{session_id := binary(),
timestamp := non_neg_integer(),
sni_domain => binary()}.
%% @doc format TLS secret
-spec format_secret(binary(), binary()) -> binary().
format_secret(Secret, Domain) when byte_size(Secret) == 16 ->
base64url(<<16#ee, Secret/binary, Domain/binary>>);
format_secret(HexSecret, Domain) when byte_size(HexSecret) == 32 ->
format_secret(mtp_handler:unhex(HexSecret), Domain).
base64url(Bin) ->
%% see https://hex.pm/packages/base64url
<< << (urlencode_digit(D)) >> || <<D>> <= base64:encode(Bin), D =/= $= >>.
urlencode_digit($/) -> $_;
urlencode_digit($+) -> $-;
urlencode_digit(D) -> D.
-spec from_client_hello(binary(), binary()) -> -spec from_client_hello(binary(), binary()) ->
{ok, iodata(), binary(), non_neg_integer(), codec()}. {ok, iodata(), meta(), codec()}.
from_client_hello(Data, Secret) -> from_client_hello(Data, Secret) ->
#client_hello{pseudorandom = ClientDigest, #client_hello{pseudorandom = ClientDigest,
session_id = SessionId} = CliHlo = parse_client_hello(Data), session_id = SessionId,
extensions = Extensions} = CliHlo = parse_client_hello(Data),
?log(debug, "TLS ClientHello=~p", [CliHlo]), ?log(debug, "TLS ClientHello=~p", [CliHlo]),
ServerDigest = make_server_digest(Data, Secret), ServerDigest = make_server_digest(Data, Secret),
<<Zeroes:(?DIGEST_LEN - 4)/binary, _/binary>> = XoredDigest = <<Zeroes:(?DIGEST_LEN - 4)/binary, _/binary>> = XoredDigest =
...@@ -84,7 +108,15 @@ from_client_hello(Data, Secret) -> ...@@ -84,7 +108,15 @@ from_client_hello(Data, Secret) ->
Response = [as_tls_frame(?TLS_REC_HANDSHAKE, SrvHello), Response = [as_tls_frame(?TLS_REC_HANDSHAKE, SrvHello),
CC, CC,
DD], DD],
{ok, Response, SessionId, Timestamp, new()}. Meta0 = #{session_id => SessionId,
timestamp => Timestamp},
Meta = case lists:keyfind(?EXT_SNI, 1, Extensions) of
{_, [{?EXT_SNI_HOST_NAME, Domain}]} ->
Meta0#{sni_domain => Domain};
_ ->
Meta0
end,
{ok, Response, Meta, new()}.
parse_client_hello(<<?TLS_REC_HANDSHAKE, ?TLS_10_VERSION, 512:16/unsigned-big, %Frame parse_client_hello(<<?TLS_REC_HANDSHAKE, ?TLS_10_VERSION, 512:16/unsigned-big, %Frame
......
...@@ -288,12 +288,13 @@ handle_upstream_data(Bin, #state{codec = Codec0} = S0) -> ...@@ -288,12 +288,13 @@ handle_upstream_data(Bin, #state{codec = Codec0} = S0) ->
parse_upstream_data(<<?TLS_START, _/binary>> = AllData, parse_upstream_data(<<?TLS_START, _/binary>> = AllData,
#state{stage = tls_hello, secret = Secret, codec = Codec0} = S) when #state{stage = tls_hello, secret = Secret, codec = Codec0,
addr = {Ip, _}, listener = Listener} = S) when
byte_size(AllData) >= (?TLS_CLIENT_HELLO_LEN + 5) -> byte_size(AllData) >= (?TLS_CLIENT_HELLO_LEN + 5) ->
assert_protocol(mtp_fake_tls), assert_protocol(mtp_fake_tls),
<<Data:(?TLS_CLIENT_HELLO_LEN + 5)/binary, Tail/binary>> = AllData, <<Data:(?TLS_CLIENT_HELLO_LEN + 5)/binary, Tail/binary>> = AllData,
{ok, Response, SessionId, Timestamp, TlsCodec} = mtp_fake_tls:from_client_hello(Data, Secret), {ok, Response, Meta, TlsCodec} = mtp_fake_tls:from_client_hello(Data, Secret),
maybe_check_tls_replay(SessionId, Timestamp), check_tls_access(Listener, Ip, Meta),
Codec1 = mtp_codec:replace(tls, true, TlsCodec, Codec0), Codec1 = mtp_codec:replace(tls, true, TlsCodec, Codec0),
Codec = mtp_codec:push_back(tls, Tail, Codec1), Codec = mtp_codec:push_back(tls, Tail, Codec1),
ok = up_send_raw(Response, S), ok = up_send_raw(Response, S),
...@@ -354,9 +355,15 @@ maybe_check_replay(Packet) -> ...@@ -354,9 +355,15 @@ maybe_check_replay(Packet) ->
ok ok
end. end.
maybe_check_tls_replay(_SessionId, _Timestamp) -> check_tls_access(_Listener, _Ip, #{sni_domain := Domain}) ->
%% TODO %% TODO validate timestamp!
ok. %% TODO some more scalable solution
AllowedDomains = application:get_env(?APP, tls_allowed_domains, []),
lists:member(Domain, AllowedDomains)
orelse error({protocol_error, tls_sni_domain_not_allowed, Domain});
check_tls_access(_, Ip, Meta) ->
error({protocol_error, tls_no_sni, {Ip, Meta}}).
up_send(Packet, #state{stage = tunnel, codec = UpCodec} = S) -> up_send(Packet, #state{stage = tunnel, codec = UpCodec} = S) ->
%% ?log(debug, ">Up: ~p", [Packet]), %% ?log(debug, ">Up: ~p", [Packet]),
......
...@@ -58,6 +58,10 @@ ...@@ -58,6 +58,10 @@
%% protocols will be immediately disallowed. %% protocols will be immediately disallowed.
{allowed_protocols, [mtp_fake_tls, mtp_secure, mtp_abridged, mtp_intermediate]}, {allowed_protocols, [mtp_fake_tls, mtp_secure, mtp_abridged, mtp_intermediate]},
%% Which domains to allow in TLS SNI
%% XXX: this option is experimental and will be removed later!
{tls_allowed_domains, [<<"en.wikipedia.org">>]},
{init_dc_connections, 2}, {init_dc_connections, 2},
{clients_per_dc_connection, 300}, {clients_per_dc_connection, 300},
......
...@@ -81,6 +81,7 @@ running_ports() -> ...@@ -81,6 +81,7 @@ running_ports() ->
end end
end, mtp_listeners()). end, mtp_listeners()).
%%==================================================================== %%====================================================================
%% Internal functions %% Internal functions
%%==================================================================== %%====================================================================
...@@ -171,8 +172,9 @@ build_urls(Host, Port, Secret, Protocols) -> ...@@ -171,8 +172,9 @@ build_urls(Host, Port, Secret, Protocols) ->
end, Protocols)), end, Protocols)),
lists:map( lists:map(
fun(mtp_fake_tls) -> fun(mtp_fake_tls) ->
RawSecret = mtp_handler:unhex(Secret), %% Print just for 1st domain as example
ProtoSecret = base64url(<<16#ee, RawSecret/binary, "en.wikipedia.org">>), {ok, [Domain | _]} = application:get_env(?APP, tls_allowed_domains),
ProtoSecret = mtp_fake_tls:format_secret(Secret, Domain),
MkUrl(ProtoSecret); MkUrl(ProtoSecret);
(mtp_secure) -> (mtp_secure) ->
ProtoSecret = ["dd", Secret], ProtoSecret = ["dd", Secret],
...@@ -181,14 +183,6 @@ build_urls(Host, Port, Secret, Protocols) -> ...@@ -181,14 +183,6 @@ build_urls(Host, Port, Secret, Protocols) ->
MkUrl(Secret) MkUrl(Secret)
end, UrlTypes). end, UrlTypes).
base64url(Bin) ->
%% see https://hex.pm/packages/base64url
<< << (urlencode_digit(D)) >> || <<D>> <= base64:encode(Bin), D =/= $= >>.
urlencode_digit($/) -> $_;
urlencode_digit($+) -> $-;
urlencode_digit(D) -> D.
-ifdef(TEST). -ifdef(TEST).
report(Fmt, Args) -> report(Fmt, Args) ->
?log(debug, Fmt, Args). ?log(debug, Fmt, Args).
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment