# coding: utf-8



import os
import logging
from datetime import datetime

from django.conf import settings
from django.contrib.auth.middleware import AuthenticationMiddleware
from django.http import HttpResponseForbidden
from django_yauth import middleware
from django.shortcuts import get_object_or_404
from requests import HTTPError

from at.common.objects import AuthInfo
from at.common import utils
from at.aux_ import models
from at.common.tvm2_client import get_tvm2_client

log = logging.getLogger(__name__)
dig_log = logging.getLogger(__name__ + '.dig')
trace_log = logging.getLogger(__name__ + '.trace')


class AuthInfoMiddleware(object):

    def process_request(self, request):
        if hasattr(request, 'ai'):
            return

        request._tracelist.append('AuthInfoMiddleware')
        if request.yauser.uid:
            auth_header = request.META.get('Authorization')
            oauth_token = auth_header and auth_header[len('OAuth '):]
            login = getattr(request.yauser, 'login', None)
            request.ai = AuthInfo(
                uid=int(request.yauser.uid),
                login=login,
                session_id=request.COOKIES.get('Session_id'),
                oauth_token=oauth_token,
                user_ticket=request.yauser.raw_user_ticket,
            )
            request.user = get_object_or_404(
                models.Person,
                person_id=int(request.yauser.uid),
            )
        else:
            request.ai = None


class DebugMiddleware(object):

    def process_request(self, request):
        request._tracelist = Tracelist()
        request._tracelist.append('dbgmw.process_request')

        headers = request.META if 'LOG_HEADERS' in os.environ else '*'

        msg = '\n'.join([
            'PATH: %s [%s]' % (request.path, request.get_full_path()),
            'METHOD: %s' % request.method,
            'GET: %s' % request.GET,
            'POST: %s' % request.POST,
            'META: %s' % headers,
        ])
        dig_log.debug(msg)

    def process_response(self, request, response):
        if not hasattr(request, '_tracelist'):
            return response

        request._tracelist.append('dbgmw.process_response')
        method = get_method_from_path(request.path)

        msg = 'FULLTIME for %s is %s total.' % (
            method,
            utils.delta_in_ms(request._tracelist.delta_total)
        )
        prev = None
        max_diff = None
        for mark, timestamp in request._tracelist:
            if prev is not None:
                diff = timestamp - prev
                if max_diff is None or diff > max_diff[1]:
                    max_diff = mark, diff
                msg += ' +%s [%s]' % (utils.delta_in_ms(diff), mark)
            else:
                msg += ' %s [%s]' % (timestamp.time().isoformat(), mark)
            prev = timestamp

        if max_diff is not None:
            msg += '. max %s (%s)' % (
                max_diff[0],
                utils.delta_in_ms(max_diff[1])
            )

        trace_log.debug(msg)
        return response


class Tracelist(list):

    def append(self, p_object):
        super(Tracelist, self).append((p_object, datetime.now()))

    @property
    def delta_total(self):
        return self[-1][1] - self[0][1]


def get_method_from_path(path):
    path = path.strip('/')
    if '/' in path:
        head, method = path.rsplit('/', 1)
    else:
        method = path
    return method


NO_FORCED_AUTH_FOR_METHODS = (
    'ping_as_xml',
    'authenticate',
    'get_feed_info',
    'get_feed_id',
    'get_page_layout',
    'get_feed_light',
)


class AtYandexAuthRequiredMiddleware(middleware.YandexAuthRequiredMiddleware):

    def process_request(self, request):
        method = get_method_from_path(request.path)
        if method in NO_FORCED_AUTH_FOR_METHODS:
            return
        return super(AtYandexAuthRequiredMiddleware, self).process_request(request)

    def process_response(self, request, response):
        if response.status_code == 401 and 'Session_id' not in request.COOKIES:
            log.warning('Session_id cookie missing for %s', request.path)
        return response


class TVMMiddleware(AuthenticationMiddleware):
    user_ticket_header = 'HTTP_X_YA_USER_TICKET'
    service_ticket_headers = ['HTTP_X_YA_SERVICE_TICKET', 'HTTP_X_TVM2_TICKET']

    @property
    def is_required(self):
        return getattr(settings, 'TVM_TICKET_REQUIRED', True)

    def get_service_ticket(self, request):
        service_ticket = None
        for header in self.service_ticket_headers:
            service_ticket = request.META.get(header)
            if service_ticket:
                break

        return service_ticket

    def check_tickets(self, request):
        tvm2_client = get_tvm2_client()
        service_ticket = self.get_service_ticket(request)

        if not service_ticket:
            if not self.is_required:
                return

            err_msg = 'TVM Service ticket is required, but not provided'
            log.error(err_msg)
            return HttpResponseForbidden(err_msg, status=401)

        parsed_service_ticket = tvm2_client.parse_service_ticket(service_ticket)

        if parsed_service_ticket:
            request.tvm_service_id = getattr(parsed_service_ticket, 'src', None)
        else:
            err_msg = 'Unknown TVM service ticket'
            log.warning(err_msg)
            return HttpResponseForbidden(err_msg)

        user_ticket = request.META.get(self.user_ticket_header)
        if user_ticket:
            err_msg = None

            try:
                parsed_user_ticket = tvm2_client.parse_user_ticket(user_ticket)
                if parsed_user_ticket:
                    request.tvm_uid = getattr(parsed_user_ticket, 'default_uid', None)
                else:
                    err_msg = 'Unknown TVM user ticket'

            except HTTPError as exc:
                err_msg = str(exc)

            if err_msg:
                log.error(err_msg)
                return HttpResponseForbidden(err_msg)

        return

    def process_request(self, request):
        request.tvm_service_id = None
        request.tvm_uid = None

        normal_path = os.path.normpath(request.path) + '/'

        if any(map(normal_path.startswith, getattr(settings, 'TVM_EXCLUDE_PATHS', []))):
            return

        http_accept = request.META.get('HTTP_ACCEPT', '')
        if settings.DEBUG and 'text/html' in http_accept:
            log.debug('TVM checks is skipped. DEBUG=%s, HTTP_ACCEPT=%s', settings.DEBUG, http_accept)
            # eсли DEBUG включен, то из браузера можно попасть куда угодно
            return

        return self.check_tickets(request)


class TVMDebugMiddleware(TVMMiddleware):
    debug_uid_header = 'HTTP_DEBUG_UID'
    debug_service_id_header = 'HTTP_DEBUG_SERVICE_ID'

    def get_debug_service_id(self, request):
        service_id = getattr(settings, 'DEBUG_TVM_SERVICE_ID', None)
        service_id = request.META.get(self.debug_service_id_header, service_id)

        if service_id:
            log.debug('TVM_SERVICE_ID=%s', service_id)
            return int(service_id)

        return None

    def get_debug_uid(self, request):
        debug_uid = getattr(settings, 'DEBUG_TVM_UID', None)
        debug_uid = request.META.get(self.debug_uid_header, debug_uid)

        if debug_uid:
            log.debug('TVM_UID=%s', debug_uid)

        return debug_uid

    def check_tickets(self, request):
        request.tvm_service_id = self.get_debug_service_id(request)
        request.tvm_uid = self.get_debug_uid(request)

        if request.tvm_service_id is None:
            return super().check_tickets(request)
