Refactor codecs

* Get rid of per-codec internal read buffer
parent 6590a56e
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
-dialyzer(no_improper_lists). -dialyzer(no_improper_lists).
-record(st, -record(st,
{buffer = <<>> :: binary()}). {}).
-define(MAX_PACKET_SIZE, 1 * 1024 * 1024). % 1mb -define(MAX_PACKET_SIZE, 1 * 1024 * 1024). % 1mb
-define(APP, mtproto_proxy). -define(APP, mtproto_proxy).
...@@ -25,33 +25,31 @@ ...@@ -25,33 +25,31 @@
new() -> new() ->
#st{}. #st{}.
-spec try_decode_packet(binary(), codec()) -> {ok, binary(), codec()} -spec try_decode_packet(binary(), codec()) -> {ok, binary(), binary(), codec()}
| {incomplete, codec()}. | {incomplete, codec()}.
try_decode_packet(<<Flag, Len:24/unsigned-little-integer, Rest/binary>> = Data, try_decode_packet(<<Flag, Len:24/unsigned-little-integer, Rest/binary>>,
#st{buffer = <<>>} = St) when Flag == 127; Flag == 255 -> #st{} = St) when Flag == 127; Flag == 255 ->
Len1 = Len * 4, Len1 = Len * 4,
try_decode_packet_len(Len1, Rest, Data, St); try_decode_packet_len(Len1, Rest, St);
try_decode_packet(<<Len, Rest/binary>> = Data, try_decode_packet(<<Len, Rest/binary>>,
#st{buffer = <<>>} = St) when Len >= 128 -> #st{} = St) when Len >= 128 ->
Len1 = (Len - 128) * 4, Len1 = (Len - 128) * 4,
try_decode_packet_len(Len1, Rest, Data, St); try_decode_packet_len(Len1, Rest, St);
try_decode_packet(<<Len, Rest/binary>> = Data, try_decode_packet(<<Len, Rest/binary>>,
#st{buffer = <<>>} = St) when Len < 127 -> #st{} = St) when Len < 127 ->
Len1 = Len * 4, Len1 = Len * 4,
try_decode_packet_len(Len1, Rest, Data, St); try_decode_packet_len(Len1, Rest, St);
try_decode_packet(Bin, #st{buffer = Buf} = St) when byte_size(Buf) > 0 -> try_decode_packet(_, St) ->
try_decode_packet(<<Buf/binary, Bin/binary>>, St#st{buffer = <<>>}); {incomplete, St}.
try_decode_packet(Bin, #st{buffer = <<>>} = St) ->
{incomplete, St#st{buffer = Bin}}.
try_decode_packet_len(Len, LenStripped, Data, St) -> try_decode_packet_len(Len, LenStripped, St) ->
(Len < ?MAX_PACKET_SIZE) (Len < ?MAX_PACKET_SIZE)
orelse error({protocol_error, abridged_max_size, Len}), orelse error({protocol_error, abridged_max_size, Len}),
case LenStripped of case LenStripped of
<<Packet:Len/binary, Rest/binary>> -> <<Packet:Len/binary, Rest/binary>> ->
{ok, Packet, St#st{buffer = Rest}}; {ok, Packet, Rest, St};
_ -> _ ->
{incomplete, St#st{buffer = Data}} {incomplete, St}
end. end.
-spec encode_packet(binary(), codec()) -> {iodata(), codec()}. -spec encode_packet(binary(), codec()) -> {iodata(), codec()}.
......
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
-export_type([codec/0]). -export_type([codec/0]).
-record(baes_st, -record(baes_st,
{decode_buf :: binary(), {block_size :: pos_integer(),
block_size :: pos_integer(),
encrypt :: any(), % aes state encrypt :: any(), % aes state
decrypt :: any() % aes state decrypt :: any() % aes state
}). }).
...@@ -30,7 +29,6 @@ ...@@ -30,7 +29,6 @@
new(EncKey, EncIv, DecKey, DecIv, BlockSize) -> new(EncKey, EncIv, DecKey, DecIv, BlockSize) ->
#baes_st{ #baes_st{
decode_buf = <<>>,
block_size = BlockSize, block_size = BlockSize,
encrypt = {EncKey, EncIv}, encrypt = {EncKey, EncIv},
decrypt = {DecKey, DecIv} decrypt = {DecKey, DecIv}
...@@ -45,40 +43,37 @@ encrypt(Data, #baes_st{block_size = BSize, ...@@ -45,40 +43,37 @@ encrypt(Data, #baes_st{block_size = BSize,
{Encrypted, S#baes_st{encrypt = {EncKey, crypto:next_iv(aes_cbc, Encrypted)}}}. {Encrypted, S#baes_st{encrypt = {EncKey, crypto:next_iv(aes_cbc, Encrypted)}}}.
-spec decrypt(binary(), codec()) -> {binary(), codec()}. -spec decrypt(binary(), codec()) -> {Data :: binary(), Tail :: binary(), codec()}.
decrypt(Data, #baes_st{block_size = BSize, decrypt(Data, #baes_st{block_size = BSize} = S) ->
decode_buf = <<>>} = S) ->
Size = byte_size(Data), Size = byte_size(Data),
Div = Size div BSize, Div = Size div BSize,
Rem = Size rem BSize, Rem = Size rem BSize,
case {Div, Rem} of case {Div, Rem} of
{0, _} -> {0, _} ->
%% Not enough bytes %% Not enough bytes
{<<>>, S#baes_st{decode_buf = Data}}; {<<>>, Data, S};
{_, 0} -> {_, 0} ->
%% Aligned %% Aligned
do_decrypt(Data, S); do_decrypt(Data, <<>>, S);
{_, Tail} -> {_, Tail} ->
%% N blocks + reminder %% N blocks + reminder
Head = Size - Tail, Head = Size - Tail,
<<ToDecode:Head/binary, Reminder/binary>> = Data, <<ToDecode:Head/binary, Reminder/binary>> = Data,
do_decrypt(ToDecode, S#baes_st{decode_buf = Reminder}) do_decrypt(ToDecode, Reminder, S)
end; end.
decrypt(Data, #baes_st{decode_buf = Buf} = S) ->
decrypt(<<Buf/binary, Data/binary>>, S#baes_st{decode_buf = <<>>}).
do_decrypt(Data, #baes_st{decrypt = {DecKey, DecIv}} = S) -> do_decrypt(Data, Tail, #baes_st{decrypt = {DecKey, DecIv}} = S) ->
Decrypted = crypto:block_decrypt(aes_cbc, DecKey, DecIv, Data), Decrypted = crypto:block_decrypt(aes_cbc, DecKey, DecIv, Data),
NewDecIv = crypto:next_iv(aes_cbc, Data), NewDecIv = crypto:next_iv(aes_cbc, Data),
{Decrypted, S#baes_st{decrypt = {DecKey, NewDecIv}}}. {Decrypted, Tail, S#baes_st{decrypt = {DecKey, NewDecIv}}}.
%% To comply mtp_layer interface %% To comply mtp_layer interface
try_decode_packet(Bin, S) -> try_decode_packet(Bin, S) ->
case decrypt(Bin, S) of case decrypt(Bin, S) of
{<<>>, S1} -> {<<>>, _Tail, S1} ->
{incomplete, S1}; {incomplete, S1};
{Dec, S1} -> {Dec, Tail, S1} ->
{ok, Dec, S1} {ok, Dec, Tail, S1}
end. end.
encode_packet(Bin, S) -> encode_packet(Bin, S) ->
...@@ -107,7 +102,7 @@ decrypt_test() -> ...@@ -107,7 +102,7 @@ decrypt_test() ->
], ],
lists:foldl( lists:foldl(
fun({In, Out}, S1) -> fun({In, Out}, S1) ->
{Dec, S2} = decrypt(In, S1), {Dec, <<>>, S2} = decrypt(In, S1),
?assertEqual(Out, Dec), ?assertEqual(Out, Dec),
S2 S2
end, S, Samples). end, S, Samples).
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
decompose/1, decompose/1,
try_decode_packet/2, try_decode_packet/2,
encode_packet/2, encode_packet/2,
fold_packets/4]). fold_packets/4,
is_empty/1]).
-export_type([codec/0]). -export_type([codec/0]).
-type state() :: any(). -type state() :: any().
...@@ -30,13 +31,15 @@ ...@@ -30,13 +31,15 @@
-record(codec, -record(codec,
{crypto_mod :: crypto_codec(), {crypto_mod :: crypto_codec(),
crypto_state :: any(), crypto_state :: any(),
crypto_buf = <<>> :: binary(),
packet_mod :: packet_codec(), packet_mod :: packet_codec(),
packet_state :: any()}). packet_state :: any(),
packet_buf = <<>> :: binary()}).
-define(APP, mtproto_proxy). -define(APP, mtproto_proxy).
-callback try_decode_packet(binary(), state()) -> -callback try_decode_packet(binary(), state()) ->
{ok, binary(), state()} {ok, Packet :: binary(), Tail :: binary(), state()}
| {incomplete, state()}. | {incomplete, state()}.
-callback encode_packet(iodata(), state()) -> -callback encode_packet(iodata(), state()) ->
...@@ -60,26 +63,48 @@ decompose(#codec{crypto_mod = CryptoMod, crypto_state = CryptoState, ...@@ -60,26 +63,48 @@ decompose(#codec{crypto_mod = CryptoMod, crypto_state = CryptoState,
%% try_decode_packet(Inner) |> try_decode_packet(Outer) %% try_decode_packet(Inner) |> try_decode_packet(Outer)
-spec try_decode_packet(binary(), codec()) -> {ok, binary(), codec()} | {incomplete, codec()}. -spec try_decode_packet(binary(), codec()) -> {ok, binary(), codec()} | {incomplete, codec()}.
try_decode_packet(Bin, #codec{crypto_mod = CryptoMod, try_decode_packet(Bin, S) ->
crypto_state = CryptoSt, decode_crypto(Bin, S).
packet_mod = PacketMod,
packet_state = PacketSt} = S) -> decode_crypto(<<>>, #codec{crypto_state = CS, crypto_buf = <<>>} = S) ->
{Dec1, CryptoSt1} = %% There is smth in packet buffer
case CryptoMod:try_decode_packet(Bin, CryptoSt) of decode_packet(<<>>, CS, <<>>, S);
{incomplete, PacketSt1_} -> decode_crypto(Bin, #codec{crypto_mod = CryptoMod,
%% We have to check if something is left in packet's buffers crypto_state = CryptoSt,
{<<>>, PacketSt1_}; crypto_buf = <<>>} = S) ->
{ok, Dec1_, PacketSt1_} -> case CryptoMod:try_decode_packet(Bin, CryptoSt) of
{Dec1_, PacketSt1_} {incomplete, CryptoSt1} ->
end, decode_packet(<<>>, CryptoSt1, <<>>, S);
case PacketMod:try_decode_packet(Dec1, PacketSt) of {ok, Dec1, Tail1, CryptoSt1} ->
decode_packet(Dec1, CryptoSt1, Tail1, S)
end;
decode_crypto(Bin, #codec{crypto_buf = Buf} = S) ->
decode_crypto(<<Buf/binary, Bin/binary>>, S#codec{crypto_buf = <<>>}).
decode_packet(<<>>, CryptoSt, CryptoTail, #codec{packet_buf = <<>>} = S) ->
%% Crypto produced nothing and there is nothing in packet buf
{incomplete, S#codec{crypto_state = CryptoSt, crypto_buf = CryptoTail}};
decode_packet(Bin, CryptoSt, CryptoTail, #codec{packet_mod = PacketMod,
packet_state = PacketSt,
packet_buf = <<>>} = S) ->
%% Crypto produced smth, and there is nothing in pkt buf
case PacketMod:try_decode_packet(Bin, PacketSt) of
{incomplete, PacketSt1} -> {incomplete, PacketSt1} ->
{incomplete, S#codec{crypto_state = CryptoSt1, {incomplete, S#codec{crypto_state = CryptoSt,
packet_state = PacketSt1}}; crypto_buf = CryptoTail,
{ok, Dec2, PacketSt1} -> packet_state = PacketSt1,
{ok, Dec2, S#codec{crypto_state = CryptoSt1, packet_buf = Bin
packet_state = PacketSt1}} }};
end. {ok, Dec2, Tail, PacketSt1} ->
{ok, Dec2, S#codec{crypto_state = CryptoSt,
crypto_buf = CryptoTail,
packet_state = PacketSt1,
packet_buf = Tail}}
end;
decode_packet(Bin, CSt, CTail, #codec{packet_buf = Buf} = S) ->
decode_packet(<<Buf/binary, Bin/binary>>, CSt, CTail, S#codec{packet_buf = <<>>}).
%% encode_packet(Outer) |> encode_packet(Inner) %% encode_packet(Outer) |> encode_packet(Inner)
-spec encode_packet(iodata(), codec()) -> {iodata(), codec()}. -spec encode_packet(iodata(), codec()) -> {iodata(), codec()}.
...@@ -105,3 +130,7 @@ fold_packets(Fun, FoldSt, Data, Codec) -> ...@@ -105,3 +130,7 @@ fold_packets(Fun, FoldSt, Data, Codec) ->
{incomplete, Codec1} -> {incomplete, Codec1} ->
{ok, FoldSt, Codec1} {ok, FoldSt, Codec1}
end. end.
-spec is_empty(codec()) -> boolean().
is_empty(#codec{packet_buf = <<>>, crypto_buf = <<>>}) -> true;
is_empty(_) -> false.
...@@ -19,8 +19,7 @@ ...@@ -19,8 +19,7 @@
-dialyzer(no_improper_lists). -dialyzer(no_improper_lists).
-record(full_st, -record(full_st,
{decode_buf = <<>> :: binary(), {enc_seq_no :: integer(),
enc_seq_no :: integer(),
dec_seq_no :: integer()}). dec_seq_no :: integer()}).
-define(MIN_MSG_LEN, 12). -define(MIN_MSG_LEN, 12).
-define(MAX_MSG_LEN, 16777216). %2^24 - 16mb -define(MAX_MSG_LEN, 16777216). %2^24 - 16mb
...@@ -39,11 +38,8 @@ new(EncSeqNo, DecSeqNo) -> ...@@ -39,11 +38,8 @@ new(EncSeqNo, DecSeqNo) ->
#full_st{enc_seq_no = EncSeqNo, #full_st{enc_seq_no = EncSeqNo,
dec_seq_no = DecSeqNo}. dec_seq_no = DecSeqNo}.
try_decode_packet(<<4:32/little, Bin/binary>>, #full_st{decode_buf = <<>>} = S) -> try_decode_packet(<<Len:32/little, PktSeqNo:32/signed-little, Tail/binary>>,
%% Skip padding #full_st{dec_seq_no = SeqNo} = S) ->
try_decode_packet(Bin, S);
try_decode_packet(<<Len:32/little, PktSeqNo:32/signed-little, Tail/binary>> = Bin,
#full_st{decode_buf = <<>>, dec_seq_no = SeqNo} = S) ->
((Len rem byte_size(?PAD)) == 0) ((Len rem byte_size(?PAD)) == 0)
orelse error({wrong_alignement, Len}), orelse error({wrong_alignement, Len}),
((?MIN_MSG_LEN =< Len) and (Len =< ?MAX_MSG_LEN)) ((?MIN_MSG_LEN =< Len) and (Len =< ?MAX_MSG_LEN))
...@@ -56,14 +52,18 @@ try_decode_packet(<<Len:32/little, PktSeqNo:32/signed-little, Tail/binary>> = Bi ...@@ -56,14 +52,18 @@ try_decode_packet(<<Len:32/little, PktSeqNo:32/signed-little, Tail/binary>> = Bi
PacketCrc = erlang:crc32([<<Len:32/little, PktSeqNo:32/little>> | Body]), PacketCrc = erlang:crc32([<<Len:32/little, PktSeqNo:32/little>> | Body]),
(CRC == PacketCrc) (CRC == PacketCrc)
orelse error({wrong_checksum, CRC, PacketCrc}), orelse error({wrong_checksum, CRC, PacketCrc}),
{ok, Body, S#full_st{decode_buf = Rest, dec_seq_no = SeqNo + 1}}; {ok, Body, skip_padding(Len, Rest), S#full_st{dec_seq_no = SeqNo + 1}};
_ -> _ ->
{incomplete, S#full_st{decode_buf = Bin}} {incomplete, S}
end; end;
try_decode_packet(Bin, #full_st{decode_buf = Buf} = S) when byte_size(Buf) > 0 -> try_decode_packet(_, S) ->
try_decode_packet(<<Buf/binary, Bin/binary>>, S#full_st{decode_buf = <<>>}); {incomplete, S}.
try_decode_packet(Bin, #full_st{decode_buf = <<>>} = S) ->
{incomplete, S#full_st{decode_buf = Bin}}. skip_padding(PktLen, Bin) ->
PaddingSize = padding_size(PktLen),
<<_:PaddingSize/binary, Tail/binary>> = Bin,
Tail.
encode_packet(Bin, #full_st{enc_seq_no = SeqNo} = S) -> encode_packet(Bin, #full_st{enc_seq_no = SeqNo} = S) ->
BodySize = iolist_size(Bin), BodySize = iolist_size(Bin),
...@@ -77,12 +77,14 @@ encode_packet(Bin, #full_st{enc_seq_no = SeqNo} = S) -> ...@@ -77,12 +77,14 @@ encode_packet(Bin, #full_st{enc_seq_no = SeqNo} = S) ->
CheckSum = erlang:crc32(MsgNoChecksum), CheckSum = erlang:crc32(MsgNoChecksum),
FullMsg = [MsgNoChecksum | <<CheckSum:32/unsigned-little-integer>>], FullMsg = [MsgNoChecksum | <<CheckSum:32/unsigned-little-integer>>],
Len = iolist_size(FullMsg), Len = iolist_size(FullMsg),
%% XXX: is there a cleaner way? NPaddings = padding_size(Len) div byte_size(?PAD),
PaddingSize = (?BLOCK_SIZE - (Len rem ?BLOCK_SIZE)) rem ?BLOCK_SIZE,
NPaddings = PaddingSize div byte_size(?PAD),
Padding = lists:duplicate(NPaddings, ?PAD), Padding = lists:duplicate(NPaddings, ?PAD),
{[FullMsg | Padding], S#full_st{enc_seq_no = SeqNo + 1}}. {[FullMsg | Padding], S#full_st{enc_seq_no = SeqNo + 1}}.
padding_size(Len) ->
%% XXX: is there a cleaner way?
(?BLOCK_SIZE - (Len rem ?BLOCK_SIZE)) rem ?BLOCK_SIZE.
-ifdef(TEST). -ifdef(TEST).
-include_lib("eunit/include/eunit.hrl"). -include_lib("eunit/include/eunit.hrl").
...@@ -132,6 +134,7 @@ decode_none_test() -> ...@@ -132,6 +134,7 @@ decode_none_test() ->
{incomplete, S}, try_decode_packet(<<>>, S)). {incomplete, S}, try_decode_packet(<<>>, S)).
codec_test() -> codec_test() ->
%% Overhead is 12b per-packet
S = new(), S = new(),
Packets = [ Packets = [
binary:copy(<<0>>, 4), %non-padded binary:copy(<<0>>, 4), %non-padded
...@@ -143,7 +146,7 @@ codec_test() -> ...@@ -143,7 +146,7 @@ codec_test() ->
fun(B, S1) -> fun(B, S1) ->
{Encoded, S2} = encode_packet(B, S1), {Encoded, S2} = encode_packet(B, S1),
BinEncoded = iolist_to_binary(Encoded), BinEncoded = iolist_to_binary(Encoded),
{ok, Decoded, S3} = try_decode_packet(BinEncoded, S2), {ok, Decoded, <<>>, S3} = try_decode_packet(BinEncoded, S2),
?assertEqual(B, Decoded, {BinEncoded, S2, S3}), ?assertEqual(B, Decoded, {BinEncoded, S2, S3}),
S3 S3
end, S, Packets). end, S, Packets).
...@@ -164,9 +167,9 @@ codec_stream_test() -> ...@@ -164,9 +167,9 @@ codec_stream_test() ->
end, {[], S}, Packets), end, {[], S}, Packets),
lists:foldl( lists:foldl(
fun(B, {Enc, S1}) -> fun(B, {Enc, S1}) ->
{ok, Dec, S2} = try_decode_packet(Enc, S1), {ok, Dec, Rest, S2} = try_decode_packet(Enc, S1),
?assertEqual(B, Dec), ?assertEqual(B, Dec),
{<<>>, S2} {Rest, S2}
end, {iolist_to_binary(Encoded), SS}, Packets). end, {iolist_to_binary(Encoded), SS}, Packets).
-endif. -endif.
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
-dialyzer(no_improper_lists). -dialyzer(no_improper_lists).
-record(int_st, -record(int_st,
{padding = false :: boolean(), {padding = false :: boolean()}).
buffer = <<>> :: binary()}).
-define(MAX_PACKET_SIZE, 1 * 1024 * 1024). % 1mb -define(MAX_PACKET_SIZE, 1 * 1024 * 1024). % 1mb
-define(APP, mtproto_proxy). -define(APP, mtproto_proxy).
-define(MAX_SIZE, 16#80000000). -define(MAX_SIZE, 16#80000000).
...@@ -32,10 +31,9 @@ new() -> ...@@ -32,10 +31,9 @@ new() ->
new(Opts) -> new(Opts) ->
#int_st{padding = maps:get(padding, Opts, false)}. #int_st{padding = maps:get(padding, Opts, false)}.
-spec try_decode_packet(binary(), codec()) -> {ok, binary(), codec()} -spec try_decode_packet(binary(), codec()) -> {ok, binary(), binary(), codec()}
| {incomplete, codec()}. | {incomplete, codec()}.
try_decode_packet(<<Len:32/unsigned-little, _/binary>> = Data, try_decode_packet(<<Len:32/unsigned-little, Tail/binary>>, St) ->
#int_st{buffer = <<>>} = St) ->
Len1 = case Len < ?MAX_SIZE of Len1 = case Len < ?MAX_SIZE of
true -> Len; true -> Len;
false -> Len - ?MAX_SIZE false -> Len - ?MAX_SIZE
...@@ -45,11 +43,9 @@ try_decode_packet(<<Len:32/unsigned-little, _/binary>> = Data, ...@@ -45,11 +43,9 @@ try_decode_packet(<<Len:32/unsigned-little, _/binary>> = Data,
begin begin
error({protocol_error, intermediate_max_size, Len1}) error({protocol_error, intermediate_max_size, Len1})
end, end,
try_decode_packet_len(Len1, Data, St); try_decode_packet_len(Len1, Tail, St);
try_decode_packet(Bin, #int_st{buffer = Buf} = St) when byte_size(Buf) > 0 -> try_decode_packet(_, St) ->
try_decode_packet(<<Buf/binary, Bin/binary>>, St#int_st{buffer = <<>>}); {incomplete, St}.
try_decode_packet(Bin, #int_st{buffer = <<>>} = St) ->
{incomplete, St#int_st{buffer = Bin}}.
try_decode_packet_len(Len, Data, #int_st{padding = Pad} = St) -> try_decode_packet_len(Len, Data, #int_st{padding = Pad} = St) ->
Padding = case Pad of Padding = case Pad of
...@@ -58,10 +54,10 @@ try_decode_packet_len(Len, Data, #int_st{padding = Pad} = St) -> ...@@ -58,10 +54,10 @@ try_decode_packet_len(Len, Data, #int_st{padding = Pad} = St) ->
end, end,
NopadLen = Len - Padding, NopadLen = Len - Padding,
case Data of case Data of
<<_:4/binary, Packet:NopadLen/binary, _Padding:Padding/binary, Rest/binary>> -> <<Packet:NopadLen/binary, _Padding:Padding/binary, Rest/binary>> ->
{ok, Packet, St#int_st{buffer = Rest}}; {ok, Packet, Rest, St};
_ -> _ ->
{incomplete, St#int_st{buffer = Data}} {incomplete, St}
end. end.
-spec encode_packet(iodata(), codec()) -> {iodata(), codec()}. -spec encode_packet(iodata(), codec()) -> {iodata(), codec()}.
......
...@@ -19,9 +19,9 @@ ...@@ -19,9 +19,9 @@
new() -> new() ->
?MODULE. ?MODULE.
-spec try_decode_packet(binary(), codec()) -> {ok, binary(), codec()}. -spec try_decode_packet(binary(), codec()) -> {ok, binary(), binary(), codec()}.
try_decode_packet(Data, ?MODULE) -> try_decode_packet(Data, ?MODULE) ->
{ok, Data, ?MODULE}. {ok, Data, <<>>, ?MODULE}.
-spec encode_packet(binary(), codec()) -> {binary(), codec()}. -spec encode_packet(binary(), codec()) -> {binary(), codec()}.
encode_packet(Data, ?MODULE) -> encode_packet(Data, ?MODULE) ->
......
...@@ -119,7 +119,7 @@ from_header(Header, Secret, AllowedProtocols) when byte_size(Header) == 64 -> ...@@ -119,7 +119,7 @@ from_header(Header, Secret, AllowedProtocols) when byte_size(Header) == 64 ->
{EncKey, EncIV} = init_up_encrypt(Header, Secret), {EncKey, EncIV} = init_up_encrypt(Header, Secret),
{DecKey, DecIV} = init_up_decrypt(Header, Secret), {DecKey, DecIV} = init_up_decrypt(Header, Secret),
St = new(EncKey, EncIV, DecKey, DecIV), St = new(EncKey, EncIV, DecKey, DecIV),
{<<_:56/binary, Bin1:6/binary, _:2/binary>>, St1} = decrypt(Header, St), {<<_:56/binary, Bin1:6/binary, _:2/binary>>, <<>>, St1} = decrypt(Header, St),
case get_protocol(Bin1) of case get_protocol(Bin1) of
{error, unknown_protocol} = Err -> {error, unknown_protocol} = Err ->
Err; Err;
...@@ -168,17 +168,17 @@ encrypt(Data, #st{encrypt = Enc} = St) -> ...@@ -168,17 +168,17 @@ encrypt(Data, #st{encrypt = Enc} = St) ->
{Enc1, Encrypted} = crypto:stream_encrypt(Enc, Data), {Enc1, Encrypted} = crypto:stream_encrypt(Enc, Data),
{Encrypted, St#st{encrypt = Enc1}}. {Encrypted, St#st{encrypt = Enc1}}.
-spec decrypt(iodata(), codec()) -> {binary(), codec()}. -spec decrypt(iodata(), codec()) -> {binary(), binary(), codec()}.
decrypt(Encrypted, #st{decrypt = Dec} = St) -> decrypt(Encrypted, #st{decrypt = Dec} = St) ->
{Dec1, Data} = crypto:stream_encrypt(Dec, Encrypted), {Dec1, Data} = crypto:stream_encrypt(Dec, Encrypted),
{Data, St#st{decrypt = Dec1}}. {Data, <<>>, St#st{decrypt = Dec1}}.
%% To comply with mtp_layer interface %% To comply with mtp_layer interface
-spec try_decode_packet(iodata(), codec()) -> {ok, Decoded :: binary(), codec()} -spec try_decode_packet(iodata(), codec()) -> {ok, Decoded :: binary(), Tail :: binary(), codec()}
| {incomplete, codec()}. | {incomplete, codec()}.
try_decode_packet(Encrypted, St) -> try_decode_packet(Encrypted, St) ->
{Decrypted, St1} = decrypt(Encrypted, St), {Decrypted, Tail, St1} = decrypt(Encrypted, St),
{ok, Decrypted, St1}. {ok, Decrypted, Tail, St1}.
-spec encode_packet(iodata(), codec()) -> {iodata(), codec()}. -spec encode_packet(iodata(), codec()) -> {iodata(), codec()}.
encode_packet(Msg, S) -> encode_packet(Msg, S) ->
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
new() -> new() ->
mtp_intermediate:new(#{padding => true}). mtp_intermediate:new(#{padding => true}).
-spec try_decode_packet(binary(), codec()) -> {ok, binary(), codec()} -spec try_decode_packet(binary(), codec()) -> {ok, binary(), binary(), codec()}
| {incomplete, codec()}. | {incomplete, codec()}.
try_decode_packet(Data, St) -> try_decode_packet(Data, St) ->
mtp_intermediate:try_decode_packet(Data, St). mtp_intermediate:try_decode_packet(Data, St).
......
...@@ -13,7 +13,7 @@ prop_codec() -> ...@@ -13,7 +13,7 @@ prop_codec() ->
codec(Bin) -> codec(Bin) ->
Codec = mtp_abridged:new(), Codec = mtp_abridged:new(),
{Data, Codec1} = mtp_abridged:encode_packet(Bin, Codec), {Data, Codec1} = mtp_abridged:encode_packet(Bin, Codec),
{ok, Decoded, _} = mtp_abridged:try_decode_packet(iolist_to_binary(Data), Codec1), {ok, Decoded, <<>>, _} = mtp_abridged:try_decode_packet(iolist_to_binary(Data), Codec1),
Decoded == Bin. Decoded == Bin.
...@@ -39,6 +39,6 @@ decode_stream(BinStream, Codec, Acc) -> ...@@ -39,6 +39,6 @@ decode_stream(BinStream, Codec, Acc) ->
case mtp_abridged:try_decode_packet(BinStream, Codec) of case mtp_abridged:try_decode_packet(BinStream, Codec) of
{incomplete, _} -> {incomplete, _} ->
lists:reverse(Acc); lists:reverse(Acc);
{ok, DecPacket, Codec1} -> {ok, DecPacket, Tail, Codec1} ->
decode_stream(<<>>, Codec1, [DecPacket | Acc]) decode_stream(Tail, Codec1, [DecPacket | Acc])
end. end.
...@@ -28,6 +28,6 @@ stream_codec(Key, Iv, Stream) -> ...@@ -28,6 +28,6 @@ stream_codec(Key, Iv, Stream) ->
{<<Acc/binary, (iolist_to_binary(Data))/binary>>, {<<Acc/binary, (iolist_to_binary(Data))/binary>>,
Codec2} Codec2}
end, {<<>>, Codec}, Stream), end, {<<>>, Codec}, Stream),
{Decrypted, _Codec3} = mtp_aes_cbc:decrypt(BinStream, Codec2), {Decrypted, <<>>, _Codec3} = mtp_aes_cbc:decrypt(BinStream, Codec2),
%% io:format("Dec: ~p~nOrig: ~p~nCodec: ~p~n", [Decrypted, Stream, _Codec3]), %% io:format("Dec: ~p~nOrig: ~p~nCodec: ~p~n", [Decrypted, Stream, _Codec3]),
Decrypted == iolist_to_binary(Stream). Decrypted == iolist_to_binary(Stream).
...@@ -14,7 +14,7 @@ prop_codec() -> ...@@ -14,7 +14,7 @@ prop_codec() ->
codec(Bin) -> codec(Bin) ->
Codec = mtp_full:new(), Codec = mtp_full:new(),
{Data, Codec1} = mtp_full:encode_packet(Bin, Codec), {Data, Codec1} = mtp_full:encode_packet(Bin, Codec),
{ok, Decoded, _} = mtp_full:try_decode_packet(iolist_to_binary(Data), Codec1), {ok, Decoded, <<>>, _} = mtp_full:try_decode_packet(iolist_to_binary(Data), Codec1),
Decoded == Bin. Decoded == Bin.
...@@ -40,6 +40,6 @@ decode_stream(BinStream, Codec, Acc) -> ...@@ -40,6 +40,6 @@ decode_stream(BinStream, Codec, Acc) ->
case mtp_full:try_decode_packet(BinStream, Codec) of case mtp_full:try_decode_packet(BinStream, Codec) of
{incomplete, _} -> {incomplete, _} ->
lists:reverse(Acc); lists:reverse(Acc);
{ok, DecPacket, Codec1} -> {ok, DecPacket, Tail, Codec1} ->
decode_stream(<<>>, Codec1, [DecPacket | Acc]) decode_stream(Tail, Codec1, [DecPacket | Acc])
end. end.
...@@ -14,7 +14,7 @@ prop_codec() -> ...@@ -14,7 +14,7 @@ prop_codec() ->
codec(Bin) -> codec(Bin) ->
Codec = mtp_intermediate:new(), Codec = mtp_intermediate:new(),
{Data, Codec1} = mtp_intermediate:encode_packet(Bin, Codec), {Data, Codec1} = mtp_intermediate:encode_packet(Bin, Codec),
{ok, Decoded, _} = mtp_intermediate:try_decode_packet(iolist_to_binary(Data), Codec1), {ok, Decoded, <<>>, _} = mtp_intermediate:try_decode_packet(iolist_to_binary(Data), Codec1),
Decoded == Bin. Decoded == Bin.
...@@ -40,8 +40,8 @@ decode_stream(BinStream, Codec, Acc) -> ...@@ -40,8 +40,8 @@ decode_stream(BinStream, Codec, Acc) ->
case mtp_intermediate:try_decode_packet(BinStream, Codec) of case mtp_intermediate:try_decode_packet(BinStream, Codec) of
{incomplete, _} -> {incomplete, _} ->
lists:reverse(Acc); lists:reverse(Acc);
{ok, DecPacket, Codec1} -> {ok, DecPacket, Tail, Codec1} ->
decode_stream(<<>>, Codec1, [DecPacket | Acc]) decode_stream(Tail, Codec1, [DecPacket | Acc])
end. end.
......
...@@ -30,7 +30,7 @@ stream_codec(Key, Iv, Stream) -> ...@@ -30,7 +30,7 @@ stream_codec(Key, Iv, Stream) ->
{<<Acc/binary, (iolist_to_binary(Data))/binary>>, {<<Acc/binary, (iolist_to_binary(Data))/binary>>,
Codec2} Codec2}
end, {<<>>, Codec}, Stream), end, {<<>>, Codec}, Stream),
{Decrypted, _Codec3} = mtp_obfuscated:decrypt(BinStream, Codec2), {Decrypted, <<>>, _Codec3} = mtp_obfuscated:decrypt(BinStream, Codec2),
%% io:format("Dec: ~p~nOrig: ~p~nCodec: ~p~n", [Decrypted, Stream, _Codec3]), %% io:format("Dec: ~p~nOrig: ~p~nCodec: ~p~n", [Decrypted, Stream, _Codec3]),
Decrypted == iolist_to_binary(Stream). Decrypted == iolist_to_binary(Stream).
...@@ -99,7 +99,7 @@ transmit_stream(EncCodec, DecCodec, Stream) -> ...@@ -99,7 +99,7 @@ transmit_stream(EncCodec, DecCodec, Stream) ->
{<<Acc/binary, (iolist_to_binary(Data))/binary>>, {<<Acc/binary, (iolist_to_binary(Data))/binary>>,
CliCodec2} CliCodec2}
end, {<<>>, EncCodec}, Stream), end, {<<>>, EncCodec}, Stream),
{Decrypted, DecCodec2} = mtp_obfuscated:decrypt(EncStream, DecCodec), {Decrypted, <<>>, DecCodec2} = mtp_obfuscated:decrypt(EncStream, DecCodec),
{EncCodec3, {EncCodec3,
DecCodec2, DecCodec2,
Decrypted}. Decrypted}.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment