Commit d3479a5b authored by levlam's avatar levlam

Allow to specify more than one TLS domain and check request SNI.

parent b2777943
......@@ -2232,12 +2232,12 @@ void mtfront_prepare_parse_options (void) {
parse_option ("http-stats", no_argument, 0, 2000, "allow http server to answer on stats queries");
parse_option ("mtproto-secret", required_argument, 0, 'S', "16-byte secret in hex mode");
parse_option ("proxy-tag", required_argument, 0, 'P', "16-byte proxy tag in hex mode to be passed along with all forwarded queries");
parse_option ("domain", required_argument, 0, 'D', "domain to which all requests unrecognized as TLS-transport requests will be proxied. If specified, value of 'slaves' option is ignored");
parse_option ("domain", required_argument, 0, 'D', "adds allowed domain for TLS-transport mode, disables other transports; can be specified more than once");
parse_option ("max-special-connections", required_argument, 0, 'C', "sets maximal number of accepted client connections per worker");
parse_option ("window-clamp", required_argument, 0, 'W', "sets window clamp for client TCP connections");
parse_option ("http-ports", required_argument, 0, 'H', "comma-separated list of client (HTTP) ports to listen");
// parse_option ("outbound-connections-ps", required_argument, 0, 'o', "limits creation rate of outbound connections to mtproto-servers (default %d)", DEFAULT_OUTBOUND_CONNECTION_CREATION_RATE);
parse_option ("slaves", required_argument, 0, 'M', "spawn several slave workers");
parse_option ("slaves", required_argument, 0, 'M', "spawn several slave workers; not supported for TLS-transport mode");
parse_option ("ping-interval", required_argument, 0, 'T', "sets ping interval in second for local TCP connections (default %.3lf)", PING_INTERVAL);
}
......
......@@ -115,16 +115,40 @@ struct domain_info {
short server_hello_encrypted_size;
char use_random_encrypted_size;
char is_reversed_extension_order;
struct domain_info *next;
};
static struct domain_info domain;
static struct domain_info *default_domain_info;
static int get_domain_server_hello_encrypted_size (const struct domain_info *domain) {
if (domain->use_random_encrypted_size) {
#define DOMAIN_HASH_MOD 257
static struct domain_info *domains[DOMAIN_HASH_MOD];
static struct domain_info **get_domain_info_bucket (const char *domain, size_t len) {
size_t i;
unsigned hash = 0;
for (i = 0; i < len; i++) {
hash = hash * 239017 + (unsigned char)domain[i];
}
return domains + hash % DOMAIN_HASH_MOD;
}
static const struct domain_info *get_domain_info (const char *domain, size_t len) {
struct domain_info *info = *get_domain_info_bucket (domain, len);
while (info != NULL) {
if (strlen (info->domain) == len && memcmp (domain, info->domain, len) == 0) {
return info;
}
info = info->next;
}
return NULL;
}
static int get_domain_server_hello_encrypted_size (const struct domain_info *info) {
if (info->use_random_encrypted_size) {
int r = rand();
return domain->server_hello_encrypted_size + ((r >> 1) & 1) - (r & 1);
return info->server_hello_encrypted_size + ((r >> 1) & 1) - (r & 1);
} else {
return domain->server_hello_encrypted_size;
return info->server_hello_encrypted_size;
}
}
......@@ -555,7 +579,7 @@ static int update_domain_info (struct domain_info *info) {
info->use_random_encrypted_size = 1;
}
vkprintf (1, "Successfully checked domain %s in %.3lf seconds: is_reversed_extension_order = %d, server_hello_encrypted_size = %d, use_random_encrypted_size = %d\n",
vkprintf (0, "Successfully checked domain %s in %.3lf seconds: is_reversed_extension_order = %d, server_hello_encrypted_size = %d, use_random_encrypted_size = %d\n",
domain, get_utime_monotonic() - (finish_time - 5.0), info->is_reversed_extension_order, info->server_hello_encrypted_size, info->use_random_encrypted_size);
if (info->is_reversed_extension_order && info->server_hello_encrypted_size <= 1250) {
kprintf ("Multiple encrypted client data packets are unsupported, so handshake with %s will not be fully emulated\n", domain);
......@@ -566,23 +590,85 @@ static int update_domain_info (struct domain_info *info) {
#undef TLS_REQUEST_LENGTH
void tcp_rpc_add_proxy_domain (const char *domain_url) {
assert (domain_url != NULL);
allow_only_tls = 1;
static const struct domain_info *get_sni_domain_info (const unsigned char *request, int len) {
#define CHECK_LENGTH(length) \
if (pos + (length) > len) { \
return NULL; \
}
int pos = 11 + 32 + 1 + 32;
CHECK_LENGTH(2);
int cipher_suites_length = read_length (request, &pos);
CHECK_LENGTH(cipher_suites_length + 4);
pos += cipher_suites_length + 4;
while (1) {
CHECK_LENGTH(4);
int extension_id = read_length (request, &pos);
int extension_length = read_length (request, &pos);
CHECK_LENGTH(extension_length);
if (extension_id == 0) {
// found SNI
CHECK_LENGTH(5);
int inner_length = read_length (request, &pos);
if (inner_length != extension_length - 2) {
return NULL;
}
if (request[pos++] != 0) {
return NULL;
}
int domain_length = read_length (request, &pos);
if (domain_length != extension_length - 5) {
return NULL;
}
int i;
for (i = 0; i < domain_length; i++) {
if (request[pos + i] == 0) {
return NULL;
}
}
const struct domain_info *info = get_domain_info ((const char *)(request + pos), domain_length);
if (info == NULL) {
vkprintf (1, "Receive request for unknown domain %.*s\n", domain_length, request + pos);
}
return info;
}
domain.domain = strdup (domain_url);
pos += extension_length;
}
#undef CHECK_LENGTH
}
void tcp_rpc_add_proxy_domain (const char *domain) {
assert (domain != NULL);
struct domain_info *info = malloc (sizeof (struct domain_info));
info->domain = strdup (domain);
struct domain_info **bucket = get_domain_info_bucket (domain, strlen (domain));
info->next = *bucket;
*bucket = info;
if (!allow_only_tls) {
allow_only_tls = 1;
default_domain_info = info;
}
}
void tcp_rpc_init_proxy_domains() {
if (domain.domain == NULL) {
return;
int i;
for (i = 0; i < DOMAIN_HASH_MOD; i++) {
struct domain_info *info = domains[i];
while (info != NULL) {
if (!update_domain_info (info)) {
kprintf ("Failed to update response data about %s, so default response settings wiil be used\n", info->domain);
info->is_reversed_extension_order = 0;
info->use_random_encrypted_size = 1;
info->server_hello_encrypted_size = 2500 + rand() % 1120;
}
if (!update_domain_info (&domain)) {
kprintf ("Failed to update response data about %s, so default response settings wiil be used\n", domain.domain);
domain.is_reversed_extension_order = 0;
domain.use_random_encrypted_size = 1;
domain.server_hello_encrypted_size = 2500 + rand() % 1120;
info = info->next;
}
}
}
......@@ -599,7 +685,7 @@ static struct client_random *client_randoms[1 << RANDOM_HASH_BITS];
static struct client_random *first_client_random;
static struct client_random *last_client_random;
static struct client_random **get_bucket (unsigned char random[16]) {
static struct client_random **get_client_random_bucket (unsigned char random[16]) {
int i = RANDOM_HASH_BITS;
int pos = 0;
int id = 0;
......@@ -613,7 +699,7 @@ static struct client_random **get_bucket (unsigned char random[16]) {
}
static int have_client_random (unsigned char random[16]) {
struct client_random *cur = *get_bucket (random);
struct client_random *cur = *get_client_random_bucket (random);
while (cur != NULL) {
if (memcmp (random, cur->random, 16) == 0) {
return 1;
......@@ -636,7 +722,7 @@ static void add_client_random (unsigned char random[16]) {
last_client_random = entry;
}
struct client_random **bucket = get_bucket (random);
struct client_random **bucket = get_client_random_bucket (random);
entry->next_by_hash = *bucket;
*bucket = entry;
}
......@@ -655,7 +741,7 @@ static void delete_old_client_randoms() {
first_client_random = first_client_random->next_by_time;
struct client_random **cur = get_bucket (entry->random);
struct client_random **cur = get_client_random_bucket (entry->random);
while (*cur != entry) {
cur = &(*cur)->next_by_hash;
}
......@@ -804,20 +890,27 @@ int tcp_rpcs_compact_parse_execute (connection_job_t C) {
if (len < min_len) {
return min_len - len;
}
int read_len = len <= 4096 ? len : 4096;
unsigned char client_hello[read_len + 1]; // VLA
assert (rwm_fetch_lookup (&c->in, client_hello, read_len) == read_len);
const struct domain_info *info = get_sni_domain_info (client_hello, read_len);
if (info == NULL) {
return (-1 << 28);
}
vkprintf (1, "TLS type with domain %s\n", info->domain);
if (len > min_len) {
vkprintf (1, "Too much data in ClientHello, receive %d instead of %d\n", len, min_len);
return (-1 << 28);
}
if (len > 1024) {
if (len != read_len) {
vkprintf (1, "Too big ClientHello: receive %d bytes\n", len);
return (-1 << 28);
}
vkprintf (1, "TLS type\n");
unsigned char client_hello[len]; // VLA
assert (rwm_fetch_lookup (&c->in, client_hello, len) == len);
unsigned char client_random[32];
memcpy (client_random, client_hello + 11, 32);
memset (client_hello + 11, '\0', 32);
......@@ -850,7 +943,7 @@ int tcp_rpcs_compact_parse_execute (connection_job_t C) {
c->flags |= C_IS_TLS;
c->left_tls_packet_length = -1;
int encrypted_size = get_domain_server_hello_encrypted_size (&domain);
int encrypted_size = get_domain_server_hello_encrypted_size (info);
int response_size = 127 + 6 + 5 + encrypted_size;
unsigned char *buffer = malloc (32 + response_size);
assert (buffer != NULL);
......@@ -864,7 +957,7 @@ int tcp_rpcs_compact_parse_execute (connection_job_t C) {
int pos = 81;
int tls_server_extensions[3] = {0x33, 0x2b, -1};
if (domain.is_reversed_extension_order) {
if (info->is_reversed_extension_order) {
int t = tls_server_extensions[0];
tls_server_extensions[0] = tls_server_extensions[1];
tls_server_extensions[1] = t;
......@@ -972,6 +1065,10 @@ int tcp_rpcs_compact_parse_execute (connection_job_t C) {
unsigned tag = *(unsigned *)(random_header + 56);
if (tag == 0xdddddddd || tag == 0xeeeeeeee || tag == 0xefefefef) {
if (tag != 0xdddddddd && allow_only_tls) {
vkprintf (1, "Expected random padding mode\n");
return (-1 << 28);
}
assert (rwm_skip_data (&c->in, 64) == 64);
rwm_union (&c->in_u, &c->in);
rwm_init (&c->in, 0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment