#include <library/cpp/config/config.h>
#include <library/cpp/logger/priority.h>
#include <kernel/matrixnet/mn_dynamic.h>
#include <catboost/libs/model/model.h>
#include <util/string/join.h>
#include <util/string/split.h>
#include <util/folder/path.h>
#include <util/stream/file.h>
#include <util/generic/adaptor.h>
#include <mail/so/libs/talkative_config/config.h>
#include <mail/so/spamstop/tools/so-common/so_log.h>
#include <mail/so/libs/syslog/so_log.h>
#include "models.h"

TModelConfig::operator size_t() const {
    return THash<TString>()(slot);
}

const TMaybe<ui64>& TModelConfig::GetFormulaId() const {
    return formulaId;
}

bool TModelConfig::operator==(const TModelConfig& config) const {
    return slot == config.slot;
}

IOutputStream& operator<<(IOutputStream& stream, const TModelConfig & model) {
    return stream << LabeledOutput(model.slot) << ", ids:[" << MakeRangeJoiner(",", model.slaveModelsIds) << ']';
}

TModelConfig::TModelConfig(const NConfig::TConfig& config) :
        slot(NTalkativeConfig::Get<TString>(config, "slot")),
        briefName(NTalkativeConfig::Get<TString>(config, "slot_brief_name")),
        binaryFile(NTalkativeConfig::Get<TString>(config, "binary_file")),
        featuresFile(NTalkativeConfig::Get<TString>(config, "features_file")),
        resourceId(NTalkativeConfig::As<ui64>(config, "resource_id")),
        threshold(config.Has("threshold") ? config["threshold"].As<double>() : 0),
        formulaId(config.Has("formula_id") ? MakeMaybe(NTalkativeConfig::As<ui64>(config, "formula_id")) : Nothing()),
        modelType(FromString(NTalkativeConfig::Get<TString>(config, "model_type")))
{
    if(!featuresFile) {
        ythrow TWithBackTrace<yexception>() << "features_file is not set: " << config;
    }
    if(!binaryFile) {
        ythrow TWithBackTrace<yexception>() << "binary_file is not set: " << binaryFile;
    }

    for(const auto& idsFieldName: {"vw_model_resource_id", "dssm_model_resource_id", "aux_models_ids"}) {
        if (!config.Has(idsFieldName))
            continue;
        const auto& idsField = config[idsFieldName];

        if(idsField.IsNull()){
            continue;
        } else if (idsField.IsA<NConfig::TArray>()) {
            for (const auto& id : idsField.Get<NConfig::TArray>()) {
                Y_VERIFY(id.IsNumeric());
                slaveModelsIds.emplace(id.As<ui64>());
            }
        } else if(idsField.IsNumeric()) {
            slaveModelsIds.emplace(idsField.As<ui64>());
        } else {
            ythrow TWithBackTrace<yexception>() << "ids field " << idsFieldName << " isn't array or num";
        }
    }

    if(auto it = MapFindPtr(NTalkativeConfig::Get<NConfig::TDict>(config), "one_vs_all")) {
        OneVsAllMapper.ConstructInPlace(TModelConfig::TMapper::FromConfig(*it));
    }

    if(auto it = MapFindPtr(NTalkativeConfig::Get<NConfig::TDict>(config), "priority")) {
        Priority = NTalkativeConfig::As<size_t>(*it);
    }

    if(auto it = MapFindPtr(NTalkativeConfig::Get<NConfig::TDict>(config), "available-experiments")) {
        const NConfig::TArray& experiments = NTalkativeConfig::Get<NConfig::TArray>(*it);

        for(const NConfig::TConfig& experiment : experiments) {
            AvailableExperiments.emplace(NTalkativeConfig::As<uint64_t>(experiment));
        }
    }
}

bool TModelConfig::AcceptExperimentId(const TExpBoxes& expBoxes) const {
    return AvailableExperiments.empty() ||
           expBoxes.Experiments.empty() ||
           AnyOf(expBoxes.Experiments, [&](const TExpBoxes::TExperiment& experiment){
               return AvailableExperiments.contains(experiment.ExperimentId);
           });
}

TVector<std::pair<TString, double>> TModelConfig::TMapper::Map(const TLog& logger, const TVector<double>& values) const {
    TVector<std::pair<TString, double>> featuresMap(Reserve(values.size()));
    for(const auto& [index, feature]: IndexToFeatureMap) {
        if(Y_UNLIKELY(index >= values.size())) {
            logger << (TLOG_ERR) << index << " is not valid for features with size " << values.size();
            continue;
        }
        featuresMap.emplace_back(feature, values[index]);
    }
    return featuresMap;
}

TModelConfig::TMapper TModelConfig::TMapper::FromConfig(const NConfig::TConfig& config) {
    TMapper mapper;
    for(const auto& [indexStr, featureConf]: NTalkativeConfig::Get<NConfig::TDict>(config)) {
        mapper.IndexToFeatureMap.emplace_back(FromString(indexStr), NTalkativeConfig::Get<TString>(featureConf));
    }
    return mapper;
}

IOutputStream& operator<<(IOutputStream& stream, const IModelApplier::TDict& dict) {
    for(const auto& [feature, index]: dict.featuresIndexesMap) {
        stream << feature << ':' << index << '\n';
    }
    return stream;
}


TVector<float> IModelApplier::TDict::MakeFeatures(std::initializer_list<std::reference_wrapper<const TFeaturesMap>> featuresMaps) const {
    TVector<float> features(maxIndex + 1, 0);
    for(const auto& [featureName, index]: featuresIndexesMap) {

        for(const auto& mapRef: featuresMaps) {
            const float* value = MapFindPtr(mapRef.get(), featureName);
            if(value) {
                features[index] = *value;
                break;
            }
        }

    }
    return features;
}

class TMatrixnetApplier : public IModelApplier{
public:
    TVector<double> Apply(const TVector<float>& features) const final {
        return {model.CalcRelev (features)};
    };

    explicit TMatrixnetApplier(TDict dict, const NMatrixnet::TMnSseDynamic& model) noexcept
            : IModelApplier(std::move(dict)), model(model) {}

private:
    const NMatrixnet::TMnSseDynamic model;
};


class TCatboostApplier : public IModelApplier{
public:
    TVector<double> Apply(const TVector<float>& features) const final {
        TVector<double> predicted(model.GetDimensionsCount(), 0.);
        model.CalcFlat(features, predicted);
        return predicted;
    };

    explicit TCatboostApplier(TDict dict, TFullModel model) noexcept
            : IModelApplier(std::move(dict)), model(std::move(model)) {}

private:
    const TFullModel model;
};

IModelApplier::TDict LoadMatrixnetDict(IInputStream& dictStream) {
    THashMap<TString, size_t> localFeaturesIndexesMap;
    size_t localMaxIndex{};
    TString line;
    while(dictStream.ReadLine(line)) {
        size_t index, onOff;
        TStringBuf featureName;

        Split(line, "\t", index, featureName, onOff);

        if(onOff == 1) {
            if(!localFeaturesIndexesMap.emplace(featureName, index).second)
                ythrow TWithBackTrace<yexception>() << "repeated feature " << featureName << " in " << line;
            localMaxIndex = Max(index, localMaxIndex);
        }
    }

    if(localFeaturesIndexesMap.empty())
        ythrow TWithBackTrace<yexception>() << "empty features map";

    return {localMaxIndex, std::move(localFeaturesIndexesMap)};
}

IModelApplier::TDict LoadCatboostDict(IInputStream& stream, size_t numFloatFeatures) {
    THashMap<TString, size_t> localFeaturesIndexesMap;
    TString line;
    stream.ReadLine(line); // skip Label
    while(stream.ReadLine(line)) {
        TStringBuf type, featureName;
        size_t index{};
        Split(line, '\t', index, type, featureName);

        if(type == "Categ")
            ythrow TWithBackTrace<yexception>() << "categ feature in " << line;

        index -= 1;

        if(!localFeaturesIndexesMap.emplace(featureName, index).second)
            ythrow TWithBackTrace<yexception>() << "repeated feature " << featureName << " in " << line;
    }

    if(localFeaturesIndexesMap.empty())
        ythrow TWithBackTrace<yexception>() << "empty features map";
    return {numFloatFeatures, std::move(localFeaturesIndexesMap)};
}

TAppliersMap LoadAppliers(const TFsPath& workingDir, const NConfig::TConfig& config) {
    workingDir.CheckExists();
    Y_VERIFY(workingDir.IsDirectory(), "%s", workingDir.c_str());
    Y_VERIFY(config.IsA<NConfig::TArray>());

    TAppliersMap applyers;
    for(const auto& c : config.Get<NConfig::TArray>()) {
        TModelConfig modelConfig(c);

        try {
            const auto fullPath = workingDir / modelConfig.binaryFile;
            const auto pathToDict = workingDir / modelConfig.featuresFile;

            fullPath.CheckExists();
            if(!fullPath.IsFile()) {
                ythrow TWithBackTrace<yexception>() << "path is not file  " << fullPath << ' ' << c;
            }
            pathToDict.CheckExists();
            if(!pathToDict.IsFile()) {
                ythrow TWithBackTrace<yexception>() << "path is not file  " << pathToDict << ' ' << c;
            }

            TMappedFileInput dictInput(pathToDict);

            switch (modelConfig.modelType) {
                case TModelConfig::MATRIXNET: {
                    NMatrixnet::TMnSseDynamic model;
                    {
                        TMappedFileInput fileInput(fullPath);
                        model.Load(&fileInput);
                    }

                    auto dict = LoadMatrixnetDict(dictInput);

                    applyers.emplace_back(std::move(modelConfig), MakeSimpleShared<TMatrixnetApplier>(std::move(dict), model));
                    break;
                }
                case TModelConfig::CATBOOST: {
                    auto model = ReadModel(fullPath);
                    auto dict = LoadCatboostDict(dictInput, model.GetNumFloatFeatures());
                    applyers.emplace_back(std::move(modelConfig), MakeSimpleShared<TCatboostApplier>(std::move(dict), std::move(model)));
                    break;
                };
            }
        } catch (...) {
            ythrow TWithBackTrace<yexception>() << "error while loading models in slot " << modelConfig << ':' << CurrentExceptionMessageWithBt();
        }
    }

    StableSort(applyers, [](const auto& p1, const auto& p2){
        return p2.first.Priority > p1.first.Priority;
    });

    return applyers;
}



