#include "sound_utils.h"

#include "encoder_decoder.h"
#include "wifi_type.h"

#include <yandex_io/libs/base/crc32.h>
#include <yandex_io/libs/base/utils.h>

#include <limits>

YIO_DEFINE_LOG_MODULE("sound_init");

using namespace quasar;
using namespace quasar::SoundUtils;

bool SoundUtils::match(double frequency1, double frequency2, int tolerance)
{
    return fabs(frequency1 - frequency2) < tolerance;
}

int SoundUtils::getCheckSum(const std::vector<unsigned char>& payload, int offset, size_t length)
{
    quasar::Crc32 crc;
    crc.extend(reinterpret_cast<const char*>(payload.data()) + offset, length);

    return (int)(crc.checksum() & ((1 << (8 * CHECKSUM_BYTES)) - 1)); // TODO WTF
}

bool SoundUtils::isChecksumCorrect(const std::vector<unsigned char>& decoded)
{
    if (decoded.size() < CHECKSUM_BYTES)
    {
        return false;
    }

    int decodedChecksum = 0;
    for (int i = 0; i < CHECKSUM_BYTES; ++i)
    {
        decodedChecksum |= (decoded[decoded.size() - i - 1]) << (8 * (CHECKSUM_BYTES - i - 1));
    }

    int calculatedChecksum = getCheckSum(decoded, 0, decoded.size() - CHECKSUM_BYTES);

    return calculatedChecksum == decodedChecksum;
}

std::vector<unsigned char> SoundUtils::addChecksum(const std::vector<unsigned char>& payload)
{
    std::vector<unsigned char> result(payload.begin(), payload.end());
    result.resize(payload.size() + CHECKSUM_BYTES, 0);

    int calculatedSum = SoundUtils::getCheckSum(payload, 0, payload.size());

    for (int i = 0; i < CHECKSUM_BYTES; ++i)
    {
        result[result.size() - i - 1] = (unsigned char)(calculatedSum >> (8 * (CHECKSUM_BYTES - i - 1)));
    }
    return result;
}

bool SoundUtils::isPartOfStartHandshake(double frequency, int tolerance)
{
    for (int freq : HANDSHAKE_START)
    {
        if (match(freq, frequency, tolerance))
        {
            return true;
        }
    }

    return false;
}

std::vector<double> SoundUtils::filter(const std::vector<double>& frequencies, int tolerance)
{
    std::vector<double> result;

    for (const double& frequency : frequencies)
    {
        if (!isPartOfStartHandshake(frequency, tolerance))
        {
            result.push_back(frequency);
        }
    }

    return result;
}

int SoundUtils::getProtocolVersionFromFrequencies(const std::vector<double>& frequencies, int chunks, int tolerance)
{
    std::vector<double> filtered_frequencies = filter(frequencies, tolerance);

    if (const std::size_t neededAmount = chunks * 2; filtered_frequencies.size() < neededAmount)
    {
        YIO_LOG_ERROR_EVENT("SoundUtils.PrematureEndOfStream",
                            "Can't find protocol version. Received " << frequencies.size()
                                                                     << " items left after filter only " << filtered_frequencies.size()
                                                                     << " items but need at least " << neededAmount);
        return -1;
    }

    std::vector<int> candidates;
    for (int i = 0; i < chunks; ++i)
    {
        int first = (int)round((filtered_frequencies[i] - START_HZ) / STEP_HZ);
        int second = (int)round((filtered_frequencies[i + chunks] - START_HZ) / STEP_HZ);

        int payload = (first << BITS) | second;

        int version = payload >> 3;
        int checksum = payload & 7;

        if ((version & 7) != checksum)
        {
            YIO_LOG_WARN("Version checksum failed. Version: " << version << " Received checksum: " << checksum);
            continue;
        }

        YIO_LOG_INFO("Protocol version candidate: " << version);
        candidates.push_back(version);
    }

    for (int candidate : candidates)
    {
        if (candidate <= PROTOCOL_VERSION)
        {
            return candidate;
        }
    }

    YIO_LOG_ERROR_EVENT("SoundUtils.ProtocolVersionNotFound", "Can't find any supported version of protocol.");

    if (!candidates.empty())
    {
        return candidates[0];
    }

    return -1;
}

std::vector<std::complex<double>> SoundUtils::toComplex(const std::vector<std::int16_t>& chunk, int size)
{
    int pow = 1;
    while ((1 << pow) < size) {
        pow++;
    }

    std::vector<std::complex<double>> result((size_t)1 << pow);

    for (int i = 0; i < size; ++i)
    {
        result[i] = chunk[i];
    }

    return result;
}

int SoundUtils::argmax(const std::vector<std::complex<double>>& c, const std::vector<FrequencyRange>& frequencyRanges,
                       const std::vector<double>& freqs, int sampleRate)
{
    if (c.empty())
    {
        throw std::runtime_error("zero length");
    }

    int index = 0;

    for (size_t i = 1; i < c.size(); ++i)
    {
        double frequency = freqs[i] * sampleRate;

        if (std::abs(c[i]) > std::abs(c[index]) || index == 0) {
            if (frequencyRanges.empty()) {
                index = i;
            } else {
                for (const auto& range : frequencyRanges) {
                    if (range.min <= frequency && frequency <= range.max) {
                        index = i;
                        break;
                    }
                }
            }
        }
    }

    return index;
}

void SoundUtils::fft(std::vector<std::complex<double>>& a, bool invert)
{
    int n = (int)a.size();

    for (int i = 1, j = 0; i < n; ++i)
    {
        int bit = n >> 1;
        for (; j >= bit; bit >>= 1) {
            j -= bit;
        }
        j += bit;
        if (i < j) {
            swap(a[i], a[j]);
        }
    }

    for (int len = 2; len <= n; len <<= 1)
    {
        double ang = 2 * M_PI / len * (invert ? -1 : 1);
        std::complex<double> wlen(cos(ang), sin(ang));
        for (int i = 0; i < n; i += len)
        {
            std::complex<double> w(1);
            for (int j = 0; j < len / 2; ++j)
            {
                std::complex<double> u = a[i + j];
                std::complex<double> v = a[i + j + len / 2] * w;
                a[i + j] = u + v;
                a[i + j + len / 2] = u - v;
                w *= wlen;
            }
        }
    }
    if (invert) {
        for (int i = 0; i < n; ++i) {
            a[i] /= n;
        }
    }
}

std::vector<double> SoundUtils::fftfreq(int n)
{
    if (n < 0)
    {
        throw std::runtime_error("n must not be negative");
    }

    if (n == 0)
    {
        return std::vector<double>();
    }

    std::vector<double> freqs(n);
    freqs[0] = 0;

    for (int i = 1; i <= n / 2; ++i)
    {
        freqs[i] = (i * 1.0 / n);
    }

    for (int i = n / 2 + 1; i < n; ++i)
    {
        freqs[i] = -((n - i) * 1.0) / n;
    }

    return freqs;
}

std::pair<double, double>
SoundUtils::findDominant(const std::vector<std::int16_t>& chunk, int size, int sampleRate, const std::vector<FrequencyRange>& frequencyRanges)
{
    std::vector<std::complex<double>> w = toComplex(chunk, size);
    fft(w);

    std::vector<double> freqs = fftfreq(w.size());

    int peakIndex = argmax(w, frequencyRanges, freqs, sampleRate);

    double peakFrequency = freqs[peakIndex];
    double amplitude = std::abs(w[peakIndex]);
    return std::make_pair(std::abs(peakFrequency * sampleRate), amplitude);
}

std::vector<unsigned char> SoundUtils::encodePayload(const std::vector<unsigned char>& payload)
{
    std::vector<unsigned char> asBytes;
    for (const unsigned char b : payload)
    {
        asBytes.push_back(b >> 4);
        asBytes.push_back(b & ((unsigned char)0xF));
    }

    EncoderDecoder encoder;
    std::vector<unsigned char> encoded = encoder.encodeData(asBytes, FEC_BYTES);
    std::vector<unsigned char> result = encoder.encodeData(asBytes, FEC_BYTES);
    return result;
}

std::vector<int> SoundUtils::getProtocolVersionAsFrequencies(int protocolVersion)
{
    unsigned char bytes[2] = {
        static_cast<unsigned char>((((protocolVersion << 3) | (protocolVersion & 7)) >> 4)),
        static_cast<unsigned char>((((protocolVersion << 3) | (protocolVersion & 7)) & 0xF)),
    };

    std::vector<int> list;
    for (unsigned char integer : bytes)
    {
        list.push_back(START_HZ + integer * STEP_HZ);
    }

    return list;
}

BitStreamToneGenerator::BitStreamToneGenerator(std::vector<unsigned char> payload, int protocolVersion)
    : versionFrequencies(getProtocolVersionAsFrequencies(protocolVersion))
    , encodedPayload(SoundUtils::encodePayload(payload))
{
}

bool BitStreamToneGenerator::hasNext() {
    if (startHandshakeCount < SoundUtils::HANDSHAKE_START_LEN || endHandshakeCount < SoundUtils::HANDSHAKE_END_LEN) {
        return true;
    }

    if (versionCount < versionFrequencies.size()) {
        return true;
    }

    return payloadCount < encodedPayload.size();
}

int BitStreamToneGenerator::next() {
    if (startHandshakeCount < SoundUtils::HANDSHAKE_START_LEN) {
        int frequency = SoundUtils::HANDSHAKE_START[startHandshakeCount];

        startHandshakeCount++;
        return frequency;
    }

    if (versionCount < versionFrequencies.size()) {
        int frequency = versionFrequencies[versionCount];

        versionCount++;
        return frequency;
    }

    if (endHandshakeCount < SoundUtils::HANDSHAKE_END_LEN && payloadCount >= encodedPayload.size()) {
        int frequency = SoundUtils::HANDSHAKE_END[endHandshakeCount];

        endHandshakeCount++;
        return frequency;
    }

    int step = encodedPayload[payloadCount];

    payloadCount++;
    return START_HZ + step * STEP_HZ;
}

size_t BitStreamToneGenerator::size() {
    return (encodedPayload.size() + versionFrequencies.size() + SoundUtils::HANDSHAKE_START_LEN + SoundUtils::HANDSHAKE_END_LEN);
}

std::vector<std::int16_t> SoundUtils::samplesGeneratorGenerateSamples(int frequency, int sampleRate) {
    auto sampleSize = (size_t)round(sampleRate * SoundUtils::DURATION);

    std::vector<std::int16_t> sample(sampleSize);
    double increment = 2 * M_PI * frequency / sampleRate;

    double angle = 0;
    for (size_t i = 0; i < sample.size(); ++i) {
        sample[i] = (std::int16_t)(sin(angle) * std::numeric_limits<std::int16_t>::max());
        angle += increment;
    }
    return sample;
}

std::vector<std::vector<std::int16_t>> SoundUtils::samplesGeneratorGenerate(BitStreamToneGenerator frequencies, int sampleRate) {
    std::vector<std::vector<std::int16_t>> result;

    while (frequencies.hasNext()) {
        result.push_back(samplesGeneratorGenerateSamples(frequencies.next(), sampleRate));
    }
    return result;
}

std::vector<std::vector<std::int16_t>> SoundUtils::soundDataSourceGenerateSamples(int protocolVersion, std::vector<unsigned char> payload, int sampleRate)
{
    if (payload.empty()) {
        return std::vector<std::vector<std::int16_t>>();
    }
    std::vector<unsigned char> withChecksum = SoundUtils::addChecksum(payload);
    BitStreamToneGenerator generator(withChecksum, protocolVersion);
    return samplesGeneratorGenerate(generator, sampleRate);
}

std::vector<unsigned char> SoundUtils::convert(const std::vector<std::int16_t>& sample) {
    std::vector<unsigned char> result;

    for (size_t i = 0; i < sample.size(); i++)
    {
        result.push_back(sample[i] & 0xFF);
        result.push_back((sample[i] & 0xFF00) >> 8);
    }

    return result;
}

std::vector<std::int16_t> SoundUtils::convert(const std::vector<unsigned char>& data) {
    std::vector<std::int16_t> result;

    result.reserve(data.size() / 2);
    for (size_t i = 0; i < data.size() - 1; i += 2) {
        auto b1 = data[i];
        auto b2 = data[i + 1];
        result.push_back(b1 | (b2 << 8));
    }

    return result;
}
