Add multiplexing draft

parent 8dd45117
......@@ -3,7 +3,8 @@
{parse_transform, lager_transform}]}.
{deps, [{ranch, "1.7.0"},
{lager, "3.6.3"}
{lager, "3.6.3"},
{psq, {git, "https://github.com/eryx67/psq.git", {branch, "master"}}}
]}.
{xref_checks,
......
{"1.1.0",
[{<<"goldrush">>,{pkg,<<"goldrush">>,<<"0.1.9">>},1},
{<<"lager">>,{pkg,<<"lager">>,<<"3.6.3">>},0},
{<<"psq">>,
{git,"https://github.com/eryx67/psq.git",
{ref,"acf8cb6620a9f9cb6123cc45aeb8767fa1a2ab08"}},
0},
{<<"ranch">>,{pkg,<<"ranch">>,<<"1.7.0">>},0}]}.
[
{pkg_hash,[
......
......@@ -15,21 +15,34 @@
%% API
-export([start_link/0]).
-export([get_downstream/1,
get_downstream_safe/1,
-export([get_downstream_safe/2,
get_downstream_pool/1,
get_netloc/1,
get_netloc_safe/1,
get_secret/0]).
-export([register_name/2,
unregister_name/1,
whereis_name/1,
send/2]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
-type dc_id() :: integer().
-type netloc() :: {inet:ip4_address(), inet:port_number()}.
-define(TAB, ?MODULE).
-define(IPS_KEY(DcId), {id, DcId}).
-define(POOL_KEY(DcId), {pool, DcId}).
-define(IDS_KEY, dc_ids).
-define(SECRET_URL, "https://core.telegram.org/getProxySecret").
-define(CONFIG_URL, "https://core.telegram.org/getProxyConfig").
-define(APP, mtproto_proxy).
-record(state, {tab :: ets:tid(),
monitors = #{} :: #{pid() => {reference(), dc_id()}},
timer :: gen_timeout:tout()}).
-ifndef(OTP_RELEASE). % pre-OTP21
......@@ -45,28 +58,74 @@
start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
-spec get_downstream(integer()) -> {ok, {inet:ip4_address(), inet:port_number()}}.
get_downstream_safe(DcId) ->
case get_downstream(DcId) of
{ok, Addr} -> Addr;
-spec get_downstream_safe(dc_id(), mtp_down_conn:upstream_opts()) ->
{dc_id(), pid(), mtp_down_conn:handle()}.
get_downstream_safe(DcId, Opts) ->
case get_downstream_pool(DcId) of
{ok, Pool} ->
Downstream = mtp_dc_pool:get(Pool, self(), Opts),
{DcId, Pool, Downstream};
not_found ->
[{_, L}] = ets:lookup(?TAB, id_range),
[{?IDS_KEY, L}] = ets:lookup(?TAB, ?IDS_KEY),
NewDcId = random_choice(L),
get_downstream_safe(NewDcId, Opts)
end.
get_downstream_pool(DcId) ->
Key = ?POOL_KEY(DcId),
case ets:lookup(?TAB, Key) of
[] -> not_found;
[{Key, PoolPid}] ->
{ok, PoolPid}
end.
-spec get_netloc_safe(dc_id()) -> {dc_id(), netloc()}.
get_netloc_safe(DcId) ->
case get_netloc(DcId) of
{ok, Addr} -> {DcId, Addr};
not_found ->
[{?IDS_KEY, L}] = ets:lookup(?TAB, ?IDS_KEY),
NewDcId = random_choice(L),
%% Get random DC; it might return 0 and recurse aggain
get_downstream_safe(NewDcId)
get_netloc_safe(NewDcId)
end.
get_downstream(DcId) ->
case ets:lookup(?TAB, {id, DcId}) of
get_netloc(DcId) ->
Key = ?IPS_KEY(DcId),
case ets:lookup(?TAB, Key) of
[] ->
not_found;
[{_, Ip, Port}] ->
{ok, {Ip, Port}};
L ->
{_, Ip, Port} = random_choice(L),
{ok, {Ip, Port}}
[{Key, [{_, _} = IpPort]}] ->
{ok, IpPort};
[{Key, L}] ->
IpPort = random_choice(L),
{ok, IpPort}
end.
register_name(DcId, Pid) ->
case ets:insert_new(?TAB, {?POOL_KEY(DcId), Pid}) of
true ->
gen_server:cast(?MODULE, {reg, DcId, Pid}),
yes;
false -> no
end.
unregister_name(DcId) ->
%% making async monitors is a bad idea..
Pid = whereis_name(DcId),
gen_server:cast(?MODULE, {unreg, DcId, Pid}),
ets:delete(?TAB, ?POOL_KEY(DcId)).
whereis_name(DcId) ->
case get_downstream_pool(DcId) of
not_found -> undefined;
{ok, PoolPid} -> PoolPid
end.
send(Name, Msg) ->
whereis_name(Name) ! Msg.
-spec get_secret() -> binary().
get_secret() ->
[{_, Key}] = ets:lookup(?TAB, key),
......@@ -79,8 +138,8 @@ init([]) ->
Timer = gen_timeout:new(
#{timeout => {env, ?APP, conf_refresh_interval, 3600},
unit => second}),
Tab = ets:new(?TAB, [bag,
protected,
Tab = ets:new(?TAB, [set,
public,
named_table,
{read_concurrency, true}]),
State = #state{tab = Tab,
......@@ -92,8 +151,14 @@ init([]) ->
handle_call(_Request, _From, State) ->
Reply = ok,
{reply, Reply, State}.
handle_cast(_Msg, State) ->
{noreply, State}.
handle_cast({reg, DcId, Pid}, #state{monitors = Mons} = State) ->
Ref = erlang:monitor(process, Pid),
Mons1 = Mons#{Pid => {Ref, DcId}},
{noreply, State#state{monitors = Mons1}};
handle_cast({unreg, DcId, Pid}, #state{monitors = Mons} = State) ->
{{Ref, DcId}, Mons1} = maps:take(Pid, Mons),
erlang:demonitor(Ref, [flush]),
{noreply, State#state{monitors = Mons1}}.
handle_info(timeout, #state{timer = Timer} =State) ->
case gen_timeout:is_expired(Timer) of
true ->
......@@ -105,8 +170,10 @@ handle_info(timeout, #state{timer = Timer} =State) ->
false ->
{noreply, State#state{timer = gen_timeout:reset(Timer)}}
end;
handle_info(_Info, State) ->
{noreply, State}.
handle_info({'DOWN', MonRef, process, Pid, _Reason}, #state{monitors = Mons} = State) ->
{{MonRef, DcId}, Mons1} = maps:take(Pid, Mons),
ets:delete(?TAB, ?POOL_KEY(DcId)),
{noreply, State#state{monitors = Mons1}}.
terminate(_Reason, _State) ->
ok.
code_change(_OldVsn, State, _Extra) ->
......@@ -117,9 +184,9 @@ code_change(_OldVsn, State, _Extra) ->
%%%===================================================================
update(#state{tab = Tab}, force) ->
update_ip(),
update_key(Tab),
update_config(Tab),
update_ip();
update_config(Tab);
update(State, _) ->
try update(State, force)
catch ?WITH_STACKTRACE(Class, Reason, Stack)
......@@ -135,9 +202,8 @@ update_key(Tab) ->
update_config(Tab) ->
{ok, Body} = http_get(?CONFIG_URL),
Downstreams = parse_config(Body),
Range = get_range(Downstreams),
update_downstreams(Downstreams, Tab),
update_range(Range, Tab).
update_ids(Downstreams, Tab).
parse_config(Body) ->
Lines = string:lexemes(Body, "\n"),
......@@ -158,15 +224,30 @@ parse_downstream(Line) ->
IpAddr,
Port}.
get_range(Downstreams) ->
[Id || {Id, _, _} <- Downstreams].
update_downstreams(Downstreams, Tab) ->
[true = ets:insert(Tab, {{id, Id}, Ip, Port})
|| {Id, Ip, Port} <- Downstreams].
ByDc = lists:foldl(
fun({DcId, Ip, Port}, Acc) ->
Netlocs = maps:get(DcId, Acc, []),
Acc#{DcId => [{Ip, Port} | Netlocs]}
end, #{}, Downstreams),
[true = ets:insert(Tab, {?IPS_KEY(DcId), Netlocs})
|| {DcId, Netlocs} <- maps:to_list(ByDc)],
lists:foreach(
fun(DcId) ->
case get_downstream_pool(DcId) of
not_found ->
%% process will be registered asynchronously by
%% gen_server:start_link({via, ..
{ok, _Pid} = mtp_dc_pool_sup:start_pool(DcId);
{ok, _} ->
ok
end
end,
maps:keys(ByDc)).
update_range(Range, Tab) ->
true = ets:insert(Tab, {id_range, Range}).
update_ids(Downstreams, Tab) ->
Ids = lists:usort([DcId || {DcId, _, _} <- Downstreams]),
true = ets:insert(Tab, {?IDS_KEY, Ids}).
update_ip() ->
case application:get_env(?APP, ip_lookup_services) of
......@@ -201,6 +282,7 @@ random_choice(L) ->
Idx = rand:uniform(length(L)),
lists:nth(Idx, L).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
......
%%%-------------------------------------------------------------------
%%% @author Sergey <me@seriyps.ru>
%%% @copyright (C) 2018, Sergey
%%% @doc
%%% Process that manages pool of connections to telegram datacenter
%%% and is responsible for load-balancing between them
%%% @end
%%% TODO: monitoring of DC connections! Make 100% sure they are killed when pool
%%% is killed. Maybe link?
%%% Created : 14 Oct 2018 by Sergey <me@seriyps.ru>
%%%-------------------------------------------------------------------
-module(mtp_dc_pool).
-behaviour(gen_server).
%% API
-export([start_link/1,
get/3,
return/2,
add_connection/1,
ack_connected/2]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
-define(SERVER, ?MODULE).
-define(APP, mtproto_proxy).
-type upstream() :: mtp_handler:handle().
-type downstream() :: mtp_down_conn:handle().
-type ds_store() :: psq:psq().
-record(state,
{dc_id :: mtp_config:dc_id(),
upstreams = #{} :: #{upstream() => downstream()},
pending_downstreams = [] :: [pid()],
downstreams :: ds_store()
}).
%%%===================================================================
%%% API
%%%===================================================================
start_link(DcId) ->
gen_server:start_link({via, mtp_config, DcId}, ?MODULE, DcId, []).
get(Pool, Upstream, #{addr := _} = Opts) ->
gen_server:call(Pool, {get, Upstream, Opts}).
return(Pool, Upstream) ->
gen_server:cast(Pool, {return, Upstream}).
add_connection(Pool) ->
gen_server:call(Pool, add_connection, 10000).
ack_connected(Pool, Downstream) ->
gen_server:cast(Pool, {connected, Downstream}).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init(DcId) ->
InitConnections = application:get_env(mtproto_proxy, init_dc_connections, 4),
PendingConnections = [do_connect(DcId) || _ <- lists:seq(1, InitConnections)],
Connections = recv_pending(PendingConnections),
Downstreams = ds_new(Connections),
{ok, #state{dc_id = DcId, downstreams = Downstreams}}.
handle_call({get, Upstream, Opts}, _From, State) ->
{Downstream, State1} = handle_get(Upstream, Opts, State),
{reply, Downstream, State1};
handle_call(add_connection, _From, State) ->
State1 = connect(State),
{reply, ok, State1}.
handle_cast({return, Upstream}, State) ->
{noreply, handle_return(Upstream, State)};
handle_cast({connected, Pid}, State) ->
{noreply, handle_connected(Pid, State)}.
handle_info({'DOWN', MonitorRef, process, Pid, _Reason}, State) ->
%% TODO: monitor downstream connections as well
{noreply, handle_down(MonitorRef, Pid, State)}.
terminate(_Reason, #state{downstreams = Ds}) ->
ds_foreach(
fun(Pid) ->
mtp_down_conn:shutdown(Pid)
end, Ds),
%% upstreams will be killed by connection itself
ok.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%%===================================================================
%%% Internal functions
%%%===================================================================
%% Handle async connection ack
handle_connected(Pid, #state{pending_downstreams = Pending,
downstreams = Ds} = St) ->
Pending1 = lists:delete(Pid, Pending),
Downstreams1 = ds_add_downstream(Pid, Ds),
St#state{pending_downstreams = Pending1,
downstreams = Downstreams1}.
handle_get(Upstream, Opts, #state{downstreams = Ds,
upstreams = Us} = St) ->
{Downstream, N, Ds1} = ds_get(Ds),
MonRef = erlang:monitor(process, Upstream),
%% if N > X and len(pending) < Y -> connect()
Us1 = Us#{Upstream => {Downstream, MonRef}},
ok = mtp_down_conn:upstream_new(Downstream, Upstream, Opts),
{Downstream, maybe_spawn_connection(
N,
St#state{downstreams = Ds1,
upstreams = Us1})}.
handle_return(Upstream, #state{downstreams = Ds,
upstreams = Us} = St) ->
{{Downstream, MonRef}, Us1} = maps:take(Upstream, Us),
ok = mtp_down_conn:upstream_closed(Downstream, Upstream),
erlang:demonitor(MonRef, [flush]),
Ds1 = ds_return(Downstream, Ds),
St#state{downstreams = Ds1,
upstreams = Us1}.
handle_down(MonRef, MaybeUpstream, #state{downstreams = Ds,
upstreams = Us} = St) ->
case maps:take(MaybeUpstream, Us) of
{{Downstream, MonRef}, Us1} ->
ok = mtp_down_conn:upstream_closed(Downstream, MaybeUpstream),
Ds1 = ds_return(Downstream, Ds),
St#state{downstreams = Ds1,
upstreams = Us1};
error ->
lager:warning("Unexpected DOWN. ref=~p, pid=~p", [MonRef, MaybeUpstream]),
St
end.
maybe_spawn_connection(CurrentMin, #state{pending_downstreams = Pending} = St) ->
%% TODO: shrinking (by timer)
case application:get_env(?APP, clients_per_dc_connection) of
{ok, N} when CurrentMin > N,
Pending == [] ->
ToSpawn = 2,
lists:foldl(
fun(_, S) ->
connect(S)
end, St, lists:seq(1, ToSpawn));
_ ->
St
end.
%% Initiate new async connection
connect(#state{pending_downstreams = Pending,
dc_id = DcId} = St) ->
%% Should monitor connection PIDs as well!
Pid = do_connect(DcId),
St#state{pending_downstreams = [Pid | Pending]}.
%% Asynchronous connect
do_connect(DcId) ->
{ok, Pid} = mtp_down_conn_sup:start_conn(self(), DcId),
Pid.
%% Block until all async connections are acked
recv_pending(Pids) ->
[receive
{'$gen_cast', {connected, Pid}} -> Pid
after 10000 ->
exit({timeout, receive Smth -> Smth after 0 -> none end})
end || Pid <- Pids].
%% New downstream connection storage
-spec ds_new([downstream()]) -> ds_store().
ds_new(Connections) ->
Psq = pid_psq:new(),
%% TODO: add `from_list` function
lists:foldl(
fun(Conn, Psq1) ->
pid_psq:add(Conn, Psq1)
end, Psq, Connections).
-spec ds_foreach(fun( (downstream()) -> any() ), ds_store()) -> ok.
ds_foreach(Fun, St) ->
psq:fold(
fun(_, _N, Pid, _) ->
Fun(Pid)
end, ok, St).
%% Add new downstream to storage
-spec ds_add_downstream(downstream(), ds_store()) -> ds_store().
ds_add_downstream(Conn, St) ->
pid_psq:add(Conn, St).
%% Get least loaded downstream connection
-spec ds_get(ds_store()) -> {downstream(), pos_integer(), ds_store()}.
ds_get(St) ->
%% TODO: should return real number of connections
{ok, {{Conn, N}, St1}} = pid_psq:get_min_priority(St),
{Conn, N, St1}.
%% Return connection back to storage
-spec ds_return(downstream(), ds_store()) -> ds_store().
ds_return(Pid, St) ->
{ok, St1} = pid_psq:dec_priority(Pid, St),
St1.
%%%-------------------------------------------------------------------
%%% @author Sergey <me@seriyps.ru>
%%% @copyright (C) 2018, Sergey
%%% @doc
%%% Supervisor for mtp_dc_pool processes
%%% @end
%%% Created : 14 Oct 2018 by Sergey <me@seriyps.ru>
%%%-------------------------------------------------------------------
-module(mtp_dc_pool_sup).
-behaviour(supervisor).
-export([start_link/0,
start_pool/1]).
-export([init/1]).
-define(SERVER, ?MODULE).
start_link() ->
supervisor:start_link({local, ?SERVER}, ?MODULE, []).
-spec start_pool(mtp_config:dc_id()) -> {ok, pid()}.
start_pool(DcId) ->
%% Or maybe it should read IPs from mtp_config by itself?
supervisor:start_child(?SERVER, [DcId]).
init([]) ->
SupFlags = #{strategy => simple_one_for_one,
intensity => 50,
period => 5},
AChild = #{id => mtp_dc_pool,
start => {mtp_dc_pool, start_link, []},
restart => permanent,
shutdown => 10000,
type => worker},
{ok, {SupFlags, [AChild]}}.
%%%-------------------------------------------------------------------
%%% @author Sergey <me@seriyps.ru>
%%% @copyright (C) 2018, Sergey
%%% @doc
%%% Process holding connection to downstream and doing multiplexing
%%% @end
%%% Created : 14 Oct 2018 by Sergey <me@seriyps.ru>
%%%-------------------------------------------------------------------
-module(mtp_down_conn).
-behaviour(gen_server).
%% API
-export([start_link/2,
upstream_new/3,
upstream_closed/2,
shutdown/1,
send/2]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
-export_type([handle/0, upstream_opts/0]).
-define(SERVER, ?MODULE).
-define(APP, mtproto_proxy).
-define(CONN_TIMEOUT, 10000).
-define(SEND_TIMEOUT, 15000).
-define(MAX_SOCK_BUF_SIZE, 1024 * 300). % Decrease if CPU is cheaper than RAM
-type handle() :: pid().
-type upstream_opts() :: #{addr := mtp_config:netloc(), % IP/Port of TG client
ad_tag => binary()}.
-type upstream() :: {
_ConnId :: mtp_rpc:conn_id(),
_Addr :: binary(),
_AdTag :: binary() | undefined
}.
-type stage() :: init | handshake_1 | handshake_2 | tunnel.
-record(state, {stage = init :: stage(),
stage_state = [] :: any(),
sock :: gen_tcp:socket() | undefined,
addr_bin :: binary(), % my external ip:port
codec :: mtp_layer:layer() | undefined,
upstreams = #{} :: #{mtp_handler:handle() => upstream()},
upstreams_rev = #{} :: #{mtp_rpc:conn_id() => mtp_handler:handle()},
pool :: pid(),
dc_id :: mtp_config:dc_id(),
netloc :: mtp_config:netloc() % telegram server ip:port
}).
start_link(Pool, DcId) ->
gen_server:start_link(?MODULE, [Pool, DcId], []).
%% To be called by mtp_dc_pool
upstream_new(Conn, Upstream, #{addr := _} = Opts) ->
gen_server:cast(Conn, {upstream_new, Upstream, Opts}).
%% To be called by mtp_dc_pool
upstream_closed(Conn, Upstream) ->
gen_server:cast(Conn, {upstream_closed, Upstream}).
%% To be called by mtp_dc_pool
shutdown(Conn) ->
gen_server:cast(Conn, shutdown).
%% To be called by upstream
-spec send(handle(), iodata()) -> ok.
send(Conn, Data) ->
gen_server:call(Conn, {send, Data}, ?SEND_TIMEOUT * 2).
init([Pool, DcId]) ->
self() ! do_connect,
{ok, #state{pool = Pool,
dc_id = DcId}}.
handle_call({send, Data}, {Upstream, _}, State) ->
{Res, State1} = handle_send(Data, Upstream, State),
{reply, Res, State1}.
handle_cast({upstream_new, Upstream, Opts}, State) ->
{noreply, handle_upstream_new(Upstream, Opts, State)};
handle_cast({upstream_closed, Upstream}, State) ->
{ok, St} = handle_upstream_closed(Upstream, State),
{noreply, St};
handle_cast(shutdown, State) ->
{stop, shutdown, State}.
handle_info({tcp, Sock, Data}, #state{sock = Sock} = S) ->
case handle_downstream_data(Data, S) of
{ok, S1} ->
ok = inet:setopts(Sock, [{active, once}]),
{noreply, S1};
{error, Reason} ->
lager:error("Error sending tunnelled data to in socket: ~p", [Reason]),
{stop, normal, S}
end;
handle_info({tcp_closed, Sock}, #state{sock = Sock} = State) ->
{stop, normal, State};
handle_info({tcp_error, Sock, Reason}, #state{sock = Sock} = State) ->
{stop, Reason, State};
handle_info(do_connect, #state{dc_id = DcId} = State) ->
try
{ok, St1} = connect(DcId, State),
{noreply, St1}
catch T:R ->
lager:error("Down connect error: ~s",
[lager:pr_stacktrace(erlang:get_stacktrace(), {T, R})]),
erlang:send_after(300, self(), do_connect),
{noreply, State}
end.
terminate(_Reason, #state{upstreams = Ups}) ->
%% Should I do this or dc_pool? Maybe only when reason is 'normal'?
Self = self(),
lists:foreach(
fun(Upstream) ->
ok = mtp_handler:send(Upstream, {close_ext, Self})
end, maps:keys(Ups)),
ok.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%% Send packet from upstream to downstream
handle_send(Data, Upstream, #state{upstreams = Ups,
addr_bin = ProxyAddr} = St) ->
UpstreamData = maps:get(Upstream, Ups),
Packet = mtp_rpc:encode_packet({data, Data}, {UpstreamData, ProxyAddr}),
down_send(Packet, St).
%% New upstream connected
handle_upstream_new(Upstream, Opts, #state{upstreams = Ups,
upstreams_rev = UpsRev} = St) ->
ConnId = erlang:unique_integer(),
{Ip, Port} = maps:get(addr, Opts),
AdTag = maps:get(ad_tag, Opts, undefined),
Ups1 = Ups#{Upstream => {ConnId, iolist_to_binary(mtp_rpc:encode_ip_port(Ip, Port)), AdTag}},
UpsRev1 = UpsRev#{ConnId => Upstream},
St#state{upstreams = Ups1,
upstreams_rev = UpsRev1}.
%% Upstream process is exited (or about to exit)
handle_upstream_closed(Upstream, #state{upstreams = Ups,
upstreams_rev = UpsRev} = St) ->
case maps:take(Upstream, Ups) of
{{ConnId, _, _}, Ups1} ->
UpsRev1 = maps:remove(ConnId, UpsRev),
St1 = St#state{upstreams = Ups1,
upstreams_rev = UpsRev1},
Packet = mtp_rpc:encode_packet(remote_closed, ConnId),
down_send(Packet, St1);
error ->
lager:warning("Unknown upstream ~p", [Upstream]),
{ok, St}
end.
handle_downstream_data(Bin, #state{stage = tunnel,
codec = DownCodec} = S) ->
{ok, S3, DownCodec1} =
mtp_layer:fold_packets(
fun(Decoded, S1) ->
mtp_metric:histogram_observe(
[?APP, tg_packet_size, bytes],
byte_size(Decoded),
#{labels => [downstream_to_upstream]}),
handle_rpc(mtp_rpc:decode_packet(Decoded), S1)
end, S, Bin, DownCodec),
{ok, S3#state{codec = DownCodec1}};
handle_downstream_data(Bin, #state{stage = handshake_1,
codec = DownCodec} = S) ->
case mtp_layer:try_decode_packet(Bin, DownCodec) of
{ok, Packet, DownCodec1} ->
down_handshake2(Packet, S#state{codec = DownCodec1});
{incomplete, DownCodec1} ->
{ok, S#state{codec = DownCodec1}}
end;
handle_downstream_data(Bin, #state{stage = handshake_2,
codec = DownCodec} = S) ->
case mtp_layer:try_decode_packet(Bin, DownCodec) of
{ok, Packet, DownCodec1} ->
%% TODO: There might be something in downstream buffers after stage3,
%% would be nice to run foldl
down_handshake3(Packet, S#state{codec = DownCodec1});
{incomplete, DownCodec1} ->
{ok, S#state{codec = DownCodec1}}
end.
-spec handle_rpc(mtp_rpc:packet(), #state{}) -> #state{}.
handle_rpc({proxy_ans, ConnId, Data}, S) ->
up_send({proxy_ans, self(), Data}, ConnId, S);
handle_rpc({close_ext, ConnId}, S) ->
up_send({close_ext, self()}, ConnId, S);
handle_rpc({simple_ack, ConnId, Confirm}, S) ->
up_send({simple_ack, self(), Confirm}, ConnId, S).
-spec down_send(iodata(), #state{}) -> {ok, #state{}}.
down_send(Packet, #state{sock = Sock, codec = Codec} = St) ->
%% lager:debug("Up>Down: ~w", [Packet]),
{Encoded, Codec1} = mtp_layer:encode_packet(Packet, Codec),
mtp_metric:rt(
[?APP, downstream_send_duration, seconds],
fun() ->
ok = gen_tcp:send(Sock, Encoded)
end),
{ok, St#state{codec = Codec1}}.
up_send(Packet, ConnId, #state{upstreams_rev = UpsRev} = St) ->
%% lager:debug("Down>Up: ~w", [Packet]),
Upstream = maps:get(ConnId, UpsRev),
ok = mtp_handler:send(Upstream, Packet),
St.
connect(DcId, S) ->
{ok, {Host, Port}} = mtp_config:get_netloc(DcId),
{ok, Sock} = tcp_connect(Host, Port),
mtp_metric:count_inc([?APP, out_connect_ok, total], 1,
#{labels => [DcId]}),
AddrStr = inet:ntoa(Host),
lager:info("~s:~p: TCP connected", [AddrStr, Port]),
down_handshake1(S#state{sock = Sock,
netloc = {Host, Port}}).
tcp_connect(Host, Port) ->
SockOpts = [{active, once},
{packet, raw},
binary,
{send_timeout, ?SEND_TIMEOUT},
%% {nodelay, true},
{keepalive, true}],
case mtp_metric:rt([?APP, downstream_connect_duration, seconds],
fun() ->
gen_tcp:connect(Host, Port, SockOpts, ?CONN_TIMEOUT)
end) of
{ok, Sock} ->
ok = inet:setopts(Sock, [%% {recbuf, ?MAX_SOCK_BUF_SIZE},
%% {sndbuf, ?MAX_SOCK_BUF_SIZE},
{buffer, ?MAX_SOCK_BUF_SIZE}]),
{ok, Sock};
{error, _} = Err ->
Err
end.
-define(RPC_NONCE, <<170,135,203,122>>).
-define(RPC_HANDSHAKE, <<245,238,130,118>>).
-define(RPC_FLAGS, <<0, 0, 0, 0>>).
down_handshake1(S) ->
RpcNonce = ?RPC_NONCE,
<<KeySelector:4/binary, _/binary>> = Key = mtp_config:get_secret(),
CryptoTs = os:system_time(seconds),
Nonce = crypto:strong_rand_bytes(16),
Msg = <<RpcNonce/binary,
KeySelector/binary,
1:32/little, %AES
CryptoTs:32/little,
Nonce/binary>>,
Full = mtp_full:new(-2, -2),
S1 = S#state{codec = mtp_layer:new(mtp_full, Full),
stage = handshake_1,
stage_state = {KeySelector, Nonce, CryptoTs, Key}},
down_send(Msg, S1).
down_handshake2(<<Type:4/binary, KeySelector:4/binary, Schema:32/little, _CryptoTs:4/binary,
SrvNonce:16/binary>>, #state{stage_state = {MyKeySelector, CliNonce, MyTs, Key},
sock = Sock,
codec = DownCodec} = S) ->
(Type == ?RPC_NONCE) orelse error({wrong_rpc_type, Type}),
(Schema == 1) orelse error({wrong_schema, Schema}),
(KeySelector == MyKeySelector) orelse error({wrong_key_selector, KeySelector}),
{ok, {DownIp, DownPort}} = inet:peername(Sock),
{MyIp, MyPort} = get_external_ip(Sock),
DownIpBin = mtp_obfuscated:bin_rev(mtp_rpc:inet_pton(DownIp)),
MyIpBin = mtp_obfuscated:bin_rev(mtp_rpc:inet_pton(MyIp)),
Args = #{srv_n => SrvNonce, clt_n => CliNonce, clt_ts => MyTs,
srv_ip => DownIpBin, srv_port => DownPort,
clt_ip => MyIpBin, clt_port => MyPort, secret => Key},
{EncKey, EncIv} = get_middle_key(Args#{purpose => <<"CLIENT">>}),
{DecKey, DecIv} = get_middle_key(Args#{purpose => <<"SERVER">>}),
CryptoCodec = mtp_layer:new(mtp_aes_cbc, mtp_aes_cbc:new(EncKey, EncIv, DecKey, DecIv, 16)),
DownCodec1 = mtp_layer:new(mtp_wrap, mtp_wrap:new(DownCodec, CryptoCodec)),
SenderPID = PeerPID = <<"IPIPPRPDTIME">>,
Handshake = [?RPC_HANDSHAKE,
?RPC_FLAGS,
SenderPID,
PeerPID],
down_send(Handshake,
S#state{codec = DownCodec1,
stage = handshake_2,
addr_bin = iolist_to_binary(mtp_rpc:encode_ip_port(MyIp, MyPort)),
stage_state = SenderPID}).
get_middle_key(#{srv_n := Nonce, clt_n := MyNonce, clt_ts := MyTs, srv_ip := SrvIpBinBig, srv_port := SrvPort,
clt_ip := CltIpBinBig, clt_port := CltPort, secret := Secret, purpose := Purpose} = _Args) ->
Msg =
<<Nonce/binary,
MyNonce/binary,
MyTs:32/little,
SrvIpBinBig/binary,
CltPort:16/little,
Purpose/binary,
CltIpBinBig/binary,
SrvPort:16/little,
Secret/binary,
Nonce/binary,
%% IPv6
MyNonce/binary
>>,
<<_, ForMd51/binary>> = Msg,
<<_, _, ForMd52/binary>> = Msg,
<<Key1:12/binary, _/binary>> = crypto:hash(md5, ForMd51),
ShaSum = crypto:hash(sha, Msg),
Key = <<Key1/binary, ShaSum/binary>>,
IV = crypto:hash(md5, ForMd52),
{Key, IV}.
down_handshake3(<<Type:4/binary, _Flags:4/binary, _SenderPid:12/binary, PeerPid:12/binary>>,
#state{stage_state = PrevSenderPid, pool = Pool,
netloc = {Addr, Port}} = S) ->
(Type == ?RPC_HANDSHAKE) orelse error({wrong_rpc_type, Type}),
(PeerPid == PrevSenderPid) orelse error({wrong_sender_pid, PeerPid}),
ok = mtp_dc_pool:ack_connected(Pool, self()),
lager:info("~s:~w: handshake complete", [inet:ntoa(Addr), Port]),
{ok, S#state{stage = tunnel,
stage_state = undefined}}.
%% Internal
get_external_ip(Sock) ->
{ok, {MyIp, MyPort}} = inet:sockname(Sock),
case application:get_env(?APP, external_ip) of
{ok, IpStr} ->
{ok, IP} = inet:parse_ipv4strict_address(IpStr),
{IP, MyPort};
undefined ->
{MyIp, MyPort}
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-define(PROXY_SECRET,
<<196,249,250,202,150,120,230,187,72,173,108,126,44,229,192,210,68,48,100,
93,85,74,221,235,85,65,158,3,77,166,39,33,208,70,234,171,110,82,171,20,
169,90,68,62,207,179,70,62,121,160,90,102,97,42,223,156,174,218,139,233,
168,13,166,152,111,176,166,255,56,122,248,77,136,239,58,100,19,113,62,92,
51,119,246,225,163,212,125,153,245,224,197,110,236,232,240,92,84,196,144,
176,121,227,27,239,130,255,14,232,242,176,163,39,86,210,73,197,242,18,105,
129,108,183,6,27,38,93,178,18>>).
middle_key_test() ->
Args = #{srv_port => 80,
srv_ip => mtp_obfuscated:bin_rev(mtp_rpc:inet_pton({149, 154, 162, 38})),
srv_n => <<247,40,210,56,65,12,101,170,216,155,14,253,250,238,219,226>>,
clt_n => <<24,49,53,111,198,10,235,180,230,112,92,78,1,201,106,105>>,
clt_ip => mtp_obfuscated:bin_rev(mtp_rpc:inet_pton({80, 211, 29, 34})),
clt_ts => 1528396015,
clt_port => 54208,
purpose => <<"CLIENT">>,
secret => ?PROXY_SECRET
},
Key = <<165,158,127,49,41,232,187,69,38,29,163,226,183,146,28,67,225,224,134,191,207,152,255,166,152,66,169,196,54,135,50,188>>,
IV = <<33,110,125,221,183,121,160,116,130,180,156,249,52,111,37,178>>,
?assertEqual(
{Key, IV},
get_middle_key(Args)).
-endif.
%%%-------------------------------------------------------------------
%%% @author Sergey <me@seriyps.ru>
%%% @copyright (C) 2018, Sergey
%%% @doc
%%% Supervisor for mtp_down_conn processes
%%% @end
%%% Created : 14 Oct 2018 by Sergey <me@seriyps.ru>
%%%-------------------------------------------------------------------
-module(mtp_down_conn_sup).
-behaviour(supervisor).
-export([start_link/0,
start_conn/2]).
-export([init/1]).
-define(SERVER, ?MODULE).
start_link() ->
supervisor:start_link({local, ?SERVER}, ?MODULE, []).
-spec start_conn(pid(), mtp_conf:dc_id()) -> {ok, pid()}.
start_conn(Pool, DcId) ->
supervisor:start_child(?SERVER, [Pool, DcId]).
init([]) ->
SupFlags = #{strategy => simple_one_for_one,
intensity => 50,
period => 5},
AChild = #{id => mtp_down_conn,
start => {mtp_down_conn, start_link, []},
restart => temporary,
shutdown => 2000,
type => worker},
{ok, {SupFlags, [AChild]}}.
......@@ -10,7 +10,7 @@
-behaviour(ranch_protocol).
%% API
-export([start_link/4]).
-export([start_link/4, send/2]).
-export([hex/1, unhex/1]).
-export([keys_str/0]).
......@@ -18,6 +18,9 @@
-export([ranch_init/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
-export_type([handle/0]).
-type handle() :: pid().
-define(MAX_SOCK_BUF_SIZE, 1024 * 50). % Decrease if CPU is cheaper than RAM
-define(MAX_UP_INIT_BUF_SIZE, 1024 * 1024). %1mb
......@@ -27,18 +30,19 @@
-record(state,
{stage = init :: stage(),
stage_state = <<>> :: any(),
up_acc = <<>> :: any(),
acc = <<>> :: any(),
secret :: binary(),
proxy_tag :: binary(),
up_sock :: gen_tcp:socket(),
up_transport :: transport(),
up_codec = ident :: mtp_layer:layer(),
sock :: gen_tcp:socket(),
transport :: transport(),
codec = ident :: mtp_layer:layer(),
down_sock :: gen_tcp:socket(),
down_codec = ident :: mtp_layer:layer(),
down :: gen_tcp:socket(),
dc_id :: integer(),
ad_tag :: binary(),
addr :: mtp_config:netloc(), % IP/Port of remote side
started_at :: pos_integer(),
timer_state = init :: init | hibernate | stop,
timer :: gen_timeout:tout()}).
......@@ -56,6 +60,10 @@ keys_str() ->
[{Name, Port, hex(Secret)}
|| {Name, Port, Secret} <- application:get_env(?APP, ports, [])].
-spec send(pid(), mtp_rpc:packet()) -> ok.
send(Upstream, Packet) ->
gen_server:cast(Upstream, Packet).
%% Callbacks
%% Custom gen_server init
......@@ -87,10 +95,11 @@ init({Socket, Transport, [Name, Secret, Tag]}) ->
{TimeoutKey, TimeoutDefault} = state_timeout(init),
Timer = gen_timeout:new(
#{timeout => {env, ?APP, TimeoutKey, TimeoutDefault}}),
State = #state{up_sock = Socket,
State = #state{sock = Socket,
secret = unhex(Secret),
proxy_tag = unhex(Tag),
up_transport = Transport,
transport = Transport,
ad_tag = unhex(Tag),
addr = {Ip, Port},
started_at = erlang:system_time(millisecond),
timer = Timer},
{ok, State};
......@@ -103,11 +112,28 @@ handle_call(_Request, _From, State) ->
Reply = ok,
{reply, Reply, State}.
handle_cast(_Msg, State) ->
handle_cast({proxy_ans, Down, Data}, #state{down = Down} = S) ->
%% telegram server -> proxy
case up_send(Data, S) of
{ok, S1} ->
{noreply, bump_timer(S1)};
{error, Reason} ->
lager:error("Error sending tunnelled data to in socket: ~p", [Reason]),
{stop, normal, S}
end;
handle_cast({close_ext, Down}, #state{down = Down, sock = USock, transport = UTrans} = S) ->
lager:debug("asked to close connection by downstream"),
ok = UTrans:close(USock),
{stop, normal, S};
handle_cast({simple_ack, Down, Confirm}, #state{down = Down} = S) ->
lager:info("Simple ack: ~p, ~p", [Down, Confirm]),
{noreply, S};
handle_cast(Other, State) ->
lager:warning("Unexpected msg ~p", [Other]),
{noreply, State}.
handle_info({tcp, Sock, Data}, #state{up_sock = Sock,
up_transport = Transport} = S) ->
handle_info({tcp, Sock, Data}, #state{sock = Sock,
transport = Transport} = S) ->
%% client -> proxy
track(rx, Data),
case handle_upstream_data(Data, S) of
......@@ -118,41 +144,13 @@ handle_info({tcp, Sock, Data}, #state{up_sock = Sock,
lager:info("handle_data error ~p", [Reason]),
{stop, normal, S}
end;
handle_info({tcp_closed, Sock}, #state{up_sock = Sock} = S) ->
handle_info({tcp_closed, Sock}, #state{sock = Sock} = S) ->
lager:debug("upstream sock closed"),
{stop, normal, maybe_close_down(S)};
handle_info({tcp_error, Sock, Reason}, #state{up_sock = Sock} = S) ->
handle_info({tcp_error, Sock, Reason}, #state{sock = Sock} = S) ->
lager:info("upstream sock error: ~p", [Reason]),
{stop, Reason, maybe_close_down(S)};
handle_info({tcp, Sock, Data}, #state{down_sock = Sock} = S) ->
%% telegram server -> proxy
track(tx, Data),
try handle_downstream_data(Data, S) of
{ok, S1} ->
ok = inet:setopts(Sock, [{active, once}]),
{noreply, bump_timer(S1)};
{error, Reason} ->
lager:error("Error sending tunnelled data to in socket: ~p", [Reason]),
{stop, normal, S}
catch throw:rpc_close ->
lager:info("downstream closed by RPC"),
#state{up_sock = USock, up_transport = UTrans} = S,
ok = UTrans:close(USock),
{stop, normal, maybe_close_down(S)}
end;
handle_info({tcp_closed, Sock}, #state{down_sock = Sock,
up_sock = USock, up_transport = UTrans} = S) ->
lager:debug("downstream sock closed"),
ok = UTrans:close(USock),
{stop, normal, S};
handle_info({tcp_error, Sock, Reason}, #state{down_sock = Sock,
up_sock = USock, up_transport = UTrans} = S) ->
lager:info("downstream sock error: ~p", [Reason]),
ok = UTrans:close(USock),
{stop, Reason, S};
handle_info(timeout, #state{timer = Timer, timer_state = TState} = S) ->
case gen_timeout:is_expired(Timer) of
true when TState == stop;
......@@ -168,10 +166,11 @@ handle_info(timeout, #state{timer = Timer, timer_state = TState} = S) ->
{noreply, S#state{timer = Timer1}}
end;
handle_info(Other, S) ->
lager:warning("Unexpected handle_info ~p", [Other]),
lager:warning("Unexpected msg ~p", [Other]),
{noreply, S}.
terminate(_Reason, #state{started_at = Started}) ->
terminate(_Reason, #state{started_at = Started} = S) ->
maybe_close_down(S),
mtp_metric:count_inc([?APP, in_connection_closed, total], 1, #{}),
Lifetime = erlang:system_time(millisecond) - Started,
mtp_metric:histogram_observe(
......@@ -183,10 +182,11 @@ terminate(_Reason, #state{started_at = Started}) ->
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
maybe_close_down(#state{down_sock = undefined} = S) -> S;
maybe_close_down(#state{down_sock = Out} = S) ->
gen_tcp:close(Out),
S#state{down_sock = undefined}.
maybe_close_down(#state{down = undefined} = S) -> S;
maybe_close_down(#state{dc_id = DcId} = S) ->
{ok, Pool} = mtp_config:get_downstream_pool(DcId),
mtp_dc_pool:return(Pool, self()),
S#state{down = undefined}.
bump_timer(#state{timer = Timer, timer_state = TState} = S) ->
Timer1 = gen_timeout:bump(Timer),
......@@ -220,7 +220,7 @@ state_timeout(stop) ->
%% Handle telegram client -> proxy stream
handle_upstream_data(Bin, #state{stage = tunnel,
up_codec = UpCodec} = S) ->
codec = UpCodec} = S) ->
{ok, S3, UpCodec1} =
mtp_layer:fold_packets(
fun(Decoded, S1) ->
......@@ -231,7 +231,7 @@ handle_upstream_data(Bin, #state{stage = tunnel,
{ok, S2} = down_send(Decoded, S1),
S2
end, S, Bin, UpCodec),
{ok, S3#state{up_codec = UpCodec1}};
{ok, S3#state{codec = UpCodec1}};
handle_upstream_data(<<Header:64/binary, Rest/binary>>, #state{stage = init, stage_state = <<>>,
secret = Secret} = S) ->
case mtp_obfuscated:from_header(Header, Secret) of
......@@ -244,8 +244,8 @@ handle_upstream_data(<<Header:64/binary, Rest/binary>>, #state{stage = init, sta
ObfuscatedLayer)),
handle_upstream_header(
DcId,
S#state{up_codec = UpCodec,
up_acc = Rest,
S#state{codec = UpCodec,
acc = Rest,
stage_state = undefined});
{error, Reason} = Err ->
mtp_metric:count_inc([?APP, protocol_error, total],
......@@ -255,58 +255,14 @@ handle_upstream_data(<<Header:64/binary, Rest/binary>>, #state{stage = init, sta
handle_upstream_data(Bin, #state{stage = init, stage_state = <<>>} = S) ->
{ok, S#state{stage_state = Bin}};
handle_upstream_data(Bin, #state{stage = init, stage_state = Buf} = S) ->
handle_upstream_data(<<Buf/binary, Bin/binary>> , S#state{stage_state = <<>>});
handle_upstream_data(Bin, #state{stage = Stage, up_acc = Acc} = S) when Stage =/= init,
Stage =/= tunnel ->
%% We are in downstream handshake; it would be better to leave socked in passive mode,
%% but let's do it in next iteration
((byte_size(Bin) + byte_size(Acc)) < ?MAX_UP_INIT_BUF_SIZE)
orelse error(upstream_buffer_overflow),
{ok, S#state{up_acc = <<Acc/binary, Bin/binary>>}}.
%% Handle telegram server -> proxy stream
handle_downstream_data(Bin, #state{stage = tunnel,
down_codec = DownCodec} = S) ->
{ok, S3, DownCodec1} =
mtp_layer:fold_packets(
fun(Decoded, S1) ->
mtp_metric:histogram_observe(
[?APP, tg_packet_size, bytes],
byte_size(Decoded),
#{labels => [downstream_to_upstream]}),
{ok, S2} = up_send(Decoded, S1),
S2
end, S, Bin, DownCodec),
{ok, S3#state{down_codec = DownCodec1}};
handle_downstream_data(Bin, #state{stage = down_handshake_1,
down_codec = DownCodec} = S) ->
case mtp_layer:try_decode_packet(Bin, DownCodec) of
{ok, Packet, DownCodec1} ->
down_handshake2(Packet, S#state{down_codec = DownCodec1});
{incomplete, DownCodec1} ->
{ok, S#state{down_codec = DownCodec1}}
end;
handle_downstream_data(Bin, #state{stage = down_handshake_2,
proxy_tag = ProxyTag,
down_codec = DownCodec} = S) ->
case mtp_layer:try_decode_packet(Bin, DownCodec) of
{ok, Packet, DownCodec1} ->
%% TODO: There might be something in downstream buffers after stage3,
%% would be nice to run foldl
{ok, S1} = down_handshake3(Packet, ProxyTag, S#state{down_codec = DownCodec1}),
S2 = #state{up_acc = UpAcc} = switch_timer(S1, hibernate),
%% Flush upstream accumulator
handle_upstream_data(UpAcc, S2#state{up_acc = []});
{incomplete, DownCodec1} ->
{ok, S#state{down_codec = DownCodec1}}
end.
handle_upstream_data(<<Buf/binary, Bin/binary>> , S#state{stage_state = <<>>}).
up_send(Packet, #state{stage = tunnel,
up_codec = UpCodec,
up_sock = Sock,
up_transport = Transport} = S) ->
codec = UpCodec,
sock = Sock,
transport = Transport} = S) ->
lager:debug(">TG: ~p", [Packet]),
{Encoded, UpCodec1} = mtp_layer:encode_packet(Packet, UpCodec),
mtp_metric:rt([?APP, upstream_send_duration, seconds],
fun() ->
......@@ -321,165 +277,27 @@ up_send(Packet, #state{stage = tunnel,
throw({stop, normal, S})
end
end),
{ok, S#state{up_codec = UpCodec1}}.
{ok, S#state{codec = UpCodec1}}.
down_send(Packet, #state{down_sock = Sock,
down_codec = DownCodec} = S) ->
{Encoded, DownCodec1} = mtp_layer:encode_packet(Packet, DownCodec),
mtp_metric:rt([?APP, downstream_send_duration, seconds],
fun() ->
case gen_tcp:send(Sock, Encoded) of
ok -> ok;
{error, Reason} ->
is_atom(Reason) andalso
mtp_metric:count_inc(
[?APP, downstream_send_error, total], 1,
#{labels => [Reason]}),
lager:warning("Downstream send error: ~p", [Reason]),
throw({stop, normal, S})
end
end),
{ok, S#state{down_codec = DownCodec1}}.
down_send(Packet, #state{down = Down} = S) ->
lager:debug("<TG: ~p", [Packet]),
ok = mtp_down_conn:send(Down, Packet),
{ok, S}.
%% Internal
handle_upstream_header(DcId, S) ->
{Addr, Port} = mtp_config:get_downstream_safe(DcId),
case connect(Addr, Port) of
{ok, Sock} ->
mtp_metric:count_inc([?APP, out_connect_ok, total], 1,
#{labels => [DcId]}),
AddrStr = inet:ntoa(Addr),
lager:info("Connected to dc_id=~w ~s:~w", [DcId, AddrStr, Port]),
down_handshake1(S#state{down_sock = Sock});
{error, Reason} = Err ->
mtp_metric:count_inc([?APP, out_connect_error, total], 1, #{labels => [Reason]}),
Err
end.
-define(CONN_TIMEOUT, 10000).
-define(SEND_TIMEOUT, 60 * 1000).
connect(Host, Port) ->
BufSize = application:get_env(?APP, downstream_socket_buffer_size,
?MAX_SOCK_BUF_SIZE),
SockOpts = [{active, once},
{packet, raw},
{mode, binary},
{send_timeout, ?SEND_TIMEOUT},
{buffer, BufSize},
%% {nodelay, true},
{keepalive, true}],
case mtp_metric:rt([?APP, downstream_connect_duration, seconds],
fun() ->
gen_tcp:connect(Host, Port, SockOpts, ?CONN_TIMEOUT)
end) of
{ok, Sock} ->
{ok, Sock};
{error, _} = Err ->
Err
end.
-define(RPC_NONCE, <<170,135,203,122>>).
-define(RPC_HANDSHAKE, <<245,238,130,118>>).
-define(RPC_FLAGS, <<0, 0, 0, 0>>).
down_handshake1(S) ->
RpcNonce = ?RPC_NONCE,
<<KeySelector:4/binary, _/binary>> = Key = mtp_config:get_secret(),
CryptoTs = os:system_time(seconds),
Nonce = crypto:strong_rand_bytes(16),
Msg = <<RpcNonce/binary,
KeySelector/binary,
1:32/little, %AES
CryptoTs:32/little,
Nonce/binary>>,
Full = mtp_full:new(-2, -2),
S1 = S#state{down_codec = mtp_layer:new(mtp_full, Full),
stage = down_handshake_1,
stage_state = {KeySelector, Nonce, CryptoTs, Key}},
down_send(Msg, S1).
down_handshake2(<<Type:4/binary, KeySelector:4/binary, Schema:32/little, _CryptoTs:4/binary,
SrvNonce:16/binary>>, #state{stage_state = {MyKeySelector, CliNonce, MyTs, Key},
down_sock = Sock,
down_codec = DownCodec} = S) ->
(Type == ?RPC_NONCE) orelse error({wrong_rpc_type, Type}),
(Schema == 1) orelse error({wrong_schema, Schema}),
(KeySelector == MyKeySelector) orelse error({wrong_key_selector, KeySelector}),
{ok, {DownIp, DownPort}} = inet:peername(Sock),
{MyIp, MyPort} = get_external_ip(Sock),
DownIpBin = mtp_obfuscated:bin_rev(mtp_rpc:inet_pton(DownIp)),
MyIpBin = mtp_obfuscated:bin_rev(mtp_rpc:inet_pton(MyIp)),
Args = #{srv_n => SrvNonce, clt_n => CliNonce, clt_ts => MyTs,
srv_ip => DownIpBin, srv_port => DownPort,
clt_ip => MyIpBin, clt_port => MyPort, secret => Key},
{EncKey, EncIv} = get_middle_key(Args#{purpose => <<"CLIENT">>}),
{DecKey, DecIv} = get_middle_key(Args#{purpose => <<"SERVER">>}),
CryptoCodec = mtp_layer:new(mtp_aes_cbc, mtp_aes_cbc:new(EncKey, EncIv, DecKey, DecIv, 16)),
DownCodec1 = mtp_layer:new(mtp_wrap, mtp_wrap:new(DownCodec, CryptoCodec)),
SenderPID = PeerPID = <<"IPIPPRPDTIME">>,
Handshake = [?RPC_HANDSHAKE,
?RPC_FLAGS,
SenderPID,
PeerPID],
down_send(Handshake, S#state{down_codec = DownCodec1,
stage = down_handshake_2,
stage_state = {MyIp, MyPort, SenderPID}}).
get_middle_key(#{srv_n := Nonce, clt_n := MyNonce, clt_ts := MyTs, srv_ip := SrvIpBinBig, srv_port := SrvPort,
clt_ip := CltIpBinBig, clt_port := CltPort, secret := Secret, purpose := Purpose} = _Args) ->
Msg =
<<Nonce/binary,
MyNonce/binary,
MyTs:32/little,
SrvIpBinBig/binary,
CltPort:16/little,
Purpose/binary,
CltIpBinBig/binary,
SrvPort:16/little,
Secret/binary,
Nonce/binary,
%% IPv6
MyNonce/binary
>>,
<<_, ForMd51/binary>> = Msg,
<<_, _, ForMd52/binary>> = Msg,
<<Key1:12/binary, _/binary>> = crypto:hash(md5, ForMd51),
ShaSum = crypto:hash(sha, Msg),
Key = <<Key1/binary, ShaSum/binary>>,
IV = crypto:hash(md5, ForMd52),
{Key, IV}.
down_handshake3(<<Type:4/binary, _Flags:4/binary, _SenderPid:12/binary, PeerPid:12/binary>>,
ProxyTag,
#state{stage_state = {MyIp, MyPort, PrevSenderPid},
down_codec = DownCodec,
up_sock = Sock,
up_transport = Transport} = S) ->
(Type == ?RPC_HANDSHAKE) orelse error({wrong_rpc_type, Type}),
(PeerPid == PrevSenderPid) orelse error({wrong_sender_pid, PeerPid}),
{ok, {ClientIp, ClientPort}} = Transport:peername(Sock),
RpcCodec = mtp_layer:new(mtp_rpc, mtp_rpc:new(ClientIp, ClientPort, MyIp, MyPort, ProxyTag)),
DownCodec1 = mtp_layer:new(mtp_wrap, mtp_wrap:new(RpcCodec, DownCodec)),
{ok, S#state{down_codec = DownCodec1,
stage = tunnel,
stage_state = undefined}}.
%% Internal
get_external_ip(Sock) ->
{ok, {MyIp, MyPort}} = inet:sockname(Sock),
case application:get_env(?APP, external_ip) of
{ok, IpStr} ->
{ok, IP} = inet:parse_ipv4strict_address(IpStr),
{IP, MyPort};
undefined ->
{MyIp, MyPort}
end.
handle_upstream_header(DcId, #state{acc = Acc, ad_tag = Tag, addr = Addr} = S) ->
Opts = #{ad_tag => Tag,
addr => Addr},
{DcId, _Pool, Downstream} = mtp_config:get_downstream_safe(DcId, Opts),
handle_upstream_data(
Acc,
S#state{down = Downstream,
dc_id = DcId,
acc = <<>>,
stage = tunnel}).
hex(Bin) ->
<<begin
......@@ -501,33 +319,3 @@ track(Direction, Data) ->
Size = byte_size(Data),
mtp_metric:count_inc([?APP, tracker, bytes], Size, #{labels => [Direction]}),
mtp_metric:histogram_observe([?APP, tracker_packet_size, bytes], Size, #{labels => [Direction]}).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-define(PROXY_SECRET,
<<196,249,250,202,150,120,230,187,72,173,108,126,44,229,192,210,68,48,100,
93,85,74,221,235,85,65,158,3,77,166,39,33,208,70,234,171,110,82,171,20,
169,90,68,62,207,179,70,62,121,160,90,102,97,42,223,156,174,218,139,233,
168,13,166,152,111,176,166,255,56,122,248,77,136,239,58,100,19,113,62,92,
51,119,246,225,163,212,125,153,245,224,197,110,236,232,240,92,84,196,144,
176,121,227,27,239,130,255,14,232,242,176,163,39,86,210,73,197,242,18,105,
129,108,183,6,27,38,93,178,18>>).
middle_key_test() ->
Args = #{srv_port => 80,
srv_ip => mtp_obfuscated:bin_rev(mtp_rpc:inet_pton({149, 154, 162, 38})),
srv_n => <<247,40,210,56,65,12,101,170,216,155,14,253,250,238,219,226>>,
clt_n => <<24,49,53,111,198,10,235,180,230,112,92,78,1,201,106,105>>,
clt_ip => mtp_obfuscated:bin_rev(mtp_rpc:inet_pton({80, 211, 29, 34})),
clt_ts => 1528396015,
clt_port => 54208,
purpose => <<"CLIENT">>,
secret => ?PROXY_SECRET
},
Key = <<165,158,127,49,41,232,187,69,38,29,163,226,183,146,28,67,225,224,134,191,207,152,255,166,152,66,169,196,54,135,50,188>>,
IV = <<33,110,125,221,183,121,160,116,130,180,156,249,52,111,37,178>>,
?assertEqual(
{Key, IV},
get_middle_key(Args)).
-endif.
......@@ -6,12 +6,11 @@
%%% Created : 6 Jun 2018 by Sergey <me@seriyps.ru>
-module(mtp_rpc).
-behaviour(mtp_layer).
-export([new/5,
try_decode_packet/2,
-export([decode_packet/1,
encode_packet/2]).
-export([inet_pton/1]).
-export([inet_pton/1,
encode_ip_port/2]).
-export_type([codec/0]).
-record(rpc_st,
......@@ -21,8 +20,12 @@
conn_id :: integer()}).
-define(APP, mtproto_proxy).
-define(RPC_PROXY_ANS, 13,218,3,68).
-define(RPC_CLOSE_EXT, 162,52,182,94).
-define(RPC_PROXY_REQ, 238,241,206,54). %0x36cef1ee
-define(RPC_PROXY_ANS, 13,218,3,68). %0x4403da0d
-define(RPC_CLOSE_CONN, 93,66,207,31). %0x1fcf425d
-define(RPC_CLOSE_EXT, 162,52,182,94). %0x5eb634a2
-define(RPC_SIMPLE_ACK, 155,64,172,59). %0x3bac409b
-define(TL_PROXY_TAG, 174,38,30,219).
-define(FLAG_NOT_ENCRYPTED , 16#2).
-define(FLAG_HAS_AD_TAG , 16#8).
......@@ -35,36 +38,43 @@
-opaque codec() :: #rpc_st{}.
-type conn_id() :: integer().
-type packet() :: {proxy_ans, conn_id(), binary()}
| {close_ext, conn_id()}
| {simple_ack, conn_id(), binary()}.
new(ClientIp, ClientPort, ProxyIp, ProxyPort, ProxyTag) ->
new(ClientIp, ClientPort, ProxyIp, ProxyPort, ProxyTag,
erlang:unique_integer()).
%% new(ClientIp, ClientPort, ProxyIp, ProxyPort, ProxyTag) ->
%% new(ClientIp, ClientPort, ProxyIp, ProxyPort, ProxyTag,
%% erlang:unique_integer()).
new(ClientIp, ClientPort, ProxyIp, ProxyPort, ProxyTag, ConnId) ->
#rpc_st{client_addr = iolist_to_binary(encode_ip_port(ClientIp, ClientPort)),
proxy_addr = iolist_to_binary(encode_ip_port(ProxyIp, ProxyPort)),
proxy_tag = ProxyTag,
conn_id = ConnId}.
%% new(ClientIp, ClientPort, ProxyIp, ProxyPort, ProxyTag, ConnId) ->
%% #rpc_st{client_addr = iolist_to_binary(encode_ip_port(ClientIp, ClientPort)),
%% proxy_addr = iolist_to_binary(encode_ip_port(ProxyIp, ProxyPort)),
%% proxy_tag = ProxyTag,
%% conn_id = ConnId}.
%% It expects that packet segmentation was done on previous layer
try_decode_packet(<<?RPC_PROXY_ANS, _AnsFlags:4/binary, _ConnId:8/binary, Data/binary>> = _Msg, S) ->
%% TODO: check if we can use downstream multiplexing using ConnId
{ok, Data, S};
try_decode_packet(<<?RPC_CLOSE_EXT, _/binary>> = _Msg, _S) ->
%% Use throw as short-circuit
throw(rpc_close);
try_decode_packet(<<>>, S) ->
{incomplete, S}.
encode_packet(Msg, #rpc_st{client_addr = ClientAddr, proxy_addr = ProxyAddr,
conn_id = ConnId, proxy_tag = ProxyTag} = S) ->
%% See mtproto/mtproto-proxy.c:process_client_packet
-spec decode_packet(binary()) -> packet() | error.
decode_packet(<<?RPC_PROXY_ANS, _AnsFlags:4/binary, ConnId:64/signed-little, Data/binary>>) ->
%% mtproto/mtproto-proxy.c:client_send_message
{proxy_ans, ConnId, Data};
decode_packet(<<?RPC_CLOSE_EXT, ConnId:64/signed-little>>) ->
{close_ext, ConnId};
decode_packet(<<?RPC_SIMPLE_ACK, ConnId:64/signed-little, Confirm:4/binary>>) ->
%% mtproto/mtproto-proxy.c:push_rpc_confirmation
{simple_ack, ConnId, Confirm}.
encode_packet({data, Msg}, {{ConnId, ClientAddr, ProxyTag}, ProxyAddr}) ->
%% See mtproto/mtproto-proxy.c:forward_mtproto_packet
((iolist_size(Msg) rem 4) == 0)
orelse error(not_aligned),
Flags1 = (?FLAG_HAS_AD_TAG
bor ?FLAG_MAGIC
bor ?FLAG_EXTMODE2
bor ?FLAG_ABRIDGED),
%% if (auth_key_id) ...
Flags = case Msg of
%% XXX: what if Msg is iolist?
<<0, 0, 0, 0, 0, 0, 0, 0, _/binary>> ->
......@@ -72,20 +82,21 @@ encode_packet(Msg, #rpc_st{client_addr = ClientAddr, proxy_addr = ProxyAddr,
_ ->
Flags1
end,
Req =
[<<238,241,206,54, %RPC_PROXY_REQ
Flags:32/little, %Flags
ConnId:64/little-signed>>,
[<<?RPC_PROXY_REQ, %RPC_PROXY_REQ
Flags:32/little, %int: Flags
ConnId:64/little-signed>>, %long long:
ClientAddr, ProxyAddr,
<<24:32/little, %ExtraSize
174,38,30,219, %ProxyTag
(byte_size(ProxyTag)),
ProxyTag/binary,
0, 0, 0 %Padding
<<24:32/little, %int: ExtraSize
?TL_PROXY_TAG, %int: ProxyTag
(byte_size(ProxyTag)), %tls_string_len()
ProxyTag/binary, %tls_string_data()
0, 0, 0 %tls_pad()
>>
| Msg
],
{Req, S}.
];
encode_packet(remote_closed, ConnId) ->
<<?RPC_CLOSE_CONN, ConnId:64/little-signed>>.
encode_ip_port(IPv4, Port) when tuple_size(IPv4) == 4 ->
IpBin = inet_pton(IPv4),
......@@ -106,17 +117,12 @@ inet_pton(IPv6) when tuple_size(IPv6) == 8 ->
-include_lib("eunit/include/eunit.hrl").
tst_new() ->
ClientIp = {109, 238, 131, 159},
ClientPort = 1128,
ProxyIp = {80, 211, 29, 34},
ProxyPort = 53634,
ProxyTag = <<220,190,143,20,147,250,76,217,171,48,8,145,192,181,179,38>>,
new(ClientIp, ClientPort, ProxyIp, ProxyPort, ProxyTag, 1).
{{1, encode_ip_port({109, 238, 131, 159}, 1128),
<<220,190,143,20,147,250,76,217,171,48,8,145,192,181,179,38>>},
encode_ip_port({80, 211, 29, 34}, 53634)}.
decode_none_test() ->
S = tst_new(),
?assertEqual(
{incomplete, S}, try_decode_packet(<<>>, S)).
?assertError(function_clause, decode_packet(<<>>)).
encode_test() ->
S = tst_new(),
......@@ -125,26 +131,23 @@ encode_test() ->
<<238,241,206,54,10,16,2,64,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,109,238,131,159,104,4,0,0,0,0,0,0,0,0,0,0,0,0,255,255,80,211,29,34,130,209,0,0,24,0,0,0,174,38,30,219,16,220,190,143,20,147,250,76,217,171,48,8,145,192,181,179,38,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,61,2,24,91,20,0,0,0,120,151,70,96,153,197,142,238,245,139,85,208,160,241,68,89,106,7,118,167>>},
{<<14,146,6,159,99,150,29,221,115,87,68,198,122,39,38,249,153,87,37,105,4,111,147,70,54,179,134,12,90,4,223,155,206,220,167,201,203,176,123,181,103,176,49,216,163,106,54,148,133,51,206,212,81,90,47,26,3,161,149,251,182,90,190,51,213,7,107,176,112,220,25,144,183,249,149,182,172,194,218,146,161,191,247,4,250,123,230,251,41,181,139,177,55,171,253,198,153,183,61,53,119,115,46,174,172,245,90,166,215,99,181,58,236,129,103,80,218,244,81,45,142,128,177,146,26,131,184,155,22,217,218,187,209,155,156,64,219,235,175,40,249,235,77,82,212,73,11,133,52,4,222,157,67,176,251,46,254,241,15,192,215,192,186,82,233,68,147,234,88,250,96,14,172,179,7,159,28,11,237,48,44,33,137,185,166,166,173,103,136,174,31,35,77,151,76,55,176,211,230,176,118,144,139,77,0,213,68,179,73,58,58,80,238,120,197,67,241,210,210,156,72,105,60,125,239,98,7,19,234,249,222,194,166,37,46,100,1,65,225,224,244,57,147,119,49,20,1,160,4,51,247,161,142,11,131,11,27,166,159,110,145,78,55,205,126,246,126,68,44,114,91,191,213,241,242,9,33,16,30,228>>,
<<238,241,206,54,8,16,2,64,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,109,238,131,159,104,4,0,0,0,0,0,0,0,0,0,0,0,0,255,255,80,211,29,34,130,209,0,0,24,0,0,0,174,38,30,219,16,220,190,143,20,147,250,76,217,171,48,8,145,192,181,179,38,0,0,0,14,146,6,159,99,150,29,221,115,87,68,198,122,39,38,249,153,87,37,105,4,111,147,70,54,179,134,12,90,4,223,155,206,220,167,201,203,176,123,181,103,176,49,216,163,106,54,148,133,51,206,212,81,90,47,26,3,161,149,251,182,90,190,51,213,7,107,176,112,220,25,144,183,249,149,182,172,194,218,146,161,191,247,4,250,123,230,251,41,181,139,177,55,171,253,198,153,183,61,53,119,115,46,174,172,245,90,166,215,99,181,58,236,129,103,80,218,244,81,45,142,128,177,146,26,131,184,155,22,217,218,187,209,155,156,64,219,235,175,40,249,235,77,82,212,73,11,133,52,4,222,157,67,176,251,46,254,241,15,192,215,192,186,82,233,68,147,234,88,250,96,14,172,179,7,159,28,11,237,48,44,33,137,185,166,166,173,103,136,174,31,35,77,151,76,55,176,211,230,176,118,144,139,77,0,213,68,179,73,58,58,80,238,120,197,67,241,210,210,156,72,105,60,125,239,98,7,19,234,249,222,194,166,37,46,100,1,65,225,224,244,57,147,119,49,20,1,160,4,51,247,161,142,11,131,11,27,166,159,110,145,78,55,205,126,246,126,68,44,114,91,191,213,241,242,9,33,16,30,228>>}],
lists:foldl(
fun({In, Out}, S1) ->
{Enc, S2} = encode_packet(In, S1),
?assertEqual(Out, iolist_to_binary(Enc)),
S2
end, S, Samples).
lists:foreach(
fun({In, Out}) ->
Enc = encode_packet({data, In}, S),
?assertEqual(Out, iolist_to_binary(Enc))
end, Samples).
decode_test() ->
S = tst_new(),
Samples =
[{<<13,218,3,68,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,52,62,238,60,2,24,91,64,0,0,0,99,36,22,5,153,197,142,238,245,139,85,208,160,241,68,89,106,7,118,167,146,202,163,241,63,158,32,27,246,203,226,70,177,46,106,225,8,34,202,206,241,19,38,121,245,0,0,0,21,196,181,28,1,0,0,0,33,107,232,108,2,43,180,195>>,
<<0,0,0,0,0,0,0,0,1,52,62,238,60,2,24,91,64,0,0,0,99,36,22,5,153,197,142,238,245,139,85,208,160,241,68,89,106,7,118,167,146,202,163,241,63,158,32,27,246,203,226,70,177,46,106,225,8,34,202,206,241,19,38,121,245,0,0,0,21,196,181,28,1,0,0,0,33,107,232,108,2,43,180,195>>},
{<<13,218,3,68,0,0,0,0,2,0,0,0,0,0,0,0,14,146,6,159,99,150,29,221,85,233,237,52,236,18,11,0,174,214,89,213,69,89,250,18,116,192,128,240,217,221,210,144,123,9,182,152,60,206,88,187,101,178,53,107,44,98,190,195,149,114,0,19,90,218,101,133,183,249,183,170,90,21,86,24,42,81,224,152,13,58,90,84,41,158,177,99,57,83,123,99,138,127,29,238,162,49,71,65,165,168,218,220,245,202,24,135,152,1,28,38,85,197,8,232,201,163,65,118,202,89,204,67,48,21,51,106,188,7,167,61,185,82,39,210,164,21,97,99,63,167,2,143,69,126,214,75,95,142,69,68,243,49,11,121,28,177,159,0,154,134,206,34>>,
<<14,146,6,159,99,150,29,221,85,233,237,52,236,18,11,0,174,214,89,213,69,89,250,18,116,192,128,240,217,221,210,144,123,9,182,152,60,206,88,187,101,178,53,107,44,98,190,195,149,114,0,19,90,218,101,133,183,249,183,170,90,21,86,24,42,81,224,152,13,58,90,84,41,158,177,99,57,83,123,99,138,127,29,238,162,49,71,65,165,168,218,220,245,202,24,135,152,1,28,38,85,197,8,232,201,163,65,118,202,89,204,67,48,21,51,106,188,7,167,61,185,82,39,210,164,21,97,99,63,167,2,143,69,126,214,75,95,142,69,68,243,49,11,121,28,177,159,0,154,134,206,34>>}],
lists:foldl(
fun({In, Out}, S1) ->
{ok, Dec, S2} = try_decode_packet(In, S1),
?assertEqual(Out, iolist_to_binary(Dec)),
S2
end, S, Samples).
lists:foreach(
fun({In, Out}) ->
{proxy_ans, _ConnId, Packet} = decode_packet(In),
?assertEqual(Out, iolist_to_binary(Packet))
end, Samples).
%% decode_close_test() ->
%% S = tst_new(),
......
......@@ -6,6 +6,7 @@
{applications,
[lager,
ranch,
psq,
crypto,
ssl,
inets,
......@@ -55,9 +56,12 @@
%% only `{allowed_protocols, [mtp_secure]}` if you want to only allow
%% connections to this proxy with "dd"-secrets. Connections by other
%% protocols will be immediately closed.
{allowed_protocols, [mtp_abridged, mtp_intermediate, mtp_secure]}
{allowed_protocols, [mtp_abridged, mtp_intermediate, mtp_secure]},
%% module with function `notify/4' exported.
{init_dc_connections, 2},
{clients_per_dc_connection, 300}
%% Should be module with function `notify/4' exported.
%% See mtp_metric:notify/4 for details
%% {metric_backend, my_metric_backend}
......
%%%-------------------------------------------------------------------
%% @doc mtproto_proxy top level supervisor.
%% @end
%% <pre>
%% dc_pool_sup (simple_one_for_one)
%% dc_pool_1 [conn1, conn3, conn4, ..]
%% dc_pool_-1 [conn2, ..]
%% dc_pool_2 [conn5, conn7, ..]
%% dc_pool_-2 [conn6, conn8, ..]
%% ...
%% down_conn_sup (simple_one_for_one)
%% conn1
%% conn2
%% conn3
%% conn4
%% ...
%% connN
%% </pre>
%%%-------------------------------------------------------------------
-module(mtproto_proxy_sup).
......@@ -26,16 +41,17 @@ start_link() ->
%% Supervisor callbacks
%%====================================================================
%% Child :: {Id,StartFunc,Restart,Shutdown,Type,Modules}
init([]) ->
Childs = [#{id => mtp_config,
start => {mtp_config, start_link, []}}
],
{ok, {#{strategy => rest_for_one,
SupFlags = #{strategy => one_for_all, %TODO: maybe change strategy
intensity => 50,
period => 5},
Childs} }.
%%====================================================================
%% Internal functions
%%====================================================================
Childs = [#{id => mtp_down_conn_sup,
type => supervisor,
start => {mtp_down_conn_sup, start_link, []}},
#{id => mtp_dc_pool_sup,
type => supervisor,
start => {mtp_dc_pool_sup, start_link, []}},
#{id => mtp_config,
start => {mtp_config, start_link, []}}
],
{ok, {SupFlags, Childs}}.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment