package ru.yandex.sanitizer2;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;

import javax.annotation.Nonnull;

import com.helger.css.decl.CSSSelector;
import com.helger.css.decl.CSSStyleRule;
import com.helger.css.decl.ECSSSelectorCombinator;
import com.helger.css.decl.ICSSSelectorMember;

import ru.yandex.collection.IntList;
import ru.yandex.sanitizer2.config.ImmutableSanitizingConfig;

public class ComplexStyleProcessor implements StyleProcessor {
    private static final int PRECISION_STEP = 1024;
    private static final int ASTERISK_PRECISION = 0;
    private static final int ANY_PARENT_PRECISION =
        ASTERISK_PRECISION + PRECISION_STEP;
    private static final int DIRECT_PARENT_PRECISION =
        ANY_PARENT_PRECISION + PRECISION_STEP;
    private static final int TAG_PRECISION =
        DIRECT_PARENT_PRECISION + PRECISION_STEP;
    private static final int SHORT_TAG_PRECISION =
        TAG_PRECISION + PRECISION_STEP;
    private static final int CLASS_PRECISION =
        SHORT_TAG_PRECISION + PRECISION_STEP;
    private static final int ID_PRECISION = CLASS_PRECISION + PRECISION_STEP;
    private static final int INITIAL_STACK_CAPACITY = 16;
    private static final Function<String, IntList> INT_LIST_FACTORY =
        x -> new IntList(INITIAL_STACK_CAPACITY);
    private static final char MAX_SHORT_TAG_CHAR = 'z';

    // List of all rules and their selectors
    private final List<CssRule> css = new ArrayList<>();
    // Temporary buffer for matched css rules
    private final List<CssRule> matchedRules = new ArrayList<>();
    private final BasicStyle tmpStyle = new BasicStyle();
    private final IdentityHashMap<String, IntList> tagsIndexes =
        new IdentityHashMap<>();
    private final IntList[] shortTagsIndexes =
        new IntList[MAX_SHORT_TAG_CHAR + 1];
    private final Map<String, IntList> idsIndexes = new HashMap<>();
    private final Map<String, IntList> classesIndexes = new HashMap<>();
    private final ImmutableSanitizingConfig config;
    private AbstractHtmlTag[] path =
        new AbstractHtmlTag[INITIAL_STACK_CAPACITY];
    private int pathSize = 0;
    private int rulesCounter = 0;
    private int minRuleDepth = Integer.MAX_VALUE;
    private int idWeight = 0;
    private int classWeight = 0;
    private int tagWeight = 0;

    public ComplexStyleProcessor(final ImmutableSanitizingConfig config) {
        this.config = config;
    }

    private IntList computeShortTagIndexes(final char tagName) {
        IntList indexes = shortTagsIndexes[tagName];
        if (indexes == null) {
            indexes = new IntList(INITIAL_STACK_CAPACITY);
            shortTagsIndexes[tagName] = indexes;
        }
        return indexes;
    }

    private IntList getTagIndexes(final String tagName) {
        if (tagName.length() == 1) {
            char c = tagName.charAt(0);
            if (c <= MAX_SHORT_TAG_CHAR) {
                return shortTagsIndexes[c];
            }
        }
        return tagsIndexes.get(tagName);
    }

    private IntList computeTagIndexes(final String tagName) {
        if (tagName.length() == 1) {
            char c = tagName.charAt(0);
            if (c <= MAX_SHORT_TAG_CHAR) {
                return computeShortTagIndexes(c);
            }
        }
        return tagsIndexes.computeIfAbsent(tagName, INT_LIST_FACTORY);
    }

    @Override
    public void pushTag(@Nonnull final AbstractHtmlTag tag) {
        computeTagIndexes(tag.tagName()).addInt(pathSize);
        String id = tag.id();
        if (id != null) {
            idsIndexes.computeIfAbsent(id, INT_LIST_FACTORY).addInt(pathSize);
        }
        List<String> classes = tag.classes();
        int size = classes.size();
        for (int i = 0; i < size; ++i) {
            classesIndexes.computeIfAbsent(classes.get(i), INT_LIST_FACTORY)
                .add(pathSize);
        }
        if (pathSize == path.length) {
            path = Arrays.copyOf(path, pathSize << 1);
        }
        path[pathSize++] = tag;
    }

    @Override
    @Nonnull
    public void applyStyle(@Nonnull final AbstractHtmlTag tag) {
        if (minRuleDepth > pathSize) {
            return;
        }
        matchedRules.clear();
        int cssSize = css.size();
        int pos = pathSize - 1;
        for (int i = 0; i < cssSize; ++i) {
            CssRule rule = css.get(i);
            if (rule.depth <= pathSize
                && rule.matcher.matches(path, pos))
            {
                matchedRules.add(rule);
            }
        }

        int matchedRulesSize = matchedRules.size();
        if (matchedRulesSize > 0) {
            matchedRules.sort(CssRuleComparator.INSTANCE);
            tmpStyle.clear();
            for (int i = 0; i < matchedRulesSize; ++i) {
                tmpStyle.merge(matchedRules.get(i).style);
            }
            tmpStyle.merge(tag.style());
            Style calculateStyle = tmpStyle.compact();
            if (calculateStyle == tmpStyle) {
                // Not compacted
                // Copy tmpStyle, so it won't be externally modified
                calculateStyle = new BasicStyle(calculateStyle);
            }
            tag.style(calculateStyle);
        }
    }

    @Override
    public void popTag() {
        AbstractHtmlTag tag = path[--pathSize];
        getTagIndexes(tag.tagName()).removeLast();
        String id = tag.id();
        if (id != null) {
            idsIndexes.get(id).removeLast();
        }
        List<String> classes = tag.classes();
        int size = classes.size();
        for (int i = size; i-- > 0;) {
            classesIndexes.get(classes.get(i)).removeLast();
        }
    }

    private void addCssRule(final CssRule rule) {
        css.add(rule);
        if (rule.depth < minRuleDepth) {
            minRuleDepth = rule.depth;
        }
    }

    // tag name, class name, and id name expected to be sanitized, but not
    // obfuscated yet
    public void processSimpleStyleRule(
        final String tagName,
        final String className,
        final String idName,
        final List<CssDeclaration> declarations)
    {
        int idWeight = 0;
        int classWeight = 0;
        int tagWeight = 0;
        CssMatcher matcher = CssAsteriskMatcher.INSTANCE;
        if (tagName != null) {
            String internedTagName = config.internTag(tagName);
            if (internedTagName == null) {
                internedTagName = tagName;
            }
            if (internedTagName.length() == 1) {
                IntList tagIndexes;
                char c = internedTagName.charAt(0);
                if (c <= MAX_SHORT_TAG_CHAR) {
                    tagIndexes = computeShortTagIndexes(c);
                } else {
                    // impossble, but just in case
                    tagIndexes = tagsIndexes.computeIfAbsent(
                        internedTagName,
                        INT_LIST_FACTORY);
                }
                matcher = new CssShortTagMatcher(c, tagIndexes);
            } else {
                matcher =
                    new CssTagMatcher(
                        internedTagName,
                        tagsIndexes.computeIfAbsent(
                            internedTagName,
                            INT_LIST_FACTORY));
            }
            tagWeight = 1;
        }
        if (className != null) {
            matcher =
                matcher.and(
                    new CssClassMatcher(
                        classesIndexes.computeIfAbsent(
                            className,
                            INT_LIST_FACTORY)));
            classWeight = 1;
        }
        if (idName != null) {
            matcher =
                matcher.and(
                    new CssIdMatcher(
                        idName,
                        idsIndexes.computeIfAbsent(idName, INT_LIST_FACTORY)));
            idWeight = 1;
        }
        addCssRule(
            new CssRule(
                matcher,
                declarations,
                idWeight,
                classWeight,
                tagWeight,
                rulesCounter++));
    }

    private CssMatcher createMatcher(
        final CSSSelector selector,
        final int start,
        final int end)
    {
        CssMatcher matcher = CssAsteriskMatcher.INSTANCE;
        for (int i = start; i < end; ++i) {
            String member = selector.getMemberAtIndex(i).getAsCSSString();
            switch (member.length()) {
                case 0:
                    break;
                case 1:
                    char c = member.charAt(0);
                    if (c != '*') {
                        IntList tagIndexes;
                        c = Character.toLowerCase(c);
                        if (c <= MAX_SHORT_TAG_CHAR) {
                            tagIndexes = computeShortTagIndexes(c);
                        } else {
                            // Impossible, but just in case
                            String tagName =
                                config.internTag(Character.toString(c));
                            tagIndexes = tagsIndexes.computeIfAbsent(
                                tagName,
                                INT_LIST_FACTORY);
                        }
                        matcher =
                            matcher.and(new CssShortTagMatcher(c, tagIndexes));
                        ++tagWeight;
                    }
                    break;
                default:
                    switch (member.charAt(0)) {
                        case '.':
                            member = member.substring(1);
                            matcher =
                                matcher.and(
                                    new CssClassMatcher(
                                        classesIndexes.computeIfAbsent(
                                            member,
                                            INT_LIST_FACTORY)));
                            ++classWeight;
                            break;
                        case '#':
                            member = member.substring(1);
                            matcher =
                                matcher.and(
                                    new CssIdMatcher(
                                        member,
                                        idsIndexes.computeIfAbsent(
                                            member,
                                            INT_LIST_FACTORY)));
                            ++idWeight;
                            break;
                        default:
                            member = member.toLowerCase(Locale.ROOT);
                            String tagName = config.internTag(member);
                            if (tagName == null) {
                                tagName = member;
                            }
                            matcher =
                                matcher.and(
                                    new CssTagMatcher(
                                        tagName,
                                        tagsIndexes.computeIfAbsent(
                                            tagName,
                                            INT_LIST_FACTORY)));
                            ++tagWeight;
                            break;
                    }
                    break;
            }
        }
        return matcher;
    }

    public void processStyleRule(
        final CSSSelector selector,
        final List<CssDeclaration> declarations,
        final SanitizingContext context)
    {
        idWeight = 0;
        classWeight = 0;
        tagWeight = 0;
        CssMatcher matcher = CssAsteriskMatcher.INSTANCE;
        int membersCount = selector.getMemberCount();
        int prev = 0;
        for (int i = 0; i < membersCount; ++i) {
            ICSSSelectorMember member = selector.getMemberAtIndex(i);
            if (member instanceof ECSSSelectorCombinator) {
                matcher = matcher.and(createMatcher(selector, prev, i));
                prev = i + 1;
                switch ((ECSSSelectorCombinator) member) {
                    case BLANK:
                        matcher = new CssAnyParentMatcher(matcher);
                        break;
                    case GREATER:
                        matcher = CssDirectParentMatcher.create(matcher);
                        break;
                    default:
                        // ignore unsupported selector
                        return;
                }
            }
        }
        matcher = matcher.and(createMatcher(selector, prev, membersCount));
        addCssRule(
            new CssRule(
                matcher,
                declarations,
                idWeight,
                classWeight,
                tagWeight,
                rulesCounter++));
    }

    @Override
    public boolean processStyleRule(
        final CSSStyleRule styleRule,
        final SanitizingContext context)
    {
        List<CssDeclaration> declarations =
            new ArrayList<>(styleRule.getDeclarationCount());
        BasicStyle.convertDeclarations(
            declarations,
            config,
            styleRule,
            context);
        int selectorsCount = styleRule.getSelectorCount();
        for (int i = 0; i < selectorsCount; ++i) {
            processStyleRule(
                styleRule.getSelectorAtIndex(i),
                declarations,
                context);
        }
        return true;
    }

    @Override
    public StyleProcessor upgrade(final ImmutableSanitizingConfig config) {
        throw new UnsupportedOperationException();
    }

    private interface CssMatcher {
        boolean matches(final AbstractHtmlTag[] tags, int pos);

        // Higher value means higher probability of being "false", so in "and"
        // statements this matcher should be evaluated first
        int precision();

        // Minimal number of tags required for this rule to match current tag
        int depth();

        // find max position <= pos which is matching this matcher
        // return -1 if no matching tag found
        int findMatchingPos(AbstractHtmlTag[] tags, int pos);

        default CssMatcher and(CssMatcher matcher) {
            if (matcher == CssAsteriskMatcher.INSTANCE) {
                return this;
            } else {
                return new CssAndMatcher(this, matcher);
            }
        }
    }

    private enum CssAsteriskMatcher implements CssMatcher {
        INSTANCE;

        @Override
        public boolean matches(final AbstractHtmlTag[] tags, int pos) {
            return true;
        }

        @Override
        public int precision() {
            return ASTERISK_PRECISION;
        }

        @Override
        public int depth() {
            return 2;
        }

        @Override
        public int findMatchingPos(
            final AbstractHtmlTag[] tags,
            final int pos)
        {
            return pos;
        }

        @Override
        public CssMatcher and(final CssMatcher matcher) {
            return matcher;
        }
    }

    private static class CssAnyParentMatcher implements CssMatcher {
        private final CssMatcher matcher;
        private final int matcherDepth;
        private final int offset;
        private final int depth;

        CssAnyParentMatcher(final CssMatcher matcher) {
            if (matcher instanceof CssAnyParentMatcher) {
                CssAnyParentMatcher anyMatcher = (CssAnyParentMatcher) matcher;
                this.matcher = anyMatcher.matcher;
                matcherDepth = anyMatcher.matcherDepth;
                offset = anyMatcher.offset + 1;
            } else {
                this.matcher = matcher;
                matcherDepth = matcher.depth();
                offset = 0;
            }
            depth = matcherDepth + offset + 1;
        }

        @Override
        public boolean matches(final AbstractHtmlTag[] tags, int pos) {
            return findMatchingPos(tags, pos) >= 0;
        }

        @Override
        public int precision() {
            return ANY_PARENT_PRECISION;
        }

        @Override
        public int depth() {
            return depth;
        }

        @Override
        public int findMatchingPos(
            final AbstractHtmlTag[] tags,
            final int pos)
        {
            int start = pos - offset;
            if (start >= matcherDepth
                && matcher.findMatchingPos(tags, start - 1) >= 0)
            {
                return start;
            } else {
                return -1;
            }
        }
    }

    private static class CssDirectParentMatcher implements CssMatcher {
        private final CssMatcher matcher;
        private final int matcherDepth;
        private final int depth;

        CssDirectParentMatcher(final CssMatcher matcher) {
            this.matcher = matcher;
            matcherDepth = matcher.depth();
            depth = matcherDepth + 1;
        }

        public static CssMatcher create(final CssMatcher matcher) {
            if (matcher instanceof CssAnyParentMatcher) {
                // CssAnyParentMatcher folding will do everything we need
                return new CssAnyParentMatcher(matcher);
            } else {
                return new CssDirectParentMatcher(matcher);
            }
        }

        @Override
        public boolean matches(final AbstractHtmlTag[] tags, int pos) {
            return pos >= matcherDepth && matcher.matches(tags, pos - 1);
        }

        @Override
        public int precision() {
            // Lightweight check, slightly harded than matcher itself
            return matcher.precision() - 1;
        }

        @Override
        public int depth() {
            return depth;
        }

        @Override
        public int findMatchingPos(
            final AbstractHtmlTag[] tags,
            final int pos)
        {
            if (pos > 0 && matcher.matches(tags, pos - 1)) {
                return pos;
            } else {
                return -1;
            }
        }
    }

    private static abstract class CssIndexesMatcher implements CssMatcher {
        protected final IntList indexes;

        CssIndexesMatcher(final IntList indexes) {
            this.indexes = indexes;
        }

        @Override
        public int depth() {
            return 2;
        }

        @Override
        public int findMatchingPos(
            final AbstractHtmlTag[] tags,
            final int pos)
        {
            if (indexes.isEmpty()) {
                return -1;
            } else {
                int end = indexes.binarySearch(pos);
                if (end >= 0) {
                    return indexes.getInt(end);
                } else {
                    end = -end;
                    --end;
                    if (end > 0) {
                        return indexes.getInt(end - 1);
                    } else {
                        return -1;
                    }
                }
            }
        }
    }

    private static class CssTagMatcher extends CssIndexesMatcher {
        private final String tagName;

        CssTagMatcher(final String tagName, final IntList tagsIndexes) {
            super(tagsIndexes);
            this.tagName = tagName;
        }

        @SuppressWarnings("ReferenceEquality")
        @Override
        public boolean matches(final AbstractHtmlTag[] tags, int pos) {
            return tagName == tags[pos].tagName();
        }

        @Override
        public int precision() {
            return TAG_PRECISION;
        }
    }

    private static class CssShortTagMatcher extends CssIndexesMatcher {
        private final char tagName;

        CssShortTagMatcher(final char tagName, final IntList tagsIndexes) {
            super(tagsIndexes);
            this.tagName = tagName;
        }

        @Override
        public boolean matches(final AbstractHtmlTag[] tags, int pos) {
            String tagName = tags[pos].tagName();
            return tagName.length() == 1 && tagName.charAt(0) == this.tagName;
        }

        @Override
        public int precision() {
            return TAG_PRECISION;
        }
    }

    private static class CssClassMatcher extends CssIndexesMatcher {
        CssClassMatcher(final IntList classesIndexes) {
            super(classesIndexes);
        }

        @Override
        public boolean matches(
            final AbstractHtmlTag[] tags,
            final int pos)
        {
            return indexes.binarySearch(pos) >= 0;
        }

        @Override
        public int precision() {
            return CLASS_PRECISION;
        }
    }

    private static class CssIdMatcher extends CssIndexesMatcher {
        private final String id;

        CssIdMatcher(final String id, final IntList idsIndexes) {
            super(idsIndexes);
            this.id = id;
        }

        @Override
        public boolean matches(final AbstractHtmlTag[] tags, int pos) {
            return id.equals(tags[pos].id());
        }

        @Override
        public int precision() {
            return ID_PRECISION;
        }
    }

    private static class CssAndMatcher implements CssMatcher {
        private final CssMatcher first;
        private final CssMatcher second;
        private final int depth;

        CssAndMatcher(final CssMatcher first, final CssMatcher second) {
            // Select most precise matcher to be checked first
            if (first.precision() >= second.precision()) {
                this.first = first;
                this.second = second;
            } else {
                this.first = second;
                this.second = first;
            }
            depth = Math.max(first.depth(), second.depth());
        }

        @Override
        public boolean matches(final AbstractHtmlTag[] tags, int pos) {
            return first.matches(tags, pos) && second.matches(tags, pos);
        }

        @Override
        public int precision() {
            // Decrement by 2, so CssIdMatcher will be checked before
            // CssAndMatcher matcher containing id matcher
            // Also, CssDirectParentMatcher is more precise than CssAndMatcher
            return first.precision() - 2;
        }

        @Override
        public int depth() {
            return depth;
        }

        @Override
        public int findMatchingPos(
            final AbstractHtmlTag[] tags,
            final int pos)
        {
            int firstPos = first.findMatchingPos(tags, pos);
            while (firstPos >= 0) {
                int secondPos = second.findMatchingPos(tags, firstPos);
                if (secondPos < 0) {
                    return -1;
                }
                if (secondPos == firstPos) {
                    return firstPos;
                }
                firstPos = first.findMatchingPos(tags, secondPos);
            }
            return -1;
        }

        @Override
        public CssMatcher and(final CssMatcher matcher) {
            if (first.precision() < matcher.precision()) {
                return new CssTripleAndMatcher(matcher, first, second);
            } else if (second.precision() < matcher.precision()) {
                return new CssTripleAndMatcher(first, matcher, second);
            } else {
                return new CssTripleAndMatcher(first, second, matcher);
            }
        }
    }

    private static class CssTripleAndMatcher implements CssMatcher {
        private final CssMatcher first;
        private final CssMatcher second;
        private final CssMatcher third;
        private final int depth;

        // Matchers already ordered by precision
        CssTripleAndMatcher(
            final CssMatcher first,
            final CssMatcher second,
            final CssMatcher third)
        {
            this.first = first;
            this.second = second;
            this.third = third;
            depth =
                Math.max(
                    first.depth(),
                    Math.max(second.depth(), third.depth()));
        }

        @Override
        public int precision() {
            return first.precision() - 2;
        }

        @Override
        public int depth() {
            return depth;
        }

        @Override
        public boolean matches(
            final AbstractHtmlTag[] tags,
            final int pos)
        {
            return first.matches(tags, pos)
                && second.matches(tags, pos)
                && third.matches(tags, pos);
        }

        @Override
        public int findMatchingPos(
            final AbstractHtmlTag[] tags,
            final int pos)
        {
            int firstPos = first.findMatchingPos(tags, pos);
            while (firstPos >= 0) {
                int secondPos = second.findMatchingPos(tags, firstPos);
                if (secondPos < 0) {
                    return -1;
                }
                if (secondPos == firstPos) {
                    int thirdPos = third.findMatchingPos(tags, secondPos);
                    if (thirdPos < 0) {
                        return -1;
                    }
                    if (thirdPos == firstPos) {
                        return firstPos;
                    }
                    firstPos = first.findMatchingPos(tags, thirdPos);
                } else {
                    firstPos = first.findMatchingPos(tags, secondPos);
                }
            }
            return -1;
        }

        @Override
        public CssMatcher and(final CssMatcher matcher) {
            if (matcher == CssAsteriskMatcher.INSTANCE) {
                return this;
            }
            // CssAndMatcher will sort things out
            if (second.precision() < matcher.precision()) {
                return new CssAndMatcher(
                    new CssAndMatcher(first, matcher),
                    new CssAndMatcher(second, third));
            } else {
                return new CssAndMatcher(
                    new CssAndMatcher(first, second),
                    new CssAndMatcher(third, matcher));
            }
        }
    }

    private static class CssRule {
        private final CssMatcher matcher;
        private final List<CssDeclaration> style;
        private final int idWeight;
        private final int classWeight;
        private final int tagWeight;
        private final int id;
        private final int depth;

        CssRule(
            final CssMatcher matcher,
            final List<CssDeclaration> style,
            final int idWeight,
            final int classWeight,
            final int tagWeight,
            final int id)
        {
            this.matcher = matcher;
            this.style = style;
            this.idWeight = idWeight;
            this.classWeight = classWeight;
            this.tagWeight = tagWeight;
            this.id = id;
            depth = matcher.depth();
        }
    }

    private enum CssRuleComparator implements Comparator<CssRule> {
        INSTANCE;

        @Override
        public int compare(final CssRule lhs, final CssRule rhs) {
            int cmp = Integer.compare(lhs.idWeight, rhs.idWeight);
            if (cmp == 0) {
                cmp = Integer.compare(lhs.classWeight, rhs.classWeight);
                if (cmp == 0) {
                    cmp = Integer.compare(lhs.tagWeight, rhs.tagWeight);
                    if (cmp == 0) {
                        cmp = Integer.compare(lhs.id, rhs.id);
                    }
                }
            }
            return cmp;
        }
    }
}

