import shards_pb2

from kazoo.client import KazooClient
from random import shuffle
from sandbox.projects import resource_types
from sandbox.projects.common import utils
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.parameters import SandboxStringParameter, SandboxIntegerParameter, SandboxBoolParameter, SandboxFloatParameter
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.sandboxsdk import environments
from sandbox.common.errors import TemporaryError
from sandbox.common.types.task import Status
from sandbox.common import rest

import collections
import math
import time
import datetime
import json
import logging
import urllib2
import os

instances_url_nanny = 'http://nanny.yandex-team.ru/v2/services/%s/current_state/instances/'
instances_url_saas_dm = 'http://saas-dm-proxy.n.yandex-team.ru/api/slots_by_interval/?service_trype=rtyserver&ctype={}&service={}'

REFRESH_SERVICES = (
    'saas_refresh_3day_production_base_multilang',
    'saas_refresh_10day_production_base_multilang',
    'saas_refresh_production_video_base',
    'saas_refresh_quick2_experiment_base',
    'saas_yp_samohod',
    'saas_yp_samohod_beta',
    'saas_yp_samohod_beta2',
    'samohod'
)


class ShardInfo(collections.MutableMapping):
    def __init__(self):
        self.shardId = -1
        self.copyFullRange = True
        self.store = {}

    def __getitem__(self, key):
        return self.store[key]

    def __setitem__(self, key, value):
        self.store[key] = value

    def __delitem__(self, key):
        del self.store[key]

    def __iter__(self):
        return iter(self.store)

    def __len__(self):
        return len(self.store)


def sendRequest(url, timeout=1):
    if url.find('http://') == -1 and url.find('https://') == -1:
        url = 'http://' + url
    for tries in range(3):
        try:
            reply = urllib2.urlopen(url, timeout=timeout)
            return reply.read()
        except Exception as e:
            logging.warning('url: ' + url)
            logging.exception(e)
        time.sleep(5)
    return False


class ZKClient(object):
    ZK_HOSTS = ('saas-zookeeper1.search.yandex.net:14880,'
                'saas-zookeeper2.search.yandex.net:14880,'
                'saas-zookeeper3.search.yandex.net:14880,'
                'saas-zookeeper4.search.yandex.net:14880,'
                'saas-zookeeper5.search.yandex.net:14880')

    def __enter__(self):
        self.kz = KazooClient(hosts=self.ZK_HOSTS, read_only=False)
        for tries in range(3):
            try:
                self.kz.start()
            except Exception as e:
                logging.warning(e)
            else:
                return self.kz
        raise SandboxTaskFailureError('Failed to establish zookeeper connection')

    def __exit__(self, type, value, traceback):
        self.kz.stop()


class ServiceName(SandboxStringParameter):
    """
        service name for detach index
    """
    name = 'service_name'
    description = 'service name for detach index'


class IsBackup(SandboxBoolParameter):
    """
        if True, the resource will be added to zookeeper, ttl will be set for _ALL_ backups
    """
    name = 'is_backup'
    description = 'is backup (if True, ttl will be set for _ALL_ backups)'
    default_value = False


class BackupsMinCount(SandboxIntegerParameter):
    """
        minimum number of backups
    """
    name = 'backups_min_count'
    description = 'minimum number of backups'
    default_value = 4


class ResourceTTL(SandboxIntegerParameter):
    """
        resource ttl in days
    """
    name = 'resource_ttl'
    description = 'resource ttl in days'
    default_value = 7


class MaxDownloadSpeed(SandboxStringParameter):
    """
        Max download speed (bytes or mbps e.g. "10M (10Mbytes) or 10mbps (10 Mbits)")
    """
    name = 'max_download_speed'
    description = 'Max download speed (bytes or mbps e.g. "10M (10Mbytes) or 10mbps (10 Mbits)")'
    default_value = '300mbps'


class DelayThreshold(SandboxFloatParameter):
    """
        not create the resource if the delay is more than this (hours)
    """
    name = 'delay_threshold'
    description = 'not create the resource if the delay is more than this (hours)'
    default_value = 7*24


class SaveAllIndexes(SandboxBoolParameter):
    """
        True - save all indexes from shard
        False - save largest index from shard
    """
    name = 'save_all_indexes'
    description = 'True - save all indexes from shard; False - save largest index from shard'
    default_value = True


class CorrectReplicasPattern(SandboxStringParameter):
    """
        RegExp of coreect replicas
    """
    name = 'correct_replica_pattern'
    description = 'RegExp of correct replicas'
    default_value = ''


class DeployRefreshFrozen(SandboxBoolParameter):
    """
        True - create task for deploy frozen refresh
        False - do not create task for deploy frozen refresh
    """
    name = 'deploy_refresh_frozen'
    description = 'True - create task for deploy frozen refresh; False - do not create task for deploy frozen refresh'
    default_value = False


class DownloadTimeout(SandboxIntegerParameter):
    """
        timeout for download index (min)
    """
    name = 'download_timeout'
    description = 'timeout for download index (min)'
    default_value = 60


class GetTorrentTimeout(SandboxIntegerParameter):
    """
        timeout for get index torrent (min)
    """
    name = 'get_torrent_timeout'
    description = 'timeout for get index torrent (min)'
    default_value = 60


class RegisterIssShards(SandboxBoolParameter):
    """
        Register shards in ISS
    """
    name = 'register_iss_shards'
    description = 'Register shards in ISS'
    default_value = False


class IssShardnameTemplate(SandboxStringParameter):
    """
        ISS shard name template, must have variables: shard, state
        Example:  SaasFrozenSamovar-{shard}-{state} -> SaasFrozenSamovar-0-6552-1533653730
    """
    name = 'iss_shardname_template'
    description = 'ISS shard name template, must have variables: shard, state'


class YtTable(SandboxStringParameter):
    """
        Yt table path for shards states
    """
    name = 'yt_table'
    description = 'Yt table path for shards states'


class DetachSmallerShard(SandboxIntegerParameter):
    """
        Detach a single small shard from the first shard (0-6552 if there are 10 shards)
    """
    name = 'smaller_shard'
    description = 'Detach a single shard from the first shard if |smaller_shard| > 0. Detached shard will be |smaller_shard| times smaller than the first shard'
    default_value = 0


class ZookeeperPrefix(SandboxStringParameter):
    """
        Zookeeper path prefix
    """
    name = 'zk_backup_path_prefix'
    description = 'Zookeeper path prefix for backups'
    default_value = '/indexBackups/'

class YtTokenVaultName(SandboxStringParameter):
    """
        name of vault secret for yt token
    """
    name = 'yt_token_vault_name'
    description = 'name of vault secret for yt token'
    default_value = 'YT_TOKEN_ARNOLD'

class YtTokenVaultOwner(SandboxStringParameter):
    """
        owner of vault secret for yt token
    """
    name = 'yt_token_vault_owner'
    description = 'owner of vault secret for yt token or empty for task owner'
    default_value = ''

class YtCluster(SandboxStringParameter):
    """
        yt cluster
    """
    name = 'yt_cluster'
    description = 'yt cluster'
    default_value = 'arnold'

class UseDeployManager(SandboxBoolParameter):
    """
        Get instances from deploy_manager
    """
    name = 'use_deploy_manager'
    description = 'Get instances from deploy_manager'
    default_value = False

class ServiceCtype(SandboxStringParameter):
    """
        service ctype for getting instances from DM
    """
    name = 'service_ctype'
    description = 'service ctype'
    default_value = ''

class NotCheckReplica(SandboxBoolParameter):
    """
        don't check that the replica is correct
    """
    name = 'not_check_replica'
    description = 'do not check that the replica is correct'
    default_value = False


class DetachServiceIndex(SandboxTask):
    """ Detach rtyserver index for some service  """
    type = 'DETACH_SERVICE_INDEX'

    environment = [environments.PipEnvironment('yandex-yt')]
    cores = 1
    required_ram = 2048
    execution_space = 1024

    input_parameters = [
            ServiceName,
            IsBackup,
            ResourceTTL,
            BackupsMinCount,
            MaxDownloadSpeed,
            DelayThreshold,
            SaveAllIndexes,
            CorrectReplicasPattern,
            DeployRefreshFrozen,
            DownloadTimeout,
            GetTorrentTimeout,
            RegisterIssShards,
            IssShardnameTemplate,
            YtTable,
            DetachSmallerShard,
            ZookeeperPrefix,
            YtTokenVaultName,
            YtTokenVaultOwner,
            YtCluster,
            UseDeployManager,
            ServiceCtype,
            NotCheckReplica
        ]

    def get_zk_backup_path(self):
        service_id = self.ctx['service_name']
        if self.ctx['use_deploy_manager']:
            service_id += '#' + self.ctx['service_ctype']

        znode = os.path.join(self.ctx['zk_backup_path_prefix'], service_id)
        return znode

    def addBackupToStorage(self, backup):
        with ZKClient() as zk:
            znode = os.path.join(self, self.get_zk_backup_path(), str(backup['timestamp']))
            try:
                if not zk.exists(znode):
                    zk.create(znode, makepath=True)
                zk.set(znode, backup['data'])
                logging.info('Backup added: ' + znode)
            except Exception as e:
                logging.exception(e)
                raise SandboxTaskFailureError('Can not save backup in zookeeper')

    def createBackupData(self, resources):
        shards = shards_pb2.TShards()
        timestamp = 0
        for resource in resources:
            shard = shards.Shard.add()
            shard.Name = 'backup_' + resource.attributes['shard'] + '_' + resource.attributes['index_timestamp']
            shardValue = resource.attributes['shard'].split('-')
            shard.ShardMin = int(shardValue[0])
            shard.ShardMax = int(shardValue[1])
            shard.Timestamp = long(resource.attributes['index_timestamp'])
            shard.Torrent = resource.skynet_id
            if timestamp == 0 or timestamp > shard.Timestamp:
                timestamp = shard.Timestamp
        logging.info('backup data:\n' + str(shards))
        return {'timestamp': timestamp, 'data': shards.SerializeToString()}

    def doBackup(self):
        resources = []
        for shardRange in self.ctx['task_by_shard']:
            childTaskResources = channel.rest.list_resources(
                resource_type=resource_types.RTYSERVER_SEARCH_DATABASE,
                status='READY',
                task_id=self.ctx['task_by_shard'][shardRange]
            )
            if len(childTaskResources) == 0:
                raise SandboxTaskFailureError('Can not find ready resource for backup for shard: ' + shardRange)
            resources.append(childTaskResources[0])
        backup = self.createBackupData(resources)
        self.addBackupToStorage(backup)

    def deleteOldBackups(self, backupTTL):
        now = time.time()
        ttlInSec = backupTTL * 24 * 60 * 60
        backupsPathForService = self.get_zk_backup_path()
        with ZKClient() as zk:
            backupTimestamps = zk.get_children(backupsPathForService)
            backupTimestamps = sorted(backupTimestamps)
            backupsCount = len(backupTimestamps)
            logging.info('List of backups:\n' + str(backupTimestamps))
            countToDelete = max(0, backupsCount - self.ctx['backups_min_count'])
            for index in range(countToDelete):
                if long(backupTimestamps[index]) < now - ttlInSec:
                    znode = os.path.join(backupsPathForService, backupTimestamps[index])
                    if zk.delete(znode):
                        logging.info('Old backup deleted: ' + znode)
                    else:
                        logging.warning('Can not delete old backup: ' + znode)

    def getShardMinValue(self, shardId):
        return int(math.floor(65533 * shardId / self.ctx['shard_count']))

    def getShardMaxValue(self, shardId):
        return int(math.floor(65533 * (shardId + 1) / self.ctx['shard_count']) - 1 + math.floor((shardId + 1) / self.ctx['shard_count']))

    def getShardRangeString(self, shardId):
        shardMin = self.getShardMinValue(shardId)
        shardMax = self.getShardMaxValue(shardId)
        return str(shardMin) + '-' + str(shardMax)

    def checkShardRanges(self, shards):
        if not self.ctx.get('shard_count'):
            logging.error('Incorrect shards count value: ' + str(self.ctx['shard_count']))
            return False
        if self.ctx['smaller_shard'] > 0:
            return True
        if self.ctx['shard_count'] != len(shards):
            logging.error('Incorrect shards count: ' + str(self.ctx['shard_count']) + ' vs ' + str(len(shards)))
            return False
        for shardId in range(self.ctx['shard_count']):
            if not shards.get(self.getShardRangeString(shardId)):
                logging.error('Not found shard: shardId=' + str(shardId) + ' shardValue=' + self.getShardRangeString(shardId))
                return False
        return True

    def get_instances_from_nanny(self):
        url = instances_url_nanny % self.ctx['service_name']
        data = sendRequest(url, timeout=15)
        if not data:
            return False
        data = json.loads(data)
        if 'error' in data:
            logging.info('nannya error: ' + data)
        if 'result' not in data:
            return False
        data = data.get('result', [])

        shardsById = {}
        for inst in data:
            host = inst['hostname']
            if inst.get('network_settings', '') == 'MTN_ENABLED':
                host = inst['container_hostname']
            instance = '{host}:{port}'.format(host=host, port=inst['port'] + 3)
            shardId = 0
            dc = 'unknown'
            for tag in inst['itags']:
                if tag.find('OPT_shardid') == 0:
                    shardId = int(tag.split('=').pop())
                if tag.find('a_dc_') == 0:
                    dc = tag.split('_').pop()
            if shardId not in shardsById:
                shardsById[shardId] = ShardInfo()
            if dc not in shardsById[shardId]:
                shardsById[shardId][dc] = []
            shardsById[shardId][dc].append(instance)

        self.ctx['shard_count'] = len(shardsById)
        shards = {}
        for shardId in shardsById:
            shards[self.getShardRangeString(shardId)] = shardsById[shardId]
            shards[self.getShardRangeString(shardId)].shardId = shardId

        if self.ctx['smaller_shard'] <= 0:
            return shards
        if 0 not in shardsById:
            raise TemporaryError('Cannot create a smaller shard')
        smallShards = {}
        smallerShard = '0-' + str(self.getShardMaxValue(0) / self.ctx['smaller_shard'])
        smallShards[smallerShard] = shardsById[0]
        smallShards[smallerShard].copyFullRange = False
        return smallShards

    def get_instances_from_saas_dm(self):
        url = instances_url_saas_dm.format(self.ctx['service_ctype'], self.ctx['service_name'])
        data = sendRequest(url, timeout=15)
        if not data:
            return False
        data = json.loads(data)
        shards = {}
        for shard in data:
            shard_range = shard['id']
            shards[shard_range] = ShardInfo()
            for slot in shard['slots']:
                if slot.get('is_sd', False):
                    continue
                dc = slot['$datacenter$'].lower()
                if dc not in shards[shard_range]:
                    shards[shard_range][dc] = []
                host, port = slot['slot'].split(':')
                instance = '{}:{}'.format(host, int(port) + 3)
                if slot['result.controller_status'] != 'Active':
                    logging.debug('Skip instance {}: {}'.format(instance, slot['result.controller_status']))
                    continue
                shards[shard_range][dc].append(instance)
        self.ctx['shard_count'] = len(shards)
        return shards

    def prepare_replicas(self, shards):
        import re
        pattern = re.compile(self.ctx['correct_replica_pattern'])
        for shard in shards:
            replicasCount = 0
            for dc in shards[shard]:
                replicas = []
                for instance in shards[shard][dc]:
                    if pattern.match(instance):
                        replicas.append(instance)
                    else:
                        logging.debug('Skip replica: %s' % instance)
                shards[shard][dc] = replicas
                shuffle(shards[shard][dc])
                replicasCount += len(replicas)
            if replicasCount == 0:
                raise SandboxTaskFailureError('No replicas for shard: {}'.format(shard))
        return shards

    def get_instances(self):
        if self.ctx['use_deploy_manager']:
            shards = self.get_instances_from_saas_dm()
        else:
            shards = self.get_instances_from_nanny()
        if not shards:
            raise TemporaryError('Can not get instances')
        if not self.checkShardRanges(shards):
            raise TemporaryError('Incorrect shard ranges:\n' + str(shards))
        return self.prepare_replicas(shards)

    def replicaIsCorrect(self, replica):
        if self.ctx['not_check_replica']:
            return True
        url = replica + '/status'
        data = sendRequest(url)
        status = {}
        if not data:
            return False
        logging.info(url + ' result: \n ' + data)
        for line in data.split('\n'):
            st = line.split(':')
            if len(st) == 2:
                status[st[0].strip()] = st[1].strip()

        status['Active'] = int(status['Active'])
        status['Consecutive_crashes'] = int(status['Consecutive_crashes'])
        status['Index_Merge_last_time'] = int(status['Index_Merge_last_time'])

        if status['Active'] != 1 or status['Consecutive_crashes'] >= 3:
            return False
        return (status['Consecutive_crashes'] == 0 or status['Index_Merge_last_time'] != 0)

    def getIndexSize(self, replicasPerDc):
        for dc in replicasPerDc:
            for replica in replicasPerDc[dc]:
                if self.replicaIsCorrect(replica):
                    url = replica + '/?command=get_info_server'
                    data = sendRequest(url)
                    if data and 'result' in data:
                        indexSize = json.loads(data)['result']['files_size']['__SUM']
                        logging.info('Index size ' + str(indexSize) + ' from ' + replica)
                        return  long(indexSize / 1024 / 1024)     # return size in MB
        raise SandboxTaskFailureError('Can not get index size for shard')

    def createChildTask(self, shardRange, shards, extDescription=''):
        indexSize = self.getIndexSize(shards[shardRange])
        logging.info('Shard=' + shardRange + ' indexSize=' + str(indexSize) + 'MB')
        iss_shardname = ''
        if self.ctx['register_iss_shards']:
            iss_shardname = self.ctx['iss_shardname_template'].format(shard=shardRange, state=self.ctx['state'])

        taskId = self.create_subtask(
            task_type='DETACH_RTYSERVER_INDEX',
            description='detach index for service "' + self.ctx['service_name'] + '" from shard: ' + shardRange + ' ' + extDescription,
            input_parameters={
                'kill_timeout': 12 * 60 * 60,
                'service_name': self.ctx['service_name'],
                'shard_range': shardRange,
                'replicas': json.dumps(shards[shardRange].store),
                'resource_ttl': self.ctx['resource_ttl'],
                'max_download_speed': self.ctx['max_download_speed'],
                'delay_threshold': self.ctx['delay_threshold'],
                'save_all_indexes': self.ctx['save_all_indexes'],
                'get_torrent_timeout': self.ctx['get_torrent_timeout'],
                'download_timeout': self.ctx['download_timeout'],
                'register_iss_shards': self.ctx['register_iss_shards'],
                'iss_shardname': iss_shardname,
                'shard_id_num': shards[shardRange].shardId,
                'copy_full_range': shards[shardRange].copyFullRange,
                'not_check_replica': self.ctx['not_check_replica']
            },
            execution_space=long(indexSize * 1.25),
            inherit_notifications=True
        ).id
        self.ctx['task_by_shard'][shardRange] = taskId

    def getSubTasksInfo(self):
        subTasksInfo = {'success_count': 0, 'failed_count': 0, 'failure': [], 'break': []}
        for shardRange in self.ctx['task_by_shard']:
            taskId = self.ctx['task_by_shard'][shardRange]
            taskStatus = rest.Client().task[taskId].read()['status']
            if taskStatus == Status.SUCCESS:
                subTasksInfo['success_count'] += 1
            if taskStatus == Status.FAILURE:
                subTasksInfo['failure'].append(shardRange)
            if taskStatus in Status.Group.BREAK:
                subTasksInfo['break'].append(shardRange)
        subTasksInfo['failed_count'] = len(subTasksInfo['failure']) + len(subTasksInfo['break'])
        return subTasksInfo

    def deployRefreshFrozen(self):
        if self.ctx['service_name'] in REFRESH_SERVICES:
            logging.info('Deploying Frozen')
            self.create_subtask(
                'DEPLOY_REFRESH_FROZEN',
                'Deploy refresh frozen from {}'.format(self.id),
                input_parameters={
                    'detach_task': self.id,
                    'save_previous': True
                }
            )
        else:
            logging.info('Not deploying Frozen because %s not in %s', self.ctx['service_name'], REFRESH_SERVICES)

    def registerIssShards(self):
        import yt.wrapper as yt

        state = datetime.datetime.fromtimestamp(self.ctx['state'])
        data = {
            "Time": state.isoformat(),
            "State": state.strftime('%Y%m%d-%H%M%S'),
            "JupiterBundleId": "",
            "MrPrefix": "",
            "MrServer": ""
        }

        yt.config.set_proxy(self.ctx['yt_cluster'])
        vault_owner = self.owner
        if self.ctx['yt_token_vault_owner']:
            vault_owner = self.ctx['yt_token_vault_owner']
        yt.config.config['token'] = self.get_vault_data(vault_owner, self.ctx['yt_token_vault_name'])
        yt.insert_rows(yt.TablePath(self.ctx['yt_table'], append=True), [data], format=yt.JsonFormat())

    def on_execute(self):
        logging.info('Service: ' + self.ctx['service_name'])
        if 'initialized' not in self.ctx:
            self.ctx['state'] = int(time.time())
            self.ctx['task_by_shard'] = {}
            self.ctx['retry_count'] = 0
            shards = self.get_instances()
            for shardRange in shards:
                self.createChildTask(shardRange, shards)
            self.ctx['initialized'] = True
            utils.wait_all_subtasks_stop()

        subTasksInfo = self.getSubTasksInfo()
        logging.info('Failed shards count: ' + str(subTasksInfo['failed_count']))
        if subTasksInfo['success_count'] != self.ctx['shard_count']:
            if self.ctx['retry_count'] < 3:
                shards = self.get_instances()
                for shardRange in subTasksInfo['failure']:
                    logging.info('Retry #' + str(self.ctx['retry_count']) + ': create new child task for: ' + shardRange)
                    self.createChildTask(shardRange, shards, 'retry #' + str(self.ctx['retry_count']))
                for shardRange in subTasksInfo['break']:
                    logging.info('Retry #' + str(self.ctx['retry_count']) + ': restart child task #' + str(self.ctx['task_by_shard'][shardRange]) + ' for: ' + shardRange)
                    channel.sandbox.server.restart_task(self.ctx['task_by_shard'][shardRange])
                self.ctx['retry_count'] += 1
                utils.wait_all_subtasks_stop()
            else:
                raise SandboxTaskFailureError('Can not detach index after ' + str(self.ctx['retry_count']) + ' retries')

        if self.ctx['is_backup']:
            self.doBackup()
            self.deleteOldBackups(self.ctx['resource_ttl'])

        if self.ctx['register_iss_shards']:
            self.registerIssShards()

    def on_success(self):
        if self.ctx.get(DeployRefreshFrozen.name, DeployRefreshFrozen.default_value):
            self.deployRefreshFrozen()


__Task__ = DetachServiceIndex
