import collections
import itertools
import os
import random
import time
import traceback

import six
from six.moves import (
    http_client,
    map,
)

import gevent
import gevent.pool
import gevent.queue
import gevent.monkey


if six.PY2:
    gevent.monkey.patch_all(thread=False, select=False)


class ConnectionPool(object):
    def __init__(self, timeout, hosts, is_https=False):
        self.connection_cls = http_client.HTTPSConnection if is_https else http_client.HTTPConnection
        self.connection_pool = {host: gevent.queue.Queue() for host in hosts}
        self.timeout = timeout

    def request(self, host, *args):
        result = {}

        try:
            queue = self.connection_pool[host]

            try:
                connection = queue.get_nowait()
            except gevent.queue.Empty:
                connection = self.connection_cls(host, timeout=self.timeout)

            connection.request(*args)
            result["response"] = connection.getresponse()
            result["data"] = result["response"].read()

        except Exception as e:
            result["exception"] = e
            result["traceback"] = traceback.format_exc()
            return result

        queue.put(connection)
        return result


class AsyncRequest(object):
    def __init__(self, connection_pool, hostname, batch, *args):
        self.connection_pool = connection_pool
        self.hostname = hostname
        self.batch = batch
        self.args = args
        self.data = None
        self.response = None
        self.exception = None
        self.traceback = None
        self.ok = False

    def send(self):
        try:
            result = self.connection_pool.request(self.hostname, *self.args)

            if "exception" in result:
                self.exception = result["exception"]
                self.traceback = result["traceback"]
                return self

            self.data = result["data"]
            self.response = result["response"]
            self.ok = self.response.status == 200
        except Exception as e:
            self.exception = e
            self.traceback = traceback.format_exc()

        return self

    @property
    def status(self):
        return str(self.response.status) if self.response else "exception"

    @property
    def response_body(self):
        return self.data or self.traceback

    def __str__(self):
        base = "Batch: {}\nHostname: {}\nRequest args: {}\n".format(self.batch, self.hostname, self.args)
        if self.response:
            return base + "Response status: {}\nResponse data: {}".format(self.response.status, self.data)
        else:
            return base + "Exception: {}\nTraceback: {}".format(self.exception, self.traceback)


class ErrorStats(object):
    class HTTPCodeData(object):
        def __init__(self, code, request_queue_size):
            self.code = code
            self.count = 0
            self.requests = collections.deque(maxlen=request_queue_size)

        def update(self, request):
            self.count += 1
            self.requests.append(request)

        def __str__(self):
            return "{}: {}\n{}".format(self.code, self.count, "\n\n".join(str(x) for x in self.requests))

    def __init__(self, request_queue_size=20):
        self.total_count = 0
        self.successes = 0
        self.stats = {}
        self.request_queue_size = request_queue_size

    def update(self, async_request):
        self.total_count += 1

        key = async_request.status

        if key not in self.stats:
            self.stats[key] = self.HTTPCodeData(key, self.request_queue_size)
        self.stats[key].update(async_request)

        if async_request.ok:
            self.successes += 1

    def __str__(self):
        return "Total {}, {:.2f}% successes\nStatuses:\n{}".format(
            self.total_count, 100.*self.successes/self.total_count,
            "\n\n".join(str(value) for key, value in six.iteritems(self.stats) if key != 200)
        )


class HttpRequestsMapper(object):
    def __init__(self, hosts, url, timeout, threads, rps_limit, max_fails, failed_request_queue_size, batch_size=1, is_https=False):
        self.hosts = list(hosts)
        self.is_https = is_https
        self.timeout = timeout
        self.url = url
        self.threads = threads
        self.rps_limit = rps_limit
        self.failed_request_queue_size = failed_request_queue_size

        self.reachable_hosts = list(self.hosts)
        self.failed_hosts = []
        self.failures = {host: 0 for host in hosts}
        self.max_failures = max_fails
        self.connection_pool = None

        self.requests = []

        self.total_errors = ErrorStats(failed_request_queue_size)
        self.last_launch_ts = 0.
        self.error_stats_queue = collections.deque(maxlen=self.max_failures)

        self.batch = []
        self.batch_size = batch_size

    def start(self):
        random.seed(os.environ["YT_JOB_ID"])
        self.connection_pool = ConnectionPool(self.timeout, self.hosts, is_https=self.is_https)

    def get_request(self, batch):
        for args in self.get_request_wo_headers(batch):
            yield args + (self.get_headers(),)

    def get_request_wo_headers(self, batch):
        return iter(())

    def get_headers(self):
        return {}

    def process_response(self, request):
        return iter(())

    def retry_failed_request(self, request):
        return True

    def is_request_successful(self, request):
        return request.ok

    def drop_host(self, host):
        if host in self.reachable_hosts:
            self.failed_hosts.append(host)
            self.reachable_hosts.remove(host)

    def create_requests(self, batch, prev_host=None):
        hosts = list(self.reachable_hosts)
        if len(hosts) > 1 and prev_host is not None and prev_host in hosts:
            hosts.remove(prev_host)

        return [
            AsyncRequest(self.connection_pool, random.choice(hosts), batch, *args)
            for args in self.get_request(batch)
        ]

    def send_requests(self, requests):
        current_time = time.time()
        elapsed_time = current_time - self.last_launch_ts
        if elapsed_time < 1.:
            time.sleep(1. - elapsed_time)
        self.last_launch_ts = time.time()

        pool = gevent.pool.Pool(size=self.threads)
        jobs = [pool.spawn(request.send) for request in requests]
        gevent.joinall(jobs)

        retry_requests = []
        reachable_hosts = set()
        all_hosts = set()
        error_stats = ErrorStats(self.failed_request_queue_size)

        for request in requests:
            host = request.hostname
            all_hosts.add(host)
            if self.is_request_successful(request):
                reachable_hosts.add(host)
            elif self.retry_failed_request(request):
                retry_requests.append(request)

            error_stats.update(request)
            self.total_errors.update(request)

        request_results = itertools.chain.from_iterable(map(self.process_response, requests))

        self.error_stats_queue.append(error_stats)

        for host in self.reachable_hosts:
            if host in reachable_hosts:
                self.failures[host] = 0
            else:
                self.failures[host] += 1
                if self.failures[host] >= self.max_failures:
                    self.drop_host(host)

        self.raise_if_hosts_unreachable()

        requests_to_retry = list(itertools.chain.from_iterable((self.create_requests(request.batch, request.hostname) for request in retry_requests)))

        return request_results, requests_to_retry

    def raise_if_hosts_unreachable(self):
        if not self.reachable_hosts:
            raise Exception("All hosts are unreachable: {}\nErrors stats:\n{}\nTotal Errors:\n{}".format(
                " ".join(self.failed_hosts),
                "\n".join(map(str, self.error_stats_queue)),
                str(self.total_errors)
            ))

    def __call__(self, row):
        self.batch.append(row)

        if len(self.batch) >= self.batch_size:
            self.requests += self.create_requests(self.batch)
            self.batch = []

            if len(self.requests) >= self.rps_limit:
                request_results, requests_to_retry = self.send_requests(self.requests)
                self.requests = requests_to_retry
                for result in request_results:
                    yield result

    def finish(self):
        if self.batch:
            self.requests += self.create_requests(self.batch)

        while self.requests:
            request_results, requests_to_retry = self.send_requests(self.requests)
            self.requests = requests_to_retry
            for result in request_results:
                yield result
