# -*- coding: utf-8 -*

import os
import re
import json
import shutil
import jinja2
import logging
import tarfile
import requests
import subprocess
from xml.dom import minidom
from datetime import datetime, timedelta

from sandbox import sdk2
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
from sandbox.sdk2.vcs.svn import Arcadia
from sandbox.sandboxsdk import environments
from sandbox.projects.yql.RunYQL2 import RunYQL2
from sandbox.sandboxsdk.paths import get_logs_folder
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.projects.geosearch import resource_types as geotypes
from sandbox.projects.common.solomon import push_to_solomon_v2, create_sensors_from_file
from sandbox.projects.geosearch.snippets.AddrsSnipippetsPushToFerryman import AddrsSnipippetsPushToFerryman


DEFAULT_YT_TTL = 2      # Days


def get_row_count(task, table):
    import yt.wrapper as yt
    yt.config['token'] = task.yt_token
    yt.config['proxy']['url'] = task.Parameters.cluster
    if yt.exists(table):
        return int(yt.get_attribute(table, 'row_count'))
    return 0


def post_stats(params):
    try:
        stats_file = params.get('out_stats')
        if params.get('solomon_labels') and not params.get('test') and os.path.exists(stats_file):
            token = sdk2.Vault.data('GEOSEARCH_PROD', 'GEOSEARCH_SOLOMON_TOKEN ')
            push_to_solomon_v2(token,
                               params=params.get('solomon_labels'),
                               sensors=create_sensors_from_file(stats_file))
    except Exception as err:
        logging.info('Failed to push data to Solomon\n%s' % err)


def get_ttl(number_of_days):
    ts = datetime.now()
    ts += timedelta(days=number_of_days)
    return ts.isoformat()


class AddrsSnippetsTask(sdk2.Task):
    '''
        Generates addrs snippets
    '''

    class Parameters(sdk2.task.Parameters):
        cluster = sdk2.parameters.String('YT cluster',
                                         default='hahn.yt.yandex.net')
        snippet_task = sdk2.parameters.String('Snippet task dict',
                                              required=True)
        run_upload_task = sdk2.parameters.Bool('Run upload task after generation',
                                               default=True)
        load_anyway = sdk2.parameters.Bool('Ignore too big row count change',
                                           default_value=False)

    class Requirements(sdk2.Task.Requirements):
        cores = 1  # exactly 1 core
        ram = 8192  # 8GiB or less
        dns = ctm.DnsType.DNS64
        environments = (environments.PipEnvironment('yandex-yt', use_wheel=True),
                        environments.PipEnvironment('yandex-yt-yson-bindings', use_wheel=True),
                        environments.PipEnvironment('hashlib'),
                        environments.PipEnvironment('lxml', use_wheel=True))

        class Caches(sdk2.Requirements.Caches):
            pass

    def _yt_path_exists(self, path):
        import yt.wrapper as yt
        yt.config['token'] = self.yt_token
        yt.config['proxy']['url'] = self.Parameters.cluster
        return yt.exists(path)

    def _get_generator_path(self, generator, test):
        if test:  # fallback to prod path if test path doesn't exist
            test_path = os.path.join(self.Context.task_manager_dir, 'generators_test', generator)
            if os.path.exists(test_path):
                return test_path
        return os.path.join(self.Context.task_manager_dir, 'generators', generator)

    @staticmethod
    def _download_resource_get_file_path(resource):
        resource_path = str(sdk2.ResourceData(resource).path)

        if os.path.isdir(resource_path):
            for root, parent, files in os.walk(resource_path):
                for fname in files:
                    if os.path.isfile(os.path.join(root, fname)):
                        return os.path.join(root, fname)
        else:
            return resource_path

    def _get_sb_resource(self, resource_dict):
        if resource_dict.get('id'):
            resource = sdk2.Resource[resource_dict.get('id')]
            path = self._download_resource_get_file_path(resource)
        elif resource_dict.get('type'):
            resource = sdk2.Resource[resource_dict.get('type')].find(attrs=resource_dict.get('attrs', {})).first()
            path = self._download_resource_get_file_path(resource)
        else:
            path = resource_dict
        return path

    def make_stat_file(self, params):
        stats = {}
        tmp = {'{snpt}_generated': get_row_count(self, self.Context.output_table),
               '{snpt}_validation_errors': get_row_count(self, params.get('error_log', '//tmp/not_existing_table'))}
        for key, value in tmp.iteritems():
            stats.update({key.format(snpt=self.task.get('table_name')): value})
        logging.info('Solomon stats: %s' % stats)
        json.dump(stats,
                  open(params.get('out_stats'), 'w'),
                  sort_keys=True,
                  indent=2)

    def get_data_resources(self, task_params):
        generators = ['pre', 'generation', 'post']
        for key, value in task_params.iteritems():
            if type(value) == dict and key not in generators:
                resource_path = self._get_sb_resource(value)
                task_params.update({key: resource_path or value})
        return task_params

    def get_generators(self, task_params):
        default = {'generation': 'snippet_processor.py',
                   'pre': '',
                   'post': ''}
        scripts = {}
        for stage in ['pre', 'generation', 'post']:
            if task_params.get(stage):
                generator = task_params.get(stage)
                if type(generator) in [str, unicode]:   # Python-script or YQL-query
                    if generator.endswith('.py') or generator.endswith('.yql'):
                        scripts.update({stage: self._get_generator_path(generator, task_params.get('test'))})
                elif type(generator) == dict:   # Sandbox resource
                    scripts.update({stage: self._get_sb_resource(generator)})
            else:
                if default.get(stage):
                    script = default.get(stage)
                    scripts.update({stage: self._get_generator_path(script, task_params.get('test'))})
        return scripts

    def download_external_xsd_schemas(self):
        external_res = {
            'xAL.xsd': 'https://docs.oasis-open.org/election/external/xAL.xsd',
            'xml.xsd': 'https://www.w3.org/2001/xml.xsd'
        }
        dir_name = './external_xsd'
        if not os.path.exists(dir_name):
            os.mkdir(dir_name)
        for name, url in external_res.iteritems():
            local_path = os.path.join(dir_name, name)
            with open(local_path, 'w+') as out:
                try:
                    response = requests.get(url)
                    if response.ok:
                        out.write(response.content)
                    else:
                        raise Exception('Failed: {}'.format(response.status_code))
                except Exception as err:
                    logging.info('Failed to download {} from {}: {}'.format(name, url, err))
        resource = sdk2.Resource[geotypes.GEOSEARCH_XSD_SCHEMAS]
        schema_resource = resource(self, 'External XSD schemas', dir_name)
        schema_resource.external = True
        schema_resource.ttl = 45
        schema_data = sdk2.ResourceData(schema_resource)
        schema_data.ready()
        return dir_name

    def get_external_xsd_schemas(self):
        resource = sdk2.Resource[geotypes.GEOSEARCH_XSD_SCHEMAS].find(attrs={'external': True}).first()
        if resource:
            resource_age = datetime.now() - resource.created.replace(tzinfo=None)
            if resource_age.days > 30:
                return self.download_external_xsd_schemas()
            else:
                resource.ttl = 45
                return str(sdk2.ResourceData(resource).path)
        return self.download_external_xsd_schemas()

    def patch_xsd_schema(self, path):
        all_xsd = os.path.join(path, 'ymaps/snippets/examples/all.xsd')
        external_res = {'xAL.xsd': 'https://docs.oasis-open.org/election/external/xAL.xsd',
                        'xml.xsd': 'https://www.w3.org/2001/xml.xsd'}
        ymaps_1x = os.path.join(path, 'ymaps/ymaps/1.x')
        external_schemas = self.get_external_xsd_schemas()
        for root, sub_dir, filenames in os.walk(external_schemas):
            for filename in filenames:
                shutil.copy(
                    os.path.join(root, filename),
                    os.path.join(ymaps_1x, filename)
                )
        with open(all_xsd) as f:
            schema_doc = minidom.parseString(f.read())
        xml_xsd = schema_doc.createElement('xs:import')
        xml_xsd.setAttribute('namespace',
                             'http://www.w3.org/XML/1998/namespace')
        xml_xsd.setAttribute('schemaLocation', '../../ymaps/1.x/xml.xsd"')
        schema_doc.childNodes[0].appendChild(xml_xsd)
        imports = schema_doc.getElementsByTagName('xs:import')
        for import_tag in imports:
            if import_tag.getAttribute('namespace') == 'urn:oasis:names:tc:ciq:xsdschema:xAL:2.0':
                import_tag.setAttribute('schemaLocation',
                                        '../../ymaps/1.x/xAL.xsd')
        with open(all_xsd, 'w') as f:
            schema_doc = f.write(schema_doc.toxml().replace('&quot;', ''))

    def get_xsd_schema(self):
        local_copy_path = os.path.join(str(self.path), 'schemas')
        shutil.copytree(self.Context.xsd_schema_path, local_copy_path)
        self.patch_xsd_schema(local_copy_path)
        tarfile_name = './schemas.tar.gz'
        with tarfile.open(tarfile_name, 'w:gz') as tar:
            tar.add(local_copy_path,
                    arcname=os.path.basename(local_copy_path),
                    exclude=lambda file_name: '.svn' in file_name)
        return tarfile_name

    def get_report_prefix(self):
        template = ('Generating {tbl_name}: '
                    'https://sandbox.yandex-team.ru/task/{task_id}/view')
        return template.format(tbl_name=self.Context.table_name,
                               task_id=self.id)

    def get_stats(self, params):
        validation_errors_count = get_row_count(self, params.get('error_log', '//tmp/not_existing_table'))
        if validation_errors_count > 0:
            self.Context.validation_errors = True
        stats_str = ('Generated {snippets_num} snippets\n'
                     'https://yt.yandex-team.ru/{cluster}/navigation?path={tbl_path}\n'
                     'Got {errors_num} validation errors\n'
                     'https://yt.yandex-team.ru/{cluster}/navigation?path={err_path}\n\n')
        return stats_str.format(snippets_num=get_row_count(self, self.Context.output_table),
                                cluster=self.Parameters.cluster.split('.')[0],
                                tbl_path=self.Context.table_path_without_schema,
                                errors_num=validation_errors_count,
                                err_path=params.get('error_log'))

    def report(self, additional_line=''):
        if not self.Context.report:
            self.Context.report = ''
        additional_line = '\n%s\n' % additional_line
        prefix_line = self.get_report_prefix()
        with open(self.log_file) as logfile:
            warn_rgx = re.compile(r'Warning! Failed to import YSON bindings:(.*)undefined symbol: PyUnicodeUCS4_AsUTF8String\n', re.MULTILINE)
            logfile_content = re.sub(warn_rgx, '', logfile.read())
            report = '%s%s%s' % (prefix_line, additional_line, logfile_content)
            self.Context.report += report

    def run_py(self, script, params, stage):
        with environments.VirtualEnvironment(use_system=False) as venv:
            proc_env = os.environ.copy()
            proc_env['YT_TOKEN'] = self.yt_token
            if not self.test_run:
                proc_env['YT_POOL'] = 'geosearch_high_priority'
            if self.yql_token:
                proc_env['YQL_TOKEN'] = self.yql_token
            venv.pip('pip==9.0.1')
            venv.pip('-i https://pypi.yandex-team.ru/simple yandex-yt==0.8.49')
            venv.pip('-i https://pypi.yandex-team.ru/simple yandex-yt-yson-bindings')
            venv.pip('requests')
            venv.pip('hashlib')
            venv.pip('lxml')
            venv.pip('jsonschema==2.6.0')
            cmd = ('{python} '
                   '{script} '
                   '--cluster {cluster} '
                   '--parameters {parameters}').format(python=venv.executable,
                                                       script=script,
                                                       cluster=self.Parameters.cluster,
                                                       parameters=json.dumps(params))
            logging.info('Running: %s' % cmd)
            self.log_file = os.path.join(get_logs_folder(),
                                         '{stage}.log'.format(stage=stage))
            try:
                with open(self.log_file, 'w') as logfile:
                    subprocess.check_call(cmd,
                                          shell=True,
                                          env=proc_env,
                                          stdout=logfile,
                                          stderr=subprocess.STDOUT)
            except subprocess.CalledProcessError as err:
                logging.info(('Generating {snippets_name} '
                              'snippets {stage} failed').format(snippets_name=self.Context.table_name,
                                                                stage=stage))
                logging.info('Details: %s' % err)
                self.report('FAILED')
                self.Context.failed = True
                raise SandboxTaskFailureError('"%s" failed' % cmd)

    def run_yql(self, query_template_path, params, stage):
        if not self.Context.yql_tasks.get(stage):
            proc_env = os.environ.copy()
            proc_env['YT_TOKEN'] = self.yt_token
            if self.yql_token:
                proc_env['YQL_TOKEN'] = self.yql_token
            params = json.loads(params)
            template_name = os.path.basename(query_template_path)
            generators_dir = os.path.dirname(query_template_path)
            logging.info('Generators dir: {dir}'.format(dir=generators_dir))
            env = jinja2.Environment(loader=jinja2.FileSystemLoader(generators_dir),
                                     extensions=['jinja2.ext.do'])
            query_text = env.get_template(template_name).render(params)
            yql_task_class = sdk2.Task[RunYQL2.type]
            yql_task = yql_task_class(self,
                                      query=query_text,
                                      trace_query=True,
                                      owner=self.owner,
                                      create_subtask=True,
                                      use_v1_syntax=True,
                                      publish_query=True
                                      )
            yql_task.enqueue()
            self.Context.yql_tasks.update({stage: yql_task.id})
            raise sdk2.WaitTask(self.Context.yql_tasks.values(),
                                ctt.Status.Group.FINISH | ctt.Status.Group.BREAK,
                                wait_all=True)
        else:
            yql_task = sdk2.Task[self.Context.yql_tasks.get(stage)]
            if yql_task.status in ctt.Status.Group.BREAK:
                msg = 'Failed task RUN_YQL_2 #{tid} for {stg} stage'.format(tid=yql_task.id, stg=stage)
                raise SandboxTaskFailureError(msg)

    def run_binary(self, binary, params, stage):
        proc_env = os.environ.copy()
        proc_env['YT_TOKEN'] = self.yt_token
        proc_env['YT_POOL'] = '' if self.test_run else 'geosearch_high_priority'
        if self.yql_token:
            proc_env['YQL_TOKEN'] = self.yql_token
        cmd = ('{binary} '
               '--cluster {cluster} '
               '--parameters {parameters}').format(binary=binary,
                                                   cluster=self.Parameters.cluster,
                                                   parameters=json.dumps(params))
        logging.info('Running: %s' % cmd)
        self.log_file = os.path.join(get_logs_folder(),
                                     '{stage}.log'.format(stage=stage))
        try:
            with open(self.log_file, 'w') as logfile:
                subprocess.check_call(cmd,
                                      shell=True,
                                      env=proc_env,
                                      stdout=logfile,
                                      stderr=subprocess.STDOUT)
        except subprocess.CalledProcessError as err:
            logging.info(('Generating {snippets_name} '
                          'snippets {stage} failed').format(snippets_name=self.Context.table_name,
                                                            stage=stage))
            logging.info('Details: %s' % err)
            self.report('FAILED')
            raise SandboxTaskFailureError('"%s" failed' % cmd)

    def run(self, generator, params, stage):
        if generator.endswith('.py'):
            self.run_py(generator, params, stage)
        elif generator.endswith('.yql'):
            self.run_yql(generator, params, stage)
        else:
            self.run_binary(generator, params, stage)

    def pre(self, params):
        script = self.scripts.get('pre')
        params = json.dumps(params)
        self.run(script, params, 'pre')

    def generate(self, params):
        script = self.scripts.get('generation')
        params = json.dumps(params)
        self.run(script, params, 'generation')

    def post(self, params):
        script = self.scripts.get('post')
        params = json.dumps(params)
        self.run(script, params, 'post')

    def add_service_data(self, params):
        script = self._get_generator_path('add_service_data.py', params.get('test'))
        params = json.dumps(params)
        self.run(script, params, 'add_service_data')

    def upload_snippets(self):
        ferryman_task_class = sdk2.Task[AddrsSnipippetsPushToFerryman.type]
        task_params = json.loads(self.Parameters.snippet_task)
        params = task_params.get('params')
        if self._yt_path_exists(self.Context.table_path_without_schema):
            namespace = params.get('namespace') or task_params.get('table_name')
            ferryman_task = ferryman_task_class(self,
                                                owner=self.owner,
                                                cluster=self.Parameters.cluster,
                                                description='{descr}: uploading "{table}" to Ferryman'.format(descr=self.Parameters.description, table=task_params.get('table_name')),
                                                table_name=task_params.get('table_name'),
                                                table_path=self.Context.table_path_without_schema,
                                                ferryman_url=params.get('saaskv_url'),
                                                namespace=namespace,
                                                snippet_params=json.dumps(params),
                                                load_anyway=self.Parameters.load_anyway)

            ferryman_task.enqueue()
            self.Context.ferryman_task = ferryman_task.id
        raise sdk2.WaitTask([self.Context.ferryman_task],
                            ctt.Status.Group.FINISH | ctt.Status.Group.BREAK,
                            wait_all=True)
        upload_task = sdk2.Task[self.Context.ferryman_task]

    def set_ttl(self, params, table_keys):
        import yt.wrapper as yt
        yt_config = {'proxy': {'url': self.Parameters.cluster},
                     'token': self.yt_token}
        client = yt.YtClient(config=yt_config)
        ttl = params.get('yt_ttl') or DEFAULT_YT_TTL
        for key in table_keys:
            if params.get(key):
                if client.exists(params.get(key)):
                    client.set_attribute(params.get(key),
                                         'expiration_time',
                                         get_ttl(ttl))

    def set_table_schema(self, params):
        import yt.wrapper as yt
        yt_config = {'proxy': {'url': self.Parameters.cluster},
                     'token': self.yt_token}
        client = yt.YtClient(config=yt_config)
        default_schema = [
            {'name': 'Url', 'type': 'string'},
            {'name': params.get('snippet_name'), 'type': 'string'}
        ]
        schema = params.get('table_schema') or default_schema
        if client.exists(self.Context.output_table):
            logging.info('Applying schema %s to %s table' % (schema, self.Context.output_table))
            client.alter_table(self.Context.output_table, schema=schema)

    def set_tasks_tags(self, params):
        try:
            snippet_name = params.get('snippet_name') or self.task.get('table_name')
            tags = set(self.Parameters.tags)
            tags.add(snippet_name.upper())
            if self.test_run:
                tags.add('TEST')
            self.Parameters.tags = list(tags)
        except Exception as err:
            logging.info('Failed to set tags. Details %s' % err)

    def create_symlink(self, params):
        try:
            import yt.wrapper as yt
            yt_config = {'proxy': {'url': self.Parameters.cluster},
                         'token': self.yt_token}
            yt_client = yt.YtClient(config=yt_config)
            yt_client.link(
                self.Context.table_path_without_schema,
                params.get('symlink_path'),
                force=True
            )
        except Exception:
            msg = 'Failed to symlink {symlink} to table {table}'.format(
                symlink=params.get('symlink_path'),
                table=params.get('processing_out') or params.get('generating_out')
            )
            logging.exception(msg)

    def on_timeout(self, prev_status):
        sdk2.Task.on_timeout(self, prev_status)
        self.set_ttl(self.task.get('params'), self.output_tables)

    def on_failure(self, prev_status):
        sdk2.Task.on_failure(self, prev_status)
        self.set_ttl(self.task.get('params'), self.output_tables)

    def on_break(self, prev_status, status):
        self.set_ttl(self.task.get('params'), self.output_tables)

    def on_stop(self):
        sdk2.Task.on_stop(self)
        self.set_ttl(self.task.get('params'), self.output_tables)

    def on_execute(self):
        self.Context.failed = False
        self.Context.validation_errors = False
        self.Context.task_manager_dir = Arcadia.export('arcadia:/arc/trunk/arcadia/search/geo/tools/task_manager',
                                                                    './task_manager')
        self.Context.xsd_schema_path = Arcadia.export('arcadia:/arc/trunk/arcadia/maps/doc/schemas',
                                                      './schemas')
        self.yt_token = sdk2.Vault.data('GEOMETA-SEARCH', 'yt-token')
        self.bad_statuses = ['FAILURE',
                             'EXCEPTION',
                             'TIMEOUT']
        self.output_tables = [
            'pre_processing_out',
            'processing_out',
            'generating_out',
        ]
        try:
            self.yql_token = sdk2.Vault.data('GEOMETA-SEARCH', 'YQL_TOKEN')
        except Exception:
            self.yql_token = None
            logging.warning("Failed to load yql token from vault")
        if not self.Context.yql_tasks:
            self.Context.yql_tasks = {}
        self.task = json.loads(self.Parameters.snippet_task)
        params = self.get_data_resources(self.task.get('params'))
        self.Context.output_table = params.get('processing_out') or params.get('output_path')
        self.Context.table_path_without_schema = '//%s' % self.Context.output_table.split('//')[-1]
        self.test_run = params.get('test')
        self.set_tasks_tags(params)
        self.scripts = self.get_generators(params)
        logging.info('Scripts: %s' % self.scripts)
        logging.info('Processing %s' % self.task.get('table_name'))
        self.Context.table_name = self.task.get('table_name')
        self.Context.snippets_path = params.get('processing_out')
        if 'provider_id' in params:
            params['provider_id'] = str(params['provider_id'])
        if params.get('schema'):    # Use JSON schema
            params.update({'schema': os.path.join(self.Context.task_manager_dir,
                                                  params.get('schema'))})
        params.update({'banlist': os.path.join(self.Context.task_manager_dir,
                                               'banlists',
                                               params.get('snippet_name') or self.task.get('table_name'),
                                               'banlist')})
        with self.memoize_stage.GET_XSD_SCHEMA(commit_on_entrance=False):
            if params.get('validate') and not params.get('schema'):     # Use XSD schema
                params.update({'schema': self.get_xsd_schema()})
        logging.info('Mapper parameters: %s' % self.task.get('params'))
        with self.memoize_stage.PRE_PROCESSING(commit_on_entrance=False):
            if params.get('pre'):
                self.pre(params)
                self.report()
        with self.memoize_stage.GENERATE(commit_on_entrance=False):
            self.generate(params)
            self.report()
        with self.memoize_stage.POST_PROCESSING(commit_on_entrance=False):
            if params.get('post'):
                self.post(params)
                self.report()
        with self.memoize_stage.ADD_SERVICE_DATA(commit_on_entrance=False):
            self.add_service_data(params)
        with self.memoize_stage.UPLOAD_SNIPPETS(commit_on_entrance=False):
            if self.Parameters.run_upload_task:
                self.upload_snippets()
        if self.Parameters.run_upload_task:
            ferryman_task = sdk2.Task[self.Context.ferryman_task]
            if ferryman_task.status in self.bad_statuses:
                msg = 'Failed to upload {table} to Ferryman'.format(table=ferryman_task.Parameters.table_path)
                raise SandboxTaskFailureError(msg)
            else:
                logging.info('Ferryman upload report: %s' % ferryman_task.Context.upload_report)
                if params.get('symlink_path'):
                    self.create_symlink(params)
                self.Context.report += ferryman_task.Context.upload_report
        self.Context.report = self.get_stats(params) + self.Context.report
        if not os.path.exists(params.get('out_stats')):     # Trying not to overwrite
            self.make_stat_file(params)
        post_stats(params)
        with self.memoize_stage.SET_TTL_AND_SCHEMA(commit_on_entrance=False):
            self.set_ttl(params, self.output_tables)
