

#include <drive/backend/chat_robots/suggest/node_resolver.h>
#include <drive/backend/ut/library/helper.h>
#include <drive/backend/ut/library/script.h>

#include <library/cpp/testing/unittest/env.h>
#include <library/cpp/testing/unittest/registar.h>

Y_UNIT_TEST_SUITE(ChatSuggest) {

    TTestNodeResolver BuildNodeResolverFromString(const TString& str) {
        NJson::TJsonValue jsonConfig;
        UNIT_ASSERT(ReadJsonFastTree(str, &jsonConfig));
        TTestNodeResolver result;
        TMessagesCollector errors;
        UNIT_ASSERT_C(result.DeserializeFromJson(jsonConfig, errors), errors.GetStringReport());
        return result;
    }

    Y_UNIT_TEST(Simple) {
        NDrive::TServerConfigGenerator configGenerator;
        TServerConfigConstructorParams params(configGenerator.GetString().data());
        NDrive::TServerConfig config(params);
        NDrive::TServerGuard server(config);
        TEnvironmentGenerator eGenerator(*server.Get());

        const TString configStr = R"(
        {
            "default_node_id": "default",
            "fallback_node_id": "fallback",
            "resolve_parameters": [
                {
                    "classification_result": "result1",
                    "node_id": "node1",
                    "min_confidence": 35,
                    "max_confidence": 100
                },
                {
                    "classification_result": "result2",
                    "node_id": "node2",
                    "min_confidence": 50,
                    "max_confidence": 99
                }
            ],
            "features": {
                "user_tags": [
                    "simple_tag_1",
                    "simple_tag_2"
                ]
            }
        })";
        {
            auto resolver = BuildNodeResolverFromString(configStr);
            UNIT_ASSERT_VALUES_EQUAL(resolver.GetClassificationFeatures().GetUserTags().size(), 2);
            resolver.AddClassificationResult("a", 99);
            resolver.AddClassificationResult("result1", 30);
            resolver.AddClassificationResult("result2", 20);
            resolver.AddClassificationResult("result2", 70);
            resolver.AddClassificationResult("result1", 55);
            UNIT_ASSERT_VALUES_EQUAL(resolver.GetNextNode(new TChatUserContext(), {}), "node2");
        }
        {
            auto resolver = BuildNodeResolverFromString(configStr);
            resolver.AddClassificationResult("a", 99);
            resolver.AddClassificationResult("result1", 30);
            resolver.AddClassificationResult("result2", 20);
            resolver.AddClassificationResult("result2", 100);
            resolver.AddClassificationResult("result1", 34);
            UNIT_ASSERT_VALUES_EQUAL(resolver.GetNextNode(new TChatUserContext(), {}), "default");
        }
    }

    Y_UNIT_TEST(ClassificationScriptParsing) {
        NDrive::TServerConfigGenerator configGenerator;
        TServerConfigConstructorParams params(configGenerator.GetString().data());
        NDrive::TServerConfig config(params);
        NDrive::TServerGuard server(config);
        TEnvironmentGenerator eGenerator(*server.Get());

        NJson::TJsonValue json;
        TFileInput in(JoinFsPaths(ArcadiaSourceRoot(), "drive/backend/chat_robots/test_data/script_support_suggest.json"));
        ReadJsonTree(&in, &json);

        TChatRobotScript s;
        TMessagesCollector errors;
        UNIT_ASSERT_C(s.Parse(json, errors), TStringBuilder() << errors.GetStringReport());
    }

    Y_UNIT_TEST(ChatClassification) {
        NDrive::TServerConfigGenerator gServer;
        gServer.SetSensorApiName({});
        TServerConfigConstructorParams params(gServer.GetString().data());
        NDrive::TServerConfig config(params);
        NDrive::TServerGuard server(config);

        TEnvironmentGenerator eGenerator(*server);
        eGenerator.BuildEnvironment(TEnvironmentGenerator::DefaultTraits);

        auto userId = eGenerator.CreateUser("skulik-was-not-there", false, "active");
        {
            UNIT_ASSERT(gServer.CommitChatAction(userId, "support_suggest.666", "!classify_default", {}));
            UNIT_ASSERT(gServer.CommitChatAction(userId, "support_suggest.666", "it's time to classify", {}));
            auto messagesMap = gServer.GetChatMessages(userId, "support_suggest.666");
            auto messages = messagesMap["messages"].GetArray();
            UNIT_ASSERT_VALUES_EQUAL(messages.size(), 5);
            UNIT_ASSERT_VALUES_EQUAL(messages.back()["text"].GetString(), "node_class_2");
        }
        {
            UNIT_ASSERT(gServer.CommitChatAction(userId, "support_suggest.777", "!classify_override", {}));
            UNIT_ASSERT(gServer.CommitChatAction(userId, "support_suggest.777", "it's time to classify", {}));
            auto messagesMap = gServer.GetChatMessages(userId, "support_suggest.777");
            auto messages = messagesMap["messages"].GetArray();
            UNIT_ASSERT_VALUES_EQUAL(messages.size(), 5);
            UNIT_ASSERT_VALUES_EQUAL(messages.back()["text"].GetString(), "node_class_1");
        }
    }

    TString GenerateStringWithSuggestOptions() {
        const TString configStr = R"(
        {
            "resolve_parameters": [
            {
                "node_id": "node_class_1",
                "max_confidence": 100,
                "classification_result": "class_1",
                "min_confidence": 10,
                "append_suggest": true,
                "schema": {
                    "node": "node_c_1",
                    "text": "go to node 1"
                },
                "sure_schema": [
                    {
                    "node": "node_1_opt_1",
                    "text": "go to node 1 option 1"
                    },
                    {
                    "node": "node_1_opt_2",
                    "text": "go to node 1 option 2"
                    }
                ]
            },
            {
                "node_id": "node_class_2",
                "max_confidence": 100,
                "classification_result": "class_2",
                "min_confidence": 50,
                "append_suggest": true,
                "schema": {
                    "node": "node_c_2",
                    "text": "go to node 2"
                },
                "sure_schema": [
                    {
                    "node": "node_2_opt_1",
                    "text": "go to node 2 option 1"
                    },
                    {
                    "node": "node_2_opt_2",
                    "text": "go to node 2 option 2"
                    },
                    {
                    "node": "node_2_opt_3",
                    "text": "go to node 2 option 3"
                    },
                    {
                    "node": "node_2_opt_4",
                    "text": "go to node 2 option 4"
                    }
                ]
            },
            {
                "node_id": "node_class_3",
                "max_confidence": 100,
                "classification_result": "class_3",
                "min_confidence": 10,
                "append_suggest": false,
                "schema": {
                    "node": "node_c_3",
                    "text": "go to node 3"
                },
                "sure_schema": [
                    {
                    "node": "node_3_opt_1",
                    "text": "go to node 3 option 1"
                    },
                    {
                    "node": "node_3_opt_2",
                    "text": "go to node 3 option 2"
                    }
                ]
            },
            {
                "node_id": "node_class_4",
                "max_confidence": 100,
                "classification_result": "class_4",
                "min_confidence": 10,
                "append_suggest": false,
                "schema": {
                    "node": "node_c_4",
                    "text": "go to node 4"
                }
            }
            ],
            "type": "taxi_support_chat_classification",
            "fallback_node_id": "classification_fallback",
            "default_node_id": "classification_default",
            "suggest_options_count": 4
        }
        )";
        return configStr;
    }

    void CheckSuggest(const NJson::TJsonValue& suggest, const TString& node, const TString& text) {
        UNIT_ASSERT_VALUES_EQUAL(suggest["text"].GetString(), text);
        UNIT_ASSERT_VALUES_EQUAL(suggest["message_text"].GetString(), node);
        UNIT_ASSERT_VALUES_EQUAL(suggest["type"].GetString(), "message");
    }

    Y_UNIT_TEST(TaxiSuggestOptions) {
        NDrive::TServerConfigGenerator gServer;
        gServer.SetSensorApiName({});
        TServerConfigConstructorParams params(gServer.GetString().data());
        NDrive::TServerConfig config(params);
        NDrive::TServerGuard server(config);
        TEnvironmentGenerator eGenerator(*server);

        const TString configStr = GenerateStringWithSuggestOptions();
        NJson::TJsonValue jsonConfig;
        UNIT_ASSERT(ReadJsonFastTree(configStr, &jsonConfig));
        TTaxiSupportChatNodeResolver nodeResolver;
        TMessagesCollector errors;
        UNIT_ASSERT_C(nodeResolver.DeserializeFromJson(jsonConfig, errors), errors.GetStringReport());
        NDrive::TSupportPrediction::TElement el1;
        el1.Topic = "class_1";
        NDrive::TSupportPrediction::TElement el2;
        el2.Topic = "class_2";
        NDrive::TSupportPrediction::TElement el3;
        el3.Topic = "class_3";
        NDrive::TSupportPrediction::TElement el4;
        el4.Topic = "class_4";

        { // sure topic, only overriden nodes
            TTaxiSupportChatSuggestClient::TSuggestResponse response;
            response.SetSureTopic("class_2");
            response.SetMostProbableTopic("class_2");
            el1.Probability = 0.30;
            el2.Probability = 0.01;
            el3.Probability = 0.60;
            el4.Probability = 0.40;
            NDrive::TSupportPrediction predictions;
            predictions.Elements = { el1, el2, el3, el4 };
            response.SetPredictions(predictions);
            auto jsonResult = nodeResolver.GetSuggestedChatOptions(response);
            UNIT_ASSERT(jsonResult.IsArray());
            UNIT_ASSERT_VALUES_EQUAL(jsonResult.GetArray().size(), 4);
            CheckSuggest(jsonResult[0], "node_2_opt_1", "go to node 2 option 1");
            CheckSuggest(jsonResult[1], "node_2_opt_2", "go to node 2 option 2");
            CheckSuggest(jsonResult[2], "node_2_opt_3", "go to node 2 option 3");
            CheckSuggest(jsonResult[3], "node_2_opt_4", "go to node 2 option 4");
        }
        { // sure topic not enough overriden nodes with append
            TTaxiSupportChatSuggestClient::TSuggestResponse response;
            response.SetSureTopic("class_1");
            response.SetMostProbableTopic("class_1");
            el1.Probability = 0.65;
            el2.Probability = 0.40;
            el3.Probability = 0.60;
            el4.Probability = 0.70;
            NDrive::TSupportPrediction predictions;
            predictions.Elements = { el1, el2, el3, el4 };
            response.SetPredictions(predictions);
            auto jsonResult = nodeResolver.GetSuggestedChatOptions(response);
            UNIT_ASSERT(jsonResult.IsArray());
            UNIT_ASSERT_VALUES_EQUAL(jsonResult.GetArray().size(), 4);
            CheckSuggest(jsonResult[0], "node_1_opt_1", "go to node 1 option 1");
            CheckSuggest(jsonResult[1], "node_1_opt_2", "go to node 1 option 2");
            CheckSuggest(jsonResult[2], "node_c_4", "go to node 4");
            CheckSuggest(jsonResult[3], "node_c_3", "go to node 3");
        }
        { // sure topic not enough overriden nodes no append
            TTaxiSupportChatSuggestClient::TSuggestResponse response;
            response.SetSureTopic("class_3");
            response.SetMostProbableTopic("class_3");
            el1.Probability = 0.50;
            el2.Probability = 0.40;
            el3.Probability = 0.60;
            el4.Probability = 0.70;
            NDrive::TSupportPrediction predictions;
            predictions.Elements = { el1, el2, el3, el4 };
            response.SetPredictions(predictions);
            auto jsonResult = nodeResolver.GetSuggestedChatOptions(response);
            UNIT_ASSERT(jsonResult.IsArray());
            UNIT_ASSERT_VALUES_EQUAL(jsonResult.GetArray().size(), 2);
            CheckSuggest(jsonResult[0], "node_3_opt_1", "go to node 3 option 1");
            CheckSuggest(jsonResult[1], "node_3_opt_2", "go to node 3 option 2");
        }
        { // sure topic no override
            TTaxiSupportChatSuggestClient::TSuggestResponse response;
            response.SetSureTopic("class_4");
            response.SetMostProbableTopic("class_4");
            el1.Probability = 0.80;
            el2.Probability = 0.40;
            el3.Probability = 0.60;
            el4.Probability = 0.20;
            NDrive::TSupportPrediction predictions;
            predictions.Elements = { el1, el2, el3, el4 };
            response.SetPredictions(predictions);
            auto jsonResult = nodeResolver.GetSuggestedChatOptions(response);
            UNIT_ASSERT(jsonResult.IsArray());
            UNIT_ASSERT_VALUES_EQUAL(jsonResult.GetArray().size(), 4);
            CheckSuggest(jsonResult[0], "node_c_4", "go to node 4");
            CheckSuggest(jsonResult[1], "node_c_1", "go to node 1");
            CheckSuggest(jsonResult[2], "node_c_3", "go to node 3");
            CheckSuggest(jsonResult[3], "node_c_2", "go to node 2");
        }
        { // no sure topic
            TTaxiSupportChatSuggestClient::TSuggestResponse response;
            NDrive::TSupportPrediction predictions;
            el1.Probability = 0.20;
            el2.Probability = 0.05;
            el3.Probability = 0.60;
            el4.Probability = 0.50;
            predictions.Elements = { el1, el2, el3, el4 };
            response.SetPredictions(predictions);
            auto jsonResult = nodeResolver.GetSuggestedChatOptions(response);
            UNIT_ASSERT(jsonResult.IsArray());
            UNIT_ASSERT_VALUES_EQUAL(jsonResult.GetArray().size(), 4);
            CheckSuggest(jsonResult[0], "node_c_3", "go to node 3");
            CheckSuggest(jsonResult[1], "node_c_4", "go to node 4");
            CheckSuggest(jsonResult[2], "node_c_1", "go to node 1");
            CheckSuggest(jsonResult[3], "node_c_2", "go to node 2");
        }
        { // no sure topic
            TTaxiSupportChatSuggestClient::TSuggestResponse response;
            NDrive::TSupportPrediction predictions;
            el1.Probability = 0.80;
            el2.Probability = 0.70;
            el3.Probability = 0.60;
            el4.Probability = 0.50;
            predictions.Elements = { el1, el2, el3, el4 };
            response.SetPredictions(predictions);
            auto jsonResult = nodeResolver.GetSuggestedChatOptions(response);
            UNIT_ASSERT(jsonResult.IsArray());
            UNIT_ASSERT_VALUES_EQUAL(jsonResult.GetArray().size(), 4);
            CheckSuggest(jsonResult[0], "node_c_1", "go to node 1");
            CheckSuggest(jsonResult[1], "node_c_2", "go to node 2");
            CheckSuggest(jsonResult[2], "node_c_3", "go to node 3");
            CheckSuggest(jsonResult[3], "node_c_4", "go to node 4");
        }
    }

    NJson::TJsonValue GetSuggest(const NDrive::TServerConfigGenerator& gServer, const TString& chatId, const TString& userId, const TString& message) {
        NNeh::THttpRequest request;
        NJson::TJsonValue post;
        post["message"] = message;
        post["type"] = "plaintext";
        request.SetUri("/api/yandex/chat/get_suggest").SetCgiData("chat_id=" + chatId + "&user_id=" + userId).SetPostData(TBlob::FromString(post.GetStringRobust())).SetRequestType("POST");
        request.AddHeader("Authorization", userId);
        NUtil::THttpReply reply = gServer.GetSendReply(request);
        NJson::TJsonValue resultReport = NJson::JSON_MAP;
        UNIT_ASSERT_VALUES_EQUAL(reply.Code(), 200);
        UNIT_ASSERT_C(NJson::ReadJsonFastTree(reply.Content(), &resultReport), reply.Content());
        return resultReport;
    }

    Y_UNIT_TEST(SuggestOptionsInChat) {
        NDrive::TServerConfigGenerator gServer;
        gServer.SetSensorApiName({});
        TServerConfigConstructorParams params(gServer.GetString().data());
        NDrive::TServerConfig config(params);
        NDrive::TServerGuard server(config);

        TEnvironmentGenerator eGenerator(*server);
        eGenerator.BuildEnvironment(TEnvironmentGenerator::DefaultTraits);

        auto userId = eGenerator.CreateUser("skulik-was-not-there", false, "active");
        {
            UNIT_ASSERT(gServer.CommitChatAction(userId, "support_suggest.888", "!classify_suggest", {}));
            auto suggest = GetSuggest(gServer, "support_suggest.888", userId, "it's time to classify");
            UNIT_ASSERT(suggest["expected_action"].IsMap());
            UNIT_ASSERT_VALUES_EQUAL(suggest["expected_action"]["type"], "suggest");
            UNIT_ASSERT_VALUES_EQUAL(suggest["expected_action"]["text"], "");
            UNIT_ASSERT(suggest["schema"].IsMap());
            UNIT_ASSERT(suggest["schema"]["options"].IsArray());
            UNIT_ASSERT_VALUES_EQUAL(suggest["schema"]["options"].GetArray().size(), 2);
            CheckSuggest(suggest["schema"]["options"][0], "node_class_1", "Option 1");
            CheckSuggest(suggest["schema"]["options"][1], "node_class_2", "");
        }
        {
            UNIT_ASSERT(gServer.CommitChatAction(userId, "support_suggest.888", "node_class_1", {}));
            auto messagesMap = gServer.GetChatMessages(userId, "support_suggest.888");
            auto messages = messagesMap["messages"].GetArray();
            UNIT_ASSERT_VALUES_EQUAL(messages.size(), 5);
            UNIT_ASSERT_VALUES_EQUAL(messages[2]["text"].GetString(), "node suggest");
            UNIT_ASSERT_VALUES_EQUAL(messages[3]["text"].GetString(), "Option 1");
            UNIT_ASSERT_VALUES_EQUAL(messages[4]["text"].GetString(), "node_class_1");
        }
        {
            UNIT_ASSERT(gServer.CommitChatAction(userId, "support_suggest.999", "!classify_suggest", {}));
            UNIT_ASSERT(gServer.CommitChatAction(userId, "support_suggest.999", "node_class_2", {}));
            auto messagesMap = gServer.GetChatMessages(userId, "support_suggest.999");
            auto messages = messagesMap["messages"].GetArray();
            UNIT_ASSERT_VALUES_EQUAL(messages.size(), 4);
            UNIT_ASSERT_VALUES_EQUAL(messages[2]["text"].GetString(), "node suggest");
            UNIT_ASSERT_VALUES_EQUAL(messages[3]["text"].GetString(), "node_class_2");
        }
        {
            UNIT_ASSERT(gServer.CommitChatAction(userId, "support_suggest.101010", "!classify_suggest", {}));
            UNIT_ASSERT(gServer.CommitChatAction(userId, "support_suggest.101010", "just a message", {}));
            auto messagesMap = gServer.GetChatMessages(userId, "support_suggest.101010");
            auto messages = messagesMap["messages"].GetArray();
            UNIT_ASSERT_VALUES_EQUAL(messages.size(), 5);
            UNIT_ASSERT_VALUES_EQUAL(messages[2]["text"].GetString(), "node suggest");
            UNIT_ASSERT_VALUES_EQUAL(messages[3]["text"].GetString(), "just a message");
            UNIT_ASSERT_VALUES_EQUAL(messages[4]["text"].GetString(), "suggest message");
        }
    }
}
