# -*- coding: utf-8 -*-
"""Реализация механизма dbqueue в Директе"""
from __future__ import absolute_import
import json
import logging
import os
import socket
import struct
import yaml
import zlib
import warnings
import MySQLdb
from contextlib import contextmanager
from sqlalchemy import MetaData, Table, Column, ForeignKey
from sqlalchemy.sql import select, literal, text, null, not_, and_, or_
from sqlalchemy.dialects import mysql
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import Update
from sqlalchemy.sql.base import _generative
from sqlalchemy.sql.elements import ClauseList
from direct.db import db_engine
from direct.dbshards import get_new_id

log = logging.getLogger(__name__)

# Корень проекта и файл с типами заданий
ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), '..', '..'))
DBQUEUE_TYPES_FILE = os.path.join(ROOT, 'etc', 'dbqueue-types.yaml')

# Маркер для параметров с динамическим значением по умолчанию
USE_DEFAULT = object()

# Пространство имён для использования в очередях
NAMESPACE = None

# Время в секундах, на которое происходит захват задач
DEFAULT_GRAB_FOR = 15 * 60

class schema(object):
    metadata = MetaData()
    dbqueue_job_types = Table('dbqueue_job_types', metadata,
        Column('job_type_id', mysql.INTEGER(10, unsigned=True), primary_key=True, autoincrement=False),
        Column('job_type', mysql.VARCHAR(30), nullable=False),
    )
    dbqueue_jobs = Table('dbqueue_jobs', metadata,
        Column('job_id', mysql.BIGINT(20, unsigned=True), primary_key=True, autoincrement=False),
        Column('ClientID', mysql.INTEGER(10, unsigned=True), nullable=False),
        Column('job_type_id', mysql.INTEGER(10, unsigned=True), ForeignKey(dbqueue_job_types.c.job_type_id), nullable=False),
        Column('namespace', mysql.VARCHAR(20), server_default=text('NULL')),
        Column('status', mysql.ENUM('New', 'Grabbed', strict=True), nullable=False),
        Column('uid', mysql.BIGINT(20, unsigned=True), nullable=False, server_default='0'),
        Column('args', mysql.MEDIUMBLOB),
        Column('priority', mysql.INTEGER(11), nullable=False, server_default='0'),
        Column('create_time', mysql.DATETIME, nullable=False, server_default=text('CURRENT_TIME')),
        Column('expiration_time', mysql.DATETIME, server_default=text('NULL')),
        Column('grabbed_by', mysql.VARCHAR(40), nullable=False, server_default=''),
        Column('grabbed_until', mysql.DATETIME, server_default=text('NULL')),
        Column('grabbed_at', mysql.DATETIME, server_default=text('NULL')),
        Column('trycount', mysql.INTEGER(11), nullable=False, server_default='0'),
        Column('run_after', mysql.DATETIME, server_default=text('NULL')),
    )
    dbqueue_job_archive = Table('dbqueue_job_archive', metadata,
        Column('job_id', mysql.BIGINT(20, unsigned=True), primary_key=True, autoincrement=False),
        Column('ClientID', mysql.INTEGER(10, unsigned=True), nullable=False),
        Column('job_type_id', mysql.INTEGER(10, unsigned=True), ForeignKey(dbqueue_job_types.c.job_type_id), nullable=False),
        Column('namespace', mysql.VARCHAR(20), server_default=text('NULL')),
        Column('status', mysql.ENUM('Finished', 'Failed', 'Revoked', strict=True), nullable=False),
        Column('uid', mysql.BIGINT(20, unsigned=True), nullable=False, server_default='0'),
        Column('args', mysql.MEDIUMBLOB),
        Column('priority', mysql.INTEGER(11), nullable=False, server_default='0'),
        Column('create_time', mysql.DATETIME, nullable=False, server_default=text('CURRENT_TIME')),
        Column('expiration_time', mysql.DATETIME, server_default=text('NULL')),
        Column('grabbed_by', mysql.VARCHAR(40), nullable=False, server_default=''),
        Column('grabbed_until', mysql.DATETIME, server_default=text('NULL')),
        Column('grabbed_at', mysql.DATETIME, server_default=text('NULL')),
        Column('trycount', mysql.INTEGER(11), nullable=False, server_default='0'),
        Column('run_after', mysql.DATETIME, server_default=text('NULL')),
        Column('result', mysql.MEDIUMBLOB),
    )

class UpdateOrderByLimit(Update):
    """Расширение для поддержки UPDATE .. ORDER BY .. LIMIT .."""
    _order_by = None
    _limit = None

    @_generative
    def order_by(self, *clauses):
        self._order_by = ClauseList(*clauses)

    @_generative
    def limit(self, limit):
        self._limit = limit

@compiles(UpdateOrderByLimit)
def visit_update_order_by_limit(element, compiler, **kwargs):
    sql = compiler.visit_update(element)
    if element._order_by is not None:
        sql += ' \nORDER BY ' + compiler.process(element._order_by)
    if element._limit is not None:
        sql += ' \nLIMIT %d' % (element._limit,)
    return sql

JOB_FIELDS = tuple(c.name for c in schema.dbqueue_jobs.c)
JOB_ARCHIVE_FIELDS = tuple(c.name for c in schema.dbqueue_job_archive.c)

def load_job_type_id_from_file(job_type):
    with open(DBQUEUE_TYPES_FILE, 'r') as f:
        for entry in yaml.load(f, Loader=yaml.CLoader):
            if entry['type'] == job_type:
                return entry['id']
    raise ValueError('Unknown job type: %r' % (job_type,))

def job_type_to_id(db, job_type):
    with db.connect() as conn:
        result = conn.scalar(
            select([schema.dbqueue_job_types.c.job_type_id])
            .where(schema.dbqueue_job_types.c.job_type == job_type)
        )
        if result is None:
            job_type_id = load_job_type_id_from_file(job_type)
            conn.execute(
                schema.dbqueue_job_types.insert()
                .prefix_with('IGNORE')
                .values(job_type_id=job_type_id, job_type=job_type)
            )
            result = conn.scalar(
                select([schema.dbqueue_job_types.c.job_type_id])
                .where(schema.dbqueue_job_types.c.job_type == job_type)
            )
            if result is None:
                raise ValueError('Failed to insert job_type=%r with job_type_id=%r' % (job_type, job_type_id))
    return result

def mysql_compress(data):
    compressed = zlib.compress(data, 9)
    if compressed.endswith(' '):
        compressed += '.'
    return struct.pack('<I', len(data)) + compressed

def mysql_uncompress(data):
    if not data:
        return None
    if len(data) < 4:
        raise ValueError('Malformed compressed data')
    size, = struct.unpack('<I', data[:4])
    data = zlib.decompress(data[4:])
    if len(data) != size:
        raise ValueError('Malformed compressed data: decompressed size mismatch')
    return data

def json_encode(data):
    data = json.dumps(data, encoding='utf8', ensure_ascii=False)
    if isinstance(data, unicode):
        data = data.encode('utf8')
    return data

def json_decode(data):
    if not data:
        return None
    if isinstance(data, str):
        data = data.decode('utf8')
    return json.loads(data)

def pack_value(value):
    if value is None:
        return None
    return mysql_compress(json_encode(value))

def unpack_value(value):
    if value is None:
        return None
    return json_decode(mysql_uncompress(value))

def get_global_process_id():
    return '{}:{}'.format(socket.gethostname(), os.getpid())

def rich_equality(column, value):
    if value is None or isinstance(value, (int, long, float, basestring)) or not hasattr(value, '__iter__'):
        return column == value
    return column.in_(value)

def combine_namespaces(namespace1, namespace2):
    """Комбинирует два namespace в общую строку"""
    if namespace1 and namespace2:
        return '{}/{}'.format(namespace1, namespace2)
    return namespace1 or namespace2

class DBQueueJob(object):
    __slots__ = JOB_ARCHIVE_FIELDS

    def __init__(self,
                 job_id,
                 ClientID,
                 job_type_id,
                 namespace,
                 status='New',
                 uid=None,
                 args=None,
                 priority=None,
                 create_time=None,
                 expiration_time=None,
                 grabbed_by=None,
                 grabbed_until=None,
                 grabbed_at=None,
                 trycount=None,
                 run_after=None,
                 result=None):
        self.job_id = job_id
        self.ClientID = ClientID
        self.job_type_id = job_type_id
        self.namespace = namespace
        self.status = status
        self.uid = uid
        self.args = args
        self.priority = priority
        self.create_time = create_time
        self.expiration_time = expiration_time
        self.grabbed_by = grabbed_by
        self.grabbed_until = grabbed_until
        self.grabbed_at = grabbed_at
        self.trycount = trycount
        self.run_after = run_after
        self.result = result

    @classmethod
    def from_db_row(cls, row):
        row = dict(row)
        if 'args' in row:
            row['args'] = unpack_value(row['args'])
        if 'result' in row:
            row['result'] = unpack_value(row['result'])
        return cls(**row)

    @classmethod
    def updates_to_db(cls, updates):
        result = {}
        for name, value in updates.iteritems():
            if value is not None and name in ('args', 'result'):
                value = pack_value(value)
            result[name] = value
        return result

    def to_db_row(self):
        result = {}
        for name in self.__slots__:
            value = getattr(self, name)
            if value is not None:
                if name in ('args', 'result'):
                    value = pack_value(value)
                result[name] = value
        return result

    def __repr__(self):
        return 'DBQueueJob(%s)' % ', '.join('%s=%r' % (name, getattr(self, name)) for name in self.__slots__)

class DBQueue(object):
    __slots__ = ('db', 'job_type', 'job_type_id', 'namespace', 'minimum_priority')

    def __init__(self, db, job_type, job_type_id=None, namespace=None, minimum_priority=None):
        """Объект для работы с очередью заданий

        :param db: База данных с очередью (строка, Engine, либо Connection)
        :param job_type: Тип заданий в очереди
        :param job_type_id: Опциональный числовой идентификатор для типа заданий
        :param namespace: Пространство имён для заданий
        :param minimum_priority: Минимальный приоритет заданий
        """
        if isinstance(db, basestring):
            db = db_engine(db)
        if job_type_id is None:
            job_type_id = job_type_to_id(db, job_type)
        self.db = db
        self.job_type = job_type
        self.job_type_id = job_type_id
        self.namespace = namespace
        self.minimum_priority = minimum_priority

    @property
    def absolute_namespace(self):
        return combine_namespaces(NAMESPACE, self.namespace)

    def with_connection(self, conn):
        """Возвращает копию DBQueue, которая использует указанный conn"""
        return DBQueue(conn, self.job_type, self.job_type_id, self.namespace, self.minimum_priority)

    def active_job_ids(self, job_ids=None, limit=None):
        """Возвращает id заданий, которые ещё не были перемещены в архив"""
        where_clauses = [
            schema.dbqueue_jobs.c.job_type_id == self.job_type_id,
        ]
        if job_ids is not None:
            where_clauses.append(schema.dbqueue_jobs.c.job_id.in_(job_ids))
        else:
            where_clauses.append(schema.dbqueue_jobs.c.namespace == self.absolute_namespace)
        sql = select([schema.dbqueue_jobs.c.job_id]).where(and_(*where_clauses))
        if limit is not None:
            sql = sql.limit(limit)
        with self.db.connect() as conn:
            rows = conn.execute(sql).fetchall()
        return [row[0] for row in rows]

    def status_by_job_ids(self, job_ids):
        """Возвращает статус для каждого найденного job_id"""
        sql = (
            select([schema.dbqueue_jobs.c.job_id, schema.dbqueue_jobs.c.status])
            .where(and_(
                schema.dbqueue_jobs.c.job_id.in_(job_ids),
                schema.dbqueue_jobs.c.job_type_id == self.job_type_id,
            ))
            .union_all(
                select([schema.dbqueue_job_archive.c.job_id, schema.dbqueue_job_archive.c.status])
                .where(and_(
                    schema.dbqueue_job_archive.c.job_id.in_(job_ids),
                    schema.dbqueue_job_archive.c.job_type_id == self.job_type_id,
                ))
            )
        )
        with self.db.connect() as conn:
            rows = conn.execute(sql).fetchall()
        return {row[0]: row[1] for row in rows}

    def find(self, ClientID=None, uid=None, status=None, status_not=None, job_id=None, skip_archived=False, limit=None):
        """Находит задания по указанным критериям"""
        where_clauses = [
            lambda table: table.c.job_type_id == self.job_type_id,
        ]
        if ClientID is not None:
            where_clauses.append(lambda table: rich_equality(table.c.ClientID, ClientID))
        if uid is not None:
            where_clauses.append(lambda table: rich_equality(table.c.uid, uid))
        if status is not None:
            where_clauses.append(lambda table: rich_equality(table.c.status, status))
        if status_not is not None:
            where_clauses.append(lambda table: not_(rich_equality(table.c.status, status_not)))
        if job_id is not None:
            where_clauses.append(lambda table: rich_equality(table.c.job_id, job_id))
        else:
            where_clauses.append(lambda table: table.c.namespace == self.absolute_namespace)
        sql = select([schema.dbqueue_jobs, null().label('result')]).where(and_(*(
            clause(schema.dbqueue_jobs) for clause in where_clauses
        )))
        if not skip_archived:
            sql = sql.union_all(select([schema.dbqueue_job_archive]).where(and_(*(
                clause(schema.dbqueue_job_archive) for clause in where_clauses
            ))))
        if limit is not None:
            sql = sql.limit(limit)
        with self.db.connect() as conn:
            rows = conn.execute(sql).fetchall()
        return [DBQueueJob.from_db_row(row) for row in rows]

    def find_by_id(self, job_id):
        """Находит задание по его job_id"""
        with self.db.connect() as conn, conn.begin():
            row = conn.execute(
                schema.dbqueue_jobs.select()
                .where(and_(
                    schema.dbqueue_jobs.c.job_id == job_id,
                    schema.dbqueue_jobs.c.job_type_id == self.job_type_id,
                ))
            ).first()
            if row is None:
                row = conn.execute(
                    schema.dbqueue_job_archive.select()
                    .where(and_(
                        schema.dbqueue_job_archive.c.job_id == job_id,
                        schema.dbqueue_job_archive.c.job_type_id == self.job_type_id,
                    ))
                ).first()
                if row is None:
                    return None
        return DBQueueJob.from_db_row(row)

    def insert_job(self, **kwargs):
        """Вставляет задание с данными kwargs"""
        if 'namespace' in kwargs:
            raise ValueError('Cannot specify namespace when inserting jobs')
        grab_for = kwargs.pop('grab_for', DEFAULT_GRAB_FOR)
        if 'job_id' not in kwargs:
            kwargs['job_id'] = get_new_id('job_id')
        kwargs.setdefault('ClientID', 0)
        kwargs['job_type_id'] = self.job_type_id
        kwargs['namespace'] = self.absolute_namespace
        job = DBQueueJob(**kwargs)
        if job.status == 'Grabbed':
            # Вставка задания сразу в состоянии grabbed
            process_id = get_global_process_id()
            job.grabbed_by = process_id
            if not job.trycount:
                job.trycount = 1
            data = job.to_db_row()
            if 'grabbed_at' not in data:
                data['grabbed_at'] = text('NOW()')
            if 'grabbed_until' not in data:
                data['grabbed_until'] = text('NOW() + INTERVAL :grab_for SECOND').bindparams(grab_for=grab_for)
        else:
            data = job.to_db_row()
        insert_statement = (
            schema.dbqueue_jobs.insert()
            .values(**data)
        )
        with self.db.connect() as conn:
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category = MySQLdb.Warning)
                conn.execute(insert_statement)
        return job

    def update_job(self, job, **updates):
        """Обновляет задание job меняя параметры на основе updates

        Возвращает True, если задание не архивное и успешно обновилось.
        """
        for forbidden_key in ('job_id', 'status', 'namespace'):
            if forbidden_key in updates:
                raise ValueError('Cannot modify key %s' % (forbidden_key,))
        update_statement = (
            schema.dbqueue_jobs.update()
            .where(and_(
                schema.dbqueue_jobs.c.job_id == job.job_id,
                schema.dbqueue_jobs.c.job_type_id == self.job_type_id,
            ))
            .values(**DBQueueJob.updates_to_db(updates))
        )
        with self.db.connect() as conn:
            affected_rows = conn.execute(update_statement).rowcount
        if not affected_rows:
            return False
        for name, value in updates.iteritems():
            setattr(job, name, value)
        return True

    def grab_job(self, grab_for=None):
        """Захватывает задание на grab_for секунд"""
        jobs = self.grab_jobs(limit=1, grab_for=grab_for)
        if jobs:
            return jobs[0]
        return None

    def grab_jobs(self, limit=1, grab_for=None):
        """Захватывает не более limit заданий до grab_for секунд от текущего времени"""
        if limit < 1:
            raise ValueError('limit must be a positive integer')
        if grab_for is None:
            grab_for = DEFAULT_GRAB_FOR
        process_id = get_global_process_id()
        where_clause = [
            schema.dbqueue_jobs.c.job_type_id == self.job_type_id,
            schema.dbqueue_jobs.c.namespace == self.absolute_namespace,
            or_(
                schema.dbqueue_jobs.c.status == 'New',
                schema.dbqueue_jobs.c.grabbed_until < text('NOW()'),
                schema.dbqueue_jobs.c.grabbed_by == process_id,
            ),
            or_(
                schema.dbqueue_jobs.c.run_after == None,
                schema.dbqueue_jobs.c.run_after < text('NOW()'),
            ),
        ]
        if self.minimum_priority is not None:
            where_clause.append(schema.dbqueue_jobs.priority >= self.minimum_priority)
        where_clause = and_(*where_clause)
        update_statement = (
            UpdateOrderByLimit(schema.dbqueue_jobs)
            .where(where_clause)
            .values(
                status='Grabbed',
                grabbed_by=process_id,
                grabbed_at=text('NOW()'),
                grabbed_until=text('NOW() + INTERVAL :grab_for SECOND').bindparams(grab_for=grab_for),
                trycount=text('trycount + 1'),
            )
            .order_by(schema.dbqueue_jobs.c.job_id)
            .limit(limit)
        )
        with self.db.connect() as conn, conn.begin():
            affected_rows = conn.execute(update_statement).rowcount
            log.debug('grab_jobs updated %s rows in the database', affected_rows)
            if not affected_rows:
                return []
            rows = conn.execute(
                schema.dbqueue_jobs.select()
                .where(and_(
                    schema.dbqueue_jobs.c.job_type_id == self.job_type_id,
                    schema.dbqueue_jobs.c.grabbed_by == process_id,
                ))
                .limit(limit)
            ).fetchall()
            if len(rows) != affected_rows:
                # Вскрылась какая-то разница между условиями в update и select
                raise RuntimeError('Fetched %d rows, but expected %d rows' % (len(rows), affected_rows))
        return [DBQueueJob.from_db_row(row) for row in rows]

    def grab_extend(self, job, grab_for=None):
        """Продлевает время захвата job до grab_for секунд от текущего времени

        Возвращает True, если задание всё ещё было захвачено и успешно продлилось.
        """
        if grab_for is None:
            grab_for = DEFAULT_GRAB_FOR
        process_id = get_global_process_id()
        where_clause = and_(
            schema.dbqueue_jobs.c.job_id == job.job_id,
            schema.dbqueue_jobs.c.status == 'Grabbed',
            schema.dbqueue_jobs.c.grabbed_by == process_id,
        )
        update_statement = (
            schema.dbqueue_jobs.update()
            .where(where_clause)
            .values(
                grabbed_until=text('NOW() + INTERVAL :grab_for SECOND').bindparams(grab_for=grab_for),
            )
        )
        with self.db.connect() as conn:
            affected_rows = conn.execute(update_statement).rowcount
            log.debug('grab_extend updated %s rows in the database', affected_rows)
        if not affected_rows:
            return False
        return True

    def grab_release(self, job, backoff=None):
        """Отпускает задание job, давая возможность обработки другому процессу

        Возвращает True, если задание всё ещё было захвачено и успешно отпустилось.
        """
        process_id = get_global_process_id()
        updates = {
            'status': 'New',
            'grabbed_by': '',
            'grabbed_until': null(),
        }
        if backoff is not None:
            updates['run_after'] = text('NOW() + INTERVAL :backoff SECOND').bindparams(backoff=backoff)
        update_statement = (
            schema.dbqueue_jobs.update()
            .where(and_(
                schema.dbqueue_jobs.c.job_id == job.job_id,
                schema.dbqueue_jobs.c.status == 'Grabbed',
                schema.dbqueue_jobs.c.grabbed_by == process_id,
            ))
            .values(**updates)
        )
        with self.db.connect() as conn:
            affected_rows = conn.execute(update_statement).rowcount
            log.debug('grab_release updated %s rows in the database', affected_rows)
        if not affected_rows:
            return False
        job.status = 'New'
        job.grabbed_by = ''
        job.grabbed_until = None
        return True

    def _archive_job(self, job, status, result=None):
        """Перемещает задания из активной очереди в архивную со сменой статуса

        :param job: может быть объектом DBQueueJob, идентификатором задания,
                    либо списком заданий.

        Возвращает True, если хотя бы одно задание было перемещено в архив.
        """
        if isinstance(job, DBQueueJob):
            job_id = job.job_id
        elif isinstance(job, (int, long)):
            job_id = job
            job = None
        elif isinstance(job, (list, tuple)):
            if not job:
                return False
            for job_id in job:
                if not isinstance(job_id, (int, long)):
                    raise TypeError('job must be either DBQueueJob or a job_id or a list of job_ids')
            job_id = job
            job = None
        else:
            raise TypeError('job must be either DBQueueJob or a job_id or a list of job_ids')
        packed_result = pack_value(result)
        copied_columns = [c for c in schema.dbqueue_jobs.c if c.name not in ('status',)]
        sql_insert = (
            schema.dbqueue_job_archive.insert()
            .from_select(
                copied_columns + ['status', 'result'],
                select(copied_columns + [
                    literal(status, schema.dbqueue_job_archive.c.status.type).label('status'),
                    literal(packed_result, schema.dbqueue_job_archive.c.result.type).label('result'),
                ], for_update=True)
                .where(rich_equality(schema.dbqueue_jobs.c.job_id, job_id))
            )
        )
        sql_delete = (
            schema.dbqueue_jobs.delete()
            .where(rich_equality(schema.dbqueue_jobs.c.job_id, job_id))
        )
        with self.db.connect() as conn, conn.begin():
            inserted = conn.execute(sql_insert).rowcount
            if not inserted:
                # Если мы не вставили ни одной строчки, значит ни одной строчки
                # не подошло под условие. В норме это может случиться, если
                # другой процесс уже успел вызвать _archive_job.
                return False
            # Удаляем строки, которые вставились в архивную таблицу. Так как
            # мы делали select for update все эти строки должны успешно
            # удалиться.
            deleted = conn.execute(sql_delete).rowcount
            if deleted != inserted:
                # delete должен удалить те же строки, которые мы заселектили
                # иначе что-то пошло совсем не так и в mysql косяк
                raise RuntimeError('Archived %r jobs, however delete matched %r jobs' % (inserted, deleted))
        if job is not None:
            # Обновляем значения в переданном объекте
            job.status = status
            job.result = result
        return True

    def mark_finished(self, job, result=None):
        """Помечает задание или задания как Finished с указанным результатом"""
        return self._archive_job(job, 'Finished', result)

    def mark_failed(self, job, result=None):
        """Помечает задание или задания как Failed с указанным результатом"""
        return self._archive_job(job, 'Failed', result)

    def mark_revoked(self, job, result=None):
        """Помечает задание или задания как Revoked с указанным результатом"""
        return self._archive_job(job, 'Revoked', result)
