# -*- coding: utf-8 -*-

import copy
from json import loads as load_from_json_string
import logging
from operator import attrgetter
import uuid

from passport.backend.social.common.chrono import now
from passport.backend.social.common.exception import ApplicationUnknown
from passport.backend.social.common.misc import (
    dump_to_json_string,
    FACEBOOK_BUSINESS_ID,
    remove_undefined_values,
    trim_message,
)
from passport.backend.social.common.provider_settings import providers
from passport.backend.social.common.random import (
    get_randomizer,
    urandom,
)
from passport.backend.social.common.redis_client import RedisError
from passport.backend.social.common.refresh_token.domain import RefreshToken
from passport.backend.social.common.social_config import social_config
from passport.backend.social.common.token.domain import Token
from passport.backend.social.common.token.utils import build_token_dict_for_proxy
from passport.backend.utils.common import remove_none_values


logger = logging.getLogger(__name__)


TASK_SERIALIZATOR_VERSION = 2


def generate_task_id():
    # Обёртка, чтобы удобно подменять результат вызова в тестах
    return _generate_task_id()


def _generate_task_id():
    return unicode(TaskId(_bytes=urandom(get_randomizer(), 16), environment_id=social_config.environment_id).hex)


def create_task():
    task = Task()
    task.task_id = generate_task_id()
    task.created = now.f()
    return task


def build_provider_for_task(code, name, id):
    return dict(code=code, name=name, id=id)


def save_task_to_redis(redis_client, task_id, task, version=TASK_SERIALIZATOR_VERSION):
    # TODO Удалять отсутствующие в данном таске атрибуты из Redis
    task_data = task._get_redis_data(version)

    logger.info('Saving data to redis: %s' % trim_message(task_data))

    res_hmset = redis_client.hmset(task_id, task_data)
    logger.debug("[redis]result = %s" % res_hmset)
    if not res_hmset:
        raise RedisError()

    expire_time = social_config.redis_task_expiration_time
    redis_client.expire(task_id, expire_time)

    task.in_redis = True


def load_task_from_redis(redis_client, task_id):
    logger.info('Loading data from redis: %s' % task_id)
    task_data = redis_client.hgetall(task_id)
    if not task_data:
        return
    task = Task()
    task._update_from_redis_data(task_data)
    task.task_id = task_id
    task.in_redis = True
    return task


def delete_task_from_redis(redis_client, task_id):
    return redis_client.delete(task_id)


class Task(object):
    dump_fields_mapping = {
        k: k for k in [
            'application',
            'callback_url',
            'code_challenge',
            'code_challenge_method',
            'collect_diagnostics',
            'consumer',
            'experiments',
            'in_redis',
            'nonce',
            'passthrough_errors',
            'place',
            'provider',
            'retpath',
            'scope',
            'sid',
            'state',
            'uid',
            'user_param',
            'yandexuid',
        ]
    }
    dump_fields_mapping.update(
        {
            'created': 'ts',
            'exchange': 'code',
            'request_token': 'req_t',
            'start_args': 'args',
            'task_id': 'tid',
        },
    )

    required_fields = set(['created', 'task_id', 'start_args', 'state'])
    profile_fields = set([
        'userid',
        'username',
        'firstname',
        'lastname',
        'gender',
        'birthday',
        'avatar',
        'email',
        'links',
        'phone',
    ])

    def __init__(self):
        # dict
        self.access_token = None
        # Application
        self.application = None
        self.callback_url = None
        self.code_challenge = None
        self.code_challenge_method = None
        self.collect_diagnostics = None
        self.consumer = None
        # float
        self.created = None
        self.exchange = None
        self.experiments = []
        # float
        self.finished = None
        self.in_redis = None
        self.nonce = None
        self.passthrough_errors = []
        self.place = None
        # dict
        self.provider = None
        # dict
        self.request_token = None
        self.retpath = None
        self.scope = None
        # dict
        self.start_args = None
        self.state = None
        self.task_id = None
        self.user_param = None
        self.yandexuid = None

        self.profile = {}
        self.profile_id = None
        self._uid = None
        self.sid = None

    @classmethod
    def from_social_userinfo(cls, social_userinfo, uid, binding_data):
        task = cls()

        task.created = task.finished = now.f()
        task.uid = uid
        task.profile = social_userinfo

        task.provider = social_userinfo['provider']

        if binding_data.token:
            app = providers.get_application_by_id(binding_data.token.application_id)
            task.application = app
            task.access_token = build_token_dict_for_proxy(
                binding_data.token,
                binding_data.refresh_token,
            )

        task.consumer = binding_data.consumer
        task.yandexuid = binding_data.yandexuid
        task.sid = binding_data.sid

        return task

    def to_json_dict(self):
        return dict(
            access_token=self.access_token,
            application=self.application and self.application.identifier,
            code_challenge_method=self.code_challenge_method,
            code_challenge=self.code_challenge,
            callback_url=self.callback_url,
            collect_diagnostics=self.collect_diagnostics,
            consumer=self.consumer,
            created=self.created,
            exchange=self.exchange,
            experiments=self.experiments,
            finished=self.finished,
            in_redis=self.in_redis,
            nonce=self.nonce,
            passthrough_errors=self.passthrough_errors,
            place=self.place,
            profile_id=self.profile_id,
            profile=self.profile,
            provider=self.provider,
            request_token=self.request_token,
            retpath=self.retpath,
            scope=self.scope,
            sid=self.sid,
            start_args=self.start_args,
            state=self.state,
            task_id=self.task_id,
            uid=self.uid,
            user_param=self.user_param,
            yandexuid=self.yandexuid,
        )

    def parse_session_data(self, json_data):
        logger.debug('Getting session data from request parameters')
        if not json_data:
            raise InvalidTaskDataError('No task data was passed into the request. Maybe session expired?')

        data = load_from_json_string(json_data)

        data['application'] = self._parse_application(data)
        data['passthrough_errors'] = self._parse_passthrough_errors(data)
        data['experiments'] = self._parse_experiments(data)

        for key, cookie_key in self.dump_fields_mapping.iteritems():
            value = data.get(cookie_key)

            if value is None and key in self.required_fields:
                raise InvalidTaskDataError('No required field %s in session data' % cookie_key)
            setattr(self, key, value)

    def dump_session_data(self):
        logger.debug('Dumping session data from request parameters')
        data = {}

        for key, cookie_key in self.dump_fields_mapping.iteritems():
            value = getattr(self, key)
            if key in self.required_fields and value is None:
                raise AttributeError(key)
            data[cookie_key] = value

        data['application'] = self._dump_application()
        data['passthrough_errors'] = self._dump_passthrough_errors()
        data['experiments'] = self._dump_experiments()

        data = dict((k, v) for k, v in data.iteritems() if v is not None)

        return dump_to_json_string(data)

    def get_social_userinfo(self):
        profile = dict((k, v) for k, v in copy.deepcopy(self.profile).iteritems() if k in self.profile_fields)
        profile['provider'] = self.provider

        token_for_business = self.profile.get('token_for_business')
        if token_for_business:
            profile['business'] = {'id': FACEBOOK_BUSINESS_ID, 'token': token_for_business}
        return profile

    def get_token(self):
        if not self.access_token:
            return
        token_kwargs = dict(
            application=self.application,
            application_id=self.application.identifier,
            value=self.access_token['value'],
            secret=self.access_token.get('secret'),
            scopes=self.access_token.get('scope'),
            expired=self.access_token.get('expires'),
        )
        if self.uid:
            token_kwargs['uid'] = self.uid
        return Token(**token_kwargs)

    def get_refresh_token(self):
        if not self.access_token or 'refresh' not in self.access_token:
            return
        return RefreshToken(
            value=self.access_token['refresh'],
            scopes=self.access_token.get('scope'),
            expired=None,
        )

    def get_dict_for_response(
        self,
        with_related_yandex_client_secret=False,
        with_token=True,
    ):
        profile = self.get_social_userinfo()

        token = dict()
        if with_token:
            token.update(self.access_token or dict())
        if self.application:
            application_name = self.application.name
        else:
            application_name = None
        token.update(dict(application=application_name))
        if self.application:
            app_attrs = dict(
                id=self.application.name,
                third_party=self.application.is_third_party or False,
                related_yandex_client_id=self.application.related_yandex_client_id,
            )
            if with_related_yandex_client_secret:
                app_attrs.update(
                    related_yandex_client_secret=self.application.related_yandex_client_secret,
                )
            app_attrs = remove_undefined_values(remove_none_values(app_attrs))
        else:
            app_attrs = dict()
        token.update(application_attributes=app_attrs)

        data = {
            'token': token,
            'profile': profile,
            'sid': self.sid,
            'task_id': self.task_id,
            'created': self.created,
            'finished': self.finished,
            'uid': self.uid,
            'yandexuid': self.yandexuid,
            'consumer': self.consumer,
            'code_challenge': self.code_challenge,
            'code_challenge_method': self.code_challenge_method,
        }
        return remove_undefined_values(remove_none_values(data))

    def update(self, task):
        task_data = task._get_redis_data()
        self._update_from_redis_data(task_data)
        return self

    def _update_from_redis_data_v2(self, data):
        def _update_attr(name, converter=None, converted_name=None):
            value = data.get(converted_name or name)
            if value:
                # Повсеместно использую json.loads, чтобы сохранить исходный
                # тип данных после Redis (т.к. любое прочитанное из Redis значение
                # является строкой).
                value = load_from_json_string(value)
                if converter:
                    value = converter(value)
                setattr(self, name, value)

        def _deconvert_app(value):
            app = providers.get_application_by_id(value)
            if not app:
                raise ApplicationUnknown()
            return app

        _update_attr('access_token')
        _update_attr('application', _deconvert_app, 'application_id')
        _update_attr('callback_url')
        _update_attr('code_challenge')
        _update_attr('code_challenge_method')
        _update_attr('consumer')
        _update_attr('created')
        _update_attr('finished')
        _update_attr('profile')
        _update_attr('provider')
        _update_attr('sid')
        _update_attr('uid')
        _update_attr('yandexuid')

    def _update_from_redis_data(self, data):
        version = int(data.get('version', 1))
        if version == 2:
            self._update_from_redis_data_v2(data)
        else:
            raise ValueError('Unknown version: %s' % version)

    def _get_redis_data_v2(self):
        def _serialize_attr(name, converter=None, converted_name=None):
            value = getattr(self, name, None)
            if value:
                if converter:
                    value = converter(value)
                value = dump_to_json_string(value)
                data[converted_name or name] = value

        data = dict(version=2)

        _serialize_attr('access_token')
        _serialize_attr('application', attrgetter('identifier'), 'application_id')
        _serialize_attr('callback_url')
        _serialize_attr('code_challenge')
        _serialize_attr('code_challenge_method')
        _serialize_attr('consumer')
        _serialize_attr('created')
        _serialize_attr('finished')
        _serialize_attr('profile')
        _serialize_attr('provider')
        _serialize_attr('sid')
        _serialize_attr('uid')
        _serialize_attr('yandexuid')
        return data

    def _get_redis_data(self, version=TASK_SERIALIZATOR_VERSION):
        if version == 2:
            return self._get_redis_data_v2()
        else:
            assert False, 'Unknown version: %s' % version  # pragma: no cover

    def _dump_application(self):
        if self.application:
            return self.application.identifier

    def _dump_passthrough_errors(self):
        if self.passthrough_errors:
            return ','.join(sorted(self.passthrough_errors))

    def _dump_experiments(self):
        if self.experiments:
            return ','.join(sorted(self.experiments))

    def _parse_application(self, data):
        if 'application' in data:
            app = providers.get_application_by_id(data['application'])
            if not app:
                raise ApplicationUnknown()
            return app

    def _parse_passthrough_errors(self, data):
        passthrough_errors = data.get('passthrough_errors', '') or ''
        if passthrough_errors:
            return passthrough_errors.split(',')
        else:
            return list()

    def _parse_experiments(self, data):
        experiments = data.get('experiments', '') or ''
        if experiments:
            return experiments.split(',')
        else:
            return list()

    def set_uid(self, value):
        if value is not None:
            self._uid = int(value)

    def get_uid(self):
        return self._uid

    uid = property(get_uid, set_uid)


class TaskId(uuid.UUID):
    # Число бит в task_id, которые выделяются под environment_id
    environment_id_length = 3

    def __init__(self, _hex=None, _bytes=None, environment_id=None):
        if not (
            _bytes is None and environment_id is None or
            _bytes is not None and environment_id is not None
        ):
            raise ValueError('bytes and environment_id are related')

        _int = None
        if _bytes is not None and environment_id is not None:
            if len(_bytes) != 16:
                raise ValueError('bytes is not a 16-char string')
            if not (0 <= environment_id < 2 ** self.environment_id_length):
                raise ValueError('environment_id is too large')

            octets = map(ord, _bytes)

            # Заменяем старшие биты в 0-м байте на environment_id
            shift = (8 - self.environment_id_length)
            octets[0] &= (1 << shift) - 1
            octets[0] |= environment_id << shift

            _int = long(('%02x' * 16) % tuple(octets), 16)

        super(TaskId, self).__init__(hex=_hex, int=_int, version=4)

    @property
    def environment_id(self):
        return self.int >> (128 - self.environment_id_length)


class BaseTaskException(Exception):
    pass


class InvalidTaskDataError(BaseTaskException):
    """
    По данным нельзя построить объект Task
    """
