import os
import sys
import datetime
import json
import logging

import boto3
from psycopg2 import connect

import datasources_config_wrapper
import solomon

import yt.wrapper as yt


logging.basicConfig()


def load_settings():
    ds = datasources_config_wrapper.DatasourcesConfigWrapper(
        prefix='dogma',
        suppress_warning=True,
    )
    return ds


class Query:
    def __init__(self, cursor):
        self.cursor = cursor

    def __call__(self, sql):
        cursor.execute(sql)
        return cursor.fetchall()


class Solomon:
    def __init__(self, project, cluster, service, token):
        self.client = solomon.BasePushApiReporter(
            project=project,
            cluster=cluster,
            service=service,
            url='http://solomon.yandex.net',
            auth_provider=solomon.OAuthProvider(token),
            common_labels={}
        )

    def __call__(self, sensor, value, labels):
        self.client.set_value(
            sensor=sensor,
            value=value,
            labels=labels
        )


def unfinished_clones(query, solomon, delay_hours):
    d = datetime.datetime.utcnow() + datetime.timedelta(hours=3)

    sql = """
        SELECT count(*), core_node.name, core_node.hostname
        FROM core_clone
        join core_node on core_clone.node_id = core_node.id, core_repo
        WHERE core_clone.modified < '%s' AND core_clone.modified > '2019-10-01 00:00:00+03'
        AND core_clone.repo_id = core_repo.id AND core_repo.is_active=true
        GROUP BY 2, 3
        ORDER BY 3;

    """ % (d - datetime.timedelta(hours=delay_hours))

    res = query(sql)

    total = 0

    for count, name, host in res:
        total += count
        solomon(
            "unfinished_clones",
            count,
            labels={
                "name": name,
                "delay": str(delay_hours) + "h",
            }
        )

    solomon(
        "unfinished_clones",
        total,
        labels={
            "name": "all",
            "delay": str(delay_hours) + "h",
        }
    )

    return res


def failed_clones(query, solomon):

    sql = """
        SELECT core_node.name, COUNT(core_clone.id) FROM core_node
        JOIN core_clone ON core_clone.node_id = core_node.id
        WHERE core_clone.status = 'fail'
        GROUP BY core_node.name
    """

    res = query(sql)

    total = 0

    for node, count in res:
        total += count
        solomon(
            "failed_clones",
            count,
            labels={
                "node": node,
            }
        )

    solomon(
        "failed_clones",
        total,
        labels={
            "node": "all",
        }
    )

    return res


def arcadia_delay(query, solomon):
    sql = """
        SELECT CURRENT_TIMESTAMP - commit_time, core_repo.vcs_name
        FROM core_repo
        JOIN core_pushedcommit ON core_repo.contiguous_chain_of_commits_ends_at_id = core_pushedcommit.id
        WHERE core_repo.id = 50016 OR core_repo.id = 177242
    """

    res = query(sql)

    for r in res:
        delay_hours = r[0].total_seconds() / 3600.
        label = r[1]
        solomon(
            "arcadia_delay",
            delay_hours,
            labels={
                "name": label,
            }
        )


def pushed_commits(query, solomon, intervals):
    intervals = sorted(intervals)

    predicate = "created > current_timestamp - interval '{} hours'"
    selections = []
    for hours in intervals[:-1]:
        selections.append("SUM(case when {} then 1 else 0 end) as interval{}h".format(predicate.format(hours), hours))
    selections.append("COUNT(*) as interval{}h".format(intervals[-1]))

    # core_pushedcommit is one of the biggest tables - accumulate required date within one pass
    sql = """
        SELECT
            {}
        FROM core_pushedcommit
        WHERE {};
    """.format(',\n'.join(selections), predicate.format(intervals[-1]))

    res = query(sql)

    for i, hours in enumerate(intervals):
        solomon(
            "pushed_commits",
            res[0][i],
            labels={
                "interval": str(hours) + "h",
            }
        )

    return res


def unflushed_repos(query, solomon):

    sql = """
        SELECT
            core_repo.id, core_repo.last_yt_sync_time, core_pushedcommit.created
        FROM
            core_repo
        LEFT JOIN
            core_pushedcommit
        ON
            core_repo.contiguous_chain_of_commits_ends_at_id = core_pushedcommit.id
        WHERE
            core_repo.contiguous_chain_of_commits_ends_at_id IS NOT NULL
        AND
            core_repo.is_active = TRUE
        AND
            core_pushedcommit.created IS NOT NULL
        AND
            core_pushedcommit.created > core_repo.last_yt_sync_time
    """

    res = query(sql)

    delays = {
        0: len(res),
        4: 0,
        12: 0,
        24: 0,
    }

    for _, last_yt_sync_time, created in res:
        delay = created - last_yt_sync_time
        for hours in 4, 12, 24:
            if delay >= datetime.timedelta(hours=hours):
                delays[hours] += 1

    for hours, count in delays.items():
        solomon(
            "unflushed_repos",
            count,
            labels={
                "interval": str(hours) + "h",
            }
        )

    return res


def repos_count(query, solomon):

    sql = """
        SELECT COUNT(*) FROM core_repo
    """

    total_repos_count = query(sql)[0][0]

    sql = """
        SELECT COUNT(*) FROM core_repo WHERE is_active = true
    """

    active_repos_count = query(sql)[0][0]

    solomon(
        "repos_count",
        total_repos_count,
        labels={
            "type": "all",
        }
    )

    solomon(
        "repos_count",
        active_repos_count,
        labels={
            "type": "active",
        }
    )
    return total_repos_count, active_repos_count


def sources_status(query, solomon):
    sql = """
        SELECT CASE WHEN status = 'success' THEN 0 ELSE 1 END, name FROM core_source
    """

    res = query(sql)

    for s in res:
        solomon(
            "sources_status",
            s[0],
            labels={
                "name": s[1],
            }
        )

    solomon(
        "sources_status",
        sum(s[0] for s in res),
        labels={
            "name": "total",
        }
    )

    return res


def nodes_free_space(query, solomon):
    sql = """
        SELECT name, space_available FROM core_node
    """
    res = query(sql)

    for s in res:
        solomon(
            "space_available",
            s[1],
            labels={
                "name": s[0],
            }
        )

    return res


def repos_without_clones(query, solomon):
    sql = """
        SELECT id FROM core_repo WHERE is_active = True
    """
    res = query(sql)
    active_repos = [x[0] for x in res]

    sql = """
        SELECT DISTINCT repo_id FROM core_clone WHERE status = 'active'
    """
    res = query(sql)
    cloned_repos = set([x[0] for x in res])

    not_cloned_repos = set([x for x in active_repos if x not in cloned_repos])

    solomon(
        "not_cloned_repos",
        len(not_cloned_repos),
        labels={
            "name": "count",
        }
    )


def queues_count(solomon):
    endpoint='http://sqs.yandex.net:8771'
    access_key='dogma'
    secret_key='not used yet'

    sqs = boto3.client('sqs', region_name='yandex', endpoint_url=endpoint, aws_access_key_id=access_key, aws_secret_access_key=secret_key)

    queues = sqs.list_queues()

    solomon(
        "sqs_queues",
        len(queues['QueueUrls']),
        labels={
            "name": "count",
        }
    )


def yt_push_delay(yt_client, solomon):
    for path in ["//statbox/qpulse/dogma/commits", "//home/dogma/export/all_commits"]:
        modtime = yt_client.get_attribute(path, "modification_time")
        modtime = datetime.datetime.strptime(modtime, "%Y-%m-%dT%H:%M:%S.%fZ")
        delay = datetime.datetime.now() - modtime
        solomon(
            "yt_push_delay",
            delay.total_seconds() / 3600.0 - 3,
            labels={
                "name": path,
            }
        )

    by_date = "//home/dogma/new_export/by_date"
    by_date_tables = yt_client.list(by_date)
    last_date_table = os.path.join(
        by_date,
        sorted(by_date_tables)[-1]
    )
    modtime = yt_client.get_attribute(last_date_table, "modification_time")
    modtime = datetime.datetime.strptime(modtime, "%Y-%m-%dT%H:%M:%S.%fZ")
    delay = datetime.datetime.now() - modtime
    solomon(
        "yt_push_delay",
        delay.total_seconds() / 3600.0 - 3,
        labels={
            "name": by_date + "/",
        }
    )


if __name__ == "__main__":

    debug = "debug" in sys.argv

    is_monitoring = "IS_MONITORING" in os.environ

    node = None

    meta_file = "/etc/qloud/meta.json"
    if os.path.exists(meta_file):
        with open(meta_file) as f:
            meta = json.load(f)
            node = meta["user_environment"]["QLOUD_COMPONENT"]

    if not (debug or is_monitoring or node in [
        "indexer",  # testing sas
        "indexer3",  # testing man
        "indexer-15",  # prod sas
        "indexer-9",  # prod man
    ]):
        sys.exit()

    ds = load_settings()

    hosts = "host=" + ",".join(host[0] for host in ds.database_pgaas_hosts)

    connect_string = "{} port={} dbname={} user={} password={}".format(
        hosts,
        ds.database_pgaas_port,
        ds.database_db,
        ds.database_user,
        ds.database_password,
    )

    connection = connect(connect_string)
    cursor = connection.cursor()
    query = Query(cursor)

    solomon_project = "dogma_testing"
    solomon_cluster = "default"
    is_b2b = False

    if ds.database_db == "dogma_test":
        solomon_service = "dogma_testing"
    elif ds.database_db == "dogma":
        solomon_service = "dogma"
    elif ds.database_db == "test_dogma_pg_db":
        solomon_service = "dogma_b2b_testing"
        is_b2b = True
    elif ds.database_db == "dogma_pg_db":
        solomon_service = "dogma_b2b"
        is_b2b = True
    else:
        assert 0, "No solomon env yet"

    solomon = Solomon(  # noqa: F811
        solomon_project,
        solomon_cluster,
        solomon_service,
        ds.solomon_token,
    )

    intervals = [1, 2, 4, 12, 24]  # hours

    for delay_hours in intervals:
        try:
            unfinished_clones(query, solomon, delay_hours)
        except Exception as e:
            cursor.execute("ROLLBACK")
            connection.commit()
            logging.error(e, exc_info=True)

    unflushed_repos(query, solomon)
    if not is_b2b:
        arcadia_delay(query, solomon)
    repos_count(query, solomon)
    failed_clones(query, solomon)
    sources_status(query, solomon)
    queues_count(solomon)
    nodes_free_space(query, solomon)
    repos_without_clones(query, solomon)

    if not is_b2b:
        try:
            yt_client = yt.YtClient(proxy="hahn", token=ds.yt_oauth_token)
            yt_push_delay(yt_client, solomon)
        except Exception:
            pass

    try:
        pushed_commits(query, solomon, intervals)
    except Exception as e:
        cursor.execute("ROLLBACK")
        connection.commit()
        logging.error(e, exc_info=True)

    if debug:
        import pdb
        pdb.set_trace()
