# -*- coding: utf-8 -*-

from .yt_subprocess_mapper import SubprocessMapper

import atexit
import hashlib
import os
import os.path
import shutil
import json
import subprocess
import logging

from tempfile import mkdtemp

import yt.wrapper as yt  # for @yt.aggregator decorator

from jinja2 import Environment, BaseLoader

from library.python import resource


class MapperWorkDir(object):

    def __init__(self, suffix='', prefix='mapper_work', dir=None):
        self.dir = mkdtemp(suffix=suffix, prefix=prefix, dir=dir)
        atexit.register(MapperWorkDir._cleanup, self)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._cleanup()

    def _cleanup(self):
        if os.path.exists and shutil.rmtree:
            if os.path.exists(self.dir):
                shutil.rmtree(self.dir)


@yt.aggregator
class PerlMapper(SubprocessMapper):

    def __init__(self,
                 mapper=None, begin=None, end=None, top_level=None, constants=None, stash=None,
                 libs=None, env=None, work_dir=None, local_files=None, dst_names=None, archive_target_dir=None, default_target_dir=None, **kwargs):
        """Arguments:
            mapper      - string with perl code for main block with rows processing
                            in this block variable $r with input row is available
                            output with yield function: yield($output_row => $table_index);
            begin       - code for start block
            end         - code for finish block
            top_level   - top level code (subs)
            constants   - dict with global constants
            dst_names   - aliases for output tables (optional)
            dst_options - options for output tables (optional)
            ...         - other arguments for SubprocessMapper
        """
        super(PerlMapper, self).__init__('', **kwargs)

        self._begin_code = begin.decode('utf-8') if begin is not None else begin
        self._mapper_code = mapper.decode('utf-8') if mapper is not None else mapper
        self._end_code = end.decode('utf-8') if end is not None else end
        self._top_level_code = top_level.decode('utf-8') if top_level is not None else top_level
        self._constants = constants if constants is not None else {}
        self._stash = stash
        self._archive_target_dir = archive_target_dir if archive_target_dir is not None else {}
        self._default_target_dir = default_target_dir

        self._libs = libs if libs is not None else []
        self._env = env if env is not None else {}
        self._work_dir = work_dir if work_dir is not None else MapperWorkDir()
        self._dst_names = dst_names if dst_names is not None else [None for f in self.dst_fields]

        self._backup_generated_pl = os.getenv("BACKUP_GENERATED_PL", False)
        if self._backup_generated_pl and self._backup_generated_pl == "1":
            self._backup_generated_pl = "/tmp/latest_generated_pl"

        self.local_files = [self._perlRunFile()]
        if local_files is not None:
            self.local_files += local_files

        self.command_env = {
            'PERL5LIB': './lib',
            'LD_LIBRARY_PATH': './lib',
        }
        self.command = ['perl', os.path.basename(self._perlRunFile())]

    def init_command(self):
        # was in mr_perl
        new_archives = True
        while new_archives:
            todo = []
            for fn in os.listdir('.'):
                if fn.endswith('.tar'):
                    gzipped = False
                elif fn.endswith('.tar.gz') or fn.endswith('.tgz'):
                    gzipped = True
                else:
                    continue

                if fn in self._archive_target_dir:
                    target_dir = self._archive_target_dir[fn]
                elif self._default_target_dir:
                    target_dir = self._default_target_dir
                else:
                    continue

                if not os.path.exists(target_dir):
                    os.mkdir(target_dir)

                tar_opts = '-x' + ('z' if gzipped else '') + 'f'
                cmd = [
                    'tar', tar_opts,
                    fn,
                    '--exclude', '_mr_target_dir',  # historical
                    '-C', target_dir,
                ]
                todo.append((fn, cmd, os.path.getsize(fn)))

            # it is good to unpack large archives first, to save disk space
            todo.sort(key=lambda h: h[-1], reverse=True)

            for fn, cmd, _ in todo:
                logging.warning('unpack cmd: %s', cmd)
                subprocess.check_call(cmd)
                os.remove(fn)

            new_archives = bool(todo)

    def _perlRunFile(self):
        if hasattr(self, '_perl_run_file'):
            return self._perl_run_file

        perl_script = Environment(loader=BaseLoader).from_string(
            resource.find('yt_perl_mapper_template').decode('utf-8')
        ).render(self.template_get_context()).encode('utf-8')

        self._perl_run_file = os.path.join(self._work_dir.dir, 'generated_mapper.' + hashlib.md5(perl_script).hexdigest() + '.pl')
        with open(self._perl_run_file, 'w') as fh:
            fh.write(perl_script)

        if self._backup_generated_pl:
            shutil.copy2(self._perl_run_file, self._backup_generated_pl)

        return self._perl_run_file

    def template_get_yt_env(self):
        yt_env = {}

        yt_env_prefix = 'YT_PERL_'
        for k, v in os.environ.items():
            if k.startswith(yt_env_prefix):
                yt_env[k[len(yt_env_prefix):]] = v

        for k, v in self._env.items():
            yt_env[k] = v
        return yt_env.iteritems()

    def template_get_dst_names(self):
        return [(x, y) for x, y in enumerate(self._dst_names) if y is not None]

    def template_get_dst(self):
        return zip(
            range(len(self.dst_fields)),
            self.dst_fields,
            self.dst_options,
        )

    def template_get_stash_json(self):
        return json.dumps(self._stash) if self._stash is not None else ''

    def template_get_context(self):
        return {
            'yt_env': self.template_get_yt_env(),
            'libs': self._libs,
            'dst_names': self.template_get_dst_names(),
            'none_value': self.NONE_VALUE,
            'dst': self.template_get_dst(),
            'begin_code': self._begin_code or '',
            'end_code': self._end_code or '',
            'mapper_code': self._mapper_code or '',
            'top_level_code': self._top_level_code or '',
            'constants': [(k, v) for k, v in self._constants.items()],
            'stash_json': self.template_get_stash_json(),
        }
