import contextlib
import datetime
import logging
import os
import sys
import tarfile
import time
import zipfile
import zlib

import requests

LOG = logging.getLogger(__name__)
MB = 1024 * 1024
CHUNK_SIZE = 65536
BATCH_SIZE = 100
MAX_ERRORS = 5
RETRY_TIMEOUT_SECONDS = 1


def download_to_file(url, out, headers, auth):
    response = requests.get(url, stream=True, headers=headers, auth=auth)
    if response.status_code != 200:
        raise requests.HTTPError("Non-200 status code returned ({})".format(response.status_code))
    LOG.info("Downloading feed '%s' to '%s'", url, out)
    started = time.time()
    downloaded = 0
    previous_batch = 0
    with open(out, "wb") as out_handle:
        for block in response.iter_content(CHUNK_SIZE):
            if block:
                downloaded += len(block)
                current_batch = downloaded / BATCH_SIZE / MB
                if current_batch > previous_batch:
                    LOG.info("Downloaded {} MB".format(current_batch * BATCH_SIZE))
                    previous_batch = current_batch
                out_handle.write(block)
    elapsed = time.time() - started
    LOG.info("Download completed: {:.2f} MB in {:.2f}s (~{:.2f} MB/s)".format(
        float(downloaded) / MB,
        float(elapsed),
        float(downloaded) / MB / elapsed))


class StreamingDownloader(object):
    def __init__(self, decompress=False, untar=False, filename=None):
        self.decompress = decompress
        self.untar = untar
        self.filename = filename

    def get(self, url, headers=None, auth=None):
        class FileLikeIterator(object):
            def __init__(self, url, decompressobj, untar=False, filename=None):
                LOG.info("Streaming feed '%s'", url)
                self.url = url
                self.decompressobj = decompressobj
                self.http_response = None
                self.content_iterator = None
                self.started = time.time()
                self.downloaded = 0
                self.num_errors = 0
                self.data = ""
                self.supports_retries = True
                self.content_length = None
                self.untar = untar
                self.filename = filename
                if self.untar:
                    class Reader(object):
                        def __init__(self, source):
                            self.source = source

                        def read(self, size):
                            return self.source.read_feed(size)

                    tar_feed = tarfile.open(fileobj=Reader(self), mode='r|')
                    tar_file = tar_feed.next()
                    if self.filename is not None:
                        assert tar_file.path == self.filename, "Feed filename mismatch: {} != {}".format(tar_file.path, self.filename)
                    self.untarred_feed = tar_feed.extractfile(tar_file)

            def get_content_iterator(self):
                if self.content_iterator is None:
                    if self.downloaded != 0 and self.supports_retries:
                        new_headers = dict(headers) if headers is not None else {}
                        new_headers["Range"] = "bytes={}-{}".format(self.downloaded, self.content_length)
                    else:
                        new_headers = headers
                    self.http_response = requests.get(self.url, stream=True, headers=new_headers, auth=auth)
                    if self.http_response.status_code not in {200, 206}:
                        raise requests.HTTPError("Non-200 status code returned ({}) for request on url '{}' (headers: {})"
                                                 .format(self.http_response.status_code, self.url, new_headers), response=self.http_response)
                    self.supports_retries = self.http_response.headers.get("Accept-Ranges") == "bytes"
                    self.content_length = self.downloaded + int(self.http_response.headers.get("Content-Length",
                                                                                               sys.maxint))
                    self.content_iterator = self.http_response.iter_content(chunk_size=CHUNK_SIZE)
                return self.content_iterator

            def unpack(self, chunk):
                if self.decompressobj:
                    return self.decompressobj.decompress(chunk)
                else:
                    return chunk

            def read_feed(self, size=None):
                assert size is not None
                while len(self.data) < size:
                    try:
                        chunk = self.unpack(self.get_content_iterator().next())
                        self.downloaded += len(chunk)
                        self.data = "".join((self.data, chunk))
                    except StopIteration:
                        break
                    except requests.RequestException:
                        if not self.supports_retries or self.num_errors > MAX_ERRORS:
                            raise
                        else:
                            LOG.warn("Connection interrupted, will retry")
                            self.content_iterator = None
                            try:
                                if self.http_response is not None:
                                    self.http_response.close()
                            except Exception:
                                pass
                            self.num_errors += 1
                            time.sleep(self.num_errors * RETRY_TIMEOUT_SECONDS)
                            continue

                result, self.data = self.data[:size], self.data[size:]
                return result

            def read(self, size):
                if self.untar:
                    return self.untarred_feed.read(size)
                else:
                    return self.read_feed(size)

            def read_all(self):
                chunks = list()
                chunk = self.read(MB)
                while chunk:
                    chunks.append(chunk)
                    chunk = self.read(MB)
                return b''.join(chunks)

            def close(self):
                if self.http_response is not None:
                    self.http_response.close()
                elapsed = time.time() - self.started
                LOG.info("Stream completed: {:.2f} MB in {:.2f}s (~{:.2f} MB/s)".format(
                    float(self.downloaded) / MB,
                    float(elapsed),
                    float(self.downloaded) / MB / elapsed))

            def __enter__(self):
                return self

            def __exit__(self, exc_type, exc_val, exc_tb):
                self.close()

        decompressobj = zlib.decompressobj(32 + zlib.MAX_WBITS) if self.decompress else None
        return FileLikeIterator(url, decompressobj=decompressobj, untar=self.untar, filename=self.filename)


class PagedStreamingDownloader(StreamingDownloader):
    def __init__(self, rows_key="rows", offset_key="offset", step=1000, limit=100000, decompress=False, initial_offset=0, offset_in_pages=False):
        super(PagedStreamingDownloader, self).__init__(decompress)
        self.rows_key = rows_key
        self.offset_key = offset_key
        self.step = step
        self.limit = limit
        self.initial_offset = initial_offset
        self.offset_in_pages = offset_in_pages

    @contextlib.contextmanager
    def get(self, url, headers=None, auth=None):
        def get_internal():
            offset = self.initial_offset
            while offset < self.limit:
                new_url = "{}&{}={}&{}={}".format(url, self.offset_key, offset, self.rows_key, self.step)
                yield super(PagedStreamingDownloader, self).get(new_url, headers, auth)
                offset += 1 if self.offset_in_pages else self.step

        yield get_internal()


class TmpFileDownloader(object):
    @contextlib.contextmanager
    def get(self, url, headers=None, auth=None):
        tmp_file = "tmp_{0}".format(datetime.datetime.now().isoformat())
        download_to_file(url, tmp_file, headers, auth)
        with open(tmp_file) as f:
            yield f
        if os.path.exists(tmp_file):
            os.remove(tmp_file)


class ZippedFileDownloader(object):
    @contextlib.contextmanager
    def get(self, url, headers=None, auth=None):
        tmp_file = "tmp_{0}".format(datetime.datetime.now().isoformat())
        download_to_file(url, tmp_file, headers, auth)
        with zipfile.ZipFile(tmp_file) as f:
            yield f
        if os.path.exists(tmp_file):
            os.remove(tmp_file)
