# -*- coding: utf-8 -*-
import uuid
from collections import defaultdict
from datetime import timedelta

from intranet.yandex_directory.src.yandex_directory.common.utils import utcnow
from intranet.yandex_directory.src.yandex_directory.common.models.base import BaseModel


class DEPENDENCIES_STATE:
    new = 'new'
    successful = 'successful'
    failed = 'failed'


DEPENDENCIES_COLUMN = {
    DEPENDENCIES_STATE.new: 'dependencies_count',
    DEPENDENCIES_STATE.successful: 'successful_dependencies_count',
    DEPENDENCIES_STATE.failed: 'failed_dependencies_count',
}


class TaskModel(BaseModel):
    db_alias = 'main'
    table = 'tasks'
    json_fields = ['params']
    pickle_fields = ['metadata']
    all_fields = [
        'id',
        'task_name',
        'worker',
        'queue',
        'params',
        'state',
        'result',
        'exception',
        'ttl',
        'start_at',
        'locked_at',
        'free_lock_at',
        'created_at',
        'finished_at',
        'tries',
        'rollback_tries',
        'author_id',
        'priority',
        'metadata',
        'parent_task_id',
        'dependencies_count',
        'successful_dependencies_count',
        'failed_dependencies_count',
        'traceback',
        'ycrid'
    ]

    def create(self, task_name, params, queue, ttl, created_at=None, author_id=None,
               start_in=timedelta(), priority=0, parent_task_id=None, depends_on=None,
               metadata=None, state='free', ycrid=''):
        # depends_on - list of uuid. Список задач, от которых зависит эта задача.
        task_id = str(uuid.uuid4())

        # В качестве параметра start_in можно передавать просто количество секунд
        if isinstance(start_in, int):
            start_in = timedelta(seconds=start_in)

        params = dict(
            id=task_id,
            task_name=task_name,
            params=params,
            queue=queue,
            ttl=ttl,
            state=state,
            created_at=created_at or utcnow(),
            author_id=author_id,
            start_at=utcnow() + start_in,
            priority=priority,
            parent_task_id=parent_task_id,
            metadata=metadata,
            ycrid=ycrid,
        )
        task = self.insert_into_db(
            **params
        )
        if depends_on:
            # Эта зависимость вызывает циклический импорт
            from intranet.yandex_directory.src.yandex_directory.core.models import TaskRelationsModel
            TaskRelationsModel(self._connection).bulk_create(
                [{'task_id': task['id'], 'dependency_task_id': dep_id} for dep_id in depends_on])

        return task

    def get_filters_data(self, filter_data):
        distinct = False

        if not filter_data:
            return distinct, [], [], []

        filter_parts, joins, used_filters = [], [], []

        self.filter_by(filter_data, filter_parts, used_filters) \
            ('id', can_be_list=True) \
            ('state', can_be_list=True) \
            ('task_name', can_be_list=True) \
            ('params') \
            ('author_id') \
            ('parent_task_id') \
            ('queue') \
            ('finished_at__lt') \
            ('finished_at__lte') \
            ('finished_at__gt') \
            ('finished_at__gte')

        if 'dependency' in filter_data:
            # Получить все таски, которые зависят от того, что указан как dependency
            joins.append("""INNER JOIN tasks_relations ON tasks.id = tasks_relations.task_id""")
            filter_parts.append(
                self.mogrify(
                    "tasks_relations.dependency_task_id = %(dependency)s",
                     {'dependency': filter_data['dependency']},
                )
            )
            used_filters.append('dependency')

        if 'dependent' in filter_data:
            #  Получить все таски, от которых зависит тот, что указан как dependent
            joins.append("""INNER JOIN tasks_relations ON tasks.id = tasks_relations.dependency_task_id""")
            filter_parts.append(
                self.mogrify(
                    "tasks_relations.task_id = %(dependent)s",
                     {'dependent': filter_data['dependent']},
                )
            )
            used_filters.append('dependent')

        if 'org_id' in filter_data:
            filter_parts.append(
                self.mogrify(
                    "tasks.params -> 'org_id'=%(org_id)s",
                    {'org_id': str(filter_data['org_id'])},
                )
            )
            used_filters.append('org_id')

        return distinct, filter_parts, joins, used_filters

    def update_one(self, task_id, update_data):
        self.update(
            update_data=update_data,
            filter_data={'id': task_id}
        )

    def set_state(self, task_id, state):
        self.update_one(task_id, {'state': state})

    def save_result(self, task_id, result):
        self.update_one(
            task_id=task_id,
            update_data={
                'result': result,
            }
        )

    def save_metadata(self, task_id, metadata):
        self.update_one(
            task_id=task_id,
            update_data={
                'metadata': metadata,
            }
        )

    def get_metadata(self, task_id):
        row = self.get(task_id, fields=['metadata'])
        if row:
            return row['metadata']

    def get_params(self, task_id):
        row = self.get(task_id, fields=['params'])
        if row:
            return row['params']

    def save_params(self, task_id, params):
        self.update_one(
            task_id=task_id,
            update_data={
                'params': params,
            }
        )

    def lock_for_worker(self, worker, queue):
        """
        Захват задачи для исполнения
        :param worker: имя обработчика
        :param queue: очередь задач
        :return: None - нет свободных задач
        :rtype: dict|None
        """
        # Обходим циклический импорт
        from intranet.yandex_directory.src.yandex_directory.core.task_queue.base import ACTIVE_STATES

        query = """
            UPDATE tasks SET
              worker=%(worker)s,
              locked_at=NOW(),
              free_lock_at=NOW() + make_interval(secs=>ttl),
              state='in-progress'
            WHERE id=(
              SELECT id FROM tasks
              WHERE
                (worker IS NULL OR CURRENT_TIMESTAMP  >= free_lock_at) AND
                CURRENT_TIMESTAMP >= start_at AND
                finished_at IS NULL AND
                queue=%(queue)s AND
                state IN %(active_state)s
              ORDER BY priority DESC, created_at
              LIMIT 1
              FOR UPDATE SKIP LOCKED
            )
            RETURNING *
        """
        params = {
            'worker': worker,
            'queue': queue,
            'active_state': tuple(ACTIVE_STATES),
        }
        result = self._connection.execute(
            query,
            self.prepare_dict_for_db(params)
        ).fetchone()
        if result:
            return dict(result)
        return None

    def release_task(self, task_id):
        from intranet.yandex_directory.src.yandex_directory.core.task_queue.base import TERMINATE_STATES

        query = """
            UPDATE
                tasks
            SET
                worker = NULL,
                locked_at = NULL,
                free_lock_at = NULL,
                state = 'free'
            WHERE id=%(id)s AND state NOT IN %(terminate)s
            """
        self._connection.execute(
            query,
            self.prepare_dict_for_db({'id': task_id, 'terminate': tuple(TERMINATE_STATES)})
        )

    def count_state(self):
        """
        Статистика по состоянию задач в очереди.
        :rtype: dict
        """
        query = """
            SELECT state, queue, task_name, COUNT(state) FROM tasks
            GROUP BY state, queue, task_name;
        """
        result = self._connection.execute(query).fetchall()
        return list(map(dict, result))

    def count_state_created_recently(self, timelimit=60):
        """
        Статистика по состоянию задач в очереди за последние timelimit минут
        :rtype: dict
        """
        query = """
            SELECT state, queue, task_name, COUNT(state) FROM tasks
            WHERE EXTRACT(EPOCH FROM CURRENT_TIMESTAMP - created_at) < {}
            GROUP BY state, queue, task_name;
        """.format(str(timelimit * 60))
        result = self._connection.execute(query).fetchall()
        return list(map(dict, result))

    def _get_dependents(self, task_id, fields=None):
        """Чтобы работать с объектами задач, надо использовать такой же метод но у
           класса task_queue.base.Task.
        """
        query = self.filter(dependency=task_id)
        if fields:
            query = query.fields(*fields)
        return query.all()

    def get_dependencies(self, task_id, task_name=None, state=None):
        filter_data = {'dependent': task_id}
        if task_name:
            if not isinstance(task_name, str):
                task_name = task_name.get_task_name()
            filter_data['task_name'] = task_name
        if state:
            filter_data['state'] = state
        return self.find(filter_data=filter_data)

    def increment_dependencies_count(self, task_id, dependencies_state, increment_value=1):
        column = DEPENDENCIES_COLUMN.get(dependencies_state)
        if not column:
            raise ValueError('Unknown dependency task state {}'.format(dependencies_state))

        query = """
            UPDATE
                tasks
            SET
                {column} = {column} + %(increment_value)s
            WHERE id=%(id)s
            """.format(column=column)
        self._connection.execute(
            query,
            self.prepare_dict_for_db({
                'id': task_id,
                'increment_value': increment_value,
                }
            )
        )

    def get_state(self, task_id):
        """
        Состояние задачи
        :rtype: str
        """
        row = self.get(task_id, fields=['state'])
        if row:
            return row['state']

    def decrement_dependencies_count(self, task_id, dependencies_state, decrement_value=1):
        self.increment_dependencies_count(task_id, dependencies_state, -decrement_value)

    def count_out_of_sync_maillist(self):
        # количество рассинхронизированных рассылок
        query = """
                    SELECT COUNT(*) FROM maillist_checks
                    WHERE ml_is_ok = FALSE;
                """
        return dict(self._connection.execute(query).fetchone()).get('count')

    def get_mail_migration_stats(self, hours=1):
        from intranet.yandex_directory.src.yandex_directory.core.task_queue.base import TASK_STATES
        from intranet.yandex_directory.src.yandex_directory.core.mail_migration import (
            CreateAccountTask,
            CreateMailBoxesTask,
            CreateMailCollectorsTask,
            CreateMailCollectorTask,
            MailMigrationTask,
            DeleteCollectorTask,
            DeleteCollectorsTask,
            WaitingForMigrationsTask,
            WaitingForMigrationTask,
            SetAccountConsistencyTask,
        )
        tasks = [
            CreateAccountTask,
            CreateMailBoxesTask,
            CreateMailCollectorsTask,
            CreateMailCollectorTask,
            MailMigrationTask,
            DeleteCollectorTask,
            DeleteCollectorsTask,
            WaitingForMigrationsTask,
            WaitingForMigrationTask,
            SetAccountConsistencyTask,
        ]

        stat = defaultdict(dict)
        query = """
            SELECT state, COUNT(*) as cnt
            FROM tasks
            WHERE finished_at > now() - '%(hours)s hour'::interval
            AND state IN %(states)s
            AND task_name IN %(task_names)s
            GROUP BY state
        """
        params = {
            'hours': hours,
            'states': tuple([TASK_STATES.success, TASK_STATES.failed]),
            'task_names': tuple(task.get_task_name() for task in tasks)
        }

        tasks_result = self._connection.execute(
            query,
            self.prepare_dict_for_db(params)
        ).fetchall()

        if tasks_result:
            for state, cnt in tasks_result:
                stat[state] = cnt
        return stat
