import logging

from enum import Enum
from functools import wraps

from sandbox import sdk2
from sandbox.sandboxsdk import environments
from sandbox.projects.common.nanny.client import (
    NannyClient,
    NannyApiException,
    SandboxFileUpdate,
    StaticFileUpdate,
)
from sandbox.projects.release_machine.core import const as rm_const
from sandbox.projects.release_machine import security as rm_sec
from sandbox.projects.release_machine import input_params2 as rm_params

logger = logging.getLogger(__name__)


class FilterType(Enum):
    service_id = "service_id"
    labels = "labels"


def filter_function(use_conjunction, use_negation):
    f = all if use_conjunction else any

    @wraps(f)
    def negate(*args, **kwargs):
        return not f(*args, **kwargs)

    return negate if use_negation else f


class CreateNannySnapshots(sdk2.Task):
    class Requirements(sdk2.Requirements):
        cores = 1  # exactly 1 core
        ram = 4096  # 4GiB or less

        environments = (
            environments.PipEnvironment('retrying'),
        )

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Parameters):
        component_name = rm_params.ComponentName2.component_name()
        release_type = sdk2.parameters.String(
            'Release type',
            choices=list((t, t) for t in (rm_const.ReleaseStatus.testing, rm_const.ReleaseStatus.stable)),
            default=rm_const.ReleaseStatus.testing,
        )

        with sdk2.parameters.Group("Updates") as updates:
            comment = sdk2.parameters.String("Snapshot comment")
            sandbox_files_dict = sdk2.parameters.Dict("Dict with sandbox file changes, mapping local_path -> resource_id", default={})
            static_files_dict = sdk2.parameters.Dict("Dict with static file changes, mapping local_path -> content_string", default={})

        with sdk2.parameters.Group("Nanny services filter") as services_filter:
            filter_type = sdk2.parameters.String("Filter type", choices=[(filter_type.name, filter_type.value) for filter_type in FilterType], default=FilterType.labels.value)

            with filter_type.value[FilterType.labels.value]:
                dashboard_id = sdk2.parameters.String("Dashboard identifier", required=True)
                filter_services_by_labels = sdk2.parameters.Dict("Filter dashboard services by labels", description="If empty, all services from dashboard will be updated")
                use_conjunction = sdk2.parameters.Bool("filter1 & filter2 & ...", default=True)
                use_negation = sdk2.parameters.Bool("!filter", default=False)

            with filter_type.value[FilterType.service_id.value]:
                service_ids = sdk2.parameters.List("List of Nanny services to update")

        with sdk2.parameters.Group("Misc") as misc_options:
            nanny_oauth_token_name = sdk2.parameters.String("Nanny OAuth token", default=None, description="If empty, token for robot-srch-releaser will be used")

        with sdk2.parameters.Output:
            created_snapshots = sdk2.parameters.Dict("Created snapshots")

    def get_service_ids(self, nanny_client):
        if self.Parameters.filter_type == FilterType.service_id.value:
            return self.Parameters.service_ids

        if self.Parameters.filter_type == FilterType.labels.value:
            filter_services_by_labels = self.Parameters.filter_services_by_labels.items()
            logger.info("Filter services by following labels: %s", filter_services_by_labels)
            service_ids = nanny_client.get_dashboard_services(self.Parameters.dashboard_id)
            if not isinstance(service_ids, list) or len(service_ids) == 0:
                logger.error("Expected non-empty list of service_ids, got this: %s", service_ids)
                raise NannyApiException("Expected non-empty list of service_ids")

            filtered_service_ids = []
            _filter_function = filter_function(self.Parameters.use_conjunction, self.Parameters.use_negation)
            for service_id in service_ids:
                service_info = nanny_client.get_service(service_id)
                labels = [(label['key'], label.get('value')) for label in service_info['info_attrs']['content']['labels']]
                if _filter_function([
                    label_key_value in labels
                    for label_key_value in filter_services_by_labels
                ]):
                    filtered_service_ids.append(service_id)
            if not filtered_service_ids:
                raise RuntimeError("No services to update")

            return filtered_service_ids

        raise RuntimeError("No services to update")

    def create_snapshots(self, nanny_client, comment=None):
        service_ids = self.get_service_ids(nanny_client)
        sandbox_updates = {}
        static_updates = {}
        for local_path, resource_id in self.Parameters.sandbox_files_dict.items():
            resource = sdk2.Resource[resource_id]
            sandbox_updates[local_path] = SandboxFileUpdate(
                local_path=local_path,
                resource_id=str(resource_id),
                resource_type=str(resource.type),
                task_id=str(resource.task.id),
                task_type=str(resource.task.type),
            )
        for local_path, file_content in self.Parameters.static_files_dict.items():
            static_updates[local_path] = StaticFileUpdate(
                local_path=local_path,
                content=file_content,
            )

        snapshots = {}
        for service_id in service_ids:
            snapshots[service_id] = nanny_client.update_service_files(service_id, sandbox_updates, static_updates, comment=comment)['runtime_attrs']['_id']
        return snapshots

    def on_execute(self):
        if self.Parameters.nanny_oauth_token_name:
            nanny_token = sdk2.Vault.data(self.Parameters.nanny_oauth_token_name)
        else:
            nanny_token = rm_sec.get_rm_token(self)
        nanny_client = NannyClient(rm_const.Urls.NANNY_BASE_URL, nanny_token)

        with self.memoize_stage.create_snapshots(commit_on_entrance=False):
            self.Parameters.created_snapshots = self.create_snapshots(nanny_client, comment=self.Parameters.comment)
