#include "transfer.h"

#include <mail/so/spamstop/tools/so-common/kfunc.h>

#include <util/generic/buffer.h>
#include <util/stream/buffer.h>

#include <library/cpp/json/json_writer.h>
#include <library/cpp/json/json_reader.h>
#include <util/generic/serialized_enum.h>
#include <util/generic/guid.h>
#include <util/generic/bt_exception.h>
#include <util/string/join.h>

namespace NLSA {
    bool TRequestData::operator == (const TRequestData& data) const {
        return fields == data.fields && modelsIds == data.modelsIds && controlSums == data.controlSums;
    }

    size_t TRequestData::Size(std::initializer_list<TField> fieldsIds) const {
        return std::reduce(fieldsIds.begin(), fieldsIds.end(), size_t{}, [this](size_t s, TField field) {
            return s + fields[(int)field].size();
        });
    }
    bool TControlSum::operator==(const TControlSum& cs) const {
        return Sum == cs.Sum;
    }

    bool TControlSum::operator!=(const TControlSum& cs) const {
        return !(*this == cs);
    }

    bool TControlSum::Empty() const {
        return Sum == 0;
    }

    ui64 TControlSum::GetSum() const {
        return Sum;
    }

    void TControlSum::Update(const TString& value) {
        Sum = FnvHash<ui64>(value.c_str(), value.size(), Sum == 0 ? FNV64INIT : Sum);
    }

    static IOutputStream& operator<<(IOutputStream& stream, const TControlSum& value) {
        return stream << value.GetSum();
    }

    IOutputStream& operator<<(IOutputStream& stream, const TRequestData& data) {
        data.VerifyControlSums();
        {
            NJsonWriter::TBuf writer(NJsonWriter::HEM_DONT_ESCAPE_HTML, &stream);

            writer.BeginObject();
            {
                for (const auto fieldId : GetEnumAllValues<TField>()) {
                    const auto& field = data.fields[(int)fieldId];
                    if (!field)
                        continue;
                    writer
                        .WriteKey(ToString(fieldId))
                        .BeginList();
                    for (const auto& v : field)
                        writer.WriteString(v);
                    writer.EndList();
                }
                {
                    writer.WriteKey("cs")
                          .BeginObject();

                    for(const auto fieldId : GetEnumAllValues<TField>()) {
                        const auto& targetSum = data.controlSums[(int)fieldId];

                        if(!targetSum.Empty()) {
                            writer.WriteKey(ToString(fieldId))
                                  .WriteULongLong(targetSum.GetSum());
                        }
                    }

                    writer.EndObject();
                }

                writer
                    .WriteKey("models_ids")
                    .BeginList();

                for (const auto& v : data.modelsIds)
                    writer.WriteString(v);
                writer.EndList();
            }
            writer.EndObject();
        }

        return stream;
    }

    void TRequestData::FromJson(const TStringBuf & src) {
        NJson::TJsonValue json;
        if(!NJson::ReadJsonTree(src, &json, false))
            ythrow yexception() << "cannot parse json from \"" << src << '"';

        const auto & map = json.GetMapSafe();

        for(const auto fieldId : GetEnumAllValues<TField>()) {
            auto& targetField = fields[(int)fieldId];

            targetField.clear();

            auto it = map.find(ToString(fieldId));
            if(map.cend() == it)
                continue;
            const auto &array = it->second.GetArraySafe();

            targetField.reserve(array.size());
            for (const auto &v : array) {
                targetField.emplace_back(v.GetStringSafe());
            }
        }

        if(auto controlSumsJs = MapFindPtr(map, "cs")) {
            const auto& controlSumsMap = controlSumsJs->GetMapSafe();
            for(const auto fieldId : GetEnumAllValues<TField>()) {
                auto& targetSum = controlSums[(int)fieldId];

                if(auto it = MapFindPtr(controlSumsMap, ToString(fieldId))) {
                    targetSum = TControlSum(it->GetUIntegerSafe());
                }
            }

            VerifyControlSums();
        }

        for(auto & p : {
                std::make_tuple("models_ids", std::ref(modelsIds)),
        }) {
            std::get<1>(p).clear();
            auto it = map.find(std::get<0>(p));
            if(map.cend() == it)
                continue;
            const auto &array = it->second.GetArraySafe();
            std::get<1>(p).reserve(array.size());
            for (const auto &v : array)
                std::get<1>(p).emplace(v.GetStringSafe());
        }
    }

    void TRequestData::Add(TString value, TField field) {
        controlSums[(int)field].Update(value);
        fields[(int)field].emplace_back(std::move(value));
    }

    void TRequestData::VerifyControlSums() const {
        for(const TField fieldId : AllFields) {
            const auto& field = fields[static_cast<int>(fieldId)];
            const auto sum = controlSums[static_cast<int>(fieldId)];

            if(field) {
                if(sum.Empty())
                    ythrow TWithBackTrace<yexception>() << fieldId << " isn't empty: " << MakeRangeJoiner(",", field) << ", but controlsum is";
                TControlSum cs;

                for(const auto& v : field)
                    cs.Update(v);

                if(cs != sum)
                    ythrow TWithBackTrace<yexception>() << fieldId << " cs mismatch: " << MakeRangeJoiner(",", field);
            } else {
                if(!sum.Empty())
                    ythrow TWithBackTrace<yexception>() << fieldId << " is empty: " << MakeRangeJoiner(",", field) << ", but controlsum isn't: " << sum;
            }
        }
    }

    const TVector<TString>& TRequestData::Get(TField field) const {
        return fields[(int)field];
    }

    TVector<TString> TRequestData::GetPrefixedWords(TField field) const {
        return GetPrefixedWords({field});
    }

    TVector<TString> TRequestData::GetPrefixedWords(std::initializer_list<TField> fields) const {
        TVector<TString> result;
        for(const auto field : fields) for(const auto & w : Get(field))
            result.emplace_back(ToString(field) + "_" + w);
        return result;
    }

    TString TResponseData::ToJsonString() const
    {
        TStringStream ss;
        {
            NJson::TJsonWriter writer(&ss, false);

            writer.OpenMap();
            {
                {
                    writer.OpenMap("predictions");
                    for (const auto&[modelId, featuresMap]: modelsPredictions) {
                        writer.OpenMap(modelId);
                        for (const auto&[feature, val] : featuresMap)
                            writer.Write(feature, val);
                        writer.CloseMap();
                    }
                    writer.CloseMap();
                }
                {
                    writer.OpenMap("features");
                    for (const auto&[modelId, features]: modelsRules) {
                        writer.OpenArray(modelId);
                        for (const auto& feature : features)
                            writer.Write(feature);
                        writer.CloseArray();
                    }
                    writer.CloseMap();
                }
                {
                    writer.OpenArray("rules");
                    for (const auto& rule: indepRules) {
                        writer.Write(rule);
                    }
                    writer.CloseArray();
                }

                writer.Write("matches", MatchsPercent);
                writer.OpenArray("themes");
                for (const auto& t : Themes) {
                    writer.OpenMap();
                    writer.Write("description", t.GetDescription());
                    writer.Write("distance", t.GetDistance());
                    writer.CloseMap();
                }
                writer.CloseArray();

                writer.OpenArray("testing");
                for (const auto& test : Testing) {
                    writer.OpenMap();
                    writer.Write("name", test.first);

                    writer.OpenArray("themes");
                    for (const auto& t : test.second) {
                        writer.OpenMap();
                        writer.Write("description", t.GetDescription());
                        writer.Write("distance", t.GetDistance());
                        writer.CloseMap();
                    }
                    writer.CloseArray();
                    writer.CloseMap();
                }
                writer.CloseArray();

                writer.OpenArray("distances_to_compls");
                for(const auto d : minDistancesToCompls)
                    writer.Write(d);
                writer.CloseArray();
            }
            writer.CloseMap();
            writer.Flush();
        }

        return std::move(ss.Str());
    }

    TResponseData TResponseData::FromJson(const TStringBuf & src) {
        TResponseData data;
        data.json = {};

        if(!NJson::ReadJsonTree(src, &data.json, false))
            ythrow yexception() << "cannot parse json from " << src;

        if(auto v = data.json.GetValueByPath("predictions")) {
            for (const auto&[modelId, featuresMap] : v->GetMapSafe()) {
                auto& modelsPrediction = data.modelsPredictions[modelId];
                for (const auto&[feature, val]: featuresMap.GetMapSafe())
                    modelsPrediction.emplace(feature, val.GetDoubleRobust());
            }
        }

        if(auto v = data.json.GetValueByPath("features")) {
            for (const auto&[modelId, featuresMap] : v->GetMapSafe()) {
                auto& modelFeatures = data.modelsRules[modelId];
                for (const auto&feature: featuresMap.GetArraySafe())
                    modelFeatures.emplace_back(feature.GetStringSafe());
            }
        }

        if(auto v = data.json.GetValueByPath("rules")) {
            const auto& rules = v->GetArraySafe();
            data.indepRules.reserve(rules.size());
            for (const auto& rule: rules) {
                data.indepRules.emplace_back(rule.GetStringSafe());
            }
        }

        if(auto v = data.json.GetValueByPath("distances_to_compls")) {
            const auto& arr = v->GetArraySafe();
            data.minDistancesToCompls.reserve(arr.size());
            for(const auto& d : arr)
                data.minDistancesToCompls.emplace_back(d.GetDoubleRobust());
        }

        data.MatchsPercent = static_cast<float>(data.json["matches"].GetDoubleSafe());

        for(const auto & jsonTheme : data.json["themes"].GetArraySafe()) {
            data.Themes.emplace_back(
                    jsonTheme["description"].GetStringSafe(),
                    static_cast<float>(jsonTheme["distance"].GetDoubleSafe())
            );
        }

        if(auto v = data.json.GetValueByPath("testing")) {
            data.Testing.reserve(2);
            for (const auto &jsTest : v->GetArraySafe()) {
                TThemeChain themes;
                for (const auto &jsonTheme : jsTest["themes"].GetArraySafe()) {
                    themes.emplace_back(
                            jsonTheme["description"].GetStringSafe(),
                            static_cast<float>(jsonTheme["distance"].GetDoubleSafe())
                    );
                }

                data.Testing.emplace(
                        jsTest["name"].GetStringSafe(),
                        themes
                );
            }
        }

        return data;
    }
}
