/******************************************************************************
* Fast linear search (AVX-512 ➜ AVX2 ➜ scalar) *
* ‑> reports which implementation ran and how long it took. *
* *
* g++ -O3 -std=c++17 -mavx512f -mavx2 fast_find.cpp -pthread -o fast_find *
* ./fast_find # single-thread *
* ./fast_find --mt # multi-thread *
******************************************************************************/
#include <algorithm>
#include <atomic>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <immintrin.h>
#include <iostream>
#include <random>
#include <string>
#include <thread>
#include <vector>
namespace fast_find {
// ---------------------------------------------------------------------------
// Which implementation was actually used?
// ---------------------------------------------------------------------------
enum class Impl { Scalar, AVX2, AVX512 };
// ---------------------------------------------------------------------------
// 1. Scalar fallback
// ---------------------------------------------------------------------------
template <typename T>
inline int scalar(const T* a, std::size_t n, T key) noexcept {
for (std::size_t i = 0; i < n; ++i)
if (a[i] == key) return static_cast<int>(i);
return -1;
}
// ---------------------------------------------------------------------------
// 2. AVX2 implementation
// ---------------------------------------------------------------------------
#ifdef __AVX2__
inline int avx2(const int* a, std::size_t n, int key) noexcept {
constexpr int W = 8;
const __m256i NEEDLE = _mm256_set1_epi32(key);
std::size_t i = 0;
const std::size_t limit = n & ~(W * 4 - 1);
for (; i < limit; i += W * 4) {
__m256i v0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i));
__m256i v1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + W));
__m256i v2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + W * 2));
__m256i v3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + W * 3));
int m0 = _mm256_movemask_epi8(_mm256_cmpeq_epi32(v0, NEEDLE));
if (m0) return i + ((m0 & -m0) % 255) >> 2;
int m1 = _mm256_movemask_epi8(_mm256_cmpeq_epi32(v1, NEEDLE));
if (m1) return i + W + ((m1 & -m1) % 255) >> 2;
int m2 = _mm256_movemask_epi8(_mm256_cmpeq_epi32(v2, NEEDLE));
if (m2) return i + W * 2 + ((m2 & -m2) % 255) >> 2;
int m3 = _mm256_movemask_epi8(_mm256_cmpeq_epi32(v3, NEEDLE));
if (m3) return i + W * 3 + ((m3 & -m3) % 255) >> 2;
}
for (; i < n; ++i)
if (a[i] == key) return static_cast<int>(i);
return -1;
}
#endif
// ---------------------------------------------------------------------------
// 3. AVX-512 implementation
// ---------------------------------------------------------------------------
#ifdef __AVX512F__
inline int avx512(const int* a, std::size_t n, int key) noexcept {
constexpr int W = 16;
const __m512i NEEDLE = _mm512_set1_epi32(key);
std::size_t i = 0;
const std::size_t limit = n & ~(W * 4 - 1);
for (; i < limit; i += W * 4) {
_mm_prefetch(reinterpret_cast<const char*>(a + i + 64), _MM_HINT_T0);
_mm_prefetch(reinterpret_cast<const char*>(a + i + 128), _MM_HINT_T0);
__mmask16 m0 = _mm512_cmpeq_epi32_mask(_mm512_loadu_si512(a + i), NEEDLE);
if (m0) return i + _tzcnt_u32(m0);
__mmask16 m1 = _mm512_cmpeq_epi32_mask(_mm512_loadu_si512(a + i + W), NEEDLE);
if (m1) return i + W + _tzcnt_u32(m1);
__mmask16 m2 = _mm512_cmpeq_epi32_mask(_mm512_loadu_si512(a + i + W * 2), NEEDLE);
if (m2) return i + W * 2 + _tzcnt_u32(m2);
__mmask16 m3 = _mm512_cmpeq_epi32_mask(_mm512_loadu_si512(a + i + W * 3), NEEDLE);
if (m3) return i + W * 3 + _tzcnt_u32(m3);
}
for (; i < n; ++i)
if (a[i] == key) return static_cast<int>(i);
return -1;
}
#endif
// ---------------------------------------------------------------------------
// 4. Single-thread façade (returns index + impl used)
// ---------------------------------------------------------------------------
inline int search(const int* data, std::size_t n, int value, Impl& used) noexcept {
#ifdef __AVX512F__
if (__builtin_cpu_supports("avx512f")) { used = Impl::AVX512; return avx512(data, n, value); }
#endif
#ifdef __AVX2__
if (__builtin_cpu_supports("avx2")) { used = Impl::AVX2; return avx2 (data, n, value); }
#endif
used = Impl::Scalar;
return scalar(data, n, value);
}
// convenience wrapper when caller doesn't care about impl
inline int search(const int* data, std::size_t n, int value) noexcept {
Impl dummy;
return search(data, n, value, dummy);
}
// ---------------------------------------------------------------------------
// 5. Multi-thread wrapper (returns index + impl used by *any* thread)
// ---------------------------------------------------------------------------
inline int search_mt(const int* data, std::size_t n, int value,
unsigned nThreads,
Impl& usedImpl)
{
if (nThreads == 0) nThreads = 1;
if (nThreads == 1 || n < 16'384) // ST faster for small inputs
return search(data, n, value, usedImpl);
const std::size_t chunk = (n + nThreads - 1) / nThreads;
std::atomic<int> result{-1};
std::atomic<Impl> implSeen{Impl::Scalar};
std::vector<std::thread> pool;
for (unsigned t = 0; t < nThreads; ++t) {
const std::size_t start = t * chunk;
if (start >= n) break;
const std::size_t end = std::min(start + chunk, n);
pool.emplace_back([&, start, end]() {
Impl localImpl;
int localIdx = search(data + start, end - start, value, localImpl);
implSeen.store(localImpl, std::memory_order_relaxed);
if (localIdx != -1) {
int global = static_cast<int>(start + localIdx);
int expected = -1;
result.compare_exchange_strong(expected, global,
std::memory_order_relaxed);
}
});
}
for (auto& th : pool) th.join();
usedImpl = implSeen.load(std::memory_order_relaxed);
return result.load();
}
} // namespace fast_find
// ═══════════════════════════════════ Demo main ═════════════════════════════
static std::string to_string(fast_find::Impl impl) {
switch (impl) {
case fast_find::Impl::Scalar: return "Scalar";
case fast_find::Impl::AVX2: return "AVX2";
case fast_find::Impl::AVX512: return "AVX-512";
}
return "Unknown";
}
int main(int argc, char** argv) {
constexpr std::size_t N = 10'000;
std::vector<int> data(N);
for (std::size_t i = 0; i < N; ++i) data[i] = (i * 77 + 123) & 0x7FFF;
// --------- Randomly pick a key from the data set
std::random_device rd;
std::mt19937 rng(rd());
std::uniform_int_distribution<std::size_t> dist(0, N - 1);
const std::size_t randIdx = dist(rng);
const int key = data[randIdx];
const bool useMT = (argc > 1 && std::string(argv[1]) == "--mt");
const unsigned hwThreads =
std::thread::hardware_concurrency() ? std::thread::hardware_concurrency() : 1;
fast_find::Impl implUsed;
const auto t0 = std::chrono::high_resolution_clock::now();
int idx = useMT
? fast_find::search_mt(data.data(), data.size(), key,
hwThreads, implUsed)
: fast_find::search (data.data(), data.size(), key,
implUsed);
const auto t1 = std::chrono::high_resolution_clock::now();
const double micro =
std::chrono::duration_cast<std::chrono::duration<double, std::micro>>(t1 - t0).count();
std::cout << (useMT ? "[MT] " : "[ST] ")
<< "Impl: " << to_string(implUsed)
<< " | Key: " << key
<< " | Index: " << idx
<< " | Time: " << micro << " µs"
<< " | Logical cores: " << hwThreads
<< '\n';
}