package ru.yandex.qe.logging.turbo;

import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Pattern;

import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.turbo.TurboFilter;
import ch.qos.logback.core.spi.FilterReply;
import org.slf4j.MDC;
import org.slf4j.Marker;

/**
 * @author intr13
 */
public class MDCLogThresholdFilter extends TurboFilter {

    public static final String MDC_THRESHOLD_KEY = "LOG_THRESHOLD";

    public static final String MDC_PATTERN_KEY = "LOG_PATTERN";

    private String thresholdKey = MDC_THRESHOLD_KEY;

    private String patternKey = MDC_PATTERN_KEY;

    private Level defaultThreshold = Level.ERROR;

    AtomicReference<PatternWithRaw> pattern = new AtomicReference<>();

    public String getThresholdKey() {
        return thresholdKey;
    }

    public void setThresholdKey(String thresholdKey) {
        this.thresholdKey = thresholdKey;
    }

    public String getPatternKey() {
        return patternKey;
    }

    public void setPatternKey(String patternKey) {
        this.patternKey = patternKey;
    }

    public Level getDefaultThreshold() {
        return defaultThreshold;
    }

    public void setDefaultThreshold(Level defaultThreshold) {
        this.defaultThreshold = defaultThreshold;
    }

    @Override
    public FilterReply decide(Marker marker, Logger logger, Level level, String format, Object[] params, Throwable t) {
        if (!isStarted()) {
            return FilterReply.NEUTRAL;
        }
        String rawThreshold = MDC.get(thresholdKey);
        if (rawThreshold == null || rawThreshold.isEmpty()) {
            return FilterReply.NEUTRAL;
        }
        Level threshold = Level.toLevel(rawThreshold, defaultThreshold);
        if (!level.isGreaterOrEqual(threshold)) {
            return FilterReply.NEUTRAL;
        }
        String rawPattern = MDC.get(patternKey);
        if (rawPattern == null || rawPattern.isEmpty()) {
            return FilterReply.ACCEPT;
        }
        PatternWithRaw pattern = this.pattern.get();
        if (pattern == null || !pattern.getRaw().equals(rawPattern)) {
            pattern = new PatternWithRaw(rawPattern, safeCompilePattern(rawPattern));
            this.pattern.set(pattern);
        }
        if (pattern.getPattern() == null) {
            return FilterReply.NEUTRAL;
        } else if (pattern.getPattern().matcher(logger.getName()).find()) {
            return FilterReply.ACCEPT;
        }
        return FilterReply.NEUTRAL;
    }

    private Pattern safeCompilePattern(String rawPattern) {
        try {
            return Pattern.compile(rawPattern);
        } catch (Exception e) {
            // ignore it!
            return null;
        }
    }

    static class PatternWithRaw {

        private final String raw;

        private final Pattern pattern;

        public PatternWithRaw(String raw, Pattern pattern) {
            this.raw = raw;
            this.pattern = pattern;
        }

        public String getRaw() {
            return raw;
        }

        public Pattern getPattern() {
            return pattern;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;

            PatternWithRaw that = (PatternWithRaw) o;

            return raw.equals(that.raw);
        }

        @Override
        public int hashCode() {
            return raw.hashCode();
        }
    }
}
