package ru.yandex.webmaster3.validator.xml;

import org.apache.xerces.dom.DOMInputImpl;
import org.apache.xerces.jaxp.validation.XMLSchema11Factory;
import org.apache.xerces.util.SecurityManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.*;
import org.xml.sax.helpers.DefaultHandler;
import ru.yandex.common.util.IOUtils;
import ru.yandex.common.util.collections.Cf;
import ru.yandex.webmaster3.core.validator.model.ErrorDescription;
import ru.yandex.webmaster3.core.validator.model.ErrorSource;

import javax.xml.parsers.ParserConfigurationException;
import javax.xml.parsers.SAXParserFactory;
import javax.xml.transform.Source;
import javax.xml.transform.stream.StreamSource;
import javax.xml.validation.Schema;
import javax.xml.validation.SchemaFactory;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.List;
import java.util.Set;

/**
 * User: Vladislav Dolbilov (darl@yandex-team.ru)
 * Date: 14.02.12 17:12
 */
public class XmlParser {
    private static final Logger log = LoggerFactory.getLogger(XmlParser.class);

    private static final DefaultHandler EMPTY_CONTENT_HANDLER = new DefaultHandler();
    private static final int MAX_ERRORS = 100;

    public void parse(final InputStream xml, ContentHandler contentHandler) {
        validate(xml, null, contentHandler);
    }

    public List<ErrorDescription> validate(final InputStream xml, final InputStream ...xsds) {
        return validate(xml, xsds, EMPTY_CONTENT_HANDLER);
    }

    private List<ErrorDescription> validate(final InputStream xml, final InputStream[] xsds, ContentHandler contentHandler) {
        final XmlErrorHandler errorHandler = new XmlErrorHandler();
        final Schema schema = xsds.length == 0 ? null : createSchema(xsds, errorHandler);
        if (schema == null && xsds.length == 0) {
            return errorHandler.getErrors();
        }

        try {
            final XMLReader reader = createReader(schema, contentHandler, errorHandler);
            reader.parse(new InputSource(xml));
        } catch (ParserConfigurationException e) {
            log.error("Parser configuration exception", e);
            errorHandler.addError(e.getLocalizedMessage(), ErrorSource.INTERNAL);
        } catch (SAXParseException e) { //subclass of SAXException, but contains (line, column) info
            errorHandler.addError(e, ErrorSource.XML);
        } catch (SAXException e) {
            errorHandler.addError(e, ErrorSource.XML);
        } catch (IOException e) {
            errorHandler.addError(e.getLocalizedMessage(), ErrorSource.XML);
        } catch (TooManyErrors e) {
            log.warn("Too many errors. Validation stopped");
        } finally {
            IOUtils.closeInput(xml);
            for (InputStream xsd : xsds) {
                IOUtils.closeInput(xsd);
            }
        }
        return errorHandler.getErrors();
    }

    private static Schema createSchema(final InputStream[] xsds, final XmlErrorHandler errorHandler) {
        final SchemaFactory factory = new XMLSchema11Factory();
        factory.setResourceResolver((type, namespaceURI, publicId, systemId, baseURI) -> new DOMInputImpl());
        try {
            return factory.newSchema(Arrays.stream(xsds).map(StreamSource::new).toArray(Source[]::new));
        } catch (SAXParseException e) {
            log.error("Schema error", e);
            errorHandler.addError(e, ErrorSource.XSD);
        } catch (SAXException e) {
            log.error("Schema error", e);
            errorHandler.addError(e, ErrorSource.XSD);
        }
        return null;
    }

    private static XMLReader createReader(Schema schema, ContentHandler contentHandler, XmlErrorHandler errorHandler) throws SAXException, ParserConfigurationException {
        final SAXParserFactory factory = SAXParserFactory.newInstance();
        factory.setNamespaceAware(true);
        factory.setXIncludeAware(false);
        factory.setValidating(false);
        if (schema != null) {
            factory.setSchema(schema);
        }

        final XMLReader reader = factory.newSAXParser().getXMLReader();
        reader.setContentHandler(contentHandler);
        reader.setErrorHandler(errorHandler);

        reader.setFeature("http://apache.org/xml/features/validation/schema/augment-psvi", false);
        reader.setFeature("http://apache.org/xml/features/nonvalidating/load-dtd-grammar", false);
        reader.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false);
        reader.setFeature("http://xml.org/sax/features/external-parameter-entities", false);
        reader.setFeature("http://xml.org/sax/features/external-general-entities", false);
        reader.setFeature("http://xml.org/sax/features/validation", false);
        reader.setProperty("http://apache.org/xml/properties/security-manager", new SecurityManager());
        return reader;
    }

    private static class XmlErrorHandler implements ErrorHandler {
        private final Set<ErrorDescription> errors = Cf.newHashSet();

        @Override
        public void warning(SAXParseException exception) throws SAXException {
            addError(exception, ErrorSource.XML);
        }

        @Override
        public void error(SAXParseException exception) throws SAXException {
            addError(exception, ErrorSource.XML);
        }

        @Override
        public void fatalError(SAXParseException exception) throws SAXException {
            addError(exception, ErrorSource.XML);
        }

        void addError(SAXParseException e, ErrorSource source) {
            addError(new ErrorDescription(e.getLineNumber(), e.getColumnNumber(), e.getLocalizedMessage(), source));
        }

        void addError(SAXException e, ErrorSource source) {
            addError(new ErrorDescription(e.getLocalizedMessage(), source));
        }

        void addError(String message, ErrorSource source) {
            addError(new ErrorDescription(message, source));
        }

        private void addError(ErrorDescription errorDescription) {
            errors.add(errorDescription);
            if (errors.size() > MAX_ERRORS) {
                throw new TooManyErrors();
            }
        }

        List<ErrorDescription> getErrors() {
            return Cf.newList(errors);
        }
    }

    private static class TooManyErrors extends RuntimeException {
    }
}
