Use standard process registration for dc pools instead of ETS

There was a race-condition, that caused crashes.
We could use gproc, but it feels a bit overkill for ~10-20 processes
parent 8669db42
...@@ -21,10 +21,6 @@ ...@@ -21,10 +21,6 @@
get_netloc_safe/1, get_netloc_safe/1,
get_secret/0, get_secret/0,
status/0]). status/0]).
-export([register_name/2,
unregister_name/1,
whereis_name/1,
send/2]).
%% gen_server callbacks %% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
...@@ -35,7 +31,6 @@ ...@@ -35,7 +31,6 @@
-define(TAB, ?MODULE). -define(TAB, ?MODULE).
-define(IPS_KEY(DcId), {id, DcId}). -define(IPS_KEY(DcId), {id, DcId}).
-define(POOL_KEY(DcId), {pool, DcId}).
-define(IDS_KEY, dc_ids). -define(IDS_KEY, dc_ids).
-define(SECRET_URL, "https://core.telegram.org/getProxySecret"). -define(SECRET_URL, "https://core.telegram.org/getProxySecret").
-define(CONFIG_URL, "https://core.telegram.org/getProxyConfig"). -define(CONFIG_URL, "https://core.telegram.org/getProxyConfig").
...@@ -43,7 +38,6 @@ ...@@ -43,7 +38,6 @@
-define(APP, mtproto_proxy). -define(APP, mtproto_proxy).
-record(state, {tab :: ets:tid(), -record(state, {tab :: ets:tid(),
monitors = #{} :: #{pid() => {reference(), dc_id()}},
timer :: gen_timeout:tout()}). timer :: gen_timeout:tout()}).
-ifndef(OTP_RELEASE). % pre-OTP21 -ifndef(OTP_RELEASE). % pre-OTP21
...@@ -73,11 +67,11 @@ get_downstream_safe(DcId, Opts) -> ...@@ -73,11 +67,11 @@ get_downstream_safe(DcId, Opts) ->
end. end.
get_downstream_pool(DcId) -> get_downstream_pool(DcId) ->
Key = ?POOL_KEY(DcId), try whereis(mtp_dc_pool:dc_to_pool_name(DcId)) of
case ets:lookup(?TAB, Key) of undefined -> not_found;
[] -> not_found; Pid when is_pid(Pid) -> {ok, Pid}
[{Key, PoolPid}] -> catch error:invalid_dc_id ->
{ok, PoolPid} not_found
end. end.
-spec get_netloc_safe(dc_id()) -> {dc_id(), netloc()}. -spec get_netloc_safe(dc_id()) -> {dc_id(), netloc()}.
...@@ -103,29 +97,6 @@ get_netloc(DcId) -> ...@@ -103,29 +97,6 @@ get_netloc(DcId) ->
{ok, IpPort} {ok, IpPort}
end. end.
register_name(DcId, Pid) ->
case ets:insert_new(?TAB, {?POOL_KEY(DcId), Pid}) of
true ->
gen_server:cast(?MODULE, {reg, DcId, Pid}),
yes;
false -> no
end.
unregister_name(DcId) ->
%% making async monitors is a bad idea..
Pid = whereis_name(DcId),
gen_server:cast(?MODULE, {unreg, DcId, Pid}),
ets:delete(?TAB, ?POOL_KEY(DcId)).
whereis_name(DcId) ->
case get_downstream_pool(DcId) of
not_found -> undefined;
{ok, PoolPid} -> PoolPid
end.
send(Name, Msg) ->
whereis_name(Name) ! Msg.
-spec get_secret() -> binary(). -spec get_secret() -> binary().
get_secret() -> get_secret() ->
...@@ -136,7 +107,8 @@ status() -> ...@@ -136,7 +107,8 @@ status() ->
[{?IDS_KEY, L}] = ets:lookup(?TAB, ?IDS_KEY), [{?IDS_KEY, L}] = ets:lookup(?TAB, ?IDS_KEY),
lists:map( lists:map(
fun(DcId) -> fun(DcId) ->
DcPoolStatus = mtp_dc_pool:status(whereis_name(DcId)), {ok, Pid} = get_downstream_pool(DcId),
DcPoolStatus = mtp_dc_pool:status(Pid),
DcPoolStatus#{dc_id => DcId} DcPoolStatus#{dc_id => DcId}
end, L). end, L).
...@@ -161,14 +133,8 @@ init([]) -> ...@@ -161,14 +133,8 @@ init([]) ->
handle_call(_Request, _From, State) -> handle_call(_Request, _From, State) ->
Reply = ok, Reply = ok,
{reply, Reply, State}. {reply, Reply, State}.
handle_cast({reg, DcId, Pid}, #state{monitors = Mons} = State) -> handle_cast(_Request, State) ->
Ref = erlang:monitor(process, Pid), {noreply, State}.
Mons1 = Mons#{Pid => {Ref, DcId}},
{noreply, State#state{monitors = Mons1}};
handle_cast({unreg, DcId, Pid}, #state{monitors = Mons} = State) ->
{{Ref, DcId}, Mons1} = maps:take(Pid, Mons),
erlang:demonitor(Ref, [flush]),
{noreply, State#state{monitors = Mons1}}.
handle_info(timeout, #state{timer = Timer} =State) -> handle_info(timeout, #state{timer = Timer} =State) ->
case gen_timeout:is_expired(Timer) of case gen_timeout:is_expired(Timer) of
true -> true ->
...@@ -179,11 +145,7 @@ handle_info(timeout, #state{timer = Timer} =State) -> ...@@ -179,11 +145,7 @@ handle_info(timeout, #state{timer = Timer} =State) ->
{noreply, State#state{timer = Timer1}}; {noreply, State#state{timer = Timer1}};
false -> false ->
{noreply, State#state{timer = gen_timeout:reset(Timer)}} {noreply, State#state{timer = gen_timeout:reset(Timer)}}
end; end.
handle_info({'DOWN', MonRef, process, Pid, _Reason}, #state{monitors = Mons} = State) ->
{{MonRef, DcId}, Mons1} = maps:take(Pid, Mons),
ets:delete(?TAB, ?POOL_KEY(DcId)),
{noreply, State#state{monitors = Mons1}}.
terminate(_Reason, _State) -> terminate(_Reason, _State) ->
ok. ok.
code_change(_OldVsn, State, _Extra) -> code_change(_OldVsn, State, _Extra) ->
...@@ -246,8 +208,6 @@ update_downstreams(Downstreams, Tab) -> ...@@ -246,8 +208,6 @@ update_downstreams(Downstreams, Tab) ->
fun(DcId) -> fun(DcId) ->
case get_downstream_pool(DcId) of case get_downstream_pool(DcId) of
not_found -> not_found ->
%% process will be registered asynchronously by
%% gen_server:start_link({via, ..
{ok, _Pid} = mtp_dc_pool_sup:start_pool(DcId); {ok, _Pid} = mtp_dc_pool_sup:start_pool(DcId);
{ok, _} -> {ok, _} ->
ok ok
......
...@@ -19,7 +19,9 @@ ...@@ -19,7 +19,9 @@
return/2, return/2,
add_connection/1, add_connection/1,
ack_connected/2, ack_connected/2,
status/1]). status/1,
valid_dc_id/1,
dc_to_pool_name/1]).
%% gen_server callbacks %% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
...@@ -50,7 +52,17 @@ ...@@ -50,7 +52,17 @@
%%% API %%% API
%%%=================================================================== %%%===================================================================
start_link(DcId) -> start_link(DcId) ->
gen_server:start_link({via, mtp_config, DcId}, ?MODULE, DcId, []). gen_server:start_link({local, dc_to_pool_name(DcId)}, ?MODULE, DcId, []).
valid_dc_id(DcId) ->
is_integer(DcId) andalso
-10 < DcId andalso
10 > DcId.
dc_to_pool_name(DcId) ->
valid_dc_id(DcId) orelse error(invalid_dc_id, [DcId]),
binary_to_atom(<<"mtp_dc_pool_", (integer_to_binary(DcId))/binary>>, utf8).
get(Pool, Upstream, #{addr := _} = Opts) -> get(Pool, Upstream, #{addr := _} = Opts) ->
gen_server:call(Pool, {get, Upstream, Opts}). gen_server:call(Pool, {get, Upstream, Opts}).
...@@ -257,6 +269,7 @@ ds_get(St) -> ...@@ -257,6 +269,7 @@ ds_get(St) ->
%% Return connection back to storage %% Return connection back to storage
-spec ds_return(downstream(), ds_store()) -> ds_store(). -spec ds_return(downstream(), ds_store()) -> ds_store().
ds_return(Pid, St) -> ds_return(Pid, St) ->
%% It may return 'undefined' if down_conn crashed
{ok, St1} = pid_psq:dec_priority(Pid, St), {ok, St1} = pid_psq:dec_priority(Pid, St),
St1. St1.
......
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
transport :: transport(), transport :: transport(),
codec = ident :: mtp_layer:layer(), codec = ident :: mtp_layer:layer(),
down :: gen_tcp:socket(), down :: mtp_down_conn:handle(),
dc_id :: integer(), dc_id :: integer(),
ad_tag :: binary(), ad_tag :: binary(),
...@@ -113,7 +113,7 @@ handle_call(_Request, _From, State) -> ...@@ -113,7 +113,7 @@ handle_call(_Request, _From, State) ->
Reply = ok, Reply = ok,
{reply, Reply, State}. {reply, Reply, State}.
handle_cast({proxy_ans, Down, Data}, #state{down = Down, listener = Listener} = S) -> handle_cast({proxy_ans, Down, Data}, #state{down = Down} = S) ->
%% telegram server -> proxy %% telegram server -> proxy
case up_send(Data, S) of case up_send(Data, S) of
{ok, S1} -> {ok, S1} ->
...@@ -223,8 +223,7 @@ state_timeout(stop) -> ...@@ -223,8 +223,7 @@ state_timeout(stop) ->
%% Handle telegram client -> proxy stream %% Handle telegram client -> proxy stream
handle_upstream_data(Bin, #state{stage = tunnel, handle_upstream_data(Bin, #state{stage = tunnel,
codec = UpCodec, codec = UpCodec} = S) ->
listener = Listener} = S) ->
{ok, S3, UpCodec1} = {ok, S3, UpCodec1} =
mtp_layer:fold_packets( mtp_layer:fold_packets(
fun(Decoded, S1) -> fun(Decoded, S1) ->
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment