Commit 1b916c40 authored by levlam's avatar levlam

Simplify and optimize encryption.

parent 7bd17464
...@@ -22,115 +22,20 @@ ...@@ -22,115 +22,20 @@
*/ */
#include "crypto/aesni256.h" #include "crypto/aesni256.h"
#include <assert.h>
#include <string.h>
#include <stdint.h>
#include "common/cpuid.h"
#include <openssl/opensslv.h>
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
#include <openssl/modes.h>
void AES_ctr128_encrypt(
const unsigned char *in,
unsigned char *out,
size_t length,
const AES_KEY *key,
unsigned char ivec[AES_BLOCK_SIZE],
unsigned char ecount_buf[AES_BLOCK_SIZE],
unsigned int *num) {
CRYPTO_ctr128_encrypt(in, out, length, key, ivec, ecount_buf, num, (block128_f)AES_encrypt);
}
#endif
void tg_ssl_aes_ctr_crypt (tg_aes_ctx_t *ctx, const unsigned char *in, unsigned char *out, int size, unsigned char iv[16], unsigned long long offset) {
unsigned char iv_copy[16];
memcpy (iv_copy, iv, 16);
unsigned long long *p = (unsigned long long *) (iv_copy + 8);
(*p) += offset >> 4;
union {
unsigned char c[16];
unsigned long long d[2];
} u;
int i = offset & 15, l;
if (i) {
AES_encrypt (iv_copy, u.c, &ctx->u.key);
(*p)++;
l = i + size;
if (l > 16) {
l = 16;
}
size -= l - i;
do {
*out++ = (*in++) ^ u.c[i++];
} while (i < l);
}
const unsigned long long *I = (const unsigned long long *) in;
unsigned long long *O = (unsigned long long *) out;
int n = size >> 4;
while (--n >= 0) {
AES_encrypt (iv_copy, (unsigned char *) u.d, &ctx->u.key);
(*p)++;
*O++ = (*I++) ^ u.d[0];
*O++ = (*I++) ^ u.d[1];
}
l = size & 15;
if (l) {
AES_encrypt (iv_copy, u.c, &ctx->u.key);
in = (const unsigned char *) I;
out = (unsigned char *) O;
i = 0;
do {
*out++ = (*in++) ^ u.c[i++];
} while (i < l);
}
}
static void tg_ssl_aes_cbc_encrypt (tg_aes_ctx_t *ctx, const unsigned char *in, unsigned char *out, int size, unsigned char iv[16]) {
AES_cbc_encrypt (in, out, size, &ctx->u.key, iv, AES_ENCRYPT);
}
static void tg_ssl_aes_cbc_decrypt (tg_aes_ctx_t *ctx, const unsigned char *in, unsigned char *out, int size, unsigned char iv[16]) {
AES_cbc_encrypt (in, out, size, &ctx->u.key, iv, AES_DECRYPT);
}
static void tg_ssl_aes_ige_encrypt (tg_aes_ctx_t *ctx, const unsigned char *in, unsigned char *out, int size, unsigned char iv[32]) {
AES_ige_encrypt (in, out, size, &ctx->u.key, iv, AES_ENCRYPT);
}
static void tg_ssl_aes_ige_decrypt (tg_aes_ctx_t *ctx, const unsigned char *in, unsigned char *out, int size, unsigned char iv[32]) {
AES_ige_encrypt (in, out, size, &ctx->u.key, iv, AES_DECRYPT);
}
void tg_ssl_aes_ctr128_crypt (struct tg_aes_ctx *ctx, const unsigned char *in, unsigned char *out, int size, unsigned char iv[16], unsigned char ecount_buf[16], unsigned int *num) { #include <assert.h>
AES_ctr128_encrypt (in, out, size, &ctx->u.key, iv, ecount_buf, num);
}
static const struct tg_aes_methods ssl_aes_encrypt_methods = {
.cbc_crypt = tg_ssl_aes_cbc_encrypt,
.ige_crypt = tg_ssl_aes_ige_encrypt,
.ctr_crypt = tg_ssl_aes_ctr_crypt,
.ctr128_crypt = tg_ssl_aes_ctr128_crypt
};
void tg_aes_set_encrypt_key (tg_aes_ctx_t *ctx, unsigned char *key, int bits) {
AES_set_encrypt_key (key, bits, &ctx->u.key);
ctx->type = &ssl_aes_encrypt_methods;
}
static const struct tg_aes_methods ssl_aes_decrypt_methods = { EVP_CIPHER_CTX *evp_cipher_ctx_init (const EVP_CIPHER *cipher, unsigned char *key, unsigned char iv[16], int is_encrypt) {
.cbc_crypt = tg_ssl_aes_cbc_decrypt, EVP_CIPHER_CTX *evp_ctx = EVP_CIPHER_CTX_new();
.ige_crypt = tg_ssl_aes_ige_decrypt, assert(evp_ctx);
.ctr_crypt = NULL,
.ctr128_crypt = NULL
};
void tg_aes_set_decrypt_key (tg_aes_ctx_t *ctx, unsigned char *key, int bits) { assert(EVP_CipherInit(evp_ctx, cipher, key, iv, is_encrypt) == 1);
AES_set_decrypt_key (key, bits, &ctx->u.key); assert(EVP_CIPHER_CTX_set_padding(evp_ctx, 0) == 1);
ctx->type = &ssl_aes_decrypt_methods; return evp_ctx;
} }
void tg_aes_ctx_cleanup (tg_aes_ctx_t *ctx) { void evp_crypt (EVP_CIPHER_CTX *evp_ctx, const void *in, void *out, int size) {
memset (ctx, 0, sizeof (tg_aes_ctx_t)); int len;
assert (EVP_CipherUpdate(evp_ctx, out, &len, in, size) == 1);
assert (len == size);
} }
...@@ -23,30 +23,8 @@ ...@@ -23,30 +23,8 @@
#pragma once #pragma once
#include <openssl/aes.h> #include <openssl/evp.h>
struct aesni256_ctx { EVP_CIPHER_CTX *evp_cipher_ctx_init (const EVP_CIPHER *cipher, unsigned char *key, unsigned char iv[16], int is_encrypt);
unsigned char a[256];
}; void evp_crypt (EVP_CIPHER_CTX *evp_ctx, const void *in, void *out, int size);
//TODO: move cbc_crypt, ige_crypt, ctr_crypt to the virtual method table
struct tg_aes_ctx;
struct tg_aes_methods {
void (*cbc_crypt) (struct tg_aes_ctx *ctx, const unsigned char *in, unsigned char *out, int size, unsigned char iv[16]);
void (*ige_crypt) (struct tg_aes_ctx *ctx, const unsigned char *in, unsigned char *out, int size, unsigned char iv[32]);
void (*ctr_crypt) (struct tg_aes_ctx *ctx, const unsigned char *in, unsigned char *out, int size, unsigned char iv[16], unsigned long long offset);
void (*ctr128_crypt) (struct tg_aes_ctx *ctx, const unsigned char *in, unsigned char *out, int size, unsigned char iv[16], unsigned char ecount_buf[16], unsigned int *num);
};
typedef struct tg_aes_ctx {
union {
AES_KEY key;
struct aesni256_ctx ctx;
} u;
const struct tg_aes_methods *type;
} tg_aes_ctx_t;
void tg_aes_set_encrypt_key (tg_aes_ctx_t *ctx, unsigned char *key, int bits);
void tg_aes_set_decrypt_key (tg_aes_ctx_t *ctx, unsigned char *key, int bits);
void tg_aes_ctx_cleanup (tg_aes_ctx_t *ctx);
...@@ -86,12 +86,8 @@ int aes_crypto_init (connection_job_t c, void *key_data, int key_data_len) { ...@@ -86,12 +86,8 @@ int aes_crypto_init (connection_job_t c, void *key_data, int key_data_len) {
MODULE_STAT->allocated_aes_crypto ++; MODULE_STAT->allocated_aes_crypto ++;
tg_aes_set_decrypt_key (&T->read_aeskey, D->read_key, 256); T->read_aeskey = evp_cipher_ctx_init (EVP_aes_256_cbc(), D->read_key, D->read_iv, 0);
memcpy (T->read_iv, D->read_iv, 16); T->write_aeskey = evp_cipher_ctx_init (EVP_aes_256_cbc(), D->write_key, D->write_iv, 1);
tg_aes_set_encrypt_key (&T->write_aeskey, D->write_key, 256);
memcpy (T->write_iv, D->write_iv, 16);
// T->read_pos = T->write_pos = 0;
T->read_num = T->write_num = 0;
CONN_INFO(c)->crypto = T; CONN_INFO(c)->crypto = T;
return 0; return 0;
} }
...@@ -105,19 +101,19 @@ int aes_crypto_ctr128_init (connection_job_t c, void *key_data, int key_data_len ...@@ -105,19 +101,19 @@ int aes_crypto_ctr128_init (connection_job_t c, void *key_data, int key_data_len
MODULE_STAT->allocated_aes_crypto ++; MODULE_STAT->allocated_aes_crypto ++;
tg_aes_set_encrypt_key (&T->read_aeskey, D->read_key, 256); // NB: *_encrypt_key here! T->read_aeskey = evp_cipher_ctx_init (EVP_aes_256_ctr(), D->read_key, D->read_iv, 1); // NB: is_encrypt == 1 here!
memcpy (T->read_iv, D->read_iv, 16); T->write_aeskey = evp_cipher_ctx_init (EVP_aes_256_ctr(), D->write_key, D->write_iv, 1);
tg_aes_set_encrypt_key (&T->write_aeskey, D->write_key, 256);
memcpy (T->write_iv, D->write_iv, 16);
// T->read_pos = T->write_pos = 0;
T->read_num = T->write_num = 0;
CONN_INFO(c)->crypto = T; CONN_INFO(c)->crypto = T;
return 0; return 0;
} }
int aes_crypto_free (connection_job_t c) { int aes_crypto_free (connection_job_t c) {
if (CONN_INFO(c)->crypto) { struct aes_crypto *crypto = CONN_INFO(c)->crypto;
free (CONN_INFO(c)->crypto); if (crypto) {
EVP_CIPHER_CTX_free (crypto->read_aeskey);
EVP_CIPHER_CTX_free (crypto->write_aeskey);
free (crypto);
CONN_INFO(c)->crypto = 0; CONN_INFO(c)->crypto = 0;
MODULE_STAT->allocated_aes_crypto --; MODULE_STAT->allocated_aes_crypto --;
} }
......
...@@ -70,12 +70,8 @@ struct aes_key_data { ...@@ -70,12 +70,8 @@ struct aes_key_data {
/* for c->crypto */ /* for c->crypto */
struct aes_crypto { struct aes_crypto {
unsigned char read_iv[16], write_iv[16]; EVP_CIPHER_CTX *read_aeskey;
unsigned char read_ebuf[16], write_ebuf[16]; /* for AES-CTR modes */ EVP_CIPHER_CTX *write_aeskey;
tg_aes_ctx_t read_aeskey __attribute__ ((aligned (16)));
tg_aes_ctx_t write_aeskey __attribute__ ((aligned (16)));
unsigned int read_num, write_num; /* for AES-CTR modes */
// long long read_pos, write_pos; /* for AES-CTR modes */
}; };
extern int aes_initialized; extern int aes_initialized;
......
...@@ -1229,11 +1229,7 @@ struct rwm_encrypt_decrypt_tmp { ...@@ -1229,11 +1229,7 @@ struct rwm_encrypt_decrypt_tmp {
int left; int left;
int block_size; int block_size;
struct raw_message *raw; struct raw_message *raw;
struct tg_aes_ctx *ctx; EVP_CIPHER_CTX *evp_ctx;
void (*crypt)(struct tg_aes_ctx *, const void *, void *, int, unsigned char *, void *, void *);
unsigned char *iv;
void *extra;
void *extra2;
char buf[16] __attribute__((aligned(16))); char buf[16] __attribute__((aligned(16)));
}; };
...@@ -1261,12 +1257,12 @@ int rwm_process_encrypt_decrypt (struct rwm_encrypt_decrypt_tmp *x, const void * ...@@ -1261,12 +1257,12 @@ int rwm_process_encrypt_decrypt (struct rwm_encrypt_decrypt_tmp *x, const void *
data += to_fill; data += to_fill;
x->bp = 0; x->bp = 0;
if (x->buf_left >= bsize) { if (x->buf_left >= bsize) {
x->crypt (x->ctx, x->buf, res->last->part->data + res->last_offset, bsize, x->iv, x->extra, x->extra2); evp_crypt (x->evp_ctx, x->buf, res->last->part->data + res->last_offset, bsize);
res->last->data_end += bsize; res->last->data_end += bsize;
res->last_offset += bsize; res->last_offset += bsize;
x->buf_left -= bsize; x->buf_left -= bsize;
} else { } else {
x->crypt (x->ctx, x->buf, x->buf, bsize, x->iv, x->extra, x->extra2); evp_crypt (x->evp_ctx, x->buf, x->buf, bsize);
memcpy (res->last->part->data + res->last_offset, x->buf, x->buf_left); memcpy (res->last->part->data + res->last_offset, x->buf, x->buf_left);
int t = x->buf_left; int t = x->buf_left;
res->last->data_end += t; res->last->data_end += t;
...@@ -1316,7 +1312,7 @@ int rwm_process_encrypt_decrypt (struct rwm_encrypt_decrypt_tmp *x, const void * ...@@ -1316,7 +1312,7 @@ int rwm_process_encrypt_decrypt (struct rwm_encrypt_decrypt_tmp *x, const void *
assert (x->buf_left + res->last_offset <= res->last->part->chunk->buffer_size); assert (x->buf_left + res->last_offset <= res->last->part->chunk->buffer_size);
if (len <= x->buf_left) { if (len <= x->buf_left) {
assert (!(len & (bsize - 1))); assert (!(len & (bsize - 1)));
x->crypt (x->ctx, data, (res->last->part->data + res->last_offset), len, x->iv, x->extra, x->extra2); evp_crypt (x->evp_ctx, data, (res->last->part->data + res->last_offset), len);
res->last->data_end += len; res->last->data_end += len;
res->last_offset += len; res->last_offset += len;
res->total_bytes += len; res->total_bytes += len;
...@@ -1324,7 +1320,7 @@ int rwm_process_encrypt_decrypt (struct rwm_encrypt_decrypt_tmp *x, const void * ...@@ -1324,7 +1320,7 @@ int rwm_process_encrypt_decrypt (struct rwm_encrypt_decrypt_tmp *x, const void *
return 0; return 0;
} else { } else {
int t = x->buf_left & -bsize; int t = x->buf_left & -bsize;
x->crypt (x->ctx, data, res->last->part->data + res->last_offset, t, x->iv, x->extra, x->extra2); evp_crypt (x->evp_ctx, data, res->last->part->data + res->last_offset, t);
res->last->data_end += t; res->last->data_end += t;
res->last_offset += t; res->last_offset += t;
res->total_bytes += t; res->total_bytes += t;
...@@ -1336,7 +1332,7 @@ int rwm_process_encrypt_decrypt (struct rwm_encrypt_decrypt_tmp *x, const void * ...@@ -1336,7 +1332,7 @@ int rwm_process_encrypt_decrypt (struct rwm_encrypt_decrypt_tmp *x, const void *
} }
int rwm_encrypt_decrypt_to (struct raw_message *raw, struct raw_message *res, int bytes, struct tg_aes_ctx *ctx, void (*crypt)(struct tg_aes_ctx *ctx, const void *src, void *dst, int l, unsigned char *iv, void *extra, void *extra2), unsigned char *iv, int block_size, void *extra, void *extra2) { int rwm_encrypt_decrypt_to (struct raw_message *raw, struct raw_message *res, int bytes, EVP_CIPHER_CTX *evp_ctx, int block_size) {
assert (bytes >= 0); assert (bytes >= 0);
assert (block_size && !(block_size & (block_size - 1))); assert (block_size && !(block_size & (block_size - 1)));
if (bytes > raw->total_bytes) { if (bytes > raw->total_bytes) {
...@@ -1365,18 +1361,14 @@ int rwm_encrypt_decrypt_to (struct raw_message *raw, struct raw_message *res, in ...@@ -1365,18 +1361,14 @@ int rwm_encrypt_decrypt_to (struct raw_message *raw, struct raw_message *res, in
} }
struct rwm_encrypt_decrypt_tmp t; struct rwm_encrypt_decrypt_tmp t;
t.bp = 0; t.bp = 0;
t.crypt = crypt;
if (res->last->part->refcnt == 1) { if (res->last->part->refcnt == 1) {
t.buf_left = res->last->part->chunk->buffer_size - res->last_offset; t.buf_left = res->last->part->chunk->buffer_size - res->last_offset;
} else { } else {
t.buf_left = 0; t.buf_left = 0;
} }
t.raw = res; t.raw = res;
t.ctx = ctx; t.evp_ctx = evp_ctx;
t.iv = iv;
t.left = bytes; t.left = bytes;
t.extra = extra;
t.extra2 = extra2;
t.block_size = block_size; t.block_size = block_size;
int r = rwm_process_and_advance (raw, bytes, (void *)rwm_process_encrypt_decrypt, &t); int r = rwm_process_and_advance (raw, bytes, (void *)rwm_process_encrypt_decrypt, &t);
if (locked) { if (locked) {
......
...@@ -145,9 +145,7 @@ int rwm_process_from_offset (struct raw_message *raw, int bytes, int offset, int ...@@ -145,9 +145,7 @@ int rwm_process_from_offset (struct raw_message *raw, int bytes, int offset, int
int rwm_transform_from_offset (struct raw_message *raw, int bytes, int offset, int (*transform_block)(void *extra, void *data, int len), void *extra); int rwm_transform_from_offset (struct raw_message *raw, int bytes, int offset, int (*transform_block)(void *extra, void *data, int len), void *extra);
int rwm_process_and_advance (struct raw_message *raw, int bytes, int (*process_block)(void *extra, const void *data, int len), void *extra); int rwm_process_and_advance (struct raw_message *raw, int bytes, int (*process_block)(void *extra, const void *data, int len), void *extra);
int rwm_sha1 (struct raw_message *raw, int bytes, unsigned char output[20]); int rwm_sha1 (struct raw_message *raw, int bytes, unsigned char output[20]);
// int rwm_encrypt_decrypt (struct raw_message *raw, int bytes, tg_aes_ctx_t *ctx, unsigned char iv[32]); int rwm_encrypt_decrypt_to (struct raw_message *raw, struct raw_message *res, int bytes, EVP_CIPHER_CTX *evp_ctx, int block_size);
// int rwm_encrypt_decrypt_cbc (struct raw_message *raw, int bytes, tg_aes_ctx_t *ctx, unsigned char iv[16]);
int rwm_encrypt_decrypt_to (struct raw_message *raw, struct raw_message *res, int bytes, tg_aes_ctx_t *ctx, void (*crypt)(tg_aes_ctx_t *ctx, const void *src, void *dst, int l, unsigned char *iv, void *extra, void *extra2), unsigned char *iv, int block_size, void *extra, void *extra2);
void *rwm_get_block_ptr (struct raw_message *raw); void *rwm_get_block_ptr (struct raw_message *raw);
int rwm_get_block_ptr_bytes (struct raw_message *raw); int rwm_get_block_ptr_bytes (struct raw_message *raw);
......
...@@ -201,7 +201,7 @@ int cpu_tcp_aes_crypto_encrypt_output (connection_job_t C) /* {{{ */ { ...@@ -201,7 +201,7 @@ int cpu_tcp_aes_crypto_encrypt_output (connection_job_t C) /* {{{ */ {
int l = out->total_bytes; int l = out->total_bytes;
l &= ~15; l &= ~15;
if (l) { if (l) {
assert (rwm_encrypt_decrypt_to (&c->out, &c->out_p, l, &T->write_aeskey, (void *)T->write_aeskey.type->cbc_crypt, T->write_iv, 16, 0, 0) == l); assert (rwm_encrypt_decrypt_to (&c->out, &c->out_p, l, T->write_aeskey, 16) == l);
} }
return (-out->total_bytes) & 15; return (-out->total_bytes) & 15;
...@@ -218,7 +218,7 @@ int cpu_tcp_aes_crypto_decrypt_input (connection_job_t C) /* {{{ */ { ...@@ -218,7 +218,7 @@ int cpu_tcp_aes_crypto_decrypt_input (connection_job_t C) /* {{{ */ {
int l = in->total_bytes; int l = in->total_bytes;
l &= ~15; l &= ~15;
if (l) { if (l) {
assert (rwm_encrypt_decrypt_to (&c->in_u, &c->in, l, &T->read_aeskey, (void *)T->read_aeskey.type->cbc_crypt, T->read_iv, 16, 0, 0) == l); assert (rwm_encrypt_decrypt_to (&c->in_u, &c->in, l, T->read_aeskey, 16) == l);
} }
return (-in->total_bytes) & 15; return (-in->total_bytes) & 15;
...@@ -253,7 +253,7 @@ int cpu_tcp_aes_crypto_ctr128_encrypt_output (connection_job_t C) /* {{{ */ { ...@@ -253,7 +253,7 @@ int cpu_tcp_aes_crypto_ctr128_encrypt_output (connection_job_t C) /* {{{ */ {
vkprintf (2, "Send TLS-packet of length %d\n", len); vkprintf (2, "Send TLS-packet of length %d\n", len);
} }
assert (rwm_encrypt_decrypt_to (&c->out, &c->out_p, len, &T->write_aeskey, (void *)T->write_aeskey.type->ctr128_crypt, T->write_iv, 1, T->write_ebuf, &T->write_num) == len); assert (rwm_encrypt_decrypt_to (&c->out, &c->out_p, len, T->write_aeskey, 1) == len);
} }
return 0; return 0;
...@@ -295,7 +295,7 @@ int cpu_tcp_aes_crypto_ctr128_decrypt_input (connection_job_t C) /* {{{ */ { ...@@ -295,7 +295,7 @@ int cpu_tcp_aes_crypto_ctr128_decrypt_input (connection_job_t C) /* {{{ */ {
c->left_tls_packet_length -= len; c->left_tls_packet_length -= len;
} }
vkprintf (2, "Read %d bytes out of %d available\n", len, c->in_u.total_bytes); vkprintf (2, "Read %d bytes out of %d available\n", len, c->in_u.total_bytes);
assert (rwm_encrypt_decrypt_to (&c->in_u, &c->in, len, &T->read_aeskey, (void *)T->read_aeskey.type->ctr128_crypt, T->read_iv, 1, T->read_ebuf, &T->read_num) == len); assert (rwm_encrypt_decrypt_to (&c->in_u, &c->in, len, T->read_aeskey, 1) == len);
} }
return 0; return 0;
......
...@@ -1091,7 +1091,7 @@ int tcp_rpcs_compact_parse_execute (connection_job_t C) { ...@@ -1091,7 +1091,7 @@ int tcp_rpcs_compact_parse_execute (connection_job_t C) {
assert (c->crypto); assert (c->crypto);
struct aes_crypto *T = c->crypto; struct aes_crypto *T = c->crypto;
T->read_aeskey.type->ctr128_crypt (&T->read_aeskey, random_header, random_header, 64, T->read_iv, T->read_ebuf, &T->read_num); evp_crypt (T->read_aeskey, random_header, random_header, 64);
unsigned tag = *(unsigned *)(random_header + 56); unsigned tag = *(unsigned *)(random_header + 56);
if (tag == 0xdddddddd || tag == 0xeeeeeeee || tag == 0xefefefef) { if (tag == 0xdddddddd || tag == 0xeeeeeeee || tag == 0xefefefef) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment