"""Project-level automation management."""

import logging

import mongoengine
from cachetools.func import ttl_cache

import walle.util.notifications  # noqa
from sepelib.core.exceptions import Error, LogicalError
from walle import util, authorization, audit_log
from walle.errors import ResourceConflictError, DNSAutomationValidationError
from walle.models import timestamp
from walle.projects import (
    Project,
    get_by_id,
    ProjectNotFoundError,
    check_id,
    HEALING_AUTOMATION_LIMIT_NAMES,
    DNS_AUTOMATION_LIMIT_NAMES,
)
from walle.util.misc import drop_none

log = logging.getLogger(__name__)


class _Automation:
    def __init__(self, field, label, limit_names):
        self._automation_field = field
        self._label = label
        self._limit_names = limit_names

    def get_project_ids_with_automation_enabled(self):
        """Returns IDs of projects with enabled automated healing."""

        project_ids = Project.objects(**{self._automation_field + "__enabled": True}).get_ids()
        return project_ids

    @ttl_cache(maxsize=1, ttl=5)
    def get_project_ids_with_automation_enabled_cached(self):
        return set(self.get_project_ids_with_automation_enabled())

    def enabled_for_project_id(self, project_id):
        """Returns True if automation is enabled for the project."""

        try:
            project = get_by_id(project_id, fields=(self._automation_field + "__enabled",))
        except ProjectNotFoundError:
            raise Error("Failed to get automation settings for '{}' project: the project doesn't exist.", project_id)

        return self.enabled_for_project(project)

    def enabled_for_project(self, project):
        return self._update_cached_project_automation(project.id, getattr(project, self._automation_field).enabled)

    def enabled_for_project_cached(self, project_id):
        """Return True if automated healing is enabled for the project.
        Fill/refresh cache if need.

        Return a cached value or False for an unknown project.
        """
        return project_id in self.get_project_ids_with_automation_enabled_cached()

    def enable_automation(self, issuer, project_id, credit=None, credit_time=None, reason=None):
        """Enable automation for the project."""

        credit_end_time = None if credit_time is None else timestamp() + credit_time
        query = {
            "id": project_id,
            (self._automation_field + "__enabled__ne"): True,
        }
        update = {
            ("set__" + self._automation_field): drop_none(
                {
                    "enabled": True,
                    "status_message": reason,
                    "failure_log_start_time": timestamp(),
                    "credit": credit,
                    "credit_end_time": credit_end_time,
                }
            ),
        }

        with self._audit_log(
            issuer, project_id=project_id, enable=True, reason=reason, credit=credit, credit_time=credit_time
        ):
            updated = Project.objects(**query).update(**update)

            if not updated:
                check_id(project_id)
                raise ResourceConflictError(
                    "Rejecting to enable {}: already enabled for the project.".format(self._label)
                )

            self.get_project_ids_with_automation_enabled_cached().add(project_id)
        if updated:
            self._on_enabled(issuer, project_id, reason)

    def disable_automation(self, issuer, project_id, reason):
        """Disable automation for the project."""

        if issuer == authorization.ISSUER_WALLE:
            log.warning("Disable %s for '%s' project: %s", self._label, project_id, reason)

        query = {
            "id": project_id,
            (self._automation_field + "__enabled__ne"): False,
        }
        update = {
            ("set__" + self._automation_field + "__enabled"): False,
            ("set__" + self._automation_field + "__status_message"): reason,
            ("unset__" + self._automation_field + "__credit"): True,
            ("unset__" + self._automation_field + "__credit_end_time"): True,
        }
        with self._audit_log(issuer, project_id=project_id, enable=False, reason=reason):
            updated = Project.objects(**query).update(**update)
            self.get_project_ids_with_automation_enabled_cached().discard(project_id)

            if not updated:
                check_id(project_id)

        if updated:
            self._on_disabled(issuer, project_id, reason)

    def acquire_credit(self, project_id, limit_name):
        query = {
            "id": project_id,
            (self._automation_field + "__enabled"): True,
            (self._automation_field + "__credit_end_time__gt"): timestamp(),
            (self._automation_field + "__credit__" + limit_name + "__gt"): 0,
        }

        update = {("dec__" + self._automation_field + "__credit__" + limit_name): 1}

        return bool(Project.objects(**query).update(multi=False, **update))

    def increase_credit(self, project_id, credit_name, credit_time=None, by_amount=1, reason=None):
        # Increase the credit amount and set the credit end time in different requests.
        # Don't force the "increment" for the credit end time field because we don't want
        #  to increase the credit end time twice if some other process already increased it
        # Also don't force the "set" because we don't want to decrease the credit time
        #  if some other process already increased it
        # so we try to set it and we are Ok if it fails.

        updated = Project.objects(**{"id": project_id, self._field_name("enabled"): True}).update_one(
            **{"inc__" + self._field_name("credit") + "__" + credit_name: by_amount}
        )
        if updated:
            log.info(
                "Increased credit %s for %s project's %s by %d: %s",
                credit_name,
                project_id,
                self._label,
                by_amount,
                reason,
            )
        else:
            log.info("Won't increase automation credit_name for project %s: %s is disabled.", project_id, self._label)
            return False

        if credit_time:
            credit_end_time = timestamp() + credit_time
            # We are ok if this fails, it means we have enough credit time already or automation was disabled
            # while we were calculating or some other stuff that we don't care happened.
            Project.objects(
                mongoengine.Q(**{self._field_name("credit_end_time__lte"): credit_end_time})
                | mongoengine.Q(**{self._field_name("credit_end_time__exists"): False}),
                **{"id": project_id, self._field_name("enabled"): True}
            ).update_one(**{self._field_name("credit_end_time"): credit_end_time})

        return True

    def get_project_limits(self, project_id):
        """Return project's automation limits and the beginning time for these limits."""

        project = get_by_id(
            project_id, fields=("automation_limits", self._automation_field + "__failure_log_start_time")
        )

        return project.automation_limits, getattr(project, self._automation_field).failure_log_start_time

    def get_limit_names(self):
        """Return limit and credit names for this automation type."""
        return self._limit_names

    def get_automation_label(self):
        """Return API automation label for this automation type."""
        return self._label

    def _update_cached_project_automation(self, project_id, enabled):
        if enabled:
            self.get_project_ids_with_automation_enabled_cached().add(project_id)
        else:
            self.get_project_ids_with_automation_enabled_cached().discard(project_id)

        return enabled

    def _audit_log(self, issuer, project_id, enable, reason, credit=None, credit_time=None):
        """Create audit log entry. Implemented in child class."""
        raise LogicalError

    def _on_enabled(self, issuer, project_id, reason):
        """Send notification when automation type enabled. Implemented in child class."""
        raise LogicalError

    def _on_disabled(self, issuer, project_id, reason):
        """Send notification when automation type disabled. Implemented in child class."""
        raise LogicalError

    def _field_name(self, name):
        return self._automation_field + "__" + name


class HealingAutomation(_Automation):
    def _audit_log(self, issuer, project_id, enable, reason, credit=None, credit_time=None):
        return audit_log.on_change_healing_status(
            issuer, project_id=project_id, enable=enable, reason=reason, credit=credit, credit_time=credit_time
        )

    def _on_enabled(self, issuer, project_id, reason):
        return util.notifications.on_healing_automation_enabled(issuer, project_id, reason=reason)

    def _on_disabled(self, issuer, project_id, reason):
        return util.notifications.on_healing_automation_disabled(issuer, project_id, reason=reason)


class DnsAutomation(_Automation):
    def _audit_log(self, issuer, project_id, enable, reason, credit=None, credit_time=None):
        return audit_log.on_change_dns_automation_status(
            issuer, project_id=project_id, enable=enable, reason=reason, credit=credit, credit_time=credit_time
        )

    def _on_enabled(self, issuer, project_id, reason):
        return util.notifications.on_dns_automation_enabled(issuer, project_id, reason=reason)

    def _on_disabled(self, issuer, project_id, reason):
        return util.notifications.on_dns_automation_disabled(issuer, project_id, reason=reason)

    def enable_automation(self, issuer, project_id, credit=None, credit_time=None, reason=None):
        self.validate_project_dns_settings(get_by_id(project_id, fields=("dns_domain", "vlan_scheme")))
        super().enable_automation(issuer, project_id, credit, credit_time, reason)

    @staticmethod
    def validate_project_dns_settings(project):
        for field in ("dns_domain", "vlan_scheme"):
            if project[field] is None:
                raise DNSAutomationValidationError(field, "is not set")


PROJECT_HEALING_AUTOMATION = HealingAutomation(
    field="healing_automation",
    label="automated healing",
    limit_names=HEALING_AUTOMATION_LIMIT_NAMES,
)
PROJECT_DNS_AUTOMATION = DnsAutomation(
    field="dns_automation",
    label="DNS automation",
    limit_names=DNS_AUTOMATION_LIMIT_NAMES,
)
