# coding: utf-8
"""
A wrapper around Flask-Principal extention to automate session information
save, load and cache via blackbox and auth backend.
"""
from logging import getLogger

import inject
from cachetools import TTLCache
from flask import redirect, request, abort, current_app
from flask_principal import Principal, Identity, PermissionDenied, AnonymousIdentity
try:
    from mongoengine.connection import ConnectionError
except ImportError:
    from mongoengine.connection import MongoEngineConnectionError as ConnectionError
from pymongo.errors import ConnectionFailure
from infra.swatlib import metrics
from sepelib.yandex.oauth import IOAuth, OAuthException
from infra.swatlib.auth.dao import User, Group
from infra.swatlib.auth.passport import IPassportClient

from .roles import AccessNeed, AccessRight
from .staff import ICachingStaffClient
from . import staff


class AuthenticatorException(Exception):
    pass


class NeedRedirectError(Exception):
    def __init__(self, response):
        self.response = response


class AllCanIdentity(Identity):
    """
    Fake identity object used if authentication is disabled.
    """

    def can(self, _):
        return True


class LazyProvidesIdentity(Identity):
    """
    Identity which lazy loades our API permissions.

    These API permissions are not required for every action, but still we load them on every request.
    """

    def __init__(self, id, authenticator, auth_type=None):
        # Don't call base class __init__ as we are changing provides to property.
        self.id = id
        self.authenticator = authenticator
        self.auth_type = auth_type
        self._provides = None

    def _load_permissions(self):
        provides = set()
        if self.auth_type == Authenticator.DUMMY_AUTH:
            return provides
        try:
            user = User.objects(id=self.id).first()
            if user:
                for role in user.roles:
                    provides.add(AccessNeed(role.module, role.access_right))

            group_ids = self.authenticator.get_group_ids(self.id)
            for group in Group.objects.filter(id__in=group_ids):
                for role in group.roles:
                    provides.add(AccessNeed(role.module, role.access_right))
        except (ConnectionFailure, ConnectionError) as e:
            # catch connection error and provide better one
            # explaining why random request failed because of mongo
            raise AuthenticatorException('Failed to load authorization info from database: {}'.format(e))
        return provides

    @property
    def provides(self):
        if self._provides is None:
            self._provides = self._load_permissions()
        return self._provides


all_can_identity = AllCanIdentity(id='anonymous')
all_can_identity.provides.add(AccessNeed('api', AccessRight.Write))


class _Stats(object):
    """
    Object encapsulating authenticator stats.
    Just a convenient helper.
    """

    def __init__(self):
        registry = metrics.ROOT_REGISTRY.path('web-authenticator')
        self._session_id_blackbox_cache_hit = registry.get_counter('session_id_blackbox_cache_hit')
        self._oauth_token_blackbox_cache_hit = registry.get_counter('oauth_token_blackbox_cache_hit')
        self._staff_cache_hit = registry.get_counter('staff_cache_hit')
        self._staff_database_cache_hit = registry.get_counter('staff_database_cache_hit')

    def mark_session_id_blackbox_cache_hit(self):
        self._session_id_blackbox_cache_hit.inc(1)

    def mark_oauth_token_blackbox_cache_hit(self):
        self._oauth_token_blackbox_cache_hit.inc(1)

    def mark_staff_cache_hit(self):
        self._staff_cache_hit.inc(1)

    def mark_staff_database_cache_hit(self):
        self._staff_database_cache_hit.inc(1)


class Authenticator(object):
    """
    The one that hooks into request processor and manages permissions
    by quering:
        * blackbox (for authentication)
        * staff (for getting users' groups)
        * roles stored in storage
    """
    RETRY_TIMEOUT = 10
    DUMMY_AUTH = 'dummy_auth'

    CACHE_MAX_SIZE = 500  # Should be more than enough (we don't have many users).
    # This should also work - don't think that using invalid cookie for some time is a crime.
    # But this will allow us not to issue blackbox requests on subsequent ajax requests to nanny.
    # This will help in case blackbox and/or staff API blackouts. There have been incidents already.
    CACHE_TTL = 60 * 60 * 24  # https://st.yandex-team.ru/SWAT-1762

    IDENTITY_CLS = LazyProvidesIdentity

    caching_staff = inject.attr(ICachingStaffClient)
    oauth = inject.attr(IOAuth)
    passport = inject.attr(IPassportClient)

    @classmethod
    def _init_ttl_cache(cls, ttl=None):
        ttl = cls.CACHE_TTL if ttl is None else ttl
        return TTLCache(maxsize=cls.CACHE_MAX_SIZE,
                        ttl=ttl)

    def __init__(self, app, enable=True, groups_cache_ttl=None):
        self._log = getLogger(__name__)
        self._enable = enable
        self._session_id_passport_cache = self._init_ttl_cache()
        self._oauth_token_passport_cache = self._init_ttl_cache()
        self._user_groups_cache = self._init_ttl_cache(groups_cache_ttl)
        self._stats = _Stats()

        # setup flask application
        principal = Principal(app, skip_static=True, use_sessions=False)
        # use lambda to let tests mock load_identity later
        principal.identity_loader(lambda: self.load_identity())
        app.errorhandler(NeedRedirectError)(self.handle_need_redirect_error)
        app.errorhandler(PermissionDenied)(self.handle_permission_denied)

    def handle_permission_denied(self, e):
        raise NotImplementedError()

    def handle_need_redirect_error(self, e):
        raise NotImplementedError()

    def load_identity(self):
        if not self._enable:
            return all_can_identity

        if request.endpoint:
            view = current_app.view_functions.get(request.endpoint)
            # Authentication disabled upon handler request
            # via magic attribute.
            if not getattr(view, 'need_auth', True):
                return AnonymousIdentity()
        else:
            # This means that rule didn't match and flask will return 404.
            # So don't bother checking anything.
            return AnonymousIdentity()

        login = None
        user_ip = request.access_route[0] if request.access_route else '127.0.0.1'

        authorization_header = request.headers.get('Authorization')
        if authorization_header:
            login = self._oauth_token_passport_cache.get(authorization_header)
            if login is not None:
                self._stats.mark_oauth_token_blackbox_cache_hit()
                return self.IDENTITY_CLS(login, self)
            try:
                login = self.oauth.get_user_login_by_authorization_header(
                    self.passport,
                    authorization_header,
                    user_ip)
            except OAuthException:
                abort(401)
            else:
                self._oauth_token_passport_cache[authorization_header] = login

        if not login:
            # Actually we're exploiting our knowledge about cookies, that's not good
            # but should do.
            session_id = request.cookies.get('Session_id')
            # Let us check if Session_id cookie value is cached.
            if session_id:
                login = self._session_id_passport_cache.get(session_id)
                if login is not None:
                    self._stats.mark_session_id_blackbox_cache_hit()
                    return self.IDENTITY_CLS(login, self)
            result = self.passport.check_passport_cookie(
                cookies=request.cookies,
                host=request.host,
                user_ip=user_ip,
                request_url=request.url
            )
            if result.redirect_url:
                # this method must return identity
                # so we break call stack by raising exception
                # and catching it using flask error handlers mechanism
                # and responding with redirect
                self._log.info('Could not authenticate req "%s", authorization header given: %s, session_id given: %s',
                               user_ip,
                               bool(authorization_header),
                               bool(session_id))
                response = redirect(result.redirect_url)
                raise NeedRedirectError(response)
            elif result.login:
                login = result.login
                # Put login into cache
                self._session_id_passport_cache[session_id] = login
        return self.IDENTITY_CLS(login, self)

    def user_belongs_to_one_of_groups(self, login, appr_group_ids):
        group_ids = self._user_groups_cache.get(login)
        if group_ids is not None:
            self._stats.mark_staff_cache_hit()
            if appr_group_ids & set(group_ids):
                return True

        group_ids = self.caching_staff.get_cached_group_ids(login)
        if group_ids:
            self._user_groups_cache[login] = group_ids
            if appr_group_ids & set(group_ids):
                return True
        try:
            group_ids = self.caching_staff.get_group_ids_from_staff(login)
        except staff.StaffError:
            self._log.exception('Failed to resolve %s\'s staff groups', login)
            return False
        self._user_groups_cache[login] = group_ids
        return bool(appr_group_ids & set(group_ids))

    def get_group_ids(self, login):
        # check in-memory cache first
        group_ids = self._user_groups_cache.get(login)
        if group_ids is not None:
            self._stats.mark_staff_cache_hit()
            return group_ids
        # request staff to get groups
        try:
            group_ids = self.caching_staff.get_group_ids(login)
        except staff.CachingStaffError as e:
            self._log.error('Failed to resolve %s\'s staff groups: %s', login, str(e.error))
            if e.cached is None:
                self._log.error('Using an empty list of group ids for %s, '
                                'authorization may not work', login)
                return []
            else:
                self._log.warning('Using the cached list of group ids for %s', login)
                self._stats.mark_staff_database_cache_hit()
                group_ids = e.cached
        # fill in-memory cache
        self._user_groups_cache[login] = group_ids
        return group_ids
