#pragma once

#include <util/generic/vector.h>

#include <algorithm>
#include <numeric>

#if __x86_64__
#   include <xmmintrin.h>
#endif

namespace NSrv {
    template <typename F>
    TVector<size_t> PartialArgSort(size_t n, size_t k, F&& cmp) noexcept {
        TVector<size_t> r(k);
        std::iota(r.begin(), r.end(), 0);
        std::make_heap(r.begin(), r.end(), cmp);
        for (size_t i = k; i < n; i++) if (cmp(i, r.front())) {
            std::pop_heap(r.begin(), r.end(), cmp);
            r.back() = i;
            std::push_heap(r.begin(), r.end(), cmp);
        }
        std::sort(r.begin(), r.end(), cmp);
        return r;
    }

    // Sample from a weighted set of items K times without repetitions, return the indices.
    template <typename T>
    TVector<size_t> RandomSample(T&& rng, const TVector<float>& invWeights, size_t k) noexcept {
        TVector<float> order(invWeights.size());
#if __x86_64__
        auto transform = [&](__m128 w) {
            // SSE pow(x, y) is generally exp2(log2(x) * y); monotonically increasing exp() can
            // be discarded, and log2(x) = exponent(x) + polynomial_log2(mantissa(x) + 1).
            auto ix = _mm_castps_si128(_mm_cvtepi32_ps(_mm_set_epi32(rng.GenRand(), rng.GenRand(), rng.GenRand(), rng.GenRand())));
            auto ie = _mm_and_si128(ix, _mm_set1_epi32(0x7F800000));
            auto im = _mm_and_si128(ix, _mm_set1_epi32(0x007FFFFF)); // discard the sign
            // 127 for the IEEE754 offset, 31 more to transform [0, 2**31) into [0, 1)
            auto e = _mm_cvtepi32_ps(_mm_sub_epi32(_mm_srli_epi32(ie, 23), _mm_set1_epi32(127 + 31)));
            auto m = _mm_or_ps(_mm_castsi128_ps(im), _mm_set1_ps(1.0f));
            // P2[x] approximation of log2(x) on [1, 2]. Abs error <= 0.004, max at points {1, 1.3, 1.7, 2}.
            auto p = _mm_add_ps(e, _mm_set1_ps(-1.672031296744f));
            p = _mm_add_ps(p, _mm_mul_ps(_mm_set1_ps(2.01963344234f), m));
            p = _mm_add_ps(p, _mm_mul_ps(_mm_set1_ps(-0.3429675467f), _mm_mul_ps(m, m)));
            return _mm_mul_ps(_mm_min_ps(p, _mm_set1_ps(-1e-44f)), w); // clip to avoid log2(x) >= 0
        };
        size_t rounded = order.size() - order.size() % 4;
        for (size_t i = 0; i < rounded; i += 4) {
            _mm_store_ps(&order[i], transform(_mm_load_ps(&invWeights[i])));
        }
        for (size_t i = rounded; i < order.size(); i++) {
            _mm_store_ss(&order[i], transform(_mm_load_ss(&invWeights[i])));
        }
#else
        for (size_t i = 0; i < order.size(); i++) {
            order[i] = std::powf(1. - rng.GenRandReal1(), invWeights[i]);
        }
#endif
        return PartialArgSort(order.size(), k, [&](size_t a, size_t b) {
            return order[a] > order[b];
        });
    }
}
