Fix bugs in policy, add tests

parent 5c8cc0c5
......@@ -318,10 +318,10 @@ If list is empty, no limits will be checked.
Following policies are supported:
* `{max_connections, KEYS, NUMBER}` - EXPERIMENTAL! if there are more than NUMBER connections with
KEYS to the proxy, new connections with those KEYS will be rejected
* `{in_table, KEY, TABLE_NAME}` - only allow connections if KEY is present in TABLE_NAME (whitelist)
* `{not_in_table, KEY, TABLE_NAME}` - only allow connections if KEY is *not* present in TABLE_NAME (blacklist)
* `{max_connections, KEYS, NUMBER}` - EXPERIMENTAL! if there are more than NUMBER connections with
KEYS to the proxy, new connections with those KEYS will be rejected.
Where:
......
......@@ -214,12 +214,18 @@ handle_info(Other, S) ->
{noreply, S}.
terminate(_Reason, #state{started_at = Started, listener = Listener,
addr = {Ip, _}, policy_state = MaybeSni} = S) ->
addr = {Ip, _}, policy_state = PolicyState} = S) ->
case PolicyState of
{ok, TlsDomain} ->
try mtp_policy:dec(
application:get_env(?APP, policy, []),
Listener, Ip, MaybeSni)
Listener, Ip, TlsDomain)
catch T:R ->
?log(warning, "Failed to decrement policy: ~p:~p", [T, R])
end;
_ ->
%% Failed before policy was stored in state. Eg, because of "policy_error"
ok
end,
maybe_close_down(S),
mtp_metric:count_inc([?APP, in_connection_closed, total], 1, #{labels => [Listener]}),
......@@ -306,25 +312,26 @@ parse_upstream_data(<<?TLS_START, _/binary>> = AllData,
check_tls_policy(Listener, Ip, Meta),
Codec1 = mtp_codec:replace(tls, true, TlsCodec, Codec0),
Codec = mtp_codec:push_back(tls, Tail, Codec1),
ok = up_send_raw(Response, S),
ok = up_send_raw(Response, S), %FIXME: if this send fail, we will get counter policy leak
{ok, S#state{codec = Codec, stage = init,
policy_state = maps:get(sni_domain, Meta, undefined)}};
policy_state = {ok, maps:get(sni_domain, Meta, undefined)}}};
parse_upstream_data(<<?TLS_START, _/binary>> = Data, #state{stage = init} = S) ->
parse_upstream_data(Data, S#state{stage = tls_hello});
parse_upstream_data(<<Header:64/binary, Rest/binary>>,
#state{stage = init, secret = Secret, listener = Listener, codec = Codec0,
ad_tag = Tag, addr = {Ip, _} = Addr} = S) ->
ad_tag = Tag, addr = {Ip, _} = Addr, policy_state = PState0} = S) ->
case mtp_obfuscated:from_header(Header, Secret) of
{ok, DcId, PacketLayerMod, CryptoCodecSt} ->
maybe_check_replay(Header),
ProtoToReport =
{ProtoToReport, PState} =
case mtp_codec:info(tls, Codec0) of
{true, _} when PacketLayerMod == mtp_secure ->
mtp_secure_fake_tls;
{mtp_secure_fake_tls, PState0};
{false, _} ->
assert_protocol(PacketLayerMod),
check_policy(Listener, Ip, undefined),
PacketLayerMod
%FIXME: if any codebelow fail, we will get counter policy leak
{PacketLayerMod, {ok, undefined}}
end,
mtp_metric:count_inc([?APP, protocol_ok, total],
1, #{labels => [Listener, ProtoToReport]}),
......@@ -341,6 +348,7 @@ parse_upstream_data(<<Header:64/binary, Rest/binary>>,
S#state{down = Downstream,
dc_id = {RealDcId, Pool},
codec = Codec,
policy_state = PState,
stage = tunnel},
hibernate));
{error, Reason} when is_atom(Reason) ->
......
......@@ -49,13 +49,12 @@ check(Rules, ListenerName, ClientIp, TlsDomain) ->
end, Rules).
dec(Rules, ListenerName, ClientIp,TlsDomain) ->
%% FIXME: this is not idempotent if `check/4` returned `false`!
Vars = vars(ListenerName, ClientIp,TlsDomain),
lists:foreach(
fun({max_connections, Keys, _Max}) ->
try
Key = [val(K, Vars) || K <- Keys],
mtp_policy_counter:increment(Key)
mtp_policy_counter:decrement(Key)
catch throw:not_applicable ->
ok
end;
......@@ -77,6 +76,7 @@ check({max_connections, Keys, Max}, Vars) ->
Key = [val(K, Vars) || K <- Keys],
case mtp_policy_counter:increment(Key) of
N when N > Max ->
mtp_policy_counter:decrement(Key),
false;
_ ->
true
......
......@@ -15,6 +15,7 @@
-export([start_link/0]).
-export([increment/1,
decrement/1,
get/1,
flush/0]).
%% gen_server callbacks
......@@ -35,7 +36,7 @@ increment(Key) ->
-spec decrement(key()) -> integer().
decrement(Key) ->
try ets:update_counter(?TAB, Key, 1) of
try ets:update_counter(?TAB, Key, -1) of
New when New =< 0 ->
ets:delete(?TAB, Key),
0;
......@@ -45,6 +46,13 @@ decrement(Key) ->
0
end.
-spec get(key()) -> non_neg_integer().
get(Key) ->
case ets:lookup(?TAB, Key) of
[] -> 0;
[{_, V}] -> V
end.
%% @doc Clean all counters
flush() ->
gen_server:call(?MODULE, flush).
......
......@@ -7,16 +7,18 @@
init_per_testcase/2,
end_per_testcase/2]).
-export([echo_secure_case/1,
-export([config_change_case/1,
downstream_size_backpressure_case/1,
downstream_qlen_backpressure_case/1,
echo_secure_case/1,
echo_abridged_many_packets_case/1,
echo_tls_case/1,
ipv6_connect_case/1,
packet_too_large_case/1,
downstream_size_backpressure_case/1,
downstream_qlen_backpressure_case/1,
config_change_case/1,
policy_max_conns_case/1,
policy_whitelist_case/1,
replay_attack_case/1,
replay_attack_server_error_case/1,
ipv6_connect_case/1
replay_attack_server_error_case/1
]).
-export([set_env/2,
......@@ -67,11 +69,11 @@ echo_secure_case(Cfg) when is_list(Cfg) ->
Port = ?config(mtp_port, Cfg),
Secret = ?config(mtp_secret, Cfg),
Cli = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
Data = crypto:strong_rand_bytes(64),
Cli1 = mtp_test_client:send(Data, Cli),
{ok, Packet, Cli2} = mtp_test_client:recv_packet(Cli1, 1000),
Cli2 = ping(Cli),
?assertEqual(
1, mtp_test_metric:get_tags(
count, [?APP, protocol_ok, total], [?FUNCTION_NAME, mtp_secure])),
ok = mtp_test_client:close(Cli2),
?assertEqual(Data, Packet),
ok = mtp_test_metric:wait_for_value(
count, [?APP, in_connection_closed, total], [?FUNCTION_NAME], 1, 5000),
?assertEqual(1, mtp_test_metric:get_tags(
......@@ -129,11 +131,11 @@ echo_tls_case(Cfg) when is_list(Cfg) ->
Port = ?config(mtp_port, Cfg),
Secret = ?config(mtp_secret, Cfg),
Cli0 = mtp_test_client:connect(Host, Port, Secret, DcId, {mtp_fake_tls, <<"example.com">>}),
Data = crypto:strong_rand_bytes(64),
Cli1 = mtp_test_client:send(Data, Cli0),
{ok, Packet, Cli2} = mtp_test_client:recv_packet(Cli1, 1000),
ok = mtp_test_client:close(Cli2),
?assertEqual(Data, Packet).
Cli1 = ping(Cli0),
?assertEqual(
1, mtp_test_metric:get_tags(
count, [?APP, protocol_ok, total], [?FUNCTION_NAME, mtp_secure_fake_tls])),
ok = mtp_test_client:close(Cli1).
%% @doc test that client trying to send too big packets will be force-disconnected
......@@ -412,15 +414,88 @@ ipv6_connect_case(Cfg) when is_list(Cfg) ->
?assertEqual(not_found, ConnCount()),
?assertEqual(8, tuple_size(Host)),
Cli0 = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
Data = crypto:strong_rand_bytes(64),
Cli1 = mtp_test_client:send(Data, Cli0),
{ok, Packet, Cli2} = mtp_test_client:recv_packet(Cli1, 1000),
ok = mtp_test_client:close(Cli2),
?assertEqual(Data, Packet),
Cli1 = ping(Cli0),
ok = mtp_test_client:close(Cli1),
?assertEqual(1, ConnCount()),
ok = mtp_test_metric:wait_for_value(
count, [?APP, in_connection_closed, total], [?FUNCTION_NAME], 1, 5000).
%% @doc Test "max_connections" policy
policy_max_conns_case({pre, Cfg}) ->
Cfg1 = setup_single(?FUNCTION_NAME, 10000 + ?LINE, #{}, Cfg),
%% Allow max 2 connections from IP
set_env([{policy, [{max_connections, [port_name, client_ipv4], 2}]}], Cfg1);
policy_max_conns_case({post, Cfg}) ->
stop_single(Cfg),
reset_env(Cfg);
policy_max_conns_case(Cfg) when is_list(Cfg) ->
DcId = ?config(dc_id, Cfg),
Host = ?config(mtp_host, Cfg),
Port = ?config(mtp_port, Cfg),
Secret = ?config(mtp_secret, Cfg),
SureClose =
fun(Cli) ->
PreClosed =
mtp_test_metric:get_tags(
count, [?APP, in_connection_closed, total], [?FUNCTION_NAME]),
ok = mtp_test_client:close(Cli),
ok = mtp_test_metric:wait_for_value(
count, [?APP, in_connection_closed, total], [?FUNCTION_NAME], PreClosed + 1, 5000)
end,
Key = [?FUNCTION_NAME, mtp_policy:convert(client_ipv4, {127, 0, 0, 1})],
%% Open 2 connections, make sure 3rd one will be rejected
Cli10 = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
Cli11 = ping(Cli10),
?assertEqual(1, mtp_policy_counter:get(Key)),
Cli20 = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
_Cli21 = ping(Cli20),
?assertEqual(2, mtp_policy_counter:get(Key)),
Cli31 = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
?assertError({badmatch, {error, closed}}, ping(Cli31)),
?assertEqual(
1, mtp_test_metric:get_tags(
count, [?APP, protocol_error, total], [?FUNCTION_NAME, policy_error])),
?assertEqual(2, mtp_policy_counter:get(Key)),
%% Close 1st connection and try to connect again. This should work.
SureClose(Cli11),
?assertEqual(1, mtp_policy_counter:get(Key)),
Cli40 = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
_Cli41 = ping(Cli40),
?assertEqual(2, mtp_policy_counter:get(Key)),
ok.
%% @doc tests that connections to whitelistsed domains are allowed and not from the list disallowed
policy_whitelist_case({pre, Cfg}) ->
Cfg1 = setup_single(?FUNCTION_NAME, 10000 + ?LINE, #{}, Cfg),
%% Allow max 2 connections from IP
Domain = <<"allowed.example.com">>,
ok = mtp_policy_table:add(domain_whitelist, tls_domain, Domain),
set_env([{policy, [{in_table, tls_domain, domain_whitelist}]}],
[{domain, Domain} | Cfg1]);
policy_whitelist_case({post, Cfg}) ->
stop_single(Cfg),
reset_env(Cfg);
policy_whitelist_case(Cfg) when is_list(Cfg) ->
DcId = ?config(dc_id, Cfg),
Host = ?config(mtp_host, Cfg),
Port = ?config(mtp_port, Cfg),
Secret = ?config(mtp_secret, Cfg),
Domain = ?config(domain, Cfg),
Cli01 = mtp_test_client:connect(Host, Port, Secret, DcId,
{mtp_fake_tls, Domain}),
_Cli02 = ping(Cli01),
?assertError({badmatch, {error, closed}},
begin
Cli11 = mtp_test_client:connect(Host, Port, Secret, DcId,
{mtp_fake_tls, <<"not-", Domain/binary>>}),
ping(Cli11)
end),
?assertEqual(
1, mtp_test_metric:get_tags(
count, [?APP, protocol_error, total], [?FUNCTION_NAME, policy_error])),
ok.
%% Helpers
setup_single(Name, MtpPort, DcCfg0, Cfg) ->
......@@ -486,3 +561,10 @@ reset_env(Cfg) ->
{ok, Val} ->
application:set_env(mtproto_proxy, K, Val)
end || {K, V} <- OldEnv].
ping(Cli0) ->
Data = crypto:strong_rand_bytes(64),
Cli1 = mtp_test_client:send(Data, Cli0),
{ok, Packet, Cli2} = mtp_test_client:recv_packet(Cli1, 1000),
?assertEqual(Data, Packet),
Cli2.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment