#include #include #include #include #include #include #include #include #include #include "aes_4.cpp" #include "aes_5a.cpp" // AES constants constexpr size_t blockSize = 16; constexpr size_t keySize = 16; // Test constants constexpr size_t numTests = 1'000'000; constexpr size_t payloadSize = numTests * blockSize; using TestTimeUnit = std::milli; using CycleTimeUnit = std::nano; void xor_into_128bit_u(uint8_t *a, uint8_t *b) { __m128i vec_a = _mm_loadu_si128((__m128i*)a); __m128i vec_b = _mm_loadu_si128((__m128i*)b); __m128i vec_result = _mm_xor_si128(vec_a, vec_b); _mm_storeu_si128((__m128i*)a, vec_result); } // aligned version void xor_128bit(__m128i *a, __m128i *b, __m128i *c) { __m128i vec_a = _mm_load_si128(a); __m128i vec_b = _mm_load_si128(b); __m128i vec_result = _mm_xor_si128(vec_a, vec_b); _mm_store_si128(c, vec_result); } void mov_128bit(__m128i *a, __m128i *b) { __m128i tmp = _mm_load_si128(a); _mm_store_si128(b, tmp); } void test(void (*aes)(uint8_t *in, uint8_t *out, uint32_t *expKey), uint8_t *in, uint8_t *refOut, uint32_t *expandedKey, uint8_t *iv, const std::string& name) { std::cout << "\n\ntesting: " << name << '\n'; uint8_t* tmpBlock(static_cast(std::aligned_alloc(blockSize, blockSize))); uint8_t* outBuf(static_cast(std::aligned_alloc(blockSize, payloadSize))); mov_128bit(reinterpret_cast<__m128i*>(iv), reinterpret_cast<__m128i*>(tmpBlock)); uint64_t cycles = __rdtsc(); auto start = std::chrono::high_resolution_clock::now(); for (size_t t = 0; t < numTests; ++t) { aes(tmpBlock, tmpBlock, expandedKey); xor_128bit(reinterpret_cast<__m128i*>(tmpBlock), reinterpret_cast<__m128i*>(in + blockSize * t), reinterpret_cast<__m128i*>(outBuf + blockSize * t)); } cycles = __rdtsc() - cycles; auto end = std::chrono::high_resolution_clock::now(); if (std::memcmp(outBuf, refOut, payloadSize)) std::cout << "test failed!\n"; else std::cout << "test passed\n"; std::chrono::duration time = end - start; double timeAVG = time.count() / numTests; std::cout << "time :" << time.count()/std::ratio_divide::den << "ms\navg time: " << timeAVG << "ns\navg cpu cycles: " << cycles/numTests << std::endl; std::free(tmpBlock); std::free(outBuf); } int main() { uint8_t key[keySize]; uint8_t iv[blockSize]; uint32_t expandedKey[44]; AES_KEY opensslKey; uint8_t* input(static_cast(std::aligned_alloc(blockSize, payloadSize))); uint8_t* opensslOutput(static_cast(std::aligned_alloc(blockSize, payloadSize))); RAND_bytes(key, keySize); RAND_bytes(iv, blockSize); RAND_bytes(input, payloadSize); // OpenSSL std::cout << "testing: OpenSSL\n"; #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wdeprecated-declarations" AES_set_encrypt_key(key, 128, &opensslKey); uint8_t* tmpBlock(static_cast(std::aligned_alloc(blockSize, blockSize))); mov_128bit(reinterpret_cast<__m128i*>(iv), reinterpret_cast<__m128i*>(tmpBlock)); auto start = std::chrono::high_resolution_clock::now(); size_t opensslCycles = __rdtsc(); for (int test = 0; test < numTests; ++test) { AES_encrypt(tmpBlock, tmpBlock, &opensslKey); xor_128bit(reinterpret_cast<__m128i*>(tmpBlock), reinterpret_cast<__m128i*>(input + blockSize * test), reinterpret_cast<__m128i*>(opensslOutput + blockSize * test)); } opensslCycles = __rdtsc() - opensslCycles; auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration opensslTime = end - start; double timeAVG = opensslTime.count() / numTests; std::cout << "avg time: " << timeAVG << "ns\navg cycles: " << opensslCycles/numTests << std::endl; std::free(tmpBlock); #pragma GCC diagnostic pop expandKey(key, expandedKey); test(aes128_4, input, opensslOutput, expandedKey, iv, "My original implementation"); test(aes128_5a::aes128_5a, input, opensslOutput, expandedKey, iv, "My original implementation"); // test(aes128, input, opensslOutput, expandedKey, iv, "My original implementation"); std::free(input); std::free(opensslOutput); return 0; } // int main() { // uint8_t key[keySize]; // uint8_t* opensslOutput = new uint8_t[payloadSize]; // uint8_t* myOutput = new uint8_t[payloadSize]; // uint8_t* myOptimOutput = new uint8_t[payloadSize]; // uint8_t* opensslInput = new uint8_t[payloadSize]; // uint8_t* myInput = new uint8_t[payloadSize]; // uint8_t* myOptimInput = new uint8_t[payloadSize]; // std::unique_ptr iv(static_cast(std::aligned_alloc(blockSize, blockSize))); // uint32_t expandedKey[44]; // std::chrono::duration opensslTime(0); // std::chrono::duration myTime(0); // std::chrono::duration myOptimTime(0); // uint64_t opensslCycles; // uint64_t myCycles; // uint64_t myOptimCycles; // AES_KEY opensslKey; // RAND_bytes(key, keySize); // RAND_bytes(opensslInput, payloadSize); // RAND_bytes(iv.get(), blockSize); // xor_into_128bit_u(opensslInput, iv.get()); // memcpy(myInput, opensslInput, payloadSize); // memcpy(myOptimInput, opensslInput, payloadSize); // expandKey(key, expandedKey); // #pragma GCC diagnostic push // #pragma GCC diagnostic ignored "-Wdeprecated-declarations" // AES_set_encrypt_key(key, 128, &opensslKey); // // OPENSSL // auto start = std::chrono::high_resolution_clock::now(); // opensslCycles = __rdtsc(); // for (int test = 0; test < numTests; ++test) { // AES_encrypt(opensslInput + blockSize * test, opensslOutput + blockSize * test, &opensslKey); // xor_into_128bit_u(opensslInput + blockSize * test, opensslOutput + blockSize * test); // } // #pragma GCC diagnostic pop // opensslCycles = __rdtsc() - opensslCycles; // auto end = std::chrono::high_resolution_clock::now(); // opensslTime += end - start; // // My 4 // start = std::chrono::high_resolution_clock::now(); // myCycles = __rdtsc(); // for (int test = 0; test < numTests; ++test) { // aes128(myInput + blockSize * test, myOutput + blockSize * test, expandedKey); // xor_into_128bit_u(myInput + blockSize * test, myOutput + blockSize * test); // } // myCycles = __rdtsc() - myCycles; // end = std::chrono::high_resolution_clock::now(); // myTime += end - start; // // My 5a // start = std::chrono::high_resolution_clock::now(); // myOptimCycles = __rdtsc(); // for (int test = 0; test < numTests; ++test) { // aes128(myOptimInput + blockSize * test, myOptimOutput + blockSize * test, expandedKey); // xor_into_128bit_u(myOptimInput + blockSize * test, myOptimOutput + blockSize * test); // } // myOptimCycles = __rdtsc() - myOptimCycles; // end = std::chrono::high_resolution_clock::now(); // myOptimTime += end - start; // // Verify // if (std::memcmp(myOptimOutput, opensslOutput, payloadSize)) { // std::cout << "Output differs\n"; // for (int i = 0; i < 16; ++i) // std::cout << (int)myOutput[i] << "!=" << (int)opensslOutput[i] << '\n'; // } else { // std::cout << "Output same\n"; // } // // Print perf stats // double opensslTimeAVG = opensslTime.count() / numTests; // double myTimeAVG = myTime.count() / numTests; // double myOptimTimeAVG = myOptimTime.count() / numTests; // std::cout << "avg openssl time: " << opensslTimeAVG << "ns, cycles: " << opensslCycles/numTests << std::endl; // std::cout << "avg my time: " << myTimeAVG << "ns, cycles: " << myCycles/numTests << std::endl; // std::cout << "avg my optim time: " << myOptimTimeAVG << "ns, cycles: " << myOptimCycles/numTests << std::endl; // return opensslOutput[0] ^ myOutput[0] ^ myOptimOutput[0]; // }