#include "patterns.h"

#include <infra/yasm/aldan/common/exceptions.h>

#include <contrib/libs/antlr4_cpp_runtime/src/antlr4-runtime.h>
#include <infra/yasm/aldan/patterns/TPatternsBaseVisitor.h>

#include <util/generic/algorithm.h>
#include <util/generic/vector.h>

using namespace NYasm::NAldan;

namespace {
    const size_t MAX_COMBINATIONS = 100;

    using TBeams = TVector<TVector<TStringBuf>>;

    class TConsumer {
    public:
        virtual ~TConsumer() = default;

        virtual void AddSlice(TStringBuf content) = 0;
        virtual void AddConsumer(TConsumer& consumer) = 0;

        TBeams& GetBeams() {
            return Beams;
        }

    private:
        TBeams Beams;
    };

    class TExpressionConsumer final : public TConsumer {
    public:
        void AddSlice(TStringBuf content) override {
            auto& target(GetBeams());
            if (target.empty()) {
                target.emplace_back().emplace_back(content);
            } else {
                for (auto& beam : target) {
                    beam.emplace_back(content);
                }
            }
        }

        void AddConsumer(TConsumer& other) override {
            auto& source(other.GetBeams());
            auto& target(GetBeams());

            Y_ASSERT(!source.empty());
            const auto newSize = source.size() * Max(target.size(), 1UL);
            if (newSize > MAX_COMBINATIONS) {
                ythrow TParsingError() << "too many combinations generated: " << newSize << " > " << MAX_COMBINATIONS;
            }

            TBeams product(Reserve(newSize));
            if (target.empty()) {
                target.emplace_back();
            }
            for (auto& leftBeam : target) {
                for (auto& rightBeam : source) {
                    auto& newBeam(product.emplace_back(Reserve(leftBeam.size() + rightBeam.size())));
                    std::move(std::begin(leftBeam), std::end(leftBeam), std::back_inserter(newBeam));
                    std::move(std::begin(rightBeam), std::end(rightBeam), std::back_inserter(newBeam));
                }
            }
            target.swap(product);
        }
    };

    class TSpreadConsumer final : public TConsumer {
    public:
        void AddSlice(TStringBuf) override {
            Y_FAIL();
        }

        void AddConsumer(TConsumer& other) override {
            auto& source(other.GetBeams());
            auto& target(GetBeams());
            target.reserve(source.size() + target.size());
            std::move(std::begin(source), std::end(source), std::back_inserter(target));
        }
    };

    class TPatternsBuilder {
    public:
        void EnterExpression() {
            Y_ASSERT(Result.empty());
            Consumers.emplace_back(MakeHolder<TExpressionConsumer>());
        }

        void EnterSpread() {
            Y_ASSERT(Result.empty());
            Consumers.emplace_back(MakeHolder<TSpreadConsumer>());
        }

        void AddSlice(TStringBuf content) {
            Y_ASSERT(!Consumers.empty());
            Consumers.back()->AddSlice(content);
        }

        void LeaveSpread() {
            RemoveLastConsumer();
        }

        void LeaveExpression() {
            RemoveLastConsumer();
        }

        TVector<TString> GetResult() {
            TVector<TString> result(Reserve(Result.size()));
            for (const auto& beam : Result) {
                auto& target(result.emplace_back());
                target.reserve(Accumulate(beam.begin(), beam.end(), size_t(), [](size_t acc, const auto& x) {
                    return acc + x.size();
                }));
                for (const auto& part : beam) {
                    target.append(part);
                }
            }
            return result;
        }

    private:
        void RemoveLastConsumer() {
            Y_ASSERT(!Consumers.empty());
            auto last = std::move(Consumers.back());
            Consumers.pop_back();

            if (Consumers.empty()) {
                Result.swap(last->GetBeams());
            } else {
                Consumers.back()->AddConsumer(*last);
            }
        }

        TVector<THolder<TConsumer>> Consumers;
        TBeams Result;
    };

    class TPatternsVisitorImpl : public TPatternsBaseVisitor {
    public:
        TPatternsVisitorImpl(TStringBuf content, TPatternsBuilder& builder)
            : Content(content)
            , Builder(builder)
        {
        }

        antlrcpp::Any visitSlice(TPatternsParser::SliceContext *context) override {
            auto start = context->getStart()->getStartIndex();
            auto stop = context->getStop()->getStopIndex() + 1;
            Builder.AddSlice(Content.SubStr(start, stop - start));
            return nullptr;
        }

        antlrcpp::Any visitExpression(TPatternsParser::ExpressionContext *context) override {
            Builder.EnterExpression();
            auto result = visitChildren(context);
            Builder.LeaveExpression();
            return result;
        }

        antlrcpp::Any visitSpread(TPatternsParser::SpreadContext *context) override {
            Builder.EnterSpread();
            auto result = visitChildren(context);
            Builder.LeaveSpread();
            return result;
        }

    private:
        const TStringBuf Content;
        TPatternsBuilder& Builder;
    };
}

TVector<TString> NYasm::NAldan::ParsePatterns(const TString& expression) {
    try {
        struct TConfReaderErrorListener: public antlr4::BaseErrorListener {
            void syntaxError(antlr4::Recognizer* /* recognizer */,
                             antlr4::Token* /* token */,
                             size_t line,
                             size_t column,
                             const std::string& message,
                             std::exception_ptr /* e */) override
            {
                ythrow TParsingError() << message << " at line " << line << ":" << column;
            }
        } errorListener;

        antlr4::ANTLRInputStream input(expression.ConstRef());
        TPatternsLexer lexer(&input);
        lexer.removeErrorListeners();
        lexer.addErrorListener(&errorListener);

        antlr4::CommonTokenStream tokens(&lexer);
        tokens.fill();

        TPatternsParser parser(&tokens);
        parser.removeErrorListeners();
        parser.addErrorListener(&errorListener);

        TPatternsBuilder builder;
        TPatternsVisitorImpl visitor(expression, builder);
        visitor.visit(parser.root());
        return builder.GetResult();

    } catch (const TParsingError&) {
        throw;
    } catch (...) {
        ythrow TParsingError() << CurrentExceptionMessage();
    }
}
