import re
import time
import datetime
import psycopg2
import requests
import os
from collections import Counter, namedtuple, defaultdict
from multiprocessing.pool import ThreadPool
import queue

import shards
import db_requests
from db_requests import *

CLEANUP_INTERVAL_SEC = 30.0
SHORT_LIVED_SUBSCRIPTION_TTL = 24
MAX_SUBS_PER_DELETE = 50000
REMOVE_BAD_FCM_SUBSCRIPTIONS_GID_OFFSET_PERCENT = 0
REMOVE_BAD_FCM_SUBSCRIPTIONS_GID_SPAN_PERCENT = 1
REMOVE_BAD_FCM_SUBSCRIPTIONS_SERVICES = ['taxi', 'taximeter', 'yamb', 'autoru', 'maps-push-proxy', 'appmetrica_25378', 'mail']
REMOVE_BAD_FCM_SUBSCRIPTIONS_404_MIN_AGE = datetime.timedelta(days=(1.5 * 365))
SYNC_ACK_SERVICE_BLACKLIST = (
    'apns_queue',
    'passport-stream',
    'webpushapi',
    'bass',
    'disk-json',
    'fake'
)

def remove_bad_fcm_subscriptions(shard, logger):
    select_distinct_service_sql = 'select distinct service from xiva.subscriptions order by service'
    select_fcm_subs_sql = '''
        SELECT gid, service, id, callback, uid, init_time
        FROM xiva.subscriptions
        WHERE (platform = 'gcm' OR platform = 'fcm')
            AND gid = %s
            AND service = %s;
    '''

    logger.info('start processing')
    start_tm = datetime.datetime.now()

    conninfo = shards.conninfo_for_read(shard)
    shard_gids = shards.gids_range(shard)
    gids_range = scale_gid_range(shard_gids,
                                 REMOVE_BAD_FCM_SUBSCRIPTIONS_GID_OFFSET_PERCENT,
                                 REMOVE_BAD_FCM_SUBSCRIPTIONS_GID_SPAN_PERCENT)
    services = REMOVE_BAD_FCM_SUBSCRIPTIONS_SERVICES

    logger.info('shard gids [%s:%s], offset %s, span %s, gids [%s:%s]',
                shard_gids[0], shard_gids[1],
                REMOVE_BAD_FCM_SUBSCRIPTIONS_GID_OFFSET_PERCENT,
                REMOVE_BAD_FCM_SUBSCRIPTIONS_GID_SPAN_PERCENT,
                gids_range[0], gids_range[1])
    logger.info('preprocess taken %.2fs with %s services and %s gids',
                (datetime.datetime.now() - start_tm).total_seconds(),
                len(services),
                gids_range[1] - gids_range[0] + 1)

    sanitizer = FCMSanitizer(logger, concurrency=4)

    # clean each subset gid in each service
    for service in services:
        start_service_tm = datetime.datetime.now()
        for gid in range(gids_range[0], gids_range[1] + 1):
            subs = execute(conninfo, select_fcm_subs_sql, gid, service)
            for sub in subs:
                sanitizer.async_process_sub(sub)
            sanitizer.wait_finish()
            logger.info('processed service %s, gid %s', service, gid)
        logger.info('finish processing service %s in %.2fs, %s unsubscribed / %s processed',
                    service,
                    (datetime.datetime.now() - start_service_tm).total_seconds(),
                    sanitizer.counters[service]['unsubscribed'],
                    sanitizer.counters[service]['processed'])

    total_counters = sanitizer.total_counters()
    logger.info('finish processing shard in %.2fs with %s services, total subscriptions: %s unsubscribed / %s processed',
                  (datetime.datetime.now() - start_tm).total_seconds(),
                  len(services),
                  total_counters['unsubscribed'],
                  total_counters['processed'])

    # Do not start one more time
    while True:
        time.sleep(3600)

def scale_gid_range(range, offset_percent, span_percent):
    full_span = range[1] - range[0]
    assert(full_span > 0)
    span = int(full_span * span_percent / 100)
    offset = int(full_span * offset_percent / 100)
    beg = bound(range[0], range[0] + offset, range[1])
    end = bound(range[0], range[0] + offset + span, range[1])
    return [beg, end]

def bound(low, value, high):
    return max(low, min(high, value))

class FCMSanitizer:
    def __init__(self, logger, concurrency=1):
        self.logger = logger
        self.counters = defaultdict(Counter)
        self.task_queue = queue.Queue()
        self.pool = ThreadPool(processes=concurrency, initializer=_worker, initargs=[self.task_queue, self.counters, self.logger])

    def async_process_sub(self, sub):
        self.task_queue.put([sub])

    def wait_finish(self):
        self.task_queue.join()

    def total_counters(self):
        res = Counter()
        for service in self.counters:
            res['unsubscribed'] += self.counters[service]['unsubscribed']
            res['processed'] += self.counters[service]['processed']
        return res

def _worker(task_queue, counters, logger):
    session = requests.Session()
    tvm_ticket_hub = TVMTicket(get_tvm_ticket, logger)
    while True:
        task = task_queue.get()
        tvm_ticket_hub.update()
        unsubscribe_if_bad(session, tvm_ticket_hub.ticket, counters, logger, *task)
        task_queue.task_done()

def unsubscribe_if_bad(session, tvm_ticket_hub, counters, logger, sub):
    sub_tskv = 'uid=' + sub.uid + ' service=' + sub.service + ' id=' + sub.id
    try:
        app, token = app_and_token_from_callback(sub.callback)
        is_invalid, reason = sub_is_invalid(session, sub.service, sub.init_time, app, token)
        if is_invalid:
            unsubscribe(sub.uid, sub.service, sub.id, session, 'janitor_bad_fcm_token', reason,
                tvm_ticket_hub, logger)
            counters[sub.service]['unsubscribed'] += 1
    except Exception as e:
        logger.error('%s FCM subscription processing failed, %s: %s',
                     sub_tskv, type(e).__name__, str(e))
    counters[sub.service]['processed'] += 1

def sub_is_invalid(session, service, init_time, app, token):
    response = check_fcm_request(session, service, app, token)
    response.raise_for_status()
    check_res = response.json()
    if not check_res:
        raise Exception('check res is None or empty')
    if not 'token' in check_res or check_res['token'] != token:
        raise Exception('token mismatch')
    if not 'code' in check_res or not 'error' in check_res:
        raise Exception('missed code or error field')
    if check_res['code'] == '400' and check_res['error'] == 'InvalidToken':
        return True, 'InvalidToken'
    if check_res['code'] == '400' and check_res['error'] == 'InvalidTokenVersion':
        return True, 'InvalidTokenVersion'
    if (check_res['code'] == '404' and check_res['error'] == 'No information found about this instance id.' and
        init_time < datetime.datetime.now(datetime.timezone.utc) - REMOVE_BAD_FCM_SUBSCRIPTIONS_404_MIN_AGE):
        return True, 'NoInformation'
    return False, 'sub is valid'

def retry_check_request(request):
    def wrapper(*args, **kwargs):
        while True:
            response = request(*args, **kwargs)
            if response.status_code not in [429, 500, 502, 504]:
                break
            time.sleep(0.1)
        return response
    return wrapper

@retry_check_request
def check_fcm_request(session, service, app, token):
    check_url = xivamob_url_from_env() + '/check/fcm'
    params = {
        'service': service,
        'app': app
    }
    headers = {
        'user-agent': 'janitor_bad_fcm_token'
    }
    data = {
        'token': token
    }
    return session.post(check_url, params=params, headers=headers, data=data)

def app_and_token_from_callback(callback):
    sub_type, app_and_token = callback.split(':', 1)
    if sub_type != 'xivamob':
        raise Exception('wrong subscription type')
    app, token = app_and_token.split('/', 1)
    if not app:
        raise Exception('empty app')
    if not token:
        raise Exception('empty token')
    return app, token

def xivamob_url_from_env():
    if not running_in_qloud():
        # Assume local environment.
        return 'http://localhost:12080'
    envtype = os.environ['QLOUD_ENVIRONMENT']
    if envtype == 'production':
        return 'http://xivamob.mail.yandex.net'
    elif envtype == 'corp':
        return 'http://xivamobcorp.mail.yandex.net'
    elif envtype == 'sandbox':
        return 'http://xivamob-sandbox.mail.yandex.net'
    raise Exception('unknown QLOUD_ENVIRONMENT: ' + envtype)

def running_in_qloud():
    return 'QLOUD_ENVIRONMENT' in os.environ

def compute_sleep_time(start_time):
    end_time = time.time()
    sleep_time = CLEANUP_INTERVAL_SEC - (end_time - start_time)
    sleep_time = sleep_time if sleep_time > 0 else 0
    return sleep_time

class TVMTicket:
    def __init__(self, getter, *args):
        self.getter = getter
        self.args = args
        self.__get_ticket()

    def __get_ticket(self):
        self.ticket = self.getter(*self.args)
        self.updated_at = datetime.datetime.now()

    def update(self):
        elapsed_hours = (datetime.datetime.now() - self.updated_at).total_seconds() / 3600.0
        if elapsed_hours > 1:
            self.__get_ticket()

def get_tvm_ticket(logger):
    if not running_in_qloud():
        return 'no_tvm_in_local_runs'
    try:
        env = os.environ['QLOUD_ENVIRONMENT']
        auth = {'authorization': os.environ['QLOUD_TVM_TOKEN']}
        src = 'xivadba-' + env
        dst = 'xivahub-' + env
        url = 'http://localhost:1/tvm/tickets?src=%s&dsts=%s' % (src, dst)
        resp = requests.get(url, headers=auth)
        resp.raise_for_status()
        return resp.json()[dst]['ticket']
    except Exception as e:
        logger.error('failed to get tvm ticket, %s: %s', type(e).__name__, str(e))
        raise e

def unsubscribe(uid, service, id, session, user_agent, reason, tvm_ticket, logger):
    sub_tskv = 'uid=' + uid + ' service=' + service + ' id=' + id
    try:
        hub_url = shards.hub_from_env()
        unsub_url = hub_url + '/unsubscribe'
        params = {
            'uid': uid,
            'service': service,
            'subscription-id': id
        }
        headers = {
            'user-agent': user_agent,
            'x-ya-service-ticket': tvm_ticket
        }
        response = session.post(unsub_url, params=params, headers=headers)
        response.raise_for_status()
        logger.info('%s unsubscribe OK, reason %s', sub_tskv, reason)
    except Exception as e:
        logger.error('%s unsubscribe failed, %s: %s', sub_tskv, type(e).__name__, str(e))

def unsubscribe_all(subs, session, user_agent, reason, tvm_ticket, logger):
    for sub in subs:
        unsubscribe(sub.uid, sub.service, sub.id, session, user_agent, reason, tvm_ticket, logger)

def delete_all(subs, conninfo, logger):
    delete_sql = '''
        WITH keys_to_delete AS (
            SELECT * FROM (VALUES %s) AS x(i_uid, i_service, i_id, i_ttl))
        DELETE FROM xiva.subscriptions
        USING keys_to_delete
        WHERE uid=i_uid and service=i_service and id=i_id;
    '''
    try:
        if (len(subs) > 0):
            execute_values(conninfo, delete_sql, subs, MAX_SUBS_PER_DELETE)
    except Exception as e:
        logger.error('delete failed, %s: %s', type(e).__name__, str(e))

def remove_old_subscriptions(shard, logger):
    select_expired_subs_sql = '''
        SELECT uid, service, id, ttl
        FROM xiva.subscriptions
        WHERE uid BETWEEN %s AND %s
            AND now() - init_time > make_interval(hours=>ttl)
            AND (platform IS NULL OR platform = '')
        LIMIT 300;
    '''
    start_time = time.time()
    replica_conn_string = shards.conninfo_for_read(shard)
    friendly_shard_name = shards.friendly_name(shard)

    logger.info('remove_old_subscriptions')

    uids = execute(replica_conn_string, 'SELECT * FROM code.return_uid_ranges();')[0][0]
    requests_session = requests.Session()
    tvm_ticket = get_tvm_ticket(logger)
    for i in range(0, len(uids)-1):
        expired_subs = execute(replica_conn_string, select_expired_subs_sql, uids[i], uids[i+1])
        short_lived_subs = [s for s in expired_subs if s.ttl <= SHORT_LIVED_SUBSCRIPTION_TTL]
        long_lived_subs = [s for s in expired_subs if s.ttl > SHORT_LIVED_SUBSCRIPTION_TTL]
        logger.info('selected %s expired subs %s short-lived %s long-lived',
            len(expired_subs), len(short_lived_subs), len(long_lived_subs))
        delete_all(short_lived_subs, shards.conninfo_for_write(shard), logger)
        unsubscribe_all(long_lived_subs, requests_session, 'janitor', 'old', tvm_ticket, logger)

    sleep_time = compute_sleep_time(start_time)
    logger.info('sleeping for %s before next remove_old_subscriptions', sleep_time)
    time.sleep(sleep_time)

def remove_broken_subscriptions(shard, logger):
    start_time = time.time()
    master_conn_string = shards.conninfo_for_write(shard)

    logger.info('remove_broken_subscriptions')
    removed_count = execute(master_conn_string,
        'SELECT code.remove_broken_subscriptions(%s);', 'off')[0][0]
    logger.info('removed %s broken subscriptions', removed_count)

    sleep_time = compute_sleep_time(start_time)
    logger.info('sleeping for %s before next remove_broken_subscriptions', sleep_time)
    time.sleep(sleep_time)

XstoreTask = namedtuple('XstoreTask', [
    'start_gid',
    'end_gid',
    'replica_conninfo',
    'master_conninfo',
    'subs'
])

def sync_acks_with_counters(shard, logger):
    start_time = time.time()
    replica_conn_string = shards.conninfo_for_read(shard)
    uid_ranges = select_attribute_ranges(replica_conn_string, 'xiva', 'subscriptions', 'uid')
    for uid_start, uid_end in uid_ranges:
        logger.info('syncing acks between %s and %s', uid_start, uid_end)
        _sync_acks_between(uid_start, uid_end, shard, logger)
    sleep_time = compute_sleep_time(start_time)
    logger.info('sleeping for %s before next sync_acks_with_counters', sleep_time)
    time.sleep(sleep_time)

def _sync_acks_between(start_uid, end_uid, shard, logger):
    subs = _select_acked_subscriptions(shard, SYNC_ACK_SERVICE_BLACKLIST, start_uid, end_uid)
    xstore_tasks = _distribute_between_xstore_shards(subs)
    for xstore_task in [t for t in xstore_tasks if len(t.subs) > 0]:
        shard_name = shards.dbname(xstore_task.master_conninfo)
        logger.info('%s subscriptions to check in shard %s', len(xstore_task.subs), shard_name)
        missing, lagging = _select_unsynced_counters(xstore_task.replica_conninfo, xstore_task.subs)
        logger.info('counters broken in shard %s: missing=%s lagging=%s', shard_name, len(missing), len(lagging))
        _restore_counters(xstore_task.master_conninfo, missing)
        _advance_counters(xstore_task.master_conninfo, lagging)

def _select_acked_subscriptions(shard, service_blacklist, uid_start, uid_end):
    select_sql = '''
        SELECT service, uid, gid, max(ack_local_id) as ack_local_id
        FROM xiva.subscriptions
        WHERE service NOT IN %s
            AND uid BETWEEN %s AND %s
            AND ack_local_id > 0
        GROUP BY (service, uid, gid)
    '''
    conninfo = shards.conninfo_for_read(shard)
    return execute(conninfo, select_sql, service_blacklist, uid_start, uid_end)

def _distribute_between_xstore_shards(subs):
    xstore_shards = shards.get('xstore')
    return [_make_xstore_task(subs, shard) for shard in xstore_shards]

def _make_xstore_task(subs, shard):
    start_gid = shard['start_gid']
    end_gid = shard['end_gid']
    return XstoreTask(
        start_gid,
        end_gid,
        shards.conninfo_for_read(shard),
        shards.conninfo_for_write(shard),
        [s for s in subs if start_gid <= s.gid and s.gid <= end_gid])

def _select_unsynced_counters(xstore_replica_conninfo, subs):
    select_sql = '''
        WITH sid_subs AS (
            SELECT sid, service_name, uid, ack_local_id
            FROM xiva.services
                INNER JOIN (VALUES %s) as users(service_name, uid, gid, ack_local_id)
                USING (service_name))
        SELECT service_name as service, uid, sid, total_count, ack_local_id
            FROM sid_subs
                LEFT JOIN xiva.counters
                USING (sid, uid)
        WHERE total_count IS NULL
            OR total_count < ack_local_id
    '''
    # Force all results to fit on one page.
    unsynced_counters = execute_values(xstore_replica_conninfo, select_sql,
        subs, page_size=len(subs), expect_result=True)
    missing_counters = [(c.sid, c.uid, c.ack_local_id) for c in unsynced_counters if c.total_count is None]
    lagging_counters = [(c.sid, c.uid, c.ack_local_id) for c in unsynced_counters if c.total_count is not None]
    return (missing_counters, lagging_counters)

def _restore_counters(xstore_conninfo, counters):
    insert_sql = '''
        INSERT INTO xiva.counters(total_count, unseen_count, next_local_id, last_seen_id, sid, uid)
        VALUES (%(total)s, %(unseen)s, %(next)s, %(last_seen)s, %(sid)s, %(uid)s)
        ON CONFLICT DO NOTHING
    '''
    _execute_for_each_counter(insert_sql, counters, xstore_conninfo)

def _advance_counters(xstore_conninfo, counters):
    advance_sql = '''
        UPDATE xiva.counters
        SET total_count = %(total)s,
            unseen_count = %(unseen)s,
            next_local_id = %(next)s
        WHERE sid = %(sid)s AND uid = %(uid)s
    '''
    _execute_for_each_counter(advance_sql, counters, xstore_conninfo)

def _execute_for_each_counter(sql, counters, xstore_conninfo):
    with psycopg2.connect(xstore_conninfo) as connection:
        connection.autocommit = True
        with connection.cursor() as cursor:
            for sid, uid, ack_local_id in counters:
                cursor.execute(sql, {
                    "total": ack_local_id,
                    "unseen": ack_local_id,
                    "next": ack_local_id + 1,
                    "last_seen": 0,
                    "sid": sid,
                    "uid": uid
                })
