# -*- coding: utf-8 -*-

import json
import logging
import time

from flask import (
    g,
    jsonify,
    request,
    Response,
)
from flask.views import View
from passport.backend.core.builders.blackbox.constants import BLACKBOX_OAUTH_VALID_STATUS
from passport.backend.core.lazy_loader import LazyLoader
from passport.backend.utils.text import camel_to_snake
from passport.backend.vault.api import tvm
from passport.backend.vault.api.builders.blackbox import get_blackbox
from passport.backend.vault.api.db import get_db
from passport.backend.vault.api.errors import (
    AccessError,
    BaseError,
    DatabaseOperationalError,
    InvalidOauthTokenError,
    InvalidScopesError,
    LastOwnerError,
    LoginHeaderInRsaSignatureRequiredError,
    mask_oauth_token,
    OutdatedRsaSignatureError,
    RsaSignatureError,
    ServiceTicketParsingError,
    ServiceTicketRequiredError,
    TimestampHeaderInRsaSignatureRequiredError,
    TvmGrantRequiredError,
    UserAuthRequiredError,
    UserTicketParsingError,
    ValidationError,
    ZeroDefaultUidError,
)
from passport.backend.vault.api.models import (
    Roles,
    SecretVersion,
    TvmGrants,
    UserInfo,
    UserRole,
)
from passport.backend.vault.api.models.base import State
from passport.backend.vault.api.utils.secrets import (
    verify_signature_v1,
    verify_signature_v2,
    verify_signature_v3,
)
from sqlalchemy.exc import OperationalError as SAOperationalError
from ticket_parser2 import (
    ServiceTicket,
    UserTicket,
)
from ticket_parser2.exceptions import TicketParsingException
from werkzeug.datastructures import (
    CombinedMultiDict,
    ImmutableMultiDict,
)


VAULT_API_OAUTH_TOKEN_SCOPE = 'vault:use'


class cached_property(object):
    def __init__(self, func):
        self.func = func

    def __get__(self, instance, type=None):
        if instance is None:
            return self
        res = instance.__dict__[self.func.__name__] = self.func(instance)
        return res


class BaseView(View):
    """
    Базовый класс для ручек
    -----
    returns = []  # Ключи в результирующем JSON'е
    raises = []  # Ошибки
    example = {}  # Пример
    """
    autodoc = True
    _cached_views = {}
    form = None

    required_user_auth = True

    statbox_mode = None
    use_slave = False

    base_auth_errors = [
        AccessError,
        InvalidOauthTokenError,
        InvalidScopesError,
        LoginHeaderInRsaSignatureRequiredError,
        OutdatedRsaSignatureError,
        RsaSignatureError,
        ServiceTicketRequiredError,
        ServiceTicketParsingError,
        TimestampHeaderInRsaSignatureRequiredError,
        UserAuthRequiredError,
        UserTicketParsingError,
        ZeroDefaultUidError,
    ]
    base_common_errors = [
        DatabaseOperationalError,
    ]
    base_form_errors = [
        ValidationError,
    ]

    @classmethod
    def as_view(cls, *class_args, **class_kwargs):
        name = camel_to_snake(cls.__name__)
        if name not in cls._cached_views:
            cls._cached_views[name] = super(BaseView, cls).as_view(name, *class_args, **class_kwargs)
            cls.store_additional_data(cls._cached_views[name], cls)
        return cls._cached_views[name]

    @classmethod
    def store_additional_data(cls, view, view_class):
        view.base_auth_errors = view_class.base_auth_errors
        view.base_common_errors = view_class.base_common_errors
        view.base_form_errors = view_class.base_form_errors
        if hasattr(view_class, 'form'):
            view.form = view_class.form
        if hasattr(view_class, 'required_user_auth'):
            view.required_user_auth = view_class.required_user_auth

    def __init__(self):
        super(BaseView, self).__init__()
        self.response_values = dict()
        self.response_status_code = 200
        self.processed_form = None
        self.statbox_auth_info = dict()
        self.statbox_extra_info = dict()

    def statbox_log(self, **kwargs):
        if self.statbox_mode:
            log_object = dict(mode=self.statbox_mode)
            log_object.update(self.statbox_auth_info)
            log_object.update(self.statbox_extra_info)
            log_object.update(kwargs)
            logging.getLogger('statbox').info(log_object)

    def _serialize_request_v1(self, url, request_data, timestamp):
        return '%s\n%s\n%s' % (url.rstrip('?'), json.dumps(request_data, sort_keys=True), timestamp)

    def _serialize_request_v2(self, method, path, data, timestamp, login):
        return '%s\n%s\n%s\n%s\n%s\n' % (method.upper(), path, data, timestamp, login)

    def _serialize_request_v3(self, method, path, data, timestamp, login):
        return '%s\n%s\n%s\n%s\n%s\n' % (method.upper(), path, data, timestamp, login)

    def _acquire_credentials(self, login):
        user_info = UserInfo.get_by_login(login=login, with_keys=True, raise_access_error=True)
        return user_info.uid, user_info.keys

    def add(self, *args):
        db = get_db()
        for arg in args:
            db.session.add(arg)

    def commit(self, *args):
        self.add(*args)
        get_db().session.commit()

    def flush(self, *args):
        self.add(*args)
        get_db().session.flush()

    @property
    def config(self):
        return LazyLoader.get_instance('config')

    @property
    def authorization(self):
        header = request.headers.get('Authorization')
        return str(header) if header else None

    @property
    def rsa_signature(self):
        header = request.headers.get('X-Ya-Rsa-Signature')
        return str(header) if header else None

    @property
    def rsa_login(self):
        header = request.headers.get('X-Ya-Rsa-Login')
        return str(header) if header else None

    @property
    def rsa_timestamp(self):
        header = request.headers.get('X-Ya-Rsa-Timestamp')
        return int(header) if header else None

    @property
    def user_ticket(self):
        header = request.headers.get('X-Ya-User-Ticket')
        return str(header) if header else None

    @property
    def service_ticket(self):
        header = request.headers.get('X-Ya-Service-Ticket')
        return str(header) if header else None

    @property
    def fail_if_service_ticket_missing(self):
        return self.config['tvm_grants']['fail_if_service_ticket_missing']

    @property
    def skip_grants(self):
        return self.config['tvm_grants']['skip_grants']

    def check_grants(self, source_tvm_client_id):
        TvmGrants.check(source_tvm_client_id)

    @cached_property
    def validated_uid(self):
        if self.user_ticket is not None:
            return self._authorize_by_user_ticket()
        elif self.rsa_signature:
            return self._authorize_by_rsa_signature()
        elif self.authorization:
            return self._authorize_by_oauth_token()

        raise UserAuthRequiredError()

    def _validate_service_ticket(self, service_ticket):
        if service_ticket is not None:
            try:
                parsed_service_ticket = tvm.get_service_context().check(
                    service_ticket.encode('utf-8')
                )
                source_tvm_client_id = parsed_service_ticket.src
                self.statbox_auth_info.update({
                    'auth_service_ticket': ServiceTicket.remove_signature(service_ticket),
                    'auth_tvm_app_id': source_tvm_client_id,
                })
            except TicketParsingException as ex:
                raise ServiceTicketParsingError(
                    ticket_status=ex.status.value,
                    ticket_message=ex.message,
                )

            try:
                self.check_grants(source_tvm_client_id)
                self.statbox_auth_info.update({
                    'tvm_grants': 'granted',
                })
            except TvmGrantRequiredError:
                self.statbox_auth_info.update({
                    'tvm_grants': 'required',
                })
                if not self.skip_grants:
                    raise
        elif self.fail_if_service_ticket_missing:
            raise ServiceTicketRequiredError()

    def _authorize_by_user_ticket(self):
        self._validate_service_ticket(self.service_ticket)

        try:
            parsed_user_ticket = tvm.get_user_context(self.config).check(self.user_ticket)
            if not (parsed_user_ticket.has_scope('bb:sessionid') or parsed_user_ticket.has_scope('vault:use')):
                raise InvalidScopesError(private_info={'user_ticket_scopes': parsed_user_ticket.scopes})
            uid = parsed_user_ticket.default_uid
            user_ticket_for_logging = UserTicket.remove_signature(self.user_ticket)
            self.statbox_auth_info.update({
                'auth_uid': uid,
                'auth_type': 'user_ticket',
                'auth_user_ticket': user_ticket_for_logging,
            })
            if uid == 0:
                raise ZeroDefaultUidError(
                    private_info={
                        'user_ticket': user_ticket_for_logging,
                    },
                )
            return uid
        except TicketParsingException as ex:
            raise UserTicketParsingError(ticket_status=ex.status.value, ticket_message=ex.message)

    def _authorize_by_rsa_signature(self):
        if not self.rsa_login:
            raise LoginHeaderInRsaSignatureRequiredError()

        uid, keys = self._acquire_credentials(self.rsa_login)

        if not self.rsa_timestamp:
            raise TimestampHeaderInRsaSignatureRequiredError(uid=uid)

        rsa_timestamp_delta = abs(int(time.time()) - self.rsa_timestamp)
        if rsa_timestamp_delta > self.config['application']['stale_time']:
            raise OutdatedRsaSignatureError(
                uid=uid,
                private_info=dict(
                    client_rsa_timestamp=self.rsa_timestamp,
                    scale_time=self.config['application']['stale_time'],
                    rsa_timestamp_delta=rsa_timestamp_delta,
                ),
            )

        self.statbox_auth_info.update({
            'auth_type': 'rsa',
            'auth_uid': uid,
            'auth_login': self.rsa_login,
        })
        serialized_request_v3 = self._serialize_request_v3(
            method=request.method,
            path=self.config['application']['balancer'] + request.full_path,
            data=request.data,
            timestamp=self.rsa_timestamp,
            login=self.rsa_login,
        )
        verification_v3 = verify_signature_v3(
            serialized_request_v3,
            self.rsa_signature,
            keys,
        )
        if verification_v3:
            self.statbox_auth_info['auth_rsa_signature_version'] = 3
            return uid

        serialized_request_v2 = self._serialize_request_v2(
            method=request.method,
            path=self.config['application']['balancer'] + request.full_path,
            data=request.data,
            timestamp=self.rsa_timestamp,
            login=self.rsa_login,
        )
        verification_v2 = verify_signature_v2(
            serialized_request_v2,
            self.rsa_signature,
            keys,
        )
        if verification_v2:
            self.statbox_auth_info['auth_rsa_signature_version'] = 2
            return uid

        serialized_request_v1 = self._serialize_request_v1(
            self.config['application']['balancer'] + request.full_path,
            request.json if self.is_json_request else request.form,
            self.rsa_timestamp,
        )
        verification_v1 = verify_signature_v1(
            serialized_request_v1,
            self.rsa_signature,
            keys,
        )
        if verification_v1:
            self.statbox_auth_info['auth_rsa_signature_version'] = 1
            return uid

        raise RsaSignatureError(
            rsa_login=self.rsa_login,
            rsa_timestamp=self.rsa_timestamp,
            uid=uid,
        )

    def _authorize_by_oauth_token(self):
        auth_header = self.authorization

        if ' ' not in auth_header:
            auth_header = 'Oauth %s' % auth_header
        bb_response = get_blackbox().oauth(
            headers={'Authorization': auth_header},
            ip=self.user_ip,
            attributes=[],
        )
        if (
            bb_response['status'] != BLACKBOX_OAUTH_VALID_STATUS or
            not bb_response.get('uid')
        ):
            raise InvalidOauthTokenError(
                oauth_token=auth_header,
                blackbox_error=bb_response.get('error'),
            )

        if VAULT_API_OAUTH_TOKEN_SCOPE not in bb_response.get('oauth', {}).get('scope', []):
            raise InvalidScopesError(
                oauth_token=auth_header,
                uid=bb_response['uid'],
            )

        self.statbox_auth_info.update({
            'auth_type': 'oauth',
            'auth_oauth_token': mask_oauth_token(auth_header),
            'auth_uid': bb_response['uid'],
        })

        return bb_response['uid']

    @property
    def uid(self):
        try:
            return self.validated_uid
        except BaseError:
            return None

    @property
    def user_ip(self):
        return request.user_ip

    @property
    def user_agent(self):
        return request.real_user_agent

    def check_user_access(self):
        return bool(not self.required_user_auth or self.validated_uid)

    def process_request(self, *args, **kwargs):
        raise NotImplementedError()  # pragma: no cover

    def dispatch_request(self, *args, **kwargs):
        try:
            consumer = request.args.get('consumer')
            if consumer:
                self.statbox_extra_info['consumer'] = consumer

            self.check_user_access()
            self.statbox_log(action='enter')
            if self.form is not None:
                real_form = self.form
                if isinstance(self.form, (list, tuple)):
                    real_form = type('DynamicForm', self.form, {})
                self.processed_form = self.process_form(real_form)

            g.use_slave = self.use_slave
            try:
                return_value = self.process_request(*args, **kwargs)
            except SAOperationalError as e:
                raise DatabaseOperationalError(e, self.use_slave)

            if return_value:
                response = Response(return_value)
            elif self.response_values:
                self.response_values.setdefault('status', 'ok')
                response = jsonify(self.response_values)
            else:
                response = jsonify({'status': 'ok'})

            if getattr(self, 'allow_origin', False):
                response.headers['Access-Control-Allow-Origin'] = '*'

            # Запрещаем индексирование страниц
            response.headers['X-Robots-Tag'] = 'noindex, noarchive'
            self.statbox_log(action='exit')
            return response, self.response_status_code

        except BaseError as e:
            e.private_info.update({'mode': self.statbox_mode})
            e.private_info.update(self.statbox_auth_info)
            raise e

    @property
    def is_json_request(self):
        return bool(
            request.content_type
            and request.content_type.lower().startswith('application/json')
            and request.data is not None
            and request.data.strip() != ''
        )

    def process_form(self, form, skip_exception=False):
        if self.is_json_request:
            raw_data = request.get_json()
            raw_data.update(request.files)
            raw_data.update(request.args)
            raw_data.update(request.view_args)
            processed_form = form.from_json(raw_data)
        else:
            raw_data = CombinedMultiDict((
                ImmutableMultiDict(request.view_args),
                request.args,
                request.files,
                request.form,
            ))
            processed_form = form(raw_data)
        if not processed_form.validate() and not skip_exception:
            raise ValidationError(processed_form.errors)
        return processed_form

    def check_if_has_role(self, role, secret_uuid=None, bundle_uuid=None, uid=None, check_if_last=False, raises=True):
        if uid is None:
            uid = self.validated_uid

        your_roles_count = UserRole.get_roles_count(
            uid=uid,
            roles=role,
            secret_uuid=secret_uuid,
            bundle_uuid=bundle_uuid,
        )
        if your_roles_count == 0:
            if raises:
                raise AccessError()
            return False
        if check_if_last:
            all_roles_count = UserRole.get_roles_count(
                roles=role,
                secret_uuid=secret_uuid,
                bundle_uuid=bundle_uuid,
            )
            if all_roles_count < 2:
                if raises:
                    raise LastOwnerError()
                return False
        return True

    def check_if_supervisor(self, raises=True):
        supervisor_roles_count = UserRole.query.filter(
            UserRole.external_type == 'user',
            UserRole.role_id == Roles.SUPERVISOR.value,
            UserRole.uid == self.validated_uid,
        ).count()
        result = supervisor_roles_count > 0
        if not result and raises:
            raise AccessError()
        return result

    def check_state(self, version, response_dict):
        transitive_state = None
        if isinstance(version, dict):
            transitive_state = version.get('transitive_state')
        elif hasattr(version, 'transitive_state'):
            transitive_state = version.transitive_state
        if transitive_state is not None and transitive_state == State.hidden.value:
            logging.getLogger('exception_logger').warning('Acquiring hidden version')
            response_dict['status'] = 'warning'
            response_dict['warning_message'] = 'version is hidden'
        return transitive_state

    def check_version_expiration(self, version, response_dict):
        if SecretVersion.check_expired(version):
            logging.getLogger('exception_logger').warning('Acquiring expired version')
            if response_dict.get('status', 'ok') == 'ok':
                response_dict['status'] = 'warning'
                response_dict['warning_message'] = 'version is expired'
