import logging
import time
import threading
from multiprocessing.dummy import Pool

from django.apps import apps
from django.core.management.base import BaseCommand
from django.db import transaction

from intranet.femida.src.wf.models import WFModelMixin
from intranet.femida.src.permissions.helpers import get_manager
from intranet.femida.src.communications.models import Message
from intranet.femida.src.utils.itertools import get_chunks


lock = threading.Lock()
logger = logging.getLogger(__name__)


def get_saver(manager):

    @transaction.atomic
    def saver(_id):
        try:
            instance = manager.select_for_update().get(id=_id)
            instance.format_wiki_fields(force=True, timeout=15)
            instance.save(
                ignore_wiki_format=True,
                update_fields=instance.WIKI_FIELDS_MAP.values(),
            )
        except Exception:
            logger.exception('Error during format %s with id %d', manager.model.__name__, _id)
            return 0
        else:
            return 1

    return saver


class Command(BaseCommand):

    help = 'Reformat all wiki-fields.'

    def add_arguments(self, parser):
        parser.add_argument('--model')
        parser.add_argument('--max-pk', action='store', default=0)
        parser.add_argument('--chunk-size', action='store', default=5000)
        parser.add_argument('--pool-size', action='store', default=12)
        parser.add_argument('--ignore-errors', action='store_true')

    def handle_model(self, model, max_pk, chunk_size, ignore_errors=False):
        manager = get_manager(model, unsafe=True)
        ids = manager.order_by('-id').values_list('id', flat=True)
        if max_pk:
            ids = ids.filter(pk__lte=max_pk)
        if model is Message:
            ids = ids.exclude(ignore_wiki_format=True)
        ids = list(ids)

        saver = get_saver(manager)
        chunks = get_chunks(ids, chunk_size)
        start = time.time()

        print('Start reformatting {} of {}'.format(len(ids), model.__name__))
        for i, chunk in enumerate(chunks):
            results = self.pool.map_async(saver, chunk).get()
            success_count = sum(results)
            current_chunk_size = len(chunk)
            if success_count < current_chunk_size:
                print('{} of {} was formatted in chunk #{}'.format(
                    success_count, current_chunk_size, i + 1
                ))
                if not ignore_errors:
                    break
            print('{} of {} was handled in {} seconds'.format(
                i * chunk_size + current_chunk_size, model.__name__, time.time() - start
            ))

    def handle(self, *args, **kwargs):
        model_name = kwargs.get('model')
        max_pk = int(kwargs.get('max_pk'))
        pool_size = int(kwargs.get('pool_size'))
        chunk_size = int(kwargs.get('chunk_size'))
        ignore_errors = kwargs.get('ignore_errors')

        self.pool = Pool(pool_size)

        if model_name:
            model = apps.get_model(model_name)
            wiki_models = [model]
        else:
            wiki_models = [m for m in apps.get_models() if issubclass(m, WFModelMixin)]

        for model in wiki_models:
            self.handle_model(model, max_pk, chunk_size, ignore_errors)
