Use standard process registration for dc pools instead of ETS

There was a race-condition, that caused crashes.
We could use gproc, but it feels a bit overkill for ~10-20 processes
parent 8669db42
......@@ -21,10 +21,6 @@
get_netloc_safe/1,
get_secret/0,
status/0]).
-export([register_name/2,
unregister_name/1,
whereis_name/1,
send/2]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
......@@ -35,7 +31,6 @@
-define(TAB, ?MODULE).
-define(IPS_KEY(DcId), {id, DcId}).
-define(POOL_KEY(DcId), {pool, DcId}).
-define(IDS_KEY, dc_ids).
-define(SECRET_URL, "https://core.telegram.org/getProxySecret").
-define(CONFIG_URL, "https://core.telegram.org/getProxyConfig").
......@@ -43,7 +38,6 @@
-define(APP, mtproto_proxy).
-record(state, {tab :: ets:tid(),
monitors = #{} :: #{pid() => {reference(), dc_id()}},
timer :: gen_timeout:tout()}).
-ifndef(OTP_RELEASE). % pre-OTP21
......@@ -73,11 +67,11 @@ get_downstream_safe(DcId, Opts) ->
end.
get_downstream_pool(DcId) ->
Key = ?POOL_KEY(DcId),
case ets:lookup(?TAB, Key) of
[] -> not_found;
[{Key, PoolPid}] ->
{ok, PoolPid}
try whereis(mtp_dc_pool:dc_to_pool_name(DcId)) of
undefined -> not_found;
Pid when is_pid(Pid) -> {ok, Pid}
catch error:invalid_dc_id ->
not_found
end.
-spec get_netloc_safe(dc_id()) -> {dc_id(), netloc()}.
......@@ -103,29 +97,6 @@ get_netloc(DcId) ->
{ok, IpPort}
end.
register_name(DcId, Pid) ->
case ets:insert_new(?TAB, {?POOL_KEY(DcId), Pid}) of
true ->
gen_server:cast(?MODULE, {reg, DcId, Pid}),
yes;
false -> no
end.
unregister_name(DcId) ->
%% making async monitors is a bad idea..
Pid = whereis_name(DcId),
gen_server:cast(?MODULE, {unreg, DcId, Pid}),
ets:delete(?TAB, ?POOL_KEY(DcId)).
whereis_name(DcId) ->
case get_downstream_pool(DcId) of
not_found -> undefined;
{ok, PoolPid} -> PoolPid
end.
send(Name, Msg) ->
whereis_name(Name) ! Msg.
-spec get_secret() -> binary().
get_secret() ->
......@@ -136,7 +107,8 @@ status() ->
[{?IDS_KEY, L}] = ets:lookup(?TAB, ?IDS_KEY),
lists:map(
fun(DcId) ->
DcPoolStatus = mtp_dc_pool:status(whereis_name(DcId)),
{ok, Pid} = get_downstream_pool(DcId),
DcPoolStatus = mtp_dc_pool:status(Pid),
DcPoolStatus#{dc_id => DcId}
end, L).
......@@ -161,14 +133,8 @@ init([]) ->
handle_call(_Request, _From, State) ->
Reply = ok,
{reply, Reply, State}.
handle_cast({reg, DcId, Pid}, #state{monitors = Mons} = State) ->
Ref = erlang:monitor(process, Pid),
Mons1 = Mons#{Pid => {Ref, DcId}},
{noreply, State#state{monitors = Mons1}};
handle_cast({unreg, DcId, Pid}, #state{monitors = Mons} = State) ->
{{Ref, DcId}, Mons1} = maps:take(Pid, Mons),
erlang:demonitor(Ref, [flush]),
{noreply, State#state{monitors = Mons1}}.
handle_cast(_Request, State) ->
{noreply, State}.
handle_info(timeout, #state{timer = Timer} =State) ->
case gen_timeout:is_expired(Timer) of
true ->
......@@ -179,11 +145,7 @@ handle_info(timeout, #state{timer = Timer} =State) ->
{noreply, State#state{timer = Timer1}};
false ->
{noreply, State#state{timer = gen_timeout:reset(Timer)}}
end;
handle_info({'DOWN', MonRef, process, Pid, _Reason}, #state{monitors = Mons} = State) ->
{{MonRef, DcId}, Mons1} = maps:take(Pid, Mons),
ets:delete(?TAB, ?POOL_KEY(DcId)),
{noreply, State#state{monitors = Mons1}}.
end.
terminate(_Reason, _State) ->
ok.
code_change(_OldVsn, State, _Extra) ->
......@@ -246,8 +208,6 @@ update_downstreams(Downstreams, Tab) ->
fun(DcId) ->
case get_downstream_pool(DcId) of
not_found ->
%% process will be registered asynchronously by
%% gen_server:start_link({via, ..
{ok, _Pid} = mtp_dc_pool_sup:start_pool(DcId);
{ok, _} ->
ok
......
......@@ -19,7 +19,9 @@
return/2,
add_connection/1,
ack_connected/2,
status/1]).
status/1,
valid_dc_id/1,
dc_to_pool_name/1]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
......@@ -50,7 +52,17 @@
%%% API
%%%===================================================================
start_link(DcId) ->
gen_server:start_link({via, mtp_config, DcId}, ?MODULE, DcId, []).
gen_server:start_link({local, dc_to_pool_name(DcId)}, ?MODULE, DcId, []).
valid_dc_id(DcId) ->
is_integer(DcId) andalso
-10 < DcId andalso
10 > DcId.
dc_to_pool_name(DcId) ->
valid_dc_id(DcId) orelse error(invalid_dc_id, [DcId]),
binary_to_atom(<<"mtp_dc_pool_", (integer_to_binary(DcId))/binary>>, utf8).
get(Pool, Upstream, #{addr := _} = Opts) ->
gen_server:call(Pool, {get, Upstream, Opts}).
......@@ -257,6 +269,7 @@ ds_get(St) ->
%% Return connection back to storage
-spec ds_return(downstream(), ds_store()) -> ds_store().
ds_return(Pid, St) ->
%% It may return 'undefined' if down_conn crashed
{ok, St1} = pid_psq:dec_priority(Pid, St),
St1.
......
......@@ -39,7 +39,7 @@
transport :: transport(),
codec = ident :: mtp_layer:layer(),
down :: gen_tcp:socket(),
down :: mtp_down_conn:handle(),
dc_id :: integer(),
ad_tag :: binary(),
......@@ -113,7 +113,7 @@ handle_call(_Request, _From, State) ->
Reply = ok,
{reply, Reply, State}.
handle_cast({proxy_ans, Down, Data}, #state{down = Down, listener = Listener} = S) ->
handle_cast({proxy_ans, Down, Data}, #state{down = Down} = S) ->
%% telegram server -> proxy
case up_send(Data, S) of
{ok, S1} ->
......@@ -223,8 +223,7 @@ state_timeout(stop) ->
%% Handle telegram client -> proxy stream
handle_upstream_data(Bin, #state{stage = tunnel,
codec = UpCodec,
listener = Listener} = S) ->
codec = UpCodec} = S) ->
{ok, S3, UpCodec1} =
mtp_layer:fold_packets(
fun(Decoded, S1) ->
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment