import logging
import requests
import os
import sys
from datetime import datetime
from requests.packages.urllib3.util.retry import Retry
from yp.client import find_token
requests.packages.urllib3.disable_warnings()
from yp.client import YpClient


def setup_custom_logger(name):
    formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
                                  datefmt='%Y-%m-%d %H:%M:%S')
    screen_handler = logging.StreamHandler(stream=sys.stdout)
    screen_handler.setFormatter(formatter)
    logger = logging.getLogger(name)
    logger.handlers = []
    logger.setLevel(logging.DEBUG)
    logger.addHandler(screen_handler)
    return logger


LOGGER = setup_custom_logger('binary')


def get_current_cluster():
    if "YP_MONITORING_CLUSTERS" in os.environ:
        return os.environ["YP_MONITORING_CLUSTERS"].split(",")

    print "Host={}".format(os.uname()[1])
    return [os.uname()[1].split(".")[1]]


def create_session(oauth=None):
    session = requests.Session()
    session.verify = False

    retry = Retry(
        backoff_factor=0.3,
        total=10
    )
    session.mount(
        'http://',
        requests.adapters.HTTPAdapter(max_retries=retry),
    )
    session.mount(
        'https://',
        requests.adapters.HTTPAdapter(max_retries=retry),
    )
    if oauth is not None:
        session.headers['Authorization'] = 'OAuth {}'.format(oauth)

    return session


def send_solomon_metrics(cluster, elapsed_milliseconds, solomon_token):
    session = create_session(solomon_token)
    solomon_packet = dict()

    solomon_packet["sensors"] = [
        {
            "labels": {"sensor": "Elapsed time, ms"},
            "value": elapsed_milliseconds
        },
    ]

    params = dict(
        project="yp",
        service="sd_response_monitoring",
        cluster="yp-{}".format(cluster),
    )

    result = session.post('https://api.solomon.search.yandex.net/api/v2/push', params=params, json=solomon_packet, timeout=60)
    result.raise_for_status()


def notify_results(cluster, status, description):
    LOGGER.info("Notifying cluster {} status={} description={}".format(cluster, status, description))

    notify_data = {
        "source": "yp_sd_response_speed_monitoring",
        "events": [
            {
                "description": description,
                "host": "yp-{}.yandex.net".format(cluster),
                "instance": "",
                "service": "yp_sd_response_speed_monitoring",
                "status": status
            }
        ]
    }

    session = create_session()
    result = session.post("http://juggler-push.search.yandex.net/events", json=notify_data, timeout=10)
    result.raise_for_status()
    LOGGER.info("Notify sent {}, data={}".format(result.ok, notify_data))


class Reason(object):
    def __init__(self, reason, description, elapsed_time=None):
        self.reason = reason
        self.description = description
        self.elapsed_time = elapsed_time

    @staticmethod
    def unknown_error(description, elapsed_time):
        return Reason("UnknownError", description, elapsed_time=elapsed_time)

    def is_unknown_error(self):
        return self.reason == "UnknownError"

    @staticmethod
    def timeout_error(description, elapsed_time):
        return Reason("Timeout", description, elapsed_time=elapsed_time)

    def is_timeout_error(self):
        return self.reason == "Timeout"

    @staticmethod
    def no_results_error(description, elapsed_time):
        return Reason("No results", description, elapsed_time=elapsed_time)

    def is_no_results_error(self):
        return self.reason == "No results"

    @staticmethod
    def ok(description, elapsed_time):
        return Reason("OK", description, elapsed_time=elapsed_time)

    def is_ok(self):
        return self.reason == "OK"

    def __repr__(self):
        return "status={}, description={}".format(self.status(), self.description)

    def status(self):
        return "OK" if self.is_ok() else "CRIT"

    def get_elapsed_milliseconds(self):
        return self.elapsed_time


def get_podset_pods_count(client, podset_id, cluster):

    try:
        pods = client.select_objects(
            "pod",
            selectors=["/meta/id"],
            filter="[/meta/pod_set_id]='{}'".format(podset_id))
    except requests.exceptions, ex:
        LOGGER.exception(ex)
        return 0, Reason.unknown_error(
            "Unknown error while receiving pods count for {} pod_set{}, see traces".format(
                cluster, podset_id
            ),
            None
        )

    if len(pods) == 0:
        return 0, Reason.no_results_error("No pods found in podset {}".format(podset_id), None)

    pods_count = len(pods)
    LOGGER.info("Pods found in podset {} count {}".format(podset_id, pods_count))
    return pods_count, Reason.ok("", None)


def check_endpoint_set(client, endpoint_set_id, pods_count, cluster):

    query_time_limit_milliseconds = 15*1000
    start_time = datetime.now()

    try:
        endpoints = client.select_objects(
            "endpoint",
            selectors=["/meta/id"],
            filter="[/meta/endpoint_set_id]='{}'".format(endpoint_set_id))
    except requests.exceptions, ex:
        LOGGER.exception(ex)
        return Reason.unknown_error(
            "Unknown error while resolving eps={} for {}, see traces".format(
                endpoint_set_id, cluster), None)

    elapsed_microseconds = (datetime.now()-start_time).microseconds
    elapsed_milliseconds = elapsed_microseconds/1000

    if elapsed_milliseconds > query_time_limit_milliseconds:
        return Reason.timeout_error("Too long SD query: {}ms".format(elapsed_milliseconds),
                                    elapsed_milliseconds)
    elif len(endpoints) == 0:
        return Reason.no_results_error("Empty SD response",
                                       elapsed_milliseconds)

    endpoints_count = len(endpoints)
    LOGGER.debug("endpoints_count={}".format(endpoints_count))

    if (pods_count < 10 and endpoints_count > 0) or (endpoints_count > (pods_count*60)/100):
        return Reason.ok("Total {} milliseconds for endpoint_set={}".format(elapsed_milliseconds, endpoint_set_id),
                         elapsed_milliseconds)
    else:
        return Reason.unknown_error("Total {} milliseconds, pods={}, endpoints={}"
                                    .format(elapsed_milliseconds, pods_count, endpoints_count),
                                    elapsed_milliseconds)


def main():
    for cluster in get_current_cluster():
        LOGGER.info("Processing {}".format(cluster))
        client = YpClient(address=cluster,
                          config={
                              "token": find_token(),
                              "retries":
                                  {
                                      "enable": True,
                                      "count": 3,
                                      "backoff":
                                      {
                                          "policy": "constant_time",
                                          "constant_time": 1000
                                      }
                                  }
                          }
                          )

        checks = [
            {
                "endpoint_set_template": "yp-rtc-sla-tentacles-production-{}",
                "pod_set_template": "yp-rtc-sla-tentacles-production-{}"
            },
            {
                "endpoint_set_template": "infra-monitoring-{}",
                "pod_set_template": "infra-monitoring-{}"
            }
        ]

        reasons = []
        for check in checks:
            endpoint_set_template = check["endpoint_set_template"]
            pod_set_template = check["pod_set_template"]

            pod_set_id = pod_set_template.format(cluster)
            endpoint_set_id = endpoint_set_template.format(cluster)

            pods_count, reason = get_podset_pods_count(client, pod_set_id, cluster)
            if not reason.is_ok():
                reasons.append(reason)
                continue
            else:
                reason = check_endpoint_set(client, endpoint_set_id, pods_count, cluster)
                LOGGER.debug("Checking cluster {} eps={} reason={}".format(cluster, endpoint_set_id, reason))

                reasons.append(reason)

                if reason.is_ok():
                    break

        ok_statuses = [status for status in reasons if status.is_ok()]
        if len(ok_statuses) > 0:
            status = ok_statuses[0]

            if status.get_elapsed_milliseconds() is not None:
                send_solomon_metrics(cluster, status.get_elapsed_milliseconds(), os.environ["SOLOMON_TOKEN"])

            notify_results(cluster, status.status(), status.description)
        else:
            for status in [status for status in reasons if not status.is_ok()]:
                notify_results(cluster, status.status(), status.description)


if __name__ == '__main__':
    main()
