package ru.yandex.sanitizer2;

import java.io.IOException;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.Charset;
import java.nio.charset.CodingErrorAction;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Consumer;
import java.util.logging.Logger;

import org.apache.http.HttpEntity;
import org.apache.http.HttpException;
import org.apache.http.HttpRequest;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.mime.FormBodyPartBuilder;
import org.apache.http.entity.mime.MultipartEntityBuilder;
import org.apache.http.entity.mime.content.ByteArrayBody;
import org.apache.http.entity.mime.content.StringBody;
import org.apache.http.nio.entity.NByteArrayEntity;
import org.apache.http.nio.protocol.HttpAsyncExchange;
import org.apache.http.nio.protocol.HttpAsyncRequestConsumer;
import org.apache.http.nio.protocol.HttpAsyncRequestHandler;
import org.apache.http.protocol.HttpContext;
import org.owasp.html.Encoding;
import org.owasp.html.HtmlLexer;
import org.owasp.html.HtmlSanitizer;
import org.owasp.html.HtmlStreamRenderer;
import org.owasp.html.HtmlTextEscapingMode;
import org.owasp.html.HtmlToken;
import org.owasp.html.HtmlTokenType;
import org.owasp.html.PolicyFactory;
import org.owasp.html.TagBalancingHtmlStreamEventReceiver;

import ru.yandex.charset.Encoder;
import ru.yandex.charset.FilterCharset;
import ru.yandex.function.ByteArrayProcessor;
import ru.yandex.function.CharArrayProcessable;
import ru.yandex.function.CharArrayProcessor;
import ru.yandex.function.GenericAutoCloseable;
import ru.yandex.http.util.CharsetUtils;
import ru.yandex.http.util.nio.AsyncCharArrayProcessableConsumer;
import ru.yandex.http.util.nio.NByteArrayEntityFactory;
import ru.yandex.io.DecodableByteArrayOutputStream;
import ru.yandex.json.writer.JsonType;
import ru.yandex.parser.uri.QueryParser;
import ru.yandex.parser.uri.ScanningCgiParams;
import ru.yandex.sanitizer2.config.ImmutablePageHeaderConfig;
import ru.yandex.sanitizer2.config.ImmutableSanitizingConfig;
import ru.yandex.sanitizer2.config.ImmutableTagConfig;
import ru.yandex.util.timesource.TimeSource;

public class SanitizingHandler
    implements GenericAutoCloseable<RuntimeException>,
        HttpAsyncRequestHandler<CharArrayProcessable>
{
    private static final String MS = " ms";

    private final ImmutableSanitizingConfig config;
    private final PolicyFactory policy;
    private final PageHeaderApplier pageHeaderApplier;
    private final Runnable phishingLinksCallback;

    public SanitizingHandler(
        final ImmutableSanitizingConfig config,
        final Consumer<String> pageHeadersAccountant,
        final Runnable phishingLinksCallback)
        throws PageHeaderException
    {
        this.config = config;
        policy = config.createPolicyFactory();
        MultiPageHeaderApplier pageHeaderApplier =
            new MultiPageHeaderApplier(pageHeadersAccountant);
        for (Map.Entry<String, ImmutablePageHeaderConfig> entry
            : config.pageHeaders().entrySet())
        {
            ImmutablePageHeaderConfig pageHeaderConfig = entry.getValue();
            pageHeaderConfig.type().addPredicate(
                entry.getKey(),
                pageHeaderConfig,
                pageHeaderApplier);
        }
        this.pageHeaderApplier = pageHeaderApplier.compact();
        this.phishingLinksCallback = phishingLinksCallback;
    }

    public ImmutableSanitizingConfig config() {
        return config;
    }

    public PolicyFactory policy() {
        return policy;
    }

    @Override
    public void close() {
        pageHeaderApplier.close();
    }

    @Override
    public HttpAsyncRequestConsumer<CharArrayProcessable> processRequest(
        final HttpRequest request,
        final HttpContext context)
        throws HttpException, IOException
    {
        return new AsyncCharArrayProcessableConsumer();
    }

    @Override
    public void handle(
        final CharArrayProcessable request,
        final HttpAsyncExchange exchange,
        final HttpContext context)
        throws HttpException, IOException
    {
        Charset acceptedCharset =
            CharsetUtils.acceptedCharset(exchange.getRequest());
        if (config.urlSanitizingConfig().jsonMarkup()) {
            // WMI can't parse  Content-Type: text/html; charset=UTF-8
            // So we form       Content-Type: text/html; charset=utf-8
            acceptedCharset = new FilterCharset(
                acceptedCharset.name().toLowerCase(Locale.ROOT),
                acceptedCharset.aliases(),
                acceptedCharset);
        }
        CharArrayProcessable body;
        if (config.textUrlencoded()) {
            String text =
                new ScanningCgiParams(
                    request.processWith(QueryParser.Factory.INSTANCE))
                    .getString("text", "");
            body = new CharArrayProcessable(text.toCharArray());
        } else {
            body = request;
        }
        String contentType =
            new ScanningCgiParams(exchange.getRequest())
                .getString("mimetype", null);
        long start = TimeSource.INSTANCE.currentTimeMillis();
        HtmlNode root = sanitize(body, "text/plain".equals(contentType));
        long timeTaken = TimeSource.INSTANCE.currentTimeMillis() - start;
        Logger logger = (Logger) context.getAttribute(Sanitizer2.LOGGER);
        logger.info("HTML sanitized in " + timeTaken + MS);
        NByteArrayEntity entity;
        if (config.urlSanitizingConfig().jsonMarkup()) {
            ByteArrayHtmlCollector htmlCollector =
                new ByteArrayHtmlCollector(
                    new Encoder(
                        acceptedCharset.newEncoder()
                            .onMalformedInput(CodingErrorAction.REPLACE)
                            .onUnmappableCharacter(CodingErrorAction.REPLACE)),
                    body.length());
            BasicUrlCollector urlCollector = new BasicUrlCollector();
            HtmlPrinter<CharacterCodingException> printer =
                new HtmlPrinter<>(
                    config,
                    htmlCollector,
                    urlCollector,
                    IdentityAttrPostProcessor.INSTANCE,
                    IdentityCssPostProcessor.INSTANCE);
            try {
                pageHeaderApplier.accept(body, printer);
            } catch (PageHeaderException e) {
                throw new IOException("Failed to apply page headers", e);
            }
            body = null;
            root.accept(printer);
            printer.done();
            List<UrlInfo> urls = urlCollector.urls();
            if (urls != null) {
                for (UrlInfo url: urls) {
                    if (url.type() == UrlType.PHISHING_URL_TYPE) {
                        phishingLinksCallback.run();
                        logger.info("Phishing links found");
                        break;
                    }
                }
            }
            MultipartEntityBuilder builder = MultipartEntityBuilder.create();
            builder.setMimeSubtype("mixed");
            builder.addPart(
                FormBodyPartBuilder.create()
                    .setBody(
                        new StringBody(
                            JsonType.NORMAL.toString(urls),
                            ContentType.APPLICATION_JSON
                                .withCharset(acceptedCharset)))
                    .setName("markup.json")
                    .build());
            builder.addPart(
                FormBodyPartBuilder.create()
                    .setBody(
                        htmlCollector.data().processWith(
                            new ByteArrayBodyFactory(
                                ContentType.TEXT_HTML.withCharset(
                                    acceptedCharset),
                                null)))
                    .setName("sanitized.html")
                    .build());
            htmlCollector = null;
            HttpEntity tmpEntity = builder.build();
            builder = null;
            // For some reason, WMI expect prologue to be not empty
            // so we explicitly write CRLF before payload
            DecodableByteArrayOutputStream out =
                new DecodableByteArrayOutputStream(
                    Math.max(1024, (int) tmpEntity.getContentLength()) + 2);
            out.write('\r');
            out.write('\n');
            tmpEntity.writeTo(out);
            entity = out.processWith(NByteArrayEntityFactory.INSTANCE);
            entity.setContentType(tmpEntity.getContentType());
        } else {
            StringBuilderHtmlCollector htmlCollector =
                new StringBuilderHtmlCollector(body.length());
            HtmlPrinter<RuntimeException> printer =
                new HtmlPrinter<>(
                    config,
                    htmlCollector,
                    NullUrlCollector.INSTANCE,
                    IdentityAttrPostProcessor.INSTANCE,
                    IdentityCssPostProcessor.INSTANCE);
            try {
                pageHeaderApplier.accept(body, printer);
            } catch (PageHeaderException e) {
                throw new IOException("Failed to apply page headers", e);
            }
            body = null;
            root.accept(printer);
            printer.done();
            StringBuilder sb = htmlCollector.sb();
            htmlCollector = null;
            int len = sb.length();
            char[] buf = new char[len];
            sb.getChars(0, len, buf, 0);
            sb = null;
            Encoder encoder = new Encoder(
                acceptedCharset.newEncoder()
                    .onMalformedInput(CodingErrorAction.REPLACE)
                    .onUnmappableCharacter(CodingErrorAction.REPLACE));
            encoder.process(buf, 0, len);
            entity = encoder.processWith(NByteArrayEntityFactory.INSTANCE);
            entity.setContentType(
                ContentType.TEXT_HTML.withCharset(acceptedCharset).toString());
        }
        logger.info("Response length: " + entity.getContentLength());
        exchange.getResponse().setEntity(entity);
        exchange.submitResponse();
    }

    public HtmlNode sanitize(
        final CharArrayProcessable text,
        final boolean plainText)
    {
        HtmlTag root;
        if (plainText) {
            root = new TextPlainProcessor(config).apply(text.toString());
        } else {
            HtmlDomBuilder domBuilder = new HtmlDomBuilder(config);
            sanitize(text, policy.apply(domBuilder));
            root = domBuilder.root();
            StyleProcessor styleProcessor = domBuilder.styleProcessor();
            if (styleProcessor != EmptyStyleProcessor.INSTANCE) {
                root.accept(new StyleApplyingVisitor(styleProcessor));
            }
        }
        HtmlNode rootNode;
        if (plainText || !config.compactHtml()) {
            rootNode = root;
        } else {
            rootNode = root.accept(new HtmlCompactor(config));
        }
        if (plainText || config.wrapPlainLinks()) {
            rootNode.accept(new PlainLinksWrappingVisitor(config));
        }
        return rootNode;
    }

    public void sanitize(
        final CharArrayProcessable text,
        final HtmlSanitizer.Policy policy)
    {
        TagBalancingHtmlStreamEventReceiver receiver =
            new TagBalancingHtmlStreamEventReceiver(policy);
        receiver.setNestingLimit(512);

        receiver.openDocument();

        HtmlDecoder decoder = text.processWith(HtmlDecoderFactory.INSTANCE);
        HtmlLexer lexer = text.processWith(HtmlLexerFactory.INSTANCE);
        List<String> attrs = new ArrayList<>();
        while (lexer.hasNext()) {
            HtmlToken token = lexer.next();
            switch (token.type) {
                case TEXT:
                    receiver.text(decoder.decodeHtml(token.start, token.end));
                    break;
                case UNESCAPED:
                    receiver.text(
                        decoder.stripBannedCodeunits(token.start, token.end));
                    break;
                case TAGBEGIN:
                    if (decoder.charAt(token.start + 1) == '/') {
                        // close tag
                        String tagName = HtmlStreamRenderer.safeName(
                            decoder.substring(token.start + 2, token.end));
                        String internedTagName = config.internTag(tagName);
                        if (internedTagName == null) {
                            internedTagName = tagName;
                        }
                        receiver.closeTag(internedTagName);
                        while (lexer.hasNext()
                            && lexer.next().type != HtmlTokenType.TAGEND)
                        {
                            // skip everything until we meet '>'
                        }
                    } else {
                        attrs.clear();
                        boolean expectAttrName = true;
                        boolean done = false;
                        int start = token.start;
                        int end = token.end;
                        String tagName = HtmlStreamRenderer.safeName(
                            decoder.substring(start + 1, end));
                        while (!done && lexer.hasNext()) {
                            token = lexer.next();
                            switch (token.type) {
                                case ATTRNAME:
                                    start = token.end;
                                    if (expectAttrName) {
                                        expectAttrName = false;
                                    } else {
                                        // last attr has not value
                                        // add attr name as value
                                        attrs.add(attrs.get(attrs.size() - 1));
                                    }
                                    attrs.add(
                                        HtmlLexer.canonicalName(
                                            decoder.substring(
                                                token.start,
                                                token.end)));
                                    break;
                                case ATTRVALUE:
                                    start = token.end;
                                    attrs.add(
                                        decoder.stripQuotesDecodeHtml(
                                            token.start,
                                            token.end));
                                    expectAttrName = true;
                                    break;
                                case TAGEND:
                                    end = token.end;
                                    done = true;
                                    break;
                                default:
                                    // Ignore everyting else
                                    break;
                            }
                        }
                        boolean selfClosed = false;
                        if (!expectAttrName) {
                            String last = attrs.get(attrs.size() - 1);
                            selfClosed =
                                last.length() == 1 && last.charAt(0) == '/';
                            if (selfClosed) {
                                attrs.remove(attrs.size() - 1);
                            } else {
                                attrs.add(last);
                            }
                        }
                        if (!selfClosed) {
                            --end; // ignore '>'
                            while (start < end) {
                                char c = decoder.charAt(end - 1);
                                if (c == '/') {
                                    selfClosed = true;
                                    break;
                                } else if (Character.isWhitespace(c)) {
                                    --end;
                                } else {
                                    break;
                                }
                            }
                        }
                        String internedTagName = config.internTag(tagName);
                        if (internedTagName == null) {
                            receiver.openTag(tagName, attrs);
                        } else {
                            receiver.openTag(internedTagName, attrs);
                        }
                        if (selfClosed) {
                            // Check if we should emit closeTag event
                            if (!HtmlTextEscapingMode.isVoidElement(tagName)) {
                                if (internedTagName == null) {
                                    receiver.closeTag(tagName);
                                } else {
                                    ImmutableTagConfig tagConfig =
                                        config.tags().get(internedTagName);
                                    if (!tagConfig.ignoreSelfClose()
                                        && !tagConfig.requireContent())
                                    {
                                        receiver.closeTag(internedTagName);
                                    }
                                }
                            }
                        }
                    }
                    break;
                default:
                    break;
            }
        }

        receiver.closeDocument();
    }

    private enum HtmlLexerFactory
        implements CharArrayProcessor<HtmlLexer, RuntimeException>
    {
        INSTANCE;

        @Override
        public HtmlLexer process(
            final char[] buf,
            final int off,
            final int len)
        {
            return new HtmlLexer(buf, off, len);
        }
    }

    private static class HtmlDecoder {
        private final StringBuilder sb = new StringBuilder();
        private final char[] buf;

        HtmlDecoder(final char[] buf) {
            this.buf = buf;
        }

        public String decodeHtml(final int start, final int end) {
            return Encoding.decodeHtml(buf, start, end - start, sb);
        }

        public String stripQuotesDecodeHtml(int start, int end) {
            if (end > start) {
                char last = buf[end - 1];
                if (last == '"' || last == '\'') {
                    // Browsers work fine with missing left quote
                    // Anyway, strip right one
                    --end;
                    if (end > start && last == buf[start]) {
                        // Left quote also preset and matches, strip it
                        ++start;
                    }
                }
            }
            return decodeHtml(start, end);
        }

        public String stripBannedCodeunits(final int start, final int end) {
            int len = end - start;
            int safeLimit =
                Encoding.longestPrefixOfGoodCodeunits(buf, start, len);
            if (safeLimit < 0) {
                return new String(buf, start, len);
            }
            sb.setLength(0);
            sb.append(buf, start, len);
            Encoding.stripBannedCodeunits(sb, safeLimit);
            return sb.toString();
        }

        public String substring(final int start, final int end) {
            return new String(buf, start, end - start);
        }

        public char charAt(final int pos) {
            return buf[pos];
        }
    }

    private enum HtmlDecoderFactory
        implements CharArrayProcessor<HtmlDecoder, RuntimeException>
    {
        INSTANCE;

        @Override
        public HtmlDecoder process(
            final char[] buf,
            final int off,
            final int len)
        {
            return new HtmlDecoder(buf);
        }
    }

    private static class ByteArrayBodyFactory
        implements ByteArrayProcessor<ByteArrayBody, RuntimeException>
    {
        private final ContentType contentType;
        private final String filename;

        public ByteArrayBodyFactory(
            final ContentType contentType,
            final String filename)
        {
            this.contentType = contentType;
            this.filename = filename;
        }

        @Override
        public ByteArrayBody process(
            final byte[] buf,
            final int off,
            final int len)
        {
            return new ByteArrayBody(buf, off, len, contentType, filename);
        }
    }
}

