#include "factory.h"

#include <infra/yasm/aldan/common/exceptions.h>

using namespace NYasm::NAldan;

namespace {
    class TTreeVisitor final : public TNodeVisitor {
    public:
        TTreeVisitor(TVirtualMachine& virtualMachine, const TInstructionFactory::TDescriptorIndex& descriptorIndex)
            : VirtualMachine(virtualMachine)
            , DescriptorIndex(descriptorIndex)
        {
        }

        void EnterNode(const TNode& node) override {
            Stack.emplace(TFrame{
                .NodeType = node.GetInferredType()
            });
        }

        void VisitInteger(i64 number) override {
            VirtualMachine.AddInstruction(TConstantIntegerOp{.Value=number});
        }

        void VisitDouble(double number) override {
            VirtualMachine.AddInstruction(TConstantDoubleOp{.Value=number});
        }

        void VisitIdent(const TString& ident) override {
            Y_UNUSED(ident);
            Y_FAIL();
        }

        void VisitApply(const TString& name) override {
            Stack.top().FunctionName = name;
        }

        void FinishNode() override {
            const auto& frame(Stack.top());
            if (frame.FunctionName) {
                const auto functionIt = DescriptorIndex.find(frame.FunctionName);
                if (functionIt.IsEnd()) {
                    ythrow TSymbolNotFound(frame.FunctionName) << "function " << frame.FunctionName << " not defined";
                }
                const auto overloadIt = functionIt->second.find(frame.NodeType);
                if (overloadIt.IsEnd()) {
                    ythrow TSymbolNotFound(frame.FunctionName) << "no overload for function " << frame.FunctionName << " found";
                }
                Y_ASSERT(overloadIt->second != nullptr);
                overloadIt->second->GenerateInstructions(VirtualMachine);
            }
            Stack.pop();
        }

    private:
        TVirtualMachine& VirtualMachine;
        const TInstructionFactory::TDescriptorIndex& DescriptorIndex;

        struct TFrame {
            TTypeConstructor::TPtr NodeType;
            TString FunctionName;
        };

        TStack<TFrame> Stack;
    };
}

TTypeConstructor::TPtr TInstructionFactory::GenerateInstructions(TNode& root, TVirtualMachine& virtualMachine) const {
    auto resultType = AugmentAST(root, TypeRegistry);
    TTreeVisitor visitor(virtualMachine, DescriptorIndex);
    root.Visit(visitor);
    return resultType;
}

void TInstructionFactory::AddDescriptor(THolder<IFunctionDescriptor> descriptor) {
    Descriptors.emplace_back(std::move(descriptor));
    auto* desc(Descriptors.back().Get());
    const auto functionName(desc->GetFunctionName());
    const auto functionTypes(desc->GetFunctionTypes());
    for (const auto& functionType : functionTypes) {
        TypeRegistry[functionName].emplace_back(functionType);
        DescriptorIndex[functionName][functionType] = desc;
    }
}
