#include #include #include #include #include #include #include #include #include #include "aes_4.cpp" #include "aes_5a.cpp" #include "aes_5b.cpp" #include "aes_6.cpp" // AES constants constexpr size_t blockSize = 16; constexpr size_t keySize = 16; size_t numTests = 1'000'000; size_t payloadSize = numTests * blockSize; using TestTimeUnit = std::milli; using CycleTimeUnit = std::nano; void xor_into_128bit_u(uint8_t *a, uint8_t *b) { __m128i vec_a = _mm_loadu_si128((__m128i*)a); __m128i vec_b = _mm_loadu_si128((__m128i*)b); __m128i vec_result = _mm_xor_si128(vec_a, vec_b); _mm_storeu_si128((__m128i*)a, vec_result); } void xor_128bit(__m128i *a, __m128i *b, __m128i *c) { __m128i vec_a = _mm_load_si128(a); __m128i vec_b = _mm_load_si128(b); __m128i vec_result = _mm_xor_si128(vec_a, vec_b); _mm_store_si128(c, vec_result); } void mov_128bit(__m128i *a, __m128i *b) { __m128i tmp = _mm_load_si128(a); _mm_store_si128(b, tmp); } void test(void (*aes)(uint8_t *in, uint8_t *out, uint32_t *expKey), uint8_t *in, uint8_t *refOut, uint32_t *expandedKey, uint8_t *iv, const std::string& name) { std::cout << "\n\ntesting: " << name << '\n'; uint8_t* tmpBlock(static_cast(std::aligned_alloc(blockSize, blockSize))); uint8_t* outBuf(static_cast(std::aligned_alloc(blockSize, payloadSize))); mov_128bit(reinterpret_cast<__m128i*>(iv), reinterpret_cast<__m128i*>(tmpBlock)); uint64_t cycles = __rdtsc(); auto start = std::chrono::high_resolution_clock::now(); for (size_t t = 0; t < numTests; ++t) { aes(tmpBlock, tmpBlock, expandedKey); xor_128bit(reinterpret_cast<__m128i*>(tmpBlock), reinterpret_cast<__m128i*>(in + blockSize * t), reinterpret_cast<__m128i*>(outBuf + blockSize * t)); } cycles = __rdtsc() - cycles; auto end = std::chrono::high_resolution_clock::now(); if (std::memcmp(outBuf, refOut, payloadSize)) std::cout << "test failed!\n"; else std::cout << "test passed\n"; std::chrono::duration time = end - start; double timeAVG = time.count() / numTests; std::cout << "total time: " << time.count()/std::ratio_divide::den << "ms\n" << "avg time pro block: " << timeAVG << "ns\n" << "avg cpu cycles per block: " << cycles/numTests << std::endl; std::free(tmpBlock); std::free(outBuf); } int main(int argc, char *argv[]) { uint8_t key[keySize]; uint8_t iv[blockSize]; uint32_t expandedKey[44]; AES_KEY opensslKey; if (argc > 1) { numTests = strtoull(argv[1], nullptr, 10); if ((int64_t)numTests < 1) { std::cout << "ivalid param" << std::endl; __builtin_trap(); } payloadSize = numTests * blockSize; } uint8_t* input( static_cast(std::aligned_alloc(blockSize, payloadSize))); uint8_t* opensslOutput(static_cast(std::aligned_alloc(blockSize, payloadSize))); RAND_bytes(key, keySize); RAND_bytes(iv, blockSize); RAND_bytes(input, payloadSize); // OpenSSL beg ############################################################################# std::cout << "measuring reference: OpenSSL\n"; #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wdeprecated-declarations" AES_set_encrypt_key(key, 128, &opensslKey); uint8_t* tmpBlock(static_cast(std::aligned_alloc(blockSize, blockSize))); mov_128bit(reinterpret_cast<__m128i*>(iv), reinterpret_cast<__m128i*>(tmpBlock)); auto start = std::chrono::high_resolution_clock::now(); size_t opensslCycles = __rdtsc(); for (int test = 0; test < numTests; ++test) { AES_encrypt(tmpBlock, tmpBlock, &opensslKey); xor_128bit(reinterpret_cast<__m128i*>(tmpBlock), reinterpret_cast<__m128i*>(input + blockSize * test), reinterpret_cast<__m128i*>(opensslOutput + blockSize * test)); } opensslCycles = __rdtsc() - opensslCycles; auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration opensslTime = end - start; double timeAVG = opensslTime.count() / numTests; std::cout << "time: " << opensslTime.count()/std::ratio_divide::den << "ms\navg time: " << timeAVG << "ns\navg cycles: " << opensslCycles/numTests << std::endl; std::free(tmpBlock); #pragma GCC diagnostic pop // OpenSSL end ############################################################################# aes128_4::expandKey(key, expandedKey); test(aes128_4::aes128_4, input, opensslOutput, expandedKey, iv, "Naive implementation (4)"); test(aes128_5a::aes128_5, input, opensslOutput, expandedKey, iv, "With macro (5a)"); test(aes128_5b::aes128_5, input, opensslOutput, expandedKey, iv, "With T-Box (5b)"); __m128i expandedKey_128[10]; aes128_6::expandKey(key, expandedKey_128); test(aes128_6::aes128_6, input, opensslOutput, (uint32_t *)expandedKey_128, iv, "Intrinsics (6)"); std::free(input); std::free(opensslOutput); return 0; }