#!/usr/bin/python
# -*- coding: utf-8 -*-

import sys

sys.path.append("/usr/share/xeno/")
from xeno_monitor import *
from time import sleep
from threading import Thread, Lock
import logging
import json

PURGE_USER_URL = "http://husky-api.mail.yandex.net/delete_user/{uid}/right_now"
SELECT_MAILISH_WITHOUT_AUTH_BY_UIDS_QUERY = """
    SELECT
        uid
    FROM
        mailish.accounts
    WHERE
        uid IN ({}) AND
        uid NOT IN (SELECT uid FROM mailish.auth_data)"""
SELECT_MAILISH_WITH_FOLDERS_BY_UIDS_QUERY = """
    SELECT
        DISTINCT uid
    FROM
        mailish.folders
    WHERE
        uid in ({})
"""
CHECK_PURGED_QUERY = """
    SELECT
        count(DISTINCT uid)
    FROM
        mailish.folders
    WHERE
        uid in ({})"""
LOG_FILE = "/ephemeral/logs/purge_users.log"
STATE_FILE = "/ephemeral/xeno/purge_users_state"
TVM_TICKET_FILE = "/root/.husky_tvm_ticket"
CHUNK_SIZE = 10000
WAIT_PURGE_USERS_COUNT = 5
purged_users_count = 0
purge_failed_count = 0
storage = LocalStorage(STATE_FILE)
lock = Lock()

STATUS_INITIAL = "initial"
STATUS_STARTED = "started"
STATUS_ERROR = "error"
STATUS_PURGED = "purged"


def main():
    if len(sys.argv) != 3:
        print("usage {} <uids_path> <concurrency_hint>".format(sys.argv[0]))
        sys.exit(1)
    uids_path = sys.argv[1]
    concurrency_hint = int(sys.argv[2])
    setup_log()
    load_tvm_ticket()
    uids = read_uids(uids_path)
    uids = get_not_already_purged_users(uids)
    # Split to chunks to minimize time spent in husky queue.
    chunks = split_to_chunks(uids)
    logging.info("start purge, users={} chunks={}".format(len(uids), len(chunks)))
    for chunk in chunks:
        purge_users_with_concurrency(chunk, concurrency_hint)
    logging.info("purged_users={} failed_users={}".format(purged_users_count, purge_failed_count))


def setup_log():
    logging.basicConfig(
        filename=LOG_FILE,
        filemode="a",
        format="[%(asctime)s] %(threadName)s %(levelname)s: %(message)s",
        level=logging.INFO,
        datefmt="%Y-%b-%d %H:%M:%S",
    )


def load_tvm_ticket():
    with open(TVM_TICKET_FILE, "r") as file:
        global TVM_TICKET
        TVM_TICKET = file.read().strip()


def read_uids(filename):
    ret = []
    with open(filename, "r") as file:
        for line in file.readlines():
            uid = line.strip()
            ret.append(uid)
    return ret


def get_not_already_purged_users(uids):
    ret = []
    for uid in uids:
        if not is_already_purged(uid):
            ret.append(uid)
    return ret


def is_already_purged(uid):
    with lock:
        data = storage.get(uid)
    if not data:
        return False
    if "status" not in data:
        return False
    return data["status"] != STATUS_INITIAL


def split_to_chunks(uids):
    chunks = []
    while len(uids) > 0:
        chunks.append(uids[:CHUNK_SIZE])
        uids = uids[CHUNK_SIZE:]
    return chunks


def purge_users_with_concurrency(uids, concurrency):
    shards = get_sharpei_stat().keys()
    chunk_size = len(shards) / concurrency
    threads = []
    while len(shards) > 0:
        chunk = shards[:chunk_size]
        shards = shards[chunk_size:]
        thread = Thread(
            target=purge_users_from_shards,
            args=(
                chunk,
                uids,
            ),
        )
        threads.append(thread)
    logging.info("start purge chunk, concurrency={}".format(len(threads)))
    for thread in threads:
        thread.start()
    logging.info("wait...")
    for thread in threads:
        thread.join()
    logging.info("purge chunk completed")


def purge_users_from_shards(shards, uids):
    for shard_id in shards:
        purge_users_from_shard(shard_id, uids)


def purge_users_from_shard(shard_id, uids):
    logging.info("processing shard {}... {}".format(shard_id, json.dumps(get_stats())))
    uids = select_users_without_auth_data(shard_id, uids)
    uids = select_users_with_folders(shard_id, uids)
    purged_users = []
    for uid in uids:
        logging.info("purge user uid={}".format(uid))
        try:
            purge_user(uid)
            update_stats_on_purge_success()
            purged_users.append(uid)
            set_status(uid, STATUS_STARTED)
        except Exception as e:
            logging.error("purge error: uid={} msg={}".format(uid, str(e)))
            update_stats_on_purge_error()
            set_status(uid, STATUS_ERROR, str(e))
    wait_purge(shard_id, purged_users)
    mark_purged(purged_users)


def get_stats():
    with lock:
        stats = {"purged_users_count": purged_users_count, "purge_failed_count": purge_failed_count}
        return stats


def select_users_without_auth_data(shard_id, uids):
    if len(uids) == 0:
        return []
    cur = get_db_cursor(shard_id)
    users = fetch(cur, SELECT_MAILISH_WITHOUT_AUTH_BY_UIDS_QUERY.format(",".join(uids)))
    return [str(user["uid"]) for user in users]


def select_users_with_folders(shard_id, uids):
    if len(uids) == 0:
        return []
    cur = get_db_cursor(shard_id)
    users = fetch(cur, SELECT_MAILISH_WITH_FOLDERS_BY_UIDS_QUERY.format(",".join(uids)))
    return [str(user["uid"]) for user in users]


@retry()
def purge_user(uid):
    url = PURGE_USER_URL.format(uid=uid)
    headers = {"X-Ya-Service-Ticket": TVM_TICKET}
    load_url(url, add_headers=headers, post_data="")


def update_stats_on_purge_success():
    with lock:
        global purged_users_count
        purged_users_count += 1


def update_stats_on_purge_error():
    with lock:
        global purge_failed_count
        purge_failed_count += 1


def set_status(uid, status, comment=None):
    with lock:
        data = storage.get(uid, {})
        data["status"] = status
        if comment:
            data["comment"] = comment
        storage.set(uid, data)


def wait_purge(shard_id, uids):
    if len(uids) == 0:
        return
    while not check_purged(shard_id, uids):
        sleep(5)


def check_purged(shard_id, uids):
    cur = get_db_cursor(shard_id)
    count = fetch_one(cur, CHECK_PURGED_QUERY.format(",".join(uids)))
    logging.info("wait for {} users in shard {}".format(count, shard_id))
    return count <= WAIT_PURGE_USERS_COUNT


def mark_purged(uids):
    for uid in uids:
        set_status(uid, STATUS_PURGED)


if __name__ == "__main__":
    main()
