import gzip
import json
import logging
import time
from collections import defaultdict

from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist

from concurrent.futures import TimeoutError
import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
import kikimr.public.sdk.python.persqueue.auth as auth
from kikimr.public.sdk.python.persqueue.errors import SessionFailureResult
from ylog.context import log_context

from intranet.search.core import celery, storages
from intranet.search.core.sources.st.utils import is_issue_in_blacklist
from intranet.search.core.swarm.pushes import log_push, handle_single_push, handle_indexation_push
from intranet.search.core.utils.config import get_default_service
from intranet.search.core.utils import log_time
from intranet.search.core.tvm import tvm2_client


log = logging.getLogger(__name__)
CLOUD_TYPES = ('cloud', 'cloud_partner')


@celery.global_app.task(name='isearch.tasks.listen_tracker', bind=True)
def listen_tracker(self):
    """ Задача забирает сообщения из логброкера и запускает соответствующие индексации
    """
    with log_context(task_id=self.request.id, task_type='logbroker'):
        config = settings.ISEARCH['logbroker']
        log.info('Start listen_tracker. Config: %s', config)

        handler = TrackerTaskLogbrokerHandler(**config['tracker']['handler'])
        consumer = LogbrokerConsumer(**config['tracker']['consumer'])
        consumer.read(handler)


@celery.global_app.task(name='isearch.tasks.listen_docs', bind=True)
def listen_docs(self):
    """ Задача забирает сообщения из логброкера и запускает соответствующие индексации
    """
    with log_context(task_id=self.request.id, task_type='logbroker'):
        config = settings.ISEARCH['logbroker']
        log.info('Start listen_docs. Config: %s', config)

        if not config.get('docs'):
            return
        handler = DocsLogbrokerHandler(**config['docs']['handler'])
        consumer = LogbrokerConsumer(**config['docs']['consumer'])
        consumer.read(handler)


class LogbrokerConsumer:
    def __init__(self, host, port, client, topics, tvm_destination='logbroker', timeout=10):
        self.timeout = timeout
        self.api = pqlib.PQStreamingAPI(host, port)
        self.configurator = pqlib.ConsumerConfigurator(
            topics=topics,
            client_id=client,
            read_infly_count=3,
            use_client_locks=True,
        )
        self.auth = auth.TVMCredentialsProvider(tvm2_client, destination_alias=tvm_destination)

    def read(self, handler, messages_limit=1000, flush_limit=300):
        """ Читает поток сообщений от логброкера и отправляет их на обработку в handler

        :param handler: объект-обработчик, должен иметь методы process и flush
        :param messages_limit: лимит на количество сообщений, которые будут обработаны.
        Лимит не жёсткий, потому что сообщения вычитываются батчами, и нельзя сказать логброкеру
        (== сделать коммит), что мы прочитали половину батча, а половину вычитаем в следующий раз,
        можно прочитать только батч целиком. Поэтому, если в батче больше чем messages_limit
        сообщений, то они будут вычитаны все, но следующий батч мы уже прочитаем в следующий раз.
        :param flush_limit: количество сообщений, после которого будет вызываться flush
        в обработчике
        """
        consumer = self._start_consumer()
        last_received_cookie = None
        last_committed_cookie = None
        cycle = 0

        try:
            while messages_limit > 0 or last_received_cookie != last_committed_cookie:
                try:
                    result = consumer.next_event().result(timeout=self.timeout)
                except TimeoutError:
                    log.info('Does not have any unread messages. Stop reading.')
                    if cycle > 0:
                        log.warning('Disconnecting from the working handler')
                    break
                cycle += 1

                if result.type == pqlib.ConsumerMessageType.MSG_COMMIT:
                    log.info('Message commited successfully: %s', result.message)
                    last_committed_cookie = result.message.commit.cookie[-1]
                elif result.type == pqlib.ConsumerMessageType.MSG_LOCK:
                    result.ready_to_read()
                    log.info('Got partition assignment: topic %s, partition %s',
                             result.message.lock.topic, result.message.lock.partition)
                    handler.partition = result.message.lock.partition
                elif result.type == pqlib.ConsumerMessageType.MSG_RELEASE:
                    log.info('Partition revoked. Topic %s, partition %s',
                             result.message.release.topic, result.message.release.partition)
                elif result.type == pqlib.ConsumerMessageType.MSG_DATA:
                    log.info('Read batch %s, messages_limit %s',
                             result.message.data.cookie, messages_limit)

                    for batch in result.message.data.message_batch:
                        for message in batch.message:
                            handler.process(message)
                            messages_limit -= 1
                            if messages_limit == 0:
                                consumer.reads_done()
                            if messages_limit % flush_limit == 0:
                                log.info('Flush messages, messages_limit %s', messages_limit)
                                handler.flush()

                    handler.flush()
                    consumer.commit([result.message.data.cookie])
                    last_received_cookie = result.message.data.cookie
                    log.info('Commit result %s, messages_limit %s',
                             result.message.data.cookie, messages_limit)
        finally:
            log.info('Stopping consumer')
            consumer.stop()

    def _start_consumer(self):
        api_start_future = self.api.start()
        result = api_start_future.result(timeout=self.timeout)
        log.info('Logbroker api started with result: %s', result)

        consumer = self.api.create_consumer(self.configurator, self.auth)
        start_future = consumer.start()
        start_result = start_future.result(timeout=self.timeout)

        if not isinstance(start_result, SessionFailureResult):
            if start_result.HasField('init'):
                log.info('Consumer started with result: %s', start_result)
            else:
                raise RuntimeError(f'Bad consumer start result from server: {start_result}')
        else:
            raise RuntimeError(f'Error occurred on start of consumer: {start_result}')
        log.debug('Consumer started')
        return consumer


class TrackerTaskLogbrokerHandler:
    """ Обработчик данных стартрека от логброкера
    """
    QUEUES_KEY = 'queues'
    ISSUES_KEY = 'issues'
    QUEUES_ISSUES_KEY = 'queues_issues'
    COMPONENTS_ISSUES_KEY = 'components_issues'

    def __init__(self, search='st', issue_index='', queue_index='queue',
                 organization_required=False):
        self.search = search
        self.issue_index = issue_index
        self.queue_index = queue_index
        self.organization_required = organization_required
        self.revision_storage = storages.RevisionStorage()
        self.partition = None
        self._clean_buffer()

    def process(self, message):
        """ Обработчик одного события от логброкера

        :param message: данные события
        """
        data = json.loads(message.data)
        log.info('Got logborker message. Meta: %s, data: %s', message.meta, data)

        event_type = data['type']
        if not self.is_known_event(event_type):
            log.info('Skip unknown message: %s', data)
            return

        organization_id = self.get_organization_id(data)
        storage = storages.OrganizationStorage()
        organization = storage.get_or_none(organization_id)
        if not organization_id and self.organization_required:
            log.error('Got event without known organization: event=%s, orgId=%s',
                      data['meta'].get('eventId'), data['meta'].get('orgId'))
            return
        if organization['organization_type'] in CLOUD_TYPES:
            log.info('Got event from the cloud organization: orgId=%s skipping', data['meta'].get('orgId'))
            return

        if event_type.startswith('Issue'):
            self.process_issue(data, organization_id)
        elif event_type.startswith('Queue'):
            self.process_queue(data, organization_id)
        elif event_type.startswith('Component'):
            self.process_component(data, organization_id)

    def flush(self):
        """ Сбрасывание событий потока, вызывается после обработки всех событий.
        """
        log.debug('Collect objects: %s', self._buffer)
        log.info('Flushing objects')
        start = time.time()
        obj_num = issues_num = components_num = queues_issues_num = org_num = queues_num = 0
        for organization_id, data in self._buffer.items():
            issues = data.get(self.ISSUES_KEY, {})
            issues_num += len(issues)
            queues_issues = data.get(self.QUEUES_ISSUES_KEY, {})
            queues_issues_num += len(queues_issues)
            components_issues = data.get(self.COMPONENTS_ISSUES_KEY, {})
            components_num += len(components_issues)
            self.run_issues_indexation(organization_id, issues=issues,
                                       queues=queues_issues, components=components_issues)
            queues = data.get(self.QUEUES_KEY, {})
            queues_num += len(queues)
            self.run_queues_indexation(organization_id, queues=queues)
            org_num += 1
            obj_num += (issues_num + queues_issues_num + components_num + queues_num)

        self._clean_buffer()
        end = time.time()
        log.info('Finished flushing objects. Partition: %s. Time spent: %s, collected objects: %s, '
                 'issues: %s, queues_issues: %s, components: %s, queues: %s',
                 end-start, self.partition,  obj_num, issues_num, queues_issues_num, components_num, queues_num)

    def process_issue(self, data, organization_id):
        """ Обработка пуша об изменении задачи. Собираем их в буфер, чтобы по несколько раз не
        индексировать одну и ту же задачу
        """
        log.debug('Handle issue push for organization %s', organization_id)
        self._save_to_buffer(organization_id, self.ISSUES_KEY, data['meta']['issueKey'], data)

    def process_queue(self, data, organization_id):
        """ Обработка пуша об изменении очереди. Собираем их в буфер и потом индексируем вместе.
        """
        if {'name', 'permissions'} & self.changed_fields(data):
            self._save_to_buffer(organization_id, self.QUEUES_ISSUES_KEY, data['meta']['key'], data)
        if self.queue_index:
            self._save_to_buffer(organization_id, self.QUEUES_KEY, data['meta']['key'], data)

    def process_component(self, data, organization_id):
        """ Обработка пуша об изменении компонента. Собираем их в буфер и потом индексируем вместе.
        """
        if (data['type'] not in ('ComponentDeleted', 'ComponentCreated')
                and {'name', 'permissions'} & self.changed_fields(data)):
            self._save_to_buffer(organization_id, self.COMPONENTS_ISSUES_KEY, data['subject']['id'], data)

    def _save_to_buffer(self, organization_id, object_type, object_id, data):
        self._buffer[organization_id].setdefault(object_type, {})
        self._buffer[organization_id][object_type][object_id] = data

    def _clean_buffer(self):
        self._buffer = defaultdict(dict)

    def changed_fields(self, data):
        return set(data['meta'].get('fields', []) if 'meta' in data else [])

    def get_organization_id(self, data):
        """ Получение нашей организации из айди директории
        """
        if data['meta'].get('orgId'):
            try:
                org = storages.OrganizationStorage().get_by_directory_or_label(
                    data['meta']['orgId'])
            except ObjectDoesNotExist:
                return
        else:
            # во внутреннем поиске организация не указывается
            org = storages.OrganizationStorage().get_by_directory_or_label('yandex')
        return org['id']

    def get_revisions(self, organization_id):
        if not hasattr(self, '_revision_cache'):
            self._revision_cache = {}

        if organization_id not in self._revision_cache:
            params = dict(search=self.search, index=self.issue_index, backend='platform',
                          organization_id=organization_id)
            revisions = self.revision_storage.get_actual(**params)
            if not revisions:
                service = get_default_service(params['search'], params['index'])
                self.revision_storage.create(status='active', service=service, **params)
            self._revision_cache[organization_id] = revisions

        return self._revision_cache[organization_id]

    def is_known_event(self, event_type):
        """ Сообщает о том, нужно ли нам обрабатывать событие
        """
        return (event_type.startswith('Issue')
                or event_type.startswith('Component')
                or event_type.startswith('Queue'))

    def run_issues_indexation(self, organization_id, queues, components, issues):
        """ Запускает переиндексацию задач из обновленных очередей и компонент
        """
        logging_time = 0
        handling_time = 0

        if issues:
            revisions = self.get_revisions(organization_id)
            for issue, data in issues.items():

                if is_issue_in_blacklist(issue):
                    log.warning('Issue %s is in black list. Skip indexing.', issue)
                    continue

                with log_time() as time_logger:
                    push_id = log_push(data, self.search, self.issue_index,
                                       push_type=data['type'],
                                       organization_id=organization_id)
                logging_time += time_logger.duration

                with log_time() as time_logger:
                    handle_single_push(push_id, data=data, search=self.search, index=self.issue_index,
                                       organization_id=organization_id, revisions=revisions)
                handling_time += time_logger.duration

        log.info('Logging time: %s, handling time: %s', logging_time, handling_time)
        if not queues and not components:
            log.debug('Nothing to run: empty queues and components')
            return

        keys = list(queues) + ['#%s' % c for c in components]
        if keys:
            data = list(queues.values()) + list(components.values())
            with log_time() as time_logger:
                push_id = log_push({'keys': data}, self.search, self.issue_index,
                                   push_type='batch_queues_components_push',
                                   organization_id=organization_id)
            batch_logging_time = time_logger.duration

            with log_time() as time_logger:
                handle_indexation_push(push_id, search=self.search, index=self.issue_index,
                                       organization_id=organization_id, keys=keys, user='logbroker')
            batch_handling_time = time_logger.duration
            log.info('Batch queue logging time: %s, handling time: %s', batch_logging_time, batch_handling_time)

    def run_queues_indexation(self, organization_id, queues):
        if not queues or not self.queue_index:
            log.debug('Not run queues indexation: queue_index=%s, queues=%s',
                      self.queue_index, queues)
            return

        for queue, data in queues.items():
            push_id = log_push(data, self.search, self.queue_index,
                               push_type=data['type'],
                               organization_id=organization_id)
            handle_single_push(push_id, search=self.search, index=self.queue_index,
                               data=data, organization_id=organization_id)


class DocsLogbrokerHandler:
    def __init__(self, search='doc', index='external'):
        self.search = search
        self.index = index
        self.revision_storage = storages.RevisionStorage()
        self.partition = None
        self._clean_buffer()

    def process(self, message):
        """ Обработчик одного события от логброкера
        :param message: данные события
        """
        data = message.data
        if message.meta.codec == 1:
            # if the codec in GZIP we have to decompress the data first
            data = gzip.decompress(data)
        data = json.loads(data)
        log.info('Got logborker message. Meta: %s, data: %s', message.meta, data)

        if not self._assert_data_fmt_is_correct(data):
            log.info('Skip unknown message: %s', data)
            return

        organization_id = self.get_organization_id()
        self.process_push(data, organization_id)

    def flush(self):
        """ Сбрасывание событий потока, вызывается после обработки всех событий.
        """
        log.debug('Collect objects: %s', self._buffer)
        log.info('Flushing objects')
        obj_num, start = 0, time.time()
        for organization_id, id2pushes in self._buffer.items():
            self.run_indexation(organization_id, id2pushes)
            obj_num += len(id2pushes)
        self._clean_buffer()
        end = time.time()
        log.info(
            'Finished flushing objects. Partition: %s. Time spent: %s, collected objects: %s',
            self.partition, end-start, obj_num,
        )

    @staticmethod
    def get_organization_id():
        org = storages.OrganizationStorage().get_by_directory_or_label('yandex')
        return org['id']

    def get_revisions(self, organization_id):
        if not hasattr(self, '_revision_cache'):
            self._revision_cache = {}

        if organization_id not in self._revision_cache:
            params = dict(
                search='doc',
                index='external',
                backend='platform',
                organization_id=1,
            )
            revisions = self.revision_storage.get_actual(**params)
            if not revisions:
                service = get_default_service(params['search'], params['index'])
                self.revision_storage.create(status='active', service=service, **params)
            self._revision_cache[organization_id] = revisions

        return self._revision_cache[organization_id]

    def run_indexation(self, organization_id, id2pushes):
        """ Запускает переиндексацию задач из обновленных очередей и компонент
        """
        logging_time = 0
        handling_time = 0

        if id2pushes:
            revisions = self.get_revisions(organization_id)
            for _, push in id2pushes.items():

                with log_time() as time_logger:
                    push_id = log_push(
                        push, self.search, self.index, push_type='docs_push', organization_id=organization_id
                    )
                logging_time += time_logger.duration

                with log_time() as time_logger:
                    kwargs = {}
                    if push['is_deleted']:
                        kwargs['action'] = 'delete'
                    handle_single_push(
                        push_id, data=push, search=self.search, index=self.index,
                        organization_id=organization_id, revisions=revisions, **kwargs,
                    )
                handling_time += time_logger.duration

        log.info('Logging time: %s, handling time: %s', logging_time, handling_time)

    @staticmethod
    def _assert_data_fmt_is_correct(data):
        return {'id', 'url', 'apiurl', 'is_deleted'}.issubset(data.keys())

    def _clean_buffer(self):
        self._buffer = defaultdict(dict)

    def process_push(self, data, organization_id):
        self._save_to_buffer(organization_id, data['id'], data)

    def _save_to_buffer(self, organization_id, object_id, data):
        self._buffer[organization_id][object_id] = data
