# coding: utf-8

import logging
from time import perf_counter

from django.conf import settings
from django.contrib.auth import authenticate, get_user, login
from django.contrib.auth.hashers import make_password
from django.contrib.auth.models import AnonymousUser
from django.db.transaction import atomic
from django.db.utils import DatabaseError
from django.http.request import split_domain_port
from django.utils import translation
from django.utils.deprecation import MiddlewareMixin
from django.utils.translation import gettext as _
from django_replicated.utils import routers
from rest_framework import status
from rest_framework.exceptions import ValidationError
from rest_framework.permissions import SAFE_METHODS
from ylog.context import log_context

from procu.api.managers import UserinfoSerializer
from procu.api.models import User
from procu.api.utils import (
    ErrorResponse,
    get_maintenance,
    get_real_ip,
    is_internal,
    is_readonly,
)

logger = logging.getLogger(__name__)


class ExceptionHandler(MiddlewareMixin):
    @staticmethod
    def process_exception(request, exc):

        if isinstance(exc, DatabaseError):
            logger.exception('Database error')

            from django.db import connection

            connection.close()

            return ErrorResponse(
                msg=_('ERRORS::MSG_DATABASE_FAILURE'),
                status=status.HTTP_503_SERVICE_UNAVAILABLE,
            )


class MixedAuthenticatonMiddleware(object):
    def __init__(self, get_response):
        self.get_response = get_response

    def set_user(self, request):

        if is_internal(request):

            s = perf_counter()
            request.yauser = authenticate(request=request)
            request.bb_time = int((perf_counter() - s) * 1000000.0)

            # Internal, blackbox authentication
            request.user = self._get_internal_user(request)

        else:
            # External, user/pass authentication
            request.user = self._get_external_user(request)

    def __call__(self, request):

        self.set_user(request)

        # ----------------------------------------------------------------
        # Set up user context for logs

        context = {'bb_time': getattr(request, 'bb_time', None)}

        user = request.user
        yauser = getattr(request, 'yauser', None)

        if user.is_authenticated:
            context['user'] = {
                'username': user.username,
                'is_staff': user.is_staff,
            }

            if user.supplier_id:
                context['user']['supplier_id'] = user.supplier_id

            if yauser is not None:
                mechanism = yauser.authenticated_by.mechanism_name
                context['user']['mechanism'] = mechanism

                if mechanism == 'oauth':
                    client = yauser.blackbox_result.oauth.client_name
                    context['user']['application'] = client

        else:
            user_context = None

        with log_context(**context):
            return self.get_response(request)

        # ----------------------------------------------------------------

    @staticmethod
    def _get_external_user(request):

        user = get_user(request)

        if user.is_authenticated:
            return user

        try:
            # External, capability URL
            token = request.GET[settings.URL_PARAM_AUTH]
            user = authenticate(request, token=token)

            if user is not None:
                login(request, user)
                request.session[settings.SESSION_KEY_EMAIL] = user.email
                return user

        except KeyError:
            pass

        return AnonymousUser()

    @staticmethod
    def _get_internal_user(request):
        if not request.yauser:
            return AnonymousUser()

        # Internal authentication
        yauser = request.yauser

        lookup = {'username': yauser.username, 'is_staff': True}

        try:
            user = User.objects.get(**lookup)

            if user.is_deleted:
                return AnonymousUser()

        except User.DoesNotExist:

            with atomic():
                user = User.objects.create(
                    password=make_password(None), **lookup
                )

        fields = dict(settings.YAUTH_PASSPORT_FIELDS)

        user_fields = {
            field: getattr(yauser, field) for field in fields.values()
        }
        user_fields['email'] = yauser.default_email

        serializer = UserinfoSerializer(data=user_fields)
        try:
            serializer.is_valid(raise_exception=True)
            data = serializer.validated_data

            # Set or update userinfo from blackbox
            update_fields = []

            for field, value in data.items():
                if getattr(user, field) != value:
                    update_fields.append(field)
                    setattr(user, field, value)

            if update_fields:
                with atomic():
                    user.save(update_fields=update_fields)

        except ValidationError:
            logger.exception('Malformed data from blackbox')
            return user

        return user


class LocaleMiddleware(object):
    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):

        # Default language
        language = settings.LANGUAGE_CODE

        query_lang = request.GET.get('lang')

        if query_lang in settings.LANGUAGES:
            language = query_lang

        translation.activate(language)
        request.LANGUAGE_CODE = translation.get_language()

        response = self.get_response(request)
        response['Content-Language'] = language

        return response


class DowntimeMiddleware(object):
    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):

        if not request.path.startswith('/api'):
            return self.get_response(request)

        maintenance = get_maintenance()
        if maintenance:
            started_at, downtime = maintenance

            response = ErrorResponse(
                msg=_('ERRORS::SERVICE_UNDER_MAINTENANCE'),
                status=status.HTTP_503_SERVICE_UNAVAILABLE,
            )

            response['Retry-After'] = (started_at + downtime).isoformat()
            return response

        readonly = is_readonly()

        if readonly:
            routers.use_state('slave')

            if request.method.upper() not in SAFE_METHODS:
                return ErrorResponse(
                    msg=_('ERRORS::READONLY'),
                    status=status.HTTP_503_SERVICE_UNAVAILABLE,
                )

        response = self.get_response(request)

        if readonly:
            try:
                routers.revert()
            except IndexError:
                pass
            response['PROCU-ReadOnly'] = 'true'

        return response


class PathPermissionMiddleware(object):
    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):

        path = request.path
        user = request.user

        if path.startswith('/api/'):

            if not user.is_authenticated:
                allowed = (
                    '/api/attachments',
                    '/api/users',
                    '/api/monitoring',
                    '/api/misc',
                    '/api/tracker/scs',
                    '/api/reply',
                )

            elif not user.is_staff:
                allowed = (
                    '/api/requests',
                    '/api/suggests',
                    '/api/attachments',
                    '/api/users',
                    '/api/account',
                    '/api/monitoring',
                )

            else:
                allowed = ('/',)

            if not any(map(path.startswith, allowed)):
                return ErrorResponse(
                    msg=_('ERRORS::UNKNOWN_METHOD'),
                    status=status.HTTP_404_NOT_FOUND,
                )

        return self.get_response(request)


class QueryLogger:
    def __init__(self, th=0.1):
        self.th = th
        self.queries = []

    def __call__(self, execute, sql, params, many, context):

        start = perf_counter()

        try:
            result = execute(sql, params, many, context)
        except Exception:
            raise
        else:
            return result
        finally:
            duration = perf_counter() - start

            params = context['connection'].connection.get_dsn_parameters()

            if duration > self.th:
                self.queries.append(
                    {
                        'sql': sql,
                        'time': int(duration * 1000000.),
                        'dbname': params['dbname'],
                    }
                )


class LogRequestContextMiddleware(object):
    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):

        if request.path.startswith('/api/monitoring/'):
            return self.get_response(request)

        host = request.get_host()
        domain, port = split_domain_port(host)

        req = {
            'ip': get_real_ip(request),
            'method': request.method,
            'domain': domain,
            'port': port,
            'path': request.path,
            'user_agent': request.META.get('HTTP_USER_AGENT', '__empty__'),
        }

        # query_logger = request.query_logger = QueryLogger(
        #     settings.LOG_QUERY_THRESHOLD
        # )
        #
        # with connection.execute_wrapper(query_logger):

        with log_context(request=req):
            return self.get_response(request)
