import functools
import furl
import logging
import time

from django.db import transaction

from cars.core.pusher import BasePusher, Xiva
from cars.core.util import import_class, datetime_helper
import cars.settings
from cars.users.models.user import User

from ..models import PushMessagesSet, PushMessagesHistory


LOGGER = logging.getLogger(__name__)


def handle_errors(f):
    @functools.wraps(f)
    def wrapped(*args, **kwargs):
        try:
            result = f(*args, **kwargs)
        except Exception as exc:
            LOGGER.exception('an unhandled exception occurred')
            result = None
        return result

    return wrapped


class PushResultReporter(object):
    def __init__(self):
        self._cache = {}

    def update_cache(self, force=False):
        if not self._cache or force:
            all_messages = PushMessagesSet.objects.all()
            cache_update = {x.message: x for x in all_messages}
            self._cache.update(**cache_update)

    def get_message_binding(self, message):
        if message not in self._cache:
            self.update_cache()

        if message not in self._cache:
            with transaction.atomic(savepoint=False):
                entry = PushMessagesSet.objects.filter(message=message).first()

                # to be done: add unique constraint and integrity exception check

                if entry is None:
                    entry = PushMessagesSet.objects.create(message=message)

            self._cache[message] = entry

        return self._cache[message]

    def report(self, message, uid, codes, sender=None):
        self.report_batch(message, {uid: codes}, sender)

    @handle_errors
    def report_batch(self, message, uid_codes_mapping, sender=None):
        now = datetime_helper.utc_now()
        message_binding = self.get_message_binding(message)

        log_entries = [
            PushMessagesHistory(
                time_id=now,
                uid=uid,
                message=message_binding,
                codes=','.join(str(c.value) for c in codes),
                sender=sender,
            )
            for uid, codes in uid_codes_mapping.items()
        ]

        PushMessagesHistory.objects.bulk_create(log_entries)


class SendPushManager:

    def __init__(self, *, pusher, mds_bucket_name, file_access_host, mds_client):
        self._pusher = pusher
        self._mds_bucket_name = mds_bucket_name
        self._file_access_host = furl.furl(file_access_host)
        self._mds_client = mds_client

    @classmethod
    def from_settings(cls, **kwargs):
        kw = {
            'pusher': BasePusher.from_settings(),
            'mds_bucket_name': cars.settings.SEND_PUSH['mds']['mds_bucket_name'],
            'file_access_host': cars.settings.SEND_PUSH['mds']['file_access_host'],
            'mds_client': import_class(cars.settings.MDS['client_class']).from_settings(),
        }
        kw.update(kwargs)
        return cls(**kw)

    def send_push(self, user_id, message, sender=None):
        user = User.objects.filter(id=user_id).first()
        assert user is not None
        self._pusher.send(user.uid, message=message, sender=sender)

    def send_push_to_batch(self, user_ids, message, ttl=None, sender=None):
        action_key = '{}-{}-{}'.format(hash(message), len(user_ids), int(time.time()))
        user_ids = set(user_ids)
        users = User.objects.filter(id__in=user_ids).all()
        not_found_user_ids = list(set(map(str, user_ids)) - set(str(user.id) for user in users))
        uids = [u.uid for u in users]

        uids_results = self._send_batch(uids, message, ttl, sender)
        stat = self._get_stat(uids_results)

        need_resend_user_ids = self._need_resend_user_ids(uids_results)
        if need_resend_user_ids:
            need_resend_url = self._make_file_and_upload(need_resend_user_ids, key=action_key)
        else:
            need_resend_url = None

        return {
            'stat': stat,
            'need_resend_count': len(need_resend_user_ids),
            'need_resend_url': need_resend_url,
            'bad_user_ids': not_found_user_ids
        }

    def _get_stat(self, uids_results):
        stat = {
            'all_devices_delivered': 0,
            'any_device_delivered': 0,
            'not_subscribed': 0,
            'subscribed_not_delivered': 0,
            'status_codes_stat': {},
        }

        delivered_codes = [
            Xiva.PushStatus.OK_200,
        ]
        not_subscribed_codes = [
            Xiva.PushStatus.NOT_SUBSCRIBED_204,
            Xiva.PushStatus.SUBSCRIPTION_REMOVED_205,
        ]

        status_codes_stat = stat['status_codes_stat']
        for result in uids_results.values():
            codes = self._pusher.get_xiva_codes_from_user_result(result)
            delivered = False
            if all(c in delivered_codes for c in codes):
                stat['all_devices_delivered'] += 1
            if any(c in delivered_codes for c in codes):
                stat['any_device_delivered'] += 1
                delivered = True
            else:
                delivered = False
            if all(c in not_subscribed_codes for c in codes):
                stat['not_subscribed'] += 1
                subscribed = False
            else:
                subscribed = True
            if subscribed and not delivered:
                stat['subscribed_not_delivered'] += 1
            for c in codes:
                status_codes_stat.setdefault(c.value, 0)
                status_codes_stat[c.value] += 1
        return stat

    def _send_batch(self, uids, message, ttl, sender=None):
        return self._pusher.send_batch(
            uids=uids,
            message=message,
            ttl=ttl,
            payload=None,
            sender=sender,
        )

    def _need_resend_user_ids(self, uid_results):
        need_resend_uids = set()
        for uid, result in uid_results.items():
            if self._need_resend_result(result):
                need_resend_uids.add(uid)
        return [u.id for u in User.objects.filter(uid__in=need_resend_uids)]

    def _need_resend_result(self, result):
        # all codes are described here: https://push.yandex-team.ru/doc/guide.html#api-reference-batch-send
        acceptable_codes = [
            Xiva.PushStatus.OK_200,
        ]

        need_resend_codes = [
            Xiva.PushStatus.BAD_REQUEST_400,
            Xiva.PushStatus.FORBIDDEN_403,
            Xiva.PushStatus.RATE_LIMIT_ERROR_429,
            Xiva.PushStatus.TRANSPORT_ERROR_500,
            Xiva.PushStatus.PUSH_SERVICE_ERROR_502,
            Xiva.PushStatus.PUSH_SERVICE_TIMEOUT_504,
        ]
        codes = self._pusher.get_xiva_codes_from_user_result(result)
        return (
            all(c not in acceptable_codes for c in codes)  # no notification reached the target
            and any(c in need_resend_codes for c in codes)  # there is a hope to resend
        )

    def _make_file_and_upload(self, user_ids, key):
        file_content = b'\n'.join(str(user_id).encode('ascii') for user_id in user_ids)
        self._mds_client.put_object(
            key=key,
            bucket=self._mds_bucket_name,
            body=file_content
        )
        return self._file_access_host.copy().join(key).url
