from django.core.exceptions import ObjectDoesNotExist
from django.db import IntegrityError, transaction
from django.db.models import Model
from mptt.exceptions import InvalidMove
from django.db import connection as db_conn
from psycopg2 import errorcodes as pg_errorcodes

from wiki.sync.staff.logger import sync_logger
from wiki.sync.staff.mapping.datasource import DataSource
from wiki.sync.staff.mapping.exceptions import MappingPathNotFound, RemoteObjectNotFound
from wiki.sync.staff.mapping.field_mapping import FieldMapping
from wiki.sync.staff.mapping.sync_node import SyncNode
from wiki.sync.staff.utils import update_model_fields


class RepositoryMapper(object):
    """
    Синхронизует репозиторий staff api с какой-то определенной локальной моделью.
    Наследуем от этого класса и регистрируем в mapper_repository

    repository_name - имя репозитория
    extra_lookups - дополнительные параметры которые надо добавить к запросу, например {'type': 'department'}
    field_mapping - список из Mapping - как связать поле модели с путем к remote данным
    pk_mappings - маппинги, по которому мы ассоциируем локальную модель с remote данными. Не обязательно (id <-> id)
    field_filters_on_create - как преобразовать словарь fields при создании новой модели
    field_filters_on_update - как преобразовать словарь fields при обновлении модели

    local_model - джанговская модель с которой мы синкаемся
    deps - словарь (локальное поле - fk на модель) для того чтобы обработать кейс IntegrityError при создании.
        Перечисляем только те поля которые хотим чтобы автомагически отрезолвились через мапперы и были взяты из
        удаленного репозиторя

    """

    repository_name = ''
    extra_lookups = {}
    field_mapping = []
    field_filters_on_create = []
    field_filters_on_update = []
    mptt_parent_field = False

    default_on_create = {}

    pk_mappings = [FieldMapping('id', 'id')]

    local_model = Model
    PAGE_SIZE = 500
    deps = {}

    @classmethod
    def all_mappings(cls, include_pk=True):
        if include_pk:
            for mapping in cls.pk_mappings:
                yield mapping
        for m in cls.field_mapping:
            yield m

    @classmethod
    def get_model_pk_tuple(cls, mdl):
        """
        Возвращает тапл для того чтобы в словаре локальных моделей
        можно было найти, есть ли такая модель -- т.е. ее надо обновить или создать.

        Тапл потому что некоторые модели могут описываться парами индексов {group_id, staff_id}
        :param mdl:
        :return:
        """
        return tuple([getattr(mdl, pk_mapping.field_name) for pk_mapping in cls.pk_mappings])

    @classmethod
    def get_remote_fields(cls):
        return [m.get_remote_field() for m in cls.all_mappings()]

    @classmethod
    def get_dict_from_batch(cls, batch):
        d = {}
        for obj in batch:
            idx = tuple([pk_mapping.get_value(obj) for pk_mapping in cls.pk_mappings])
            d[idx] = obj
        return d

    @classmethod
    def update_local_model(cls, model, fields):
        _fields = fields.copy()
        mptt_updated = False
        if cls.mptt_parent_field:
            """
            Чтобы MPTT корректно пересчитал свои временные поля нужно чтобы мы устанавливали не mdl.parent_id = new_val
            а саму модельку mdl.parent = parent
            Поэтому мы удаляем поле parent_id и выставляем поле parent
            Причем и parent и сама моделька должна быть самая свежая из БД -- иначе перерасчет сломает дерево.
            Можно рейзить DoesNotExits - будет обработано выше.
            """
            new_parent_id = _fields['parent_id']
            old_parent_id = getattr(model, 'parent_id')

            if new_parent_id != old_parent_id:
                del _fields['parent_id']
                model.refresh_from_db()
                if new_parent_id is not None:
                    try:
                        parent_model = cls.local_model.objects.get(pk=new_parent_id)
                    except Exception:
                        sync_logger.exception(
                            f'model={cls.local_model} new_parent_id={new_parent_id} old_parent_id={old_parent_id}'
                        )
                        raise
                else:
                    parent_model = None
                _fields['parent'] = parent_model
                mptt_updated = True

        for fn in cls.field_filters_on_update:
            fn(_fields)

        return update_model_fields(model, _fields, update_whole_model=mptt_updated)

    @classmethod
    def create_local_model(cls, fields, exist_ok=True):
        _fields = cls.default_on_create.copy()
        _fields.update(fields)
        for fn in cls.field_filters_on_create:
            fn(_fields)

        try:
            with transaction.atomic():
                cls.local_model.objects.create(**_fields)
        except IntegrityError as exc:
            if not (exist_ok and exc.__cause__.pgcode == pg_errorcodes.UNIQUE_VIOLATION):
                raise

    @classmethod
    def convert_remote_to_model_fields(cls, remote, include_pk=True):
        try:
            fields = {}
            for mapping in cls.all_mappings(include_pk):
                mapping.contribute_to_fields_dict(fields, remote)
        except MappingPathNotFound:
            sync_logger.exception('Cant map remote to local %s, remote is %s' % (cls.__name__, remote))
            raise
        return fields

    @classmethod
    def process_batch(cls, remote_objects, mapper_registry):
        if len(remote_objects) == 0:
            return 0, 0

        pk_to_remote_obj = cls.get_dict_from_batch(remote_objects)

        # маппинги задаются из кода, поэтому можно считать что поля и название таблицы безопасны

        table_name = cls.local_model._meta.db_table
        model_fields = ','.join([f"\"{pk_mapping.field_name}\"" for pk_mapping in cls.pk_mappings])

        with db_conn.cursor() as cursor:

            cursor.execute(
                f"SELECT {model_fields} FROM \"{table_name}\" WHERE ({model_fields}) IN %s",
                [tuple(pk_to_remote_obj.keys())],
            )
            local_objects = set(cursor.fetchall())

        failures_count = 0
        success_count = 0
        not_changed = 0

        CREATED = 0
        UPDATED = 1
        NOT_CHANGED = 2
        FAILED = -1

        for pks_tuple, remote in pk_to_remote_obj.items():
            model_fields = cls.convert_remote_to_model_fields(remote)
            if pks_tuple in local_objects:

                attempts = 1

                while attempts >= 0:
                    attempts -= 1

                    # 1. прогон "на удачу"
                    # 0. прогон после ремонта

                    try:
                        model_instance = cls.get_local_model_instance(pks_tuple)
                        result, updated_fields = cls.update_local_model(model_instance, model_fields)
                        if result:
                            sync_logger.debug(
                                'Updating %s %s [%s]...'
                                % (cls.local_model.__name__, pks_tuple, ', '.join(updated_fields))
                            )
                            result = UPDATED
                        else:
                            result = NOT_CHANGED

                        break
                    except ObjectDoesNotExist:
                        if cls.fix_local_model_without_deps(model_fields, remote, mapper_registry):
                            continue
                        else:
                            result = FAILED
                            break
                    except InvalidMove:
                        sync_logger.error('Hierarchy problem. It usually must not be happening')
                        result = FAILED
                        break

            else:
                sync_logger.debug('Creating %s %s...' % (cls.local_model.__name__, pks_tuple))
                try:
                    cls.create_local_model(model_fields, exist_ok=True)
                    result = CREATED
                except ObjectDoesNotExist:
                    if cls.fix_local_model_without_deps(model_fields, remote, mapper_registry):
                        result = CREATED
                    else:
                        result = FAILED
                except Exception:
                    sync_logger.exception('Fatal on creation %s, %s' % (cls.local_model.__name__, model_fields))
                    raise

            if result == NOT_CHANGED:
                not_changed += 1
            else:
                if result == FAILED:
                    failures_count += 1
                else:
                    success_count += 1

        return success_count, failures_count

    @classmethod
    def find_missing_related_models(cls, node, mapper_registry):
        """
        @param DataSource datasource:
        @param List missing_dependencies:
        @param MapperRegistry mapper_registry:
        @param SyncNode node:
        @return:
        """

        fields = cls.convert_remote_to_model_fields(node.raw)
        my_pk = {pk_mapper.field_name: fields[pk_mapper.field_name] for pk_mapper in cls.pk_mappings}

        # локальное поле с FK -> Модель
        for local_field, model_klass in list(cls.deps.items()):  # type: (str, Model)
            if fields[local_field]:
                exists = model_klass.objects.filter(pk=fields[local_field]).exists()
                if not exists:
                    mapper = mapper_registry.get_mapper(model_klass)  # type: RepositoryMapper

                    sync_logger.debug(
                        "%s %s doesn't exist (foreign key %s of %s %s), resolving using %s"
                        % (
                            model_klass.__name__,
                            fields[local_field],
                            local_field,
                            cls.local_model.__name__,
                            my_pk,
                            mapper.__name__,
                        )
                    )

                    missing_node = mapper.get_node(fields[local_field])
                    node.missing_dependencies.append(missing_node)
                    mapper.find_missing_related_models(missing_node, mapper_registry)

        node.graph_built = True

    _datasource = None

    @classmethod
    def get_default_datasource(cls):
        if cls._datasource is None:
            cls._datasource = DataSource(
                repository_name=cls.repository_name, extra_lookups=cls.extra_lookups, fields=cls.get_remote_fields()
            )
        return cls._datasource

    @classmethod
    def get_node(cls, pk):
        if len(cls.pk_mappings) > 1:
            raise NotImplementedError('Can\'t resolve model if PK consist of composite key')
        return SyncNode(cls, cls.get_default_datasource().get_object({cls.pk_mappings[0].remote_field: pk}))

    @classmethod
    def update_models(cls, queryset, mapper_registry):
        ds = cls.get_default_datasource()
        raw = []
        for model in queryset:
            raw.append(ds.get_object(cls.get_remote_pk_lookup_expression(model)))

        cls.process_batch(raw, mapper_registry)

    @classmethod
    def fix_local_model_without_deps(cls, pks_tuple, remote_dict, mapper_registry):
        sync_logger.debug('Dependency is missing, trying to resolve')
        node = SyncNode(cls, remote_dict)
        try:
            cls.find_missing_related_models(node, mapper_registry)
            if len(node.missing_dependencies) > 0:
                node.apply_dependency_graph()
            return True
        except RemoteObjectNotFound as s:
            sync_logger.warn('Resolve failed %s' % s)

        return False

    @classmethod
    def get_joined_datasource(cls, *other_mappers):
        e_l = cls.extra_lookups.copy()
        fields = cls.get_remote_fields()

        for mapper in other_mappers:
            e_l.update(mapper.extra_lookups)
            fields += mapper.get_remote_fields()

        fields = set(fields)

        return DataSource(repository_name=cls.repository_name, extra_lookups=e_l, fields=fields)

    @classmethod
    def get_remote_pk_lookup_expression(cls, mdl):
        lookup = {}
        for mapping in cls.pk_mappings:
            lookup[mapping.remote_field] = getattr(mdl, mapping.field_name)
        return lookup

    @classmethod
    def get_local_pk_lookup_expression(cls, pk_tuple):
        lookup = {}
        for pk_value, mapping in zip(pk_tuple, cls.pk_mappings):
            lookup[mapping.field_name] = pk_value
        return lookup

    @classmethod
    def get_local_model_instance(cls, pks_tuple):
        return cls.local_model.objects.get(**cls.get_local_pk_lookup_expression(pks_tuple))
