import requests
import logging

import sandbox.sandboxsdk.task as sdk_task
import sandbox.sandboxsdk.parameters as sdk_parameters

from sandbox.projects.common.utils import receive_skynet_key


class Nanny(object):
    url = 'https://nanny.yandex-team.ru/v2'

    def __init__(self, token):
        self.token = token

    def _get(self, path):
        result = requests.get(
            self.url + path,
            headers={'Authorization': 'OAuth ' + self.token}
        ).json()
        return result

    def load_snapshots(self, service):
        return self._get(
            '/services/{}/current_state/'.format(service)
        )['content']['active_snapshots']

    def load_instances(self, service, snapshot_id, limit=200):
        assert 1 <= limit <= 200
        prev_len = -1
        lst = []
        while prev_len != len(lst):
            prev_len = len(lst)
            lst += self._get(
                '/services/instances/{}/sn/{}/?limit={}&skip={}'.format(service, snapshot_id, limit, prev_len)
            )['instances']
        return lst

    @classmethod
    def _validate_snapshots(cls, snapshots):
        if len(snapshots) < 2:
            return False
        if snapshots[0]['state'] != 'ACTIVE':
            return False
        if len([sn for sn in snapshots if sn['state'] == 'ACTIVE']) != 1:
            return False
        return True

    def all_and_suspicious_hosts(self, service):
        snapshots = self.load_snapshots(service)
        if not self._validate_snapshots(snapshots):
            return set()
        fresh = self.load_instances(service, snapshots[0]['snapshot_id'])
        old = self.load_instances(service, snapshots[1]['snapshot_id'])

        fresh_good_hosts = {
            inst['host'] for inst in fresh
            if inst['current_state'] in {'ACTIVE', 'HOOK_SEMI_FAILED'}
        }
        old_bad_hosts = {
            inst['host'] for inst in old
            if inst['current_state'] in {'UNKNOWN'}
        }
        all_hosts = {
            inst['host'] for inst in old + fresh
        }

        suspicious_hosts_ = old_bad_hosts - fresh_good_hosts
        return all_hosts, suspicious_hosts_


def run_clean_up_psi(hosts, service_name, psi_namespace):
    import api.cqueue

    clean_up_psi_on_host = CleanUpPsi(service_name, psi_namespace)
    with api.cqueue.Client(implementation='cqudp').run(hosts, clean_up_psi_on_host) as session:
        for host, result, err in session.wait():
            if err:
                logging.error('%s: skynet process failed: %s', host, err)
            else:
                logging.info('%s: created %i files', host, result)


class CleanUpPsi(object):
    psi_disable_file_name = 'psi_disable.flag'
    iss_instances_root = '/place/db/iss3/instances/'

    def __init__(self, service_name, psi_namespace):
        self.service = service_name
        self.namespace = psi_namespace

    def _dir_match(self, name):
        if self.service not in name:
            return False
        slot, _, unique = name.partition('_' + self.service + '_')
        if not slot.isdigit():
            return False
        if not len(unique) == 11:
            return False
        return True

    @staticmethod
    def _touch_file(filename):
        logging.debug('touch %s', filename)
        open(filename, 'a').close()

    def put_psi_disable_file(self):
        import os
        files_created = 0
        for dir_ in os.listdir(self.iss_instances_root):
            if self._dir_match(dir_):
                filename = os.path.join(self.iss_instances_root, dir_, self.psi_disable_file_name)
                if not os.path.exists(filename):
                    self._touch_file(filename)
                    files_created += 1
        return files_created

    def destroy_namespace(self):
        import porto
        conn = porto.Connection()
        try:
            conn.Destroy(self.namespace)
        except porto.exceptions.ContainerDoesNotExist:
            pass

    def __call__(self):
        created = self.put_psi_disable_file()
        self.destroy_namespace()
        return created


class CleanUpBadPsiContainers(sdk_task.SandboxTask):
    class NannyToken(sdk_parameters.SandboxStringParameter):
        name = "nanny_oauth_token"
        description = "Nanny OAuth token name (in sandbox vault)"
        required = True

    class NannyTokenOwner(sdk_parameters.SandboxStringParameter):
        name = "nanny_token_owner"
        description = "nanny token owner"
        required = True

    class PsiNannyService(sdk_parameters.SandboxStringParameter):
        name = "psi_nanny_service"
        description = "psi service name in nanny"
        required = True

    class PsiNamespace(sdk_parameters.SandboxStringParameter):
        name = "psi_namespace"
        description = "psi porto namespace"
        required = True

    class SuspiciousThreshold(sdk_parameters.SandboxFloatParameter):
        name = "suspicious_threshold"
        description = "suspicious count threshold"
        required = True

    class MaxSuspiciousHosts(sdk_parameters.SandboxIntegerParameter):
        name = "max_suspicious_hosts"
        description = "suspicious count max"
        required = True

    class NoAction(sdk_parameters.SandboxBoolParameter):
        name = "no_action"
        description = "No action"
        required = True

    type = "CLEAN_UP_BAD_PSI_CONTAINERS"

    input_parameters = [
        NannyToken,
        NannyTokenOwner,
        PsiNannyService,
        PsiNamespace,
        SuspiciousThreshold,
        MaxSuspiciousHosts,
        NoAction
    ]

    def on_execute(self):
        receive_skynet_key(self.owner)
        token = self.get_vault_data(
            self.ctx[self.NannyTokenOwner.name],
            self.ctx[self.NannyToken.name],
        )
        service_name = self.ctx[self.PsiNannyService.name]
        psi_namespace = self.ctx[self.PsiNamespace.name]
        threshold = self.ctx[self.SuspiciousThreshold.name]
        max_suspicious_hosts = self.ctx[self.MaxSuspiciousHosts.name]
        no_action = self.ctx[self.NoAction.name]
        nanny = Nanny(token)

        all_hosts, suspicious_hosts = nanny.all_and_suspicious_hosts(service_name)
        logging.info('suspicious hosts (%i/%i): \n%s', len(suspicious_hosts), len(all_hosts), suspicious_hosts)
        self.ctx['suspicious_hosts'] = '\n'.join(suspicious_hosts)

        if len(suspicious_hosts) > len(all_hosts) * threshold:
            logging.info('too many suspicious hosts (>threshold), exiting')
            return
        if len(suspicious_hosts) > max_suspicious_hosts:
            logging.info('too many suspicious hosts (>max), exiting')
            return
        if not suspicious_hosts:
            logging.info('no suspicious hosts, exiting')
            return
        if no_action:
            logging.info('no_action == True, exiting')
            return
        logging.info('running cleanup')
        run_clean_up_psi(suspicious_hosts, service_name, psi_namespace)


__Task__ = CleanUpBadPsiContainers
