#include "switch_node.h"
#include "switch_node_ut.h"

#include <infra/pod_agent/libs/behaviour/bt/nodes/base/mock_node.h>
#include <infra/pod_agent/libs/behaviour/bt/nodes/base/test/mock_tick_context.h>

#include <infra/libs/logger/logger.h>

#include <library/cpp/testing/unittest/registar.h>
#include <library/cpp/testing/unittest/tests_data.h>

namespace NInfra::NPodAgent::NTestSwitchNode {

using namespace NTest;

Y_UNIT_TEST_SUITE(SwitchNodeSuite) {

static TLogger logger({});

TBasicTreeNodePtr NodeWithResult(TTickResult tickResult) {
    return new TMockNode({1, "title"}, tickResult);
}

TBasicTreeNodePtr SingleChildCase(TTickResult tickResult, const TString& enumValue) {
    TCaseNodePtr result = new TCaseNode({1, "title"}, enumValue);
    result->SetChildren({NodeWithResult(tickResult)});
    return result;
}

template <class TEnum>
TSwitchNodePtr<TEnum> GetSwitch(TVector<TBasicTreeNodePtr>&& children) {
    TSwitchNodePtr<TEnum> result = new TSwitchNode<TEnum>({1, "title"});
    result->SetChildren(std::move(children));
    return result;
}

template<class TTickRunner>
class TFailingNode : public TBasicLeaf {
public:
    TFailingNode(
        const TBasicTreeNodeDescriptor& descriptor,
        const TTickRunner& runner
    )
        : TBasicLeaf(descriptor)
        , Runner_(runner)
    {}

    ENodeType GetType() const final {
        return ENodeType::UNKNOWN;
    }

private:
    TTickResult TickImpl(TTickContextPtr) final {
        ++CallCount_;
        return Runner_(CallCount_);
    }

public:
    size_t CallCount_{0};
    TTickRunner Runner_;
};

Y_UNIT_TEST(TestEverythingSuccessful) {
    TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ToString(TE_FAIL)};

    TVector<TBasicTreeNodePtr> children = {
        NodeWithResult(condition)
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, ToString(TE_FAIL))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
    };
    TSwitchNodePtr<ETestEnum> node;
    UNIT_ASSERT_NO_EXCEPTION(node = GetSwitch<ETestEnum>(std::move(children)));


    auto result = node->Tick(MockTickContext(logger));
    UNIT_ASSERT_EQUAL((TTickResult(TNodeSuccess{ENodeStatus::SUCCESS, "fail"})), result);
}

Y_UNIT_TEST(TestWrongCondition) {
    TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ToString(TE_FAIL)};

    TVector<TBasicTreeNodePtr> children = {
        SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, "")
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, ToString(TE_FAIL))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
    };
    UNIT_ASSERT_EXCEPTION_CONTAINS(GetSwitch<ETestEnum>(std::move(children)), yexception, "condition child node can't have type TCaseNode");
}

Y_UNIT_TEST(TestWrongCaseChildren) {
    TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ToString(TE_FAIL)};

    TVector<TBasicTreeNodePtr> children = {
        NodeWithResult(condition)
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
        , NodeWithResult(condition)
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
    };
    UNIT_ASSERT_EXCEPTION_CONTAINS(GetSwitch<ETestEnum>(std::move(children)), yexception, "all children of Switch node must have type TCaseNode");
}

Y_UNIT_TEST(TestWrongCaseValue) {
    TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ToString(TE_FAIL)};

    TVector<TBasicTreeNodePtr> children = {
        NodeWithResult(condition)
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, "wrong_value")
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
    };
    UNIT_ASSERT_EXCEPTION(GetSwitch<ETestEnum>(std::move(children)), yexception);
}

Y_UNIT_TEST(TestSameCases) {
    TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ToString(TE_FAIL)};

    TVector<TBasicTreeNodePtr> children = {
        NodeWithResult(condition)
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
    };
    UNIT_ASSERT_EXCEPTION_CONTAINS(GetSwitch<ETestEnum>(std::move(children)), yexception, "children array for Switch node has two Case nodes with same value: " + ToString(TE_OK));
}

Y_UNIT_TEST(TestConditionReturn) {
    {
        TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ""};

        TVector<TBasicTreeNodePtr> children = {
            NodeWithResult(condition)
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, ToString(TE_FAIL))
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
        };
        auto node = GetSwitch<ETestEnum>(std::move(children));
        auto result = node->Tick(MockTickContext(logger));
        UNIT_ASSERT_STRING_CONTAINS(result.Error().Message, "switch child condition cast failed");
    }
    {
        TTickResult condition = TNodeSuccess{ENodeStatus::FAILURE, ""};

        TVector<TBasicTreeNodePtr> children = {
            NodeWithResult(condition)
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, ToString(TE_FAIL))
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
        };
        auto node = GetSwitch<ETestEnum>(std::move(children));
        auto result = node->Tick(MockTickContext(logger));
        UNIT_ASSERT_STRING_CONTAINS(result.Error().Message, "Unexpected node status for condition node");
    }
    TVector<TTickResult> testConditions = {
        TNodeError{"some_error"},
        TNodeSuccess{ENodeStatus::RUNNING, "reason"}
    };
    for (auto& condition: testConditions)
    {
        TVector<TBasicTreeNodePtr> children = {
            NodeWithResult(condition)
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, ToString(TE_FAIL))
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
        };
        auto node = GetSwitch<ETestEnum>(std::move(children));
        auto result = node->Tick(MockTickContext(logger));
        UNIT_ASSERT_EQUAL(result, condition);
    }
}

Y_UNIT_TEST(TestChildReturn) {
    TVector<TTickResult> testResults = {
        TNodeError{"some_error"},
        TNodeSuccess{ENodeStatus::FAILURE, "reason"},
        TNodeSuccess{ENodeStatus::RUNNING, "reason"}
    };
    for (auto& tick: testResults)
    {
        TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ToString(TE_FAIL)};

        TVector<TBasicTreeNodePtr> children = {
            NodeWithResult(condition)
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
            , SingleChildCase(tick, ToString(TE_FAIL))
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
        };
        auto node = GetSwitch<ETestEnum>(std::move(children));
        auto result = node->Tick(MockTickContext(logger));
        UNIT_ASSERT_EQUAL(result, tick);
    }
}

Y_UNIT_TEST(TestClearRunningOnError) {
    TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ToString(TE_FAIL)};

    TBasicTreeNodePtr conditionNode = new TFailingNode({1, "title"}, [](size_t callCount) -> TTickResult {
        if (callCount == 1) {
            return TNodeSuccess{ENodeStatus::SUCCESS, ToString(TE_OK)};
        }
        if (callCount == 2) {
            return TNodeSuccess{ENodeStatus::RUNNING, "condition running"};
        }
        UNIT_ASSERT_C(callCount <= 2, "unexpected call count");
        return TNodeError{""};
    });
    TCaseNodePtr okCase = new TCaseNode({1, "title"}, ToString(TE_OK));
    okCase->SetChildren({TBasicTreeNodePtr(new TFailingNode({2, "title"}, [](size_t callCount) -> TTickResult {
        if (callCount == 1) {
            return TNodeSuccess{ENodeStatus::RUNNING, "ok running"};
        }
        if (callCount == 2) {
            return TNodeError{"failure"};
        }
        UNIT_ASSERT_C(callCount <= 2, "unexpected call count");
        return TNodeError{""};
    }))});

    TVector<TBasicTreeNodePtr> children = {
        conditionNode
        , okCase
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, ToString(TE_FAIL))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
    };
    auto node = GetSwitch<ETestEnum>(std::move(children));
    auto result = node->Tick(MockTickContext(logger));
    UNIT_ASSERT_EQUAL(result, TNodeSuccess(ENodeStatus::RUNNING, "ok running"));
    result = node->Tick(MockTickContext(logger));
    UNIT_ASSERT_EQUAL(result, TNodeError{"failure"});
    result = node->Tick(MockTickContext(logger));
    UNIT_ASSERT_EQUAL(result, TNodeSuccess(ENodeStatus::RUNNING, "condition running"));
}

Y_UNIT_TEST(TestConditionReturnEnum) {
    {
        TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ""};

        TVector<TBasicTreeNodePtr> children = {
            NodeWithResult(condition)
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(CR_SUCCESS))
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, ToString(CR_FAILURE))
        };
        TSwitchNodePtr<EConditionReturn> node;
        UNIT_ASSERT_NO_EXCEPTION(node = GetSwitch<EConditionReturn>(std::move(children)));


        auto result = node->Tick(MockTickContext(logger));
        UNIT_ASSERT_EQUAL((TTickResult(TNodeSuccess{ENodeStatus::SUCCESS, "ok"})), result);
    }
    {
        TTickResult condition = TNodeSuccess{ENodeStatus::FAILURE, ""};

        TVector<TBasicTreeNodePtr> children = {
            NodeWithResult(condition)
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(CR_SUCCESS))
            , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, ToString(CR_FAILURE))
        };
        TSwitchNodePtr<EConditionReturn> node;
        UNIT_ASSERT_NO_EXCEPTION(node = GetSwitch<EConditionReturn>(std::move(children)));


        auto result = node->Tick(MockTickContext(logger));
        UNIT_ASSERT_EQUAL((TTickResult(TNodeSuccess{ENodeStatus::SUCCESS, "fail"})), result);
    }
}

Y_UNIT_TEST(TestDefaultInEnum) {
    TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ""};

    TVector<TBasicTreeNodePtr> children = {
        NodeWithResult(condition)
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(EWithDefaultField::OK))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, ToString(EWithDefaultField::FAIL))
    };

    UNIT_ASSERT_EXCEPTION_CONTAINS(GetSwitch<EWithDefaultField>(std::move(children)), yexception, TStringBuilder() << " has '" << TCaseNode::DEFAULT_CASE_VALUE << "' value");
}

Y_UNIT_TEST(TestMoreThanOneDefault) {
    TTickResult condition = TNodeSuccess{ENodeStatus::FAILURE, ""};

    TVector<TBasicTreeNodePtr> children = {
        NodeWithResult(condition)
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(CR_SUCCESS))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, ToString(CR_FAILURE))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, "default")
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, "default")
    };
    UNIT_ASSERT_EXCEPTION_CONTAINS(GetSwitch<EConditionReturn>(std::move(children)), yexception, "more than 1 default cases");
}

Y_UNIT_TEST(TestTooFewChildren) {
    TTickResult condition = TNodeSuccess{ENodeStatus::FAILURE, ""};

    TVector<TBasicTreeNodePtr> children;
    UNIT_ASSERT_EXCEPTION_CONTAINS(GetSwitch<EConditionReturn>(std::move(children)), yexception, "children count is 0");

    children.push_back(NodeWithResult(condition));
    UNIT_ASSERT_EXCEPTION_CONTAINS(GetSwitch<EConditionReturn>(std::move(children)), yexception, "children count is 1");
}

Y_UNIT_TEST(TestEverythingSuccessfulWithDefaultCase) {
    TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ToString(TE_FAIL)};

    TVector<TBasicTreeNodePtr> children = {
        NodeWithResult(condition)
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "fail"}, TCaseNode::DEFAULT_CASE_VALUE)
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
    };
    TSwitchNodePtr<ETestEnum> node;
    UNIT_ASSERT_NO_EXCEPTION(node = GetSwitch<ETestEnum>(std::move(children)));


    auto result = node->Tick(MockTickContext(logger));
    UNIT_ASSERT_EQUAL((TTickResult(TNodeSuccess{ENodeStatus::SUCCESS, "fail"})), result);
}

Y_UNIT_TEST(TestCasesLessThanEnumsWithoutDefault) {
    TTickResult condition = TNodeSuccess{ENodeStatus::SUCCESS, ToString(TE_FAIL)};

    TVector<TBasicTreeNodePtr> children = {
        NodeWithResult(condition)
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "ok"}, ToString(TE_OK))
        , SingleChildCase(TNodeSuccess{ENodeStatus::SUCCESS, "semi_fail"}, ToString(TE_SEMI_FAIL))
    };
    UNIT_ASSERT_EXCEPTION_CONTAINS(GetSwitch<ETestEnum>(std::move(children)), yexception, "is unequal to passed children count");
}

}

} // namespace NInfra::NPodAgent::NTestSwitchNode
