#include "god_object.h"

#include <mail/so/spamstop/tools/lsa/data/clusterify.h>
#include <mail/so/libs/syslog/so_log.h>
#include <library/cpp/vowpalwabbit/vowpal_wabbit_predictor.h>
#include <util/generic/bt_exception.h>
#include <util/string/join.h>

#include <utility>
#include <library/cpp/accurate_accumulate/accurate_accumulate.h>
#include <util/generic/iterator_range.h>
#include <util/generic/serialized_enum.h>

#include <catboost/libs/eval_result/eval_helpers.h>
#include <library/cpp/config/config.h>
#include <mail/so/libs/talkative_config/config.h>
#include <mlp/mail/tabs/tabs_model.h>
#include <util/string/strip.h>

#include "mlp_solver.h"
#include "greedy_solver.h"

NLSA::TPredictionsMap TFeaturesMapper::MakePredictionsMap(const TVector<float>& predictions) const {
    NLSA::TPredictionsMap predictionsMap;
    for(const auto& [index, featureName] : map) {
        if (index >= predictions.size())
            ythrow TWithBackTrace<yexception>() << "there are " << predictions.size()
                                                << " predictions, but index in mapper is " << index;

        predictionsMap.emplace(featureName, predictions[index]);
    }
    return predictionsMap;
}

TFeaturesMapper TFeaturesMapper::FromConfig(const NConfig::TConfig& config, size_t maxIndex) {
    TFeaturesMapper mapper;
    for(const auto& [featureIndex, featureName] : NTalkativeConfig::Get<NConfig::TDict>(config)) {
        const size_t index = FromString(featureIndex);
        if(index > maxIndex)
            ythrow TWithBackTrace<yexception>() << "index=" << index << " maxIndex=" << maxIndex << " config=" << config;
        mapper.map.emplace(index, NTalkativeConfig::Get<TString>(featureName));
    }

    return mapper;
}


TRangeProjector TRangeProjector::FromConfig(const NConfig::TConfig& config) {
    const auto& rangeArray = NTalkativeConfig::Get<NConfig::TArray>(config, "range");

    if(rangeArray.size() < 2)
        ythrow TWithBackTrace<yexception>() << "wrong format for range " << config;

    TRangeProjector projector;
    projector.LowerBound = NTalkativeConfig::As<float>(rangeArray[0]);
    projector.UpperBound = NTalkativeConfig::As<float>(rangeArray[1]);

    projector.Feature = NTalkativeConfig::Get<TString>(config, "feature");

    return projector;
}

const TString& TRangeProjector::Apply(float value) const {
    return LowerBound <= value && value < UpperBound ? Feature : Default<TString>();
}


TVector<TString> TProjectorsMapper::MakeFeatures(const TVector<float>& predictions) const {
    TVector<TString> features;
    for(const auto& [index, projectors] : map) {
        if (index >= predictions.size())
            ythrow TWithBackTrace<yexception>() << "there are " << predictions.size()
                                                << " predictions, but index in mapper is " << index;

        for(const auto& projector : projectors) {
            auto f = projector->Apply(predictions[index]);
            if(f)
                features.emplace_back(std::move(f));
        }
    }
    return features;
}

TProjectorsMapper TProjectorsMapper::FromConfig(const NConfig::TConfig& config, size_t maxIndex) {
    TProjectorsMapper featuresMapper;
    for(const auto& [indexStr, projectorsConfig] : NTalkativeConfig::Get<NConfig::TDict>(config)) {
        const size_t index = FromString(indexStr);
        if(index > maxIndex)
            ythrow TWithBackTrace<yexception>() << "index=" << index << " maxIndex=" << maxIndex << " config=" << config;
        auto& projectors = featuresMapper.map[index];

        for(const auto& projectorConfig : NTalkativeConfig::Get<NConfig::TArray>(projectorsConfig)) {
            projectors.emplace_back(MakeHolder<TRangeProjector>(TRangeProjector::FromConfig(projectorConfig)));
        }
    }
    return featuresMapper;
}

TMailTabModelWithInfo::TMailTabModelWithInfo(const TFsPath& root, const NConfig::TConfig& config)
    : NMailTabs::TMailTabsModel(config, root) {
    for(const auto& [tab, rule]: NTalkativeConfig::Get<NConfig::TDict>(config, "features")) {
        mapper.emplace(FromString(tab), NTalkativeConfig::Get<TString>(rule));
    }
}

struct TVWModelInfo{
    TString GetFilename() const {
        return sandboxId + '_' + ToString(threshold);
    }
    bool operator < (const TVWModelInfo & i) const {
        return sandboxId < i.sandboxId;
    }
    friend IOutputStream & operator << (IOutputStream & stream, const TVWModelInfo & info) {
        return stream << info.sandboxId << ' ' << info.threshold;
    }

    TVWModelInfo() = default;
    TVWModelInfo(TVWModelInfo&&) noexcept = default;
    TVWModelInfo(const TVWModelInfo&) = default;
    TVWModelInfo(TString sandboxId, double threshold) : sandboxId(std::move(sandboxId)), threshold(threshold) {}

    TString sandboxId;
    double threshold{};
};

TVector<const NLSA::TW2VTrait *> TGodObject::GetCoordinatesByIds(const TVector<TString>& tokens) const {
    TVector<const NLSA::TW2VTrait *> result(Reserve(tokens.size()));

    for (const auto& w : tokens)
        if (auto it = w2vDictionary.find(w); w2vDictionary.cend() != it)
            result.emplace_back(&it->second);

    return result;
}

TVector<const NLSA::TW2VViewTrait *> TGodObject::GetComplCoordinatesByIds(const TVector<TString>& tokens) const {
    TVector<const NLSA::TW2VViewTrait *> result(Reserve(tokens.size()));

    for (const auto& w : tokens)
        if (auto it = w2vComplsDictionary.find(w); w2vComplsDictionary.cend() != it)
            result.emplace_back(&it->second);

    return result;
}

TVector<NLSA::TDistance::TResult> TGodObject::GetNearestCompl(const TVector<const NLSA::TW2VViewTrait *>& traits) const {
    if(!hnswComplIndex)
        return {};
    TVector<NLSA::TDistance::TResult> distances(Reserve(traits.size()));
    for(const auto* trait: traits) {
        const auto neighbors = hnswComplIndex->GetNearestNeighbors(trait->coordinate, 1, 2);
        distances.emplace_back(neighbors.front().Dist);
    }

    if(distances.size() > 5) {
        auto e = std::next(distances.begin(), 5);
        std::partial_sort(distances.begin(), e, distances.end());
        distances.erase(e, distances.end());
    } else
        std::sort(distances.begin(), distances.end());

    return distances;
}

THashMap<TString, NLSA::TResolutionContext> TGodObject::GetVWResolution(const NLSA::TRequestData &requestData) const {

    THashMap<TString, NLSA::TResolutionContext> predictionsByModel;
    predictionsByModel.reserve(vwModelsById.size());

    for(const auto & id : requestData.GetModelsIds()) {
        auto it = vwModelsById.find(id);

        if(Y_UNLIKELY(vwModelsById.cend() == it))
            continue;

        const auto & model = it->second;

        const TVowpalWabbitPredictor predictor(model);

        TVector<double> predictionsPerHash(Reserve(requestData.Size(NLSA::NonSecureFields)));
        for(const auto& header : model.GetHeaders()) {
            const TVector<TString>& data = header.Apply(requestData.Get(header.GetField()));
            TVector<ui32> hashes(Reserve(data.size() * model.GetNgram()));
            NVowpalWabbit::THashCalcer::CalcHashes(header.GetRepresentation(), data, model.GetNgram(), hashes);

            for(const auto hash : hashes)
                predictionsPerHash.emplace_back(model[hash]);
        }

        const auto prediction = predictor.GetConstPrediction() + FastAccumulate(predictionsPerHash);
        TVector<float> predictions{float(Sigmoid(prediction))};

        for(const auto cluster : ClusterifyValues<clustersNum>(predictionsPerHash.cbegin(), predictionsPerHash.cend())) {
            predictions.emplace_back(Sigmoid(cluster));
        }

        predictionsByModel.emplace(
            id,
            NLSA::TResolutionContext{
                model.mapper.MakePredictionsMap(predictions),
                model.projectorsMapper.MakeFeatures(predictions)
            }
        );
    }

    return predictionsByModel;
}


THashMap<TString, NLSA::TResolutionContext> TGodObject::GetDSSMResolution(const NLSA::TRequestData &requestData) const {
    THashMap<TString, NLSA::TResolutionContext> predictionsByModel;
    predictionsByModel.reserve(dssmModelsById.size());


    for(const auto & id : requestData.GetModelsIds()) {
        auto it = dssmModelsById.find(id);
        if(Y_UNLIKELY(dssmModelsById.cend() == it))
            continue;

        const auto & model = it->second;

        TVector<TString> dssmInputs(Reserve(model.GetHeaders().size()));
        TVector<TString> annotations(Reserve(model.GetHeaders().size()));

        for (const auto& header : model.GetHeaders()) {
            dssmInputs.emplace_back(JoinSeq(" ", header.Apply(requestData.Get(header.GetField()))));
            annotations.emplace_back(header.GetRepresentation());
        }

        TVector<float> dssmOutput;

        model.Apply(
            MakeAtomicShared<NNeuralNetApplier::TSample>(
                    annotations,
                    dssmInputs),
            model.GetOutputs(),
            dssmOutput
        );

        predictionsByModel.emplace(
            id,
            NLSA::TResolutionContext{
                model.mapper.MakePredictionsMap(dssmOutput),
                model.projectorsMapper.MakeFeatures(dssmOutput)
            }
        );
    }

    return predictionsByModel;


}

TVector<TString> TGodObject::GetNGTabResolution(const NLSA::TRequestData &requestData) const {
    TVector<TString> rules;

    for(const auto & [id, model] : ngTabModels) {
        const auto resolution = model.GetMailTabsResolution(requestData);

        NLSA::TResolutionContext resolutionContext;
        rules.emplace_back(model.mapper.at(resolution));
    }

    return rules;
}

THashMap<TString, double> TGodObject::GetVWBodyWeight(const TString & word) const {
    THashMap<TString, double> res;
    for(const auto & [id, modelInfo] : vwModelsById) {
        res[id] = TVowpalWabbitPredictor(modelInfo).Predict("body", TVector<TString>{word}, 2);
    }
    return res;
}

TChain TGodObject::Solve(const NLSA::TMatrix & docCoordinate) const {
    return solver->Solve(*optimizer, docCoordinate);
}

IOutputStream& operator<<(IOutputStream& stream, const NConfig::TConfig& config) {
    config.ToJson(stream);
    return stream;
}

TGodObject::TGodObject(const NConfig::TConfig& config) {
    const auto & map = NTalkativeConfig::Get<NConfig::TDict>(config);
    {
        const TFsPath pathToML = NTalkativeConfig::As<TFsPath>(config, "mlp");
        TIFStream f(pathToML);

        NLSA::TMlByIds ml;
        ::Load(&f, ml);
        optimizer = MakeHolder<TMLPSolver>(ml);
    }
    {
        const TFsPath path = NTalkativeConfig::As<TFsPath>(config, "themes");
        TIFStream f(path);
        ::Load(&f, themes);
    }
    {
        solver = MakeHolder<TGreedySolver>(themes);
    }
    {
        for(const auto& rawPath : NTalkativeConfig::Get<NConfig::TArray>(config, "vw_configs")) {
            const TFsPath path(NTalkativeConfig::As<TFsPath>(rawPath));

            const NConfig::TConfig conf = [&path]{
                TIFStream input(path);
                return NConfig::TConfig::FromJson(input);
            }();

            const auto& workdir = TFsPath(path.Dirname());
            for(const auto& vwConf : NTalkativeConfig::Get<NConfig::TArray>(conf)) {
                const auto resourceId = NTalkativeConfig::As<ui64>(vwConf, "resource_id");

                const TFsPath pathToModel = workdir / NTalkativeConfig::Get<TString>(vwConf, "binary_file");
                Y_VERIFY(pathToModel.Exists(), "%s", pathToModel.GetPath().c_str());


                TVector<NLSA::TProjectorField> headers;
                if(auto it = NTalkativeConfig::Find<NConfig::TArray>(vwConf, "headers")) {
                    for(const auto& header : *it) {
                        headers.emplace_back(NLSA::TProjectorField::Parse(StripString(NTalkativeConfig::Get<TString>(header))));
                    }
                } else {
                    headers = {
                        NLSA::TProjectorField(NLSA::TField::Body),
                        NLSA::TProjectorField(NLSA::TField::Subject),
                        NLSA::TProjectorField(NLSA::TField::Fromname),
                        NLSA::TProjectorField(NLSA::TField::Fromaddr),};
                }

                ui8 ngram = 2;
                if(auto it = NTalkativeConfig::Find(vwConf, "ngram")) {
                    ngram = NTalkativeConfig::As<size_t>(*it);
                }

                vwModelsById.emplace(
                    std::piecewise_construct,
                    std::forward_as_tuple(ToString(resourceId)),
                    std::forward_as_tuple(
                        TBlob::PrechargedFromFile(pathToModel),
                        TFeaturesMapper::FromConfig(NTalkativeConfig::Get(vwConf, "features"), 3),
                        vwConf.Has("projectors") ? TProjectorsMapper::FromConfig(NTalkativeConfig::Get(vwConf, "projectors"), 3) : TProjectorsMapper{},
                        std::move(headers),
                        ngram));
            }
        }
    }
    {
        for(const auto& rawPath : NTalkativeConfig::Get<NConfig::TArray>(config, "dssm_configs")) {
            const TFsPath path(NTalkativeConfig::As<TFsPath>(rawPath));

            const NConfig::TConfig conf = [&path]{
                TIFStream input(path);
                return NConfig::TConfig::FromJson(input);
            }();

            const auto& workdir = TFsPath(path.Dirname());
            for(const auto& dssmConf : NTalkativeConfig::Get<NConfig::TArray>(conf)) {
                const auto resourceId = NTalkativeConfig::As<ui64>(dssmConf, "resource_id");
                TVector<NLSA::TProjectorField> headers;
                TVector<TString> outputs;

                for (const auto& header : NTalkativeConfig::Get<NConfig::TArray>(dssmConf, "headers")) {
                    headers.emplace_back(NLSA::TProjectorField::Parse(StripString(NTalkativeConfig::Get<TString>(header))));
                }

                for (const auto& output : NTalkativeConfig::Get<NConfig::TArray>(dssmConf, "outputs")) {
                    outputs.emplace_back(NTalkativeConfig::As<TString>(output));
                }

                const TFsPath pathToModel = workdir / NTalkativeConfig::Get<TString>(dssmConf, "binary_file");
                Y_VERIFY(pathToModel.Exists(), "%s", pathToModel.GetPath().c_str());

                dssmModelsById.emplace(
                        ToString(resourceId),
                        TDSSMModelWithInfo(
                                TBlob::PrechargedFromFile(pathToModel),
                                std::move(headers),
                                std::move(outputs),
                                TFeaturesMapper::FromConfig(NTalkativeConfig::Get(dssmConf, "features")),
                                dssmConf.Has("projectors") ? TProjectorsMapper::FromConfig(NTalkativeConfig::Get(dssmConf, "projectors")) : TProjectorsMapper{}
                        )
                );
            }
        }
    }

    {
        const TFsPath path = NTalkativeConfig::As<TFsPath>(config, "w2v");
        TIFStream f(path);
        ::Load(&f, w2vDictionary);
    }

    {
        const TFsPath & path = NTalkativeConfig::As<TFsPath>(config, "cb_model_for_checkform");
        catboostModelForCheckForm = ReadModel(path.GetPath());
    }
    {
        for(const auto& [name, traitsConf]: NTalkativeConfig::Get<NConfig::TDict>(config, "tab_models")) {
            TTabModels tabModels;
            for(const auto& [tabName, pathConf]: NTalkativeConfig::Get<NConfig::TDict>(traitsConf, "models")) {
                const TFsPath path = NTalkativeConfig::As<TFsPath>(pathConf);
                tabModels.emplace(FromString(tabName), ReadModel(path));

                Cerr << "loaded " << name << " " << tabName << " from " << path << Endl;
            }
            tabTraitsByName.emplace(name, TModelTraits{std::move(tabModels)});
        }
    }

    {
        if(config.Has("w2v_compl")) {
            const TFsPath path = NTalkativeConfig::As<TFsPath>(config, "w2v_compl");
            TIFStream f(path);
            ::Load(&f, w2vComplsDictionary);

            NLSA::NormalizeDict(w2vComplsDictionary);
        }
    }

    {
        if(config.Has("hnsw_compl")) {
            const TFsPath path = NTalkativeConfig::As<TFsPath>(config, "hnsw_compl");
            NLSA::THnswContext context;

            TIFStream f(path);
            ::Load(&f, context);

            hnswComplIndex = MakeHolder<NLSA::THnswIndex>(TBlob::FromBuffer(context.indexData), context.storage);
        }
    }

    if(auto it = MapFindPtr(map, "ng_tab_config"))  {
        for(const auto& rawPath: NTalkativeConfig::Get<NConfig::TArray>(*it)) {
            const TFsPath path = NTalkativeConfig::As<TFsPath>(rawPath);
            const NConfig::TConfig conf = [&path] {
                TIFStream input(path);
                return NConfig::TConfig::FromJson(input);
            }();

            const auto workingDir = path.Parent();

            for(const auto& [id, localConfig]: NTalkativeConfig::Get<NConfig::TDict>(conf)) {
                try {
                    ngTabModels.emplace(std::piecewise_construct,
                                        std::forward_as_tuple(id),
                                        std::forward_as_tuple(workingDir, localConfig));
                } catch (...) {
                    ythrow TWithBackTrace<yexception>() << "error while loading new tab model " << id << " from config:" << localConfig << ":" << CurrentExceptionMessageWithBt();
                }
            }
        }
    }
}

TGodObject::~TGodObject() = default;
