import logging
import os
import time
import urlparse

import datetime
import requests

from multiprocessing.pool import ThreadPool
from functools import partial

import sandbox.sandboxsdk.parameters as sdk_parameters
from sandbox.sandboxsdk import environments
from sandbox.sandboxsdk import errors
from sandbox.sandboxsdk.task import SandboxTask

ADMIN_BACKUP_SQL = '/admin/backup/mongo'

logging.basicConfig(
    level=logging.INFO,
    format='[%(levelname)s %(asctime)s]: %(message)s',
    datefmt='%Y-%m-%d %I:%M:%S'
)

class YtTargetPathParam(sdk_parameters.SandboxStringParameter):
    name = 'target_path'
    description = 'YT target path'
    default_value = '//home/search-functionality/sup/backup'
    required = True


class BackupHostParam(sdk_parameters.SandboxStringParameter):
    name = 'backup_host'
    description = 'Backup host'
    required = True
    default_value = 'push-beta.n.yandex-team.ru'


class BackupTablesParam(sdk_parameters.ListRepeater, sdk_parameters.SandboxStringParameter):
    name = 'backup_tables'
    description = 'Backup tables'
    required = True


class IsSeqParam(sdk_parameters.SandboxBoolParameter):
    name = 'backup_is_seq'
    description = 'Sequential run'
    default_value = True
    required = True


class NumThreadsParam(sdk_parameters.SandboxIntegerParameter):
    name = 'num_threads'
    description = 'Num threads'
    default_value = 1
    required = True


class YtClusterParam(sdk_parameters.SandboxStringParameter):
    name = 'yt_cluster'
    description = 'Yt cluster'
    required = True
    default_value = 'hahn.yt.yandex.net'


class ShardNumParam(sdk_parameters.SandboxIntegerParameter):
    name = 'shard_num'
    description = 'Shard num'
    default_value = -1
    required = True


class DaysToKeepParam(sdk_parameters.SandboxIntegerParameter):
    name = 'days_to_keep'
    description = 'Days to keep'
    default_value = -1
    required = True


class SupBackupTask(SandboxTask):
    type = 'SUP_BACKUP'
    input_parameters = (
        YtTargetPathParam, BackupHostParam, BackupTablesParam, IsSeqParam, NumThreadsParam, YtClusterParam,
        ShardNumParam,
        DaysToKeepParam
    )
    environment = (
        environments.PipEnvironment('yandex-yt', use_wheel=True),
    )

    def on_execute(self):
        if self.ctx.get('subtask_id'):
            return
        token = self.get_vault_data(self.owner, 'robot_sup_hahn_token')
        host = self.ctx[BackupHostParam.name]
        target = self.ctx[YtTargetPathParam.name]
        tables = self.ctx[BackupTablesParam.name]
        wait = self.ctx[IsSeqParam.name]
        num_threads = self.ctx[NumThreadsParam.name]
        base_path = os.path.join(target, host, datetime.datetime.now().strftime('%Y-%m-%d'))
        shard = None if self.ctx[ShardNumParam.name] < 0 else self.ctx[ShardNumParam.name]
        days_to_keep = None if self.ctx[DaysToKeepParam.name] < 0 else self.ctx[DaysToKeepParam.name]
        import yt.wrapper as yt
        yt.config['token'] = token
        yt.config['proxy']['url'] = self.ctx[YtClusterParam.name]
        yt.mkdir(base_path, True)
        if days_to_keep:
            expiration_time = (datetime.datetime.now() + datetime.timedelta(days=days_to_keep)).isoformat("T")
            yt.set(base_path + '/@expiration_time', expiration_time)
        ThreadPool(num_threads).map(
            partial(self.backup, host, base_path, wait=wait, shard=shard), tables)

    @staticmethod
    def backup(host, base_path, tablespec, wait=True, shard=None):
        parts = tablespec.split(':')
        table = parts.pop(0) if parts else None
        chunk_size = int(parts.pop(0)) if parts else None
        limit = int(parts.pop(0)) if parts else None
        table_path = os.path.join(base_path, table)
        table_part = table if not shard else os.path.join('sharded', table)
        url = urlparse.urlunsplit(('http', host, os.path.join(ADMIN_BACKUP_SQL, table_part), None, None))
        response = requests.post(url,
                                 json={
                                     'path': table_path,
                                     'chunk_size': chunk_size,
                                     'limit': limit,
                                     'shard': shard
                                 },
                                 headers={'Content-Type': 'application/json'})
        if response.status_code != 200:
            raise errors.SandboxTaskFailureError('SUP API Error (%d):\n%s' % (response.status_code, response.text))
        json = response.json()
        job_id = json.get('jobExecutionId')
        if job_id is None:
            raise errors.SandboxTaskFailureError('SUP API Error (%d):\n%s' % (response.status_code, response.text))
        logging.info('Starting backup of %s using %s into %s', (table, job_id, table_path))
        if not wait:
            return job_id
        front_host = response.headers['X-Yandex-Front']
        if front_host:
            host = front_host
        front_port = response.headers['X-Yandex-Front-Port']
        if front_port:
            host += ':' + front_port
        url = urlparse.urlunsplit(('http', host, os.path.join(ADMIN_BACKUP_SQL, str(job_id)), None, None))
        try:
            while True:
                try:
                    response = requests.get(url, headers={'Content-Type': 'application/json'})
                except requests.exceptions.ReadTimeout:
                    continue

                if response.status_code == 504:
                    continue

                if response.status_code != 200:
                    raise errors.SandboxTaskFailureError(
                        'SUP API Error (%d):\n%s' % (response.status_code, response.text))
                if not wait:
                    break
                json = response.json()
                status = json.get('status')
                if status == 'COMPLETED':
                    return json.get('jobExecutionId')
                elif status == 'FAILED':
                    raise errors.SandboxTaskFailureError('Backup failed (%s):\n%s' % (url, json.get('exitDescription')))
                time.sleep(5.0)
        except KeyboardInterrupt:
            return None
        logging.info('Finished backup of %s using %s into %s', (table, job_id, table_path))
        return job_id


__Task__ = SupBackupTask
