# -*- coding: utf-8 -*-
import time
from contextlib import contextmanager
from yt.wrapper import ypath_join, YtError
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.config.yt import PRODUCTION_ROOT
from datacloud.dev_utils.yt import yt_utils
from datacloud.dev_utils.yt.yt_config_table import ConfigTable
from datacloud.dev_utils.status_db.task import Task, Status
from datacloud.dev_utils.time.utils import now_str
from datacloud.dev_utils.status_db.time_log import TimeDB, TimeRecord
from datetime import datetime
from datacloud.dev_utils.time.patterns import FMT_DATE_HMS


__all__ = [
    'StatusDB'
]


logger = get_basic_logger(__name__)
STATUS_DB_TAG = 'STATUS_DB'

DEFAULT_STATUS_DB_PATH = ypath_join(PRODUCTION_ROOT, 'new-status-db')
DEFAULT_LOCK_PATH = ypath_join(PRODUCTION_ROOT, 'locks')


@contextmanager
def lock_manager(lock_name, yt_client=None):
    yt_client = yt_client or yt_utils.get_yt_client()
    # From old version, don't know do we really need this
    yt_client.config['transaction_use_signal_if_ping_failed'] = True
    lock_path = DEFAULT_LOCK_PATH
    with yt_client.Transaction():
        try:
            yt_client.lock(lock_path, mode='shared', child_key=lock_name, waitable=True, wait_for=1000)
            logger.info('[LOCK]: SUCCESS! {}'.format(lock_name))
        except YtError:
            raise LockError("[LOCK]: FAIL! {}".format(lock_name))
        yield


class LockError(RuntimeError):
    pass


class StatusDB(ConfigTable):
    def __init__(self, table_path=DEFAULT_STATUS_DB_PATH, yt_client=None):
        schema = [
            {'name': 'program', 'type': 'string', 'sort_order': 'ascending', 'required': True},
            {'name': 'key', 'type': 'string', 'sort_order': 'ascending', 'required': True},
            {'name': 'status', 'type': 'string', 'required': True},
            {'name': 'data', 'type': 'any'},
            {'name': 'create_time', 'type': 'string', 'required': True},
            {'name': 'update_time', 'type': 'string', 'required': True},
        ]

        super(StatusDB, self).__init__(
            table_path, schema, yt_client
        )

        self._get_tasks_request = \
            '*' \
            ' FROM [' + self.table_path + ']' \
            ' WHERE program = "{program}" AND key = "{key}"'

        self._get_program_request = \
            '*' \
            ' FROM [' + self.table_path + ']' \
            ' WHERE program = "{program}"'

        self._get_program_with_status_request = \
            '*' \
            ' FROM [' + self.table_path + ']' \
            ' WHERE program = "{program}"' \
            ' AND status = "{status}"'

        self._get_all_with_status_request = \
            '*' \
            ' FROM [' + self.table_path + ']' \
            ' WHERE status = "{status}"'

    def _build_get_task_request(self, program, key):
        return self._get_tasks_request.format(program=program, key=key)

    def _build_get_program_request(self, program):
        return self._get_program_request.format(program=program)

    def get_task(self, program, key):
        # request = self._get_tasks_request.format(program=program, key=key)
        raw_record = self.get_task_raw(program, key)
        if raw_record:
            return Task(**raw_record)
        return None

    def get_task_raw(self, program, key):
        request = self._get_tasks_request.format(program=program, key=key)
        return self.get_record(request)

    def get_tasks_with_status(self, program, status):
        for raw_record in self.get_tasks_with_status_raw(program, str(status)):
            task = Task(**raw_record)
            if task.status == status:
                yield task

    def get_tasks_with_status_raw(self, program, status='READY'):
        recs = self.request_records(self._get_program_with_status_request.format(
            program=program, status=str(status)))
        for raw_record in recs:
            yield raw_record

    def get_all_with_status_raw(self, status='READY'):
        recs = self.request_records(self._get_all_with_status_request.format(
            status=str(status)))
        for raw_record in recs:
            yield raw_record

    def add_task(self, task):
        """[summary]

        :param task: [description]
        :type task: dev_utils.task.Task
        """
        self.add_multiple_tasks([task])

    def add_multiple_tasks(self, task_list):
        logger.info('Add tasks: {}'.format([str(task) for task in task_list]))
        task_list = [task.to_json() for task in task_list]
        self.insert_records_with_retry(task_list, n_retry=10)

    def update_status(self, program, key, new_status):
        self.update_record(self._build_get_task_request(program, key), {'status': str(new_status)})

    def check_task_exists(self, program, key):
        return self.get_task(program, key) is not None

    def get_task_status(self, program, key):
        return self.get_task(program, key).status

    def detect_ready(self, program, tables_func, date_str, status=Status.READY, force=False):
        if force:
            logger.info('[{}] FORCE ON'.format(STATUS_DB_TAG))
        for key, data in tables_func(date_str):
            if not self.check_task_exists(program, key) or force:
                current_time = now_str()
                assert isinstance(data, dict), '`data` for task must be instance of dict, your data is: {}'.format(data)

                # print('Task is: {}'.format(Task(program, key, status, data, current_time, current_time)))
                # TODO: Uncomment
                self.add_task(Task(program, key, status, data, current_time, current_time))
            else:
                pass
                # logger.info('[{}] task {}#{} already in status-db. Skipped.'.format(STATUS_DB_TAG, program, key))
        if force:
            logger.info('[{}] DONE FORCE'.format(STATUS_DB_TAG))
        logger.info('[{}] DONE detect ready for `{}`'.format(STATUS_DB_TAG, program))

    def run_task(self, program, step_call, from_status=Status.READY):
        ready_tasks = list(self.get_tasks_with_status(program, from_status))
        for idx, task in enumerate(sorted(ready_tasks, key=lambda it: str(it))):
            logger.info('Task {}/{}'.format(idx + 1, len(ready_tasks)))
            lock_name = program + '#' + str(task.key)
            try:
                with lock_manager(lock_name):
                    actual_status = self.get_task_status(task.program, task.key)
                    if actual_status != from_status:  # Check if task status was changed before lock
                        logger.warning('[%s] Task `%s` status was changed before lock from `%s` to `%s`. Skip task.', STATUS_DB_TAG, task, from_status, actual_status)
                        continue

                    start_time = time.time()
                    new_tasks = step_call(task)
                    duration = int(time.time() - start_time)

                    if new_tasks:
                        self.add_multiple_tasks(new_tasks)
                        logger.info('[{}] Set new tasks: {}'.format(STATUS_DB_TAG, [str(task) for task in new_tasks]))
                        self._try_log_done_tasks(new_tasks, duration)
                    else:
                        logger.warning('[%s] no new task for %s', STATUS_DB_TAG, task)
                    logger.info('[{}] DONE {} ({} seconds)'.format(STATUS_DB_TAG, task, duration))
            except LockError:
                logger.warning('[%s] STIP Can not take lock: `%s`', STATUS_DB_TAG, lock_name)
        logger.info('[%s] DONE all programs `%s` with status `%s`', STATUS_DB_TAG, program, from_status)

    def _try_log_done_tasks(self, tasks, duration):
        for task in tasks:
            if task.status == Status.DONE:
                try:
                    self.log_time(task, duration)
                except Exception as ex:
                    logger.warning('[%s] Can not log time for %s : %s', STATUS_DB_TAG, task, ex)

    def log_time(self, task, duration):
        if task.status == Status.DONE:
            timestamp = int(time.mktime(datetime.strptime(task.update_time, FMT_DATE_HMS).timetuple()))
            TimeDB().add(TimeRecord(task.program, timestamp, task.key, task.status, duration))
