#include <algorithm>
#include <array>
#include <numeric>
#include <util/generic/yexception.h>
#include <util/generic/vector.h>
#include <util/generic/hash.h>
#include <util/generic/scope.h>
#include <util/stream/format.h>
#include <util/string/join.h>

#include <valarray>
#include <library/cpp/string_utils/base64/base64.h>
#include <library/cpp/json/json_writer.h>
#include <library/cpp/vowpalwabbit/vowpal_wabbit_predictor.h>
#include <library/cpp/http/server/response.h>
#include <mail/so/spamstop/tools/so-common/stats_consumer.h>
#include <mail/so/spamstop/tools/simple_shingler/tsr.h>

#include "TGetTopThemesHandler.h"

void TGetTopThemesHandler::Reply(THandleContext && handleContext, void* tsr) {

    TStatsList statsList;
    Y_DEFER {
        auto& executionContext = *reinterpret_cast<TProfiledContext::TExecutionContext*>(tsr);
        executionContext.Resource.Add(std::move(statsList));
    };

    const auto reqData = handleContext.ExtractData();
    auto g = Guard(statsList.emplace_back("gtt"));

    NLSA::TRequestData data;
    with_lock(statsList.emplace_back("json_parsing")) {
        data.FromJson(TStringBuf(static_cast<const char*>(reqData.Data()), reqData.Size()));
    }

    NLSA::TResponseData response;

    const TGodObject& godObject = *this->godObject;

    with_lock(statsList.emplace_back("old_themes")) {
        const auto& foundTraits = godObject.GetCoordinatesByIds(data.GetPrefixedWords(NLSA::NonSecureFields));

        if (!foundTraits.empty()) {
            const auto& t = CreateDocCoordinateFromWordCoordinates(foundTraits);
            const auto& docCoordinate = std::get<0>(t);

            const TChain& chain = godObject.Solve(docCoordinate);

            const float matchsPercent = float(foundTraits.size()) / (data.Size(NLSA::NonSecureFields));

            response.SetMatchsPercent(matchsPercent);
            {
                NLSA::TResponseData::TThemeChain themes;
                for (const auto& best : chain) {
                    themes.emplace_back(best.theme->description, static_cast<float>(best.weight));
                }
                response.SetThemes(std::move(themes));
            }
        }
    }
    with_lock(statsList.emplace_back("compls_distance")) {
        const auto& foundTraits = godObject.GetComplCoordinatesByIds(data.GetPrefixedWords(NLSA::NonSecureFields));

        if (!foundTraits.empty())
            response.SetDistancesToCompls( godObject.GetNearestCompl(foundTraits));
    }

    with_lock(statsList.emplace_back("tabs")) {
        std::array<float, 202 * 4> inputForCB{};
        {
            size_t i = 0;

            for (auto field : NLSA::NonSecureFields) {
                const auto& foundTraits = godObject.GetCoordinatesByIds(data.GetPrefixedWords(field));

                if (foundTraits.empty())
                    continue;

                const auto&[docCoordinate, docNorm, docHist, histNorm] = CreateDocCoordinateFromWordCoordinates(foundTraits);

                for (size_t k = 0; k < 100; k++)
                    inputForCB[i++] = docHist(0, k);
                inputForCB[i++] = histNorm;
                for (size_t k = 0; k < 100; k++)
                    inputForCB[i++] = docCoordinate(0, k);
                inputForCB[i++] = docNorm;
            }

        }

        for (const auto &[name, tabModelTraits] : godObject.GetTabTraitsByName()) {
            std::array<double, GetEnumItemsCount<NLSA::TTabTheme>()> predicted{};

            for (const auto &[tab, model] : tabModelTraits.modelsByName) {
                std::array<double, 1> p{};

                model.Calc(inputForCB, {}, p);

                predicted[static_cast<int>(tab)] = Sigmoid(p[0]);
            }

            NLSA::TResponseData::TThemeChain themes(Reserve(GetEnumItemsCount<NLSA::TTabTheme>()));
            for (const auto tab : GetEnumAllValues<NLSA::TTabTheme>())
                themes.emplace_back(ToString(tab), predicted[(int)tab]);

            response.SetTestingThemes(name, std::move(themes));
        }
    }
    with_lock(statsList.emplace_back("vw")) {
        for (auto&& [model, resolutionContext] : godObject.GetVWResolution(data)) {
            response.AddPrediction(model, std::move(resolutionContext.predictionsMap));
            response.SetRules(model, std::move(resolutionContext.rules));
        }
    }

    with_lock(statsList.emplace_back("dssm")) {
        for (auto&& [model, resolutionContext] : godObject.GetDSSMResolution(data)) {
            response.AddPrediction(model, std::move(resolutionContext.predictionsMap));
            response.SetRules(model, std::move(resolutionContext.rules));
        }
    }

    with_lock(statsList.emplace_back("ng_tab")) {
        response.AddRules(godObject.GetNGTabResolution(data));
    }

    THttpResponse(HTTP_OK).SetContent(response.ToJsonString()).OutTo(handleContext.output);
}

std::tuple<NLSA::TMatrix, double, NLSA::TMatrix, double> TGetTopThemesHandler::CreateDocCoordinateFromWordCoordinates(const TVector<const NLSA::TW2VTrait *> & coordinates) {

    NLSA::TMatrix hist(1, 100);

    auto it = coordinates.cbegin();

    NLSA::TMatrix sumCoordinate = (*it)->coordinate;
    hist(0, (*it)->cluster) ++;

    for(++it; it != coordinates.cend(); ++it) {
        sumCoordinate += (*it)->coordinate;
        hist(0, (*it)->cluster) ++;
    }

    const auto coordsNorm = sumCoordinate.norm();
    const auto histNorm = hist.norm();

    if(coordsNorm)
        sumCoordinate /= coordsNorm;
    if(histNorm)
        hist /= histNorm;


    return {std::move(sumCoordinate), coordsNorm, std::move(hist), histNorm};
}
