Fix bugs in policy, add tests

parent 5c8cc0c5
...@@ -318,10 +318,10 @@ If list is empty, no limits will be checked. ...@@ -318,10 +318,10 @@ If list is empty, no limits will be checked.
Following policies are supported: Following policies are supported:
* `{max_connections, KEYS, NUMBER}` - EXPERIMENTAL! if there are more than NUMBER connections with
KEYS to the proxy, new connections with those KEYS will be rejected
* `{in_table, KEY, TABLE_NAME}` - only allow connections if KEY is present in TABLE_NAME (whitelist) * `{in_table, KEY, TABLE_NAME}` - only allow connections if KEY is present in TABLE_NAME (whitelist)
* `{not_in_table, KEY, TABLE_NAME}` - only allow connections if KEY is *not* present in TABLE_NAME (blacklist) * `{not_in_table, KEY, TABLE_NAME}` - only allow connections if KEY is *not* present in TABLE_NAME (blacklist)
* `{max_connections, KEYS, NUMBER}` - EXPERIMENTAL! if there are more than NUMBER connections with
KEYS to the proxy, new connections with those KEYS will be rejected.
Where: Where:
......
...@@ -214,12 +214,18 @@ handle_info(Other, S) -> ...@@ -214,12 +214,18 @@ handle_info(Other, S) ->
{noreply, S}. {noreply, S}.
terminate(_Reason, #state{started_at = Started, listener = Listener, terminate(_Reason, #state{started_at = Started, listener = Listener,
addr = {Ip, _}, policy_state = MaybeSni} = S) -> addr = {Ip, _}, policy_state = PolicyState} = S) ->
try mtp_policy:dec( case PolicyState of
application:get_env(?APP, policy, []), {ok, TlsDomain} ->
Listener, Ip, MaybeSni) try mtp_policy:dec(
catch T:R -> application:get_env(?APP, policy, []),
?log(warning, "Failed to decrement policy: ~p:~p", [T, R]) Listener, Ip, TlsDomain)
catch T:R ->
?log(warning, "Failed to decrement policy: ~p:~p", [T, R])
end;
_ ->
%% Failed before policy was stored in state. Eg, because of "policy_error"
ok
end, end,
maybe_close_down(S), maybe_close_down(S),
mtp_metric:count_inc([?APP, in_connection_closed, total], 1, #{labels => [Listener]}), mtp_metric:count_inc([?APP, in_connection_closed, total], 1, #{labels => [Listener]}),
...@@ -306,25 +312,26 @@ parse_upstream_data(<<?TLS_START, _/binary>> = AllData, ...@@ -306,25 +312,26 @@ parse_upstream_data(<<?TLS_START, _/binary>> = AllData,
check_tls_policy(Listener, Ip, Meta), check_tls_policy(Listener, Ip, Meta),
Codec1 = mtp_codec:replace(tls, true, TlsCodec, Codec0), Codec1 = mtp_codec:replace(tls, true, TlsCodec, Codec0),
Codec = mtp_codec:push_back(tls, Tail, Codec1), Codec = mtp_codec:push_back(tls, Tail, Codec1),
ok = up_send_raw(Response, S), ok = up_send_raw(Response, S), %FIXME: if this send fail, we will get counter policy leak
{ok, S#state{codec = Codec, stage = init, {ok, S#state{codec = Codec, stage = init,
policy_state = maps:get(sni_domain, Meta, undefined)}}; policy_state = {ok, maps:get(sni_domain, Meta, undefined)}}};
parse_upstream_data(<<?TLS_START, _/binary>> = Data, #state{stage = init} = S) -> parse_upstream_data(<<?TLS_START, _/binary>> = Data, #state{stage = init} = S) ->
parse_upstream_data(Data, S#state{stage = tls_hello}); parse_upstream_data(Data, S#state{stage = tls_hello});
parse_upstream_data(<<Header:64/binary, Rest/binary>>, parse_upstream_data(<<Header:64/binary, Rest/binary>>,
#state{stage = init, secret = Secret, listener = Listener, codec = Codec0, #state{stage = init, secret = Secret, listener = Listener, codec = Codec0,
ad_tag = Tag, addr = {Ip, _} = Addr} = S) -> ad_tag = Tag, addr = {Ip, _} = Addr, policy_state = PState0} = S) ->
case mtp_obfuscated:from_header(Header, Secret) of case mtp_obfuscated:from_header(Header, Secret) of
{ok, DcId, PacketLayerMod, CryptoCodecSt} -> {ok, DcId, PacketLayerMod, CryptoCodecSt} ->
maybe_check_replay(Header), maybe_check_replay(Header),
ProtoToReport = {ProtoToReport, PState} =
case mtp_codec:info(tls, Codec0) of case mtp_codec:info(tls, Codec0) of
{true, _} when PacketLayerMod == mtp_secure -> {true, _} when PacketLayerMod == mtp_secure ->
mtp_secure_fake_tls; {mtp_secure_fake_tls, PState0};
{false, _} -> {false, _} ->
assert_protocol(PacketLayerMod), assert_protocol(PacketLayerMod),
check_policy(Listener, Ip, undefined), check_policy(Listener, Ip, undefined),
PacketLayerMod %FIXME: if any codebelow fail, we will get counter policy leak
{PacketLayerMod, {ok, undefined}}
end, end,
mtp_metric:count_inc([?APP, protocol_ok, total], mtp_metric:count_inc([?APP, protocol_ok, total],
1, #{labels => [Listener, ProtoToReport]}), 1, #{labels => [Listener, ProtoToReport]}),
...@@ -341,6 +348,7 @@ parse_upstream_data(<<Header:64/binary, Rest/binary>>, ...@@ -341,6 +348,7 @@ parse_upstream_data(<<Header:64/binary, Rest/binary>>,
S#state{down = Downstream, S#state{down = Downstream,
dc_id = {RealDcId, Pool}, dc_id = {RealDcId, Pool},
codec = Codec, codec = Codec,
policy_state = PState,
stage = tunnel}, stage = tunnel},
hibernate)); hibernate));
{error, Reason} when is_atom(Reason) -> {error, Reason} when is_atom(Reason) ->
......
...@@ -49,13 +49,12 @@ check(Rules, ListenerName, ClientIp, TlsDomain) -> ...@@ -49,13 +49,12 @@ check(Rules, ListenerName, ClientIp, TlsDomain) ->
end, Rules). end, Rules).
dec(Rules, ListenerName, ClientIp,TlsDomain) -> dec(Rules, ListenerName, ClientIp,TlsDomain) ->
%% FIXME: this is not idempotent if `check/4` returned `false`!
Vars = vars(ListenerName, ClientIp,TlsDomain), Vars = vars(ListenerName, ClientIp,TlsDomain),
lists:foreach( lists:foreach(
fun({max_connections, Keys, _Max}) -> fun({max_connections, Keys, _Max}) ->
try try
Key = [val(K, Vars) || K <- Keys], Key = [val(K, Vars) || K <- Keys],
mtp_policy_counter:increment(Key) mtp_policy_counter:decrement(Key)
catch throw:not_applicable -> catch throw:not_applicable ->
ok ok
end; end;
...@@ -77,6 +76,7 @@ check({max_connections, Keys, Max}, Vars) -> ...@@ -77,6 +76,7 @@ check({max_connections, Keys, Max}, Vars) ->
Key = [val(K, Vars) || K <- Keys], Key = [val(K, Vars) || K <- Keys],
case mtp_policy_counter:increment(Key) of case mtp_policy_counter:increment(Key) of
N when N > Max -> N when N > Max ->
mtp_policy_counter:decrement(Key),
false; false;
_ -> _ ->
true true
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
-export([start_link/0]). -export([start_link/0]).
-export([increment/1, -export([increment/1,
decrement/1, decrement/1,
get/1,
flush/0]). flush/0]).
%% gen_server callbacks %% gen_server callbacks
...@@ -35,7 +36,7 @@ increment(Key) -> ...@@ -35,7 +36,7 @@ increment(Key) ->
-spec decrement(key()) -> integer(). -spec decrement(key()) -> integer().
decrement(Key) -> decrement(Key) ->
try ets:update_counter(?TAB, Key, 1) of try ets:update_counter(?TAB, Key, -1) of
New when New =< 0 -> New when New =< 0 ->
ets:delete(?TAB, Key), ets:delete(?TAB, Key),
0; 0;
...@@ -45,6 +46,13 @@ decrement(Key) -> ...@@ -45,6 +46,13 @@ decrement(Key) ->
0 0
end. end.
-spec get(key()) -> non_neg_integer().
get(Key) ->
case ets:lookup(?TAB, Key) of
[] -> 0;
[{_, V}] -> V
end.
%% @doc Clean all counters %% @doc Clean all counters
flush() -> flush() ->
gen_server:call(?MODULE, flush). gen_server:call(?MODULE, flush).
......
...@@ -7,16 +7,18 @@ ...@@ -7,16 +7,18 @@
init_per_testcase/2, init_per_testcase/2,
end_per_testcase/2]). end_per_testcase/2]).
-export([echo_secure_case/1, -export([config_change_case/1,
downstream_size_backpressure_case/1,
downstream_qlen_backpressure_case/1,
echo_secure_case/1,
echo_abridged_many_packets_case/1, echo_abridged_many_packets_case/1,
echo_tls_case/1, echo_tls_case/1,
ipv6_connect_case/1,
packet_too_large_case/1, packet_too_large_case/1,
downstream_size_backpressure_case/1, policy_max_conns_case/1,
downstream_qlen_backpressure_case/1, policy_whitelist_case/1,
config_change_case/1,
replay_attack_case/1, replay_attack_case/1,
replay_attack_server_error_case/1, replay_attack_server_error_case/1
ipv6_connect_case/1
]). ]).
-export([set_env/2, -export([set_env/2,
...@@ -67,11 +69,11 @@ echo_secure_case(Cfg) when is_list(Cfg) -> ...@@ -67,11 +69,11 @@ echo_secure_case(Cfg) when is_list(Cfg) ->
Port = ?config(mtp_port, Cfg), Port = ?config(mtp_port, Cfg),
Secret = ?config(mtp_secret, Cfg), Secret = ?config(mtp_secret, Cfg),
Cli = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure), Cli = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
Data = crypto:strong_rand_bytes(64), Cli2 = ping(Cli),
Cli1 = mtp_test_client:send(Data, Cli), ?assertEqual(
{ok, Packet, Cli2} = mtp_test_client:recv_packet(Cli1, 1000), 1, mtp_test_metric:get_tags(
count, [?APP, protocol_ok, total], [?FUNCTION_NAME, mtp_secure])),
ok = mtp_test_client:close(Cli2), ok = mtp_test_client:close(Cli2),
?assertEqual(Data, Packet),
ok = mtp_test_metric:wait_for_value( ok = mtp_test_metric:wait_for_value(
count, [?APP, in_connection_closed, total], [?FUNCTION_NAME], 1, 5000), count, [?APP, in_connection_closed, total], [?FUNCTION_NAME], 1, 5000),
?assertEqual(1, mtp_test_metric:get_tags( ?assertEqual(1, mtp_test_metric:get_tags(
...@@ -129,11 +131,11 @@ echo_tls_case(Cfg) when is_list(Cfg) -> ...@@ -129,11 +131,11 @@ echo_tls_case(Cfg) when is_list(Cfg) ->
Port = ?config(mtp_port, Cfg), Port = ?config(mtp_port, Cfg),
Secret = ?config(mtp_secret, Cfg), Secret = ?config(mtp_secret, Cfg),
Cli0 = mtp_test_client:connect(Host, Port, Secret, DcId, {mtp_fake_tls, <<"example.com">>}), Cli0 = mtp_test_client:connect(Host, Port, Secret, DcId, {mtp_fake_tls, <<"example.com">>}),
Data = crypto:strong_rand_bytes(64), Cli1 = ping(Cli0),
Cli1 = mtp_test_client:send(Data, Cli0), ?assertEqual(
{ok, Packet, Cli2} = mtp_test_client:recv_packet(Cli1, 1000), 1, mtp_test_metric:get_tags(
ok = mtp_test_client:close(Cli2), count, [?APP, protocol_ok, total], [?FUNCTION_NAME, mtp_secure_fake_tls])),
?assertEqual(Data, Packet). ok = mtp_test_client:close(Cli1).
%% @doc test that client trying to send too big packets will be force-disconnected %% @doc test that client trying to send too big packets will be force-disconnected
...@@ -412,15 +414,88 @@ ipv6_connect_case(Cfg) when is_list(Cfg) -> ...@@ -412,15 +414,88 @@ ipv6_connect_case(Cfg) when is_list(Cfg) ->
?assertEqual(not_found, ConnCount()), ?assertEqual(not_found, ConnCount()),
?assertEqual(8, tuple_size(Host)), ?assertEqual(8, tuple_size(Host)),
Cli0 = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure), Cli0 = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
Data = crypto:strong_rand_bytes(64), Cli1 = ping(Cli0),
Cli1 = mtp_test_client:send(Data, Cli0), ok = mtp_test_client:close(Cli1),
{ok, Packet, Cli2} = mtp_test_client:recv_packet(Cli1, 1000),
ok = mtp_test_client:close(Cli2),
?assertEqual(Data, Packet),
?assertEqual(1, ConnCount()), ?assertEqual(1, ConnCount()),
ok = mtp_test_metric:wait_for_value( ok = mtp_test_metric:wait_for_value(
count, [?APP, in_connection_closed, total], [?FUNCTION_NAME], 1, 5000). count, [?APP, in_connection_closed, total], [?FUNCTION_NAME], 1, 5000).
%% @doc Test "max_connections" policy
policy_max_conns_case({pre, Cfg}) ->
Cfg1 = setup_single(?FUNCTION_NAME, 10000 + ?LINE, #{}, Cfg),
%% Allow max 2 connections from IP
set_env([{policy, [{max_connections, [port_name, client_ipv4], 2}]}], Cfg1);
policy_max_conns_case({post, Cfg}) ->
stop_single(Cfg),
reset_env(Cfg);
policy_max_conns_case(Cfg) when is_list(Cfg) ->
DcId = ?config(dc_id, Cfg),
Host = ?config(mtp_host, Cfg),
Port = ?config(mtp_port, Cfg),
Secret = ?config(mtp_secret, Cfg),
SureClose =
fun(Cli) ->
PreClosed =
mtp_test_metric:get_tags(
count, [?APP, in_connection_closed, total], [?FUNCTION_NAME]),
ok = mtp_test_client:close(Cli),
ok = mtp_test_metric:wait_for_value(
count, [?APP, in_connection_closed, total], [?FUNCTION_NAME], PreClosed + 1, 5000)
end,
Key = [?FUNCTION_NAME, mtp_policy:convert(client_ipv4, {127, 0, 0, 1})],
%% Open 2 connections, make sure 3rd one will be rejected
Cli10 = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
Cli11 = ping(Cli10),
?assertEqual(1, mtp_policy_counter:get(Key)),
Cli20 = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
_Cli21 = ping(Cli20),
?assertEqual(2, mtp_policy_counter:get(Key)),
Cli31 = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
?assertError({badmatch, {error, closed}}, ping(Cli31)),
?assertEqual(
1, mtp_test_metric:get_tags(
count, [?APP, protocol_error, total], [?FUNCTION_NAME, policy_error])),
?assertEqual(2, mtp_policy_counter:get(Key)),
%% Close 1st connection and try to connect again. This should work.
SureClose(Cli11),
?assertEqual(1, mtp_policy_counter:get(Key)),
Cli40 = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure),
_Cli41 = ping(Cli40),
?assertEqual(2, mtp_policy_counter:get(Key)),
ok.
%% @doc tests that connections to whitelistsed domains are allowed and not from the list disallowed
policy_whitelist_case({pre, Cfg}) ->
Cfg1 = setup_single(?FUNCTION_NAME, 10000 + ?LINE, #{}, Cfg),
%% Allow max 2 connections from IP
Domain = <<"allowed.example.com">>,
ok = mtp_policy_table:add(domain_whitelist, tls_domain, Domain),
set_env([{policy, [{in_table, tls_domain, domain_whitelist}]}],
[{domain, Domain} | Cfg1]);
policy_whitelist_case({post, Cfg}) ->
stop_single(Cfg),
reset_env(Cfg);
policy_whitelist_case(Cfg) when is_list(Cfg) ->
DcId = ?config(dc_id, Cfg),
Host = ?config(mtp_host, Cfg),
Port = ?config(mtp_port, Cfg),
Secret = ?config(mtp_secret, Cfg),
Domain = ?config(domain, Cfg),
Cli01 = mtp_test_client:connect(Host, Port, Secret, DcId,
{mtp_fake_tls, Domain}),
_Cli02 = ping(Cli01),
?assertError({badmatch, {error, closed}},
begin
Cli11 = mtp_test_client:connect(Host, Port, Secret, DcId,
{mtp_fake_tls, <<"not-", Domain/binary>>}),
ping(Cli11)
end),
?assertEqual(
1, mtp_test_metric:get_tags(
count, [?APP, protocol_error, total], [?FUNCTION_NAME, policy_error])),
ok.
%% Helpers %% Helpers
setup_single(Name, MtpPort, DcCfg0, Cfg) -> setup_single(Name, MtpPort, DcCfg0, Cfg) ->
...@@ -486,3 +561,10 @@ reset_env(Cfg) -> ...@@ -486,3 +561,10 @@ reset_env(Cfg) ->
{ok, Val} -> {ok, Val} ->
application:set_env(mtproto_proxy, K, Val) application:set_env(mtproto_proxy, K, Val)
end || {K, V} <- OldEnv]. end || {K, V} <- OldEnv].
ping(Cli0) ->
Data = crypto:strong_rand_bytes(64),
Cli1 = mtp_test_client:send(Data, Cli0),
{ok, Packet, Cli2} = mtp_test_client:recv_packet(Cli1, 1000),
?assertEqual(Data, Packet),
Cli2.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment