import collections
import io
import logging
import time

import ijson
import six
from lxml import etree
from requests.exceptions import ChunkedEncodingError

LOG = logging.getLogger(__name__)
NO_LIMIT = 2 ** 30
RETRY_TIMEOUT_SECONDS = 10
MAX_ERRORS = 5


class XmlParser(object):
    def __init__(self, depth=1, logging_batch_size=1000, extra_info=None, track_parents=False, recover=None, prefetch=False):
        self.depth = depth
        self.logging_batch_size = logging_batch_size
        self.extra_info = extra_info or dict()
        self.track_parents = track_parents
        self.recover = recover
        self.prefetch = prefetch

    def get_elements(self, input_file):
        current_depth = 0
        if self.prefetch:
            data = input_file.read_all()
            if len(data) == 0:
                return
            input_file = io.BytesIO(data)
        for event, element in etree.iterparse(input_file, ["start", "end", "comment"], recover=self.recover):
            if event == "start":
                current_depth += 1
            if event == "end":
                current_depth -= 1
            if current_depth == self.depth and event in {"end", "comment"}:
                yield element

    @staticmethod
    def elem_to_dict(elem, preserve_order):
        name = elem.tag
        text = elem.text
        if text is not None:
            text = text.strip()
        if not isinstance(name, str):
            return "comment", text
        if name.startswith("{"):
            name = name[name.index("}") + 1:]

        has_children = len(elem) > 0
        has_attributes = len(elem.attrib) > 0
        has_text = text is not None and len(text) > 0
        if has_children or has_attributes:
            if preserve_order:
                value = collections.OrderedDict()
            else:
                value = {}
        else:
            value = None

        if has_text:
            if has_attributes or has_children:
                value["#text"] = text
            else:
                value = text
        for attr_key, attr_val in six.iteritems(elem.attrib):
            value["@" + attr_key] = attr_val
        for child_elem in elem:
            child_key, child_value = XmlParser.elem_to_dict(child_elem, preserve_order)
            if child_key in value:
                if isinstance(value[child_key], list):
                    value[child_key].append(child_value)
                else:
                    value[child_key] = [value[child_key], child_value]
            else:
                value[child_key] = child_value

        return name, value

    def parse(self, input_file, limit=None):
        index = 0
        info = {}
        for elem in self.get_elements(input_file):
            n, v = self.elem_to_dict(elem, False)
            if n == "comment":
                info.setdefault("comments", set()).add(v)
            else:
                if "comments" in info:
                    serializable_info = dict(info)
                    serializable_info["comments"] = list(info["comments"])
                else:
                    serializable_info = info
                serializable_info.update(self.extra_info)
                if self.track_parents:
                    depth = self.depth - 1
                    parent = elem.getparent()
                    while depth > 0:
                        name = parent.tag
                        if name.startswith("{"):
                            name = name[name.index("}") + 1:]
                        serializable_info[name] = dict(parent.attrib)
                        depth -= 1
                        parent = parent.getparent()
                yield v, serializable_info
                index += 1
                if index % self.logging_batch_size == 0:
                    LOG.info("Parsed %s XML items", index)
            elem.clear()
            if limit is not None and index >= limit:
                break
        LOG.info("Parsed %s XML items; done!", index)


class MultiStreamXmlParser(XmlParser):
    def __init__(self, depth=1, logging_batch_size=1000, items_per_page=1000, extra_info=None, stop_on_partial_file=True, prefetch=False):
        super(MultiStreamXmlParser, self).__init__(depth, logging_batch_size, extra_info=extra_info, prefetch=prefetch)
        self.items_per_page = items_per_page
        self.num_errors = 0
        self.stop_on_partial_file = stop_on_partial_file

    def parse(self, input_files, limit=None):
        page = 0
        total = 0
        for input_file in input_files:
            reached_index_on_page = 0
            while True:
                try:
                    with input_file:
                        index_on_page = 0
                        LOG.info("Processing page %s", page)
                        for item, extra in super(MultiStreamXmlParser, self).parse(input_file):
                            if index_on_page < reached_index_on_page:
                                index_on_page += 1
                                continue
                            info = {}
                            info.update(self.extra_info)
                            info.update(extra)
                            info["page"] = page
                            info["index_on_page"] = index_on_page
                            yield item, info
                            reached_index_on_page += 1
                            index_on_page += 1
                            total += 1
                            if limit is not None and total >= limit:
                                LOG.info("Parsed %s XML items total; done!", total)
                                return
                        if self.stop_on_partial_file:
                            done = index_on_page < self.items_per_page
                        else:
                            done = index_on_page == 0
                        if done:
                            LOG.info("Parsed %s XML items total; done!", total)
                            return
                        page += 1
                    break
                except ChunkedEncodingError:
                    if self.num_errors > MAX_ERRORS:
                        raise
                    else:
                        LOG.warn("Connection interrupted, will retry")
                        input_file.content_iterator = None
                        input_file.downloaded = 0
                        input_file.data = ""
                        self.num_errors += 1
                        time.sleep(self.num_errors * RETRY_TIMEOUT_SECONDS)
                        continue


class ZippedXmlParser(XmlParser):
    def parse(self, zipFile, limit=None):
        total = 0
        file_index = 0
        total_files = len(zipFile.filelist)
        for fileInfo in zipFile.filelist:
            file_index += 1
            LOG.info("Processing file %r of archive %r (%s/%s)",
                     fileInfo.filename, zipFile.filename, file_index, total_files)
            index_in_file = 0
            with zipFile.open(fileInfo) as f:
                for item, extra in super(ZippedXmlParser, self).parse(f):
                    info = {}
                    info.update(self.extra_info)
                    info.update(extra)
                    info["filename"] = fileInfo.filename
                    info["index_in_file"] = index_in_file
                    yield item, info
                    index_in_file += 1
                    total += 1
                    if limit is not None and total >= limit:
                        return


class DelimitedStreamParser(object):
    def __init__(self, stream_chunk_size=256 * 1024, logging_batch_size=1000, sep='\n', extra_info=None):
        self.logging_batch_size = logging_batch_size
        self.stream_chunk_size = stream_chunk_size
        self.sep = sep
        self.extra_info = extra_info or dict()

    def parse(self, input_file, limit=None):
        index = 0
        info = {}
        info.update(self.extra_info)
        for elem in self.get_elements(input_file):
            yield elem, info
            index += 1
            if index % self.logging_batch_size == 0:
                LOG.info("Parsed %s JSON items", index)
            if limit is not None and index >= limit:
                break
        LOG.info("Parsed %s JSON items; done!", index)

    def get_elements(self, input_file):
        buff = ""
        while True:
            pos = 0
            index = buff.find(self.sep)
            data = buff
            while index == -1:
                pos += len(data)
                data = input_file.read(size=self.stream_chunk_size)
                if len(data) == 0:
                    if len(buff) > 0:
                        yield buff
                    return
                index = data.find(self.sep)
                buff += data
            pos += index
            element, buff = buff[:pos], buff[pos + 1:]
            yield element


class JsonParser(object):
    def __init__(self, ijson_path="item", logging_batch_size=1000, extra_info=None):
        """
        todo: This sometimes parses float into decimal.Decimal and yson format can't handle it. Need to fix.
        :param ijson_path: ijson-format regexp for path of items to get.
        example: [{'name':'test', 'data':{'subname':'test1'}, {'name':'test', 'data':{'subname':'test2'}]
        path='item' will get you list of all dicts.
        path='item.data' will get you [{'subname':'test1'}, {'subname':'test2'}]
        :param logging_batch_size:
        :param extra_info:
        """
        self.json_path = ijson_path
        self.logging_batch_size = logging_batch_size
        self.extra_info = extra_info or dict()

    def get_elements(self, input_file):
        return ijson.items(input_file, self.json_path, use_float=True)

    def parse(self, input_file, limit=None):
        index = 0
        for elem in self.get_elements(input_file):
            serializable_info = dict()
            serializable_info.update(self.extra_info)
            yield elem, serializable_info
            index += 1
            if index % self.logging_batch_size == 0:
                LOG.info("Parsed %s JSON items", index)
            if limit is not None and index >= limit:
                break
        LOG.info("Parsed %s JSON items; done!", index)
