package ru.yandex.partner.libs.multistate.expression;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import javax.annotation.Nonnull;

import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.misc.ParseCancellationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.MessageSource;
import org.springframework.context.MessageSourceAware;
import org.springframework.context.support.DelegatingMessageSource;
import org.springframework.context.support.MessageSourceAccessor;

import ru.yandex.partner.core.multistate.Multistate;
import ru.yandex.partner.core.multistate.StateFlag;
import ru.yandex.partner.lib.multistate.parser.MultistatePredicateBaseVisitor;
import ru.yandex.partner.lib.multistate.parser.MultistatePredicateLexer;
import ru.yandex.partner.lib.multistate.parser.MultistatePredicateParser;
import ru.yandex.partner.libs.i18n.MsgWithArgs;
import ru.yandex.partner.libs.multistate.MultistatePredicates;
import ru.yandex.partner.libs.multistate.expression.antlr.ThrowingErrorListener;
import ru.yandex.partner.libs.multistate.messages.ExpressionParsingMsg;

public class MultistateExpressionParser<T extends StateFlag> implements MessageSourceAware {

    private static final Logger LOGGER = LoggerFactory.getLogger(MultistateExpressionParser.class);

    public static final ThrowingErrorListener ERROR_LISTENER = new ThrowingErrorListener();

    private final Visitor<T> visitor;

    private MessageSourceAccessor messages = new MessageSourceAccessor(new DelegatingMessageSource());

    public MultistateExpressionParser(Class<T> stateFlagClass) {
        this.visitor = new Visitor<>(stateFlagClass);
    }

    @Override
    public void setMessageSource(@Nonnull MessageSource messageSource) {
        messages = new MessageSourceAccessor(messageSource);
    }

    public Predicate<Multistate<T>> parseExpression(String expression) throws IllegalArgumentException,
            ParseCancellationException {
        try {
            MultistatePredicateParser parser = getParser(expression);
            return visitor.visit(parser.parse());
        } catch (ParseCancellationException e) {
            LOGGER.debug("Exception while parsing expression {}: {}", expression, e.getMessage());
            throw new ExpressionParsingException(
                    messages.getMessage(MsgWithArgs.of(ExpressionParsingMsg.SYNTAX_ERROR_IN_EXPRESSION, expression)),
                    e);
        }
    }

    public MultistatePredicateParser getParser(String expression) {
        MultistatePredicateLexer lexer = new MultistatePredicateLexer(CharStreams.fromString(expression));
        MultistatePredicateParser parser = new MultistatePredicateParser(new CommonTokenStream(lexer));
        parser.removeErrorListeners();
        parser.addErrorListener(ERROR_LISTENER);
        return parser;
    }


    private class Visitor<T extends StateFlag>
            extends MultistatePredicateBaseVisitor<Predicate<Multistate<T>>> {

        public static final String EMPTY_PARAMETER = "__EMPTY__";

        private final Map<String, Predicate<Multistate<T>>> statePredicates;

        public Visitor(Class<T> stateFlagClass) {
            if (!stateFlagClass.isEnum()) {
                throw new IllegalArgumentException("stateFlagClass must be enum");
            }
            statePredicates = new HashMap<>();
            statePredicates.put(EMPTY_PARAMETER, MultistatePredicates.empty());
            statePredicates.putAll(
                    Arrays.stream(stateFlagClass.getEnumConstants())
                            .filter(flag -> !flag.isPrivate())
                            .collect(Collectors.toMap(flag -> ((Enum) flag).name().toLowerCase(),
                                    MultistatePredicates::has))
            );
        }

        @Override
        public Predicate<Multistate<T>> visitParse(MultistatePredicateParser.ParseContext ctx) {
            return visit(ctx.expression());
        }

        @Override
        public Predicate<Multistate<T>> visitParenExpression(MultistatePredicateParser.ParenExpressionContext ctx) {
            return visit(ctx.expression());
        }

        @Override
        public Predicate<Multistate<T>> visitIdentifierExpression(
                MultistatePredicateParser.IdentifierExpressionContext ctx) {
            String stateName = ctx.IDENTIFIER().getText();
            if (!statePredicates.containsKey(stateName)) {
                throw new ExpressionParsingException(
                        messages.getMessage(MsgWithArgs.of(ExpressionParsingMsg.STATUS_DOES_NOT_EXIST, stateName))
                );
            }
            return statePredicates.get(stateName);
        }

        @Override
        public Predicate<Multistate<T>> visitNotExpression(MultistatePredicateParser.NotExpressionContext ctx) {
            return Predicate.not(visit(ctx.expression()));
        }

        @Override
        public Predicate<Multistate<T>> visitAndExpression(MultistatePredicateParser.AndExpressionContext ctx) {
            return visit(ctx.left).and(visit(ctx.right));
        }

        @Override
        public Predicate<Multistate<T>> visitOrExpression(MultistatePredicateParser.OrExpressionContext ctx) {
            return visit(ctx.left).or(visit(ctx.right));
        }

    }
}
