#include "parser.h"

#include <infra/yasm/aldan/common/exceptions.h>
#include <infra/yasm/aldan/patterns/patterns.h>

#include <contrib/libs/antlr4_cpp_runtime/src/antlr4-runtime.h>
#include <infra/yasm/aldan/parser/TAldanBaseVisitor.h>

#include <util/generic/map.h>
#include <util/generic/hash.h>
#include <util/system/type_name.h>
#include <util/string/cast.h>

using namespace NYasm::NAldan;

namespace {
    class TVisitingNode : public TNode {
    public:
        void Visit(TNodeVisitor& visitor) const override {
            visitor.EnterNode(*this);
            VisitImpl(visitor);
            visitor.FinishNode();
        }

        void Transform(TNodeVisitor& visitor) override {
            visitor.EnterNode(*this);
            TransformImpl(visitor);
            visitor.FinishNode();
        }

        virtual void VisitImpl(TNodeVisitor& visitor) const = 0;

        virtual void TransformImpl(TNodeVisitor& visitor) {
            VisitImpl(visitor);
        }
    };

    class TIntegerNode final : public TVisitingNode {
    public:
        TIntegerNode(ui64 integer)
            : Integer(integer)
        {
        }

        void VisitImpl(TNodeVisitor& visitor) const override {
            visitor.VisitInteger(Integer);
        }

    private:
        ui64 Integer;
    };

    class TDoubleNode final : public TVisitingNode {
    public:
        TDoubleNode(double number)
            : Number(number)
        {
        }

        void VisitImpl(TNodeVisitor& visitor) const override {
            visitor.VisitDouble(Number);
        }

    private:
        double Number;
    };

    class TIdentNode final : public TVisitingNode {
    public:
        TIdentNode(TString ident)
            : Ident(std::move(ident))
        {
        }

        void VisitImpl(TNodeVisitor& visitor) const override {
            visitor.VisitIdent(Ident);
        }

    private:
        TString Ident;
    };

    class TApplyNode final : public TVisitingNode {
    public:
        const TString& GetName() const {
            if (Name.empty()) {
                ythrow TParsingError() << "name not set";
            }
            return Name;
        }

        void AddChild(THolder<TNode> node) override {
            if (Name.empty()) {
                TNameVisitor visitor(this);
                node->Visit(visitor);
            } else {
                Arguments.emplace_back(std::move(node));
            }
        }

        void VisitImpl(TNodeVisitor& visitor) const override {
            visitor.VisitApply(GetName());
            for (const auto& argument : Arguments) {
                argument->Visit(visitor);
            }
        }

        void TransformImpl(TNodeVisitor& visitor) override {
            visitor.VisitApply(GetName());
            for (const auto& argument : Arguments) {
                argument->Transform(visitor);
            }
        }

    private:
        class TNameVisitor : public TNodeVisitor {
        public:
            TNameVisitor(TApplyNode* parent)
                : Parent(parent)
            {
            }

            void VisitIdent(const TString& name) override {
                Parent->Name = name;
            }

        private:
            TApplyNode* Parent;
        };

        TString Name;
        TVector<THolder<TNode>> Arguments;
    };

    class TRawNode {
    public:
        virtual ~TRawNode() = default;

        virtual void AddChild(THolder<TRawNode>) {
            ythrow TParsingError() << ToString() << " can't have children";
        }

        virtual void Negate() {
            ythrow TParsingError() << ToString() << " can't be negated";
        }

        virtual THolder<TRawNode> Transform() {
            ythrow TParsingError() << ToString() << " can't be transformed";
        }

        virtual THolder<TNode> CreateNode() {
            ythrow TParsingError() << ToString() << " can't be converted to AST";
        }

        virtual void AddToNode(TNode& node) {
            node.AddChild(CreateNode());
        }

        virtual TString ToString() = 0;
    };

    class TIdentifier final : public TRawNode {
    public:
        TIdentifier(TString name)
            : Name(std::move(name))
        {
        }

        THolder<TNode> CreateNode() override {
            return MakeHolder<TIdentNode>(std::move(Name));
        }

        TString ToString() override {
            return "identifier";
        }

    private:
        TString Name;
    };

    class TNumber final : public TRawNode {
    public:
        TNumber(TString number)
            : Number(std::move(number))
        {
        }

        void Negate() override {
            Number.prepend('-');
        }

        THolder<TNode> CreateNode() override {
            try {
                return MakeHolder<TIntegerNode>(::FromString<ui64>(Number));
            } catch (...) {
                return MakeHolder<TDoubleNode>(::FromString<double>(Number));
            }
        }

        TString ToString() override {
            return "number";
        }

    private:
        TString Number;
    };

    class TFunctionCall final : public TRawNode {
    public:
        void AddChild(THolder<TRawNode> node) override {
            Arguments.emplace_back(std::move(node));
        }

        THolder<TRawNode> Transform() override {
            auto target = MakeHolder<TFunctionCall>();
            target->Arguments.swap(Arguments);
            return target;
        }

        THolder<TNode> CreateNode() override {
            auto target = MakeHolder<TApplyNode>();
            for (auto& arg : Arguments) {
                arg->AddToNode(*target);
            }
            return target;
        }

        TString ToString() override {
            return "function call";
        }

    private:
        TVector<THolder<TRawNode>> Arguments;
    };

    class TPatternTransform final : public TRawNode {
    public:
        TPatternTransform(TVector<THolder<TRawNode>> children)
            : Children(std::move(children))
        {
        }

        THolder<TNode> CreateNode() override {
            if (Children.size() == 1) {
                return Children.back()->CreateNode();
            } else {
                return TRawNode::CreateNode();
            }
        }

        void AddToNode(TNode& node) override {
            for (auto& child : Children) {
                child->AddToNode(node);
            }
        }

        TString ToString() override {
            return "pattern";
        }

    private:
        TVector<THolder<TRawNode>> Children;
    };

    class TNegateTransform final : public TRawNode {
    public:
        void AddChild(THolder<TRawNode> node) override {
            if (Argument) {
                ythrow TParsingError() << "only one argument supported";
            }
            Argument = std::move(node);
        }

        THolder<TRawNode> Transform() override {
            if (!Argument) {
                ythrow TParsingError() << "no argument given";
            }
            Argument->Negate();
            return std::move(Argument);
        }

        TString ToString() override {
            return "negate";
        }

    private:
        THolder<TRawNode> Argument;
    };

    class TExpression final : public TRawNode {
    public:
        void AddChild(THolder<TRawNode> node) override {
            Children.emplace_back(std::move(node));
        }

        THolder<TRawNode> Transform() override {
            if (Children.empty()) {
                ythrow TParsingError() << "empty expression given";
            } else if (Children.size() == 1) {
                return std::move(Children[0]);
            } else if (Children.size() == 2) {
                Children[0]->AddChild(std::move(Children[1]));
                return Children[0]->Transform();
            } else if (Children.size() == 3) {
                Children[1]->AddChild(std::move(Children[0]));
                Children[1]->AddChild(std::move(Children[2]));
                return Children[1]->Transform();
            } else {
                ythrow TParsingError() << "can't convert expression with " << Children.size() << " to AST";
            }
        }

        TString ToString() override {
            return "expression";
        }

    private:
        TVector<THolder<TRawNode>> Children;
    };

    class TRoot final : public TRawNode {
    public:
        void AddChild(THolder<TRawNode> node) override {
            if (Root) {
                ythrow TParsingError() << "root already exists";
            }
            Root = std::move(node);
        }

        THolder<TNode> CreateNode() override {
            if (!Root) {
                ythrow TParsingError() << "root not exists";
            }
            return Root->CreateNode();
        }

        TString ToString() override {
            return "root";
        }

    private:
        THolder<TRawNode> Root;
    };

    class TAstBuilder : public TNonCopyable {
    public:
        TAstBuilder() {
            NodeStack.emplace_back(MakeHolder<TRoot>());
        }

        void EnterFunctionCall() {
            NodeStack.emplace_back(MakeHolder<TFunctionCall>());
        }

        void EnterExpression() {
            NodeStack.emplace_back(MakeHolder<TExpression>());
        }

        void Leave() {
            auto last = std::move(Peek());
            NodeStack.pop_back();
            Peek()->AddChild(last->Transform());
        }


        void OnIdent(TString name) {
            Peek()->AddChild(CreateIdentifier(std::move(name)));
        }

        void OnPattern(TVector<TString> names) {
            TVector<THolder<TRawNode>> children(Reserve(names.size()));
            for (auto& name : names) {
                children.emplace_back(CreateIdentifier(std::move(name)));
            }
            Peek()->AddChild(MakeHolder<TPatternTransform>(std::move(children)));
        }

        void OnNumber(TString number) {
            Peek()->AddChild(MakeHolder<TNumber>(std::move(number)));
        }

        void OnUnaryOperator(TString name) {
            if (name == TStringBuf("minus")) {
                Peek()->AddChild(MakeHolder<TNegateTransform>());
            } else {
                OnFunctionCall(std::move(name));
            }
        }

        void OnFunctionCall(TString name) {
            auto node = MakeHolder<TFunctionCall>();
            node->AddChild(MakeHolder<TIdentifier>(std::move(name)));
            Peek()->AddChild(std::move(node));
        }

        THolder<TNode> GetAST() {
            return Peek()->CreateNode();
        }

    private:
        THolder<TRawNode> CreateIdentifier(TString name) {
            return MakeHolder<TIdentifier>(std::move(name));
        }

        THolder<TRawNode>& Peek() {
            if (NodeStack.empty()) {
                ythrow TParsingError() << "stack empty";
            }
            return NodeStack.back();
        }

        TVector<THolder<TRawNode>> NodeStack;
    };

    class TAldanVisitorImpl : public TAldanBaseVisitor {
    public:
        TAldanVisitorImpl(TAstBuilder& builder)
            : Builder(builder)
        {
            Naming.emplace("+", "plus");
            Naming.emplace("-", "minus");
            Naming.emplace("/", "divide");
            Naming.emplace("//", "integer_divide");
            Naming.emplace("*", "multiply");
            Naming.emplace("%", "modulo");
            Naming.emplace("^", "power");
            Naming.emplace("<<", "bitwise_left_shift");
            Naming.emplace(">>", "bitwise_right_shift");
            Naming.emplace("&", "bitwise_and");
            Naming.emplace("|", "bitwise_or");
            Naming.emplace("not", "not");
            Naming.emplace("and", "and");
            Naming.emplace("or", "or");
            Naming.emplace("<", "less");
            Naming.emplace(">", "greater");
            Naming.emplace("<=", "less_or_equal");
            Naming.emplace(">=", "greater_or_equal");
            Naming.emplace("~=", "like");
            Naming.emplace("==", "equal");
        }

        antlrcpp::Any visitFunctioncall(TAldanParser::FunctioncallContext *ctx) override {
            Builder.EnterFunctionCall();
            auto result = visitChildren(ctx);
            Builder.Leave();
            return result;
        }

        antlrcpp::Any visitExp(TAldanParser::ExpContext *ctx) override {
            Builder.EnterExpression();
            auto result = visitChildren(ctx);
            Builder.Leave();
            return result;
        }

        antlrcpp::Any visitFuncname(TAldanParser::FuncnameContext *ctx) override {
            Builder.OnIdent(TString(ctx->getText()));
            return defaultResult();
        }

        antlrcpp::Any visitVar(TAldanParser::VarContext *ctx) override {
            TString name(ctx->getText());
            Builder.OnPattern(ParsePatterns(name));
            return defaultResult();
        }

        antlrcpp::Any visitNumber(TAldanParser::NumberContext *ctx) override {
            Builder.OnNumber(TString(ctx->getText()));
            return defaultResult();
        }

        antlrcpp::Any visitOperatorOr(TAldanParser::OperatorOrContext *ctx) override {
            OnBinaryOperator(ctx->getText());
            return visitChildren(ctx);
        }

        antlrcpp::Any visitOperatorAnd(TAldanParser::OperatorAndContext *ctx) override {
            OnBinaryOperator(ctx->getText());
            return visitChildren(ctx);
        }

        antlrcpp::Any visitOperatorComparison(TAldanParser::OperatorComparisonContext *ctx) override {
            OnBinaryOperator(ctx->getText());
            return visitChildren(ctx);
        }

        antlrcpp::Any visitOperatorAddSub(TAldanParser::OperatorAddSubContext *ctx) override {
            OnBinaryOperator(ctx->getText());
            return visitChildren(ctx);
        }

        antlrcpp::Any visitOperatorMulDivMod(TAldanParser::OperatorMulDivModContext *ctx) override {
            OnBinaryOperator(ctx->getText());
            return visitChildren(ctx);
        }

        antlrcpp::Any visitOperatorBitwise(TAldanParser::OperatorBitwiseContext *ctx) override {
            OnBinaryOperator(ctx->getText());
            return visitChildren(ctx);
        }

        antlrcpp::Any visitOperatorPower(TAldanParser::OperatorPowerContext *ctx) override {
            OnBinaryOperator(ctx->getText());
            return visitChildren(ctx);
        }

        antlrcpp::Any visitOperatorUnary(TAldanParser::OperatorUnaryContext *ctx) override {
            OnUnaryOperator(ctx->getText());
            return visitChildren(ctx);
        }

    private:
        TString GetName(const std::string& name) {
            const auto it = Naming.find(name);
            if (it.IsEnd()) {
                ythrow TParsingError() << "unknown operator " << name << " given";
            }
            return it->second;
        }

        void OnUnaryOperator(const std::string& name) {
            Builder.OnUnaryOperator(GetName(name));
        }

        void OnBinaryOperator(const std::string& name) {
            Builder.OnFunctionCall(GetName(name));
        }

        TAstBuilder& Builder;
        THashMap<std::string, TString> Naming;
    };
}

THolder<TNode> NYasm::NAldan::ParseExpression(const TString& expression) {
    try {
        struct TConfReaderErrorListener: public antlr4::BaseErrorListener {
            void syntaxError(antlr4::Recognizer* /* recognizer */,
                             antlr4::Token* /* token */,
                             size_t line,
                             size_t column,
                             const std::string& message,
                             std::exception_ptr /* e */) override
            {
                ythrow TParsingError() << message << " at line " << line << ":" << column;
            }
        } errorListener;

        antlr4::ANTLRInputStream input(expression.ConstRef());
        TAldanLexer lexer(&input);
        lexer.removeErrorListeners();;
        lexer.addErrorListener(&errorListener);

        antlr4::CommonTokenStream tokens(&lexer);
        tokens.fill();

        TAldanParser parser(&tokens);
        parser.removeErrorListeners();
        parser.addErrorListener(&errorListener);

        TAstBuilder builder;
        TAldanVisitorImpl visitor(builder);
        visitor.visit(parser.root());
        return builder.GetAST();

    } catch (const TParsingError&) {
        throw;
    } catch (...) {
        ythrow TParsingError() << CurrentExceptionMessage() << Endl;
    }
}
