# -*- coding: utf-8 -*-
import logging

from django.conf import settings
from django.core.management import BaseCommand

import ydb

from travel.avia.library.python.ticket_daemon.date import unixtime
from travel.avia.ticket_daemon.ticket_daemon.lib.ydb import cache as ydb_cache, delete_expired
from ._loggers import setup_logger

logger = logging.getLogger(__name__)


class GarbageCollector(delete_expired.BaseGarbageCollector):
    TARGET_TABLE = ydb_cache.RESULTS_TABLE_NAME
    TTL_TABLE_COUNT = ydb_cache.RESULTS_TTL_TABLES_COUNT

    def delete_row(self, session_pool, ttl_table_name, row):
        request = """
            PRAGMA TablePathPrefix("{path}");

            DECLARE $point_from AS Utf8;
            DECLARE $point_to AS Utf8;
            DECLARE $date_forward AS Uint32;
            DECLARE $date_backward AS Uint32;
            DECLARE $klass AS Uint8;
            DECLARE $passengers AS Uint32;
            DECLARE $national_version AS Utf8;
            DECLARE $lang AS Utf8;
            DECLARE $partner_code AS Utf8;
            DECLARE $expires_at AS Uint64;

            DELETE FROM {results}
            WHERE
                point_from = $point_from
                AND point_to = $point_to
                AND date_forward = $date_forward
                AND date_backward = $date_backward
                AND klass = $klass
                AND passengers = $passengers
                AND national_version = $national_version
                AND lang = $lang
                AND partner_code = $partner_code
                AND expires_at = $expires_at;

            DELETE FROM {result_ttl}
            WHERE
                point_from = $point_from
                AND point_to = $point_to
                AND date_forward = $date_forward
                AND date_backward = $date_backward
                AND klass = $klass
                AND passengers = $passengers
                AND national_version = $national_version
                AND lang = $lang
                AND partner_code = $partner_code
                AND expires_at = $expires_at;
        """.format(
            path=settings.DRIVER_CONFIG.database,
            results=self.TARGET_TABLE,
            result_ttl=ttl_table_name,
        )

        def callee(session):
            parameters = {'$' + k: v for k, v in row.iteritems()}
            prepared_query = session.prepare(request)
            session.transaction(ydb.SerializableReadWrite()).execute(
                prepared_query,
                commit_tx=True,
                parameters=parameters
            )

        return session_pool.retry_operation_sync(callee)

    @staticmethod
    def format_ttl_table_name(ttl_table_idx):
        return ydb_cache.format_ttl_table_name(ttl_table_idx)


class Command(BaseCommand):
    help = 'Delete expired YDB rows'

    def add_arguments(self, parser):
        parser.add_argument(
            '--stdout', action='store_true', default=False,
            dest='add_stdout_handler', help='Add stdout handler',
        )
        parser.add_argument(
            '--jobs',  type=int, default=1,
            dest='jobs', help='parallel deletion from max N ttl tables',
        )
        parser.add_argument(
            '--start-ts-shift', type=int, default=60 * 60 * 24 * 2,
            dest='start_ts_shift',
            help='Delete expired rows in the last N seconds',
        )

    def handle(self, *db_names, **options):
        setup_logger(logger, options.get('verbosity'), options.get('add_stdout_handler'))

        gc = GarbageCollector(logger, 20)
        start_ts = unixtime() - options.get('start_ts_shift')
        logger.info('Start')
        gc.delete_expired_by_ttl_table(start_ts, options.get('jobs'))
        logger.info('End')
