#pragma once

#include "case_node.h"

#include <infra/pod_agent/libs/behaviour/bt/nodes/base/basic_composite_node.h>

#include <util/generic/serialized_enum.h>
#include <util/string/cast.h>
#include <util/system/type_name.h>

#include <type_traits>

namespace NInfra::NPodAgent {

enum EConditionReturn {
    CR_SUCCESS /* "success" */,
    CR_FAILURE /* "failure" */
};

/**
 * Switch
 *
 * If the first child of switch node returns does not return SUCCESS,
 * then it's TTickResult is returned. Otherwise, it's TNodeSuccess.Message
 * is casted to the specified enum type, and then a subtree with a CaseNode
 * of corresponding value is executed and it's TTickResult is returned.
 */

template <class TEnum>
class TSwitchNode;

template <class TEnum>
using TSwitchNodePtr = TSimpleSharedPtr<TSwitchNode<TEnum>>;

template <class TEnum>
class TSwitchNode : public TBasicCompositeNode {
public:
    TSwitchNode(
        const TBasicTreeNodeDescriptor& descriptor
    )
        : TBasicCompositeNode(descriptor, {})
    {}

    virtual ENodeType GetType() const override final {
        return TSwitchNode::NODE_TYPE;
    }

    virtual void SetChildren(TVector<TBasicTreeNodePtr>&& children) override final {
        Switch_.clear();
        auto enumAllValues = GetEnumAllValues<TEnum>();
        for (const auto& name: enumAllValues) {
            Y_ENSURE(ToString(name) != TCaseNode::DEFAULT_CASE_VALUE, TStringBuilder() << "enum " << TypeName<TEnum>() << " has '" << TCaseNode::DEFAULT_CASE_VALUE << "' value");
        }

        Y_ENSURE(children.size() > 1, "switch node must have at least 2 children, but children count is " << children.size());
        ConditionChild_ = children.front();
        Y_ENSURE(dynamic_cast<TCaseNode*>(ConditionChild_.Get()) == nullptr, "condition child node can't have type TCaseNode");

        for (size_t i = 1; i < children.size(); ++i) {
            auto& node = children[i];
            TCaseNode *caseNode = dynamic_cast<TCaseNode*>(node.Get());
            Y_ENSURE(caseNode, "all children of Switch node must have type TCaseNode");

            if (caseNode->IsDefault()) {
                Y_ENSURE(!DefaultChild_, "children array of Switch node have more than 1 default cases");
                DefaultChild_ = node;
            } else {
                TEnum value = FromString(caseNode->GetCaseValue());
                Y_ENSURE(!Switch_.contains(value), "children array for Switch node has two Case nodes with same value: " << value);
                Switch_[value] = node;
            }
        }

        Y_ENSURE(DefaultChild_ || enumAllValues.size() + 1 == children.size(),
            "enum " << TypeName<TEnum>() << " option count " << enumAllValues.size() << " plus 1"
                    << " is unequal to passed children count " << children.size());

        TBasicCompositeNode::SetChildren(std::move(children));
    }

private:
    TEnum ExtractFromNodeSuccess(const TNodeSuccess& result) const {
        Y_ENSURE(result.Status == ENodeStatus::SUCCESS, "Unexpected node status for condition node");
        return FromString(result.Message);
    }

    virtual TTickResult TickImpl(TTickContextPtr tickContext) override final {
        if (RunningChild_ == nullptr) {
            RunningChild_ = ConditionChild_;
        }
        auto tickResult = RunningChild_->Tick(tickContext);
        if (!tickResult) {
            RunningChild_ = nullptr;
            return tickResult.Error();
        }
        if (tickResult.Success().Status == ENodeStatus::RUNNING)
            return tickResult.Success();

        if (RunningChild_ != ConditionChild_) {
            RunningChild_ = nullptr;
            return tickResult.Success();
        }
        RunningChild_ = nullptr;
        TEnum condition;
        try {
            condition = ExtractFromNodeSuccess(tickResult.Success());
        } catch (yexception& e) {
            return TNodeError{"switch child condition cast failed: " + TString(e.what())};
        }

        auto child = Switch_.FindPtr(condition);
        RunningChild_ = (child ? *child : DefaultChild_);

        tickResult = RunningChild_->Tick(tickContext);
        if (!tickResult) {
            RunningChild_ = nullptr;
            return tickResult.Error();
        }
        if (tickResult.Success().Status == ENodeStatus::RUNNING)
            return tickResult.Success();
        RunningChild_ = nullptr;
        return tickResult.Success();
    }

public:
    static constexpr const ENodeType NODE_TYPE = ENodeType::SWITCH;

private:
    TBasicTreeNodePtr ConditionChild_;
    TMap<TEnum, TBasicTreeNodePtr> Switch_;
    TBasicTreeNodePtr RunningChild_;
    TBasicTreeNodePtr DefaultChild_ = nullptr;
};

template <>
inline EConditionReturn TSwitchNode<EConditionReturn>::ExtractFromNodeSuccess(const TNodeSuccess& result) const {
    switch (result.Status) {
        case ENodeStatus::SUCCESS:
            return CR_SUCCESS;
        case ENodeStatus::FAILURE:
            return CR_FAILURE;
        default:
            ythrow yexception() << "Unexpected ENodeStatus value for EConditionReturn switch";
    }
}

} // namespace NInfra::NPodAgent
