import timestamp_pb2

from sandbox import common
import sandbox.common.types.client as ctc
from sandbox.sandboxsdk import process
from sandbox.projects import resource_types
from sandbox.projects.common import apihelpers
from sandbox.sandboxsdk.parameters import SandboxStringParameter, SandboxIntegerParameter, SandboxBoolParameter, SandboxFloatParameter
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.common.errors import TaskFailure, TemporaryError
from sandbox.sandboxsdk.errors import SandboxSubprocessTimeoutError

import json
import logging
import urllib2
import time
import os
import shutil
from datetime import datetime
import google.protobuf

DOWNLOAD_PATH = 'backup'


def sendRequest(url, timeout=1):
    if url.find('http://') == -1 and url.find('https://') == -1:
        url = 'http://' + url
    for tries in range(3):
        try:
            reply = urllib2.urlopen(url, timeout=timeout)
            return reply.read()
        except Exception as e:
            logging.warning('url: ' + url)
            logging.exception(e)
        time.sleep(5)
    return False


def getDirSize(dir):
    size = 0
    for path, subdirs, files in os.walk(dir):
        for file in files:
            filename = os.path.join(path, file)
            size += os.path.getsize(filename)
    return size


class ServiceName(SandboxStringParameter):
    """
        service name for detach index from some replica
    """
    name = 'service_name'
    description = 'service name for copy search index'


class ShardRange(SandboxStringParameter):
    """
        shard range: "ShardMin-ShardMax"
    """
    name = 'shard_range'
    description = 'shard range: "ShardMin-ShardMax"'


class Replicas(SandboxStringParameter):
    """
        json with replicas per dc:
        Example: {'sas':[host1:port1,host2:port2,host3:port3...],'man':[]}
    """
    name = 'replicas'
    description = 'json with replicas by dc. Example: {"sas":["host1:port1","host2:port2"],"man":[]}'


class ResourceTTL(SandboxIntegerParameter):
    """
        resource ttl in days
    """
    name = 'resource_ttl'
    description = 'resource ttl in days'
    default_value = 7


class MaxDownloadSpeed(SandboxStringParameter):
    """
        Max download speed (bytes or mbps e.g. "10M (10Mbytes) or 10mbps (10 Mbits)")
    """
    name = 'max_download_speed'
    description = 'Max download speed (bytes or mbps e.g. "10M (10Mbytes) or 10mbps (10 Mbits)")'
    default_value = '300mbps'


class DelayThreshold(SandboxFloatParameter):
    """
        not create the resource if the delay is more than this (hours)
    """
    name = 'delay_threshold'
    description = 'not create the resource if the delay is more than this (hours)'
    default_value = 7*24


class SaveAllIndexes(SandboxBoolParameter):
    """
        True - save all indexes from shard
        False - save largest index from shard
    """
    name = 'save_all_indexes'
    description = 'True - save all indexes from shard; False - save largest index from shard'
    default_value = True


class DownloadTimeout(SandboxIntegerParameter):
    """
        timeout for download index (min)
    """
    name = 'download_timeout'
    description = 'timeout for download index (min)'
    default_value = 60


class GetTorrentTimeout(SandboxIntegerParameter):
    """
        timeout for get index torrent (min)
    """
    name = 'get_torrent_timeout'
    description = 'timeout for get index torrent (min)'
    default_value = 60


class RegisterIssShards(SandboxBoolParameter):
    """
        Register shards in ISS
    """
    name = 'register_iss_shards'
    description = 'Register shards in ISS'
    default_value = False


class ShardIdNum(SandboxIntegerParameter):
    """
        shard id as integer, will be propagated to resource
    """
    name = 'shard_id_num'
    description = 'Shard id to be progated to resource'
    default_value = -1


class IssShardname(SandboxStringParameter):
    """
        ISS shard name
    """
    name = 'iss_shardname'
    description = 'ISS shard name'


class ForceCopyAllIndexRange(SandboxBoolParameter):
    """
        Force: copy full index range from backend REFRESH-379
    """
    name = 'copy_full_range'
    description = 'force: copy full index range [0;65533] from backend REFRESH-379'
    default_value = False

class NotCheckReplica(SandboxBoolParameter):
    """
        don't check that the replica is correct
    """
    name = 'not_check_replica'
    description = 'do not check that the replica is correct'
    default_value = False


class DetachRtyserverIndex(SandboxTask):
    """ Detach rtyserver index for some replica """

    type = 'DETACH_RTYSERVER_INDEX'

    cores = 1
    required_ram = 64 * 1024
    execution_space = 64 * 1024

    input_parameters = [
            ServiceName,
            ShardRange,
            Replicas,
            ResourceTTL,
            MaxDownloadSpeed,
            DelayThreshold,
            SaveAllIndexes,
            DownloadTimeout,
            GetTorrentTimeout,
            RegisterIssShards,
            IssShardname,
            ShardIdNum,
            ForceCopyAllIndexRange,
            NotCheckReplica
        ]

    def on_enqueue(self):
        replicasPerDc = json.loads(self.ctx['replicas'])
        datacenters = filter(lambda dc: replicasPerDc[dc], replicasPerDc)
        dc_to_tag = {
            "iva": ctc.Tag.IVA,
            "man": ctc.Tag.MAN,
            "vla": ctc.Tag.VLA,
            "sas": ctc.Tag.SAS
        }
        self.client_tags = self.__class__.client_tags
        add_tags = None
        for dc in datacenters:
            if dc in dc_to_tag:
                if add_tags is None:
                    add_tags = dc_to_tag[dc]
                else:
                    add_tags |= dc_to_tag[dc]

        if add_tags is not None:
            self.client_tags &= add_tags

    def clearStage(self):
        self.ctx.pop('selected_replica', None)
        self.ctx.pop('stage', None)
        self.ctx.pop('torrent', None)
        self.ctx.pop('request_id', None)
        self.ctx.pop('index_timestamp', None)
        self.ctx.pop('detach_timestamp', None)
        self.ctx.pop('download_retries', None)

    def getCurrentDelay(self, replica):
        url = replica + '/?command=get_info_server'
        data = sendRequest(url)
        timestamp = 0
        if data and 'result' in data:
            indexes = json.loads(data)['result']['indexes']
            for index in indexes:
                if indexes[index]['type'] == 'FINAL':
                    timestamp = max(timestamp, indexes[index]['timestamp'])
        delay = (time.time() - timestamp)
        logging.info('Current delay {:.1f}sec'.format(delay))
        return delay

    def replicaIsCorrect(self, replica):
        if self.ctx['not_check_replica']:
            return True
        url = replica + '/status'
        data = sendRequest(url)
        status = {}
        if not data:
            return False
        logging.debug(url + ' result: \n ' + data)
        for line in data.split('\n'):
            st = line.split(':')
            if len(st) == 2:
                status[st[0].strip()] = st[1].strip()

        status['Active'] = int(status['Active'])
        status['Consecutive_crashes'] = int(status['Consecutive_crashes'])
        status['Index_Merge_last_time'] = int(status['Index_Merge_last_time'])
        isCorrect = status['Active'] == 1 and status['Consecutive_crashes'] < 3
        isCorrect = isCorrect and (status['Consecutive_crashes'] == 0 or status['Index_Merge_last_time'] != 0)
        isCorrect = isCorrect and (self.getCurrentDelay(replica) <= self.ctx['delay_threshold'] * 60 * 60)
        if not isCorrect:
            logging.info('Replica is not correct')
        return isCorrect

    def getTorrent(self, replica):
        logging.info('Getting torrent')
        if self.ctx.get('torrent'):
            return True
        if not self.ctx.get('request_id'):
            self.ctx['detach_timestamp'] = int(time.mktime(datetime.now().timetuple()))
            shard_min, shard_max = self.ctx.get('shard_range').split('-')
            if self.ctx.get('copy_full_range', False):
                shard_min, shard_max = (0, 65533)
            url = '{}/?command=synchronizer&action=detach&sharding_type=url_hash&min_shard={}&max_shard={}&async=yes'.format(
                replica,
                shard_min,
                shard_max
            )
            logging.info('request: {}'.format(url))
            data = sendRequest(url)
            logging.debug('Detach request result:\n {}'.format(data))
            if not data:
                logging.warning('Can not request torrent from replica')
                return False
            result = json.loads(data)
            if result.get('task_status') != 'ENQUEUED':
                logging.warning('Can not request torrent')
                return False
            self.ctx['request_id'] = result.get('id')
        url = '{replica}/?command=get_async_command_info&id={id}'.format(replica=replica, id=self.ctx['request_id'])
        logging.info('Waiting request result... request_id={}'.format(self.ctx['request_id']))

        for i in range(self.ctx['get_torrent_timeout']):
            time.sleep(60)
            data = sendRequest(url)
            logging.debug('Detach status:\n {}'.format(data))
            if not data:
                break
            detach = json.loads(data)['result']
            task_status = detach.get('task_status', '')
            logging.info('task_status={} stage={}'.format(task_status, detach.get('stage')))
            if task_status == 'FINISHED':
                if 'id_res' in detach:
                    self.ctx['torrent'] = detach.get('id_res')
                    return True
                else:
                    break
            if task_status == 'FAILED' or task_status == 'NOT_FOUND':
                break
        return False

    def getIndexTimestamp(self, indexDir):
        timestampsFile = os.path.join(indexDir, 'timestamp')
        if not os.path.isfile(timestampsFile):
            logging.error('No timestamps file: {}'.format(timestampsFile))
            return 0
        with open(timestampsFile, 'r') as file:
            data = file.read()

        message = timestamp_pb2.TTimestamp()
        google.protobuf.text_format.Merge(data, message)

        maxTimestamp = 0
        for streamTimestamp in message.StreamTimestamp:
            maxTimestamp = max(maxTimestamp, streamTimestamp.MaxValue)
        return maxTimestamp

    def prepareAllIndexesResource(self, prefixPath, dirs):
        indexTimestamp = 0
        for dir in dirs:
            indexTimestamp = max(indexTimestamp, self.getIndexTimestamp(os.path.join(prefixPath, dir)))
        self.ctx['index_timestamp'] = indexTimestamp

    def prepareLargestIndexResource(self, prefixPath, dirs):
        largestIndex = {'dir': '', 'size': 0}
        for dir in dirs:
            size = getDirSize(os.path.join(prefixPath, dir))
            if size > largestIndex['size']:
                largestIndex = {'dir': dir, 'size': size}
        for dir in dirs:
            if dir != largestIndex['dir']:
                shutil.rmtree(os.path.join(prefixPath, dir))
        self.ctx['index_timestamp'] = self.getIndexTimestamp(os.path.join(prefixPath, largestIndex['dir']))

    def downloadIndex(self, torrent, downloadPath):
        logging.info('Downloading rtyserver index...')
        for tries in range(5):
            try:
                if os.path.exists(downloadPath):
                    shutil.rmtree(downloadPath)
                os.makedirs(downloadPath)
                process.run_process(
                    ['sky', 'get', '-uwp', '--max-dl-speed={}'.format(self.ctx['max_download_speed']), torrent],
                    log_prefix='sky_get',
                    work_dir=downloadPath,
                    shell=True,
                    wait=True,
                    timeout=self.ctx['download_timeout'] * 60,
                    timeout_sleep=60
                    )
                return True
            except SandboxSubprocessTimeoutError:
                logging.warning('download index timeout')
                return False
            except Exception as e:
                logging.exception(e)
            time.sleep(60)
        raise TemporaryError('Can not download index from torrent: {}'.format(torrent))

    def prepareIndexes(self, downloadPath):
        downloadDirs = os.listdir(downloadPath)
        if len(downloadDirs) == 0:
            logging.warning('Downloaded index is empty')
            return False
        if self.ctx['save_all_indexes']:
            self.prepareAllIndexesResource(downloadPath, downloadDirs)
        else:
            self.prepareLargestIndexResource(downloadPath, downloadDirs)
        return (self.ctx['index_timestamp'] != 0)

    def detachIndex(self, replica):
        logging.info('Selected replica: {}'.format(self.ctx['selected_replica']))
        if self.ctx['stage'] == 'SELECTED':
            if not self.replicaIsCorrect(replica):
                return False
            self.ctx['stage'] = 'GETTING_TORRENT'
        if self.ctx['stage'] == 'GETTING_TORRENT':
            if not self.getTorrent(replica):
                logging.warning('Can not get torrent')
                return False
        if self.ctx['stage'] == 'GETTING_TORRENT' and self.ctx.get('torrent'):
            self.ctx['stage'] = 'DOWNLOAD'
            self.ctx['download_retries'] = 0
        if self.ctx['stage'] == 'DOWNLOAD':
            self.ctx['download_retries'] += 1
            if self.ctx['download_retries'] > 3:
                logging.warning('Can not download indexes after {} tries'.format(self.ctx['download_retries']))
                return False
            downloadPath = DOWNLOAD_PATH
            if not self.downloadIndex(self.ctx['torrent'], downloadPath):
                logging.warning('Can not download indexes')
                return False
            if not self.prepareIndexes(downloadPath):
                return False
            description = ('all' if self.ctx['save_all_indexes'] else 'largest') + ' rtyserver index segments'
            description += 'from {service}, shard={shard}'.format(service=self.ctx['service_name'], shard=self.ctx['shard_range'])
            self.create_resource(
                description=description,
                resource_path=downloadPath,
                resource_type=resource_types.RTYSERVER_SEARCH_DATABASE,
                attributes={
                    'service': self.ctx['service_name'],
                    'shard': self.ctx['shard_range'],
                    'detach_timestamp': self.ctx['detach_timestamp'],
                    'index_timestamp': self.ctx['index_timestamp'],
                    'ttl': self.ctx['resource_ttl'],
                    'full_index': self.ctx['save_all_indexes'],
                    'shard_id_num': self.ctx['shard_id_num']
                }
            )
            self.ctx['stage'] = 'COMPLITE'
            return True
        return False

    def extract_replica(self, replicas):
        return replicas.pop() if replicas else None

    def next_replica(self):
        dc = common.config.Registry().this.dc
        replica = self.extract_replica(self.ctx['unused_replicas'].get(dc, []))
        if replica:
            return replica

        for dc in self.ctx['unused_replicas']:
            replica = self.extract_replica(self.ctx['unused_replicas'][dc])
            if replica:
                return replica
        return None

    def on_prepare(self):
        if not self.ctx.get('unused_replicas'):
            self.ctx['unused_replicas'] = json.loads(self.ctx['replicas'])

    def registerIssShard(self):
        iss_shards_resource = apihelpers.get_last_resource_with_attribute(resource_types.ISS_SHARDS, "released", "stable")
        if not iss_shards_resource:
            raise Exception('There is no iss_shards resource')

        iss_shards = self.sync_resource(iss_shards_resource)

        logging.info('iss_shard configure...')
        process.run_process(
            [iss_shards, 'configure', '--id', self.ctx['iss_shardname'], DOWNLOAD_PATH],
            log_prefix="iss_shards_configure",
            work_dir=self.path(),
        )

        logging.info('iss_shard register...')
        process.run_process(
            [iss_shards, 'register', '--with-dir', DOWNLOAD_PATH],
            log_prefix="iss_shards_register",
            work_dir=self.path(),
        )

    def on_execute(self):
        logging.info('service={service} shard={shard}'.format(service=self.ctx['service_name'], shard=self.ctx['shard_range']))
        logging.info('task stage: {}'.format(self.ctx.get('stage')))
        logging.info('unused replicas:\n{}'.format(self.ctx['unused_replicas']))
        while True:
            if not self.ctx.get('stage'):
                self.ctx['selected_replica'] = self.next_replica()
                if not self.ctx['selected_replica']:
                    raise TaskFailure('Connot get next replica')
                self.ctx['stage'] = 'SELECTED'
            if self.detachIndex(self.ctx['selected_replica']):
                break
            self.clearStage()

        if self.ctx.get('stage') != 'COMPLITE':
            raise TaskFailure('Can not detach index.')
        else:
            if self.ctx['register_iss_shards']:
                self.registerIssShard()


__Task__ = DetachRtyserverIndex
