# -*- coding: utf-8 -*-

import os
import json
import logging
import datetime
from time import sleep
from sandbox import sdk2
import sandbox.common.types.task as ctt
from sandbox.sdk2.vcs.svn import Arcadia
from sandbox.sandboxsdk import environments
import sandbox.common.types.notification as ctn
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.projects.geosearch.snippets.AddrsSnippetsTask import AddrsSnippetsTask


ALTAY_TABLE = '//home/sprav/providers/task_manager_altay_temp'
SPRAV_FEEDS_DIR = '//home/altay/providers/tables'


class TaskParametersParserException(Exception):
    pass


class AddrsSnippetsTaskManager(sdk2.Task):
    '''
        Geosearch snippets task manager
    '''

    class Parameters(sdk2.task.Parameters):
        task_manager_config = sdk2.parameters.String('Arcadia path to task manager`s config',
                                                     default_value='arcadia:/arc/trunk/arcadia/search/geo/tools/snippets_processors/task_manager.json')
        task_manager_config_text = sdk2.parameters.String('JSON with task manager`s config',
                                                          default_value='')
        task_manager_path = sdk2.parameters.String('YT path to task manager directory',
                                                   default_value='//home/geosearch-prod/snippets')
        geobasesearch_db_sync = sdk2.parameters.Bool('Sync snippets upload with geosearch DB release',
                                                     default_value=False)
        fast = sdk2.parameters.Bool('Fast task manager (just generate and upload snippets)',
                                    default_value=False)
        mail_list = sdk2.parameters.List('Mail list')

    class Requirements(sdk2.Task.Requirements):
        cores = 1  # exactly 1 core
        ram = 8192  # 8GiB or less
        environments = (environments.PipEnvironment('yandex-yt', use_wheel=True),)

        class Caches(sdk2.Requirements.Caches):
            pass

    def _add_params(self, table_name, params):
        import yt.yson as yson
        errors = []
        results_dir = os.path.join(self.Parameters.task_manager_path, 'results')
        error_log_dir = os.path.join(self.Parameters.task_manager_path, 'validation_errors')
        error_log_dir = error_log_dir.replace('fast_task_manager/', '')
        prev_results_dir = os.path.join(self.Parameters.task_manager_path, 'prev_results')
        if self.Parameters.task_manager_config.endswith('task_manager_test.json'):
            params.update({'test': True})
        if 'cluster' not in params:
            params.update({'cluster': 'hahn.yt.yandex.net'})
        test_suffix = '_test' if params.get('test') else ''
        if params.get('pre'):
            params.update({'pre_processing_out': '{tm_home}/geosearch-snippets-{table_name}-{tid}-pre{test}'.format(tm_home=self.Parameters.task_manager_path,
                                                                                                                    table_name=table_name,
                                                                                                                    tid=self.id,
                                                                                                                    test=test_suffix)})
        if params.get('post'):
            params.update({'generating_out': '{tm_home}/gen_geosearch-snippets-{table_name}-{tid}{test}'.format(tm_home=self.Parameters.task_manager_path,
                                                                                                                table_name=table_name,
                                                                                                                tid=self.id,
                                                                                                                test=test_suffix)})
        if params.get('direct_reduce'):    # Don`t merge generated table with altay
            params.update({'processing_out': '{results_dir}/{table_name}'.format(results_dir=results_dir, table_name=table_name)})
        else:
            params.update({'processing_out': '{tm_home}/geosearch-snippets-{table_name}-{tid}-finished{test}'.format(tm_home=self.Parameters.task_manager_path,
                                                                                                                     table_name=table_name,
                                                                                                                     tid=self.id,
                                                                                                                     test=test_suffix)})
        params.update({'output_path': '{results_dir}/{table_name}'.format(results_dir=results_dir, table_name=table_name)})
        schema = ''
        if 'table_schema' not in params or len(params.get('table_schema')) > 0:
            default_schema = [
                {'name': 'Url', 'type': 'any'},
                {'name': params.get('snippet_name'), 'type': 'any'}
            ]
            yson_schema = yson.YsonList(params.get('table_schema') or default_schema)
            yson_schema.attributes['strict'] = False
            schema = '<schema={schema}>'.format(schema=yson.dumps(yson_schema))

        processing_out = params.get('processing_out')
        params.update({'processing_out': ('{schema}'
                                          '{output}').format(schema=schema,
                                                             output=processing_out)})
        output_path = params.get('output_path')
        params.update({'output_path': ('{schema}'
                                       '{output}').format(schema=schema,
                                                          output=output_path)})
        params.update({'prev_result': '{prev_results}/{table_name}'.format(prev_results=prev_results_dir, table_name=table_name)})
        params.update({'error_log': '{error_log_dir}/{table_name}{test}'.format(error_log_dir=error_log_dir, table_name=table_name, test=test_suffix)})
        if params.get('test'):
            params.update({'saaskv_url': 'http://geosnippets-t.ferryman.n.yandex-team.ru/'})
        else:
            params.update({'test': False})
        if not params.get('saaskv_url'):
            params.update({'saaskv_url': 'http://geo-snippets.ferryman.n.yandex-team.ru/'})
        if not params.get('out_stats'):
            params.update({'out_stats': 'solomon_stats.json'})
        if not params.get('solomon_labels'):
            params.update({'solomon_labels': {'project': 'geosearch_snippents',
                                              'cluster': '0',
                                              'service': 'push'}})
        if errors:
            raise TaskParametersParserException('\n'.join(errors))
        if params.get('key_type') == 'provider':
            assert params.get("provider_name"), 'Parameter "provider_name" is required for "provider" key_type'
            assert params.get("original_id_field"), 'Parameter "original_id_field" is required for "provider" key_type'
        if self.Parameters.fast and not params.get('yt_ttl'):   # Store tables from fast task manager for 24 hours
            params.update({'yt_ttl': 1})
        if params.get('make_sprav_feed'):
            params.update({
                'sprav_feed': os.path.join(SPRAV_FEEDS_DIR, table_name, table_name),
                'feed_original_id_field': params.get("original_id_field") or params.get("permalink_field")
            })
        return params

    def get_task_manager_confg(self):
        if self.Parameters.task_manager_config:
            arcadia_path, script_name = os.path.split(self.Parameters.task_manager_config)
            checkout_path = Arcadia.export(arcadia_path, './config')
            json_file = os.path.join(checkout_path, script_name)
            return json.load(open(json_file))
        if self.Parameters.task_manager_config_text:
            return json.loads(self.Parameters.task_manager_config_text)

    def get_tasks(self):
        return [{'table_name': key, 'params': self._add_params(key, value)} for key, value in self.Context.configs.iteritems()]

    def _check_subtasks(self):
        logging.info('Checking subtasks')
        err_list = []
        for task in self.find():
            if str(task.status) in self.bad_statuses:
                err_list.append('Task #{task_id} {description} failed'.format(task_id=task.id,
                                                                              description=task.Parameters.description))
        logging.info('Error messages: %s' % err_list)
        if err_list:
            raise SandboxTaskFailureError('\n'.join(err_list))

    def _yt_path_exists(self, path):
        import yt.wrapper as yt
        yt.config['token'] = self.yt_token
        yt.config['proxy']['url'] = 'hahn.yt.yandex.net'
        return yt.exists(path)

    def _get_table_timedelta(self, table_path):
        import yt.wrapper as yt
        yt.config['token'] = self.yt_token
        yt.config['proxy']['url'] = 'hahn.yt.yandex.net'
        try:
            mtime = yt.get_attribute(table_path, 'modification_time')
            mtime_tuple = datetime.datetime.strptime(mtime,
                                                     '%Y-%m-%dT%H:%M:%S.%fZ')
            delta = datetime.datetime.now() - mtime_tuple
            return delta.days
        except Exception:
            return 0

    def get_input_tables_mtime(self):
        self.Context.mtimes = {}
        for task in self.Context.snippet_tasks:
            params = task.get('params')
            input_path = params.get('input_table')
            self.Context.mtimes.update({task.get('table_name'):
                                        self._get_table_timedelta(input_path)})

    def launch_snippet_generation(self):
        self.Context.gen_snippets_tasks = []
        snippet_task_class = sdk2.Task[AddrsSnippetsTask.type]
        for task in self.Context.snippet_tasks:
            task_description = '{descr}: generating "{snippet}"'.format(descr=self.Parameters.description,
                                                                        snippet=task.get('table_name'))
            params = task.get('params')
            snippet_task = snippet_task_class(self,
                                              owner=self.owner,
                                              description=task_description,
                                              kill_timeout=params.get('sb_task_kill_timeout') or 7200,
                                              cluster=params.get('cluster', 'hahn.yt.yandex.net'),
                                              snippet_task=json.dumps(task),
                                              run_upload_task=(not self.Parameters.geobasesearch_db_sync),
                                              notifications=self._sb_notifications(params.get('notify', [])))
            snippet_task.enqueue()
            self.Context.gen_snippets_tasks.append(snippet_task.id)
        raise sdk2.WaitTask(self.Context.gen_snippets_tasks,
                            ctt.Status.Group.FINISH | ctt.Status.Group.BREAK,
                            wait_all=True)

    def _get_report_from_stage(self, task_list, snippet_name):
        for task_id in task_list:
            task = sdk2.Task[task_id]
            try:
                snippet_task = json.loads(task.Parameters.snippet_task)
            except Exception:
                snippet_task = {}
            if snippet_task.get('table_name') == snippet_name:
                return {'report': task.Context.report,
                        'failed': task.Context.failed,
                        'validation_errors': task.Context.validation_errors}

    def validation_errors_report(self, snippet_name):
        import yt.wrapper as yt
        yt.config['token'] = self.yt_token
        yt.config['proxy']['url'] = 'hahn.yt.yandex.net'
        for snippet_task in self.Context.snippet_tasks:
            if snippet_task.get('table_name') == snippet_name:
                params = snippet_task.get('params')
                msg = ''
                if self._yt_path_exists(params.get('error_log')) and yt.is_empty(params.get('error_log')):
                    msg = ('\nThere is no validation errors in '
                           'https://yt.yandex-team.ru/hahn/'
                           '?page=navigation&path={path}')
                elif self._yt_path_exists(params.get('error_log')) and not yt.is_empty(params.get('error_log')):
                    msg = ('\nGot some validation errors. Check:  '
                           'https://yt.yandex-team.ru/hahn/'
                           '?page=navigation&path={path} for details')
                return msg.format(path=params.get('error_log'))

    def get_timedelta_warning(self, snippet_name):
        for task in self.Context.snippet_tasks:
            if task.get('table_name') == snippet_name:
                params = task.get('params')
                input_path = params.get('input_table')
        if self.Context.mtimes.get(snippet_name, 0) > 3:
            return '\nWARNING! Input table {table} has not been updated for 3 days\n'.format(table=input_path)
        return ''

    def get_report(self, snippet_name):
        return self._get_report_from_stage(self.Context.gen_snippets_tasks, snippet_name)

    def _get_addressees(self, addressees, y_team=True, extend_default=True):
        if extend_default:
            addressees.extend(self.Parameters.mail_list)
        if y_team:
            return {'{login}@yandex-team.ru'.format(login=login) if '@yandex-team.ru' not in login else login for login in addressees}
        return {login.replace('@yandex-team.ru', '') for login in addressees}

    def _sb_notifications(self, notify_list):
        addressees = self._get_addressees(notify_list,
                                          y_team=False,
                                          extend_default=True)
        if addressees:
            return sdk2.Notification(tuple(self.bad_statuses),
                                     addressees,
                                     ctn.Transport.EMAIL)
        return ''

    def notify(self):
        import smtplib
        from email.mime.text import MIMEText
        for snippet_task in self.Context.snippet_tasks:
            try:
                snippet_type = snippet_task.get('table_name')
                snippet_params = snippet_task.get('params')
                logging.info('Trying to send report on %s snippets' % snippet_type)
                report_data = self.get_report(snippet_type)
                message_text = report_data.get('report')
                message = MIMEText(message_text, '', 'utf-8')
                message['From'] = 'sandbox-urgent@yandex-team.ru'
                failed = report_data.get('failed') or report_data.get('validation_errors')
                if failed and snippet_params.get('notify_on_error'):
                    addressees = self._get_addressees(snippet_params.get('notify_on_error', []))
                else:
                    addressees = self._get_addressees(snippet_params.get('notify', []))
                if addressees:
                    message['To'] = ', '.join(addressees)
                    message['Subject'] = 'Snippets task manager report on %s' % snippet_type
                    logging.info('Trying to send message to {addressees}'.format(addressees=addressees))
                    server = smtplib.SMTP('omail.yandex.ru')
                    server.sendmail('sandbox-urgent@yandex-team.ru',
                                    addressees,
                                    message.as_string())
                    server.quit()
                    sleep(10)
            except Exception:
                logging.exception('Failed to send report')

    def on_execute(self):
        self.Context.configs = self.get_task_manager_confg()
        self.base_path = self.Parameters.task_manager_path
        self.tasks_table_path = os.path.join(self.base_path, 'tasks')
        self.results_path = os.path.join(self.base_path, 'results')
        self.prev_results_path = os.path.join(self.base_path, 'prev_results')
        self.yt_token = sdk2.Vault.data('GEOMETA-SEARCH', 'yt-token')
        self.Context.snippet_tasks = self.get_tasks()
        logging.info('Tasks: %s' % self.Context.snippet_tasks)
        self.provider_ids = set([task["params"].get("provider_id") for task in self.Context.snippet_tasks if task["params"].get("key_type") == "provider"])
        logging.info('Provider IDs: %s' % self.provider_ids)
        self.with_permalink = True
        self.bad_statuses = ['FAILURE',
                             'EXCEPTION',
                             'TIMEOUT']
        self.get_input_tables_mtime()
        for task in self.Context.snippet_tasks:
            if task['params'].get('key_type') != "permalink":
                self.with_permalink = False
        with self.memoize_stage.GENERATE_SNIPPETS(commit_on_entrance=False):
            self.launch_snippet_generation()
        with self.memoize_stage.NOTIFY(commit_on_entrance=False):
            self.notify()
        self._check_subtasks()
