import json
from concurrent.futures import TimeoutError

import kikimr.public.sdk.python.persqueue.auth as auth
import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
import tvmauth
from kikimr.public.sdk.python.persqueue.errors import SessionFailureResult

from intranet.yandex_directory.src import settings
from intranet.yandex_directory.src.yandex_directory import app
from intranet.yandex_directory.src.yandex_directory.common.db import get_meta_connection, get_main_connection, get_shard
from intranet.yandex_directory.src.yandex_directory.core.actions import (
    action_domain_add,
    action_domain_delete,
    action_domain_master_modify,
)
from intranet.yandex_directory.src.yandex_directory.core.models import OrganizationMetaModel
from intranet.yandex_directory.src.yandex_directory.core.task_queue import Task
from intranet.yandex_directory.src.yandex_directory.directory_logging.logger import log
from intranet.yandex_directory.src.yandex_directory.core.tasks import UpdateEmailFieldsTask


class DomainEventProcessor:
    allowed_events = [
        'domain_occupied',
        'domain_deleted',
        'domain_added',
        'domain_changed',
        'domain_master_changed',
    ]

    def process_message(self, message):
        data = json.loads(message.data)
        event_name = data['name']
        event_data = data['data']
        self.process_event(event_name, event_data)

    def process_event(self, event_name, event_data):
        log.info('Processing event %s with data %s', event_name, event_data)
        if event_name not in self.allowed_events:
            log.error('Unknown event: %s, data: %s', event_name, event_data)
            return
        method = getattr(self, event_name)
        method(**event_data)

    def get_shard_by_org_id(self, org_id):
        with get_meta_connection() as meta_connection:
            org_meta = OrganizationMetaModel(meta_connection).get(org_id)
            shard = org_meta['shard']
        return shard

    def domain_added(self, domain, org_id, author_id):
        with log.fields(org_id=org_id, domain=domain, author_id=author_id):
            log.info('Domain added for org_id %s: %s', org_id, domain)
        shard = self.get_shard_by_org_id(org_id)
        with get_main_connection(for_write=True, shard=shard) as main_connection:
            action_domain_add(
                main_connection,
                org_id=org_id,
                author_id=author_id,
                object_value=domain,
                old_object=None,
            )

    def domain_deleted(self, domain, org_id, author_id):
        with log.fields(org_id=org_id, domain=domain):
            log.info('Domain deleted for org_id %s: %s', org_id, domain)
        shard = self.get_shard_by_org_id(org_id)
        with get_main_connection(for_write=True, shard=shard) as main_connection:
            action_domain_delete(
                main_connection,
                org_id=org_id,
                author_id=author_id,
                object_value=domain,
                old_object=domain,
            )

    def domain_master_changed(self, domain, org_id, admin_uid, old_master_domain, old_master_admin_uid):
        shard = self.get_shard_by_org_id(org_id)
        with get_main_connection(shard=shard) as main_connection:
            UpdateEmailFieldsTask(main_connection).delay(
                master_domain=domain,
                org_id=org_id,
            )

            action_domain_master_modify(
                main_connection,
                org_id=org_id,
                author_id=admin_uid,
                object_value=domain,
                old_object=old_master_domain,
            )


class ReadDomainUpdatesFromLogbrokerTask(Task):
    singleton = True

    def do(self):
        api = pqlib.PQStreamingAPI(settings.LOGBROKER_ENDPOINT, settings.LOGBROKER_PORT)
        api_start_future = api.start()
        result = api_start_future.result(timeout=10)
        tvm_settings = tvmauth.TvmApiClientSettings(
            self_client_id=int(app.config['TVM_CLIENT_ID']),
            self_secret=app.config['TVM_SECRET'],
            dsts={'logbroker': app.config['LOGBROKER_TVM_ID']}
        )
        tvm_client = tvmauth.TvmClient(tvm_settings)
        credentials_provider = auth.TVMCredentialsProvider(tvm_client=tvm_client, destination_alias='logbroker')
        configurator = pqlib.ConsumerConfigurator(settings.LOGBROKER_TOPIC, settings.LOGBROKER_CONSUMER_NAME)
        consumer = api.create_consumer(configurator, credentials_provider=credentials_provider)
        start_future = consumer.start()
        start_result = start_future.result(timeout=10)
        if isinstance(start_result, SessionFailureResult):
            log.error('Logbroker consumer could not start')
            return

        last_received_cookie = None
        last_committed_cookie = None
        total_messages_expected = 1000
        processor = DomainEventProcessor()

        while total_messages_expected > 0 or last_received_cookie != last_committed_cookie:
            try:
                result = consumer.next_event().result(timeout=10)
            except TimeoutError:
                log.info('Timeout during logbroker read, no data to read')
                consumer.stop()
                api.stop()
                return

            if result.type == pqlib.ConsumerMessageType.MSG_DATA:
                for batch in result.message.data.message_batch:
                    for message in batch.message:
                        processor.process_message(message)
                        total_messages_expected -= 1
                        if total_messages_expected <= 0:
                            consumer.reads_done()
                consumer.commit(result.message.data.cookie)
            elif result.type == pqlib.ConsumerMessageType.MSG_COMMIT:
                last_committed_cookie = result.message.commit.cookie[-1]

        consumer.stop()
        api.stop()
