from yp_proto.yp.client.api.proto import autogen_pb2

from sandbox.common import rest

import logging
import socket
import subprocess
import urllib2


DISK_TAX_IN_GB = 1
BOUND_FOR_DISK_TAX_IN_GB = 10


def enum(**enums):
    return type('Enum', (), enums)


class MigratorBase(object):
    class NotImplementedException(Exception):
        def __init__(self, message, path):
            super(MigratorBase.NotImplementedException, self).__init__('%s at %s' % (message, path))
            self.path = path

    class NotImplementedExceptionList(Exception):
        def __init__(self, errors):
            super(MigratorBase.NotImplementedExceptionList, self).__init__(u'; '.join(unicode(e) for e in errors))
            self.errors = errors

    EImplemented = enum(
        UNKNOWN='UNKNOWN',
        NOT_IMPLEMENTED='NOT_IMPLEMENTED',
        SKIP='SKIP',
        IMPLEMENTED='IMPLEMENTED',
        IMPLEMENTED_LIST_FOR_EACH='IMPLEMENTED_LIST_FOR_EACH',
    )

    def __init__(self, migrator_name, spec, stage_id, raise_on_error, implemented_map, mock_clients):
        self.migrator_name = migrator_name
        self.spec = spec
        self.stage = autogen_pb2.TStage()
        self.__stage_id = stage_id
        self.__raise_on_error = raise_on_error
        self.__implemented_map = implemented_map
        self.__mock_clients = mock_clients

    def _assert_equal_one_of(self, expected_set):
        def f(value, path):
            if value not in expected_set:
                msg = "%s: field = '%s', supported only %s" % (self.EImplemented.NOT_IMPLEMENTED, value, expected_set)
                raise self.NotImplementedException(msg, path)
        return f

    def _warn_if_not_equal_one_of(self, expected_set, message):
        def f(value, path):
            try:
                self._assert_equal_one_of(expected_set)(value, path)
            except self.NotImplementedException as e:
                logging.warning('%s: %s', message, e)
        return f

    def _assert_equal_one_of_enum(self, expected_set):
        def f(value, path):
            try:
                self._assert_equal_one_of(expected_set)(value, path)
            except self.NotImplementedException as e:
                e.path += [value]
                raise e
        return f

    def _assert_equal(self, expected):
        return self._assert_equal_one_of([expected])

    def _warn_if_not_equal(self, expected, message):
        return self._warn_if_not_equal_one_of([expected], message)

    def _get_pod_agent_payload_spec(self, unit_id):
        return self._get_pod_template_spec_spec(unit_id).pod_agent_payload.spec

    def _get_pod_template_spec_spec(self, unit_id):
        return self.stage.spec.deploy_units[unit_id].multi_cluster_replica_set.replica_set.pod_template_spec.spec

    def _get_box(self, box_id):
        box = None
        for it in self._get_pod_agent_payload_spec(box_id).boxes:
            if it.id == box_id:
                box = it
        if not box:
            box = self._get_pod_agent_payload_spec(box_id).boxes.add()
        box.id = box_id
        return box

    def _get_workload(self, unit_id, workload_id=None):
        if workload_id is None:
            workload_id = unit_id

        workload = None
        for it in self._get_pod_agent_payload_spec(unit_id).workloads:
            if it.id == workload_id:
                workload = it
        if not workload:
            workload = self._get_pod_agent_payload_spec(unit_id).workloads.add()
        workload.id = workload_id
        workload.box_ref = unit_id
        return workload

    def _resolve_ip6(self, host):
        if self.__mock_clients:
            return '::1'
        else:
            return socket.getaddrinfo(host, None, socket.AF_INET6, 0, socket.IPPROTO_TCP)[0][4][0]

    def _get_url_content(self, url):
        if self.__mock_clients:
            return ''
        else:
            return urllib2.urlopen(url).read()

    def _get_sandbox_resource(self, id):
        if self.__mock_clients:
            return {
                'http': {
                    'proxy': 'https://proxy.sandbox.yandex-team.ru/' + str(id),
                },
            }
        else:
            return rest.Client().resource[id].read()

    def _get_sandbox_resources_from_task_by_type(self, id, type):
        if self.__mock_clients:
            return {
                'http': {
                    'proxy': 'https://proxy.sandbox.yandex-team.ru/' + str(id),
                },
                'skynet_id': 'rbtorrent:blabla' + str(id),
                'id': id,
                'type': type
            }
        else:
            resources = rest.Client().task[id].resources.read()['items']
            for resource in resources:
                if resource['type'] == type:
                    return resource
            raise Exception("Could not find resource (maybe, is has been removed?) type %s at resources list %s" % (type, resources))

    def _migrate_with_implementation_map(self, spec, implemented_map, path):
        if callable(implemented_map):
            try:
                implemented_map(spec, path)
            except self.NotImplementedException as e:
                return [e]
            except self.NotImplementedExceptionList as e:
                return e.errors
            return []
        elif isinstance(implemented_map, str):
            if implemented_map == self.EImplemented.SKIP:
                logging.info('%s: field %s' % (implemented_map, path))
            elif implemented_map != self.EImplemented.IMPLEMENTED:
                return [self.NotImplementedException('%s: field %s' % (implemented_map, path), path)]
            return []

        errors = []
        if isinstance(spec, dict):
            if not isinstance(implemented_map, dict):
                return [self.NotImplementedException('unexpected dict at %s' % path, path)]
            for field in spec:
                sub_path = path + [field]
                if field not in implemented_map:
                    errors.append(self.NotImplementedException('field %s is not at IMPLEMENTED_MAP' % (sub_path), sub_path))
                    continue
                errors.extend(self._migrate_with_implementation_map(spec[field], implemented_map[field], sub_path))
        elif isinstance(spec, list):
            if not isinstance(implemented_map, dict) or self.EImplemented.IMPLEMENTED_LIST_FOR_EACH not in implemented_map:
                return [self.NotImplementedException('unexpected list at %s' % path, path)]
            if self.EImplemented.IMPLEMENTED in implemented_map:
                errors.extend(self._migrate_with_implementation_map(spec, implemented_map[self.EImplemented.IMPLEMENTED], path))
            for i in range(len(spec)):
                sub_path = path + [i]
                errors.extend(self._migrate_with_implementation_map(spec[i], implemented_map[self.EImplemented.IMPLEMENTED_LIST_FOR_EACH], sub_path))
        else:
            return [self.NotImplementedException('unexpected field %s' % path, path)]
        return errors

    def _get_stage_id(self, object_id):
        result = self.__stage_id or object_id.replace('.', '_')
        if len(result) > 40:
            result = result[:40]
            logging.warning("Name shortened to %s, to change use --stage-id" % result)
        return result

    def _set_stage_id(self, object_id, path):
        self.stage.meta.id = self._get_stage_id(object_id)
        self.stage.meta.project_id = str(object_id.split('.')[0])

        label = self.stage.labels.attributes.add()
        label.key = 'migration_from'
        label.value = str(object_id)
        label = self.stage.labels.attributes.add()
        label.key = 'migration_type'
        label.value = str(self.migrator_name)

    def _create_yav_token(self, client, secret_uuid, object_id, unit_id):
        secret_delegation_token, secret_delegation_token_uuid = client.create_token(
            secret_uuid,
            tvm_client_id=2001151,  # YP
            signature=self._get_stage_id(object_id) + '.' + unit_id,  # pod_set_id
            comment='token for Nanny/Qloud -> Deploy migrated env %s' % unit_id
        )
        return secret_delegation_token

    def _create_yav_secret(self, client, yav_secret_alias, path):
        sec_list = client.list_secrets()
        secret_uuid = None
        for sec in sec_list:
            if sec['name'] == yav_secret_alias:
                if secret_uuid:
                    raise self.NotImplementedException('found two secret uuid with "%s" alias, delete one of specify another' % yav_secret_alias, path)
                secret_uuid = sec['uuid']
        if not secret_uuid:
            secret_uuid = client.create_secret(yav_secret_alias)
        return secret_uuid

    def _create_vault_token(self, client, secret_id, secret_version, object_id, unit_id, deploy_secret_id):
        secret_delegation_token = self._create_yav_token(
            client=client,
            secret_uuid=secret_id,
            object_id=object_id,
            unit_id=unit_id
        )
        secret = self._get_pod_template_spec_spec(unit_id).secrets[deploy_secret_id]
        secret.secret_id = secret_id
        secret.secret_version = secret_version
        secret.delegation_token = secret_delegation_token

    def _run_remote_call_on_instance(self, instance, command):
        sub_command = ['ssh', instance, command]
        logging.info(sub_command)
        try:
            output = subprocess.check_output(sub_command)
        except Exception as e:
            raise Exception(
                'Error at ssh call. Please check that instance is active and accessible from current working machine (check Puncher rules). Error: %s' % e,
            )

        return output

    def _migrate_spec(self):
        errors = self._migrate_with_implementation_map(self.spec, self.__implemented_map, path=[])
        if errors:
            if self.__raise_on_error:
                raise self.NotImplementedExceptionList(errors)
            else:
                for e in errors:
                    logging.warning(e)
