# -*- coding: utf-8 -*-
import functools
import time
import random
from abc import ABCMeta, abstractproperty, abstractmethod
from multiprocessing.dummy import Pool as ThreadPool

from django.conf import settings
from more_itertools import chunked

import ydb

from travel.avia.library.python.ticket_daemon.date import unixtime
from travel.avia.library.python.ydb.session_manager import YdbSessionManager


class BaseGarbageCollector(object):
    SESSION_POOL_SIZE = 50
    __metaclass__ = ABCMeta

    @abstractproperty
    def TARGET_TABLE(self):
        """Таблица с данными"""

    @abstractproperty
    def TTL_TABLE_COUNT(self):
        """Колличество вспомогательных таблиц"""

    @abstractmethod
    def delete_row(self, session_pool, ttl_table_name, row):
        """Метод для удаления устаревших данных"""

    @staticmethod
    @abstractmethod
    def format_ttl_table_name(ttl_table_idx):
        """Метод формирует название вспомогательной таблицы"""

    def __init__(self, logger, threads, pool_size=SESSION_POOL_SIZE):
        self.logger = logger
        self._pool = ThreadPool(threads)
        self._session_manager = YdbSessionManager(driver_config=settings.DRIVER_CONFIG, pool_size=pool_size)

    def get_rows(self, session_pool, ttl_table_name, timestamp, ts_delta=60, limit=1000):
        request = """
            PRAGMA TablePathPrefix("{path}");

            DECLARE $timestamp AS Uint64;
            DECLARE $prev_timestamp AS Uint64;

            SELECT * from {ttl_table}
            WHERE expires_at <= $timestamp AND expires_at >= $prev_timestamp
            LIMIT {limit};
        """.format(
            path=settings.DRIVER_CONFIG.database,
            ttl_table=ttl_table_name,
            limit=limit,
        )

        def callee(session):
            prepared_query = session.prepare(request)
            result_sets = session.transaction(
                ydb.SerializableReadWrite()
            ).execute(
                prepared_query,
                commit_tx=True,
                parameters={
                    '$timestamp': timestamp + ts_delta,
                    '$prev_timestamp': timestamp,
                }
            )
            return result_sets[0].rows

        return session_pool.retry_operation_sync(callee)

    def delete_rows(self, session_pool, ttl_table_name, current_ts, rows):
        if current_ts > unixtime():
            self.logger.info('Stop deleting from %s', ttl_table_name)
            return

        if rows:
            self.logger.info('Delete %d rows from %s', len(rows), ttl_table_name)
            next_ts = max(rows, key=lambda row: row['expires_at'])['expires_at']
            self._pool.map_async(
                functools.partial(self.delete_row, session_pool, ttl_table_name),
                rows,
                callback=functools.partial(
                    self.delete_rows_callback, session_pool, ttl_table_name, next_ts
                )
            )
        else:
            next_ts = current_ts + 60
            self._pool.apply_async(
                self.get_rows,
                args=(session_pool, ttl_table_name, current_ts),
                callback=functools.partial(self.delete_rows, session_pool, ttl_table_name, next_ts)
            )

    def delete_rows_callback(self, session_pool, ttl_table_name, timestamp, *args):
        self.logger.info('Get rows for deletion from %s, Start timestamp: %d', ttl_table_name, timestamp)
        self._pool.apply_async(
            self.get_rows,
            args=(session_pool, ttl_table_name, timestamp),
            callback=functools.partial(self.delete_rows, session_pool, ttl_table_name, timestamp)
        )

    def delete_expired_by_ttl_table(self, start_ts, jobs=1):
        ttl_tables = map(self.format_ttl_table_name, xrange(self.TTL_TABLE_COUNT))
        random.shuffle(ttl_tables)
        for tables in chunked(ttl_tables, jobs):
            for ttl_table_name in tables:
                with self._session_manager.get_session_pool() as session_pool:
                    current_ts = self.get_last_row_timestamp(session_pool, ttl_table_name, start_ts)
                    if current_ts is None:
                        self.logger.info('Nothing to delete from %s', ttl_table_name)
                        continue
                    self.logger.info('Add %s to pool. Start timestamp: %d', ttl_table_name, current_ts)
                    self._pool.apply_async(
                        self.get_rows,
                        args=(session_pool, ttl_table_name, current_ts),
                        callback=functools.partial(self.delete_rows, session_pool, ttl_table_name, current_ts)
                    )

            while self._pool._cache:
                self.logger.info('Number of jobs pending: %d', len(self._pool._cache))
                time.sleep(10)
        self._pool.close()

    def get_last_row_timestamp(self, session_pool, ttl_table_name, start_ts):
        self.logger.info('Get %s last row timestamp', ttl_table_name)
        step = 60 * 20  # 20 minutes

        while start_ts < unixtime():
            rows = self.get_rows(session_pool, ttl_table_name, start_ts, ts_delta=step, limit=1)
            if rows:
                return rows[0]['expires_at']

            start_ts += step

        return None
