from typing import Callable, Dict, Optional, Tuple, Union

from lxml.etree import (
    Element,
    Resolver,
    XMLParser,
    XMLSchema,
    _Element as XmlElement,
    parse as xml_parse,
)


class XmlDoc:
    def __init__(self, name: str, nsmap: dict = None):
        self.current_tag = Element(name, nsmap=nsmap)

    def tag(self, _name_: str, **kwargs) -> "Tag":
        return self.__class__.Tag(self, _name_, **kwargs)

    def text(self, text: Union[str, int, float, None]):
        if text is not None:
            text = str(text)

        self.current_tag.text = text

    def attrs(self, **kwargs):
        for key, value in kwargs.items():
            self.current_tag.set(key, value)

    def append(self, tag: XmlElement):
        self.current_tag.append(tag)

    def tag_conditional(self, _name_: str, _conditional_, **kwargs) -> Optional["Tag"]:
        if _conditional_:
            return self.__class__.Tag(self, _name_, **kwargs)

    # noinspection SpellCheckingInspection
    def ttaa(self) -> Tuple["XmlDoc", Callable, Callable, Callable, Callable]:
        return self, self.tag, self.text, self.attrs, self.append

    @property
    def result(self) -> XmlElement:
        return self.current_tag.getroottree().getroot()

    class Tag:
        parent_tag = None

        def __init__(
            self,
            doc: "XmlDoc",
            _name_: str,
            text: Union[str, int, float, None] = None,
            **attrs,
        ):
            self.tag = Element(_name_, **attrs)
            if text is not None:
                self.tag.text = str(text)

            self.doc = doc
            self.doc.current_tag.append(self.tag)
            self._added_to = self.doc.current_tag

        def __enter__(self) -> "XmlDoc.Tag":
            self.parent_tag = self.doc.current_tag
            self.doc.current_tag = self.tag

            return self

        def __exit__(self, exc_type, exc_value, traceback):
            if exc_value is None:
                self.doc.current_tag = self.parent_tag
                if len(self.tag) == 0:
                    self.erase()

        def erase(self):
            if self._added_to is not None:
                self._added_to.remove(self.tag)
                self._added_to = None


class UrlResolver(Resolver):
    def __init__(self, url_map: Dict[str, bytes]):
        super().__init__()
        self._url_map = url_map

    def resolve(self, url, _id, context):
        url_lower = url.lower()
        if self._url_map and url_lower in self._url_map:
            return self.resolve_string(self._url_map[url_lower].decode(), context)

        return self.resolve_filename(url, context)


class XmlSchema:
    def __init__(self, schema):
        self._schema = schema

    @classmethod
    def from_config(cls, config) -> "XmlSchema":
        with open(config.XML_SCHEMA_PATH, "rb") as xml_file:
            with open(config.W3C_SCHEMA_PATH, "rb") as w3c_file:
                return cls.from_file(xml_file, w3c_file)

    @classmethod
    def from_file(cls, xml_file, w3c_file) -> "XmlSchema":
        parser = XMLParser()
        parser.resolvers.add(
            UrlResolver({"http://www.w3.org/2001/xml.xsd": w3c_file.read()})
        )

        xmlschema_doc = xml_parse(xml_file, parser)
        schema = XMLSchema(xmlschema_doc)

        return cls(schema)

    def validate(self, xml: XmlElement, raises: bool = True):
        if raises:
            return self._schema.assertValid(xml)
        return self._schema.validate(xml)


__all__ = ["XmlDoc", "XmlSchema"]
