#include "sound_data_receiver.h"

#include "encoder_decoder.h"
#include "sound_utils.h"

#include <yandex_io/libs/base/utils.h>
#include <yandex_io/libs/device/device.h>
#include <yandex_io/libs/json_utils/json_utils.h>
#include <yandex_io/libs/logging/logging.h>

#include <vector>

YIO_DEFINE_LOG_MODULE("sound_init");

using namespace quasar;

SoundDataReceiver::SoundDataReceiver(std::shared_ptr<YandexIO::IDevice> device, size_t sampleRate)
    : device_(std::move(device))
    , sampleRate_(sampleRate)
    , frameSize_((size_t)(round(sampleRate_ * SoundUtils::DURATION / SPLIT_CHUNKS)))
    , maxPacketsCount_((size_t)round(MAX_WAIT_SECONDS * SPLIT_CHUNKS / SoundUtils::DURATION))
{
    auto config = device_->configuration()->getServiceConfig("sound_initd");
    amplitudeThreshold_ = tryGetInt(config, "amplitudeThreshold", 300);
    frequencyTolerance_ = tryGetInt(config, "frequencyTolerance", SoundUtils::TOLERANCE);
    handshakeStartMin_ = *std::min_element(SoundUtils::HANDSHAKE_START, SoundUtils::HANDSHAKE_START + SoundUtils::HANDSHAKE_START_LEN);
    handshakeStartMax_ = *std::max_element(SoundUtils::HANDSHAKE_START, SoundUtils::HANDSHAKE_START + SoundUtils::HANDSHAKE_START_LEN);
}

SoundDataReceiver::~SoundDataReceiver()
{
    if (!stopped_.exchange(true)) {
        audioStream_.stop();
        try {
            listenThread_.join();
        } catch (const std::exception& e) {
            YIO_LOG_ERROR_EVENT("SoundDataReceiver.FailedJoinListenThread", "Error while joining listenThread_ in ~SoundDataReceiver: " << e.what());
        }
    }
}

void SoundDataReceiver::start()
{
    if (stopped_.exchange(false)) {
        audioStream_.start();
        listenThread_ = std::thread(&SoundDataReceiver::listen, this);

        YIO_LOG_INFO("Sound data receiver started.");
    }
}

void SoundDataReceiver::write(const std::vector<std::int16_t>& buffer)
{
    audioStream_.write(buffer);
}

void SoundDataReceiver::stop()
{
    if (!stopped_.exchange(true)) {
        audioStream_.stop();
        listenThread_.join();
        audioStream_.clear();

        YIO_LOG_INFO("Sound data receiver stopped.");
    }
}

void SoundDataReceiver::listen()
{
    YIO_LOG_INFO("Sound data receiver started.");
    YIO_LOG_INFO("FRAME_SIZE: " << frameSize_);
    bool record = false;
    std::vector<std::int16_t> buffer;
    bool handshakeStartOccurred = false;
    std::vector<double> packet;
    std::vector<double> handshakeStart;
    while (!stopped_)
    {
        try {
            if (!audioStream_.read(buffer, frameSize_)) {
                continue;
            }
            std::pair<double, double> dominant;
            if (record) {
                dominant = SoundUtils::findDominant(
                    buffer, frameSize_, sampleRate_,
                    {
                        {SoundUtils::START_HZ - SoundUtils::STEP_HZ / 2,
                         SoundUtils::START_HZ + (1 << SoundUtils::BITS) * SoundUtils::STEP_HZ + SoundUtils::STEP_HZ / 2 - 1},
                        {SoundUtils::HANDSHAKE_END_HZ - frequencyTolerance_ + 1,
                         SoundUtils::HANDSHAKE_END_HZ + frequencyTolerance_ - 1},
                    });
            } else if (handshakeStartOccurred) {
                dominant = SoundUtils::findDominant(
                    buffer, frameSize_, sampleRate_,
                    {{SoundUtils::START_HZ - SoundUtils::STEP_HZ, SoundUtils::HANDSHAKE_START_HZ + SoundUtils::STEP_HZ}});
            } else {
                dominant = SoundUtils::findDominant(buffer, frameSize_, sampleRate_, {});
            }
            double dominantFrequency = dominant.first;
            double dominantAmplitude = dominant.second;

            if (dominantAmplitude <= amplitudeThreshold_ && !handshakeStartOccurred) {
                continue;
            }
            YIO_LOG_TRACE("Dominant received: " << dominantFrequency << " with amplitude " << dominantAmplitude);
            if (record && SoundUtils::match(dominantFrequency, SoundUtils::HANDSHAKE_END_HZ, frequencyTolerance_)) {
                onSoundMessageEnd(packet, handshakeStart, record, handshakeStartOccurred);
            } else if (record) {
                onNewMessageFrequency(dominantFrequency, packet, handshakeStart, record, handshakeStartOccurred);
            } else {
                onNewFrequency(dominantFrequency, dominantAmplitude, packet, handshakeStart, record, handshakeStartOccurred);
            }

        } catch (const std::exception& e)
        {
            if (!stopped_) {
                YIO_LOG_ERROR_EVENT("SoundDataReceiver.UnknownListenException", "IOException caught: " << e.what());
            }
        }
    }
}

void SoundDataReceiver::onNewFrequency(double dominantFrequency, double dominantAmplitude, std::vector<double>& packet,
                                       std::vector<double>& handshakeStart, bool& record, bool& handshakeStartOccurred) const {
    updateStartHandshake(handshakeStart, dominantFrequency);
    if (SoundUtils::match(dominantFrequency, SoundUtils::HANDSHAKE_START_HZ, frequencyTolerance_)) {
        handshakeStartOccurred = true;
    }
    if (isHandshakeStartMatching(handshakeStart))
    {
        YIO_LOG_INFO("Started recording");
        YIO_LOG_INFO("Last dominantAmplitude: " << dominantAmplitude);
        record = true;
        packet.clear();
        handshakeStart.clear();
        if (onTransferStart) {
            onTransferStart();
        }
    }
}

void SoundDataReceiver::onNewMessageFrequency(double dominantFrequency, std::vector<double>& packet,
                                              std::vector<double>& handshakeStart, bool& record, bool& handshakeStartOccurred) const {
    packet.push_back(dominantFrequency);
    if (packet.size() > maxPacketsCount_)
    {
        YIO_LOG_WARN("Too much packets received");
        record = false;
        handshakeStartOccurred = false;
        packet.clear();
        handshakeStart.clear();
        if (onTransferError) {
            onTransferError();
        }
    }
}

void SoundDataReceiver::onSoundMessageEnd(std::vector<double>& packet, std::vector<double>& handshakeStart,
                                          bool& record, bool& handshakeStartOccurred)
{
    YIO_LOG_DEBUG("PACKET:" << join(packet, ", "));
    YIO_LOG_DEBUG("PACKET LENGTH:" << packet.size());

    int protocolVersion = SoundUtils::getProtocolVersionFromFrequencies(packet, SPLIT_CHUNKS, frequencyTolerance_);

    if (protocolVersion == -1)
    {
        YIO_LOG_WARN("Can't read protocol version. Consider as version 0");
        protocolVersion = 0;
    } else if (protocolVersion > SoundUtils::PROTOCOL_VERSION)
    {
        YIO_LOG_WARN("Unsupported protocol version received: " << protocolVersion);

        packet.clear();
        handshakeStart.clear();
        record = false;
        handshakeStartOccurred = false;

        if (onUnsupportedProtocol) {
            onUnsupportedProtocol(protocolVersion);
        }
        return;
    }

    YIO_LOG_INFO("Source protocol version is: " << protocolVersion);

    std::vector<byte> payload = getPayload(packet);
    if (payload.empty())
    {
        if (onTransferError) {
            onTransferError();
        }
    } else {
        if (onDataReceived) {
            onDataReceived(payload, protocolVersion);
        }
    }

    packet.clear();
    handshakeStart.clear();
    record = false;
    handshakeStartOccurred = false;
}

std::vector<SoundDataReceiver::byte> SoundDataReceiver::getPayload(const std::vector<double>& packet) const {
    std::vector<std::vector<double>> split_packet = split(packet);
    for (std::vector<double>& splitPacket : split_packet)
    {
        try {
            splitPacket = std::vector<double>(splitPacket.begin() + 2, splitPacket.end());
            std::vector<byte> extracted = extractPacket(splitPacket);
            std::string s;
            for (int b : splitPacket) {
                s += std::to_string(b);
            }
            YIO_LOG_DEBUG("SPLITTED PACKET:" << s);
            YIO_LOG_DEBUG("SPLITTED PACKET SIZE:" << splitPacket.size());
            s = "";
            for (int b : extracted) {
                s += std::to_string(b) + ", ";
            }
            YIO_LOG_DEBUG("EXTRACTED: " << s);
            YIO_LOG_DEBUG("EXTRACTED SIZE: " << extracted.size());

            std::vector<byte> decoded = EncoderDecoder().decodeData(extracted, SoundUtils::FEC_BYTES);
            YIO_LOG_DEBUG("DECODED: " + std::string(decoded.begin(), decoded.end()));
            std::vector<byte> assembled = decodeBitChunks(SoundUtils::BITS, decoded);
            YIO_LOG_DEBUG("ASSEMBLED: " + std::string(assembled.begin(), assembled.end()));

            if (SoundUtils::isChecksumCorrect(assembled))
            {
                std::vector<byte> payload(assembled.begin(), assembled.begin() + assembled.size() -
                                                                 SoundUtils::CHECKSUM_BYTES);

                YIO_LOG_DEBUG("PAYLOAD:" + std::string(payload.begin(), payload.end()));
                YIO_LOG_DEBUG(join(payload, ", "));
                return payload;
            } else {
                YIO_LOG_DEBUG("Incorrect check sum");
            }
        } catch (std::exception& e) {
            YIO_LOG_INFO("Bad shift: " << e.what());
        }
    }
    return std::vector<SoundDataReceiver::byte>();
}

bool SoundDataReceiver::isHandshakeStartMatching(const std::vector<double>& handshakeStart) const {
    if (handshakeStart.size() < SoundUtils::HANDSHAKE_START_LEN * SPLIT_CHUNKS)
    {
        return false;
    }

    for (size_t i = 0; i < SoundUtils::HANDSHAKE_START_LEN; ++i)
    {
        if (!findInHandshake(handshakeStart, SoundUtils::HANDSHAKE_START[i],
                             i * SPLIT_CHUNKS,
                             (i + 1) * SPLIT_CHUNKS))
        {
            return false;
        }
    }

    return true;
}

bool SoundDataReceiver::findInHandshake(const std::vector<double>& handshakeStart, int handshakeFrequency, int start,
                                        int end) const {
    for (int i = start; i < end; ++i)
    {
        if (SoundUtils::match(handshakeFrequency, handshakeStart[i], frequencyTolerance_))
        {
            return true;
        }
    }

    return false;
}

void SoundDataReceiver::updateStartHandshake(std::vector<double>& handshakeStart, double dominant) const {
    handshakeStart.push_back(dominant);

    while (handshakeStart.size() > SoundUtils::HANDSHAKE_START_LEN * SPLIT_CHUNKS)
    {
        handshakeStart.erase(handshakeStart.begin());
    }
}

std::vector<std::vector<double>> SoundDataReceiver::split(const std::vector<double>& packet) const {
    std::vector<std::vector<double>> result(SPLIT_CHUNKS);
    result[0].reserve(packet.size() / 2 + 1);
    result[1].reserve(packet.size() / 2 + 1);

    for (size_t i = 0; i < packet.size(); ++i)
    {
        if (SoundUtils::isPartOfStartHandshake(packet[i], frequencyTolerance_))
        {
            continue;
        }

        int index = i % SPLIT_CHUNKS;
        result[index].push_back(packet[i]);
    }

    return result;
}

std::vector<SoundDataReceiver::byte> SoundDataReceiver::extractPacket(const std::vector<double>& packet) {
    std::vector<byte> bitChunks(packet.size());

    for (size_t i = 0; i < packet.size(); ++i)
    {
        bitChunks[i] = (byte)round((packet[i] - SoundUtils::START_HZ) / SoundUtils::STEP_HZ);
    }

    return bitChunks;
}

std::vector<SoundDataReceiver::byte>
SoundDataReceiver::decodeBitChunks(int bits, const std::vector<SoundDataReceiver::byte>& chunks) {
    std::vector<byte> out;

    size_t nextReadChunk = 0;
    byte nextReadBit = 0;

    int currentByte = 0;
    int bitsLeft = 8;
    while (nextReadChunk < chunks.size())
    {
        int canFill = bits - nextReadBit;
        int toFill = std::min(bitsLeft, canFill);

        int offset = bits - nextReadBit - toFill;

        currentByte <<= toFill;
        int shifted = chunks[nextReadChunk] & (((1 << toFill) - 1) << offset);
        currentByte |= shifted >> offset;

        bitsLeft -= toFill;
        nextReadBit += toFill;

        if (bitsLeft <= 0)
        {
            out.push_back((byte)currentByte);
            currentByte = 0;
            bitsLeft = 8;
        }

        if (nextReadBit >= bits)
        {
            nextReadChunk++;
            nextReadBit -= bits;
        }
    }

    return out;
}
