# coding: utf-8


import itertools
import logging
from weakref import proxy

import attr
from django.conf import settings
from django.db import transaction
from django_pgaas.transaction import atomic_retry

from idm.nodes.canonical import list_of, dict_of, SELF
from collections import Counter


log = logging.getLogger(__name__)

@attr.s(slots=True)
class DiffAddition(object):
    new = attr.ib()


@attr.s(slots=True)
class DiffRemoval(object):
    old = attr.ib()


@attr.s(slots=True)
class DiffModification(object):
    old = attr.ib()
    new = attr.ib()
    name = attr.ib()


@attr.s(slots=True)
class NestedDiff(object):
    additions = attr.ib(validator=list_of(DiffAddition), default=attr.Factory(list))
    removals = attr.ib(validator=list_of(DiffRemoval), default=attr.Factory(list))
    modifications = attr.ib(default=attr.Factory(list))  # validator=list_of(DiffSet)


@attr.s(slots=True)
class DiffSet(object):
    modified_value = attr.ib()
    new_value = attr.ib()
    flat = attr.ib(validator=list_of((DiffModification, SELF)))
    nested = attr.ib(validator=dict_of(NestedDiff), default=attr.Factory(dict))


@attr.s(slots=True)
class BaseItem(object):
    queue = attr.ib()
    node = attr.ib()
    extra = attr.ib(init=False, default=attr.Factory(dict))
    counter = attr.ib(default=0)
    type = 'unknown'

    def update(self, extra=None):
        if extra is not None:
            self.extra.update(extra)

    def apply(self, **kwargs):
        raise NotImplementedError()

    def get_canonical_diff(self, current_canonical, new_canonical, cmp_attrs=None):
        flat = []
        nested = {}

        for attr in current_canonical.__attrs_attrs__:
            if not attr.cmp:
                continue
            if cmp_attrs is not None and attr.name not in cmp_attrs:  # Только для верхнего уровня
                continue
            current = getattr(current_canonical, attr.name)
            new = getattr(new_canonical, attr.name)
            if hasattr(current_canonical, 'hash') and isinstance(current, dict):
                nested_diff = NestedDiff()
                old_keys = set(current.keys())
                new_keys = set(new.keys())
                removed = old_keys - new_keys
                added = new_keys - old_keys
                both = new_keys & old_keys
                for key in added:
                    new_value = new[key]
                    nested_diff.additions.append(DiffAddition(new=new_value))
                for key in removed:
                    old_value = current[key]
                    nested_diff.removals.append(DiffRemoval(old=old_value))
                for key in both:
                    old_value = current[key]
                    new_value = new[key]
                    if old_value != new_value:
                        diffset = self.get_canonical_diff(old_value, new_value)
                        nested_diff.modifications.append(diffset)
                nested[attr.name] = nested_diff
            else:
                if current != new:
                    flat.append(DiffModification(name=attr.name, old=current, new=new))
        diff = DiffSet(flat=flat, nested=nested, modified_value=current_canonical, new_value=new_canonical)
        return diff


@attr.s(slots=True)
class RestoringItem(BaseItem):
    """ Восстановление узла """

    type = 'restore'


@attr.s(slots=True)
class AdditionItem(BaseItem):
    """ Добавление узла """

    type = 'add'
    parent_item = attr.ib(default=None)
    new_node = attr.ib(default=None)
    child_data = attr.ib(default=None)

    def get_diff(self, cmp_attrs=None):
        canonical = self.new_node.as_canonical()
        diffset = self.get_canonical_diff(canonical, self.child_data, cmp_attrs)
        return diffset


@attr.s(slots=True)
class MovementItem(BaseItem):
    """ Перемещение узла """

    type = 'move'
    parent_item = attr.ib(default=None)
    new_parent = attr.ib(default=None)


@attr.s(slots=True)
class ModificationItem(BaseItem):
    """ Модификация узла """

    new_data = attr.ib(default=None)
    type = 'modify'

    def get_diff(self):
        canonical = self.node.as_canonical()
        diffset = self.get_canonical_diff(canonical, self.new_data)
        return diffset


@attr.s(slots=True)
class RemovalItem(BaseItem):
    """ Удаление узла """

    type = 'remove'


@attr.s(slots=True)
class HashUpdate(BaseItem):
    """Обновление хеша узла"""

    hash = attr.ib(default=None)
    parent_item = attr.ib(default=None)
    type = 'hashupdate'

    def get_parent(self):
        if self.node is not None:
            parent = self.node
        else:
            parent = self.parent_item.new_node
        return parent

    def apply(self, **kwargs):
        parent = self.get_parent()
        parent.hash = self.hash
        parent.save(update_fields=('hash',))


@attr.s(slots=True)
class Queue(object):
    addition = AdditionItem
    removal = RemovalItem
    modification = ModificationItem
    restore = RestoringItem
    movement = MovementItem
    hash_update = HashUpdate
    items = attr.ib(default=attr.Factory(list))
    counter = attr.ib(default=0)
    external_id_to_item = attr.ib(default=attr.Factory(dict))
    system = attr.ib(init=True, default=None)
    __weakref__ = attr.ib(init=False)

    def __iter__(self):
        for item in self.items:
            yield item

    def __len__(self):
        return len(self.items)

    def __getitem__(self, item):
        return self.items[item]

    def get_summary(self):
        counter = Counter([item.type for item in self.items])
        return counter

    def get_of_type(self, type_):
        return [item for item in self.items if item.type == type_]

    def push_addition(self, child_data, parent_item=None, node=None, extra=None, **kwargs):
        log.info('Push add node %s', node)
        self.counter += 1
        item = self.addition(
            queue=proxy(self),
            node=node,
            parent_item=parent_item,
            child_data=child_data,
            counter=self.counter,
        )
        item.update(extra)
        self.items.append(item)
        for child in item.child_data.children:
            self.push_addition(
                parent_item=item,
                child_data=child,
                extra=extra,
                **kwargs
            )
        return item

    def push_modification(self, node, new_data, extra=None, **kwargs):
        log.info('Push modification to node %s', node)
        self.counter += 1
        item = self.modification(
            queue=proxy(self),
            node=node,
            new_data=new_data,
            counter=self.counter
        )
        item.update(extra)
        self.items.append(item)
        return item

    def push_removal(self, node, extra=None, **kwargs):
        log.info('Push remove node %s', node)
        self.counter += 1
        item = self.removal(
            queue=proxy(self),
            node=node,
            counter=self.counter
        )
        item.update(extra)
        for descendant in item.node.get_descendants().active():
            self.push_removal(
                node=descendant,
                extra=extra,
                **kwargs
            )
        self.items.append(item)
        return item

    def push_movement(self, node, parent_item, new_parent, extra=None, **kwargs):
        log.info('Push move node %s', node)
        item = self.movement(
            queue=proxy(self),
            node=node,
            parent_item=parent_item,
            new_parent=new_parent,
            **kwargs
        )
        item.update(extra)
        self.items.append(item)
        return item

    def push_restore(self, node, extra=None, **kwargs):
        log.info('Push restore node %s', node)
        self.counter += 1
        item = self.restore(
            queue=proxy(self),
            node=node,
            counter=self.counter,
        )
        item.update(extra)
        self.items.append(item)

    def push_hash_update(self, node, hash, parent_item=None, extra=None, **kwargs):
        log.info('Push hash update to node %s', node)
        self.counter += 1
        item = self.hash_update(
            queue=proxy(self),
            node=node,
            hash=hash,
            parent_item=parent_item,
            counter=self.counter,
        )
        item.update(extra)
        self.items.append(item)

    def extend(self, other_queue):
        self.items.extend(other_queue.items)

    # I hope method has no non-db side-effects
    @atomic_retry
    def apply_batch(self, batch, **kwargs):
        for item in batch:
            item.apply(**kwargs)

    def apply(self, **kwargs):
        nonhashupdates = (item for item in self.items if not isinstance(item, self.hash_update))
        hashupdates = (item for item in self.items if isinstance(item, self.hash_update))
        for items in (nonhashupdates, hashupdates):
            while True:
                batch = list(itertools.islice(items, settings.IDM_NODES_SYNC_BATCH_SIZE))
                if not batch:
                    break
                self.apply_batch(batch, **kwargs)
        return self


class ExternalIDQueue(Queue):

    def get_item_class(self):
        raise NotImplementedError()

    def apply(self, **kwargs):
        item_class = self.get_item_class()
        added_external_id_to_item = {}
        external_id_to_deleted_node = {item.external_id: item for item in item_class.objects.inactive()}
        for item in self.items:
            if item.type == 'add':
                added_external_id_to_item[item.child_data.external_id] = item
            elif item.type == 'remove':
                external_id_to_deleted_node[item.node.external_id] = item.node
        moved_external_ids = added_external_id_to_item.keys() & external_id_to_deleted_node.keys()
        new_queue = self.__class__()
        for item in self.items:
            item.queue = new_queue
            if item.type == 'remove':
                if item.node.external_id in moved_external_ids:
                    continue
                else:
                    new_queue.external_id_to_item[item.node.external_id] = item
            elif item.type == 'add':
                if item.child_data.external_id in moved_external_ids:
                    moved_node = external_id_to_deleted_node[item.child_data.external_id]
                    movement = new_queue.push_movement(node=moved_node,
                                                       parent_item=item.parent_item,
                                                       new_parent=item.node)
                    if not moved_node.is_active():
                        restoring_info = {'data': item.child_data}
                        new_queue.push_restore(node=moved_node, extra=restoring_info)
                    new_queue.push_modification(node=moved_node, new_data=item.child_data)
                    new_queue.external_id_to_item[moved_node.external_id] = movement
                    continue
                else:
                    new_queue.external_id_to_item[item.child_data.external_id] = item
            new_queue.items.append(item)
        return super(ExternalIDQueue, new_queue).apply(**kwargs)


@attr.s(slots=True)
class ExternalIDAdditionItem(AdditionItem):
    external_id = attr.ib(default=None)

    def get_parent_node(self):
        if self.node is not None:
            parent_node = self.node
        else:
            # нужно сходить в меппинг, так как parent_item может указывать на уже неактуальный объект
            parent_item = self.queue.external_id_to_item[self.parent_item.get_external_id()]
            parent_node = parent_item.get_new_node()
        return parent_node

    def get_external_id(self):
        return self.child_data.external_id

    def get_new_node(self):
        return self.new_node


@attr.s(slots=True)
class ExternalIDMovementItem(MovementItem):
    external_id = attr.ib(default=None)

    def get_parent_node(self):
        if self.new_parent is not None:
            parent_node = self.new_parent
        else:
            external_id = self.parent_item.get_external_id()
            parent_item = self.queue.external_id_to_item.get(external_id)
            parent_node = parent_item.get_new_node()
        return parent_node

    def get_external_id(self):
        return self.node.external_id

    def get_new_node(self):
        return self.node


@attr.s(slots=True)
class ExternalIDRemovalItem(RemovalItem):
    external_id = attr.ib(default=None)
