import contextlib
import json
import logging
from collections import defaultdict
from functools import update_wrapper
from threading import Lock
from urllib.parse import urlencode, urljoin
from rest_framework import status as http_status
from typing import Optional

import requests
from django.conf import settings

from smarttv.droideka.proxy import exceptions
from smarttv.droideka.unistat import manager, metrics
from plus.utils.http import make_requests_session

logger = logging.getLogger(__name__)


class DummyCtx:
    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass


def log_and_raise(exception_cls, log_msg, log_msg_args=(), exception_msg=None, logger=logger):
    logger.exception(log_msg, *log_msg_args)
    if exception_msg is not None:
        raise exception_cls(exception_msg)
    raise exception_cls(log_msg % log_msg_args)


# noinspection PyMethodMayBeStatic
class BaseApi:
    content_type_header = None
    accept_header = None

    unistat_suffix = None
    LOG_ERROR_RESPONSE_HEADERS = tuple()

    class RequestError(exceptions.InternalServerError):
        pass

    class BadGatewayError(exceptions.BadGatewayError):
        pass

    class GatewayTimeoutError(exceptions.GatewayTimeoutError):
        pass

    class NotFoundError(exceptions.NotFoundError):
        pass

    class Forbidden(exceptions.ForbiddenError):
        pass

    class ConflictError(exceptions.APIException):
        status_code = http_status.HTTP_409_CONFLICT
        default_detail = 'Can not perform action because of conflict'
        default_code = 'conflict'

    class PermissionDeniedError(exceptions.PermissionDeniedError):
        pass

    class BadRequestError(exceptions.BadRequestError):
        pass

    def __init__(self, url, timeout, retries):
        self.api_url = url
        self.timeout = timeout
        self.retries = retries

        logger.info('Creating session with pool max size: %s,'
                    'pool connections: %s for %s',
                    settings.NETWORK_CONNECTION_POOL_MAX_SIZE,
                    settings.NETWORK_CONNECTION_POOL_SIZE,
                    self.__class__.__name__)
        self.session = make_requests_session(max_retries=retries,
                                             pool_maxsize=int(settings.NETWORK_CONNECTION_POOL_MAX_SIZE),
                                             pool_connections=int(settings.NETWORK_CONNECTION_POOL_SIZE))
        self.unistat_counters = {}
        self.unistat_timer = None
        self.init_unistat_counters()

    def init_unistat_counters(self):
        if not self.unistat_suffix:
            raise ValueError(f'You forget to define unistat_suffix in {self.__class__}')
        if self.unistat_suffix not in metrics.EXTERNAL_SERVICES:
            return
        for status in metrics.RESPONSE_STATUSES:
            counter_name = metrics.get_request_counter_name(self.unistat_suffix, status)
            self.unistat_counters[status] = manager.get_counter(counter_name)

        timer_name = metrics.get_request_counter_name(self.unistat_suffix, 'timing')
        self.unistat_timer = manager.get_timer(timer_name)

    @property
    def proxies(self) -> Optional[dict]:
        return None

    @property
    def verify_ssl(self) -> bool:
        return True

    def increment_unistat_counter(self, status, value=1):
        counter = self.unistat_counters.get(status)
        if counter is not None:
            counter.increment(value)

    def get_response_headers(self, response):
        result = {}
        for header_name, header_value in response.headers.items():
            if header_name.lower() in self.LOG_ERROR_RESPONSE_HEADERS:
                result[header_name] = header_value
        return result

    def update_params(self, params):
        return params

    def update_headers(self, headers):
        return headers

    def make_request(self, method, url, headers, params, timeout, data, json, **kwargs):
        try:
            proxies = self.proxies  # don't measure time of getting tvm tokens
            with self.unistat_timer or contextlib.nullcontext():
                response = self.session.request(
                    method=method,
                    url=url,
                    headers=headers,
                    params=params,
                    data=data,
                    json=json,
                    timeout=timeout,
                    proxies=proxies,
                    verify=self.verify_ssl,
                )
        except requests.exceptions.RetryError:
            self.increment_unistat_counter('retry', self.retries)
            raise
        except requests.exceptions.BaseHTTPError:
            self.increment_unistat_counter('error')
            raise
        else:
            logger.debug('request elapsed %s', response.elapsed)
        if 200 <= response.status_code < 300:
            self.increment_unistat_counter('2xx')
        elif 400 <= response.status_code < 500:
            self.increment_unistat_counter('4xx')
        elif 500 <= response.status_code:
            self.increment_unistat_counter('5xx')
        retry_count = len(response.raw.retries.history) if response.raw.retries else 0
        if retry_count > 0:
            self.increment_unistat_counter('retry', retry_count)
        return response

    def handle_status(self, response, **kwargs):
        response_headers = self.get_response_headers(response)
        msg_args = (response.content, response_headers)
        if response.status_code == 404:
            # we don't want log this error, because it's not an error,
            # so we just throw it and DRF will send 404 to the user
            msg = f'{self.__class__.__name__} responded with code 404, ' \
                  f'response body: %s, response_headers: %s'
            raise self.NotFoundError(msg % msg_args)
        if response.status_code == 409:
            raise self.ConflictError(f'{self.__class__.__name__} responded with code 409, '
                                     f'response body: {response.content}, response_headers: {response_headers}')
        if response.status_code >= 500:
            msg = f'{response.status_code} response code from upstream server in {self.__class__.__name__}, ' \
                  f'response body: %s, response_headers: %s'
            exception_cls = self.BadGatewayError
        else:
            msg = f'Unhandled exception: {self.__class__.__name__} responded with code {response.status_code}, ' \
                  f'response body: %s, response_headers: %s'
            exception_cls = self.RequestError

        log_and_raise(exception_cls, msg, log_msg_args=msg_args, logger=logger)

    def _request(self, method='GET', endpoint='', headers=None, params=None, timeout=None, data=None,
                 json=None, **make_request_kwargs):
        headers = headers or {}
        params = params or {}

        headers = self.update_headers(headers)

        if self.accept_header:
            headers['Accept'] = self.accept_header
        if self.content_type_header:
            headers['Content-Type'] = self.content_type_header
        params = self.update_params(params)

        url = urljoin(self.api_url, endpoint)

        logger.debug('%s %s%s', method, url, f'?{urlencode(params, doseq=True)}' if params else '')
        try:
            response = self.make_request(
                method=method,
                url=url,
                headers=headers,
                params=params,
                timeout=timeout or self.timeout,
                data=data,
                json=json,
                **make_request_kwargs
            )
            if response.status_code >= 400:
                self.handle_status(response, params=params)
            return response
        except requests.exceptions.RetryError:
            msg = f'Request to {self.__class__.__name__} exceeded max_retries'
            log_and_raise(self.GatewayTimeoutError, msg, logger=logger)
        except requests.exceptions.BaseHTTPError as e:
            msg = f'Request to {self.__class__.__name__} failed due to {e.__class__}'
            log_and_raise(self.RequestError, msg, logger=logger)

        return None


# noinspection PyMethodMayBeStatic
class BaseJsonApi(BaseApi):
    content_type_header = 'application/json'
    accept_header = 'application/json'

    def handle_response(self, json_response, params, headers):
        pass

    def _request(self, method='GET', endpoint='', headers=None, params=None, timeout=None, data=None,
                 **make_request_kwargs):
        response = super()._request(method, endpoint, headers, params, timeout, data=data, **make_request_kwargs)
        if response is not None and method != 'HEAD':
            try:
                result = json.loads(response.content)
                self.handle_response(result, params, headers)
                return result
            except (TypeError, ValueError):
                self.increment_unistat_counter('parseerror')
                msg = f'{self.__class__.__name__} responded with bad json %s'
                msg_args = (response.content,)
                log_and_raise(self.BadGatewayError, msg, log_msg_args=msg_args, logger=logger)
        return None


class BaseRpcApi(BaseApi):
    content_type_header = 'application/protobuf'
    accept_header = 'application/protobuf'

    def _parse_response(self, response, response_class):
        # response_class.
        pass

    def handle_response(self, json_response, params, headers):
        pass

    def _request(self, method='GET', endpoint='', headers=None, params=None, timeout=None, data=None,
                 **make_request_kwargs):
        response_class = make_request_kwargs['response_class']
        response = super()._request(method, endpoint, headers, params, timeout, data=data, **make_request_kwargs)
        try:
            return response_class.FromString(response.content)
        except (TypeError, ValueError):
            self.increment_unistat_counter('parseerror')
            msg = f'{self.__class__.__name__} responded with bad protobuf %s'
            msg_args = (response.content,)
            log_and_raise(self.BadGatewayError, msg, log_msg_args=msg_args, logger=logger)
        return


class CacheItem:
    def __init__(self, max_count):
        self.max_count = max_count
        self._lock = Lock()
        self.count = 0
        self.data = None
        self.exc = None
        self.exc_tb = None
        self.is_ready = False

    def __enter__(self):
        self._lock.acquire(blocking=True)
        self.count += 1
        if self.count >= self.max_count:
            self.is_ready = False
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._lock.release()

    def get(self):
        if self.exc:
            raise self.exc.with_traceback(self.exc_tb)
        return self.data

    def set(self, data):
        self.data = data
        self.exc = None
        self.exc_tb = None
        self.is_ready = True
        self.count = 0

    def set_exc(self, exc, tb):
        self.data = None
        self.exc = exc
        self.exc_tb = tb
        self.is_ready = True
        self.count = 0


class ApiSampler:
    def __init__(self, fake_to_real_rate=100):
        self._cache = defaultdict(lambda: CacheItem(fake_to_real_rate))

    def wrap(self, by_endpoint=False):
        def outer(f):
            def inner(*args, **kwargs):
                # construct key from function instance
                cache_key = str(id(f.__self__)) + f.__name__
                if by_endpoint:
                    # incorporate endpoint into key
                    try:
                        endpoint = kwargs.get('endpoint') or args[1]
                        if endpoint:
                            cache_key += endpoint
                    except IndexError:
                        pass

                with self._cache[cache_key] as item:
                    if item.is_ready:
                        # get old response if ready
                        return item.get()
                    else:
                        # save new response
                        try:
                            item.set(f(*args, **kwargs))
                            return item.get()
                        except Exception as exc:
                            item.set_exc(exc, exc.__traceback__)
                            raise

            return update_wrapper(inner, f)

        return outer

    def patch_request(self, obj, func_name='_request', by_endpoint=False):
        orig_func = getattr(obj, func_name)
        wrapper = self.wrap(by_endpoint=by_endpoint)
        wrapped_func = wrapper(orig_func)
        setattr(obj, func_name, wrapped_func)
