# -*- coding: utf-8 -*-
from collections import (
    defaultdict,
    namedtuple,
)
from functools import partial
import logging

from passport.backend.core.am_pushes.common import (
    get_am_capabilities_manager,
    get_platform_by_push_service,
    Platforms,
)
from passport.backend.core.builders.blackbox import get_blackbox
from passport.backend.core.builders.push_api import get_push_api
from passport.backend.core.conf import settings
from passport.backend.core.logging_utils.helpers import trim_message
from passport.backend.core.utils.decorators import cached_property
from passport.backend.core.utils.version import parse_am_version


log = logging.getLogger('passport.am_pushes')
AM_CAPABILITY_PUSH_PROTOCOL = 'push:passport_protocol'


RemovedSubscription = namedtuple('RemovedSubscription', 'subscription, message')


class SubscriptionFilterResult(object):
    """
    Работает как список подписок (len, for, iter etc..)
    """
    def __init__(self, subscriptions, removed=None):
        self._subscriptions = subscriptions
        self.removed = [] if removed is None else removed

    def __eq__(self, other):
        return other == self._subscriptions

    def __getitem__(self, item):
        return self._subscriptions.__getitem__(item)

    def __iter__(self):
        return self._subscriptions.__iter__()

    def __len__(self):
        return len(self._subscriptions)

    def __bool__(self):
        return bool(self._subscriptions)

    __nonzero__ = __bool__

    def replace(self, subscriptions):
        self._subscriptions = subscriptions

    def make_next_result(self):
        new_list = []
        return SubscriptionFilterResult(new_list, self.removed), new_list

    def __repr__(self):
        return '<{} {}>'.format(self.__class__.__name__, self._subscriptions)


class IFilter(object):
    def filter(self, prev_result, push_service=None, event=None):
        raise NotImplementedError()


class BasePerSubscriptionFilter(IFilter):
    def check_subscription(self, subscription, push_service, event):
        raise NotImplementedError()

    def filter(self, prev_result, push_service=None, event=None):
        result, new_list = prev_result.make_next_result()
        for subscription in prev_result:
            is_ok, message = self.check_subscription(subscription, push_service, event)
            if is_ok:
                new_list.append(subscription)
            else:
                result.removed.append(RemovedSubscription(subscription, message))

        return result


class CapsFilter(BasePerSubscriptionFilter):
    def __init__(self, required_caps):
        self.required_caps = required_caps
        min_versions_by_platform = {
            platform: get_am_capabilities_manager().get_min_am_version_by_caps(platform, required_caps)
            for platform in Platforms
        }
        self.min_versions_by_platform = min_versions_by_platform

    def check_subscription(self, subscription, push_service, event):
        if not self.required_caps:
            return True, None

        platform = get_platform_by_push_service(subscription.platform)
        if platform is None:
            message = 'Unknown push service {}'.format(subscription.platform)
            log.warning(message)
            return False, message

        min_version = self.min_versions_by_platform[platform]
        if not min_version:
            return True, None

        if not subscription.extra.am_version:
            return False, 'unknown AM version'

        version = tuple(parse_am_version(subscription.extra.am_version))
        if version < min_version:
            return False,  'AM version {} < required {}'.format(version, min_version)

        return True, None

    def __str__(self):
        return '<CapsFilter {}>'.format(self.required_caps)


class PlatformFilter(BasePerSubscriptionFilter):
    def __init__(self, required_platforms):
        self.required_platforms = required_platforms

    def check_subscription(self, subscription, push_service, event):
        if self.required_platforms:
            platform = get_platform_by_push_service(subscription.platform)
            if platform is None:
                message = 'Unknown push service {}'.format(subscription.platform)
                log.warning(message)
                return False, message
            if platform not in self.required_platforms:
                return False, 'wrong platform'

        return True, None

    def __str__(self):
        return '<PlatformFilter {}>'.format(self.required_platforms)


class DeviceIdFilter(BasePerSubscriptionFilter):
    def __init__(self, device_ids):
        self.device_ids = device_ids

    def check_subscription(self, subscription, push_service, event):
        if subscription.device and subscription.device not in (self.device_ids or ()):
            return False, 'wrong device'
        return True, None

    def __str__(self):
        return '<DeviceIdFilter {}>'.format(self.device_ids)


class ComplexRulesetFilter(BasePerSubscriptionFilter):
    def __init__(self, ruleset):
        self.ruleset = ruleset

    def _check_platform(self, subscription, expected_val):
        real_platform = get_platform_by_push_service(subscription.platform)
        expected_platform = Platforms[expected_val]
        return real_platform == expected_platform

    def _check_app(self, subscription, expected_val):
        return subscription.app == expected_val

    def _check_push_service(self, push_service, expected_val):
        return push_service == expected_val

    def _check_event(self, event, expected_val):
        return event == expected_val

    def _check_version_from(self, subscription, expected_val):
        real_version = tuple(parse_am_version(subscription.extra.am_version))
        expected_version = tuple(parse_am_version(expected_val))
        return real_version >= expected_version

    def _check_version_to(self, subscription, expected_val):
        real_version = tuple(parse_am_version(subscription.extra.am_version))
        expected_version = tuple(parse_am_version(expected_val))
        return real_version <= expected_version

    def _is_rule_matched(self, subscription, rule, push_service, event):
        all_checks = [
            ('platform', partial(self._check_platform, subscription=subscription)),
            ('app', partial(self._check_app, subscription=subscription)),
            ('push_service', partial(self._check_push_service, push_service=push_service)),
            ('event', partial(self._check_event, event=event)),
            ('version_from', partial(self._check_version_from, subscription=subscription)),
            ('version_to', partial(self._check_version_to, subscription=subscription)),
        ]

        return all(
            check_method(expected_val=getattr(rule, attr))
            for attr, check_method in all_checks
            if getattr(rule, attr) is not None
        )

    def check_subscription(self, subscription, push_service, event):
        for rule in self.ruleset:
            result = self._is_rule_matched(subscription, rule, push_service, event)
            if result:
                return rule.allow, str(rule)

        return True, None

    def __str__(self):
        return '<ComplexRulesetFilter {} rules>'.format(len(self.ruleset))


class PushesSubscriptionManager(object):
    def __init__(self, uid, push_service=None, event=None):
        self.uid = uid
        self.push_service = push_service
        self.event = event
        self._subscriptions = None
        self._test_subscriptions = None
        self._app_blacklist = defaultdict(set)

    @cached_property
    def blackbox(self):
        return get_blackbox()

    @cached_property
    def push_api(self):
        return get_push_api()

    def _debug_list(self, message, data, postfix=None):
        if data:
            message = 'Got {}: {}'.format(message, data)
        else:
            message = 'No {}'.format(message)
        if postfix:
            message = '{} {}'.format(message, postfix)
        log.debug(trim_message(message))

    def _debug_filter_result(self, comment, collection):
        if collection.removed:
            removed_message = 'Removed: {}'.format(
                ', '.join(
                    '{}: {}'.format(removed.subscription.id, removed.message)
                    for removed in collection.removed
                ),
            )
        else:
            removed_message = None
        return self._debug_list(
            'subscriptions {}'.format(comment),
            [s.id for s in collection],
            removed_message,
        )

    @staticmethod
    def _format_token(token):
        oauth = token.get('oauth', {})
        return dict(
            login_id=oauth.get('login_id'),
            token_id=oauth.get('token_id'),
        )

    @cached_property
    def trusted_xtokens(self):
        tokens = self.blackbox.get_oauth_tokens(
            self.uid,
            xtoken_only=True,
            get_is_xtoken_trusted=True,
        )
        res = [t for t in tokens if t.get('oauth', {}).get('is_xtoken_trusted')]
        self._debug_list('xtokens', [self._format_token(t) for t in res])
        return res

    @cached_property
    def trusted_login_ids(self):
        res = {
            t.get('login_id') for t in self.trusted_xtokens
            if t.get('login_id') is not None
        }
        self._debug_list('trusted login ids', res)
        return res

    @cached_property
    def trusted_device_ids(self):
        res = {
            t.get('oauth', {}).get('device_id')
            for t in self.trusted_xtokens
            if t.get('oauth', {}).get('device_id')
        }
        self._debug_list('trusted device_ids', res)
        return res

    @staticmethod
    def _is_test_subscription(subscription):
        return subscription.client == 'passport_autotests'

    def filter_subscriptions(self, subscriptions, filters):
        for filter_obj in filters:
            subscriptions = filter_obj.filter(
                subscriptions,
                push_service=self.push_service,
                event=self.event,
            )
        self._debug_filter_result(
            'filtered: {}'.format(', '.join(str(f) for f in filters)),
            subscriptions,
        )

        return subscriptions

    def _load_subscriptions(self):
        raw_subs = self.push_api.list(self.uid)
        subscriptions = []
        test_subscriptions = []
        for subscription in raw_subs:
            if self._is_test_subscription(subscription):
                test_subscriptions.append(subscription)
            else:
                subscriptions.append(subscription)
        self._debug_list('subscriptions', subscriptions)
        if test_subscriptions:
            self._debug_list('test subscriptions', test_subscriptions)

        self._subscriptions = SubscriptionFilterResult(subscriptions)
        self._test_subscriptions = test_subscriptions

    @property
    def subscriptions(self):
        if self._subscriptions is None:
            self._load_subscriptions()
        return self._subscriptions

    @property
    def test_subscriptions(self):
        if self._subscriptions is None:
            self._load_subscriptions()
        return self._test_subscriptions

    @cached_property
    def yakey_compatible_subscriptions(self):
        return self.filter_subscriptions(
            self.subscriptions,
            [
                ComplexRulesetFilter(settings.YAKEY_SUBSCRIPTION_APP_RULES),
            ],
        )

    @cached_property
    def am_compatible_subscriptions(self):
        return self.filter_subscriptions(
            self.subscriptions,
            [
                CapsFilter([AM_CAPABILITY_PUSH_PROTOCOL]),
                ComplexRulesetFilter(settings.AM_SUBSCRIPTION_APP_RULES),
            ],
        )

    @cached_property
    def trusted_subscriptions(self):
        if not self.trusted_device_ids:
            return SubscriptionFilterResult([])

        return self.filter_subscriptions(
            self.am_compatible_subscriptions,
            [
                DeviceIdFilter(self.trusted_device_ids),
            ],
        )

    def has_subscriptions(self):
        return bool(self.subscriptions)

    def has_trusted_subscriptions(self):
        return bool(self.trusted_subscriptions)

    def has_am_compatible_subscriptions(self):
        return bool(self.am_compatible_subscriptions)

    def has_yakey_compatible_subscriptions(self):
        return bool(self.yakey_compatible_subscriptions)


def get_pushes_subscription_manager(uid, push_service=None, event=None):
    return PushesSubscriptionManager(uid=uid, push_service=push_service, event=event)


class SubscriptionsCollection(object):
    def __init__(self, subscriptions):
        self.subscriptions = subscriptions
        self.idx = 0
        self.num_tries = 0

    @property
    def current(self):
        return self.subscriptions[self.idx]

    @property
    def is_finished(self):
        return self.idx >= len(self.subscriptions)

    def set_finished(self):
        self.num_tries = 0
        self.idx = len(self.subscriptions)

    def set_next(self):
        self.num_tries = 0
        self.idx += 1


class DeviceSubscriptions(object):
    def __init__(self):
        self.by_device = defaultdict(SubscriptionsCollection)

    def __getitem__(self, item):
        return self.by_device[item]

    def __setitem__(self, key, value):
        self.by_device[key] = value

    def __delitem__(self, key):
        del self.by_device[key]

    def __bool__(self):
        return not all(v.is_finished for v in self.by_device.values())

    def values(self):
        return (v for v in self.by_device.values() if not v.is_finished)

    def items(self):
        return ((k, v) for k, v in self.by_device.items() if not v.is_finished)


class SubscriptionRatingPerDeviceProcessor(object):
    def __init__(self, app_priority):
        self.app_priority = app_priority

    def _subscription_rating(self, subscription):
        """
        Считаем рейтинг по параметрам (по убыванию важности):
        - рейтинг приложения из конфига (самый высокий - самый лучший)
        - рейтинг по init_time (самый недавний/большой - самый лучший)
        """
        app = subscription.app

        rating_by_app = {
            app: i for i, app in enumerate(reversed(self.app_priority))
        }
        app_rating = rating_by_app.get(app, -1)

        try:
            init_time_rating = subscription.init_time or -1
        except (ValueError, TypeError):
            init_time_rating = -1

        return app_rating, init_time_rating

    def generate(self, subscriptions):
        subscriptions_by_device = defaultdict(list)
        for subscription in subscriptions:
            subscriptions_by_device[subscription.device].append(subscription)

        device_subscriptions = DeviceSubscriptions()
        for device_id, subscriptions in subscriptions_by_device.items():
            device_subscriptions[device_id] = SubscriptionsCollection(sorted(
                subscriptions, key=self._subscription_rating, reverse=True,
            ))

        return device_subscriptions

    def __str__(self):
        return '<{} {}>'.format(self.__class__.__name__, self.app_priority)
