"""CMS client."""

import logging
import re
from functools import partial

import requests
import simplejson as json
from requests import RequestException
from six import wraps

import walle.clients.utils as utils
from object_validator import String, List
from sepelib.core import config
from sepelib.core.constants import MINUTE_SECONDS, HOUR_SECONDS, WEEK_SECONDS
from sepelib.core.exceptions import LogicalError
from walle.clients import juggler, tvm
from walle.errors import RecoverableError
from walle.statbox.loggers import cms_logger
from walle.stats import stats_manager
from walle.util import notifications
from walle.util.misc import drop_none, StopWatch, filter_dict_keys
from walle.util.validation import ApiDictScheme

log = logging.getLogger(__name__)

CMS_API_REQUEST_TIMEOUT = 60  # seems to big? see WALLESUPPORT-1096
CMS_API_DRY_RUN_TIMEOUT = 1  # don't need to wait for 30 seconds for this


class CmsError(RecoverableError):
    pass


class CmsConnectionError(CmsError):
    pass


class CmsTaskRejectedError(CmsError):
    pass


class CmsApiError(CmsError):
    def __init__(self, status_code, message, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.message = message
        self.status_code = status_code


class CmsApiVersion:
    V1_0 = "v1.0"
    V1_1 = "v1.1"
    V1_2 = "v1.2"
    V1_3 = "v1.3"
    V1_4 = "v1.4"

    # NB: UI shows versions in the reverse order (from the most recent to the oldest),
    # last item here is the first (and the default) value in UI.
    ALL_CMS_API = [V1_0, V1_1, V1_2, V1_3, V1_4]
    ALL = ALL_CMS_API


class CmsTaskType:
    MANUAL = "manual"
    AUTOMATED = "automated"

    ALL = [MANUAL, AUTOMATED]


class CmsTaskAction:
    PREPARE = "prepare"
    DEACTIVATE = "deactivate"

    POWER_OFF = "power-off"
    REBOOT = "reboot"
    PROFILE = "profile"
    TEMPORARY_UNREACHABLE = "temporary-unreachable"
    REDEPLOY = "redeploy"

    REPAIR_LINK = "repair-link"
    CHANGE_DISK = "change-disk"

    V1_0_ACTIONS = [POWER_OFF, REBOOT, PROFILE, REDEPLOY, REPAIR_LINK, CHANGE_DISK]
    V1_1_ACTIONS = [PREPARE, DEACTIVATE] + V1_0_ACTIONS
    V1_2_ACTIONS = V1_1_ACTIONS
    V1_3_ACTIONS = V1_2_ACTIONS
    V1_4_ACTIONS = V1_3_ACTIONS + [TEMPORARY_UNREACHABLE]
    # NB: ALL != VX_Y_ACTIONS.
    ALL = [PREPARE, DEACTIVATE, POWER_OFF, REBOOT, PROFILE, TEMPORARY_UNREACHABLE, REDEPLOY, REPAIR_LINK, CHANGE_DISK]

    # This mapping is needed for YP. We may need to add more parameters later, but this is enough for a prototype.
    TASK_MAINTENANCE_MAPPING = {
        POWER_OFF: {"destructive": False, "period": WEEK_SECONDS},
        REBOOT: {"destructive": False, "period": 20 * MINUTE_SECONDS},
        REDEPLOY: {"destructive": True, "period": 40 * MINUTE_SECONDS},
        PROFILE: {"destructive": False, "period": 3 * HOUR_SECONDS},
        TEMPORARY_UNREACHABLE: {"destructive": False, "period": HOUR_SECONDS},
        REPAIR_LINK: {"destructive": False, "period": 0},  # yep, no downtime expected
        CHANGE_DISK: {"destructive": True, "period": 2 * HOUR_SECONDS},
        PREPARE: {"destructive": True, "period": 0},
        DEACTIVATE: {"destructive": True, "period": 0},
    }


class CmsTaskStatus:
    REJECTED = "rejected"
    IN_PROCESS = "in-process"
    OK = "ok"

    ALL = [REJECTED, IN_PROCESS, OK]


class MaintenancePriority:
    NONE = "none"
    NORMAL = "normal"
    HIGH = "high"

    ALL = [NORMAL, HIGH]


_CMS_TASK_SCHEME = ApiDictScheme(
    {
        "id": String(min_length=1),
        "type": String(choices=CmsTaskType.ALL, optional=True),
        "issuer": String(min_length=1, optional=True),
        # TODO: Should be a list. Sometimes we need to process a sequence of actions: for example, change disk and then
        # redeploy.
        "action": String(choices=CmsTaskAction.ALL, optional=True),
        "hosts": List(String(min_length=1)),
        "status": String(choices=CmsTaskStatus.ALL),
        "message": String(optional=True),
    }
)


def _normalize_cms_name(prefix, cms_url, cms_action):
    cms_name = cms_url.replace("https://", "").replace("http://", "")
    cms_name = re.sub(r"[/\\:#%&?]+", "-", cms_name)

    _max_string_length = 128

    cms_signal_name_common = ".".join([prefix, cms_name])[:_max_string_length]
    cms_signal_name_action = ".".join([prefix, cms_action, cms_name])[:_max_string_length]

    return cms_signal_name_common, cms_signal_name_action


def _yasm_stats(label):
    def _decorator(method):
        @wraps(method)
        def _decorated_method(self, *args, **kwargs):
            cms_name = "default" if self.name.startswith("default.") else self.name
            cms_signal_name_common, cms_signal_name_action = _normalize_cms_name(label, cms_name, method.__name__)

            result = "success"
            stopwatch = StopWatch()
            try:
                return method(self, *args, **kwargs)

            except CmsTaskRejectedError:
                raise
            except requests.ConnectionError:
                result = "connection_error"
                raise
            except requests.Timeout:
                result = "timeout"
                raise
            except Exception:
                result = "error"
                raise

            finally:
                stats_manager.add_sample((cms_signal_name_common, result, "time"), stopwatch.get())
                stats_manager.add_sample((cms_signal_name_action, result, "time"), stopwatch.get())
                stats_manager.increment_counter((cms_signal_name_common, result, "count"))
                stats_manager.increment_counter((cms_signal_name_action, result, "count"))

        return _decorated_method

    return _decorator


def _send_to_juggler(label):
    def _decorator(method):
        @wraps(method)
        def _decorated_method(self, *args, **kwargs):
            cms_action = method.__name__
            cms_version = self.api_version
            cms_project = getattr(self, "project_id", None)

            event_host_name_common, event_host_name_action = _normalize_cms_name(
                "wall-e.{}".format(label), self.name, cms_action
            )

            tags = [
                event_host_name_common,
                "wall-e.{}.version.{}".format(label, cms_version),
                "wall-e.{}.{}".format(label, cms_action),
            ]
            if self.name.startswith("default."):
                tags.append("wall-e.default-cms")
            if cms_project is not None:
                tags.extend(
                    [
                        "wall-e.{}.{}".format(label, cms_project),
                        "wall-e.project.{}".format(cms_project),
                    ]
                )

            status = juggler.JugglerCheckStatus.OK
            message = "Request to {} to CMS {} was successful.".format(cms_action, self.name)
            message_failure = "Request to {} to CMS {} failed: {{}}.".format(cms_action, self.name)

            send = partial(juggler.send_event, "wall-e-cms-request", tags=tags)

            try:
                return method(self, *args, **kwargs)

            except CmsTaskRejectedError:
                raise

            except Exception as e:
                message = message_failure.format(str(e))
                status = juggler.JugglerCheckStatus.CRIT
                raise

            finally:
                log.info("Sending events for %s %s to juggler...", event_host_name_common, event_host_name_action)
                send(status, host_name=event_host_name_common, message=message)
                send(status, host_name=event_host_name_action, message=message)
                log.info("Sent events for %s %s to juggler", event_host_name_common, event_host_name_action)

        return _decorated_method

    return _decorator


class CmsNamespace:
    @classmethod
    def namespace(cls):
        return config.get_value("cms.namespace")

    @classmethod
    def add_namespace(cls, task_id):
        return "-".join([(cls.namespace()), task_id])

    @classmethod
    def namespace_filter(cls, field_name="id"):
        namespace = cls.namespace() + "-"

        def namespace_filter(task):
            return task[field_name].startswith(namespace)

        return namespace_filter


class _CmsClientMeta(type):
    _VERSIONS_MAPPING = {}

    def __new__(mcs, *args, **kwargs):
        cls = super().__new__(mcs, *args, **kwargs)

        mcs.register(cls)
        return cls

    @classmethod
    def register(mcs, cls):
        if cls.api_version is not None:
            mcs._VERSIONS_MAPPING[cls.api_version] = cls

    @classmethod
    def get_client_class(mcs, version):
        return mcs._VERSIONS_MAPPING[version]


class _BaseCmsClient(metaclass=_CmsClientMeta):
    """'Base CMS API client' implements all underlying logic and is basically a client for the latest
    version of CMS api protocol. The purpose is to have single common interface for clients for any CMS API version.

    Version-specific clients should encapsulate version-specific logic and expose
    same common interface so that stages don't have to check project's CMS API version.
    Most often this means that version-specific clients just silently discard some extra
    fields of data and provide fake values for unsupported requests.

    This class exposes a reference implementation of the said interface.
    """

    api_version = None
    _supported_kwargs = None
    _supported_actions = None

    def __init__(self, name, project_id, url, url_opener=None, query_params=None, context_logger=None, use_tvm=False):
        self.name = name
        self.project_id = project_id
        self.use_tvm = use_tvm
        self.__base_url = url.rstrip("/")
        self.__query_params = {} if query_params is None else query_params
        self.__url_opener = _UrlOpener() if url_opener is None else url_opener
        self.context_logger = (context_logger or cms_logger()).get_child(cms_name=name)

    @_send_to_juggler("cms")
    @_yasm_stats("cms")
    def add_task(self, id, type, issuer, action, hosts, dry_run=False, **task_kwargs):
        if self._supported_kwargs is not None:
            task_kwargs = filter_dict_keys(task_kwargs, self._supported_kwargs)

        downscaled_action = self._downscale_action(action)
        cms_logger = self.context_logger.get_child(
            cms_task_id=id,
            type=type,
            issuer=issuer,
            cms_name=self.name,
            cms_action=downscaled_action,
            walle_action="cms_add_task",
            dry_run=dry_run,
        )
        for host in hosts:
            cms_logger.log(hostname=host)

        try:
            result = self.__post(
                "/tasks",
                drop_none(dict(id=id, type=type, issuer=issuer, action=downscaled_action, hosts=hosts, **task_kwargs)),
                params={"dry_run": "true" if dry_run else None},
                scheme=_CMS_TASK_SCHEME,
                timeout=CMS_API_DRY_RUN_TIMEOUT if dry_run else None,
            )
        except CmsError as e:
            cms_logger.log(exception=str(e))
            raise

        message = result.get("message")
        response_status = result["status"]
        response_message = message or "-"
        for host in hosts:
            cms_logger.log(hostname=host, response_status=response_status, response_message=response_message)

        if result["status"] == CmsTaskStatus.REJECTED:
            raise CmsTaskRejectedError(message or "[no message provided]", result=result)

        return result

    @_send_to_juggler("cms")
    @_yasm_stats("cms")
    def get_task(self, task_id):
        return self.__get("/tasks/" + task_id, scheme=_CMS_TASK_SCHEME, allowed_errors=(requests.codes.not_found,))

    @_send_to_juggler("cms")
    @_yasm_stats("cms")
    def get_tasks(self):
        return self.__get("/tasks", scheme=ApiDictScheme({"result": List(_CMS_TASK_SCHEME)}))["result"]

    @_send_to_juggler("cms")
    @_yasm_stats("cms")
    def delete_task(self, task_id):
        cms_logger = self.context_logger.get_child(
            cms_task_id=task_id, cms_name=self.name, walle_action="cms_delete_task"
        )
        allowed_errors = [requests.codes.no_content, requests.codes.not_found]
        try:
            cms_logger.log()
            self.__delete("/tasks/" + task_id, allowed_errors=allowed_errors)
        except CmsError as e:
            cms_logger.log(exception=str(e))
            raise

    def __get(self, method, params=None, scheme=None, allowed_errors=None):
        return self.__call("GET", method, params=params, scheme=scheme, allowed_errors=allowed_errors)

    def __post(self, method, request, params=None, scheme=None, allowed_errors=None, timeout=None):
        return self.__call(
            "POST",
            method,
            params=params,
            request=request,
            scheme=scheme,
            allowed_errors=allowed_errors,
            timeout=timeout,
        )

    def __delete(self, method, request=None, params=None, scheme=None, allowed_errors=None):
        return self.__call(
            "DELETE", method, params=params, request=request, scheme=scheme, allowed_errors=allowed_errors
        )

    def __call(self, method, path, params=None, request=None, scheme=None, allowed_errors=None, timeout=None):
        url = self.__base_url + path
        params = drop_none(dict(self.__query_params, **({} if params is None else params))) or None
        headers = {}
        data = None

        headers.update(self._get_auth_headers())

        if request is not None:
            headers["Content-Type"] = "application/json"
            data = json.dumps(request)

        try:
            response = self.__url_opener.open(method, url, params=params, headers=headers, data=data, timeout=timeout)
        except RequestException as e:
            raise CmsConnectionError(
                "{} {} request to {} CMS failed: {}",
                method,
                path,
                self.name,
                e,
                method=method,
                path=path,
                project=self.project_id,
                name=self.name,
            )

        try:
            if response.status_code in (requests.codes.ok, requests.codes.created):
                return utils.get_json_response(
                    response, scheme=scheme, check_status=False, errors_log_level=utils.logging.DEBUG
                )
        except RequestException as e:
            notifications.on_cms_api_error(e, url, self.project_id, response)
            raise CmsConnectionError(
                "{} {} request to {} CMS failed: {}",
                method,
                path,
                self.name,
                e,
                response=response,
                method=method,
                path=path,
                project=self.project_id,
                name=self.name,
            )

        if (
            response.status_code in (requests.codes.not_found, requests.codes.no_content)
            and allowed_errors
            and response.status_code in allowed_errors
        ):
            return None

        try:
            error = utils.get_json_response(
                response,
                scheme=ApiDictScheme({"message": String()}),
                check_status=False,
                errors_log_level=utils.logging.DEBUG,
            )["message"]
        except RequestException:
            error = "Server returned an error: {}.".format(response.reason)

        notifications.on_cms_api_error(error, url, self.project_id, response)
        raise CmsApiError(
            response.status_code,
            error,
            "{} {} request to {} CMS failed: {}",
            method,
            path,
            self.name,
            error,
            response=response,
            method=method,
            path=path,
            project=self.project_id,
            name=self.name,
        )

    def _get_auth_headers(self):
        auth_headers = {}

        if self.use_tvm:
            from walle.projects import project_id_to_cms_tvm_alias

            tvm_alias = project_id_to_cms_tvm_alias(self.project_id, self.__base_url)
            auth_headers["X-Ya-Service-Ticket"] = tvm.get_ticket_for_service(tvm_alias)
        else:
            access_token = config.get_value("cms.access_token", None)
            if access_token is not None:
                auth_headers["Authorization"] = "OAuth {}".format(access_token)

        return auth_headers

    def _downscale_action(self, action):
        if action not in self._supported_actions:
            if action == CmsTaskAction.TEMPORARY_UNREACHABLE:
                log.info("Downscaling action %s to %s", CmsTaskAction.TEMPORARY_UNREACHABLE, CmsTaskAction.REBOOT)
                return CmsTaskAction.REBOOT
            else:
                raise LogicalError()

        return action


class CmsClientV1_4(_BaseCmsClient):
    """Client for cms api v1.4. New action "temporary-unreachable" was added"""

    api_version = CmsApiVersion.V1_4
    _supported_actions = CmsTaskAction.V1_4_ACTIONS


class CmsClientV1_3(_BaseCmsClient):
    """Client for cms api v1.3. New optional task field was added ("extra") which was a backward compatibility
    breaking change (yes, again).
    We decided to create new cms-api version so that existing implementation doesn't break
    and(!) the new version requires cms api to accept and ignore unknown task fields,
    that's why no `_supported_kwargs` here."""

    api_version = CmsApiVersion.V1_3
    _supported_actions = CmsTaskAction.V1_3_ACTIONS


class CmsClientV1_2(_BaseCmsClient):
    """Client for cms api v1.2. New optional task field was added ("comment") which was a backward compatibility
    breaking change (yes, that was not expected).
    We decided to create new cms-api version so that existing implementation doesn't break."""

    api_version = CmsApiVersion.V1_2
    _supported_kwargs = {"comment"}
    _supported_actions = CmsTaskAction.V1_2_ACTIONS


class CmsClientV1_1(_BaseCmsClient):
    """Client for cms api v1.1.
    Some new actions were added, this was important enough to bump protocol version.
    The need for the new version came from the possibility that old cms api implementations could reject new actions
    because they didn't know how to deal with them. This was not acceptable, so the new cms api version was created.

    The new actions were "add host" and "delete host".
    """

    api_version = CmsApiVersion.V1_1
    _supported_kwargs = {}
    _supported_actions = CmsTaskAction.V1_1_ACTIONS


class CmsClientV1_0(_BaseCmsClient):
    """Client for the very first cms-api protocol version.
    This class knows that it does not support some cms-api actions from newer cms-api protocol versions.
    """

    api_version = CmsApiVersion.V1_0
    _supported_kwargs = {}
    _supported_actions = CmsTaskAction.V1_0_ACTIONS

    def add_task(self, id, type, issuer, action, hosts, dry_run=False, **task_kwargs):
        # do not send actions that CMS API does not support. Provide Ok for these.
        if action in [CmsTaskAction.DEACTIVATE, CmsTaskAction.PREPARE]:
            message = "CMS API {version} does not support '{action}' host tasks. Allowing by default.".format(
                version=self.api_version, action=action
            )

            return _fake_allow_task(id, type, issuer, action, hosts, message, **task_kwargs)

        return super().add_task(id, type, issuer, action, hosts, dry_run=dry_run, **task_kwargs)


def get_cms_task_type(walle_task_type):
    from walle.hosts import TaskType

    if walle_task_type == TaskType.MANUAL:
        return CmsTaskType.MANUAL
    elif walle_task_type in TaskType.ALL_AUTOMATED:
        return CmsTaskType.AUTOMATED
    else:
        raise LogicalError()


def _fake_allow_task(id, type, issuer, action, hosts, message, status=CmsTaskStatus.OK, **kwargs):
    """Take task parameters but instead of submitting the task to the real CMS, just
    return a fake task, allowing the operation.

    'DELETE' command goes to the real CMS, though (it should get a 404 and it should be OK by design).
    """
    return {
        "id": id,
        "type": type,
        "issuer": issuer,
        "action": action,
        "hosts": hosts,
        "status": status,
        "message": message,
    }


class _UrlOpener:
    def open(self, method, url, params=None, headers=None, data=None, timeout=None):
        return utils.request(
            "cms",
            method,
            url,
            params=params,
            headers=headers,
            data=data,
            timeout=timeout if timeout is not None else CMS_API_REQUEST_TIMEOUT,
            check_status=False,
            errors_log_level=utils.logging.DEBUG,
        )


def get_cms_client(version=None):
    # currently, assume by default the most conservative version.
    if version is None:
        version = CmsApiVersion.V1_0

    if version not in CmsApiVersion.ALL_CMS_API:
        raise LogicalError

    return _CmsClientMeta.get_client_class(version)


def make_maintenance_info(cms_task_id, action, hosts, node_set_id=None, comment=None, issuer=None):
    yp_action_info = CmsTaskAction.TASK_MAINTENANCE_MAPPING[action]
    task_id = str(cms_task_id)
    labels = dict(issuer=issuer) if issuer else None

    return drop_none(
        {
            "id": task_id,
            "kind": action,
            "message": comment,
            "disruptive": yp_action_info["destructive"],
            "estimated_duration": yp_action_info["period"],
            "node_set_id": node_set_id or task_id,
            "priority": MaintenancePriority.NORMAL,
            "labels": labels,
        }
    )
