# -*- coding: utf-8 -*
import time
import json
import logging
import datetime
import requests

from sandbox import sdk2
from sandbox.sandboxsdk import environments
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

from sandbox.projects.geosearch.tools.misc import retry


CRITICAL_ROW_COUNT_CHANGE_PERC = 45.0


class AddrsSnipippetsPushToFerryman(sdk2.Task):
    '''
        Push table with geosearch snippets to Ferryman
    '''

    class Parameters(sdk2.task.Parameters):
        kill_timeout = 240
        cluster = sdk2.parameters.String('YT cluster',
                                         default_value='hahn.yt.yandex.net')
        table_name = sdk2.parameters.String('Snippet name',
                                            required=True)
        table_path = sdk2.parameters.String('Table path',
                                            required=True)
        ferryman_url = sdk2.parameters.String('Ferryman URL',
                                              required=True)
        namespace = sdk2.parameters.String('Namespace',
                                           required=True)
        snippet_params = sdk2.parameters.String('Snippet parameters',
                                                required=True)
        erase = sdk2.parameters.Bool('Remove all snippets from namespace',
                                     default_value=False)
        load_anyway = sdk2.parameters.Bool('Ignore too big row count change',
                                           default_value=False)

    class Requirements(sdk2.Task.Requirements):
        cores = 1
        ram = 2048
        environments = (environments.PipEnvironment('yandex-yt', use_wheel=True),)

        class Caches(sdk2.Requirements.Caches):
            pass

    def get_table_timestamp(self):
        timestamp = int(time.mktime(datetime.datetime.now().timetuple()))
        return timestamp * 1000000

    def get_row_count(self, path):
        import yt.wrapper as yt
        yt.config['token'] = self.yt_token
        yt.config['proxy']['url'] = self.Parameters.cluster
        return yt.get('{tbl}/@row_count'.format(tbl=path))

    @retry(tries=5, delay=10)
    def get_snippets_count_from_saas(self):
        uri = 'get_namespaces'
        url = '{url}/{uri}'.format(
            url=self.Parameters.ferryman_url,
            uri=uri,
        )
        try:
            resp = requests.get(url)
            data = resp.json()
            flat_namespaces = []
            for key in data.keys():
                flat_namespaces.extend(data.get(key))
            namespaces = {namespace['namespace']: namespace['rowCount'] for namespace in flat_namespaces}
            return namespaces.get(self.Parameters.namespace, 0)
        except Exception as err:
            logging.info('Failed to get rowCount for {} from SaaS'.format(self.Parameters.namespace))

    def get_critical_row_count_change(self, params):
        if not params.get('ignore_row_count_change'):
            self.Context.saas_row_count = self.get_snippets_count_from_saas()
            if self.Context.saas_row_count == 0:     # New snippet
                return False
            self.Context.current_row_count = self.get_row_count(self.Parameters.table_path)
            if self.Context.saas_row_count and self.Context.current_row_count:
                changes = self.Context.saas_row_count - self.Context.current_row_count
                self.Context.changed_perc = changes * 100.0 / self.Context.saas_row_count
                return abs(self.Context.changed_perc) >= CRITICAL_ROW_COUNT_CHANGE_PERC
        return False

    def non_empty_table_exists(self, path):
        import yt.wrapper as yt
        yt.config['token'] = self.yt_token
        yt.config['proxy']['url'] = self.Parameters.cluster
        if yt.exists(path):
            return int(self.get_row_count(path)) > 0

    def create_empty_table(self):
        import yt.wrapper as yt
        yt_config = {
            'proxy': {'url': self.Parameters.cluster},
            'token': self.yt_token,
        }
        client = yt.YtClient(config=yt_config)
        table_path = '//tmp/erase_{}'.format(self.Parameters.namespace)
        client.write_table(table_path, [], raw=False)
        return table_path

    def push_to_ferryman(self):
        uri = 'add-full-tables'
        if self.Parameters.erase:
            empty_table_path = self.create_empty_table()
            data = [{
                'Namespace': str(self.Parameters.namespace),
                'Cluster': self.Parameters.cluster.split('.')[0],
                'Path': str(empty_table_path),
                'Timestamp': self.get_table_timestamp()
            }]
        elif self.non_empty_table_exists(self.Parameters.table_path):
            data = [{
                'Namespace': str(self.Parameters.namespace),
                'Cluster': self.Parameters.cluster.split('.')[0],
                'Path': str(self.Parameters.table_path),
                'Timestamp': self.get_table_timestamp()
            }]
        else:
            msg = '{table} not exists or is empty'.format(table=self.Parameters.table_path)
            raise SandboxTaskFailureError(msg)
        try:
            logging.info('Trying to push: {data}'.format(data=data))
            url = '{url}/{uri}?tables={data}&delta=false'.format(
                url=self.Parameters.ferryman_url,
                uri=uri,
                data=json.dumps(data)
            )
            resp = requests.get(url)
            data = resp.json()
            logging.info('Ferryman URL:  {url}'.format(url=url))
            logging.info('Ferryman response text: {resp_text}'.format(resp_text=resp.text))
            logging.info('Ferryman response code: {resp_code}'.format(resp_code=resp.status_code))
            return data.get('batch')
        except Exception as err:
            logging.info('Failed to push table:  {url}'.format(url=url))
            logging.info('Ferryman response text: {resp_text}'.format(resp_text=resp.text))
            logging.info('Ferryman response code: {resp_code}'.format(resp_code=resp.status_code))
            logging.info(err)

    def get_batch_status(self, batch):
        if batch:
            uri = 'get-batch-status'
            url = '{url}/{uri}?batch={batch}'.format(url=self.Parameters.ferryman_url,
                                                     uri=uri,
                                                     batch=batch)
            resp = requests.get(url)
            try:
                data = resp.json()
                return data.get('status')
            except Exception as err:
                logging.info('Failed to get status from  {url}'.format(url=url))
                logging.info('Ferryman response text: {resp_text}'.format(resp_text=resp.text))
                logging.info('Ferryman response code: {resp_code}'.format(resp_code=resp.status_code))
                logging.info(err)
        else:
            self.Context.status = 'error'

    def get_batch_error(self, batch):
        uri = 'get-batch-status'
        url = '{url}/{uri}?batch={batch}'.format(url=self.Parameters.ferryman_url,
                                                 uri=uri,
                                                 batch=batch)
        resp = requests.get(url)
        try:
            data = resp.json()
            if data.get('status') == 'error':
                return data.get('error')
        except Exception as err:
            logging.info('Failed to get status from  {url}'.format(url=url))
            logging.info('Ferryman response text: {resp_text}'.format(resp_text=resp.text))
            logging.info('Ferryman response code: {resp_code}'.format(resp_code=resp.status_code))
            logging.info(err)
            raise SandboxTaskFailureError('Failed to get batch status')

    def set_tasks_tags(self, params):
        try:
            snippet_name = params.get('snippet_name') or self.Parameters.table_name
            tags = set(self.Parameters.tags)
            tags.add(snippet_name.upper())
            self.Parameters.tags = list(tags)
        except Exception as err:
            logging.info('Failed to set tags. Details %s' % err)

    def on_execute(self):
        self.Context.upload_report = '\nUpload task: https://sandbox.yandex-team.ru/task/{tid}/view\n'.format(tid=self.id)
        if not self.Context.RESTART_COUNT:
            self.Context.RESTART_COUNT = 0
        self.yt_token = sdk2.Vault.data('GEOMETA-SEARCH', 'yt-token')
        params = json.loads(self.Parameters.snippet_params)
        self.set_tasks_tags(params)
        target_status = 'final' if ('geo-fast-export' in self.Parameters.ferryman_url or 'geo-recom-profiles' in self.Parameters.ferryman_url) else 'searchable'
        if not self.Parameters.load_anyway and self.get_critical_row_count_change(params):
            message = ('{table} row count ({current_count}) changed more than '
                       '{crit} percents compared to number of snippets '
                       'in SaaS ({saas_count})').format(table=self.Parameters.table_path,
                                                        crit=CRITICAL_ROW_COUNT_CHANGE_PERC,
                                                        current_count=self.Context.current_row_count,
                                                        saas_count=self.Context.saas_row_count)
            raise SandboxTaskFailureError(message)
        while (
            (
                not self.Context.batch or
                self.Context.status != target_status
            ) and
            self.Context.RESTART_COUNT <= 3
        ):
            if not self.Context.batch:
                self.Context.batch = self.push_to_ferryman()
                time.sleep(60)
            self.Context.status = self.get_batch_status(self.Context.batch)
            if self.Context.status == 'error':
                logging.info('Batch {id} failed because of {err}'.format(id=self.Context.batch, err=self.get_batch_error(self.Context.batch)))
                logging.info('Task will be restarted')
                self.Context.batch = ''
                self.Context.RESTART_COUNT += 1
            while self.Context.status not in [target_status, 'error']:
                msg = 'Batch {batch_id} status {status}'.format(batch_id=self.Context.batch,
                                                                status=self.get_batch_status(self.Context.batch))
                logging.info(msg)
                raise sdk2.WaitTime(300)
        if self.Context.status == 'error':
            msg = 'Batch {batch_id} failed. Details from Ferryman:\n{details}'
            self.Context.upload_report += msg.format(format(batch_id=self.Context.batch,
                                                            details=self.get_batch_error(self.Context.batch)))
            raise SandboxTaskFailureError(self.Context.upload_report)
        elif self.Context.status == 'searchable':
            msg = 'Batch {batch_id} became searchable in SaaS-KV at {now}'
            self.Context.upload_report += msg.format(batch_id=self.Context.batch,
                                                     now=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
