import logging
import requests

from sandbox.common import auth
from sandbox.common import errors
from sandbox.common import rest
from sandbox.common import urls as common_urls
from sandbox.common.types import resource as ctr
from sandbox.common.types import misc as ctm

from sandbox import sdk2

from sandbox.projects.sandbox import create_resource as cr_task


VAULT_TEMPLATE = {
    "data": None,
    "description": None,
    "name": None,
    "owner": None,
    "shared": None,
}

MAX_RESOURCE_LIMIT = 1000
DEFAULT_SECRET_KEY = "sandbox-oauth-token"
PRODUCTION_ORIGIN_RID_ATTR = "origin_production_rid"


class RequestedResources(sdk2.parameters.JSON):
    @classmethod
    def cast(cls, value):
        value = super(RequestedResources, cls).cast(value)

        if value:
            if not isinstance(value, list):
                raise ValueError("Unrecognized format, list of dicts with filters required")

            for filter_ in value:
                if not isinstance(filter_, dict):
                    raise ValueError("Filters must be dict")

                if not filter_.get("id") and (not filter_.get("limit") or not filter_.get("type")):
                    raise ValueError("Each filter must contain 'id' or 'limit' and 'type' params")

        return value


class PreprodSync(sdk2.Task):
    MAX_GROUPS_PER_ITERATION = 1000

    exception_logs = []

    class Parameters(sdk2.Parameters):
        enable_groups_sync = sdk2.parameters.Bool("Enable groups synchronization", default=True)
        sandbox_token = sdk2.parameters.YavSecret(
            "Yav secret with Sandbox OAuth token. Default key: '{}'".format(DEFAULT_SECRET_KEY),
            required=True,
        )
        requested_resources = RequestedResources("List of filters (dicts) to find resources to sync", default=[])
        requested_secrets = sdk2.parameters.List("Pairs 'owner, secret_name' of secrets to sync", default=[])

        with sdk2.parameters.Output:
            exceptions = sdk2.parameters.String("All exceptions")

    def get_all_groups(self, api):
        all_groups = []
        offset = 0
        total = 0
        query = {
            "limit": self.MAX_GROUPS_PER_ITERATION,
        }

        while offset < total or offset == 0:
            query["offset"] = offset
            response = api.group.read(**query)
            total = response.get("total", 0)
            all_groups += response.get("items", [])
            offset += self.MAX_GROUPS_PER_ITERATION

        # Sort sources' groups to compare them in future
        for group in all_groups:
            for source in group["sources"]:
                source["group"] = ", ".join(sorted(source["group"].split(", ")))

        return {group["name"]: group for group in all_groups}

    def sync_groups(self, api_prod, api_preprod):
        prod_groups_dict = self.get_all_groups(api_prod)
        preprod_groups_dict = self.get_all_groups(api_preprod)

        for name, group in prod_groups_dict.items():
            logging.info("Syncing group %s", name)
            if name not in preprod_groups_dict:
                logging.info("Didn't find group in preprod. Creating new one")
                try:
                    api_preprod.group.create(group)
                except requests.HTTPError as ex:
                    self.exception_logs.append("{} (group {})".format(str(ex.response.text).rstrip(), name))
                    logging.exception(ex)
            elif group["sources"] != preprod_groups_dict[name]["sources"] or group["abc"] != preprod_groups_dict[name]["abc"]:
                logging.info(
                    "Found in preprod. Diff:\nSources in prod: %s\nSources in preprod: %s",
                    group["sources"],
                    preprod_groups_dict[name]["sources"],
                )
                try:
                    api_preprod.group[name] = group
                except requests.HTTPError as ex:
                    self.exception_logs.append("{} (group {})".format(str(ex.response.text).rstrip(), name))
                    logging.exception(ex)
                except api_preprod.TimeoutExceeded:
                    self.exception_logs.append("Timeout exceeded (group {})".format(name))
                    logging.exception("Timeout exceeded (group %s)", name)

            logging.info("Group %s processed", name)

    def sync_vaults(self, api_prod, api_preprod):
        for secret_data in self.Parameters.requested_secrets:
            logging.info("Syncing secret %s", secret_data)
            owner, name = secret_data.replace(" ", "").split(",")
            secret_prod = api_prod.vault.data.read({"name": name, "owner": owner})

            secret_to_put = VAULT_TEMPLATE.copy()
            for field in secret_to_put:
                secret_to_put[field] = secret_prod[field]
            secret_to_put["data"] = sdk2.Vault.data(owner, name)

            secret_ids = api_preprod.vault.read({"limit": 1, "name": name, "owner": owner})
            if secret_ids["items"]:
                logging.info(
                    "Found a secret with the same pair owner: name. Id: %s. Updating it's value",
                    secret_ids["items"][0]["id"],
                )
                api_preprod.vault[secret_ids["items"][0]["id"]] = secret_to_put
            else:
                logging.info(
                    "Didn't find secrets with the same pair owner: name. Creating new secret."
                )
                api_preprod.vault.create(secret_to_put)

            logging.info("Secret %s successfully synced", secret_data)

    def get_resources_ids(self, api_prod):
        resources_ids = []

        for filter_ in self.Parameters.requested_resources:
            logging.info("Getting id's of resources with filter %s", filter_)
            if filter_.get("id"):
                if isinstance(filter_["id"], list):
                    resources_ids += filter_["id"]
                else:
                    resources_ids.append(filter_["id"])
            else:
                resources_list = api_prod.resource.read(filter_)["items"]

                for resource in resources_list:
                    resources_ids.append(resource["id"])

            logging.info("Id's to sync: %s", ", ".join(set(map(str, resources_ids))))

        return set(resources_ids)

    def get_backup_resource_on_preprod_id(self, api_preprod, resource_data):
        resource_query = {
            "limit": 1,
            "type": resource_data["type"],
            "attrs": {
                PRODUCTION_ORIGIN_RID_ATTR: resource_data["id"],
            },
            "state": ctr.State.READY,
        }
        resources = api_preprod.resource.read(resource_query)

        if not resources["items"]:
            return None
        else:
            return resources["items"][0]["id"]

    @staticmethod
    def _group_resources_by_path(resources):
        chunk = []
        unique_paths = set()
        for res_info in resources:
            if res_info.resource_name in unique_paths:
                yield chunk
                chunk = []
                unique_paths = set()
            chunk.append(res_info)
            unique_paths.add(res_info.resource_name)
        if chunk:
            yield chunk

    def sync_resources(self, api_prod, api_preprod):
        resources_ids = self.get_resources_ids(api_prod)

        resources_cfg = {}
        already_in_preprod = {}
        for rid in resources_ids:
            resource = api_prod.resource[rid].read()

            preprod_rid = self.get_backup_resource_on_preprod_id(api_preprod, resource)
            if preprod_rid is None:
                attrs = resource["attributes"]
                attrs["backup_task"] = True
                attrs[PRODUCTION_ORIGIN_RID_ATTR] = rid

                resources_of_owner = resources_cfg.setdefault(resource["owner"], [])
                resources_of_owner.append(
                    cr_task.MetaInfo(
                        protocol=cr_task.Protocol.SKYNET,
                        src_path=resource["skynet_id"],
                        description=resource["description"],
                        resource_name=resource["file_name"],
                        resource_type=resource["type"],
                        attributes=attrs,
                        resource_arch=resource["arch"],
                        rsync_copy_links=None,
                        id=rid,
                    )
                )
            else:
                already_in_preprod[rid] = preprod_rid

        if already_in_preprod:
            logging.info("Resource already in preprod (prod_rid => preprod_rid): %s", already_in_preprod)

        for group, groupped_resources in resources_cfg.items():
            for chunk in self._group_resources_by_path(groupped_resources):
                description_lines = ["Resources migration. Original resources:"]
                resources_cfg = []
                for res_info in chunk:
                    description_lines.append(common_urls.get_resource_link(res_info.id))
                    resources_cfg.append(res_info.__dict__)

                custom_fields = {
                    cr_task.CreateResource.Parameters.create_multiple_resources.name: True,
                    cr_task.CreateResource.Parameters.multi_config.name: resources_cfg,
                }

                data = {
                    "type": "CREATE_RESOURCE",
                    "description":  "\n".join(description_lines),
                    "owner": group,
                    "custom_fields": [{"name": name, "value": value} for name, value in custom_fields.items()]
                }
                task = api_preprod.task(data)
                logging.info("Synchronization task created: %s", task["id"])
                resp = api_preprod.batch.tasks.start.update(task["id"])[0]

                if resp["status"] == ctm.BatchResultStatus.ERROR:
                    raise errors.TaskFailure("Failed to start task #{}: {}".format(task["id"], resp))

                if resp["status"] != ctm.BatchResultStatus.SUCCESS:
                    logging.warning("Task #%s started with warning: %s", task["id"], resp)

                logging.info("Synchronization started for resources: %s", [_.id for _ in chunk])

    def on_execute(self):
        api_prod = rest.Client(total_wait=20)
        secret_key = self.Parameters.sandbox_token.default_key or DEFAULT_SECRET_KEY
        oauth = auth.OAuth(self.Parameters.sandbox_token.data()[secret_key])
        api_preprod = rest.Client(
            base_url="https://www-sandbox1.n.yandex-team.ru/api/v1.0",
            auth=oauth,
            total_wait=60,
        )

        if self.Parameters.enable_groups_sync:
            self.sync_groups(api_prod, api_preprod)
        self.sync_vaults(api_prod, api_preprod)
        self.sync_resources(api_prod, api_preprod)

        logging.info("\n".join(self.exception_logs))
        self.Parameters.exceptions = "\n".join(self.exception_logs)

        if len(self.exception_logs) > 30:
            raise errors.TaskFailure("Too many exceptions")
