# coding: utf-8
import logging
import random
import time

import passport.backend.core.ydb_client as ydb
from passport.backend.profile.utils.yt import ExclusiveLock
from retrying import retry


log = logging.getLogger('passport.profile')

UID_RANGES = ((0, 5 * 10 ** 9),  # ordinary
              (1110000000000000, 1110000100000000),  # kinopoisk
              (1130000000000000, 1130000200000000))  # pdd


DELETE_QUERY = """
--!syntax_v1

DELETE FROM {table_name} WHERE
uid between {beginning} and {ending} and (updated_at < {unixtime} or value IS NULL);
"""

FIND_UIDS_RANGE_MIN_QUERY = """
--!syntax_v1

SELECT uid FROM {table_name}
WHERE uid > {min_uid} and uid < {max_uid}
ORDER BY uid asc
LIMIT 1
"""

FIND_UIDS_RANGE_MAX_QUERY = """
--!syntax_v1

SELECT uid FROM {table_name}
WHERE uid > {min_uid} and uid < {max_uid}
ORDER BY uid desc
LIMIT 1
"""


def get_ydb(config):
    connection_params = ydb.ConnectionParams(
        config['ydb']['endpoint'],
        database=config['ydb']['database'],
        auth_token=config['ydb']['token'],
    )
    driver = ydb.Driver(connection_params)
    driver.wait(timeout=5)
    return driver


def retry_policy(exceptions, retry_logger):
    def _validator(e):
        retry_logger.info(str(e))
        return isinstance(e, exceptions)
    return _validator


@retry(
    stop_max_attempt_number=3,
    wait_fixed=2000,
    retry_on_exception=retry_policy(Exception, log),
    wrap_exception=True
)
def do_ydb_request(session, query):
    return session.transaction().execute(
        query,
        commit_tx=True,
    )


def find_uids_range_limits_from_ydb(session, config, min_uid, max_uid):
    result_min = do_ydb_request(
        session,
        FIND_UIDS_RANGE_MIN_QUERY.format(
            table_name=config['ydb']['table_name'],
            min_uid=min_uid,
            max_uid=max_uid,
        )
    )

    result_max = do_ydb_request(
        session,
        FIND_UIDS_RANGE_MAX_QUERY.format(
            table_name=config['ydb']['table_name'],
            min_uid=min_uid,
            max_uid=max_uid,
        )
    )

    if not all([len(x[0].rows) for x in [result_min, result_max]]):
        return 0, 0
    return int(result_min[0].rows[0].uid), int(result_max[0].rows[0].uid)


def get_uid_ranges_and_generate_partitions(config):
    driver = get_ydb(config)
    session = driver.table_client.session().create()

    partitions = []

    for left_limit, right_limit in UID_RANGES:
        min_uid, max_uid = find_uids_range_limits_from_ydb(session, config, left_limit, right_limit)
        partitions.extend(generate_partitions(min_uid, max_uid, config['ydb']['clearing_chunk_size']))

    random.shuffle(partitions)
    return partitions


def generate_partitions(begin, end, size):
    # так как в запросе на удаление используется between, это включает правую и левую границы диапазона, то используем size - 1.
    # если end кратен size, то последний элемент не попадет в наши partitions, используем end + 1 для фикса.
    return [(b, b + size - 1) for b in range(begin, end + 1, size)]


def delete_rows_partition(config, session, unixtime, partition):
    begin, end = partition
    do_ydb_request(
        session,
        DELETE_QUERY.format(
            table_name=config['ydb']['table_name'],
            unixtime=unixtime,
            beginning=begin,
            ending=end,
        )
    )


def delete_rows_partitions(config, unixtime, partitions):
    driver = get_ydb(config)
    session = driver.table_client.session().create()

    for i, partition in enumerate(partitions):
        try:
            delete_rows_partition(config, session, unixtime, partition)
            if i % 100 == 0:
                log.info('Deleted %s partitions of size %s', i, config['ydb']['clearing_chunk_size'])
        except Exception:
            log.info('Partition %s omitted due issues with removal', partition)
            session = driver.table_client.session().create()


def delete_old_rows(config):
    with ExclusiveLock(config=config, lock_path=config['yt']['delete_old_rows_daily_lock']):
        unixtime_week_ago = (int(time.time()) - 7 * 24 * 60 * 60) * 1000000
        log.info('Starting old rows deletion before %s unixtime...', unixtime_week_ago)
        partitions = get_uid_ranges_and_generate_partitions(config)
        delete_rows_partitions(config, unixtime_week_ago, partitions)
        log.info('Deletion finished!')
