import contextlib
import logging

from django.conf import settings
from smarttv.utils.unistat import path_to_signal

from smarttv.droideka import unistat

logger = logging.getLogger(__name__)

LOCALHOST = ['localhost', '127.0.0.1', '[::]', '[::1]']
ACCESS_CONTROL_ALLOW_ORIGIN = 'Access-Control-Allow-Origin'
ACCESS_CONTROL_ALLOW_CREDENTIALS = 'Access-Control-Allow-Credentials'
HTTP_ORIGIN = 'HTTP_ORIGIN'

CONTENT_SECURITY_POLICY = 'Content-Security-Policy'


class CorsMiddleware:
    CORS_HOSTS = getattr(settings, 'CORS_HOSTS', LOCALHOST)
    CORS_ORIGIN_ALLOW_ALL = getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False)

    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):
        response = self.get_response(request)

        if not self.is_enabled(request):
            return response

        if self.CORS_ORIGIN_ALLOW_ALL:
            response[ACCESS_CONTROL_ALLOW_ORIGIN] = '*'
        elif HTTP_ORIGIN in request.META:
            request_origin = request.META.get(HTTP_ORIGIN)
            response[ACCESS_CONTROL_ALLOW_ORIGIN] = request_origin
        response[ACCESS_CONTROL_ALLOW_CREDENTIALS] = 'true'
        if request.path == '/swagger/':
            response[CONTENT_SECURITY_POLICY] = 'content="img-src * \'self\' data: https: http:;'

        return response

    def is_enabled(self, request):
        host = request.get_host()
        enabled = host in self.CORS_HOSTS
        logger.debug("Host is '%s', Enabled: %s, CORS_HOSTS: %s", host, enabled, self.CORS_HOSTS)
        return enabled


class UnistatMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):
        def safe_increment_counter(counter_name: str):
            counter = unistat.manager.get_counter(counter_name)
            if counter:
                counter.increment()

        signal_name_base = path_to_signal(request.path)

        safe_increment_counter(f'call-{signal_name_base}-cnt')
        try:
            with unistat.manager.get_timer(f'call-{signal_name_base}-duration') or contextlib.nullcontext():
                response = self.get_response(request)
                if response.status_code >= 400:
                    safe_increment_counter(f'call-{signal_name_base}-fail')
                    if response.status_code < 500:
                        safe_increment_counter(f'call-{signal_name_base}-fail-4xx')
                    else:
                        safe_increment_counter(f'call-{signal_name_base}-fail-5xx')
                return response
        except Exception:
            safe_increment_counter(f'call-{signal_name_base}-err')
            raise
