# -*- coding: utf-8 -*-
import logging
import os
from sandbox import common
import tempfile

from sandbox import sdk2
import sandbox.common.types.resource as ctr
from sandbox.projects.images.tags import resources as tags_resources
from sandbox.sandboxsdk.copy import RemoteCopy


class ImagesTagsReleaseAliceBans(sdk2.Task):
    """
        Компилирует GZT и regex (PIRE) баны для Алисы, которые применяются в правиле бегемота ImgTagsBan.
        По умолчанию источники банов перечислены в arcadia/search/wizard/data/fresh/ImgTagsBan/sources.tsv
    """
    class Requirements(sdk2.Task.Requirements):
        disk_space = 1 * 1024  # 1 Gb
        ram = 1 * 1024  # 1 Gb
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 10600
        sourcesPath = sdk2.parameters.SvnUrl("Path to tags models&data on svn",
                                              required=True,
                                              default_value='svn+ssh://arcadia.yandex.ru/arc/trunk/arcadia/search/wizard/data/fresh/ImgTagsBan/sources.tsv')
        bansCompilerResource = sdk2.parameters.LastReleasedResource('Bans compiler resource',
                                             resource_type=tags_resources.IMAGES_TAGS_BAN_COMPILER_EXECUTABLE,
                                             state=(ctr.State.READY),
                                             required=True)

    @staticmethod
    def _CutColumnToTmpFile(filename, column):
        tmpFile = tempfile.NamedTemporaryFile(delete=False)
        with open(filename, 'rb') as f:
            for iline, line in enumerate(f):
                line = line.rstrip().split('\t')

                if column >= len(line):
                    logging.info('Skip line {} of {} as it has less than {} columns.'.format(iline, filename, column))
                    continue
                tmpFile.write('{}\n'.format(line[column]))
        return tmpFile.name

    @staticmethod
    def _GetFilesForBanType(sources, banType):
        """
            В файлах с чёрными списками слова находятся в разных колонках. Иногда в первой, иногда во второй.
            Выбираем нужную колонку из файла скачанного с sandbox и кладём результат во временный файл.
        """
        banSources = []
        for src in filter(lambda s: s[2] == banType and not s[0].startswith('WHITELIST'), sources):
            srcId, srcFileName, _, srcColToUse = src
            tmpFileName = ImagesTagsReleaseAliceBans._CutColumnToTmpFile(srcFileName, srcColToUse)
            banSources.append('{}:{}'.format(tmpFileName, srcId))

        whitelistFile = tempfile.NamedTemporaryFile(mode='ab', delete=False)
        for src in filter(lambda s: s[2] == banType and s[0].startswith('WHITELIST'), sources):
            srcId, srcFileName, _, srcColToUse = src
            tmpFileName = ImagesTagsReleaseAliceBans._CutColumnToTmpFile(srcFileName, srcColToUse)
            whitelistFile.write(open(tmpFileName, 'rb').read())
            whitelistFile.write('\n')  # just in case
        whitelistFile.close()
        return banSources, whitelistFile.name

    def _GetBanCompilerExecutable(self):
        return str(sdk2.ResourceData(self.Parameters.bansCompilerResource).path)

    def _CompileBanResource(self, sources, whitelist, banResource, bansCompilerExecutableMode):
        outputKeysDict = {'CompileBanGzt': '--gzt', 'CompileBanPire': '--pire'}
        resultBanKey = outputKeysDict.get(bansCompilerExecutableMode, '')
        if not resultBanKey:
            logging.error('bansCompilerExecutableMode is equal to {}, but should be one of: {}'.format(bansCompilerExecutableMode, outputKeysDict.keys()))
            raise common.errors.TaskFailure('bansCompilerExecutableMode is equal to {}, but should be one of: {}'.format(bansCompilerExecutableMode, outputKeysDict.keys()))

        inputKeysDict = {'CompileBanGzt': '--ban', 'CompileBanPire': '--regex'}
        inputBanKey = inputKeysDict.get(bansCompilerExecutableMode, '')

        bansCompilerExecutable = self._GetBanCompilerExecutable()

        banResourceData = sdk2.ResourceData(banResource)
        banResourceData.path.parents[0].mkdir(parents=True, exist_ok=True)

        inputParameters = ' '.join(map(lambda s: '{} {}'.format(inputBanKey, s), sources))
        whitelistParameter = '--whitelist {}'.format(whitelist)
        dstParameter = '{} {}'.format(resultBanKey, banResourceData.path)
        commandCall = '{} {} {} {} {}'.format(bansCompilerExecutable,
                                              bansCompilerExecutableMode,
                                              inputParameters,
                                              whitelistParameter,
                                              dstParameter)
        if os.system(commandCall):
            logging.error('Ban compiler failed: {}'.format(commandCall))
            raise common.errors.TaskFailure("Error when compiling bans.")
        else:
            banResourceData.ready()

    def _PrepareSourcesList(self):
        DATA_DIR = 'ImgTagsBan'
        SOURCES_LIST = os.path.join(DATA_DIR, 'sources.tsv')
        if not os.path.exists(DATA_DIR):
            os.makedirs(DATA_DIR)

        # Copy ban sources file and parse
        RemoteCopy(self.Parameters.sourcesPath, SOURCES_LIST)()
        sources_remote = []
        with open(SOURCES_LIST, 'rb') as fin:
            sources_remote = filter(lambda s: len(s.strip()) and not s.startswith('#'), fin.readlines())
            sources_remote = map(lambda s: s.rstrip().split('\t'), sources_remote)
            for s in sources_remote:
                if len(s) != 4:
                    raise common.errors.TaskFailure('Each non-empty line of {} have to contain 4 columns: ban_id svn_path ban_type tsv_column'.format(SOURCES_LIST))

        # Download bans
        sources_local = []
        for remote in sources_remote:
            local_path = os.path.join(DATA_DIR, os.path.basename(remote[1]))
            logging.info("Download source: {} to {}".format(remote, local_path))
            RemoteCopy(remote[1], local_path)()
            sources_local.append(remote[:1] + [local_path] + remote[2:])
            sources_local[-1][3] = int(sources_local[-1][3]) - 1  # one-based indexing in input file was used
        return sources_local

    def on_execute(self):
        sources = self._PrepareSourcesList()

        # build GZT bans
        gztSources, gztWhitelist = self._GetFilesForBanType(sources, 'GZT')
        gztBanResource = tags_resources.IMAGES_TAGS_BAN_GZT(self, 'Compiled GZT bans for tags.', 'banned-tags.gzt.bin')
        self._CompileBanResource(gztSources, gztWhitelist, gztBanResource, 'CompileBanGzt')

        # build PIRE bans
        pireSources, pireWhitelist = self._GetFilesForBanType(sources, 'PIRE')
        pireBanResource = tags_resources.IMAGES_TAGS_BAN_PIRE(self, 'Compiled PIRE bans for tags.', 'banned-tags.pire.bin')
        self._CompileBanResource(pireSources, pireWhitelist, pireBanResource, 'CompileBanPire')
