#include <iostream>
#include <vector>
#include <algorithm>
#include <chrono>
#include <random>
#include <type_traits>

// =====================================================================
// Custom Template Radix Sort - $O(N)$ Complexity
// Optimally designed for 32-bit integers (signed and unsigned)
// =====================================================================
template<typename RandomIt>
void custom_radix_sort(RandomIt first, RandomIt last) {
    using T = typename std::iterator_traits<RandomIt>::value_type;
    
    // We enforce 32-bit integers for this specific high-performance implementation
    //static_assert(std::is_integral_v<T> && sizeof(T) == 4, 
    //              "This template is optimized for 32-bit integers.");

    size_t n = std::distance(first, last);
    if (n <= 1) return;

    // Buffer for scatter phase
    std::vector<T> buffer(n);
    
    // Raw pointers for maximum memory speed
    T* src = &(*first);
    T* dst = buffer.data();

    // 1. One-Pass Histogram Generation (Cache-Friendly Optimization)
    // We count the byte frequencies for all 4 bytes in a single pass.
    uint32_t counts[4][256] = {0};
    
    for (size_t i = 0; i < n; ++i) {
        // Cast to unsigned to prevent arithmetic shift (sign extension) bugs
        uint32_t val = static_cast<uint32_t>(src[i]); 
        
        counts[0][val & 0xFF]++;
        counts[1][(val >> 8) & 0xFF]++;
        counts[2][(val >> 16) & 0xFF]++;
        
        // Handle negative numbers correctly by flipping the sign bit
        unsigned char c3 = (val >> 24) & 0xFF;
        c3 ^= 128; 
        counts[3][c3]++;
    }

    // 2. Radix Passes (Process each byte)
    for (int byte = 0; byte < 4; ++byte) {
        // Calculate prefix sums to find the exact array index for each item
        uint32_t pos[256];
        pos[0] = 0;
        for (int i = 1; i < 256; ++i) {
            pos[i] = pos[i - 1] + counts[byte][i - 1];
        }

        // Scatter elements into the destination buffer
        int shift = byte * 8;
        for (size_t i = 0; i < n; ++i) {
            uint32_t val = static_cast<uint32_t>(src[i]);
            unsigned char c = (val >> shift) & 0xFF;
            
            if (byte == 3) c ^= 128; // Flip highest bit for sorting signed ints
            
            dst[pos[c]++] = src[i];
        }

        // Swap src and dst pointers.
        // Because we do exactly 4 passes, src will cleanly end up pointing 
        // back to the original array (`first`), requiring zero final copies!
        std::swap(src, dst);
    }
}

// =====================================================================
// Benchmark Engine
// =====================================================================
int main() {
    const int SIZE = 100000;
    
    std::cout << "Generating " << SIZE << " random elements...\n";
    std::vector<int> data_std(SIZE);
    
    // Generate chaotic random data (including negative numbers)
    std::mt19937 rng(42);
    std::uniform_int_distribution<int> dist(-1000000, 1000000);
    for(int i = 0; i < SIZE; ++i) {
        data_std[i] = dist(rng);
    }
    
    // Duplicate data
    std::vector<int> data_custom = data_std; 

    // ---------------------------------------------------------
    // Benchmark 1: std::sort (O(N log N))
    // ---------------------------------------------------------
    auto start1 = std::chrono::high_resolution_clock::now();
    std::sort(data_std.begin(), data_std.end());
    auto end1 = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double, std::milli> time_std = end1 - start1;

    // ---------------------------------------------------------
    // Benchmark 2: Custom Radix Sort (O(N))
    // ---------------------------------------------------------
    auto start2 = std::chrono::high_resolution_clock::now();
    custom_radix_sort(data_custom.begin(), data_custom.end());
    auto end2 = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double, std::milli> time_custom = end2 - start2;

    // ---------------------------------------------------------
    // Results
    // ---------------------------------------------------------
    std::cout << "\n--- Benchmark Results (" << SIZE << " elements) ---\n";
    std::cout << "std::sort time:          " << time_std.count() << " ms\n";
    std::cout << "custom_radix_sort:       " << time_custom.count() << " ms\n";
    
    // Verify
    bool correct = std::is_sorted(data_custom.begin(), data_custom.end());
    std::cout << "\nCustom sort is correct:  " << (correct ? "True" : "False") << "\n";
    
    // Speed Math
    if (time_custom.count() < time_std.count()) {
        std::cout << "🏆 RESULT: custom_radix_sort won!\n";
        std::cout << "It is " << (time_std.count() / time_custom.count()) 
                  << "x faster than std::sort.\n";
    }

    return 0;
}