import os
import stat
import socket
import hashlib
import logging
import requests

try:
    from typing import Optional, Generator  # noqa
except ImportError:
    pass

from . import s3_link
from .. import config as common_config
from ..types import misc as ctm
from ..types import resource as ctr


DEFAULT_CHUNK_SIZE = 0x4FFFFF  # Default size of data chunk
TIMEOUT = 30  # Default mds request timeout


class MD5Mismatch(Exception):
    """ Exception raised on md5 mismatch """


def iter_content(response, logger=None, chunk_size=DEFAULT_CHUNK_SIZE, timeout=TIMEOUT):
    # type: (requests.Response, Optional[logging.Logger], Optional[int], Optional[int]) -> Generator[bytes]
    """
    Iterates by content_iter from mds response data. Yield data chunks.
    """
    logger = logger or logging.getLogger("mds_stream")
    url = response.request.url
    headers = response.request.headers
    response_raw_stream_func = response.raw.stream
    content_iter = iter(response.iter_content(chunk_size))
    size = 0
    while True:
        try:
            chunk = next(content_iter)
        except StopIteration:
            break
        except (socket.error, IOError) as ex:
            logger.warning("Error while getting %s: %s", url, ex)
            logger.info("Getting %s from %sth byte: %s", url, size + 1)
            req_headers = dict(headers)
            req_headers[ctm.HTTPHeader.RANGE] = "bytes={}-".format(size + 1)
            response = requests.get(
                url,
                stream=True,
                headers=req_headers,
                timeout=timeout
            )
            response.raw.stream = response_raw_stream_func
            content_iter = iter(response.iter_content(chunk_size))
            continue
        size += len(chunk)
        yield chunk


class Chunker(object):
    def __init__(self, url, logger=None, chunk_size=DEFAULT_CHUNK_SIZE, timeout=TIMEOUT):
        self.offset = 0
        self.data_offset = 0
        self.chunk_size = chunk_size
        self.timeout = timeout

        response = requests.get(url, stream=True, timeout=self.timeout)
        if response.status_code != requests.codes.OK:
            response.raise_for_status()
        self.content_iter = iter(
            iter_content(response, logger=logger, chunk_size=self.chunk_size, timeout=self.timeout)
        )

        try:
            self.data = next(self.content_iter)
        except StopIteration:
            self.data = b""

    def read(self, offset, size=None):
        """ Read size bytes from offset. If size is None read all bytes from stream """
        try:
            if size and self.offset + self.data_offset >= offset + size:
                return

            while self.offset + len(self.data) <= offset:
                self.offset += len(self.data)
                self.data_offset = 0
                self.data = next(self.content_iter)

            self.data_offset = offset - self.offset

            while size is None or self.offset + len(self.data) <= offset + size:
                if not self.data_offset:
                    yield self.data
                else:
                    yield self.data[self.data_offset:]
                self.offset += len(self.data)
                self.data_offset = 0
                self.data = next(self.content_iter)

            new_data_offset = offset + size - self.offset
            final_data = self.data[self.data_offset:new_data_offset]
            self.data_offset = new_data_offset

            if final_data:
                yield final_data
        except StopIteration:
            return


def read_from_mds(
    resource, resource_path, config=None, logger=None, chunk_size=DEFAULT_CHUNK_SIZE,
    timeout=TIMEOUT
):
    """
    Download resource to resource_path from mds.
    Return path to downloaded resource.
    If resource with old directory format return None.
    """
    rid = resource["id"]
    mds = resource.get("mds")
    if not mds:
        return None

    config = config or common_config.Registry()
    logger = logger or logging.getLogger("mds_reader")

    mds_link = s3_link(mds["key"], mds.get("namespace"), mds_settings=config.common.mds)
    file_name = resource["file_name"]
    file_path = os.path.join(resource_path, file_name.split("/")[-1])
    file_dir = os.path.dirname(file_path)
    if not os.path.exists(file_dir):
        os.makedirs(file_dir, mode=0o755)

    if not resource.get("multifile"):
        if resource.get("executable") is None:
            return None
        with open(str(file_path), "wb") as f:
            chunker = Chunker(mds_link, logger=logger, chunk_size=chunk_size, timeout=timeout)
            hasher = hashlib.md5()
            for chunk in chunker.read(0, size=None):
                hasher.update(chunk)
                f.write(chunk)
            file_md5 = hasher.hexdigest()
            if resource.get("md5", "") != file_md5:
                raise MD5Mismatch(
                    "Md5 '{}' differs with original '{}' for resource {}.".format(
                        file_md5, resource.get("md5", ""), resource["id"]
                    )
                )
        if resource["executable"]:
            file_stat = os.stat(file_path).st_mode
            os.chmod(file_path, file_stat | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
        return file_path

    metadata_link = s3_link(
        str(rid), mds.get("namespace"), mds_settings=config.common.mds
    )
    logger.debug("Downloading metadata for resource #%s from %s", rid, metadata_link)
    mds_metadata_resp = requests.get(
        metadata_link,
        timeout=timeout,
    )

    if mds_metadata_resp.status_code != requests.codes.OK:
        logger.debug("Multifile resource #%s uses old storage schema", rid)
        return None

    mds_metadata = mds_metadata_resp.json()

    chunker = Chunker(mds_link, logger=logger, chunk_size=chunk_size, timeout=timeout)
    for item in mds_metadata[2:]:
        item_type = item["type"]
        path = os.path.join(str(resource_path), item["path"])
        if item_type == ctr.FileType.DIR:
            os.makedirs(path, mode=0o755)
        elif item_type == ctr.FileType.SYMLINK:
            os.symlink(item["symlink"], path)
        elif item_type == ctr.FileType.TOUCH:
            open(path, "a").close()
            os.chmod(path, 0o755 if item.get("executable") else 0o644)
        elif item_type == ctr.FileType.FILE:
            with open(path, "wb") as f:
                hasher = hashlib.md5()
                for chunk in chunker.read(item["offset"], item["size"]):
                    hasher.update(chunk)
                    f.write(chunk)
                file_md5 = hasher.hexdigest()
                if item.get("md5", "") != file_md5:
                    raise MD5Mismatch(
                        "Md5 '{}' differs with original '{}' for resource {} file '{}'.".format(
                            file_md5, resource.get("md5", ""), resource["id"], item["path"]
                        )
                    )
            os.chmod(path, 0o755 if item.get("executable") else 0o644)

    return file_path
