import logging
from typing import TypeVar, Optional, Set, Tuple
from datetime import datetime

from django.db import models
from django.utils import translation

from staff.lib.db import atomic


class RolluperError(Exception):
    pass


class ObjectNotValid(RolluperError):
    """Объект не прошёл валидацию"""


class ConfigurationError(RolluperError):
    """Допущена ошибка в параметрах синхронизации"""


class Rollupper:
    """
    Накатыватель.

    Накатывает понаехавшие из OEBS данные на таблицы DIS.
    """
    queryset = None
    model: Optional[models.Model] = None
    link_field_name: Optional[str] = None
    data_mapper_class = None  # type: type
    validator_form_class = None  # type: type
    key_field_name: Optional[str] = None
    create_absent = False
    rollup_rel_fields = []

    @classmethod
    def rollup(cls, logger=None, object_id=None, debug=False, dry_run=False, create_absent=False):
        self = cls(logger=logger, debug=debug, create_absent=create_absent)
        self.run_rollup(object_id, dry_run=dry_run)

    def __init__(self, logger=None, debug=False, create_absent=False):
        self.logger = self.LoggerWrapper(logger, debug=debug)
        self.create_absent = create_absent
        self.cache = {}

    def dry(self, oebs_instance):
        return {key: value for key, value in vars(oebs_instance).items()}

    def validate_instance(self, oebs_instance):
        if not self.validator_form_class:
            return
        validator = self.validator_form_class(self.dry(oebs_instance))

        if validator.errors or validator.warnings:
            with translation.override('ru'):  # beware! 1.4 feature
                self.mark_error(
                    oebs_instance,
                    errors=validator.errors,
                    warnings=validator.warnings
                )

        if validator.errors:
            raise ObjectNotValid(repr(validator.errors))

    def run_rollup(self, object_id=None, dry_run=False):
        queryset = self.get_queryset(object_id)
        for oebs_instance in queryset:
            try:
                created = False
                self.validate_instance(oebs_instance)
                cached_dis_instance = getattr(oebs_instance, self.get_link_field_name())
                roll_up_needed = (
                    cached_dis_instance is None
                    or self.roll_up_instance(cached_dis_instance, oebs_instance, True)
                )

                with atomic():
                    if roll_up_needed:
                        dis_instance, created = self.get_dis_instance(oebs_instance)
                        updated_fields = self.roll_up_instance(dis_instance, oebs_instance, dry_run)

                        if updated_fields or created:
                            if not dry_run:
                                dis_instance.save()
                                setattr(oebs_instance, self.get_link_field_name(), dis_instance)  # django 1.7 bug
                            self.logger.success(vars())
                        else:
                            self.logger.skip(vars())
                    else:
                        self.logger.skip(vars())

            except Exception:
                self.logger.exception(vars())
            else:
                self.mark_rollupped(oebs_instance, is_changed=created)

        self.logger.results(vars())

    def roll_up_instance(self, dis_instance, oebs_instance, dry_run):
        plain_fields = self.rollup_plain_fields(oebs_instance, dis_instance, dry_run)
        relation_fields = self.rollup_relation_fields(oebs_instance, dis_instance, dry_run)

        return plain_fields | relation_fields

    def _make_queryset(self, object_id=None):
        lookup = {}

        if not self.create_absent:
            lookup.update({'%s__isnull' % self.link_field_name: False})
        if object_id:
            lookup.update({self.key_field_name: object_id})

        return self.model.objects.filter(**lookup)

    def get_queryset(self, object_id=None):
        if self.queryset is not None:
            queryset = self.queryset.all()  # to copy qs
        elif self.model:
            queryset = self._make_queryset(object_id)
        else:
            raise ConfigurationError('Neither queryset nor model defined')

        return queryset.select_related(self.get_link_field_name())

    def generate_update_values(self, oebs_instance: models.Model, dis_instance: models.Model):
        data_mapper_class = self.get_data_mapper()
        mapper = data_mapper_class(vars(oebs_instance))

        for field_name, new_value in mapper:
            old_value = getattr(dis_instance, field_name)
            if new_value != old_value and str(new_value) != str(old_value):
                yield field_name, new_value
                self.logger.field_changed(vars())

    def rollup_plain_fields(self, oebs_instance: models.Model, dis_instance: models.Model, dry_run: bool) -> Set[str]:
        changed_fields: Set[str] = set()
        update_values = self.generate_update_values(oebs_instance, dis_instance)

        for field_name, new_value in update_values:
            setattr(dis_instance, field_name, new_value)
            changed_fields.add(field_name)

        return changed_fields

    def get_dis_instance(self, oebs_instance: models.Model) -> Tuple[Optional[models.Model], bool]:
        dis_instance_id = getattr(oebs_instance, self.get_link_field_name() + '_id')

        if dis_instance_id is not None:
            link_field = getattr(oebs_instance.__class__, self.get_link_field_name())
            dis_model: models.Model = link_field.field.rel.to
            return dis_model.objects.select_for_update().get(pk=dis_instance_id), False

        if self.create_absent:
            return self.create_dis_instance(oebs_instance), True

        self.logger.logger.info('staff instance was not found and not created')
        return None, False

    def create_dis_instance(self, oebs_instance: models.Model) -> models.Model:
        raise ConfigurationError('Method is abstract and must be implemented in subclass')

    def generic_rollup(self, oebs_instance, dis_instance, field_name, dry_run, **kwargs) -> bool:
        """
            Накатывает значение <field_name> из <oebs_instance> в <dis_instance>
            И возвращает True если изменения сделаны, False — если нет.
        """
        raise ConfigurationError('Method is abstract and must be implemented in subclass')

    def get_link_field_name(self):
        return self.link_field_name

    def get_data_mapper(self):
        if not self.data_mapper_class:
            raise ConfigurationError('Subclass should set up data mapper')

        return self.data_mapper_class

    def get_model(self):
        if self.model:
            return self.model

        if self.queryset:
            return self.queryset.model

    def rollup_relation_fields(self, oebs_instance, dis_instance, dry_run) -> Set[str]:
        changed_fields: Set[str] = set()
        if self.rollup_rel_fields:
            for field_name, kwargs in self.rollup_rel_fields:
                if self.generic_rollup(oebs_instance, dis_instance, field_name, dry_run, **kwargs):
                    changed_fields.add(field_name)

        if changed_fields:
            self.logger.log_rel_fields_changed(vars())

        return changed_fields

    class LoggerWrapper(object):
        def __init__(self, logger=None, debug=False):
            self.logger = logger or logging.getLogger(__name__)
            self.all = self.rolled = self.errors = self.skipped = 0
            self.debug_on_exceptions = debug

        def _model_name(self, ctx):
            attr = ctx['self'].get_link_field_name()
            field_declaration = getattr(ctx['oebs_instance'].__class__, attr)
            return field_declaration.field.rel.to.__name__

        def success(self, ctx):
            self.rolled += 1
            self.all += 1

            if ctx['created']:
                msg = 'Rollup to {model}[{id}] successful: created'
            else:
                msg = 'Rollup to {model}[{id}] successful: updated {updated_fields}'

            self.logger.info(msg.format(
                model=self._model_name(ctx),
                id=ctx['dis_instance'].pk,
                **ctx,
            ))

        def exception(self, ctx):
            self.errors += 1
            self.all += 1
            oebs_instance = ctx['oebs_instance']

            self.logger.exception(
                f'Exception during rollup from {oebs_instance.__class__.__name__}[%s] to %s[%s]',
                oebs_instance.pk,
                self._model_name(ctx),
                getattr(ctx.get('dis_instance'), 'id', '?'),
            )

            if self.debug_on_exceptions:
                import ipdb
                ipdb.set_trace()

        def results(self, ctx=None):
            msg = ' '.join([
                'Rolled up {model}: {rolled},',
                'errors: {errors},',
                'skipped: {skipped}',
                'total: {all}',
            ])
            self.logger.info(msg.format(model=ctx['queryset'].model.__name__, **vars(self)))

        def field_changed(self, ctx):
            msg = 'Gonna rollup {model}[{id}].{field_name}: "{old_value}" -> "{new_value}"'
            self.logger.debug(msg.format(
                model=self._model_name(ctx),
                id=ctx['dis_instance'].pk,
                **ctx,
            ))

        def log_rel_fields_changed(self, ctx):
            msg = 'Rolled up {model}[{id}]: changed {changed_fields} relations'
            self.logger.debug(msg.format(
                model=self._model_name(ctx),
                id=ctx['dis_instance'].pk,
                **ctx,
            ))

        def skip(self, ctx):
            self.all += 1
            self.skipped += 1

    def mark_rollupped(self, oebs_instance, is_changed=False):
        """@type oebs_instance: OEBSModelBase"""
        if hasattr(oebs_instance, 'last_rollup'):
            oebs_instance.last_rollup = datetime.now()
            is_changed = True
        if hasattr(oebs_instance, 'last_rollup_error'):
            oebs_instance.last_rollup_error = ''
            is_changed = True
        if is_changed:
            oebs_instance.save()

    def mark_error(self, oebs_instance, errors=None, warnings=None):
        """
        @type oebs_instance: OEBSModelBase
        """
        assert errors is not None and warnings is not None

        def format_errors_dict(d):
            return '; '.join(
                '%s: %s' % (field, ', '.join(errors_list))
                for field, errors_list in d.items()
            )

        if hasattr(oebs_instance, 'last_rollup_error'):
            oebs_instance.last_rollup_error = format_errors_dict(errors)

        if hasattr(oebs_instance, 'rollup_warnings'):
            oebs_instance.rollup_warnings = format_errors_dict(warnings)

        oebs_instance.save()


RollupperT = TypeVar('RollupperT', bound=Rollupper)
