import os
from os.path import join as pj
import subprocess
import json
import logging
import pipes

import sandbox.sandboxsdk.util as sdk_util
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.errors import SandboxTaskFailureError, SandboxSubprocessError
from sandbox.sandboxsdk.paths import make_folder, get_unique_file_name
from sandbox.sandboxsdk.process import run_process
from sandbox.sandboxsdk.sandboxapi import SandboxResource

from sandbox.projects.common.ProcessPool import ProcessPool
from sandbox.projects import resource_types


class BaseClient(object):
    def __init__(self, mr_binary, mr_server, mr_env_str='', log_dir=""):
        self._prefix = "{} {} -server {} ".format(
            mr_env_str,
            mr_binary,
            mr_server
        )
        self._env_str = mr_env_str
        self._mr_binary = mr_binary
        self._log_dir = log_dir
        self._mr_server = mr_server
        if log_dir:
            self._log_prefix = "mapreduce_client/"
            log_dir = pj(log_dir, self._log_prefix)
            make_folder(log_dir)
        else:
            self._log_prefix = ""

    def get_binary_path(self):
        return self._mr_binary

    def get_env_str(self):
        return self._env_str

    def get_log_prefix(self, p):
        return self._log_prefix + p

    def command(self, cmd):
        return self._prefix + cmd

    def run(self, cmd, *args, **kwargs):
        if "log_prefix" in kwargs:
            kwargs["log_prefix"] = self._log_prefix + kwargs["log_prefix"]
        return run_process(self.command(cmd), shell=True, *args, **kwargs)


class MapreduceClient(BaseClient):
    def __init__(self, mr_binary, mr_server, mr_env_str='', log_dir=""):
        BaseClient.__init__(self, mr_binary, mr_server, mr_env_str, log_dir)
        self._yt = "mapreduce-yt" in mr_binary

    def get_yt_cli(self):
        assert self.is_yt()
        mr_binary = self.get_binary_path().replace("mapreduce-yt", "yt2")
        res = BaseClient(mr_binary, self._mr_server, self.get_env_str(), self._log_dir)
        res._prefix = res._prefix.replace("-server", "--proxy")
        return res

    def is_yt(self):
        return self._yt

    def _list_tables(self, prefix, want_json, file_prefix=None):
        cmd = "-list -prefix '" + prefix + "'" + (' -jsonoutput' if want_json else '')
        if file_prefix is None:
            p = self.run(
                cmd,
                stdout=subprocess.PIPE,
                wait=False,
                log_prefix="list_tables",
                outputs_to_one_file=False,
                check=True
            )
            (res, _) = p.communicate()
            return res
        else:
            self.run(
                cmd,
                wait=True,
                log_prefix=file_prefix,
                check=True
            )

    def save_tables_list(self, prefix, want_json, file_prefix):
        self._list_tables(prefix, want_json, file_prefix=file_prefix)

    def is_strictly_sorted(self, table):
        if not self.is_yt():
            return False
        p = self.run(
            "-get {}/@key_columns".format(table),
            stdout=subprocess.PIPE,
            outputs_to_one_file=False,
            check=True,
            wait=False,
            log_prefix="is_strictly_sorted"
        )
        (out, _) = p.communicate()
        columns = json.loads(out)
        return "value" in columns

    def is_yamr_compatible_format(self, table):
        if not self.is_yt():
            return True

        p = self.get_yt_cli().run(
            "get {} --attribute schema --format json".format(table),
            stdout=subprocess.PIPE,
            outputs_to_one_file=False,
            check=True,
            wait=False,
            log_prefix="is_yamr_compatible_format"
        )
        (out, _) = p.communicate()
        columns = json.loads(out)['$attributes']['schema']['$value']
        return all([column['name'] in ['key', 'subkey', 'value'] for column in columns])

    def get_tables_list(self, prefix, json_format=False):
        if json_format:
            return json.loads(self._list_tables(prefix, want_json=True))
        else:
            return filter(lambda s: s, self._list_tables(prefix, want_json=False).splitlines())

    def copy_table(self, src, dst):
        self.run(
            '-copy -src {src} -dst {dst} '.format(src=src, dst=dst),
            log_prefix='copy_table',
            wait=True,
            check=True
        )


class TableNamer(object):
    @staticmethod
    def reprefix(name, prefix=None, new_prefix=None):
        if prefix is not None:
            assert name.startswith(prefix)
            if new_prefix is not None:
                name = name[len(prefix):]
        if new_prefix is not None:
            name = new_prefix + name
        return name

    @staticmethod
    def file_to_table(name, prefix="", new_prefix=""):
        for sort_mode in [".strictly_sorted", ".sorted", ""]:
            name = name.replace(sort_mode + ".lenval.gz", "")
        name = name.replace(":", "/")
        return TableNamer.reprefix(name, prefix, new_prefix)

    @staticmethod
    def get_sort_mode(filename):
        if filename.endswith(".sorted.lenval.gz"):
            return SortMode.STANDART
        if filename.endswith(".strictly_sorted.lenval.gz"):
            return SortMode.STRICT
        assert filename.endswith(".lenval.gz")
        return SortMode.NONE

    @staticmethod
    def table_to_file(name, sort_mode, prefix=None, new_prefix=None):
        name = TableNamer.reprefix(name, prefix, new_prefix)
        name = name.replace("/", ":")
        if sort_mode == SortMode.STANDART:
            name += ".sorted"
        elif sort_mode == SortMode.STRICT:
            name += ".strictly_sorted"
        name += ".lenval.gz"
        return name


class SortMode:
    NONE = 0
    STANDART = 1  # key, subkey
    STRICT = 2  # key, subkey, value

    @staticmethod
    def get_testable(testable):
        return SortMode.STRICT if testable else SortMode.STANDART


class MapreduceTablesIO(object):
    def __init__(self, client):
        self.client = client
        self.yt_cli = client.get_yt_cli() if client.is_yt() else None
        self._sort_input_tables_mode = SortMode.NONE
        self._sort_output_tables_mode = SortMode.NONE
        self._safe_mode = False
        self._nproc = max(1, sdk_util.system_info()['ncpu'] / 2)

    def _get_resource(self, resource_or_id):
        if isinstance(resource_or_id, SandboxResource):
            return resource_or_id
        res_id = int(resource_or_id)
        resource = channel.sandbox.get_resource(res_id)
        if not resource:
            raise SandboxTaskFailureError("Error fetching resource {}".format(res_id))
        return resource

    def set_safe_mode(self):
        self._safe_mode = True

    def set_sort_input_tables_mode(self, mode):
        if not self.client.is_yt():
            mode = min(mode, SortMode.STANDART)
        self._sort_input_tables_mode = mode

    def set_sort_output_tables_mode(self, mode):
        if not self.client.is_yt():
            mode = min(mode, SortMode.STANDART)
        self._sort_output_tables_mode = mode

    def call_safe(self, fn, *args, **kwargs):
        retries = kwargs.pop("retries", 1)

        while True:
            retries -= 1
            try:
                return fn(*args, **kwargs)
            except:
                if retries > 0:
                    continue
                if not self._safe_mode:
                    raise
                return None

    def _download(self, infos, dir_name, prefix, new_prefix, table_to_file=None):
        read_cmd = self.client.command("-lenval -subkey -read ")
        sort_cmd = self.client.command("-sort ")
        cmds = []
        for info in infos:
            table = info["name"]

            ret = self.call_safe(self.client.is_yamr_compatible_format, table)
            if ret is None:
                raise SandboxTaskFailureError(
                    'is_yamr_compatible_format for table {} failed'.format(table)
                )
            if not ret:
                logging.info(
                    "Table {} is not in YAMR-compatible format sorted, skip it".format(table)
                )
                continue

            cur_cmd = ""
            sort_mode = SortMode.NONE
            if info.get("sorted"):
                sort_mode = SortMode.STANDART
                if self.call_safe(self.client.is_strictly_sorted, table):
                    sort_mode = SortMode.STRICT
            if sort_mode < self._sort_output_tables_mode:
                logging.info("Table {} is not sorted, going to fix it".format(table))
                cur_cmd = sort_cmd + table
                if self._sort_output_tables_mode == SortMode.STRICT:
                    cur_cmd += " -sortby key -sortby subkey -sortby value "
                cur_cmd += " && "
                sort_mode = self._sort_output_tables_mode
            fname = TableNamer.table_to_file(table, sort_mode, prefix, new_prefix)

            if table_to_file is not None:
                table_to_file[table] = fname
            fname = pj(dir_name, fname)
            cur_cmd += "echo {} | bash -e -o pipefail".format(
                pipes.quote(read_cmd + table + ' | gzip > ' + fname)
            )
            cmds.append(cur_cmd)

        p = ProcessPool(self._nproc)
        p.map(
            lambda cmd: self.call_safe(
                run_process,
                cmd, shell=True, wait=True, check=True,
                log_prefix=self.client.get_log_prefix('download.read_sort'),
                retries=2
            ),
            cmds
        )

    def download_tables(self, task, descr, prefix, new_prefix=None, attrs=None,
                        skip_unchanged=None, resource_id=None,
                        table_to_file=None):
        logging.info("saving tables with prefix {}".format(prefix))
        to_skip = {t["name"]: t for t in skip_unchanged or {}}

        infos = []
        for info in sorted(self.client.get_tables_list(prefix, json_format=True)):
            table = info["name"]
            if table in to_skip:
                old = to_skip[table]
                if all(old[f] == info[f] for f in ['records', 'creat_time', 'mod_time']):
                    logging.info("Skipping {}: it's unchanged from input".format(table))
                    continue
                logging.info("Table '{}' modified, was {}, now {}".format(table, repr(old), repr(info)))
            infos.append(info)

        dir_name = get_unique_file_name(task.abs_path(), 'mapreduce_tables')
        make_folder(dir_name)
        self._download(infos, dir_name, prefix, new_prefix, table_to_file)

        resource = None
        if resource_id:
            resource = channel.sandbox.get_resource(resource_id)
        if not (resource and resource.type == resource_types.USERDATA_TABLES_ARCHIVE):
            resource = task.create_resource(
                descr,
                dir_name,
                resource_types.USERDATA_TABLES_ARCHIVE,
                arch='any',
            )
        attrs = attrs.copy()
        attrs["tables_prefix"] = new_prefix or prefix
        for n, v in attrs.iteritems():
            channel.sandbox.set_resource_attribute(resource.id, n, v)
        task.change_resource_basename(resource.id, dir_name)
        task.mark_resource_ready(resource.id)
        return resource

    def get_resource_table_descrs(self, task, resource_or_id, new_prefix=None):
        resource = self._get_resource(resource_or_id)
        dirname = task.sync_resource(resource.id)
        orig_prefix = resource.attributes['tables_prefix']
        out = []
        tables = filter(lambda name: name.endswith(".lenval.gz"), os.listdir(dirname))  # XXX
        max_sort_mode = SortMode.STRICT if self.client.is_yt() else SortMode.STANDART
        for t in sorted(tables):
            out.append({
                'filename': t,
                'fullpath': pj(dirname, t),
                'sort_mode': min(TableNamer.get_sort_mode(t), max_sort_mode),
                'tablename': TableNamer.file_to_table(t, orig_prefix, new_prefix),
                'shortname': TableNamer.file_to_table(t, orig_prefix, None),
            })
        return out

    def upload_table(self, table_descr):
        table_descr = dict(table_descr)
        cmd_before = ""
        cmd_after = ""
        if table_descr["sort_mode"] == SortMode.STRICT:
            cmd_before = self.yt_cli.command(
                "create table //{tablename} -r -i".format(**table_descr)
            ) + " && "
            cmd_after = self.yt_cli.command(
                "write '<sorted_by=[key;subkey;value]>//{tablename}' "
                "--format '<lenval=true;has_subkey=true>yamr' "
                "--table-writer '{{max_key_weight=220000}}'".format(**table_descr)
            )
        else:
            table_descr["writecmd"] = '-writesorted' if table_descr["sort_mode"] == SortMode.STANDART else '-write'
            table_descr["timeout"] = '-timeout 3600' if self.client.is_yt() else ''
            cmd_after = self.client.command(' {timeout} -lenval -subkey {writecmd} {tablename}'.format(**table_descr))

        TRIES = 1 if self.client.is_yt() else 2
        for take in range(TRIES)[::-1]:
            try:
                run_process(
                    cmd_before + ' zcat ' + table_descr["fullpath"] + " | " + cmd_after,
                    shell=True, check=True, wait=True,
                    log_prefix=self.client.get_log_prefix("upload_table.write")
                )
                break
            except SandboxSubprocessError:
                if take == 0:
                    raise
        if table_descr["sort_mode"] < self._sort_input_tables_mode:
            table_descr["sortby"] = ""
            if self._sort_output_tables_mode == SortMode.STRICT:
                table_descr["sortby"] = " -sortby key -sortby subkey -sortby value "

            self.client.run(
                '-sort -src {tablename} -dst {tablename} {sortby}'.format(**table_descr),
                check=True, wait=True, log_prefix='upload_one.sort'
            )

    def upload_tables(self, task, resource_or_id, new_prefix):
        resource = self._get_resource(resource_or_id)
        descrs = self.get_resource_table_descrs(task, resource, new_prefix)
        p = ProcessPool(self._nproc)
        p.map(lambda descr: self.upload_table(descr), descrs)
        return resource


# XXX: deprecated
class MapreduceStoredTables(object):
    # XXX: args
    # XXX: pass log_dir
    def __init__(self, mr_binary=None, mr_server=None, mr_env_str='', mr_filters='', client=None, log_dir=None):
        if client:
            self.client = client
        else:
            assert mr_binary and mr_server
            self.client = MapreduceClient(mr_binary, mr_server, mr_env_str or "")
        self.mr_filters = mr_filters or ''
        if log_dir:
            self.log_prefix = "mr_tables_ops/"
            log_dir = pj(log_dir, self.log_prefix)
            make_folder(log_dir)
        else:
            self.log_prefix = ""

        self._strictly_sorted_yt_tables = False

    def set_strictly_sorted_yt_tables(self):
        self._strictly_sorted_yt_tables = self.client.is_yt()

    def _list_tables(self, prefix, want_json, file_prefix=None):
        if file_prefix is None:
            kwargs = dict(
                stdout=subprocess.PIPE,
                wait=False,
                log_prefix=self.log_prefix + "list_tables",
                outputs_to_one_file=False,
                check=True
            )
        else:
            kwargs = dict(
                wait=True,
                log_prefix=file_prefix,
                check=True
            )

        p = self.client.run(
            "-list -prefix '" + prefix + "'" + (' -jsonoutput' if want_json else ''),
            **kwargs
        )
        if file_prefix is None:
            (tlist, _) = p.communicate()
            return tlist
        return None

    def save_tables_list(self, prefix, want_json, file_prefix):
        self._list_tables(prefix, want_json, file_prefix=file_prefix)

    def list_tables(self, prefix):
        return filter(lambda s: bool(s), self._list_tables(prefix, False).splitlines())

    def list_tables_long(self, prefix):
        return json.loads(self._list_tables(prefix, True))

    def copy_table(self, src, dst):
        self.client.run(
            '-copy -src {src} -dst {dst} '.format(src=src, dst=dst),
            log_prefix=self.log_prefix + 'copy_table',
            wait=True,
            check=True
        )

    def mangle_tname(self, tname, prefix, cut_prefix=False, add_prefix=None):
        if cut_prefix:
            tname = tname[len(prefix):]
        if add_prefix:
            tname = add_prefix + tname
        return tname

    def make_process_pool(self, nproc):
        if not nproc:
            nproc = max(1, sdk_util.system_info()['ncpu'] / 2)
        return ProcessPool(nproc)

    def download(self, dirname, prefix, cut_prefix=False,
                 add_prefix=None, skip_unchanged=None, flt=None, sort=False,
                 table_to_file=None):
        make_folder(dirname)
        read_cmd = self.client.command("{filters} -lenval -subkey -read ".format(filters=self.mr_filters))
        sort_cmd = self.client.command("-sort ")
        to_skip = {}
        if skip_unchanged:
            for t in skip_unchanged:
                to_skip[t["name"]] = t
        cmds = []
        for tdata in sorted(self.list_tables_long(prefix)):
            table = tdata["name"]
            if table in to_skip:
                old = to_skip[table]
                if all(old[f] == tdata[f] for f in ['records', 'creat_time', 'mod_time']):
                    logging.info("Skipping %s: it's unchanged from input" % table)
                    continue
                else:
                    logging.info("Table '%s' modified, was %s, now %s" % (table, repr(old), repr(tdata)))
            if callable(flt) and not flt(table):
                continue

            cur_cmd = ""
            if sort and not tdata.get("sorted"):
                logging.info("Table {} is not sorted, going to fix it".format(table))
                cur_cmd = sort_cmd + table + " && "
                tdata["sorted"] = True
            issrt = '.sorted' if tdata.get("sorted") else ''
            fname = self.mangle_tname(table, prefix, cut_prefix, add_prefix)
            fname = fname.replace('/', ':') + issrt + '.lenval.gz'
            if table_to_file is not None:
                table_to_file[table] = fname
            fname = pj(dirname, fname)
            cur_cmd += "echo {} | bash -e -o pipefail".format(
                pipes.quote(read_cmd + table + ' | gzip > ' + fname)
            )
            cmds.append(cur_cmd)

        p = self.make_process_pool(None)
        p.map(lambda cmd: run_process(
            cmd, log_prefix=self.log_prefix + 'download.read', shell=True, wait=True, check=True), cmds
        )

    def file2tablename(self, fname, prefix='', cut_prefix=False, add_prefix=None):
        return self.mangle_tname(fname.replace('.sorted.lenval.gz', '').replace('.lenval.gz', '').replace(':', '/'), prefix, cut_prefix, add_prefix)

    def tables2resource(self, task, descr, prefix, cut_prefix=False, add_prefix=None,
                        attrs=None, skip_unchanged=None, flt=None, resource_id=None,
                        sort_tables=False, table_to_file=None):
        logging.info("saving tables with prefix {} (cut_prefix={}, add_prefix={})".format(prefix, cut_prefix, add_prefix))
        dirname = get_unique_file_name(task.abs_path(), 'mapreduce_tables')
        make_folder(dirname)
        self.download(
            dirname, prefix,
            cut_prefix=cut_prefix,
            add_prefix=add_prefix,
            skip_unchanged=skip_unchanged,
            flt=flt,
            sort=sort_tables,
            table_to_file=table_to_file
        )
        realattrs = attrs.copy()
        realattrs['tables_prefix'] = prefix
        resource = None
        if resource_id:
            resource = channel.sandbox.get_resource(resource_id)
        if not (resource and resource.type == resource_types.USERDATA_TABLES_ARCHIVE):
            resource = task.create_resource(
                descr,
                dirname,
                'USERDATA_TABLES_ARCHIVE',  # XXX
                arch='any',
            )
        for n, v in attrs.iteritems():
            channel.sandbox.set_resource_attribute(resource.id, n, v)
        channel.sandbox.set_resource_attribute(resource.id, 'tables_prefix', prefix)
        task.change_resource_basename(resource.id, dirname)
        task.mark_resource_ready(resource.id)
        return resource

    def resid2resource(self, resource):
        resid = None
        try:
            resid = int(resource)
        except TypeError:
            pass

        if resid:
            nresource = channel.sandbox.get_resource(resid)
            if not nresource:
                raise SandboxTaskFailureError("Error fetching resource %d" % resid)
            resource = nresource
        return resource

    def resource2tabledescs(self, task, resource, cut_prefix=False, add_prefix=None):
        resource = self.resid2resource(resource)
        dirname = task.sync_resource(resource.id)
        orig_prefix = resource.attributes['tables_prefix']
        out = []
        tables = os.listdir(dirname)
        for t in sorted(tables):
            out.append({
                'filename': t,
                'fullpath': pj(dirname, t),
                'sorted': t.endswith('.sorted.lenval.gz'),
                'tablename': self.file2tablename(t, orig_prefix, cut_prefix, add_prefix),
                'shortname': self.file2tablename(t, orig_prefix, cut_prefix=True, add_prefix=''),
            })
        return out

    def upload_one(self, tdesc, sort_tables=True):
        lctx = dict(tdesc)
        lctx["writecmd"] = '-writesorted' if tdesc['sorted'] else '-write'
        lctx["timeout"] = '-timeout 3600' if self.client.is_yt() else ''
        lctx["sortby"] = ""
        tries = 2
        for take in range(tries):
            try:
                cmd = 'zcat {fullpath} | ' + self.client.command(' {timeout} -lenval -subkey {writecmd} {tablename}')
                run_process(cmd.format(**lctx), shell=True, check=True, wait=True, log_prefix=self.log_prefix + 'upload_one.write')
                break
            except SandboxSubprocessError:
                if take < tries - 1:
                    pass
                else:
                    raise
        need_sort = sort_tables and (not tdesc["sorted"] or self._strictly_sorted_yt_tables)
        if need_sort:
            if self._strictly_sorted_yt_tables:
                lctx["sortby"] = " -sortby key -sortby subkey -sortby value "

            self.client.run(
                '-sort -src {tablename} -dst {tablename} {sortby}'.format(**lctx),
                check=True, wait=True, log_prefix=self.log_prefix + 'upload_one.sort'
            )

    def resource2tables(self, task, resource, cut_prefix=False, add_prefix=None, nproc=None):
        resource = self.resid2resource(resource)
        to_upload = self.resource2tabledescs(task, resource, cut_prefix, add_prefix)
        p = self.make_process_pool(nproc)
        p.map(lambda tdesc: self.upload_one(tdesc), to_upload)
        return resource
