#! /usr/bin/env python

import os
import sys
import logging
import logging.handlers
from collections import defaultdict
from datetime import datetime, date, timedelta

import psycopg2
import yt.wrapper as yt
import shards
import yt_tables

SUBSCRIPTIONS_REQUEST = """
select
    service,
    platform,
    count(1),
    sum(pg_column_size(xiva.subscriptions.*)) as size
from
    xiva.subscriptions
where
    (platform is not NULL and platform != '')
    or
    ((platform is NULL or platform = '') and callback like 'webpush:%')
group by
    service, platform
;
"""

SERVICES_REQUEST = """
SELECT
    service_name as service, sid
FROM
    xiva.services
"""

COUNTERS_REQUEST = """
SELECT
    sid, count(1), sum(pg_column_size(xiva.counters.*)) as size
FROM
    xiva.counters
GROUP BY
    sid
"""

NOTIFICATIONS_REQUEST_TEMPLATE = """
SELECT
    sid, count(1), sum(pg_column_size(%(partition)s.*)) as size
FROM
    %(partition)s
GROUP BY
    sid
"""

PERCENTILES = [50, 75, 85, 90, 95, 99, 100]


def percentiles_request(partition):
    request = """
        WITH uid_counts(sid, count, size) AS(
            SELECT
                sid, count(1), sum(pg_column_size(%(partition)s.*))
            FROM
                %(partition)s
            GROUP BY
                (sid, uid)
        )
        SELECT
            sid,
    """ % {
        "partition": partition
    }

    # Add percentile select list items. Such as:
    # percentile_cont(0.5) within group (order by count) as count_p50,
    # Then add '1' to finalize the select list, because last percentile
    # item ends in a ','.
    for group in ("count", "size"):
        for p in PERCENTILES:
            request += (
                "percentile_cont(%(float_p)s) within group (order by %(group)s) as %(group)s_p%(p)s,"
                % {"float_p": 0.01 * p, "group": group, "p": p}
            )
    request += """
            1
        FROM
            uid_counts
        GROUP BY
            sid
    """
    return request


def shard_execute(conninfo, request):
    friendly_shard_name = str(shards.hosts(conninfo))

    with psycopg2.connect(conninfo) as connection:
        logging.debug("shard %s got connection %s", friendly_shard_name, connection)
        connection.readonly = True
        with connection.cursor() as cursor:
            cursor.execute(request)
            results = cursor.fetchall()
            return results


def collect_xtable_stats(date_string):
    shards_list = shards.get_unique("xtable")
    replicas = [shards.conninfo_for_read(s) for s in shards_list]

    subscriptions = defaultdict(lambda: defaultdict(int))
    for replica in replicas:
        for service, platform, count, size in shard_execute(replica, request=SUBSCRIPTIONS_REQUEST):
            key = (service, platform if platform else "webpush")
            subscriptions[key]["count"] += count
            subscriptions[key]["size"] += size
    return (
        {
            "date": date_string,
            "service": key[0],
            "platform": key[1],
            "count": values["count"],
            "size": values["size"],
        }
        for key, values in subscriptions.iteritems()
    )


def accumulate_xstore_group(stats, rows, group, services):
    for sid, count, size in rows:
        key = (services[sid], group)
        stats[key]["count"] += count
        stats[key]["size"] += size
    return stats


def accumulate_percentiles(stats, rows, services):
    for row in rows:
        sid = row[0]
        service = services[sid]
        count_percentiles = row[1 : len(PERCENTILES) + 1]
        size_percentiles = row[len(PERCENTILES) + 1 : -1]
        for i in range(0, len(PERCENTILES)):
            group = "notifications_p%s" % (PERCENTILES[i])
            key = (service, group)
            stats[key]["count"] = max(stats[key]["count"], int(count_percentiles[i]))
            stats[key]["size"] = max(stats[key]["size"], int(size_percentiles[i]))
    return stats


def collect_xstore_stats(date_string):
    shards_list = shards.get_unique("xstore")
    replicas = [shards.conninfo_for_read(s) for s in shards_list]
    stats = defaultdict(lambda: defaultdict(int))
    for replica in replicas:
        services = {sid: name for name, sid in shard_execute(replica, SERVICES_REQUEST)}
        stats = accumulate_xstore_group(
            stats, shard_execute(replica, COUNTERS_REQUEST), "counters", services
        )
        partition = "xiva.notifications_p%s" % (date_string.replace("-", "_"),)
        notifications_request = NOTIFICATIONS_REQUEST_TEMPLATE % {"partition": partition}
        stats = accumulate_xstore_group(
            stats, shard_execute(replica, notifications_request), "notifications", services
        )
        # RTEC-4539
        # notification_percentiles = shard_execute(replica, percentiles_request(partition))
        # stats = accumulate_percentiles(stats, notification_percentiles, services)
    return (
        {
            "date": date_string,
            "service": key[0],
            "group": key[1],
            "count": values["count"],
            "size": values["size"],
        }
        for key, values in stats.iteritems()
    )


def add_to_yt(date_string, stats, db):
    yt.config["proxy"]["url"] = "hahn.yt.yandex.net"
    yt.config["token"] = open("/home/xiva/.yt/token").read()
    table = yt_tables.get(db)
    with yt.Transaction():
        yt.run_erase(yt.TablePath(table, exact_key=date_string))
        yt.write_table(yt.TablePath(table, append=True), stats)


def main():
    db = sys.argv[1]
    app = os.environ["QLOUD_APPLICATION"]
    collect_stats = globals()["collect_%s_stats" % (db,)]
    LOGFORMAT = "%(asctime)s %(message)s"
    logging.basicConfig(
        filename="/var/log/%(app)s/%(db)s.log" % {"app": app, "db": db},
        format=LOGFORMAT,
        level=logging.DEBUG,
    )

    try:
        date_string = str(date.today() - timedelta(days=1))
        logging.info("started")
        stats = collect_stats(date_string)
        logging.info("collected")
        add_to_yt(date_string, stats, db)
        logging.info("saved")
    except Exception as e:
        logging.exception(str(e))
        logging.info("failed")


if __name__ == "__main__":
    main()
