# coding: utf-8



import datetime
from itertools import groupby
from collections import defaultdict
from operator import attrgetter
from json import JSONEncoder
import logging

import lxml.etree as ET
from pymongo import DESCENDING, errors

from at.aux_ import Accesses
from at.aux_ import entries
from at.common.utils import get_connection, et2xml, log_exception, Status
from at.aux_.MongoStorage import Storage
from functools import reduce


_log = logging.getLogger(__name__)


def node(tag, text = None, **attrs):
    e = ET.Element(tag, **attrs)

    if text is not None:
        e.text = str(text)

    return e


def retriable(exceptions, limit):
    def wrapper(function):
        def wrapped(*args, **kwargs):
            retries = 0
            while True:
                try:
                    return function(*args, **kwargs)
                except tuple(exceptions):
                    retries += 1
                    if retries < limit:
                        pass
                    else:
                        raise
        return wrapped
    return wrapper


class EventRegistry(type):
    types = {}

    def __new__(meta, name, bases, content):
        cls = type.__new__(meta, name, bases, content)

        if name != 'EventBase':
            if 'collection_name' in content:
                meta.types[content['collection_name']] = cls

            # Enforce storage indexes
            #cls.storage().ensure_index(
            #    [(key, ASCENDING) for key in cls.__pk__],
            #    unique = True, dropDups = True)

        return cls


class EventBase(object, metaclass=EventRegistry):
    population_sort_key = staticmethod(lambda e: e['timestamp'])
    population_sort_reverse = True
    order = [
        ('object.timestamp', DESCENDING)
    ]

    @classmethod
    def storage(cls):
        return Storage('events.%s' % cls.collection_name)

    def __init__(self, id = None, timestamp = None, state = None, object = None):
        self.id = id
        self.timestamp = timestamp if timestamp is not None else datetime.datetime.now()
        self.object = object if object is not None else {}

        # We keep initial value separately, so we can tell if the object has changed
        self.state = set(state) if state is not None else set()
        self.__initial_state = self.state.copy()

    @property
    def dirty(self):
        return self.state != self.__initial_state


    def serialize(self):

        serializers = defaultdict(lambda: node)
        serializers.update({
            list: lambda name, seq, **attrs: reduce(
                lambda root, v: root.append(serializers[type(v)](name, v)) or root,
                seq,
                ET.Element('%ss' % name, count = str(len(seq)), **attrs)),
            dict: lambda name, seq, **attrs: reduce(
                lambda root, dd: root.append(serializers[type(dd[1])](dd[0], dd[1])) or root,
                iter(seq.items()),
                ET.Element(name, **attrs)
            ),
            datetime.datetime: lambda name, d, **attrs: node(name, d.strftime('%s'), **attrs)
        })

        root = serializers[dict](
            type(self).collection_name,
            self.object,
            id=str(self.id),
            timestamp=str(self.timestamp.strftime('%s')))
        state_node = ET.SubElement(root, 'state')
        for state in self.state:
            ET.SubElement(state_node, state)
        return root

    def __getitem__(self, key):
        return self.object[key]

    def __repr__(self):
        return '%s(%s)' % (type(self).__name__, ', '.join(
            str('%s = %s' % (name, getattr(self, name) or '<none>'))
            for name in ('id', 'timestamp', 'state')))


class EventJSONEncoder(JSONEncoder):
    def default(self, obj):
        if isinstance(obj, EventBase):
            return dict(
                id = str(obj.id),
                type = type(obj).__name__,
                timestamp = obj.timestamp,
                state = obj.state,
                object = obj.object)
        elif isinstance(obj, datetime.datetime):
            return int(obj.strftime('%s'))
        elif isinstance(obj, set):
            return list(obj)
        else:
            super(type(self), self).default(obj)


class EventList(object):
    def __init__(self, uid, types = None, exclude = None):
        self.uid = uid

        if types:
            self.__types = [EventRegistry.types[type] for type in types]
        else:
            self.__types = list(EventRegistry.types.values())

        # Base query
        self.__query = {
            'uid': uid,
            'state': {'$nin': ['deleted']}
        }

        if exclude:
            assert all(state in ('seen', 'sent', 'deleted') for state in exclude), \
                "invalid event state specified for exclusion"
            self.__query['state']['$nin'].extend(exclude)

        self.__events = None
        self.__limit = 0

    @retriable([errors.AutoReconnect], limit = 3)
    def push(self, event):
        storage = event.storage()

        if type(event) not in self.__types:
            raise TypeError("event type is out of the current event list scope")

        try:
            event.id = storage.insert({
                'uid': self.uid,
                'timestamp': event.timestamp,
                'state': list(event.state),
                'object': event.object
            })
        except errors.DuplicateKeyError:
            return False

        # The list is already populated, so we can insert something into it
        # Otherwise, we leave it as is, as the new event will be fetched from the
        # storage during the population process
        if self.__events is not None:
            self.__events.insert(0, event)

        return True

    @retriable([errors.AutoReconnect], limit = 3)
    def save(self):
        # Skip saving if nothing was loaded
        if self.__events is None:
            return

        # Find all the dirty events.
        # Try commit them to relevant old storages (if any).
        # I'm too lazy to bulk commit support.
        commited = [e for e in self if e.dirty]

        # Group by type, then by state, then update in bulk
        state_getter = attrgetter('state')
        classes = {type(elem).__name__: type(elem) for elem in commited}
        for cls, group in groupby(sorted(commited, key=lambda x: type(x).__name__), key=lambda x: type(x).__name__):
            cls = classes[cls]
            for state, bulk in groupby(sorted(group, key=state_getter), key=state_getter):
                cls.storage().update(
                    {'_id': {'$in': [event.id for event in bulk]}},
                    {'$set': {'state': list(state)}},
                    multi=True)

    def serialize(self, sieve=None, offset=None, limit=None, escape_html=False):
        targets = list(filter(sieve, self))

        root = ET.Element('EventList',
            uid = str(self.uid),
            count = str(len(targets)))

        if offset and limit:
            limit = offset + limit

        targets = targets[offset:limit]

        [root.append(event.serialize()) for event in targets]

        # Call decorators before sending data.
        for cls in set( map(type, targets) ):
            if hasattr(cls, 'finalize_bulk'):
                cls.finalize_bulk( self.uid, root.xpath('/EventList/' + cls.collection_name) )

        return root

    @retriable([errors.AutoReconnect], limit = 3)
    def truncate(self, sieve = None):
        targets = list(filter(sieve, self))

        # Remove them from the storage
        [cls.storage().remove({'_id': {'$in': [event.id for event in group]}})
            for cls, group in groupby(sorted(targets, key = type), key = type)]

        # And from the event list
        [self.__events.remove(event) for event in targets]

    def filter(self, **kwargs):
        self.__query.update(('object.%s' % key, value)
            for key, value in kwargs.items())

        return self

    @retriable([errors.AutoReconnect], limit = 3)
    def __populate(self):
        self.__events = []

        if len(self.__types) == 1:
            sort_controller = self.__types[0]
        else:
            if len(set(tuple(cls.order) for cls in self.__types)) > 1:
                # Нету общего порядка сортировки.
                sort_controller = EventBase
            else:
                # У всех событий одинаковый способ сортировки.
                sort_controller = self.__types[0]

        for cls in self.__types:
            for doc in list(cls.storage().find(self.__query).sort(sort_controller.order).limit(self.__limit)):
                # Create an uninitialized instance
                event = cls.__new__(cls)

                # And initialize it with the base class'
                # constructor, which should be EventBase.__init__()
                EventBase.__init__(event,
                    doc['_id'],
                    doc['timestamp'],
                    doc['state'],
                    doc['object'])

                if hasattr(event, "finalize"):
                    event.finalize()

                self.__events.append(event)

        if len(self.__types) > 1:
            self.__events.sort(
                key=sort_controller.population_sort_key,
                reverse=sort_controller.population_sort_reverse)
            self.__events = self.__events[ : self.__limit or None]

    def __getitem__(self, target):
        if self.__events is None:
            self.__populate()

        return self.__events[target]

    def __iter__(self):
        if self.__events is None:
            self.__populate()

        return iter(self.__events)

    @retriable([errors.AutoReconnect], limit = 3)
    def __len__(self):
        if self.__events is None:
            return sum(cls.storage().find(self.__query).count() for cls in self.__types)
        else:
            return len(self.__events)

    def __repr__(self):
        return "EventList(uid = %s, count = %s)" % (self.uid, len(self))



def mark_seen(uid, types, **kwargs):
    events = EventList(uid, types, ['seen']).filter(**kwargs)
    for e in events:
        e.state.update(set(['seen']))
    events.save()


# --------------------
# UEL Public Interface
# --------------------

class UserEvents(object):
    @staticmethod
    def can_access_events(ai, recipient_uid):
        return Accesses.Access(ai.uid, recipient_uid).is_moderator()

    @classmethod
    @log_exception
    @et2xml
    def GetUserEventsXML2(cls, ai, recipient_uid, types, exclude, offset, limit, escape_html=False):
        if not cls.can_access_events(ai, recipient_uid):
            return ET.ElementTree(ET.Element("AccessDenied"))

        types = types.split(',') if types else []
        exclude = exclude.split(',') if exclude else []
        return ET.ElementTree(EventList(recipient_uid, types=types,
            exclude=exclude).serialize(escape_html=escape_html, offset=offset or None, limit=limit or None))


    @classmethod
    @log_exception
    @et2xml
    def GetUserEventCountXML2(cls, recipient_uid, types, exclude, tag):
        types = types.split(',') if types else []
        exclude = exclude.split(',') if exclude else []
        count = len(EventList(recipient_uid, types=types, exclude=exclude))
        return ET.ElementTree(node(tag, str(count), uid=str(recipient_uid)))

    @classmethod
    @log_exception
    @et2xml
    def MarkUserEvents2(cls, ai, request):
        recipient_uid = request['recipient_uid']
        types = request['types']
        ids = request['ids']
        mark_str = request.pop('mark_str', 'seen')

        if not cls.can_access_events(ai, recipient_uid):
            return ET.ElementTree(ET.Element("AccessDenied"))

        types = types.split(',') if types else []
        ids = ids.split(',') if ids else []
        marks = mark_str.split(',')

        # Add marks to specified events.
        events = EventList(recipient_uid, types = types, exclude=['seen'])
        [event.state.update(marks) for event in events if str(event.id) in ids]
        events.save()
        return Status('Success')


# -----------
# Event Types
# -----------

class ItemHasGone(Exception):
    pass


class Like(EventBase):
    collection_name = "Like"
    __pk__ = [
        'uid',
        'object.uid',
        'object.feed_id',
        'object.item_no',
        'object.comment_id'
    ]

    def __init__(self, uid, feed_id, item_no, comment_id = 0):
        # Basic stuff
        super(type(self), self).__init__()

        # Fetch the event data
        query = """
            select title, url, timestamp, value
            from Likes
            where
                feed_id = %s and
                item_no = %s and
                comment_id = %s and
                uid = %s
            """

        with get_connection() as connection:
            cursor = connection.execute(query, (feed_id, item_no, comment_id, uid))
            # This can happen if the person liked and immediately unliked the item
            if cursor.rowcount != 1:
                raise ItemHasGone(feed_id, item_no, comment_id)
            title, url, timestamp, value = cursor.fetchone()

        # The object should contain enough information
        # to display the event notification without any
        # further access to the database
        self.object = dict(
            uid = uid,
            feed_id = feed_id,
            item_no = item_no,
            comment_id = comment_id,
            title = title,
            url = url,
            timestamp = timestamp,
            value = value)


class CommentNotification(EventBase):
    collection_name = "CommentNotification"
    __pk__ = [
        'uid',
        'object.feed_id',
        'object.item_no',
        'object.comment_id'
    ]
    order = [
        ('object.timestamp', DESCENDING)
    ]

    def __init__(self, feed_id, item_no, comment_id, store_time):
        super(type(self), self).__init__()
        self.object = dict(
            feed_id = feed_id,
            item_no = item_no,
            comment_id = comment_id,
            timestamp = store_time)

    @classmethod
    def finalize_bulk(cls, uid, nodes):
        def get_id(node):
            return tuple(
                int(node.findtext(field))
                for field in ['feed_id', 'item_no', 'comment_id']
            )
        node_map = dict(
            (get_id(node), node)
            for node in nodes
        )

        replies = dict(
            ((e.feed_id, e.item_no, e.comment_id), e)
            for e in entries.models.load_comments_by_ids(list(node_map.keys()), with_deleted=True)
        )
        for reply_id, node in node_map.items():
            if reply_id in replies:
                node.append(replies[reply_id].serializer.build_reply_node())
        return


class MentionNotification(EventBase):
    collection_name = "MentionNotification"
    __pk__ = [
        'uid',
        'object.feed_id',
        'object.item_no',
        'object.comment_id',
    ]

    def __init__(self, mention_entry_tuple, author_uid, mention_type, timestamp, mentioned_entry_tuple=None):
        super(type(self), self).__init__()
        feed_id, item_no, comment_id = mention_entry_tuple
        self.object = dict(
            feed_id=feed_id,
            item_no=item_no,
            comment_id=comment_id,
            author_uid=author_uid,
            type=mention_type,
            timestamp=timestamp)
        if mentioned_entry_tuple:
            linked_feed_id, linked_item_no, linked_comment_id = mentioned_entry_tuple
            self.object['linked_feed_id'] = linked_feed_id
            self.object['linked_item_no'] = linked_item_no
            self.object['linked_comment_id'] = linked_comment_id


class ClubNotification(EventBase):
    collection_name = 'ClubNotification'
    __pk__ = [
            'uid',
            'object.feed_id',
            'object.item_no',
            'object.type',
            'object.sender_id' # ? not sure
            ]
    key_fields = ['feed_id', 'type', 'item_no', 'sender_id']
    all_fields = key_fields + ['timestamp']

    def __init__(self, action_type, sender_id, feed_id, item_no=0):
        super(type(self), self).__init__()
        self.object = {
            'feed_id': feed_id,
            'item_no': item_no,
            'type': action_type,
            'sender_id': sender_id,
            'timestamp': datetime.datetime.now()
        }

    @classmethod
    def build_key(cls, uid, obj):
        return dict([('uid', uid)] + [('object.'+k, obj[k]) for k in cls.key_fields])


class FrienderNotification(EventBase):
    collection_name = "Friender"
    __pk__ = [
        'uid',
        'object.friender_uid'
    ]

    def __init__(self, friender_uid):
        super(type(self), self).__init__()
        self.object = dict(friender_uid = friender_uid)
