#! /usr/bin/env python3

import copy
import logging
import os
import psycopg2
import shards
import tempfile
import time
import yt.wrapper as yt
from multiprocessing.pool import ThreadPool

SERVICES = [
    "appmetrica_10267",
    "appmetrica_10324",
    "appmetrica_25378",
    "mobile-kinopoisk",
    "music",
    "tickets",
    "plus-city",
    "eda-client",
]

YT_TABLES = {
    "temp": {
        "production": "//home/xiva/dump-subscriptions/production-temp",
        "corp": "//home/xiva/dump-subscriptions/corp-temp",
        "sandbox": "//home/xiva/dump-subscriptions/sandbox-temp",
    },
    "final": {
        "production": "//home/xiva/dump-subscriptions/production",
        "corp": "//home/xiva/dump-subscriptions/corp",
        "sandbox": "//home/xiva/dump-subscriptions/sandbox",
    },
}

YT_SCHEME = [
    {"name": "service", "type": "string"},
    {"name": "transport", "type": "string"},
    {"name": "platform", "type": "string"},
    {"name": "uid", "type": "string"},
    {"name": "device", "type": "string"},
]

PG_GIDS_BATCH = {"production": 100, "corp": 1000, "sandbox": 10000}

SUBSCRIPTIONS_SQL = """
    SELECT
        service,
        CASE
            WHEN callback LIKE 'webpush:%%' THEN 'webpush'
            ELSE 'mobile'
        END AS transport,
        COALESCE(platform, ''),
        uid,
        COALESCE(device, '')
    FROM
        xiva.subscriptions
    WHERE
        gid BETWEEN %(start_gid)s AND %(end_gid)s
        AND service = %(service)s
        AND ((platform IS NOT NULL AND platform != '')
            OR ((platform IS NULL OR platform = '') AND callback LIKE 'webpush:%%'))
"""


def write_dsv(temp_file, values):
    s = "\t".join(values) + "\n"
    temp_file.write(s.encode())


def qloud_env():
    return os.environ["QLOUD_ENVIRONMENT"]


def yt_table_path(temp=False):
    return YT_TABLES["temp" if temp else "final"][qloud_env()]


def yt_token_path():
    user = os.environ["QLOUD_EXEC_USER"] if "QLOUD_EXEC_USER" in os.environ else os.environ["USER"]
    return "/home/{}/.yt/token".format(user)


def yt_client():
    config = copy.deepcopy(yt.default_config.get_default_config())
    config["proxy"]["url"] = "hahn.yt.yandex.net"
    config["token"] = open(yt_token_path()).readline().strip()
    return yt.YtClient(config=config)


def prepare_table(yt_client):
    if yt_client.exists(yt_table_path(temp=True)):
        logging.info("removing old temp table")
        yt_client.remove(yt_table_path(temp=True))
        logging.info("removed old temp table")
    else:
        logging.info("no old temp table")
    yt_client.create("table", yt_table_path(temp=True), attributes={"schema": YT_SCHEME})
    logging.info("temp table created")


def commit_table(yt_client):
    logging.info("commiting table")
    yt_client.move(yt_table_path(temp=True), yt_table_path(), force=True)
    logging.info("table commited")


def upload_subscriptions(temp_file, yt_client, logger):
    path = yt_client.TablePath(yt_table_path(temp=True), append=True)
    data = open(temp_file.name, mode="rb")
    fmt = "<columns=[service;transport;platform;uid;device]>schemaful_dsv"
    yt_client.write_table(path, data, format=fmt, raw=True)


def gids_batch():
    return PG_GIDS_BATCH[qloud_env()]


def fetch_subscriptions_range(start_gid, end_gid, shard, service, logger):
    replica = shards.conninfo_for_read(shard)
    with psycopg2.connect(replica) as connection:
        connection.readonly = True
        with connection.cursor() as cursor:
            cursor.execute(
                SUBSCRIPTIONS_SQL, {"start_gid": start_gid, "end_gid": end_gid, "service": service}
            )
            return cursor.fetchall()


def try_fetch_subscriptions_range(start_gid, end_gid, shard, service, logger):
    retry_count = 10
    retry_pause = 5
    for retry_attempt in range(retry_count):
        try:
            return fetch_subscriptions_range(start_gid, end_gid, shard, service, logger)
        except Exception as e:
            logger.exception(str(e))
            logger.debug(
                "failed to download subscriptions range (attempt %s to retry after %s seconds)",
                retry_attempt,
                retry_pause,
            )
            time.sleep(retry_pause)
    raise RuntimeError("failed to download subscriptions range (retry count exceeded)")


def progress(processed, gids_range):
    return int(1.0 * (processed - gids_range[0]) / (gids_range[1] - gids_range[0]) * 100)


def splitted_gids_range(gids_range, step):
    for start_gid in range(gids_range[0], gids_range[1] + 1, step):
        end_gid = min(start_gid + step - 1, gids_range[1])
        yield start_gid, end_gid


def download_subscriptions(temp_file, shard, service, logger):
    shard_gids = shards.gids_range(shard)
    for start_gid, end_gid in splitted_gids_range(shard_gids, gids_batch()):
        subs_range = try_fetch_subscriptions_range(start_gid, end_gid, shard, service, logger)
        for sub in subs_range:
            write_dsv(temp_file, sub)
        logger.debug(
            "downloaded {}% subscriptions for service {}".format(
                progress(end_gid, shard_gids), service
            )
        )
    temp_file.flush()
    logger.debug("service %s subscriptions downloaded successfully", service)


def process_service(shard, service, logger):
    logger.debug("start processing service %s", service)
    with tempfile.NamedTemporaryFile(mode="wb") as temp_file:
        logger.debug("created temporary file %s", temp_file.name)
        download_subscriptions(temp_file, shard, service, logger)
        upload_subscriptions(temp_file, yt_client(), logger)
    logger.debug("finished processing service %s", service)


def process_shard(shard, logger):
    logger.debug("start processing shard")
    for service in SERVICES:
        process_service(shard, service, logger)
    logger.debug("finished processing shard")


def process_shards():
    shards_list = shards.get_unique("xtable")
    pool = ThreadPool(processes=len(shards_list))
    async_result = pool.starmap_async(
        process_shard,
        [
            (
                shard,
                shards.LoggerShardNameAdapter(
                    logging.getLogger(), {"shard_name": shards.friendly_name(shard)}
                ),
            )
            for shard in shards_list
        ],
        chunksize=1,
    )
    async_result.get()


def main():
    LOGFORMAT = "%(asctime)s %(message)s"
    logging.basicConfig(
        filename="/var/log/xivadb-stats/dump-subscriptions.log",
        format=LOGFORMAT,
        level=logging.DEBUG,
    )
    try:
        logging.info("started")
        prepare_table(yt_client())
        process_shards()
        commit_table(yt_client())
        logging.info("finished successfully")
    except Exception as e:
        logging.exception(str(e))
        logging.info("failed")


if __name__ == "__main__":
    main()
