#include "node_resolver.h"

#include <drive/backend/support_center/categorizer/model.h>
#include <drive/backend/support_center/ifaces.h>
#include <drive/library/cpp/taxi/support_classifier/client.h>

template <>
NJson::TJsonValue NJson::ToJson(const TSuggestSchema& object) {
    NJson::TJsonValue result;
    NJson::InsertField(result, "node", object.GetNode());
    NJson::InsertField(result, "text", object.GetText());
    return result;
}

template <>
bool NJson::TryFromJson(const NJson::TJsonValue& value, TSuggestSchema& result) {
    return
        NJson::ParseField(value, "node", result.MutableNode(), true) &&
        NJson::ParseField(value, "text", result.MutableText(), false);
}

template <>
NJson::TJsonValue NJson::ToJson(const TNodeResolveParameters& object) {
    NJson::TJsonValue result;
    NJson::InsertField(result, "classification_result", object.GetClassificationResult());
    NJson::InsertField(result, "node_id", object.GetNodeId());
    NJson::InsertField(result, "min_confidence", object.GetMinConfidence());
    NJson::InsertField(result, "max_confidence", object.GetMaxConfidence());
    NJson::InsertField(result, "append_suggest", object.GetAppendSuggest());
    if (object.HasSchema()) {
        NJson::InsertField(result, "schema", object.GetSchemaRef());
    }
    if (object.HasSureSchema()) {
        NJson::InsertField(result, "sure_schema", object.GetSureSchemaRef());
    }
    return result;
}

template <>
NJson::TJsonValue NJson::ToJson(const TClassificationFeatures& object) {
    NJson::TJsonValue result;
    NJson::InsertField(result, "user_tags", object.GetUserTags());
    return result;
}

NJson::TJsonValue TSuggestSchema::GetChatReport() const {
    NJson::TJsonValue result = NJson::JSON_MAP;
    const TString type = "message";
    NJson::InsertField(result, "message_text", Node);
    NJson::InsertField(result, "text", Text);
    NJson::InsertField(result, "type", type);
    return result;
}

bool TNodeResolveParameters::DeserializeFromJson(const NJson::TJsonValue& json, TMessagesCollector& errors) {
    return NJson::ParseField(json, "classification_result", ClassificationResult, true, errors) &&
        NJson::ParseField(json, "node_id", NodeId, true, errors) &&
        NJson::ParseField(json, "min_confidence", MinConfidence, false, errors) &&
        NJson::ParseField(json, "max_confidence", MaxConfidence, false, errors) &&
        NJson::ParseField(json, "append_suggest", AppendSuggest, false, errors) &&
        NJson::ParseField(json, "sure_schema", SureSchema, false, errors) &&
        NJson::ParseField(json, "schema", Schema, false, errors);
}

bool TNodeResolveParameters::IsMatching(const TString& classificationResult, const i32 confidence) const {
    return ClassificationResult == classificationResult &&
        MinConfidence <= confidence && confidence <= MaxConfidence;
}

void TNodeResolveParameters::AddNodeTextMappings(TMap<TString, TString>& map) const {
    if (HasSchema()) {
        map[Schema->GetNode()] = Schema->GetText();
    }
    if (HasSureSchema()) {
        for (auto&& elem : SureSchema.GetRef()) {
            map[elem.GetNode()] = elem.GetText();
        }
    }
}

NJson::TJsonValue TClassificationFeatures::GenerateFeaturesJson(const IChatUserContext::TPtr context, bool userOptionsSuggest) const {
    NJson::TJsonValue result = NJson::JSON_ARRAY;
    if (!UserTags.empty()) {
        TMap<TString, ui32> tagCount;
        for (auto&& tagName : UserTags) {
            NJson::TJsonValue tagCount = NJson::JSON_MAP;
            tagCount["key"] = tagName;
            auto it = context->GetUserTags().find(tagName);
            if (it != context->GetUserTags().end()) {
                tagCount["value"] = it->second;
            } else {
                tagCount["value"] = 0;
            }
            result.AppendValue(tagCount);
        }
    }
    if (userOptionsSuggest) {
        NJson::TJsonValue isSuggest = NJson::JSON_MAP;
        isSuggest["key"] = "suggest";
        isSuggest["value"] = true;
        result.AppendValue(isSuggest);
    }
    return result;
}

bool TClassificationFeatures::DeserializeFromJson(const NJson::TJsonValue& json, TMessagesCollector& errors) {
    return NJson::ParseField(json, "user_tags", UserTags, false, errors);
}

TExpected<INodeResolver::TPtr, TString> INodeResolver::Construct(const NJson::TJsonValue& json, const INodeResolver::TPtr toCopy) {
    TPtr nodeResolver = nullptr;
    if (json.IsMap()) {
        TString classifierType;
        TMessagesCollector errors;
        if (!NJson::ParseField(json, "type", classifierType, true, errors)) {
            return MakeUnexpected(errors.GetStringReport());
        }
        nodeResolver = TFactory::Construct(classifierType);
        if (!nodeResolver) {
            return MakeUnexpected(classifierType + " is unknown");
        }
        nodeResolver->MergeFields(toCopy);
        if (!nodeResolver->DeserializeFromJson(json, errors)) {
            return MakeUnexpected(errors.GetStringReport());
        }
    }
    return nodeResolver;
}

void INodeResolver::MergeFields(INodeResolver::TPtr toCopy) {
    if (!toCopy) {
        return;
    }
    SetDefaultNode(toCopy->GetDefaultNode());
    SetFallbackNode(toCopy->GetFallbackNode());
    SetRequestTimeout(toCopy->GetRequestTimeout());
    SetResolveParameters(toCopy->GetResolveParameters());
    SetSuggestOptionsCount(toCopy->GetSuggestOptionsCount());
}

bool INodeResolver::DeserializeFromJson(const NJson::TJsonValue& json, TMessagesCollector& errors) {
    if (json["resolve_parameters"].GetArray().size() == 0 && ResolveParameters.empty()) {
        errors.AddMessage("INodeResolver::DeserializeFromJson", "'resolve_parameters' field is missing or not an array");
        return false;
    }
    if (json["resolve_parameters"].GetArray().size() > 0) {
        ResolveParameters.clear();
        for (auto&& resolveParemetersJson : json["resolve_parameters"].GetArray()) {
            TNodeResolveParameters resolveParameters;
            if (!resolveParameters.DeserializeFromJson(resolveParemetersJson, errors)) {
                return false;
            }
            resolveParameters.AddNodeTextMappings(SuggestNodeToText);
            ResolveParameters.push_back(std::move(resolveParameters));
        }
    }
    if (json["features"].IsMap() && !ClassificationFeatures.DeserializeFromJson(json["features"], errors)) {
        return false;
    }
    return
        NJson::ParseField(json, "default_node_id", DefaultNode, DefaultNode.empty(), errors) &&
        NJson::ParseField(json, "fallback_node_id", FallbackNode, FallbackNode.empty(), errors) &&
        NJson::ParseField(json, "suggest_options_count", SuggestOptionsCount, false, errors) &&
        NJson::ParseField(json, "message_limit", MessageLimit, false, errors) &&
        NJson::ParseField(json, "request_timeout", RequestTimeout, false, errors);
}

NJson::TJsonValue INodeResolver::SerializeToJson() const {
    NJson::TJsonValue result;
    result["default_node_id"] = NJson::ToJson(GetDefaultNode());
    result["fallback_node_id"] = NJson::ToJson(GetFallbackNode());
    result["request_timeout"] = NJson::ToJson(GetRequestTimeout());
    result["resolve_parameters"] = NJson::ToJson(GetResolveParameters());
    result["features"] = NJson::ToJson(GetClassificationFeatures());
    result["suggest_options_count"] = NJson::ToJson(GetSuggestOptionsCount());
    result["message_limit"] = NJson::ToJson(GetMessageLimit());
    return result;
}

TString INodeResolver::GetSuggestText(const TString& node) const {
    auto it = SuggestNodeToText.find(node);
    if (it != SuggestNodeToText.end()) {
        return it->second;
    }
    return "";
}

TMaybe<TNodeResolveParameters> INodeResolver::TryFindResolveParameters(const TString& classificationResult, const i32 confidence, const bool withSchema) const {
    auto it = std::find_if(GetResolveParameters().begin(), GetResolveParameters().end(),
        [&classificationResult, confidence, withSchema](const TNodeResolveParameters& param) {
            return param.IsMatching(classificationResult, confidence) && (!withSchema || param.HasSchema());
        });
    if (it != GetResolveParameters().end()) {
        return *it;
    }
    return Nothing();
}

TMaybe<TNodeResolveParameters> INodeResolver::TryFindResolveParameters(const TString& classificationResult, const bool withSchema) const {
    auto it = std::find_if(GetResolveParameters().begin(), GetResolveParameters().end(),
        [&classificationResult, withSchema](const TNodeResolveParameters& param) {
            return param.GetClassificationResult() == classificationResult && (!withSchema || param.HasSchema());
        });
    if (it != GetResolveParameters().end()) {
        return *it;
    }
    return Nothing();
}

bool ComparePredictions(const NDrive::TSupportPrediction::TElement& lhs, const NDrive::TSupportPrediction::TElement& rhs) {
    return lhs.Probability > rhs.Probability;
}

TString TSupportPredictionNodeResolver::MatchPredictions(const TVector<NDrive::TSupportPrediction::TElement>& predictions, const TString& logEvent) const {
    for (auto&& pred : predictions) {
        auto resolveParameters = TryFindResolveParameters(pred.Topic, pred.Probability * 100, false);
        if (resolveParameters) {
            NDrive::TEventLog::Log(logEvent, NJson::TMapBuilder("node", resolveParameters->GetNodeId()));
            return resolveParameters->GetNodeId();
        }
    }
    NDrive::TEventLog::Log(logEvent, NJson::TMapBuilder("node", GetDefaultNode()));
    return GetDefaultNode();
}

TVector<TNodeResolveParameters> TSupportPredictionNodeResolver::GetTopPredictions(const TVector<NDrive::TSupportPrediction::TElement>& predictions, const TSet<TString>& skipTopics, const ui32 count, const bool withSchema) const {
    TVector<TNodeResolveParameters> result;
    for (auto&& pred : predictions) {
        if (result.size() >= count) {
            return result;
        }
        if (skipTopics.contains(pred.Topic)) {
            continue;
        }
        auto resolveParameters = TryFindResolveParameters(pred.Topic, withSchema);
        if (resolveParameters) {
            result.emplace_back(std::move(*resolveParameters));
        }
    }
    return result;
}

TString TTaxiSupportChatNodeResolver::GetNextNode(const IChatUserContext::TPtr context, const NDrive::NChat::TMessageEvents& messages) const {
    auto suggestResult = GetSuggest(context, messages);
    suggestResult.Wait(GetRequestTimeout());
    if (suggestResult.HasException()) {
        NDrive::TEventLog::Log("TTaxiSupportChatNodeResolverError", NJson::TMapBuilder("error", "Taxi classifier exception " + NThreading::GetExceptionMessage(suggestResult)));
        return GetFallbackNode();
    }
    if (TString sureTopic = suggestResult.GetValue().GetSureTopic()) {
        auto sureNode = TryFindResolveParameters(sureTopic, false);
        if (sureNode) {
            return sureNode->GetNodeId();
        }
    }
    auto predictions = suggestResult.GetValue().GetPredictions().Elements;
    Sort(predictions.begin(), predictions.end(), ComparePredictions);
    return MatchPredictions(predictions, "TTaxiSupportChatNodeResolver");
}

NThreading::TFuture<TString> TTaxiSupportChatNodeResolver::GetNextNodeFuture(const IChatUserContext::TPtr context, const NDrive::NChat::TMessageEvents& messages) const {
    auto suggestResult = GetSuggest(context, messages);
    return suggestResult.Apply([*this](const NThreading::TFuture<TTaxiSupportChatSuggestClient::TSuggestResponse>& fut) {
        if (fut.HasException() || !fut.HasValue()) {
            NDrive::TEventLog::Log("TTaxiSupportChatNodeResolverError", NJson::TMapBuilder("error", "Taxi classifier exception " + NThreading::GetExceptionMessage(fut)));
            return GetFallbackNode();
        }
        if (TString sureTopic = fut.GetValue().GetSureTopic()) {
            auto sureNode = TryFindResolveParameters(sureTopic, false);
            if (sureNode) {
                return sureNode->GetNodeId();
            }
        }
        auto predictions = fut.GetValue().GetPredictions().Elements;
        Sort(predictions.begin(), predictions.end(), ComparePredictions);
        return MatchPredictions(predictions, "TTaxiSupportChatNodeResolver");
    });
}

NThreading::TFuture<TTaxiSupportChatSuggestClient::TSuggestResponse> TTaxiSupportChatNodeResolver::GetSuggest(const IChatUserContext::TPtr context, const NDrive::NChat::TMessageEvents& messages, bool userOptionsSuggest) const {
    if (!context->GetChatRobot()) {
        return NThreading::MakeErrorFuture<TTaxiSupportChatSuggestClient::TSuggestResponse>(std::make_exception_ptr(yexception() << "No chat robot"));
    }
    if (!context->GetServer().GetTaxiSupportChatSuggestClient()) {
        return NThreading::MakeErrorFuture<TTaxiSupportChatSuggestClient::TSuggestResponse>(std::make_exception_ptr(yexception() << "Taxi support chat suggest client not configured"));
    }
    return context->GetServer().GetTaxiSupportChatSuggestClient()->GetChatSuggest(TTaxiSupportChatSuggestClient::TDialog(context, messages, GetClassificationFeatures().GenerateFeaturesJson(context, userOptionsSuggest), GetMessageLimit()));
}

NJson::TJsonValue TTaxiSupportChatNodeResolver::GetSuggestedChatOptions(const TTaxiSupportChatSuggestClient::TSuggestResponse& suggestResponse) const {
    NJson::TJsonValue result = NJson::JSON_ARRAY;
    auto predictions = suggestResponse.GetPredictions().Elements;
    Sort(predictions.begin(), predictions.end(), ComparePredictions);
    ui32 suggestCount = GetSuggestOptionsCount();
    TSet<TString> skipTopics;
    if (TString sureTopic = suggestResponse.GetSureTopic()) {
        skipTopics.emplace(sureTopic);
        auto sureNode = TryFindResolveParameters(sureTopic, true);
        if (sureNode) {
            if (sureNode->HasSureSchema()) {
                for (auto&& schema : sureNode->GetSureSchemaRef()) {
                    result.AppendValue(schema.GetChatReport());
                    --suggestCount;
                }
                if (!sureNode->GetAppendSuggest()) {
                    suggestCount = 0;
                }
            } else if (sureNode->HasSchema()) {
                result.AppendValue(sureNode->GetSchemaRef().GetChatReport());
                --suggestCount;
            }
        }
    }

    if (suggestCount > 0) {
        auto mathcingResolvers = GetTopPredictions(predictions, skipTopics, suggestCount, true);
        for (auto&& resolver : mathcingResolvers) {
            result.AppendValue(resolver.GetSchemaRef().GetChatReport());
        }
    }

    return result;
}

INodeResolver::TFactory::TRegistrator<TTaxiSupportChatNodeResolver> TTaxiSupportChatNodeResolver::Registrator("taxi_support_chat_classification");
