# -*- coding: utf-8 -*-

import base64
from functools import partial
import json
import logging
import zlib

import formencode.validators
from passport.backend.social.common import (
    oauth2,
    validators,
)
from passport.backend.social.common.exception import (
    InvalidResponseAttributeProxylibError,
    InvalidTokenProxylibError,
    ProviderCommunicationProxylibError,
    ResponseDecodeProxylibError,
    TooLongUnexpectedResponseProxylibError,
    UnexpectedResponseProxylibError,
)
from passport.backend.social.common.misc import (
    dump_to_json_string,
    expires_in_to_expires_at,
    trim_message,
)
from passport.backend.social.common.provider_settings import (
    get_profile_addresses,
    providers,
)
from passport.backend.social.common.social_config import social_config
from passport.backend.utils.common import (
    from_base64_url,
    to_base64_url,
)
from werkzeug.urls import url_decode


logger = logging.getLogger('proxylib.proxy')


def paginated(func):
    def pagination_processor(*args, **kwargs):
        request_number = 0
        result = []
        for items_block in func(*args, **kwargs):
            request_number += 1
            if request_number == 100:
                msg = 'Failed to get all the data in 100 requests'
                logger.warning(msg)
                raise ProviderCommunicationProxylibError(msg)

            result += items_block
        return result

    # сохраняем список переменных
    if hasattr(func, 'real_co_varnames'):
        pagination_processor.real_co_varnames = func.real_co_varnames
    else:
        pagination_processor.real_co_varnames = func.func_code.co_varnames
    return pagination_processor


def save_method_info(func):
    def processor(self, *args, **kwargs):
        self.r.method_name = func.__name__
        self.r.method_args = args
        self.r.method_kwargs = kwargs
        return func(self, *args, **kwargs)

    processor.real_co_varnames = func.func_code.co_varnames
    return processor


class RefreshTokenGetter(object):
    def __init__(self, repo, refresh_token, detect_error=None, auth_type_is_basic=False):
        self.r = repo
        self.refresh_token = refresh_token
        self.detect_error = detect_error
        self.auth_type_is_basic = auth_type_is_basic

    def _build_request_data(self):
        return {
            'grant_type': 'refresh_token',
            'refresh_token': self.refresh_token,
        }

    def _add_auth(self, data, headers):
        client_id = self.r.app.custom_provider_client_id or self.r.app.id
        if self.auth_type_is_basic:
            headers.update({
                'Authorization': 'Basic %s' % base64.b64encode(
                    '%s:%s' % (client_id, self.r.app.secret),
                ),
            })
        else:
            data.update({
                'client_id': client_id,
                'client_secret': self.r.app.secret,
            })

    def _build_request(self):
        data = self._build_request_data()
        headers = {}
        self._add_auth(data, headers)
        return data, headers

    def _parse_json(self):
        _check_token_response_length(self.r.context['raw_response'].data)
        self.r.deserialize_json()

    def _detect_error(self, detect_error):
        try:
            detect_error(self.r.context['data'])
        except oauth2.refresh_token.InvalidGrant as e:
            raise InvalidTokenProxylibError(e.error_description)
        except (
            oauth2.refresh_token.InvalidClient,
            oauth2.refresh_token.InvalidRequest,
            oauth2.refresh_token.InvalidScope,
            oauth2.refresh_token.UnauthorizedClient,
            oauth2.refresh_token.UnsupportedGrantType,
        ) as e:
            conf_attr = 'treat_%s_as_invalid_token' % e.error
            if not self.r.settings.get(conf_attr):
                raise
            logger.debug('Failed to refresh token: %s (%s)' % (e.error, e.error_description))
            raise InvalidTokenProxylibError(e.error_description)
        except oauth2.refresh_token.UnexpectedException:
            raise UnexpectedResponseProxylibError()

    def _validate_result(self):
        self.r.context['data'] = _validate_token_response_dict(self.r.context['data'], skip_if_invalid={'id_token'})
        if 'access_token' not in self.r.context['data']:
            raise InvalidResponseAttributeProxylibError(
                'Refresh token response does not contain token value',
                attr_name='access_token',
                max_length=_token_validator.max,
                min_length=1,
            )

    def _build_parser(self):
        def parse(detect_error):
            self._parse_json()
            self._detect_error(detect_error)
            self._validate_result()

        return parse

    def _parse_result(self):
        if self.detect_error is None:
            self.detect_error = oauth2.refresh_token.detect_error

        parse = self._build_parser()
        parse = partial(parse, self.detect_error)

        self.r.execute_request_basic(parser=parse)
        parse()

        self.r.extract_response_data(
            {
                'access_token': 'access_token',
                'expires_in': 'expires_in',
                'refresh_token': 'refresh_token',
            },
            one=True,
            listed=False,
        )

    def get_refresh_token(self):
        """
        Выдаёт access_token по refresh_token.

        В атрибуте SETTINGS должен быть ключ oauth_refresh_token_url.
        """
        data, headers = self._build_request()

        self.r.compose_request(
            url_name='oauth_refresh_token_url',
            data=data,
            additional_headers=headers,
            add_access_token=False,
        )

        self._parse_result()

        return self.r.context['processed_data']


class SocialProxy(object):
    SETTINGS = dict()
    REFRESH_TOKEN_GETTER_CLASS = RefreshTokenGetter

    def __init__(self):
        self.r = None

    def _refresh_token(self, refresh_token, detect_error=None, auth_type_is_basic=False):
        refresh_token_getter = self.REFRESH_TOKEN_GETTER_CLASS(
            repo=self.r,
            refresh_token=refresh_token,
            detect_error=detect_error,
            auth_type_is_basic=auth_type_is_basic,
        )
        return refresh_token_getter.get_refresh_token()

    def get_profile_links(self, userid, username=None):
        """
        возвращает список url (>= 0 штук) для профиля пользователя.
        """
        if not self.code:
            return list()
        provider_info = providers.get_provider_info_by_name(self.code)
        if not provider_info:
            return list()
        return get_profile_addresses(provider_info['id'], userid, username)

    def update_request_privacy(self, request_args, privacy, arg_name):
        if not hasattr(self, 'PRIVACY_LOCAL_BY_GLOBAL') or privacy is None:
            return
        privacy_local = self.PRIVACY_LOCAL_BY_GLOBAL.get(str(privacy))
        if not privacy_local:
            return
        request_args[arg_name] = privacy_local

    def _authorization_code_to_token(self, code):
        request = oauth2.token.build_authorization_code_request(
            endpoint=self.SETTINGS['oauth_access_token_url'],
            code=code,
            client_id=self.r.app.id,
            client_secret=self.r.app.secret,
        )
        self.r.compose_request(
            base_url=request.endpoint,
            method=request.method,
            additional_args=request.query,
            data=request.data,
            additional_headers=request.headers,
            add_access_token=False,
        )

        def parse():
            response = self.r.context['raw_response'].decoded_data
            self.r.context['data'] = parse_access_token_from_authorization_code_response(response)

        self.r.execute_request_basic(parser=parse)
        parse()
        return self.r.context['data']

    def parse_raw_token(self, raw_token):
        self.r.access_token = raw_token


class SocialPackedTokenVersionForm(validators.Schema):
    version = validators.Int()


class SocialPackedTokenParser(object):
    """
    Класс предлагает методы для декодирования упакованно токена из принятого
    в Социализме представления.
    """

    def __init__(
        self,
        version_forms,
        compression=None,
    ):
        """
        version_forms
        Отображение версии (числа) в соответствующую ей форму валидации (класс
        formencode.Schema).
        """
        self.compression = compression
        self.version_forms = version_forms

    def parse(self, doc):
        doc = self.str_to_json_doc(doc)
        token_dict = self.json_doc_to_dict(doc)
        return self.dict_to_typed_dict(token_dict)

    def str_to_json_doc(self, doc):
        try:
            bytes_ = from_base64_url(doc)
        except TypeError:
            raise InvalidTokenProxylibError()
        if not self.compression:
            return bytes_
        try:
            return zlib.decompress(bytes_)
        except zlib.error:
            raise InvalidTokenProxylibError()

    def json_doc_to_dict(self, doc):
        try:
            token_dict = json.loads(doc)
        except ValueError:
            raise InvalidTokenProxylibError()
        if not isinstance(token_dict, dict):
            raise InvalidTokenProxylibError()
        return token_dict

    def dict_to_typed_dict(self, token_dict):
        try:
            version = SocialPackedTokenVersionForm().to_python(token_dict)
        except validators.Invalid:
            raise InvalidTokenProxylibError()

        version_form = self.version_forms.get(version['version'])
        if not version_form:
            raise InvalidTokenProxylibError()

        try:
            token_dict = version_form().to_python(token_dict)
        except validators.Invalid:
            raise InvalidTokenProxylibError()

        return token_dict


class SocialPackedTokenSerializer(object):
    def __init__(
        self,
        version_forms,
        compression=None,
    ):
        """
        version_form
        Отображение версии (числа) в соответствующую ей форму валидации (класс
        formencode.Schema).
        """
        self.compression = compression
        self.version_forms = version_forms

    def serialize(self, token_dict):
        token_dict = self.typed_dict_to_dict(token_dict)
        doc = self.dict_to_json_doc(token_dict)
        return self.json_doc_to_str(doc)

    def typed_dict_to_dict(self, token_dict):
        version = SocialPackedTokenVersionForm().from_python(token_dict)
        version_form = self.version_forms.get(version['version'])
        return version_form().from_python(token_dict)

    def dict_to_json_doc(self, token_dict):
        return dump_to_json_string(token_dict, minimal=True)

    def json_doc_to_str(self, doc):
        if self.compression:
            doc = zlib.compress(doc, 1)
        return to_base64_url(doc)


def parse_access_token_from_authorization_code_response(response, _format='json', detect_error=None):
    if detect_error is None:
        detect_error = oauth2.token.detect_error

    _check_token_response_length(response)

    try:
        if _format == 'json':
            response = json.loads(response)
        elif _format == 'urlencoded':
            response = url_decode(response.encode('utf-8'))
        else:
            raise NotImplementedError()
    except ValueError:
        raise ResponseDecodeProxylibError(
            'Response is not in a %s format: %s' % (_format, response),
        )

    try:
        detect_error(response)
    except oauth2.token.InvalidGrant as e:
        raise InvalidTokenProxylibError(e.error_description)
    except oauth2.token.UnexpectedException:
        raise UnexpectedResponseProxylibError()

    response = _validate_token_response_dict(response)

    expires = response.get('expires_in') or response.get('expires')

    if response.get('expires') and not response.get('expires_in'):
        # Потом можно будет погрепать логи на эту строки и если её нет, удалить
        # наконец работу с нестандартным полем expires.
        logger.debug('Key "expires" still exists in token responses')

    if 'access_token' not in response:
        raise InvalidResponseAttributeProxylibError(
            'Access token response does not contain token value',
            attr_name='access_token',
            max_length=_token_validator.max,
            min_length=1,
        )

    token = {
        'value': response['access_token'],
        'expires': expires_in_to_expires_at(expires),
    }

    refresh_token_value = response.get('refresh_token')
    if refresh_token_value:
        token['refresh'] = refresh_token_value

    id_token_value = response.get('id_token')
    if id_token_value:
        token['id_token'] = id_token_value

    return token


def _check_token_response_length(response):
    if len(response) > social_config.max_token_response_length:
        message = 'Response is too long: %s' % trim_message(response)
        logger.debug(message)
        raise TooLongUnexpectedResponseProxylibError(message, len(response), social_config.max_token_response_length)


def _validate_token_response_dict(response, skip_if_invalid=None):
    if not isinstance(response, dict):
        message = 'Response is not a dict: %s' % trim_message(repr(response))
        logger.debug(message)
        raise UnexpectedResponseProxylibError(message)

    skip_if_invalid = set() if skip_if_invalid is None else skip_if_invalid

    retval = dict()
    for attr, validator in _TOKEN_RESPONSE_VALIDATORS:
        if not (response.get(attr) or response.get(attr) == 0):
            continue
        try:
            retval[attr] = validator.to_python(response[attr])
        except validators.Invalid:
            message = 'Invalid %s value: %s' % (attr, trim_message(response[attr]))
            logger.debug(message)

            if attr in skip_if_invalid:
                continue

            if attr in {'access_token', 'refresh_token'}:
                raise InvalidResponseAttributeProxylibError(
                    message,
                    attr_name=attr,
                    max_length=_token_validator.max,
                    min_length=1,
                )

            if attr in {'expires', 'expires_in'}:
                raise InvalidResponseAttributeProxylibError(
                    message,
                    attr_name=attr,
                )

            raise UnexpectedResponseProxylibError(message)
    return retval


_expires_in_validator = formencode.compound.All(
    validators=[
        formencode.validators.Wrapper(to_python=int),
        validators.Number(min=0, if_empty=None),
        validators.String(max=len(str(2 ** 32))),
    ],
)

_token_validator = validators.Token()

_TOKEN_RESPONSE_VALIDATORS = [
    ('expires_in', _expires_in_validator),
    ('expires', _expires_in_validator),
    ('access_token', _token_validator),
    ('refresh_token', _token_validator),
    ('id_token', _token_validator),
]
