package ru.yandex.crypta.lab.formatters;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.antlr.v4.runtime.BailErrorStrategy;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.Parser;
import org.antlr.v4.runtime.RecognitionException;
import org.antlr.v4.runtime.Recognizer;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.misc.ParseCancellationException;
import org.jetbrains.annotations.NotNull;

import ru.yandex.crypta.lib.word.rule.WordRuleBaseVisitor;
import ru.yandex.crypta.lib.word.rule.WordRuleLexer;
import ru.yandex.crypta.lib.word.rule.WordRuleParser;

public abstract class BooleanRuleFormatter implements RuleConditionFormatter {
    protected final static String OR = "OR";
    final private static String AND = "AND";

    protected static class ExpressionNode {
        public String lemma;
        public List<ExpressionNode> operands;
        public String operator;
        public List<String> tags;

        public ExpressionNode(@NotNull String lemma) {
            this.lemma = lemma;
        }

        public ExpressionNode(@NotNull String operator, @NotNull ExpressionNode operand) {
            this.operator = operator;
            this.operands = List.of(operand);
        }

        public ExpressionNode(@NotNull String operator, @NotNull List<ExpressionNode> operands) {
            this.operator = operator.equals("-") ? AND : operator;
            this.operands = new ArrayList<>();
            for (var operand: operands) {
                if (this.operator.equals(operand.operator)) {
                    this.operands.addAll(operand.operands);
                } else {
                    this.operands.add(operand);
                }
            }
        }

        public ExpressionNode(@NotNull String operator, @NotNull List<ExpressionNode> operands, @NotNull List<String> tags) {
            this(operator, operands);
            this.tags = tags;
        }

        private String serialize(boolean isRoot) {
            if (lemma != null) {
                return lemma;
            } else if (operands.size() == 1) {
                return String.format("%s %s", operator, operands.get(0).serialize(false));
            } else {
                var delimiter = String.format(" %s ", operator);
                var serialized = operands.stream().map(x -> x.serialize(false)).collect(Collectors.joining(delimiter));
                if (!isRoot) {
                    serialized = String.format("(%s)", serialized);
                }
                return serialized;
            }
        }

        @Override
        public String toString() {
            return serialize(true);
        }

        public Set<String> getTags() {
            var tags = new HashSet<String>();
            if (this.tags != null) {
                tags.addAll(this.tags);
            }
            if (operands != null) {
                for (var operand: operands) {
                    tags.addAll(operand.getTags());
                }
            }
            return tags;
        }
    }

    private static class Visitor extends WordRuleBaseVisitor<ExpressionNode> {
        private final Function<WordRuleParser.LemmaExprContext, ExpressionNode> lemmaVisitor;

        public Visitor(Function<WordRuleParser.LemmaExprContext, ExpressionNode> lemmaVisitor) {
            this.lemmaVisitor = lemmaVisitor;
        }

        @Override
        public ExpressionNode visitLemmaExpr(@NotNull WordRuleParser.LemmaExprContext ctx) {
            return lemmaVisitor.apply(ctx);
        }

        @Override
        public ExpressionNode visitParenthesesExpr(@NotNull WordRuleParser.ParenthesesExprContext ctx) {
            return this.visit(ctx.expr);
        }

        @Override
        public ExpressionNode visitUnaryExpr(@NotNull WordRuleParser.UnaryExprContext ctx) {
            return new ExpressionNode(
                ctx.op.getText(),
                this.visit(ctx.expr)
            );
        }

        @Override
        public ExpressionNode visitBinaryExpr(@NotNull WordRuleParser.BinaryExprContext ctx) {
            return new ExpressionNode(
                ctx.op.getText(),
                List.of(this.visit(ctx.left), this.visit(ctx.right))
            );
        }

        @Override
        public ExpressionNode visitRoot(@NotNull WordRuleParser.RootContext ctx) {
            return this.visit(ctx.expression());
        }
    }

    private static class ErrorListener extends BaseErrorListener {
        private final List<String> errors;

        public ErrorListener() {
            errors = new ArrayList<>();
        }

        @Override
        public void syntaxError(Recognizer recognizer,
                                Object offendingSymbol,
                                int line,
                                int charPositionInLine,
                                String msg,
                                RecognitionException e)
        {
            errors.add(String.format("position %d: %s", charPositionInLine, msg));
        }

        public String getFirstError() {
            return errors.get(0);
        }
    }

    private static class ReportingInLineBailErrorStrategy extends BailErrorStrategy {
        @Override
        public Token recoverInline(Parser recognizer)
                throws RecognitionException
        {
            try {
                return super.recoverInline(recognizer);
            } catch (ParseCancellationException e) {
                reportError(recognizer, (RecognitionException) e.getCause());
                throw e;
            }
        }
    }

    protected static class ParsingException extends Exception {
        public ParsingException(String error) {
            super(error);
        }
    }

    public ExpressionNode getNormalizedRule(final String wordRule) throws ParsingException {
        var listener = new ErrorListener();
        var strategy = new ReportingInLineBailErrorStrategy();
        try {
            var lexer = new WordRuleLexer(CharStreams.fromString(wordRule));
            lexer.removeErrorListeners();
            lexer.addErrorListener(listener);

            var tokens = new CommonTokenStream(lexer);

            var parser = new WordRuleParser(tokens);
            parser.setErrorHandler(strategy);
            parser.removeErrorListeners();
            parser.addErrorListener(listener);

            return new Visitor(this::visitLemmaExpr).visit(parser.root());
        } catch (ParseCancellationException e) {
            throw new ParsingException(listener.getFirstError());
        }
    }

    protected abstract ExpressionNode visitLemmaExpr(@NotNull WordRuleParser.LemmaExprContext ctx);
}
