import time
import threading
import psycopg2
from collections import namedtuple, defaultdict
from multiprocessing.pool import ThreadPool

import shards
from db_requests import *

User = namedtuple('User', [
    'service',
    'uid'
])

tls = threading.local()
DELETE_POOL_SIZE = 6
USERS_TO_CLEAN_PERCENT = 1
SERVICE_BLACKLIST = [
    'apns_queue',
    'passport-stream',
    'webpushapi'
]

def _profiled(func):
    def wrapper(*args, **kwargs):
        try:
            start_time = time.time()
            result = func(*args, **kwargs)
            duration = time.time() - start_time
            tls.logger.info('%s done in %.6f sec', func.__name__, duration)
            return result
        finally:
            pass
    return wrapper

def remove_unused_counters(xstore_shard, logger):
    tls.logger = logger
    services = _select_services(xstore_shard)
    sid_blacklist = _blacklisted_sids(services)
    uid_ranges = _select_uid_ranges(xstore_shard)
    uid_ranges = _apply_user_cleanup_limit(uid_ranges)
    uid_ranges, skipped = _skip_processed_ranges(xstore_shard, uid_ranges)
    total_deleted = 0
    for range_index, uid_range in enumerate(uid_ranges):
        counters = _select_counters(xstore_shard, uid_range, sid_blacklist)
        xtable_users_by_conninfo = _distribute_by_xtable_conninfo(counters, services)
        acked_subscriptions = _select_acked_subscriptions(xtable_users_by_conninfo)
        possibly_unused_counters = _filter_used_counters(services, counters, acked_subscriptions)
        total_deleted += _try_delete_counters(possibly_unused_counters, xstore_shard, logger)
        _update_ranges_processed(xstore_shard, range_index + skipped + 1)
    logger.info('removed %s counters among %s%% of users, taking a short break', total_deleted, USERS_TO_CLEAN_PERCENT)
    time.sleep(30)

@_profiled
def _select_services(xstore_shard):
    conninfo = shards.conninfo_for_read(xstore_shard)
    select_sql = 'SELECT sid, service_name FROM xiva.services'
    rows = execute(conninfo, select_sql)
    return {r.sid: r.service_name for r in rows}

def _blacklisted_sids(services):
    return [sid for sid, service_name in services.items() if service_name in SERVICE_BLACKLIST]

@_profiled
def _select_uid_ranges(xstore_shard):
    ranges = _select_cached_ranges(xstore_shard)
    if len(ranges) == 0:
        tls.logger.info('range cache empty, updating from histogram bounds')
        ranges = _ranges_from_histogram_bounds(xstore_shard)
        _set_cached_ranges(xstore_shard, ranges)
    return ranges

def _select_cached_ranges(xstore_shard):
    select_sql = '''
        SELECT first, last
        FROM xiva.cached_counter_ranges
        ORDER BY first
    '''
    conninfo = shards.conninfo_for_read(xstore_shard)
    return execute(conninfo, select_sql)

def _set_cached_ranges(xstore_shard, ranges):
    insert_sql = '''
        INSERT INTO xiva.cached_counter_ranges(first, last)
        VALUES %s
    '''
    conninfo = shards.conninfo_for_write(xstore_shard)
    execute_no_result(conninfo, 'TRUNCATE xiva.cached_counter_ranges')
    execute_values(conninfo, insert_sql, ranges, page_size=10000)

def _ranges_from_histogram_bounds(xstore_shard):
    conninfo = shards.conninfo_for_read(xstore_shard)
    return select_attribute_ranges(conninfo, 'xiva', 'counters', 'uid')

def _apply_user_cleanup_limit(uid_ranges):
    if USERS_TO_CLEAN_PERCENT == 100:
        return uid_ranges
    ranges_to_clean = int(len(uid_ranges) * (USERS_TO_CLEAN_PERCENT / 100.0))
    return uid_ranges[:ranges_to_clean]

def _range_file_path(xstore_shard):
    file_name = shards.dbname(shards.conninfo_for_write(xstore_shard))
    return '/var/xivadba/%s' % (file_name,)

def _skip_processed_ranges(xstore_shard, uid_ranges):
    ranges_processed = 0
    try:
        with open(_range_file_path(xstore_shard)) as f:
            ranges_processed = int(f.read())
    except (OSError, ValueError) as e:
        tls.logger.warning('failed to read number of processed ranges,'
            ' starting from scratch; error %s: %s', type(e).__name__, str(e))
    return uid_ranges[ranges_processed:], ranges_processed

@_profiled
def _select_counters(xstore_shard, uid_range, sid_blacklist):
    conninfo = shards.conninfo_for_read(xstore_shard)
    select_sql = '''
        SELECT uid, sid
        FROM xiva.counters
        WHERE uid BETWEEN %s AND %s
            AND NOT sid = ANY(%s)
    '''
    return execute(conninfo, select_sql, uid_range.first, uid_range.last, sid_blacklist)

@_profiled
def _distribute_by_xtable_conninfo(counters, services):
    xtable_shards = shards.get('xtable')
    users_by_conninfo = defaultdict(list)
    # Calling shards.conninfo_for_read is expensive, because
    # each call determines role of each host in the conninfo
    # by executing an equivalent of 'SHOW transaction_read_only'.
    # Therefore, each conninfo_for_read = 3 established connections + 3 requests.
    # To avoid doing this for each counter, cache replica conninfo
    # for each shard.
    replica_cache = { s["id"]: shards.conninfo_for_read(s) for s in xtable_shards }
    for counter in counters:
        service = services[counter.sid]
        user = User(service, counter.uid)
        xtable_shard = shards.find(user.uid, xtable_shards)
        xtable_conninfo = replica_cache[xtable_shard["id"]]
        users_by_conninfo[xtable_conninfo].append(user)
    return users_by_conninfo

@_profiled
def _select_acked_subscriptions(users_by_conninfo):
    select_sql = '''
        SELECT DISTINCT service, uid
        FROM xiva.subscriptions
            INNER JOIN (VALUES %s) as users(service, uid)
            USING (service, uid)
        WHERE ack_local_id > 0
    '''
    acked_subscriptions = []
    for conninfo, users in users_by_conninfo.items():
        acked_subscriptions += execute_values(conninfo, select_sql, users, len(users), expect_result=True)
    return acked_subscriptions

@_profiled
def _filter_used_counters(services, counters, acked_subscriptions):
    # Prepare a string set (service + uid) for faster lookup.
    def set_key(service, uid):
        return service + ':' + uid
    acked_set = {set_key(s.service, s.uid) for s in acked_subscriptions}
    unused_counters = [c for c in counters if set_key(services[c.sid], c.uid) not in acked_set]
    tls.logger.info('counters: total=%s, used=%s, to_remove=%s',
        len(counters), len(counters) - len(unused_counters), len(unused_counters))
    return unused_counters

@_profiled
def _try_delete_counters(counters, xstore_shard, logger):
    master_conninfo = shards.conninfo_for_write(xstore_shard)
    try:
        worker_pool = ThreadPool(DELETE_POOL_SIZE, _init_worker, (master_conninfo, logger))
        # Default chunk size (1) increases CPU consumption,
        # putting all counters in 1 chunk per worker may consume too much memory.
        worker_chunk_size = int(max(1, len(counters) / DELETE_POOL_SIZE / 100))
        result = worker_pool.starmap_async(_try_delete_counter, counters, chunksize=worker_chunk_size)
        return sum([n for n in result.get()])
    finally:
        worker_pool.terminate()
        worker_pool.join()

def _init_worker(conninfo, logger):
    tls.conninfo = conninfo
    tls.connection = None
    tls.cursor = None
    tls.logger = logger

def _try_delete_counter(uid, sid):
    delete_sql = '''
        DELETE FROM xiva.counters
        WHERE uid=%s
            AND sid=%s
        RETURNING 1
    '''
    try:
        if tls.cursor is None:
            tls.logger.info('connecting to %s', tls.conninfo)
            tls.connection = _connect_autocommit(tls.conninfo)
            tls.cursor = tls.connection.cursor()
        tls.cursor.execute(delete_sql, (uid, sid))
        return len(tls.cursor.fetchall())
    except psycopg2.IntegrityError:
        tls.logger.error('counter uid=%s sid=%s has notifications', uid, sid)
        return 0
    except Exception as e:
        tls.logger.error('failed to delete %s, %s, %s: %s', uid, sid, type(e).__name__, str(e))
        if not tls.cursor.closed:
            tls.cursor.close()
        tls.cursor = None
        if not tls.connection.closed:
            tls.connection.close()
        tls.connection = None
        return 0

def _update_ranges_processed(xstore_shard, range_index):
    with open(_range_file_path(xstore_shard), 'w') as f:
        f.write(str(range_index))

def _connect_autocommit(conninfo):
    connection = psycopg2.connect(conninfo)
    connection.autocommit = True
    return connection
