# -*- coding: utf-8 -*-
import uuid
import logging
import sys

from flask_script import Command as FlaskCommand, Option

from intranet.yandex_directory.src.yandex_directory.common.db import (
    get_meta_connection,
    get_main_connection,
    get_shard_numbers,
    lock,
)
from intranet.yandex_directory.src.yandex_directory.directory_logging.logger import log


class CommandError(RuntimeError):
    """Если брошено это исключение, то команда выводит заданное
    сообщение об ошибке, без трейсбэка, и завершает работу с кодом 1.
    """


class BaseCommand(FlaskCommand):
    def __init__(self, *args, **kwargs):
        if not hasattr(self, 'name'):
            raise AssertionError(
                'Please, set \'name\' attribute for command class {0}'.format(
                    self.__class__.__name__
                )
            )

        self.log = logging.getLogger('command.' + self.name)
        super(BaseCommand, self).__init__(*args, **kwargs)

    def __call__(self, app=None, *args, **kwargs):
        with app.app_context():
            with log.fields(command='command.%s' % self.name,
                            command_id=uuid.uuid4().hex):
                return self.try_run(*args, **kwargs)

    def try_run(self, *args, **kwargs):
        try:
            log.info('Starting command execution')
            result = self.run(*args, **kwargs)

            log.info('Command executed successfully')
            return result
        except CommandError as e:
            log.trace().error('Command execution failed')
            print(str(e))
            sys.exit(1)
        except Exception:
            log.trace().error('Command execution failed')
            raise


class TransactionalCommand(BaseCommand):
    """
    Только для быстрых команд, у админов запущена килялка транзакций > 15 мин
    """

    def __init__(self, *args, **kwargs):
        super(TransactionalCommand, self).__init__(*args, **kwargs)
        # У всех команд должен быть параметр "--shard"
        shard_option_already_defined = False
        for opt in self.option_list:
            if opt.kwargs.get('dest') == 'shard':
                shard_option_already_defined = True

        if not shard_option_already_defined:
            self.option_list = (Option('--shard', '-s', dest='shard', type=int, required=True, help='Shard id'),) + self.option_list

    def try_run(self, *args, **kwargs):
        result = None
        try:
            log.info('Starting command execution')
            shard = kwargs.get('shard')
            with get_meta_connection(for_write=True) as self.meta_connection:
                lock_name = 'TransactionalCommand.{}.Shard{}'.format(self.__class__.__name__, shard)
                with lock(self.meta_connection, lock_name), \
                     get_main_connection(shard=shard, for_write=True) as self.main_connection:
                    result = self.run(*args, **kwargs)
        except CommandError as e:
            log.trace().error('Command execution failed')
            print(str(e))
            sys.exit(1)
        except Exception:
            log.trace().error('Command execution failed')
            raise
        else:
            log.info('Command executed successfully')
        return result


class AllShardsCommand(BaseCommand):
    """
    Только для быстрых команд, у админов запущена килялка транзакций > 15 мин
    Эта команда будет запускать метод run последовательно для каждого из шардов
    главной базы. Коннект к шарду будет доступен через self.main_connection,
    а сам id шарда в self.shard.
    """
    need_writable_database_connections = ['main', 'meta']

    def __for_write(self, db):
        return db in self.need_writable_database_connections

    def try_run(self, *args, **kwargs):
        with get_meta_connection(for_write=self.__for_write('meta')) as self.meta_connection:
            log.info('Starting command execution')

            shards = get_shard_numbers()
            for shard in shards:
                with log.fields(shard=shard):
                    self.shard = shard
                    try:
                        with get_main_connection(shard=shard, for_write=self.__for_write('main')) as self.main_connection:
                            self.run(*args, **kwargs)
                    except CommandError as e:
                        log.trace().error('Command execution on the shard failed')
                        print(str(e))
                        sys.exit(1)
                    except Exception:
                        log.trace().error('Command execution on the shard failed')
                        raise
                    else:
                        log.info('Command executed successfully')

            log.info('Command was executed on all shards')
