# -*- coding: utf-8 -*-
import logging
import time

from passport.backend.core.builders.passport import (
    Passport,
    PassportAccountNotFoundError,
    PassportActionNotRequiredError,
)
from passport.backend.core.logbroker.logbroker_base import format_protobuf_safe
from passport.backend.core.logging_utils.loggers.statbox import StatboxLogger
from passport.backend.core.logging_utils.request_id import RequestIdManager
from passport.backend.logbroker_client.core.consumers.simple.native_worker import NativeLogbrokerWorker
from passport.backend.logbroker_client.core.handlers.protobuf import BaseProtobufHandler
from passport.backend.logbroker_client.takeout_tasks.exceptions import LogbrokerWriterError
from passport.backend.takeout.common.conf.services import get_service_configs
from passport.backend.takeout.common.logbroker import get_takeout_logbroker_writer
from passport.backend.takeout.common.utils import get_task_id
from passport.backend.takeout.logbroker_client.cleanup import (
    cleanup_task,
    send_cleanup,
)
from passport.backend.takeout.logbroker_client.download_user_data import send_extract_task_message
from passport.backend.takeout.logbroker_client.exception_handler import process_exception
from passport.backend.takeout.logbroker_client.extract_service_tasks import (
    async_extract_service_get_task,
    async_extract_service_start_task,
    async_upload_service_get_task,
    async_upload_service_start_task,
    send_extract_service_task_message,
    sync_extract_service_task,
)
from passport.backend.takeout.logbroker_client.make_archive import (
    check_touch_files_done,
    make_archive_task,
    send_make_archive,
)
from passport.backend.takeout.logbroker_client.resend_message import resend_message


log = logging.getLogger(__name__)


class TakeoutTaskHandler(BaseProtobufHandler):
    handler_name = 'takeout_tasks'
    statbox = StatboxLogger(
        log_source='logbroker',
        tskv_format='takeout-log',
    )

    def __init__(self, config, **kwargs):
        super(TakeoutTaskHandler, self).__init__(config=config, **kwargs)
        if config.get('explicit_partitions_in_writer'):
            self.override_logbroker_writer_settings()

        self.use_tvm = config['passport']['use_tvm']

        if config.get('archive') and config['archive'].get('cooking_directory'):
            self.cooking_directory = config['archive']['cooking_directory']
        else:
            self.cooking_directory = None

        self.default_pause = int(config['DEFAULT_TASK_PAUSE'])
        self.task_pauses = config['task_deltas']['task_pauses']
        self.max_pause = int(config['TASK_MAX_PAUSE'])
        self.pause_multiplier = int(config['TASK_PAUSE_MULTIPLIER'])
        self.expire_delta = int(config['TASK_EXPIRE'])
        self.delay_throttling_min_exec_time = float(config.get('delay_throttling_min_exec_time', 0.0))

    def override_logbroker_writer_settings(self):
        task = NativeLogbrokerWorker.current_task
        get_takeout_logbroker_writer('takeout_tasks', task['host'], task['partition_group'])

    def is_message_delayed(self, message):
        if message.delay_until == 0 or int(time.time()) > message.delay_until:
            return False

        return True

    def is_message_expired(self, message):
        if int(time.time()) < message.unixtime + self.expire_delta:
            return False

        return True

    def get_task_pause(self, retries, service_type=None, step='start'):
        pause = self.default_pause
        if service_type:
            pause = int(self.task_pauses[service_type][step])

        return min(pause * self.pause_multiplier ** retries, self.max_pause)

    def process_download_user_data(
        self,
        task_id,
        uid,
        extract_id,
        unixtime,
        services=None,
        retries=0,
        max_retries=0,
        manual_takeout=None,
    ):
        service_configs = get_service_configs(services=services)

        if retries == 0:
            enabled_service_configs = {
                service_name: service_config
                for service_name, service_config in service_configs.items() if service_config.enabled
            }

            is_sent = send_make_archive(
                extract_id=extract_id,
                manual_takeout=manual_takeout,
                services=enabled_service_configs.keys(),
                task_id=get_task_id(),
                uid=uid,
                unixtime=unixtime,
                source='download',
            )
            if not is_sent:
                raise LogbrokerWriterError("Force error raise: send 'make_archive' message failed during extract task")

        failed_services = []
        for service_name, service_config in service_configs.items():
            is_sent = send_extract_service_task_message(
                task_id=get_task_id(),
                service_name=service_name,
                uid=uid,
                extract_id=extract_id,
                unixtime=unixtime,
                max_retries=max_retries,
                service_enabled=service_config.enabled,
                step='start',
                source='download',
            )
            if not is_sent:
                failed_services.append(service_name)

        if not failed_services and not manual_takeout:
            try:
                Passport(use_tvm=self.use_tvm).takeout_start_extract(uid=uid)
            except (PassportActionNotRequiredError, PassportAccountNotFoundError) as e:
                log.debug('Ignoring error from passport: {}'.format(e.__class__.__name__))

            self.statbox.log(
                task_id=task_id,
                uid=uid,
                extract_id=extract_id,
                task_name='extract',
                status='ok',
                archive_requested_unixtime=unixtime,
            )
        elif failed_services:
            if max_retries == 0 or retries < max_retries:
                send_extract_task_message(
                    delay_until=self.get_task_pause(retries) + int(time.time()),
                    extract_id=extract_id,
                    manual_takeout=manual_takeout,
                    max_retries=max_retries,
                    retries=retries + 1,
                    services=failed_services,
                    task_id=task_id,
                    uid=uid,
                    unixtime=unixtime,
                    source='download:failed_service',
                )
            else:
                log.exception('Task download user data failed and will not be retried anymore')
                self.statbox.log(
                    task_id=task_id,
                    uid=uid,
                    extract_id=extract_id,
                    task_name='extract',
                    status='failure',
                    archive_requested_unixtime=unixtime,
                    failed_services=', '.join(sorted(failed_services)),
                )
        elif manual_takeout:
            pass
        else:
            raise NotImplementedError()

    def _process_extract_service(self, uid, extract_id, service_name, unixtime, step, job_id, max_retries):
        service_config = get_service_configs()[service_name]

        if service_config.type == 'sync':
            sync_extract_service_task(
                uid=uid,
                extract_id=extract_id,
                service_name=service_name,
                unixtime=unixtime,
            )

        elif service_config.type == 'async':
            if step == 'start':
                job_id = async_extract_service_start_task(
                    uid=uid,
                    extract_id=extract_id,
                    service_name=service_name,
                    unixtime=unixtime,
                )

                is_sent = send_extract_service_task_message(
                    task_id=get_task_id(),
                    service_name=service_name,
                    step='get',
                    uid=uid,
                    extract_id=extract_id,
                    unixtime=unixtime,
                    max_retries=max_retries,
                    job_id=job_id,
                    source='extract'
                )

                if not is_sent:
                    raise LogbrokerWriterError("Force error raise: send second step 'extract_service' message failed")

            elif step == 'get':
                async_extract_service_get_task(
                    uid=uid,
                    extract_id=extract_id,
                    service_name=service_name,
                    job_id=job_id,
                )

        elif service_config.type == 'async_upload':
            if step == 'start':
                job_id = async_upload_service_start_task(
                    uid=uid,
                    extract_id=extract_id,
                    service_name=service_name,
                    unixtime=unixtime,
                )

                is_sent = send_extract_service_task_message(
                    task_id=get_task_id(),
                    service_name=service_name,
                    step='get',
                    uid=uid,
                    extract_id=extract_id,
                    unixtime=unixtime,
                    max_retries=max_retries,
                    job_id=job_id,
                    source='extract',
                )

                if not is_sent:
                    raise LogbrokerWriterError("Force error raise: send second step 'extract_service' message failed")

            elif step == 'get':
                async_upload_service_get_task(
                    uid=uid,
                    extract_id=extract_id,
                    service_name=service_name,
                )

    def process_extract_service_tasks(self, task_id, uid, extract_id, service_name, unixtime, step=None, retries=0,
                                      max_retries=0, job_id=None):
        try:
            self._process_extract_service(
                uid=uid,
                extract_id=extract_id,
                service_name=service_name,
                unixtime=unixtime,
                step=step,
                job_id=job_id,
                max_retries=max_retries,
            )

            self.statbox.log(
                task_id=task_id,
                uid=uid,
                extract_id=extract_id,
                task_name='extract_service',
                service_name=service_name,
                step=step,
                status='ok',
                archive_requested_unixtime=unixtime,
            )

        except Exception as e:
            statbox_params = {
                'service_name': service_name,
                'archive_requested_unixtime': unixtime,
                'step': step,
                'job_id': job_id,
            }

            can_be_retried = process_exception(
                task_id=task_id,
                uid=uid,
                extract_id=extract_id,
                task_name='extract_service',
                statbox_params=statbox_params,
                e=e,
                retries=retries,
                max_retries=max_retries,
            )

            if can_be_retried:
                service_type = get_service_configs()[service_name].type

                is_sent = send_extract_service_task_message(
                    task_id=task_id,
                    uid=uid,
                    extract_id=extract_id,
                    service_name=service_name,
                    step=step,
                    unixtime=unixtime,
                    retries=retries + 1,
                    max_retries=max_retries,
                    job_id=job_id,
                    delay_until=self.get_task_pause(retries, service_type=service_type, step=step) + int(time.time()),
                    source='extract:exc',
                )

                if not is_sent:
                    raise LogbrokerWriterError("Force error raise: send retry 'extract_service' message failed")

    def process_make_archive(
        self,
        task_id,
        uid,
        extract_id,
        unixtime,
        services,
        retries,
        manual_takeout=None,
    ):
        try:
            ts = time.time()
            check_touch_files_done(
                uid=uid,
                extract_id=extract_id,
                services=services,
                retries=retries,
            )
            log.debug('Checked touch files in {:.4f}'.format(time.time() - ts))

            ts = time.time()
            make_archive_task(
                cooking_directory=self.cooking_directory,
                extract_id=extract_id,
                manual_takeout=manual_takeout,
                uid=uid,
                unixtime=unixtime,
            )
            log.debug('Made archive in {:.4f}'.format(time.time() - ts))

            if not manual_takeout:
                is_sent = send_cleanup(
                    task_id=get_task_id(),
                    uid=uid,
                    extract_id=extract_id,
                    unixtime=unixtime,
                    source='make_archive',
                )
            else:
                is_sent = False

            self.statbox.log(
                task_id=task_id,
                uid=uid,
                extract_id=extract_id,
                task_name='make_archive',
                status='ok',
                archive_requested_unixtime=unixtime,
                cleanup_task_sent=is_sent,
            )

        except Exception as e:
            statbox_params = {
                'archive_requested_unixtime': unixtime,
            }
            process_exception(
                task_id=task_id,
                uid=uid,
                extract_id=extract_id,
                task_name='make_archive',
                statbox_params=statbox_params,
                retries=retries,
                e=e,
            )
            is_sent = send_make_archive(
                delay_until=self.get_task_pause(retries) + int(time.time()),
                extract_id=extract_id,
                manual_takeout=manual_takeout,
                retries=retries + 1,
                services=services,
                task_id=task_id,
                uid=uid,
                unixtime=unixtime,
                source='make_archive:exc',
            )
            if not is_sent:
                raise LogbrokerWriterError("Force error raise: send retry 'make_archive' message failed")

    def process_cleanup(self, task_id, uid, extract_id, unixtime, retries=0, max_retries=0):
        try:
            cleanup_task(
                uid=uid,
                extract_id=extract_id,
            )

            self.statbox.log(
                task_id=task_id,
                uid=uid,
                extract_id=extract_id,
                task_name='cleanup',
                status='ok',
                archive_requested_unixtime=unixtime,
            )

        except Exception as e:
            statbox_params = {
                'archive_requested_unixtime': unixtime,
                'unixtime': unixtime,
            }
            can_be_retried = process_exception(
                task_id=task_id,
                uid=uid,
                extract_id=extract_id,
                task_name='cleanup',
                statbox_params=statbox_params,
                e=e,
                retries=retries,
                max_retries=max_retries,
            )
            if can_be_retried:
                is_sent = send_cleanup(
                    task_id=task_id,
                    uid=uid,
                    extract_id=extract_id,
                    unixtime=unixtime,
                    retries=retries + 1,
                    max_retries=max_retries,
                    delay_until=self.get_task_pause(retries) + int(time.time()),
                    source='cleanup:exc',
                )
                if not is_sent:
                    raise LogbrokerWriterError("Force error raise: send retry 'cleanup' message failed")

    def push_task_info_to_request_id(self, task_message):
        # Формат request_id будет ...,uid,extract_id,имя_сервиса либо make_archive, cleanup, extract
        task_base = task_message.task_base_message
        request_id_bits = [
            task_base.uid,
            task_base.extract_id,
        ]
        if task_base.task_name in ['cleanup', 'extract', 'make_archive']:
            request_id_bits.append(task_base.task_name)
        elif task_base.task_name == 'extract_service':
            request_id_bits.append(task_message.extract_service.service)
        for request_id_bit in request_id_bits:
            RequestIdManager.push_request_id(request_id_bit)

    def _throttle_delay(self, start_time: float):
        if self.delay_throttling_min_exec_time:
            delta = time.time() - start_time
            if delta < self.delay_throttling_min_exec_time:
                time.sleep(self.delay_throttling_min_exec_time - delta)

    def process_message(self, header, message):
        start_time = time.time()
        task_base = message.task_base_message
        log.debug('[{}:{}] TASK {} uid={}'.format(
            task_base.task_id, task_base.seq, task_base.task_name, task_base.uid,
        ))

        self.push_task_info_to_request_id(message)

        if self.is_message_expired(task_base):
            log.warning('Task {} for uid {} is expired and will not be retried'.format(
                task_base.task_name,
                task_base.uid,
            ))
            self.statbox.log(
                task_id=task_base.task_id,
                uid=task_base.uid,
                extract_id=task_base.extract_id,
                task_name=task_base.task_name,
                status='expired',
                delay_until=task_base.delay_until,
                archive_requested_unixtime=task_base.unixtime,
            )
            return

        if self.is_message_delayed(task_base):
            resend_message(message=message, source='delay')
            self._throttle_delay(start_time)
            log.debug('Task delayed')
            return

        log.debug('EXECUTE TASK {}'.format(format_protobuf_safe(message)))

        if task_base.task_name == 'extract':
            task_detail = message.extract
            self.process_download_user_data(
                extract_id=task_base.extract_id,
                manual_takeout=task_detail.is_manual_takeout,
                max_retries=task_base.max_retries,
                retries=task_base.retries,
                services=task_detail.services or None,
                task_id=task_base.task_id,
                uid=task_base.uid,
                unixtime=task_base.unixtime,
            )
        elif task_base.task_name == 'extract_service':
            task_detail = message.extract_service
            self.process_extract_service_tasks(
                task_id=task_base.task_id,
                uid=task_base.uid,
                extract_id=task_base.extract_id,
                unixtime=task_base.unixtime,
                service_name=task_detail.service,
                step=task_detail.step or None,
                retries=task_base.retries,
                max_retries=task_base.max_retries,
                job_id=task_detail.job_id or None,
            )

        elif task_base.task_name == 'make_archive':
            task_detail = message.make_archive
            self.process_make_archive(
                extract_id=task_base.extract_id,
                manual_takeout=task_detail.is_manual_takeout,
                retries=task_base.retries,
                services=task_detail.services,
                task_id=task_base.task_id,
                uid=task_base.uid,
                unixtime=task_base.unixtime,
            )

        elif task_base.task_name == 'cleanup':
            self.process_cleanup(
                task_id=task_base.task_id,
                uid=task_base.uid,
                extract_id=task_base.extract_id,
                unixtime=task_base.unixtime,
                retries=task_base.retries,
                max_retries=task_base.max_retries,
            )
