from sandbox import sdk2
from sandbox.common.errors import TaskFailure, TemporaryError
from sandbox.common.types.misc import NotExists
from sandbox.common.types.task import Status
from sandbox.projects import resource_types
from sandbox.projects.saas.common.classes import SaasBinaryTask
from sandbox.sandboxsdk.channel import channel

import sandbox.common.types.task as ctt

from kazoo.client import KazooClient
from random import shuffle

import collections
import math
import time
import datetime
import json
import logging
import urllib2
import os

instances_url_nanny = 'http://nanny.yandex-team.ru/v2/services/%s/current_state/instances/'
instances_url_saas_dm = 'http://saas-dm-proxy.n.yandex-team.ru/api/slots_by_interval/?service_trype=rtyserver&ctype={}&service={}'

REFRESH_SERVICES = (
    'saas_refresh_3day_production_base_multilang',
    'saas_refresh_10day_production_base_multilang',
    'saas_refresh_production_video_base',
    'saas_refresh_quick2_experiment_base',
)


class ShardInfo(collections.MutableMapping):
    def __init__(self):
        self.shardId = -1
        self.copyFullRange = True
        self.store = {}

    def __getitem__(self, key):
        return self.store[key]

    def __setitem__(self, key, value):
        self.store[key] = value

    def __delitem__(self, key):
        del self.store[key]

    def __iter__(self):
        return iter(self.store)

    def __len__(self):
        return len(self.store)


def sendRequest(url, timeout=1):
    if url.find('http://') == -1 and url.find('https://') == -1:
        url = 'http://' + url
    for tries in range(3):
        try:
            reply = urllib2.urlopen(url, timeout=timeout)
            return reply.read()
        except Exception as e:
            logging.warning('url: ' + url)
            logging.exception(e)
        time.sleep(5)
    return False


class ZKClient(object):
    ZK_HOSTS = ('saas-zookeeper1.search.yandex.net:14880,'
                'saas-zookeeper2.search.yandex.net:14880,'
                'saas-zookeeper3.search.yandex.net:14880,'
                'saas-zookeeper4.search.yandex.net:14880,'
                'saas-zookeeper5.search.yandex.net:14880')

    def __enter__(self):
        self.kz = KazooClient(hosts=self.ZK_HOSTS, read_only=False)
        for tries in range(3):
            try:
                self.kz.start()
            except Exception as e:
                logging.warning(e)
            else:
                return self.kz
        raise TaskFailure('Failed to establish zookeeper connection')

    def __exit__(self, type, value, traceback):
        self.kz.stop()


def parseServiceDatacenters(serviceState):
    datacenters = set()
    for inst in serviceState:
        for tag in inst['itags']:
            if tag.find('a_dc_') == 0:
                datacenters.add(tag.split('_').pop())

    return datacenters


def getServicePods(service_id, datacenters):
    from saas.library.python.nanny_proto import Pod

    pods = []
    for dc in datacenters:
        cluster = dc.upper()
        limit = 500
        offset = 0

        while True:
            pods.extend(Pod.list(cluster=dc, service_id=service_id, limit=limit, offset=offset))
            if len(pods) != limit + offset:
                break

            offset += limit

    return pods


class DetachServiceIndex2(SaasBinaryTask):
    """ Detach rtyserver index for some service  """
    TASKS_RESOURCE_NAME = 'SaasSandboxTasks'

    class Requirements(sdk2.Task.Requirements):
        cores = 1
        ram = 2048
        disk_space = 1024

        class Caches(sdk2.Requirements.Caches):
            pass

    class Context(sdk2.Task.Context):
        initialized = False

    class Parameters(sdk2.Task.Parameters):
        service_name = sdk2.parameters.String('service name for detach index')
        is_backup = sdk2.parameters.Bool('is backup (if True, ttl will be set for _ALL_ backups)', default=False)
        backups_min_count = sdk2.parameters.Integer('minimum number of backups', default=4)
        resource_ttl = sdk2.parameters.Integer('resource ttl in days', default=7)
        max_download_speed = sdk2.parameters.String('Max download speed (bytes or mbps e.g. "10M (10Mbytes) or 10mbps (10 Mbits)")', default='300mbps')
        delay_threshold = sdk2.parameters.Float('not create the resource if the delay is more than this (hours)', default=7*24)
        save_all_indexes = sdk2.parameters.Bool('True - save all indexes from shard; False - save largest index from shard', default=True)
        correct_replica_pattern = sdk2.parameters.String('RegExp of correct replicas', default='')
        deploy_refresh_frozen = sdk2.parameters.Bool('True - create task for deploy frozen refresh; False - do not create task for deploy frozen refresh', default=False)
        download_timeout = sdk2.parameters.Integer('timeout for download index (min)', default=60)
        get_torrent_timeout = sdk2.parameters.Integer('timeout for get index torrent (min)', default=60)
        register_iss_shards = sdk2.parameters.Bool('Register shards in ISS', default=False)
        iss_shardname_template = sdk2.parameters.String('ISS shard name template, must have variables: shard, state')
        yt_table = sdk2.parameters.String('Yt table path for shards states')
        smaller_shard = sdk2.parameters.Integer('Detach a single shard from the first shard if |smaller_shard| > 0. Detached shard will be |smaller_shard| times smaller than the first shard',
                                                default=0)
        zk_backup_path_prefix = sdk2.parameters.String('Zookeeper path prefix for backups', default='/indexBackups/')
        yt_token_vault_name = sdk2.parameters.String('name of vault secret for yt token', default='YT_TOKEN_ARNOLD')
        yt_token_vault_owner = sdk2.parameters.String('owner of vault secret for yt token or empty for task owner', default='')
        yt_cluster = sdk2.parameters.String('yt_cluster', default='arnold')
        use_deploy_manager = sdk2.parameters.Bool('Get instances from deploy_manager', default=False)
        service_ctype = sdk2.parameters.String('service ctype', default='')
        yp_pod_label_with_shardid = sdk2.parameters.String('Custom yp pod label with shardid. default=shard', default='shard')
        not_check_replica = sdk2.parameters.Bool('do not check that the replica is correct', default=False)
        nanny_token_vault_name = sdk2.parameters.String('name of vault secret for nanny token', default='')
        nanny_token_vault_owner = sdk2.parameters.String('owner of vault secret for nanny token or empty for task owner', default='')

    def get_zk_backup_path(self):
        service_id = self.Parameters.service_name
        if self.Parameters.use_deploy_manager:
            service_id += '#' + self.Parameters.service_ctype

        znode = os.path.join(self.Parameters.zk_backup_path_prefix, service_id)
        return znode

    def addBackupToStorage(self, backup):
        with ZKClient() as zk:
            znode = os.path.join(self, self.get_zk_backup_path(), str(backup['timestamp']))
            try:
                if not zk.exists(znode):
                    zk.create(znode, makepath=True)
                zk.set(znode, backup['data'])
                logging.info('Backup added: ' + znode)
            except Exception as e:
                logging.exception(e)
                raise TaskFailure('Can not save backup in zookeeper')

    def createBackupData(self, resources):
        from saas.protos.shards_pb2 import TShards
        shards = TShards()
        timestamp = 0
        for resource in resources:
            shard = shards.Shard.add()
            shard.Name = 'backup_' + resource.attributes['shard'] + '_' + resource.attributes['index_timestamp']
            shardValue = resource.attributes['shard'].split('-')
            shard.ShardMin = int(shardValue[0])
            shard.ShardMax = int(shardValue[1])
            shard.Timestamp = long(resource.attributes['index_timestamp'])
            shard.Torrent = resource.skynet_id
            if timestamp == 0 or timestamp > shard.Timestamp:
                timestamp = shard.Timestamp
        logging.info('backup data:\n' + str(shards))
        return {'timestamp': timestamp, 'data': shards.SerializeToString()}

    def doBackup(self):
        resources = []
        for shardRange, taskId in self.Context.task_by_shard.iteritems():
            childTaskResources = channel.rest.list_resources(
                resource_types.RTYSERVER_SEARCH_DATABASE,
                status='READY',
                task_id=taskId
            )
            if len(childTaskResources) == 0:
                raise TaskFailure('Can not find ready resource for backup for shard: ' + shardRange)
            resources.append(childTaskResources[0])
        backup = self.createBackupData(resources)
        self.addBackupToStorage(backup)

    def deleteOldBackups(self, backupTTL):
        now = time.time()
        ttlInSec = backupTTL * 24 * 60 * 60
        backupsPathForService = self.get_zk_backup_path()
        with ZKClient() as zk:
            backupTimestamps = zk.get_children(backupsPathForService)
            backupTimestamps = sorted(backupTimestamps)
            backupsCount = len(backupTimestamps)
            logging.info('List of backups:\n' + str(backupTimestamps))
            countToDelete = max(0, backupsCount - self.Parameters.backups_min_count)
            for index in range(countToDelete):
                if long(backupTimestamps[index]) < now - ttlInSec:
                    znode = os.path.join(backupsPathForService, backupTimestamps[index])
                    if zk.delete(znode):
                        logging.info('Old backup deleted: ' + znode)
                    else:
                        logging.warning('Can not delete old backup: ' + znode)

    def getShardMinValue(self, shardId):
        return int(math.floor(65533 * shardId / self.Context.shard_count))

    def getShardMaxValue(self, shardId):
        return int(math.floor(65533 * (shardId + 1) / self.Context.shard_count) - 1 + math.floor((shardId + 1) / self.Context.shard_count))

    def getShardRangeString(self, shardId):
        shardMin = self.getShardMinValue(shardId)
        shardMax = self.getShardMaxValue(shardId)
        return str(shardMin) + '-' + str(shardMax)

    def checkShardRanges(self, shards):
        if self.Context.shard_count == NotExists:
            logging.error('Incorrect shards count value: ' + str(self.Context.shard_count))
            return False
        if self.Parameters.smaller_shard > 0:
            return True
        if self.Context.shard_count != len(shards):
            logging.error('Incorrect shards count: ' + str(self.Context.shard_count) + ' vs ' + str(len(shards)))
            return False
        for shardId in range(self.Context.shard_count):
            if not shards.get(self.getShardRangeString(shardId)):
                logging.error('Not found shard: shardId=' + str(shardId) + ' shardValue=' + self.getShardRangeString(shardId))
                return False
        return True

    def get_instances_from_nanny(self):
        serviceName = self.Parameters.service_name
        url = instances_url_nanny % serviceName
        data = sendRequest(url, timeout=15)
        if not data:
            return False
        data = json.loads(data)
        if 'error' in data:
            logging.info('nannya error: ' + data)
        if 'result' not in data:
            return False
        data = data.get('result', [])
        pods = getServicePods(serviceName, parseServiceDatacenters(data))
        pods_shardmap = {pod.hostname: pod.labels.get(self.Parameters.yp_pod_label_with_shardid, '0') for pod in pods}

        shardsById = {}
        for inst in data:
            host = inst['hostname']
            if inst.get('network_settings', '') == 'MTN_ENABLED':
                host = inst['container_hostname']
            instance = '{host}:{port}'.format(host=host, port=inst['port'] + 3)
            shardId = 0
            dc = 'unknown'
            for tag in inst['itags']:
                if tag.find('OPT_shardid') == 0:
                    shardId = int(tag.split('=').pop())
                elif host in pods_shardmap:
                    shardId = int(pods_shardmap[host])
                if tag.find('a_dc_') == 0:
                    dc = tag.split('_').pop()
            if shardId not in shardsById:
                shardsById[shardId] = ShardInfo()
            if dc not in shardsById[shardId]:
                shardsById[shardId][dc] = []
            shardsById[shardId][dc].append(instance)

        self.Context.shard_count = len(shardsById)
        shards = {}
        for shardId in shardsById:
            shards[self.getShardRangeString(shardId)] = shardsById[shardId]
            shards[self.getShardRangeString(shardId)].shardId = shardId

        if self.Parameters.smaller_shard <= 0:
            return shards
        if 0 not in shardsById:
            raise TemporaryError('Cannot create a smaller shard')
        smallShards = {}
        smallerShard = '0-' + str(self.getShardMaxValue(0) / self.Parameters.smaller_shard)
        smallShards[smallerShard] = shardsById[0]
        smallShards[smallerShard].copyFullRange = False
        return smallShards

    def get_instances_from_saas_dm(self):
        url = instances_url_saas_dm.format(self.Parameters.service_ctype, self.ctx.service_name)
        data = sendRequest(url, timeout=15)
        if not data:
            return False
        data = json.loads(data)
        shards = {}
        for shard in data:
            shard_range = shard['id']
            shards[shard_range] = ShardInfo()
            for slot in shard['slots']:
                if slot.get('is_sd', False):
                    continue
                dc = slot['$datacenter$'].lower()
                if dc not in shards[shard_range]:
                    shards[shard_range][dc] = []
                host, port = slot['slot'].split(':')
                instance = '{}:{}'.format(host, int(port) + 3)
                if slot['result.controller_status'] != 'Active':
                    logging.debug('Skip instance {}: {}'.format(instance, slot['result.controller_status']))
                    continue
                shards[shard_range][dc].append(instance)
        self.Context.shard_count = len(shards)
        return shards

    def prepare_replicas(self, shards):
        import re
        pattern = re.compile(self.Parameters.correct_replica_pattern)
        for shard in shards:
            replicasCount = 0
            for dc in shards[shard]:
                replicas = []
                for instance in shards[shard][dc]:
                    if pattern.match(instance):
                        replicas.append(instance)
                    else:
                        logging.debug('Skip replica: %s' % instance)
                shards[shard][dc] = replicas
                shuffle(shards[shard][dc])
                replicasCount += len(replicas)
            if replicasCount == 0:
                raise TaskFailure('No replicas for shard: {}'.format(shard))
        return shards

    def get_instances(self):
        if self.Parameters.use_deploy_manager:
            shards = self.get_instances_from_saas_dm()
        else:
            shards = self.get_instances_from_nanny()
        if not shards:
            raise TemporaryError('Can not get instances')
        if not self.checkShardRanges(shards):
            raise TemporaryError('Incorrect shard ranges:\n' + str(shards))
        return self.prepare_replicas(shards)

    def replicaIsCorrect(self, replica):
        if self.Parameters.not_check_replica:
            return True
        url = replica + '/status'
        data = sendRequest(url)
        status = {}
        if not data:
            return False
        logging.info(url + ' result: \n ' + data)
        for line in data.split('\n'):
            st = line.split(':')
            if len(st) == 2:
                status[st[0].strip()] = st[1].strip()

        status['Active'] = int(status['Active'])
        status['Consecutive_crashes'] = int(status['Consecutive_crashes'])
        status['Index_Merge_last_time'] = int(status['Index_Merge_last_time'])

        if status['Active'] != 1 or status['Consecutive_crashes'] >= 3:
            return False
        return (status['Consecutive_crashes'] == 0 or status['Index_Merge_last_time'] != 0)

    def getIndexSize(self, replicasPerDc):
        for dc in replicasPerDc:
            for replica in replicasPerDc[dc]:
                if self.replicaIsCorrect(replica):
                    url = replica + '/?command=get_info_server'
                    data = sendRequest(url)
                    if data and 'result' in data:
                        indexSize = json.loads(data)['result']['files_size']['__SUM']
                        logging.info('Index size ' + str(indexSize) + ' from ' + replica)
                        return long(indexSize / 1024 / 1024)     # return size in MB
        raise TaskFailure('Can not get index size for shard')

    def createChildTask(self, shardRange, shards, extDescription=''):
        indexSize = self.getIndexSize(shards[shardRange])
        logging.info('Shard=' + shardRange + ' indexSize=' + str(indexSize) + 'MB')
        iss_shardname = ''
        if self.Parameters.register_iss_shards:
            iss_shardname = self.Parameters.iss_shardname_template.format(shard=shardRange, state=self.Context.state)

        childTask = sdk2.Task["DETACH_RTYSERVER_INDEX"]
        taskId = childTask(
            self,
            description='detach index for service "' + self.Parameters.service_name + '" from shard: ' + shardRange + ' ' + extDescription,
            kill_timeout=12 * 60 * 60,
            service_name=self.Parameters.service_name,
            shard_range=shardRange,
            replicas=json.dumps(shards[shardRange].store),
            resource_ttl=self.Parameters.resource_ttl,
            max_download_speed=self.Parameters.max_download_speed,
            delay_threshold=self.Parameters.delay_threshold,
            save_all_indexes=self.Parameters.save_all_indexes,
            get_torrent_timeout=self.Parameters.get_torrent_timeout,
            download_timeout=self.Parameters.download_timeout,
            register_iss_shards=self.Parameters.register_iss_shards,
            iss_shardname=iss_shardname,
            shard_id_num=shards[shardRange].shardId,
            copy_full_range=shards[shardRange].copyFullRange,
            not_check_replica=self.Parameters.not_check_replica,
            execution_space=long(indexSize * 1.25),
            inherit_notifications=True
        ).enqueue().id
        self.Context.task_by_shard[shardRange] = taskId

    def getSubTasksInfo(self):
        subTasksInfo = {'success_count': 0, 'failed_count': 0, 'failure': [], 'break': []}
        for shardRange in self.Context.task_by_shard:
            taskId = self.Context.task_by_shard[shardRange]
            taskStatus = sdk2.Task[taskId].status
            if taskStatus == Status.SUCCESS:
                subTasksInfo['success_count'] += 1
            if taskStatus == Status.FAILURE:
                subTasksInfo['failure'].append(shardRange)
            if taskStatus in Status.Group.BREAK:
                subTasksInfo['break'].append(shardRange)
        subTasksInfo['failed_count'] = len(subTasksInfo['failure']) + len(subTasksInfo['break'])
        return subTasksInfo

    def deployRefreshFrozen(self):
        if self.Parameters.service_name in REFRESH_SERVICES:
            self.create_subtask(
                'DEPLOY_REFRESH_FROZEN',
                'Deploy refresh frozen from {}'.format(self.id),
                input_parameters={
                    'detach_task': self.id,
                    'save_previous': True
                }
            )

    def registerIssShards(self):
        import yt.wrapper as yt

        state = datetime.datetime.fromtimestamp(self.Context.state)
        data = {
            "Time": state.isoformat(),
            "State": state.strftime('%Y%m%d-%H%M%S'),
            "JupiterBundleId": "",
            "MrPrefix": "",
            "MrServer": ""
        }

        yt.config.set_proxy(self.Parameters.yt_cluster)
        vault_owner = self.owner
        if self.Parameters.yt_token_vault_owner:
            vault_owner = self.Parameters.yt_token_vault_owner
        yt.config.config['token'] = sdk2.Vault.data(vault_owner, self.Parameters.yt_token_vault_name)
        yt.insert_rows(yt.TablePath(self.Parameters.yt_table, append=True), [data], format=yt.JsonFormat())

    def on_execute(self):
        logging.info('Service: ' + self.Parameters.service_name)
        from saas.library.python.token_store import TokenStore
        vault_owner = self.owner
        if self.Parameters.nanny_token_vault_owner:
            vault_owner = self.Parameters.yt_token_vault_owner
        nanny_token = sdk2.Vault.data(vault_owner, self.Parameters.nanny_token_vault_name)
        TokenStore.add_token('nanny', nanny_token)

        if not self.Context.initialized:
            self.Context.state = int(time.time())
            self.Context.task_by_shard = {}
            self.Context.retry_count = 0
            shards = self.get_instances()
            for shardRange in shards:
                self.createChildTask(shardRange, shards)
            self.Context.initialized = True
            raise sdk2.WaitTask(self.Context.task_by_shard.values(), ctt.Status.Group.FINISH | ctt.Status.Group.BREAK)

        subTasksInfo = self.getSubTasksInfo()
        logging.info('Failed shards count: ' + str(subTasksInfo['failed_count']))
        if subTasksInfo['success_count'] != self.Context.shard_count:
            if self.Context.retry_count < 3:
                shards = self.get_instances()
                for shardRange in subTasksInfo['failure']:
                    logging.info('Retry #' + str(self.Context.retry_count) + ': create new child task for: ' + shardRange)
                    self.createChildTask(shardRange, shards, 'retry #' + str(self.Context.retry_count))
                for shardRange in subTasksInfo['break']:
                    logging.info('Retry #' + str(self.Context.retry_count) + ': restart child task #' + str(self.Context.task_by_shard[shardRange]) + ' for: ' + shardRange)
                    sdk2.Task[self.Context.task_by_shard[shardRange]].enqueue()
                self.Context.retry_count += 1
                raise sdk2.WaitTask(self.Context.task_by_shard.values(), ctt.Status.Group.FINISH | ctt.Status.Group.BREAK)
            else:
                raise TaskFailure('Can not detach index after ' + str(self.Context.retry_count) + ' retries')

        if self.Parameters.is_backup:
            self.doBackup()
            self.deleteOldBackups(self.Parameters.resource_ttl)

        if self.Parameters.register_iss_shards:
            self.registerIssShards()

    def on_success(self, prev_status):
        if self.Parameters.deploy_refresh_frozen:
            self.deployRefreshFrozen()
