from sandbox import sdk2
from sandbox.projects.common.wizard.current_production import get_current_production_resource_id
from sandbox.projects.common.wizard.providers import EntitySearchProvider
from sandbox.projects.EntitySearch import resource_types as es_resource_types
from sandbox.projects.geosearch import resource_types as geotypes
from sandbox.sandboxsdk import environments
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.sandboxsdk.paths import get_logs_folder

import os
import logging
import subprocess
from md5 import md5

PARENT_SERVICE = 'sas-production-entitysearch-yp'


def get_prod_res_id(res_type, nanny_token):
    return get_current_production_resource_id(PARENT_SERVICE, res_type, nanny_token)


def read_to_tsv(yt_client, table_path, columns, key_prefix, cnv, file_name):
    import yt.wrapper as yt
    with open(file_name, 'a') as output:
        rows = yt_client.read_table(table_path, format=yt.JsonFormat())
        for row in rows:
            output.write('{}\t{}{}\n'.format(row[columns[0]], key_prefix, cnv(row[columns[1]]) if cnv else row[columns[1]]))


def calc_hash(uri):
    return md5(uri).hexdigest()


class UpdateGeosearchEntities(sdk2.Task):
    class Parameters(sdk2.Task.Parameters):
        cluster = sdk2.parameters.String('YT cluster',
                                         default='hahn')

        org_entities_table = sdk2.parameters.String('Table with permalink <-> onto_id mapping',
                                                    default='//home/geosearch/entity_search_permalinks')

        toponym_entities_table = sdk2.parameters.String('Table with source_id <-> onto_id mapping',
                                                        default='')

        toponym_uri_entities_table = sdk2.parameters.String('Table with uri <-> onto_id mapping',
                                                            default='')

        snippets_table = sdk2.parameters.String('Table with prepared snippets',
                                                default='//home/geosearch/mesherin/entities_temp')
        data_builder = sdk2.parameters.Resource('Entitysearh data builder',
                                                resource_type=geotypes.GEOSEARCH_ENTITY_DATA_PREPARER)

    class Requirements(sdk2.Task.Requirements):
        environments = (
            environments.PipEnvironment('yandex-yt', use_wheel=True),
            environments.PipEnvironment('yandex-yt-yson-bindings-skynet')
        )

    def on_execute(self):
        nanny_token = sdk2.Vault.data('GEOMETA-SEARCH', 'nanny_token')
        yt_token = sdk2.Vault.data('GEOMETA-SEARCH', 'yt-token')
        os.environ['MR_RUNTIME'] = 'YT'
        os.environ['YT_TOKEN'] = yt_token

        import yt.wrapper as yt
        yt_client = yt.YtClient(self.Parameters.cluster, yt_token)

        input_file = os.path.abspath('input.tsv')
        read_to_tsv(yt_client, self.Parameters.org_entities_table, ['ontoid', 'ysr_id'], '', None, input_file)
        read_to_tsv(yt_client, self.Parameters.toponym_entities_table, ['ontoid', 'source_id'], 'geocoder_id_', None, input_file)
        if self.Parameters.toponym_uri_entities_table:
            read_to_tsv(yt_client, self.Parameters.toponym_uri_entities_table, ['ontoid', 'uri'], 'geocoder_uri_', calc_hash, input_file)

        output_file = os.path.abspath('saas_dump.tsv')

        with EntitySearchProvider(
            fresh_id=get_prod_res_id(es_resource_types.ENTITY_SEARCH_FRESH, nanny_token),
            ner_id=get_prod_res_id(es_resource_types.ENTITY_SEARCH_NER_DATA, nanny_token),
            data_id=get_prod_res_id(es_resource_types.ENTITY_SEARCH_DATA, nanny_token),
            binary_id=get_prod_res_id(es_resource_types.ENTITY_SEARCH_EXECUTABLE, nanny_token),
            config_id=get_prod_res_id(es_resource_types.ENTITY_SEARCH_CONFIG, nanny_token),
            parent_nanny_service=PARENT_SERVICE,
            start_timeout=60 * 60,
            use_mmap=True,
        ):
            data_builder_resource = str(sdk2.ResourceData(self.Parameters.data_builder).path)
            cmd = [
                data_builder_resource,
                '--server', 'localhost:8895',
                '--input_file', input_file,
                '--output_file', output_file
            ]
            try:
                log_path = os.path.join(get_logs_folder(), 'prepare_entity.log')
                with open(log_path, 'w') as logfile:
                    subprocess.check_call(
                        cmd,
                        stdout=logfile,
                        stderr=subprocess.STDOUT
                    )
            except subprocess.CalledProcessError:
                logging.exception('Command "{}" failed.'.format(' '.join(cmd)))
                raise SandboxTaskFailureError('Building snippet data failed')
        try:
            yt_client.write_table(self.Parameters.snippets_table,
                                  (x for x in open(output_file, 'r')),
                                  format="dsv",
                                  raw=True)
            yt_client.run_merge(self.Parameters.snippets_table, self.Parameters.snippets_table, spec={'combine_chunks': True})
        except Exception:
            logging.exception('Failed to write data to YT table {}'.format(self.Parameters.snippets_table))
            raise SandboxTaskFailureError('Building snippet data failed')
