#include "mtp_period_ticker.h"

#include <infra/pod_agent/libs/behaviour/bt/nodes/base/basic_leaf.h>
#include <infra/pod_agent/libs/behaviour/bt/template_tree/template_tree.h>
#include <infra/pod_agent/libs/behaviour/loaders/behavior3_editor_json_reader.h>
#include <infra/pod_agent/libs/behaviour/loaders/proto/behavior3.pb.h>
#include <infra/pod_agent/libs/pod_agent/object_meta/test_lib/test_functions.h>

#include <library/cpp/testing/unittest/env.h>
#include <library/cpp/testing/unittest/registar.h>

#include <util/stream/file.h>

namespace NInfra::NPodAgent::NTestMtpPeriodTicker {

static TLogger logger({});

Y_UNIT_TEST_SUITE(TMtpPeriodTickerSuite) {

Y_UNIT_TEST(TestSimpleTree) {
    TBehavior3 templateTree = CreateTestTemplateBehaviorTree();

    TMtpPeriodTicker ticker(TBehaviorTickerConfig{}, logger.SpawnFrame(), logger.SpawnFrame());
    ticker.Start();

    TLayerStatusRepositoryPtr statusRepository = new TLayerStatusRepository();
    statusRepository->AddObject(NObjectMetaTestLib::CreateLayerMetaSimple("my_layer", "my_layer_download_hash"));

    const TString treeId = "treeId";
    ticker.AddTree(new TTree(logger, treeId, TBehavior3EditorJsonReader(templateTree).WithLayerStatusRepository(statusRepository).BuildRootNode()), "hash");
    UNIT_ASSERT_EQUAL(true, ticker.HasTree(treeId));
    UNIT_ASSERT_EQUAL(false, ticker.HasTree(treeId + "_2"));
    UNIT_ASSERT_EQUAL(TVector<TString>({treeId}), ticker.GetTreeIds());

    Sleep(TDuration::MilliSeconds(300));

    ticker.RemoveTree(treeId);
    UNIT_ASSERT_EQUAL(false, ticker.HasTree(treeId));
    UNIT_ASSERT_EQUAL(API::ELayerState_DOWNLOADING, statusRepository->GetObjectStatus("my_layer").state());

    ticker.Stop();
}

Y_UNIT_TEST(TestThrowsOnNonExist) {
    TMtpPeriodTicker ticker(TBehaviorTickerConfig{}, logger.SpawnFrame(), logger.SpawnFrame());

    const TString treeId = "treeId";
    ticker.AddTree(new TTree(logger, treeId, nullptr), "hash");
    UNIT_ASSERT_EQUAL(true, ticker.HasTree(treeId));
    UNIT_ASSERT_EQUAL(false, ticker.HasTree(treeId + "_2"));
    UNIT_ASSERT_EQUAL(TVector<TString>({treeId}), ticker.GetTreeIds());

    UNIT_ASSERT_EXCEPTION_CONTAINS(ticker.AddTree(new TTree(logger, treeId, nullptr), "different_hash"), yexception, "different hash");
    UNIT_ASSERT_EXCEPTION_CONTAINS(ticker.RemoveTree(treeId + "_2"), yexception, "doesn't have tree with id");

    UNIT_ASSERT_EQUAL(true, ticker.HasTree(treeId));
    UNIT_ASSERT_EQUAL(false, ticker.HasTree(treeId + "_2"));
    UNIT_ASSERT_EQUAL(TVector<TString>({treeId}), ticker.GetTreeIds());

    ticker.RemoveTree(treeId);
    UNIT_ASSERT_EQUAL(false, ticker.HasTree(treeId));

    ticker.Wait();
}

Y_UNIT_TEST(TestTreeWithMultipleNumberOfOccurrences) {
    TMtpPeriodTicker ticker(TBehaviorTickerConfig{}, logger.SpawnFrame(), logger.SpawnFrame());

    const TString treeId = "treeId";
    const size_t numberOfOccurrences = 10;

    for (size_t i = 0; i < numberOfOccurrences; ++i) {
        ticker.AddTree(new TTree(logger, treeId, nullptr), "hash");
        UNIT_ASSERT_EQUAL(true, ticker.HasTree(treeId));
    }
    for (size_t i = 0; i < numberOfOccurrences; ++i) {
        UNIT_ASSERT_EQUAL(true, ticker.HasTree(treeId));
        ticker.RemoveTree(treeId);
    }

    UNIT_ASSERT_EQUAL(false, ticker.HasTree(treeId));

    ticker.Wait();
}

Y_UNIT_TEST(TestRestartTicker) {
    TBehavior3 templateTree = CreateTestTemplateBehaviorTree();

    TMtpPeriodTicker ticker(TBehaviorTickerConfig{}, logger.SpawnFrame(), logger.SpawnFrame());
    ticker.Start();

    TLayerStatusRepositoryPtr statusRepository = new TLayerStatusRepository();
    statusRepository->AddObject(NObjectMetaTestLib::CreateLayerMetaSimple("my_layer", "my_layer_download_hash"));

    const TString treeId = "treeId";
    ticker.AddTree(new TTree(logger, treeId, TBehavior3EditorJsonReader(templateTree).WithLayerStatusRepository(statusRepository).BuildRootNode()), "hash");

    Sleep(TDuration::MilliSeconds(300));
    UNIT_ASSERT_EQUAL(API::ELayerState_DOWNLOADING, statusRepository->GetObjectStatus("my_layer").state());
    ticker.Stop();
    statusRepository->UpdateObjectState("my_layer_download_hash", API::ELayerState_UNKNOWN);

    ticker.Start();
    Sleep(TDuration::MilliSeconds(300));
    UNIT_ASSERT_EQUAL(API::ELayerState_DOWNLOADING, statusRepository->GetObjectStatus("my_layer").state());
}

Y_UNIT_TEST(TestDoubleStartException) {
    TMtpPeriodTicker ticker(TBehaviorTickerConfig{}, logger.SpawnFrame(), logger.SpawnFrame());

    ticker.Start();
    UNIT_ASSERT_EXCEPTION_CONTAINS(ticker.Start(), yexception, "TMtpPeriodTicker already started");
    ticker.Stop();
}

Y_UNIT_TEST(TestWaitRunningTrees) {
    class TMyMockNode : public TBasicLeaf {
    public:
        TMyMockNode() : TBasicLeaf({1, "title"}) {}

        ENodeType GetType() const final {
            return ENodeType::MOCK;
        }

        TTickResult TickImpl(TTickContextPtr) final {
            ++Calls_;
            return Calls_ < 3 ? TNodeSuccess(ENodeStatus::RUNNING) : TNodeSuccess(ENodeStatus::SUCCESS);
        }

    public:
        size_t Calls_ = 0;
    };

    TMtpPeriodTicker ticker(TBehaviorTickerConfig{}, logger.SpawnFrame(), logger.SpawnFrame());
    ticker.Start();

    TBasicTreeNodePtr node = new TMyMockNode();
    TTreePtr tree = new TTree(logger, "treeId", node);
    ticker.AddTree(tree, "hash");

    Sleep(TDuration::MilliSeconds(300));
    UNIT_ASSERT_C( ((TMyMockNode*)node.Get())->Calls_ < 3, ((TMyMockNode*)node.Get())->Calls_);
    ticker.Stop();
    UNIT_ASSERT_EQUAL_C( ((TMyMockNode*)node.Get())->Calls_, 3, ((TMyMockNode*)node.Get())->Calls_);
}

Y_UNIT_TEST(TestWaitRunningTreeOnRemove) {
    class TMyMockNode : public TBasicLeaf {
    public:
        TMyMockNode() : TBasicLeaf({1, "title"}) {}

        ENodeType GetType() const final {
            return ENodeType::MOCK;
        }

        TTickResult TickImpl(TTickContextPtr) final {
            ++Calls_;
            return Calls_ < 3 ? TNodeSuccess(ENodeStatus::RUNNING) : TNodeSuccess(ENodeStatus::SUCCESS);
        }

    public:
        size_t Calls_ = 0;
    };

    TMtpPeriodTicker ticker(TBehaviorTickerConfig{}, logger.SpawnFrame(), logger.SpawnFrame());
    ticker.Start();

    TBasicTreeNodePtr node = new TMyMockNode();
    TTreePtr tree = new TTree(logger, "treeId", node);
    ticker.AddTree(tree, "hash");

    Sleep(TDuration::MilliSeconds(300));
    UNIT_ASSERT_C( ((TMyMockNode*)node.Get())->Calls_ < 3, ((TMyMockNode*)node.Get())->Calls_);
    ticker.RemoveTree("treeId");
    UNIT_ASSERT_EQUAL_C( ((TMyMockNode*)node.Get())->Calls_, 3, ((TMyMockNode*)node.Get())->Calls_);

    ticker.Stop();
}

Y_UNIT_TEST(TestFailingTree) {
    class TMyMockNode : public TBasicLeaf {
    public:
        TMyMockNode() : TBasicLeaf({1, "title"}) {}

        ENodeType GetType() const final {
            return ENodeType::MOCK;
        }

        TTickResult TickImpl(TTickContextPtr) final {
            ++Calls_;
            if (Calls_ == 1) {
                return TNodeSuccess(ENodeStatus::FAILURE);
            }
            ++Errors_;
            return TNodeError{"Error"};
        }

    public:
        size_t Calls_ = 0;
        size_t Errors_ = 0;
    };

    TMtpPeriodTicker ticker(TBehaviorTickerConfig{}, logger.SpawnFrame(), logger.SpawnFrame());
    ticker.Start();

    TBasicTreeNodePtr node = new TMyMockNode();
    TTreePtr tree = new TTree(logger, "treeId", node);
    ticker.AddTree(tree, "hash");

    Sleep(TDuration::MilliSeconds(500));
    ticker.Stop();
    UNIT_ASSERT_C( ((TMyMockNode*)node.Get())->Calls_ > 0, ((TMyMockNode*)node.Get())->Calls_);
    UNIT_ASSERT_C( ((TMyMockNode*)node.Get())->Errors_ > 0, ((TMyMockNode*)node.Get())->Errors_);
}

Y_UNIT_TEST(TestErrorCompacting) {
    class TMyMockNode : public TBasicLeaf {
    public:
        TMyMockNode() : TBasicLeaf({1, "title"}) {}

        ENodeType GetType() const final {
            return ENodeType::MOCK;
        }

        TTickResult TickImpl(TTickContextPtr) final {
            AtomicIncrement(Calls_);
            if (AtomicGet(Calls_) == 1 || AtomicGet(Calls_) > 20) {
                return TNodeSuccess(ENodeStatus::SUCCESS);
            }
            return TNodeError{"Error"};
        }

    public:
        TAtomic Calls_ = 0;
    };

    class TMyMockBackend: public TLogBackend {
    public:
        TMyMockBackend(TAtomic* countOfTBehaviourTreeTickError, TAtomic* countOfTBehaviourTreeTickV2, TAtomic* countOther)
            : CountOfTBehaviourTreeTickError_(countOfTBehaviourTreeTickError)
            , CountOfTBehaviourTreeTickV2_(countOfTBehaviourTreeTickV2)
            , CountOther_(countOther)
        {}

        void WriteData(const TLogRecord& logRecord) final {
            TStringBuf logStr = TStringBuf{logRecord.Data, logRecord.Len};
            if (logStr.find("TBehaviourTreeTickError") != TStringBuf::npos) {
                AtomicIncrement(*CountOfTBehaviourTreeTickError_);
            } else if (logStr.find("TBehaviourTreeTickV2") != TStringBuf::npos) {
                AtomicIncrement(*CountOfTBehaviourTreeTickV2_);
            } else {
                AtomicIncrement(*CountOther_);
            }
        }

        void ReopenLog() final {
        }

    public:
        TAtomic* CountOfTBehaviourTreeTickError_;
        TAtomic* CountOfTBehaviourTreeTickV2_;
        TAtomic* CountOther_;
    };

    const size_t tryIterations = 10;
    TAtomic countOfTBehaviourTreeTickError = 0;
    TAtomic countOfTBehaviourTreeTickV2 = 0;
    TAtomic countOther = 0;
    TAutoPtr<TLogBackend> backend(new TMyMockBackend{&countOfTBehaviourTreeTickError, &countOfTBehaviourTreeTickV2, &countOther});
    TBehaviorTickerConfig config;
    config.SetTreeTickPeriodMs(5);
    TMtpPeriodTicker ticker(config, MakeAtomicShared<TLogFrame>(MakeAtomicShared<TLog>(std::move(backend)), TAtomicSharedPtr<NErrorBoosterClient::TErrorBoosterClient>{}, 0, MakeAtomicShared<TAtomic>(ELogPriority::TLOG_CRIT), TLoggerConfig::JSON), logger.SpawnFrame());
    ticker.Start();

    TBasicTreeNodePtr node = new TMyMockNode();
    TTreePtr tree = new TTree(logger, "treeId", node);
    ticker.AddTree(tree, "hash");

    for (size_t i = 0; i < tryIterations; ++i) {
        Sleep(TDuration::MilliSeconds(500));
        if (AtomicGet(((TMyMockNode*)node.Get())->Calls_) > 50) {
            break;
        }
    }
    ticker.Stop();
    UNIT_ASSERT_C(AtomicGet(((TMyMockNode*)node.Get())->Calls_) > 50, AtomicGet(((TMyMockNode*)node.Get())->Calls_));
    UNIT_ASSERT_EQUAL_C(AtomicGet(countOfTBehaviourTreeTickError), 10, AtomicGet(countOfTBehaviourTreeTickError));
    UNIT_ASSERT_EQUAL_C(AtomicGet(countOfTBehaviourTreeTickV2), 10, AtomicGet(countOfTBehaviourTreeTickV2));
    UNIT_ASSERT_EQUAL_C(AtomicGet(countOther), 0, AtomicGet(countOther));
}

Y_UNIT_TEST(TestVariousErrorCompacting) {
    class TMyMockNode : public TBasicLeaf {
    public:
        TMyMockNode() : TBasicLeaf({1, "title"}) {}

        ENodeType GetType() const final {
            return ENodeType::MOCK;
        }

        TTickResult TickImpl(TTickContextPtr) final {
            AtomicIncrement(Calls_);
            if (AtomicGet(Calls_) < 20) {
                return TNodeError{"Error"};
            }
            if (AtomicGet(Calls_) < 40) {
                return TNodeError{"OtherError"};
            }
            if (AtomicGet(Calls_) < 60) {
                return TNodeSuccess(ENodeStatus::FAILURE);
            }
            if (AtomicGet(Calls_) == 60) {
                return TNodeSuccess(ENodeStatus::SUCCESS);
            }
            if (AtomicGet(Calls_) < 80) {
                return TNodeSuccess(ENodeStatus::FAILURE);
            }
            return TNodeSuccess(ENodeStatus::SUCCESS);
        }

    public:
        TAtomic Calls_ = 0;
    };

    class TMyMockBackend: public TLogBackend {
    public:
        TMyMockBackend(TAtomic* countOfTBehaviourTreeTickError, TAtomic* countOfTBehaviourTreeTickV2, TAtomic* countOther)
            : CountOfTBehaviourTreeTickError_(countOfTBehaviourTreeTickError)
            , CountOfTBehaviourTreeTickV2_(countOfTBehaviourTreeTickV2)
            , CountOther_(countOther)
        {}

        void WriteData(const TLogRecord& logRecord) final {
            TStringBuf logStr = TStringBuf{logRecord.Data, logRecord.Len};
            if (logStr.find("TBehaviourTreeTickError") != TStringBuf::npos) {
                AtomicIncrement(*CountOfTBehaviourTreeTickError_);
            } else if (logStr.find("TBehaviourTreeTickV2") != TStringBuf::npos) {
                AtomicIncrement(*CountOfTBehaviourTreeTickV2_);
            } else {
                AtomicIncrement(*CountOther_);
            }
        }

        void ReopenLog() final {
        }

    public:
        TAtomic* CountOfTBehaviourTreeTickError_;
        TAtomic* CountOfTBehaviourTreeTickV2_;
        TAtomic* CountOther_;
    };

    const size_t tryIterations = 100;
    TAtomic countOfTBehaviourTreeTickError = 0;
    TAtomic countOfTBehaviourTreeTickV2 = 0;
    TAtomic countOther = 0;
    TAutoPtr<TLogBackend> backend(new TMyMockBackend{&countOfTBehaviourTreeTickError, &countOfTBehaviourTreeTickV2, &countOther});
    TBehaviorTickerConfig config;
    config.SetTreeTickPeriodMs(3);
    TMtpPeriodTicker ticker(config, MakeAtomicShared<TLogFrame>(MakeAtomicShared<TLog>(std::move(backend)), TAtomicSharedPtr<NErrorBoosterClient::TErrorBoosterClient>{}, 0, MakeAtomicShared<TAtomic>(ELogPriority::TLOG_CRIT), TLoggerConfig::JSON), logger.SpawnFrame());
    ticker.Start();

    TBasicTreeNodePtr node = new TMyMockNode();
    TTreePtr tree = new TTree(logger, "treeId", node);
    ticker.AddTree(tree, "hash");

    for (size_t i = 0; i < tryIterations; ++i) {
        Sleep(TDuration::MilliSeconds(500));
        if (AtomicGet(((TMyMockNode*)node.Get())->Calls_) > 200) {
            break;
        }

    }
    ticker.Stop();
    UNIT_ASSERT_C(AtomicGet(((TMyMockNode*)node.Get())->Calls_) > 200, AtomicGet(((TMyMockNode*)node.Get())->Calls_));
    UNIT_ASSERT_EQUAL_C(AtomicGet(countOfTBehaviourTreeTickError), 42, AtomicGet(countOfTBehaviourTreeTickError));
    UNIT_ASSERT_EQUAL_C(AtomicGet(countOfTBehaviourTreeTickV2), 40, AtomicGet(countOfTBehaviourTreeTickV2));
    UNIT_ASSERT_EQUAL_C(AtomicGet(countOther), 0, AtomicGet(countOther));
}

Y_UNIT_TEST(TestErrorCompactingLogInterval) {
    class TMyMockNode : public TBasicLeaf {
    public:
        TMyMockNode() : TBasicLeaf({1, "title"}) {}

        ENodeType GetType() const final {
            return ENodeType::MOCK;
        }

        TTickResult TickImpl(TTickContextPtr) final {
            AtomicIncrement(Calls_);
            if (AtomicGet(Calls_) > 2000) {
                return TNodeSuccess(ENodeStatus::SUCCESS);
            }
            return TNodeError{"Error"};
        }

    public:
        TAtomic Calls_ = 0;
    };

    class TMyMockBackend: public TLogBackend {
    public:
        TMyMockBackend(TAtomic* countOfTBehaviourTreeTickError, TAtomic* countOfTBehaviourTreeTickV2, TAtomic* countOther)
            : CountOfTBehaviourTreeTickError_(countOfTBehaviourTreeTickError)
            , CountOfTBehaviourTreeTickV2_(countOfTBehaviourTreeTickV2)
            , CountOther_(countOther)
        {}

        void WriteData(const TLogRecord& logRecord) final {
            TStringBuf logStr = TStringBuf{logRecord.Data, logRecord.Len};
            if (logStr.find("TBehaviourTreeTickError") != TStringBuf::npos) {
                AtomicIncrement(*CountOfTBehaviourTreeTickError_);
            } else if (logStr.find("TBehaviourTreeTickV2") != TStringBuf::npos) {
                AtomicIncrement(*CountOfTBehaviourTreeTickV2_);
            } else {
                AtomicIncrement(*CountOther_);
            }
        }

        void ReopenLog() final {
        }

    public:
        TAtomic* CountOfTBehaviourTreeTickError_;
        TAtomic* CountOfTBehaviourTreeTickV2_;
        TAtomic* CountOther_;
    };

    const size_t tryIterations = 100;
    TAtomic countOfTBehaviourTreeTickError = 0;
    TAtomic countOfTBehaviourTreeTickV2 = 0;
    TAtomic countOther = 0;
    TAutoPtr<TLogBackend> backend(new TMyMockBackend{&countOfTBehaviourTreeTickError, &countOfTBehaviourTreeTickV2, &countOther});
    TBehaviorTickerConfig config;
    config.SetTreeTickPeriodMs(5);
    TMtpPeriodTicker ticker(config, MakeAtomicShared<TLogFrame>(MakeAtomicShared<TLog>(std::move(backend)), TAtomicSharedPtr<NErrorBoosterClient::TErrorBoosterClient>{}, 0, MakeAtomicShared<TAtomic>(ELogPriority::TLOG_CRIT), TLoggerConfig::JSON), logger.SpawnFrame());
    ticker.Start();

    TBasicTreeNodePtr node = new TMyMockNode();
    TTreePtr tree = new TTree(logger, "treeId", node);
    ticker.AddTree(tree, "hash");

    for (size_t i = 0; i < tryIterations; ++i) {
        Sleep(TDuration::MilliSeconds(500));
        if (AtomicGet(((TMyMockNode*)node.Get())->Calls_) > 2500) {
            break;
        }
    }
    ticker.Stop();
    UNIT_ASSERT_C(AtomicGet(((TMyMockNode*)node.Get())->Calls_) > 2500, AtomicGet(((TMyMockNode*)node.Get())->Calls_));
    UNIT_ASSERT_EQUAL_C(AtomicGet(countOfTBehaviourTreeTickError), 1989, AtomicGet(countOfTBehaviourTreeTickError));
    UNIT_ASSERT_EQUAL_C(AtomicGet(countOfTBehaviourTreeTickV2), 12, AtomicGet(countOfTBehaviourTreeTickV2));
    UNIT_ASSERT_EQUAL_C(AtomicGet(countOther), 0, AtomicGet(countOther));
}

}

} // namespace NInfra::NPodAgent::NTestMtpPeriodTicker
