#include "augment.h"
#include "definitions.h"
#include "builtin_types.h"

#include <infra/yasm/aldan/common/exceptions.h>
#include <infra/yasm/common/labels/signal/signal_name.h>

using namespace NYasm::NAldan;
using namespace NZoom::NSignal;
using namespace NZoom::NAccumulators;

namespace {
    class IConsumer {
    public:
        virtual ~IConsumer() = default;

        virtual void Consume(TDefinition::TPtr) = 0;
        virtual TDefinition::TPtr Exhaust(TNode&) = 0;
    };

    class TSimpleConsumer final : public IConsumer {
    public:
        void Consume(TDefinition::TPtr definition) override {
            Definition = std::move(definition);
        }

        TDefinition::TPtr Exhaust(TNode& node) override {
            node.SetDefinition(Definition);
            return std::move(Definition);
        }

    private:
        TDefinition::TPtr Definition;
    };

    class TApplyConsumer final : public IConsumer {
    public:
        TApplyConsumer(TDefinition::TPtr function)
            : Function(std::move(function))
        {
        }

        void Consume(TDefinition::TPtr definition) override {
            Arguments.emplace_back(std::move(definition));
        }

        TDefinition::TPtr Exhaust(TNode& node) override {
            node.SetDefinition(Function);
            return MakeIntrusive<TApplyDefinition>(
                std::move(Function), MakeIntrusive<TTupleDefinition>(std::move(Arguments))
            );
        }

    private:
        TDefinition::TPtr Function;
        TVector<TDefinition::TPtr> Arguments;
    };

    class TConsumer final : public IConsumer {
    public:
        TConsumer(TNode& node)
            : Node(node)
            , Impl(MakeHolder<TSimpleConsumer>())
        {
        }

        void ToApplyConsumer(TDefinition::TPtr function) {
            Impl = MakeHolder<TApplyConsumer>(std::move(function));
        }

        void Consume(TDefinition::TPtr definition) override {
            Impl->Consume(std::move(definition));
        }

        TDefinition::TPtr Exhaust(TNode&) override {
            ythrow TInferError() << "consumer can't be exhausted";
        }

        TDefinition::TPtr Extract() {
            return Impl->Exhaust(Node);
        }

    private:
        TNode& Node;
        THolder<IConsumer> Impl;
    };

    class TTypeAugmentingVisitor final : public TNodeVisitor {
    public:
        void EnterNode(TNode& node) override {
            ConsumerStack.emplace_back(node);
        }

        void EnterNode(const TNode&) override {
            Y_FAIL();
        }

        void VisitInteger(i64) override {
            Peek().Consume(MakeIntrusive<TIdentifierDefinition>(INTEGER));
        }

        void VisitDouble(double) override {
            Peek().Consume(MakeIntrusive<TIdentifierDefinition>(DOUBLE));
        }

        void VisitIdent(const TString& ident) override {
            Peek().Consume(ConvertIdent(ident));
        }

        void VisitApply(const TString& name) override {
            Peek().ToApplyConsumer(MakeIntrusive<TIdentifierDefinition>(name));
        }

        void FinishNode() override {
            auto definition = Peek().Extract();
            ConsumerStack.pop_back();
            if (ConsumerStack.empty()) {
                Root = std::move(definition);
            } else {
                Peek().Consume(std::move(definition));
            }
        }

        TDefinition& GetRoot() const {
            if (!Root) {
                ythrow TInferError() << "given AST is empty or invalid";
            }
            return *Root;
        }

    private:
        TDefinition::TPtr ConvertIdent(const TString& ident) {
            TMaybe<TSignalName> signalName(TSignalName::TryNew(ident));
            if (signalName.Defined() && signalName->GetAggregationRules() != nullptr) {
                const auto& rules(*signalName->GetAggregationRules());
                return MakeIntrusive<TIdentifierDefinition>(InferType(rules.GetAccumulatorType(EAggregationMethod::MetaGroup)));
            }
            return MakeIntrusive<TIdentifierDefinition>(ident);
        }

        TTypeConstructor::TPtr InferType(EAccumulatorType accumulatorType) {
            switch (accumulatorType) {
                case EAccumulatorType::Average: {
                    return COUNTED_SUM_VALUE;
                }
                case EAccumulatorType::Counter: {
                    return COUNTER_VALUE;
                }
                case EAccumulatorType::List:
                case EAccumulatorType::Hgram: {
                    return HISTOGRAM_VALUE;
                }
                case EAccumulatorType::Avg:
                case EAccumulatorType::Max:
                case EAccumulatorType::Min:
                case EAccumulatorType::Summ:
                case EAccumulatorType::SummNone:
                case EAccumulatorType::Last: {
                    return DOUBLE_VALUE;
                }
            }
        }

        TConsumer& Peek() {
            if (ConsumerStack.empty()) {
                ythrow TInferError() << "given AST is invalid";
            }
            return ConsumerStack.back();
        }

        TDefinition::TPtr Root;
        TVector<TConsumer> ConsumerStack;
    };
}

TTypeConstructor::TPtr NYasm::NAldan::AugmentAST(TNode& root, const TRegistry& registry) {
    TTypeAugmentingVisitor visitor;
    root.Transform(visitor);
    return AnalyzeType(visitor.GetRoot(), registry);
}
