"""Juggler client."""

import json
import logging
import re
from contextlib import contextmanager

import juggler_sdk as sdk
from requests import RequestException

import walle.clients.utils as client_utils
from sepelib.core import config
from sepelib.core.constants import DAY_SECONDS, MINUTE_SECONDS
from sepelib.core.exceptions import LogicalError
from walle.errors import RecoverableError
from walle.models import timestamp
from walle.util import cloud_tools
from walle.util.misc import ellipsis_string

_BLOCKED_HOSTNAME_DOWNTIME_TIMEOUT_DAYS = 30
BLOCKED_HOSTNAME_DOWNTIME_OFFSET = "{} days".format(_BLOCKED_HOSTNAME_DOWNTIME_TIMEOUT_DAYS)
BLOCKED_HOSTNAME_DOWNTIME_TIMEOUT = _BLOCKED_HOSTNAME_DOWNTIME_TIMEOUT_DAYS * DAY_SECONDS
DEFAULT_DOWNTIME_TIMEOUT = 90 * DAY_SECONDS
TRANSITION_DOWNTIME_TIMEOUT = 30 * MINUTE_SECONDS
JUGGLER_MSG_MAX_SIZE = 1023


log = logging.getLogger(__name__)


class JugglerCheckStatus:
    OK = "OK"
    INFO = "INFO"
    WARN = "WARN"
    CRIT = "CRIT"
    ALL = [OK, INFO, WARN, CRIT]


class JugglerDowntimeParams:
    def __init__(self, source_suffix, log_message, timeout=None, end_time=None):
        self.source_suffix = source_suffix
        self.timeout = timeout
        self.log_message = log_message
        self._end_time = end_time

    @property
    def end_time(self):
        if self.timeout:
            return timestamp() + self.timeout
        return self._end_time


class JugglerDowntimeName:

    DEFAULT = "default"

    DELETED_HOST = "deleted_host"

    TRANSITION_STATE_HOST = "transition_state"

    ALL = (DEFAULT, DELETED_HOST, TRANSITION_STATE_HOST)

    REMOVABLE = (DEFAULT, DELETED_HOST)


class JugglerDowntimeTypes:

    DEFAULT = JugglerDowntimeParams(
        source_suffix=JugglerDowntimeName.DEFAULT, timeout=DEFAULT_DOWNTIME_TIMEOUT, log_message=""
    )

    DELETED_HOST = JugglerDowntimeParams(
        source_suffix=JugglerDowntimeName.DELETED_HOST,
        timeout=BLOCKED_HOSTNAME_DOWNTIME_TIMEOUT,
        log_message=" for {}".format(BLOCKED_HOSTNAME_DOWNTIME_OFFSET),
    )

    TRANSITION_STATE_HOST = JugglerDowntimeParams(
        source_suffix=JugglerDowntimeName.TRANSITION_STATE_HOST,
        timeout=TRANSITION_DOWNTIME_TIMEOUT,
        log_message=" for {}".format(TRANSITION_DOWNTIME_TIMEOUT),
    )

    ALL = (DEFAULT, DELETED_HOST, TRANSITION_STATE_HOST)

    ALL_BY_SUFFIX = {t.source_suffix: t for t in ALL}

    @classmethod
    def by_suffix_name(cls, suffix):
        try:
            return cls.ALL_BY_SUFFIX[suffix]
        except KeyError:
            raise LogicalError


class JugglerError(RecoverableError):
    def log(self):
        log.error(str(self))


class JugglerCommunicationError(JugglerError):
    def __init__(self, error, message, events):
        self.error = error
        self.message = message
        self.events = events

        super().__init__("Error in communication with juggler agent: {}: {}", error, message)

    def log(self):
        log.error(
            "Error in communication with juggler agent: %s: %s. Events data was %s",
            self.error,
            self.message,
            json.dumps(self.events),
        )


class JugglerConnectionError(JugglerError):
    def __init__(self, error, events):
        self.error = error
        self.events = events

        super().__init__("Error in communication with juggler agent: {}", error)

    def log(self):
        log.exception(
            "Error in communication with juggler agent: %s. Events data was %s", self.error, json.dumps(self.events)
        )


class JugglerEventsLostError(JugglerError):
    def __init__(self, events):
        self.events = events
        super().__init__("Events has been lost by juggler agent.")

    def log(self):
        log.error("Events has been lost by juggler agent. Event data was %s", json.dumps(self.events))


class JugglerEventError(JugglerError):
    def __init__(self, error, event):
        self.event = event
        self.error = error

        super().__init__("Event {}:{} was rejected by juggler agent: {}.", event["host"], event["service"], error)

    def log(self):
        log.error(
            "Event %s:%s was rejected by juggler: %s. Event data was %s",
            self.event["host"],
            self.event["service"],
            self.error,
            json.dumps(self.event),
        )


def send_event(juggler_service_name, status, message, tags=None, host_name=None):
    event_data = dict(
        status=status,
        description=message,
        tags=config.get_value("juggler.event_tags") + list(tags or []),
        **cloud_tools.get_juggler_event_key(juggler_service_name, host_name)
    )

    return send_batch([event_data])


def send_batch(events):
    """Send events in a single batch and return list of errors, possibly empty.
    Be sure to check that batch is not too big to send (e.g. limit it to 100 events).
    """
    if not config.get_value("juggler.events_enabled", False):
        log.info("Not sending events to juggler because juggler events are disabled.")
        return []

    try:
        result = _local_agent_request({"events": events})

    except client_utils.HttpClientError as e:
        message = e.response.json().get("message", "No error message")
        return _log_errors(JugglerCommunicationError(str(e), message, events))

    except Exception as e:
        return _log_errors(JugglerConnectionError(str(e), events))

    else:
        if result["accepted_events"] == len(events):
            return []

        if not result.get("events"):
            return _log_errors(JugglerEventsLostError(events))

        errors = []
        for index, event_result in enumerate(result["events"]):
            if event_result["code"] != 200:
                errors.append(JugglerEventError(event_result["error"], events[index]))

        return _log_errors(errors if errors else JugglerError("No error description from Juggler."))


def normalize(string):
    """Remove any garbage symbols, keep only letters, numbers, dashes and low dashes"""
    return re.sub(r"\W+", "-", string)


def _log_errors(errors):
    """

    :type errors: Union[JugglerError, Iterable[JugglerError]]
    """
    if isinstance(errors, JugglerError):
        errors = [errors]

    for error in errors:
        error.log()

    return errors


class JugglerClient:
    def __init__(self, client=None):
        if client is None:
            client = _get_juggler_client()

        self._client = client
        self._prefix = stand_uid()

    def _get_source(self, suffix):
        return "{}.{}".format(self._prefix, suffix)

    def remove_downtimes(self, downtime_ids):
        if isinstance(downtime_ids, str):
            downtime_ids = [downtime_ids]
        return self._client.remove_downtimes(downtime_ids)

    def is_downtimed(self, fqdn):
        return fqdn in self.get_fqdn_to_downtimes_map(fqdn=fqdn)

    def set_downtime(self, host, description, end_time=None, suffix=JugglerDowntimeName.DEFAULT):
        return self._client.set_downtimes(
            filters=[sdk.DowntimeSelector(host=host)],
            end_time=end_time,
            description=description,
            source=self._get_source(suffix),
        ).downtime_id

    def edit_downtime(self, host, description, end_time, downtime_id, source):
        return self._client.set_downtimes(
            filters=[sdk.DowntimeSelector(host=host)],
            end_time=end_time,
            downtime_id=downtime_id,
            description=description,
            source=source,
        ).downtime_id

    def clear_downtimes(self, fqdn):
        downtime_ids = self.get_fqdn_to_downtimes_map(fqdn=fqdn).get(fqdn)
        if downtime_ids is not None:
            self.remove_downtimes(downtime_ids)
        return downtime_ids

    def get_fqdn_to_downtimes_map(self, fqdn=None, suffixes=JugglerDowntimeName.REMOVABLE, only_ids=True):
        hosts = {}

        filters = [sdk.DowntimeSearchFilter(source=self._get_source(suffix), host=fqdn) for suffix in suffixes]

        page_size = 100
        page = 0

        while True:
            page += 1
            downtimes = self._client.get_downtimes(filters=filters, page=page, page_size=page_size)

            for downtime in downtimes.items:
                for dt_filter in downtime.filters:
                    host = hosts.setdefault(dt_filter.host, [])

                    if only_ids:
                        host.append(downtime.downtime_id)
                    else:
                        host.append(downtime)

            if len(downtimes.items) < page_size:
                break  # we've just got last of them or an empty list.

        return hosts


def _get_juggler_client():
    api_kwargs = config.get_value("juggler.client_kwargs")
    if not api_kwargs.get("mark"):
        api_kwargs["mark"] = stand_uid()  # default to the same uuid as used as a source for downtimes.

    return sdk.JugglerApi(**api_kwargs)


def stand_uid():
    return config.get_value("juggler.source")


@client_utils.retry(exceptions=[RequestException], skip=[client_utils.HttpClientError])
def _local_agent_request(data):
    url = "http://{agent_host_port}/events".format(agent_host_port=config.get_value("juggler.agent_host_port"))
    return client_utils.json_request("juggler-agent", "POST", url, data=data, check_status=True, timeout=5)


@contextmanager
def exception_monitor(event_name, err_msg_tmpl=None, exc_classes=(Exception,), log=None, on_exc=None, reraise=False):
    err_msg_tmpl = err_msg_tmpl or "{exc}"
    exc_classes = tuple(exc_classes)

    status = JugglerCheckStatus.OK
    msg = "OK"

    try:
        yield
    except exc_classes as e:
        status = JugglerCheckStatus.CRIT

        if on_exc is not None:
            on_exc(e)

        msg = err_msg_tmpl.format(exc=e)
        if log is not None:
            log.error(msg)

        msg = ellipsis_string(msg, JUGGLER_MSG_MAX_SIZE)

        if reraise:
            raise
    finally:
        send_event(event_name, status, msg)
