218 lines
8.1 KiB
C++
218 lines
8.1 KiB
C++
#include <iostream>
|
|
#include <iomanip>
|
|
#include <openssl/aes.h>
|
|
#include <openssl/rand.h>
|
|
#include <cstring>
|
|
#include <chrono>
|
|
#include <immintrin.h>
|
|
#include <memory>
|
|
#include <type_traits>
|
|
#include "aes_4.cpp"
|
|
#include "aes_5a.cpp"
|
|
|
|
// AES constants
|
|
constexpr size_t blockSize = 16;
|
|
constexpr size_t keySize = 16;
|
|
|
|
// Test constants
|
|
constexpr size_t numTests = 1'000'000;
|
|
constexpr size_t payloadSize = numTests * blockSize;
|
|
|
|
using TestTimeUnit = std::milli;
|
|
using CycleTimeUnit = std::nano;
|
|
|
|
void xor_into_128bit_u(uint8_t *a, uint8_t *b) {
|
|
__m128i vec_a = _mm_loadu_si128((__m128i*)a);
|
|
__m128i vec_b = _mm_loadu_si128((__m128i*)b);
|
|
|
|
__m128i vec_result = _mm_xor_si128(vec_a, vec_b);
|
|
|
|
_mm_storeu_si128((__m128i*)a, vec_result);
|
|
}
|
|
|
|
// aligned version
|
|
void xor_128bit(__m128i *a, __m128i *b, __m128i *c) {
|
|
__m128i vec_a = _mm_load_si128(a);
|
|
__m128i vec_b = _mm_load_si128(b);
|
|
|
|
__m128i vec_result = _mm_xor_si128(vec_a, vec_b);
|
|
|
|
_mm_store_si128(c, vec_result);
|
|
}
|
|
|
|
void mov_128bit(__m128i *a, __m128i *b) {
|
|
__m128i tmp = _mm_load_si128(a);
|
|
_mm_store_si128(b, tmp);
|
|
}
|
|
|
|
void test(void (*aes)(uint8_t *in, uint8_t *out, uint32_t *expKey), uint8_t *in, uint8_t *refOut,
|
|
uint32_t *expandedKey, uint8_t *iv,
|
|
const std::string& name) {
|
|
std::cout << "\n\ntesting: " << name << '\n';
|
|
|
|
uint8_t* tmpBlock(static_cast<uint8_t*>(std::aligned_alloc(blockSize, blockSize)));
|
|
uint8_t* outBuf(static_cast<uint8_t*>(std::aligned_alloc(blockSize, payloadSize)));
|
|
mov_128bit(reinterpret_cast<__m128i*>(iv), reinterpret_cast<__m128i*>(tmpBlock));
|
|
|
|
uint64_t cycles = __rdtsc();
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
for (size_t t = 0; t < numTests; ++t) {
|
|
aes(tmpBlock, tmpBlock, expandedKey);
|
|
xor_128bit(reinterpret_cast<__m128i*>(tmpBlock),
|
|
reinterpret_cast<__m128i*>(in + blockSize * t),
|
|
reinterpret_cast<__m128i*>(outBuf + blockSize * t));
|
|
}
|
|
cycles = __rdtsc() - cycles;
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
|
|
if (std::memcmp(outBuf, refOut, payloadSize)) std::cout << "test failed!\n";
|
|
else std::cout << "test passed\n";
|
|
|
|
std::chrono::duration<double, CycleTimeUnit> time = end - start;
|
|
double timeAVG = time.count() / numTests;
|
|
|
|
std::cout << "time :" << time.count()/std::ratio_divide<CycleTimeUnit, TestTimeUnit>::den << "ms\navg time: " << timeAVG << "ns\navg cpu cycles: " << cycles/numTests << std::endl;
|
|
std::free(tmpBlock);
|
|
std::free(outBuf);
|
|
}
|
|
|
|
|
|
int main() {
|
|
uint8_t key[keySize];
|
|
uint8_t iv[blockSize];
|
|
uint32_t expandedKey[44];
|
|
AES_KEY opensslKey;
|
|
|
|
uint8_t* input(static_cast<uint8_t*>(std::aligned_alloc(blockSize, payloadSize)));
|
|
uint8_t* opensslOutput(static_cast<uint8_t*>(std::aligned_alloc(blockSize, payloadSize)));
|
|
|
|
RAND_bytes(key, keySize);
|
|
RAND_bytes(iv, blockSize);
|
|
RAND_bytes(input, payloadSize);
|
|
|
|
// OpenSSL
|
|
std::cout << "testing: OpenSSL\n";
|
|
#pragma GCC diagnostic push
|
|
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
|
AES_set_encrypt_key(key, 128, &opensslKey);
|
|
uint8_t* tmpBlock(static_cast<uint8_t*>(std::aligned_alloc(blockSize, blockSize)));
|
|
mov_128bit(reinterpret_cast<__m128i*>(iv), reinterpret_cast<__m128i*>(tmpBlock));
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
size_t opensslCycles = __rdtsc();
|
|
for (int test = 0; test < numTests; ++test) {
|
|
AES_encrypt(tmpBlock, tmpBlock, &opensslKey);
|
|
xor_128bit(reinterpret_cast<__m128i*>(tmpBlock),
|
|
reinterpret_cast<__m128i*>(input + blockSize * test),
|
|
reinterpret_cast<__m128i*>(opensslOutput + blockSize * test));
|
|
}
|
|
opensslCycles = __rdtsc() - opensslCycles;
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
std::chrono::duration<double, std::nano> opensslTime = end - start;
|
|
double timeAVG = opensslTime.count() / numTests;
|
|
std::cout << "avg time: " << timeAVG << "ns\navg cycles: " << opensslCycles/numTests << std::endl;
|
|
std::free(tmpBlock);
|
|
#pragma GCC diagnostic pop
|
|
|
|
expandKey(key, expandedKey);
|
|
|
|
test(aes128_4, input, opensslOutput, expandedKey, iv, "My original implementation");
|
|
test(aes128_5a::aes128_5a, input, opensslOutput, expandedKey, iv, "My original implementation");
|
|
// test(aes128, input, opensslOutput, expandedKey, iv, "My original implementation");
|
|
|
|
std::free(input);
|
|
std::free(opensslOutput);
|
|
return 0;
|
|
}
|
|
|
|
// int main() {
|
|
|
|
// uint8_t key[keySize];
|
|
|
|
// uint8_t* opensslOutput = new uint8_t[payloadSize];
|
|
// uint8_t* myOutput = new uint8_t[payloadSize];
|
|
// uint8_t* myOptimOutput = new uint8_t[payloadSize];
|
|
|
|
// uint8_t* opensslInput = new uint8_t[payloadSize];
|
|
// uint8_t* myInput = new uint8_t[payloadSize];
|
|
// uint8_t* myOptimInput = new uint8_t[payloadSize];
|
|
|
|
// std::unique_ptr<uint8_t> iv(static_cast<uint8_t*>(std::aligned_alloc(blockSize, blockSize)));
|
|
// uint32_t expandedKey[44];
|
|
|
|
// std::chrono::duration<double, std::nano> opensslTime(0);
|
|
// std::chrono::duration<double, std::nano> myTime(0);
|
|
// std::chrono::duration<double, std::nano> myOptimTime(0);
|
|
|
|
// uint64_t opensslCycles;
|
|
// uint64_t myCycles;
|
|
// uint64_t myOptimCycles;
|
|
|
|
// AES_KEY opensslKey;
|
|
|
|
// RAND_bytes(key, keySize);
|
|
// RAND_bytes(opensslInput, payloadSize);
|
|
|
|
// RAND_bytes(iv.get(), blockSize);
|
|
// xor_into_128bit_u(opensslInput, iv.get());
|
|
// memcpy(myInput, opensslInput, payloadSize);
|
|
// memcpy(myOptimInput, opensslInput, payloadSize);
|
|
|
|
// expandKey(key, expandedKey);
|
|
// #pragma GCC diagnostic push
|
|
// #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
|
// AES_set_encrypt_key(key, 128, &opensslKey);
|
|
|
|
// // OPENSSL
|
|
// auto start = std::chrono::high_resolution_clock::now();
|
|
// opensslCycles = __rdtsc();
|
|
// for (int test = 0; test < numTests; ++test) {
|
|
// AES_encrypt(opensslInput + blockSize * test, opensslOutput + blockSize * test, &opensslKey);
|
|
// xor_into_128bit_u(opensslInput + blockSize * test, opensslOutput + blockSize * test);
|
|
// }
|
|
// #pragma GCC diagnostic pop
|
|
// opensslCycles = __rdtsc() - opensslCycles;
|
|
// auto end = std::chrono::high_resolution_clock::now();
|
|
// opensslTime += end - start;
|
|
|
|
// // My 4
|
|
// start = std::chrono::high_resolution_clock::now();
|
|
// myCycles = __rdtsc();
|
|
// for (int test = 0; test < numTests; ++test) {
|
|
// aes128(myInput + blockSize * test, myOutput + blockSize * test, expandedKey);
|
|
// xor_into_128bit_u(myInput + blockSize * test, myOutput + blockSize * test);
|
|
// }
|
|
// myCycles = __rdtsc() - myCycles;
|
|
// end = std::chrono::high_resolution_clock::now();
|
|
// myTime += end - start;
|
|
|
|
// // My 5a
|
|
// start = std::chrono::high_resolution_clock::now();
|
|
// myOptimCycles = __rdtsc();
|
|
// for (int test = 0; test < numTests; ++test) {
|
|
// aes128(myOptimInput + blockSize * test, myOptimOutput + blockSize * test, expandedKey);
|
|
// xor_into_128bit_u(myOptimInput + blockSize * test, myOptimOutput + blockSize * test);
|
|
// }
|
|
// myOptimCycles = __rdtsc() - myOptimCycles;
|
|
// end = std::chrono::high_resolution_clock::now();
|
|
// myOptimTime += end - start;
|
|
|
|
// // Verify
|
|
// if (std::memcmp(myOptimOutput, opensslOutput, payloadSize)) {
|
|
// std::cout << "Output differs\n";
|
|
// for (int i = 0; i < 16; ++i)
|
|
// std::cout << (int)myOutput[i] << "!=" << (int)opensslOutput[i] << '\n';
|
|
// } else {
|
|
// std::cout << "Output same\n";
|
|
// }
|
|
|
|
// // Print perf stats
|
|
// double opensslTimeAVG = opensslTime.count() / numTests;
|
|
// double myTimeAVG = myTime.count() / numTests;
|
|
// double myOptimTimeAVG = myOptimTime.count() / numTests;
|
|
|
|
// std::cout << "avg openssl time: " << opensslTimeAVG << "ns, cycles: " << opensslCycles/numTests << std::endl;
|
|
// std::cout << "avg my time: " << myTimeAVG << "ns, cycles: " << myCycles/numTests << std::endl;
|
|
// std::cout << "avg my optim time: " << myOptimTimeAVG << "ns, cycles: " << myOptimCycles/numTests << std::endl;
|
|
|
|
// return opensslOutput[0] ^ myOutput[0] ^ myOptimOutput[0];
|
|
// }
|