#include "voice_stats.h"

#include <voicetech/vqe/libs/stft/stft.h>
#include <voicetech/vqe/libs/utils/utils.h>
#include <yandex_io/libs/logging/logging.h>

#include <deque>
#include <numeric>

using namespace quasar;

namespace {
    using RmsArray = std::array<float, 10>;

    RmsArray squeezeChannel(const RMSChannel::DataType& data) {
        RmsArray result;
        static_assert(std::tuple_size<decltype(result)>::value * 4 >= std::tuple_size<std::remove_reference<decltype(data)>::type>::value,
                      "squeezeChannel compact elements in proportion 1:4. Source array is too large");
        auto dst = std::begin(result);
        int i = 0;
        float avgs = 0.0;

        for (auto f : data) {
            avgs += f;
            if (++i == 4) {
                *dst = avgs / 4.0;
                avgs = 0;
                i = 0;
                ++dst;
            }
        }
        return result;
    }
} // unnamed namespace

RMSChannel::RMSChannel(std::string n, float avg, float rawAvg, DataType d)
    : name(std::move(n))
    , average(avg)
    , rawAverage(rawAvg)
    , data{d}
{
}

namespace {
    constexpr const char* RCU_CHANNEL_NAME = "raw_rcu";

    struct PreRms {
        std::int64_t cur = 0;
        int counts = 0;
        bool fresh = false;     // needs to remove outdated channels
        std::deque<float> data; // FIXME: use std::array as a circle

        static constexpr int frameSize = 512;
        static constexpr int chunkSize = frameSize * 2;

        PreRms(bool /*unused*/){}; // to implement emplace in unordered_map

        using DataType = RMSChannel::DataType;

        PreRms& operator+=(std::int64_t pcm) {
            cur += pcm * pcm;
            if (++counts == chunkSize) {
                data.push_back(float(cur));
                cur = 0;
                counts = 0;
                if (data.size() > std::tuple_size<DataType>::value) {
                    data.pop_front();
                }
            };
            return *this;
        }

        std::tuple<float, float, DataType> calcSquareRootsOfAverages(float multiplier) {
            DataType rval = {};
            float averageRms = 0.0;

            std::transform(data.begin(), data.end(), rval.begin(),
                           [multiplier, &averageRms](std::int64_t sum) -> float {
                               auto rmsValue = sqrtf(float(sum) / float(chunkSize));
                               averageRms += rmsValue;
                               return rmsValue * multiplier;
                           });
            float correctedAvgRms = (averageRms * multiplier) / data.size();
            averageRms /= data.size();
            return {correctedAvgRms, averageRms, rval};
        }

        std::tuple<float, float, DataType> calcWeightedSquareRootsOfAverages(float multiplier, const std::vector<float>& fftWindow) {
            vqe::FDFrame<float> fdFrame;
            vqe::Stft<float> stft(frameSize, vqe::StftWindowFunction::NONE, vqe::StftTransformType::STFT);
            stft.stft(data.begin(), data.end(), fdFrame);
            if (fftWindow.size() == fdFrame.size()) {
                for (size_t i = 0; i < fdFrame.size(); ++i) {
                    fdFrame[i] *= fftWindow[i];
                }
            }
            std::vector<float> tdFrame;
            stft.istft(fdFrame, tdFrame);
            std::move(begin(tdFrame), end(tdFrame), back_inserter(data));
            return calcSquareRootsOfAverages(multiplier);
        }
    };

    class Compressor {
        PreRms& dst;
        const int ratio = 1;
        std::int64_t current = 0;
        int passCount = 0;

    public:
        Compressor(PreRms& d, int r)
            : dst(d)
            , ratio(r)
                  {};

        Compressor& operator+=(std::int64_t pcm) {
            current += pcm;
            if (++passCount == ratio) {
                dst += current / ratio;
                passCount = 0;
                current = 0;
            }
            return *this;
        }
    };

    class VoiceStatsImpl: public VoiceStats {
        std::mutex mutex_;

        std::unordered_map<std::string, PreRms> preRms_;
        RMSInfo calcedRMS_;
        float rmsCorrection_ = 1.0;
        std::atomic_bool calcOnVqe_{false};

        void calcRMSInfo() {
            calcedRMS_.clear();
            calcedRMS_.reserve(preRms_.size());
            for (auto& [name, src] : preRms_) {
                auto [average, rawAverage, data] = src.calcSquareRootsOfAverages(rmsCorrection_);
                calcedRMS_.emplace_back(name, average, rawAverage, data);
            };
        }

        void calcWeightedRMSInfo(const std::vector<float>& fftWindow) {
            calcedRMS_.clear();
            calcedRMS_.reserve(preRms_.size());
            for (auto& [name, src] : preRms_) {
                auto [average, rawAverage, data] = src.calcWeightedSquareRootsOfAverages(rmsCorrection_, fftWindow);
                calcedRMS_.emplace_back(name, average, rawAverage, data);
            };
        }

    public:
        RMSInfo getRms() override {
            std::lock_guard<std::mutex> lock(mutex_);
            calcRMSInfo();
            return calcedRMS_;
        }

        RMSInfo getWeightedRms(const std::vector<float>& fftWindow) override {
            std::lock_guard<std::mutex> lock(mutex_);
            calcWeightedRMSInfo(fftWindow);
            return calcedRMS_;
        }

        template <typename Result_>
        void updatePreRms(Result_& dst, const std::int16_t* pcmPtr, std::size_t size) {
            const std::int16_t* pcmEnd = pcmPtr + size;
            while (pcmPtr != pcmEnd) {
                dst += *pcmPtr;
                ++pcmPtr;
            };
        }

        void pushAudioChannel(const YandexIO::ChannelData& channel) {
            std::lock_guard<std::mutex> lock(mutex_);
            auto [iter, inserted] = preRms_.emplace(channel.name, true);
            iter->second.fresh = true;
            int compressRatio = channel.sampleRate / 16000;
            if (compressRatio > 1) {
                Compressor dst(iter->second, compressRatio);
                updatePreRms(dst, &channel.data[0], channel.data.size());
            } else {
                updatePreRms(iter->second, &channel.data[0], channel.data.size());
            }
        }

        void calcOnVqe(bool v) override {
            calcOnVqe_ = v;
        }

        void pushAudioChannels(const YandexIO::ChannelsData& data) override {
            using ChannelType = YandexIO::ChannelData::Type;
            bool calcOnVqe = calcOnVqe_.load();
            for (auto& channel : data) {
                if ((channel.type == ChannelType::RAW || (channel.type == ChannelType::VQE && calcOnVqe)) && channel.name != RCU_CHANNEL_NAME) {
                    pushAudioChannel(channel);
                }
            }
            removeOutdatedChannels();
        }

        void removeOutdatedChannels() {
            auto cur = preRms_.begin();
            auto end = preRms_.end();
            while (cur != end) {
                if (!cur->second.fresh) {
                    cur = preRms_.erase(cur);
                } else {
                    cur->second.fresh = false;
                    ++cur;
                }
            }
        }

        void setRMSCorrection(float rmsCorrection) override {
            std::lock_guard<std::mutex> lock(mutex_);
            if (rmsCorrection_ != rmsCorrection) {
                YIO_LOG_INFO("RMScorrection factor changed from " << rmsCorrection_ << " to " << rmsCorrection);
                rmsCorrection_ = rmsCorrection;
            };
        };
    };
} // namespace

std::shared_ptr<VoiceStats> VoiceStats::create() {
    return std::make_shared<VoiceStatsImpl>();
}

Json::Value VoiceStats::rmsToJson(const RMSInfo& rms) {
    Json::Value rval;
    for (auto& channel : rms) {
        Json::Value channelVal;
        channelVal["name"] = channel.name;
        Json::Value data(Json::arrayValue);
        for (auto f : channel.data) {
            data.append(f);
        }
        channelVal["data"] = std::move(data);
        rval.append(channelVal);
    }
    return rval;
}

/* Problem: decrease size of rmses in vins.
   1. There is no way to separately specify precision for floats in Json::Value
   2. Speechkit accepts json as string from us. So it parse our floats to Json::Value and serialize back
   So we use "fixedpoint" way: we send values 100 times biger than initial ones. That's enough to compare two RMSes on server side.
   P.S. Float as string contains more bytes: "32.25" is 3 bytes more. For 10 values per 4 mics channels is 120 bytes more.
 */

namespace {

    std::tuple<Json::Value, float, float> rmsToJsonPackedChannels(const RMSInfo& rms) {
        Json::Value rval(Json::arrayValue);
        float averageRms = 0.0;
        float rawAvgRms = 0.0;
        for (auto& channel : rms) {
            Json::Value channelVal;
            channelVal["name"] = channel.name;
            Json::Value data(Json::arrayValue);
            auto squeezedChannel = squeezeChannel(channel.data);
            for (const auto floatVal : squeezedChannel) {
                data.append(int(roundf(floatVal * 100.0)));
            }
            channelVal["data"] = std::move(data);
            rval.append(channelVal);
            averageRms += channel.average;
            rawAvgRms += channel.rawAverage;
        }
        if (rms.size() && averageRms > 0.0) {
            averageRms /= float(rms.size());
            rawAvgRms /= float(rms.size());
        }
        return {rval, averageRms, rawAvgRms};
    }

} // namespace

Json::Value VoiceStats::rmsToJsonPacked(const RMSInfo& rms) {
    Json::Value rval;
    rval["version"] = 1;
    auto [channels, avgRms, rawAvgRms] = rmsToJsonPackedChannels(rms);
    rval["channels"] = channels;
    rval["AvgRMS"] = avgRms;
    rval["RawAvgRMS"] = rawAvgRms;

    return rval;
}
