import logging
from datetime import datetime
from typing import Any, Dict, Tuple

from django.db import models, IntegrityError

from staff.lib.bulk_update import execute_bulk_update
from staff.lib.sync_tools import DataGenerator
from staff.lib.sync_tools.updater import Updater
from staff.lib.sync_tools.diff_merger import DataDiffMerger

from staff.oebs.constants import ACTIVE_STATUS
from staff.oebs.exceptions import OEBSError


logger = logging.getLogger(__name__)


class OEBSDataDiffMerger(DataDiffMerger):
    batch_size = 1000

    def __init__(self, data_generator: DataGenerator, logger=None):
        self.to_create = {}
        self.to_update = {}
        self.to_delete = {}
        super().__init__(data_generator, logger)

    def execute(self, create=False, update=False, delete=False):
        accepted_actions = self._accepted_actions(create, update, delete)
        sync_fields = self.data_generator.get_sync_fields()

        for ext_data, int_data, sync_key in self.data_generator:
            self.logger.debug('Merging %s with %s by %s', ext_data, int_data, sync_key)
            self.all += 1

            action = self.detect_action(ext_data, int_data)
            self.logger.debug('Gonna %s', action)
            if action not in accepted_actions:
                self.logger.debug('%s is not accepted action, so skip', action)
                self.skipped += 1
                continue

            if action == 'delete':
                self.logger.debug('Gonna %s %s', action, sync_key)
                self.to_delete[sync_key] = True
                self.perform_deletes()
            elif action == 'update':
                diff_data = self.diff(ext_data, int_data)
                if diff_data:
                    self.logger.debug('Data differs by %s', diff_data)
                    self.to_update[sync_key] = diff_data
                    self.perform_updates()
                else:
                    self.logger.debug('Object %s is up to date already', sync_key)
                    self.skipped += 1
            elif action == 'create':
                self.to_create[sync_key] = ext_data
                for field in sync_fields:
                    del self.to_create[sync_key][field]
                self.perform_creates()
            else:
                raise RuntimeError(f'Unknown action "{action}"')

        self.perform_all_actions()

    def perform_all_actions(self):
        self.perform_deletes(force=True)
        self.perform_updates(force=True)
        self.perform_creates(force=True)

    def perform_deletes(self, force: bool = False) -> None:
        if not self._is_required(force, self.to_delete):
            return None

        query = models.Q()

        for sync_key in self.to_delete:
            query |= models.Q(**dict(zip(self.data_generator.get_sync_fields(), sync_key)))

        try:
            self.data_generator.get_queryset().filter(query).delete()
            self.deleted += len(self.to_delete)
        except Exception:
            self.errors += len(self.to_delete)
            logger.exception('Bulk delete failed')

        self.to_delete = {}

    def perform_updates(self, force: bool = False) -> None:
        if not self._is_required(force, self.to_update):
            return None

        update_groups: Dict[Tuple, Dict[str, Dict[str, Any]]] = {}
        sync_fields = list(self.data_generator.get_sync_fields())

        for sync_key, diff_data in self.to_update.items():
            diff_data['last_sync'] = datetime.now()
            for index, field in enumerate(sync_fields):
                diff_data[field] = sync_key[index]
            update_groups.setdefault(tuple(sorted(diff_data.keys())), {})[sync_key] = diff_data

        for all_fields, update_group in update_groups.items():
            updatable_fields = list(set(all_fields) - set(sync_fields))
            values = [
                [diff_data[field] for field in updatable_fields + sync_fields]
                for diff_data in update_group.values()
            ]

            try:
                execute_bulk_update(
                    self.data_generator.get_model(),
                    sync_fields,
                    updatable_fields,
                    values,
                )
                self.updated += len(update_group)
            except Exception:
                self.errors += len(update_group)
                logger.exception('Bulk update failed')

            for sync_key in update_group:
                del self.to_update[sync_key]

    def perform_creates(self, force: bool = False) -> None:
        if not self._is_required(force, self.to_create):
            return None

        objects = []

        for sync_key, diff_data in self.to_create.items():
            obj = self.data_generator.create_object(sync_key)
            diff_data['last_sync'] = datetime.now()
            self.set_data(obj, diff_data)
            objects.append(obj)

        try:
            self.data_generator.get_model().objects.bulk_create(objects)
            self.created += len(self.to_create)
        except IntegrityError:  # STAFF-17876: Закопать листьями дубли в OEBS
            logger.exception('Bulk create failed, running STAFF-17876 w/a')
            for item in objects:
                try:
                    item.save()
                    self.created += 1
                except Exception:
                    self.errors += 1
                    logger.exception('STAFF-17876 w/a failed')
        except Exception:
            self.errors += len(self.to_create)
            logger.exception('Bulk create failed')

        self.to_create = {}

    def _is_required(self, force: bool, action_data: dict) -> bool:
        return action_data and (force or len(action_data) >= self.batch_size)


class OEBSPlacementDataDiffMerger(OEBSDataDiffMerger):
    def perform_deletes(self, force: bool = False) -> None:
        for sync_key in self.to_delete:
            self.to_update[sync_key] = {'active_status': ACTIVE_STATUS[0]}
        self.to_delete = {}
        super().perform_updates(force)


class OEBSUpdater(Updater):
    def run_sync(self):
        try:
            return super().run_sync()
        except OEBSError:
            self.logger.info('%s sync failed for some reason', self.get_source_type(), exc_info=True)
            raise
