import datetime  # noqa
import logging
import json

from collections import defaultdict
from retry import retry

import yt.wrapper
from yt.wrapper.common import MB
from yt.wrapper.ypath import ypath_join

FORMAT = "%(asctime)s\t%(levelname)s\t%(message)s"
logging.basicConfig(format=FORMAT)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def closure_reducer(field, date):  # c901 # noqa
    if not field:
        raise RuntimeError("VerifyField should be provided")

    @yt.wrapper.with_context
    def table_reducer(key, input_row_iterator, context):
        base = None
        source = None

        for input_row in input_row_iterator:
            if context.table_index == 0 and base is None:
                base = input_row
            elif context.table_index == 1 and source is None:
                source = input_row
            elif context.table_index > 1:
                raise RuntimeError("Unknown table index")
            else:
                raise RuntimeError("Not uniq keys")

        if source is None:
            # the value was deleted in new table
            out = dict(key)
            out.update(date=date)
            yield out
        elif base is None or source.get(field) != base.get(field):
            # the value was changed in new table
            source.update(date=date)
            yield source

    return table_reducer


def simple_mapper(row):
    yield row


def closure_mapper(date):  # c901 # noqa
    def table_mapper(row):
        row.update(date=date)
        yield row

    return table_mapper


def get_date(date):
    return datetime.datetime.strptime(date, "%Y-%m-%d")


def get_month_begin(date):
    return get_date(date).replace(day=1)


def get_internal_date(date):
    if get_date(date).day < 15:
        return get_month_begin(date)
    return get_date(date).replace(day=15)


def nonrequire(field):
    field["required"] = False
    for key in field.keys():
        if key.startswith("type_v"):
            field.pop(key)
    return field


def nonsorted(field):
    field.pop("sort_order", None)
    return field


class BackupTask(object):

    OPTIMIZE_FOR = "scan"
    COMPRESSION_CODEC = "brotli_6"
    ERASURE_CODEC = "isa_lrc_12_2_2"
    GENERATE_DATE_FORMAT = "%Y-%m-%d"
    GENERATE_DATE_ATTR = "generate_date"
    NEEDED_ATTR = {"generate_date", "_yql_proto_field_graph"}
    COMPRESS_ENABLED = False  # used nightly compression process

    def __init__(self, yt_client):
        self.yt = yt_client
        self.reactions = defaultdict(list)

    @retry(tries=3, delay=1)
    def run(self, source, backup_path, backupers):
        attribute_path = "{path}/@".format(path=source)
        self.all_attrs = {
            key: value
            for key, value in self.yt.get(attribute_path).iteritems()
            if key in self.NEEDED_ATTR
        }
        if self.GENERATE_DATE_ATTR not in self.all_attrs:
            self.all_attrs[self.GENERATE_DATE_ATTR] = datetime.datetime.now().strftime(
                self.GENERATE_DATE_FORMAT
            )
        self.backup_path = backup_path
        self.date = self.all_attrs[self.GENERATE_DATE_ATTR]
        logger.info("Run backup for `%s` at `%s`", source, self.date)

        with self.yt.Transaction():
            schema = json.loads(
                self.yt.get("{path}/@schema".format(path=source), format="json")
            )["$value"]
            with self.yt.TempTable(
                attributes={"schema": map(nonsorted, map(nonrequire, schema))}
            ) as compressed:
                self.compress(source, compressed)
                self.copy(compressed, backupers)
                self.trim(backupers)

    def compress(self, source, target):
        """compres source table for optimize chanks"""
        self.yt.create(
            "map_node", self.backup_path, recursive=True, ignore_existing=True
        )
        logger.info("Compress `%s` into `%s`", source, target)
        self.transform(source, target)
        for attribute, value in self.all_attrs.iteritems():
            self.yt.set("{path}/@{attr}".format(path=target, attr=attribute), value)
        return target

    def transform(self, source, target):
        dynamic = self.yt.get("{path}/@dynamic".format(path=source))

        if dynamic:
            # self.yt.unmount_table(source, sync=True)
            self.yt.run_map(
                simple_mapper,
                source_table=source,
                destination_table=target,
            )
            # self.yt.mount_table(source, sync=True)
        elif self.COMPRESS_ENABLED:
            self.yt.transform(
                source_table=source,
                destination_table=target,
                compression_codec=self.COMPRESSION_CODEC,
                erasure_codec=self.ERASURE_CODEC,
                desired_chunk_size=3 * 512 * MB,
                optimize_for=self.OPTIMIZE_FOR,
            )
        elif source != target:
            self.yt.copy(
                source_path=source,
                destination_path=target,
                recursive=True,
                force=True,
            )

    def copy(self, compressed, backupers):
        """copy source table to backup"""
        # copy same table as is for each directory to allow cypres chank optimisation
        for backup in backupers:
            if backup.DiffBy:
                self._diff(compressed, backup)
                self._fill_diff_missing_days(backup)
            else:
                self._copy(compressed, backup)

    def _copy(self, compressed, backup):
        source, date = self._source_path(compressed, backup.Master or None)
        target = self._table_path(backup.DirName)
        need_backuped = eval(backup.Predicate)(date) or self._fill_gap(backup, date)
        logger.info(
            "Copy backup from `%s` to `%s` [%s]",
            source,
            target,
            ("skip", "run")[need_backuped],
        )
        if need_backuped:
            self.yt.copy(
                source,
                target,
                recursive=True,
                force=True,
            )
            self.reactions[("full", date)].append(target)

    def _fill_gap(self, backup, date):
        backup_dir_path = self._path(backup.DirName)
        if not backup.PredicateGap or not self.yt.exists(backup_dir_path):
            # no gap untill no tables
            return False
        last_backuped_table = max(
            self.yt.list(path=backup_dir_path, absolute=True, sort=True)
        )
        attr_path = "{path}/@{attr}".format(
            path=last_backuped_table, attr=self.GENERATE_DATE_ATTR
        )
        return eval(backup.PredicateGap)(date, self.yt.get(attr_path))

    def _create(self, path, schema):
        self.yt.create(
            "table",
            path,
            attributes=dict(schema=schema),
            force=True,
            recursive=True,
        )

    def _diff(self, compressed, backup):
        source, date = self._source_path(compressed)
        internal_date = get_internal_date(date).strftime("%Y-%m-%d")

        internal = self._path(ypath_join(backup.DirName, "internal", internal_date))
        baselink = self._path(ypath_join(backup.DirName, "base", date))
        delta = self._path(ypath_join(backup.DirName, "delta", date))

        schema = json.loads(
            self.yt.get("{path}/@schema".format(path=source), format="json")
        )["$value"]
        schema.append({"name": "date", "type": "string"})

        self._create(delta, map(nonrequire, schema))
        # make internal base link
        self.yt.link(internal, baselink, force=True, recursive=True)
        if not self.yt.exists(internal):
            self._create(internal, map(nonsorted, map(nonrequire, schema)))
            self.yt.run_map(
                closure_mapper(date), source_table=source, destination_table=internal
            )
            self.yt.run_sort(
                source_table=internal,
                destination_table=internal,
                sort_by=list(backup.DiffBy),
            )
            self.transform(internal, internal)
            self.yt.run_merge(
                source_table=internal,
                destination_table=internal,
                mode="sorted",
                spec=dict(combine_chunks=True, merge_by=list(backup.DiffBy)),
            )
            self.yt.set(
                "{path}/@{attr}".format(path=internal, attr=self.GENERATE_DATE_ATTR),
                date,
            )
        else:
            # reduce base to delta
            self.yt.run_reduce(
                closure_reducer(backup.VerifyField, date),
                source_table=[baselink, source],
                destination_table=delta,
                reduce_by=list(backup.DiffBy),
                sort_by=list(backup.DiffBy),
                # spec={"auto_merge": {"mode": "relaxed"}},
            )
            # TODO: YT-8024 use sorted automerge
            self.yt.run_merge(
                source_table=delta,
                destination_table=delta,
                mode="sorted",
                spec=dict(combine_chunks=True, merge_by=list(backup.DiffBy)),
            )

        # set attributes
        self.yt.set(
            "{path}/@{attr}".format(path=delta, attr=self.GENERATE_DATE_ATTR), date
        )
        self.reactions[("difflog", date)].extend([baselink, delta])

    def _fill_diff_missing_days(self, backup):
        tables = list(
            self.yt.list(
                path=self._path(ypath_join(backup.DirName, "delta")),
                absolute=True,
                sort=True,
            )
        )
        if not tables:
            return
        last_delta = tables[0]
        for table in tables[1:]:
            last_table_dt = get_date(last_delta.split("/")[-1])
            for offset in range(
                1, (get_date(table.split("/")[-1]) - last_table_dt).days
            ):
                offset_date = (
                    last_table_dt + datetime.timedelta(days=offset)
                ).strftime("%Y-%m-%d")
                last_date = last_table_dt.strftime("%Y-%m-%d")

                deltalink = self._path(ypath_join(backup.DirName, "delta", offset_date))
                baselink = self._path(ypath_join(backup.DirName, "base", offset_date))
                last_baselink = self._path(
                    ypath_join(backup.DirName, "base", last_date)
                )
                self.yt.link(last_delta, deltalink, force=True, recursive=True)
                self.yt.link(last_baselink, baselink, force=True, recursive=True)

                self.reactions[("difflog", offset_date)].extend([baselink, deltalink])

            last_delta = table

    def trim(self, backupers):
        """clean up tail of each backup directory"""
        for backup in backupers:
            if backup.Repeats == 0:
                # no trim if no limit
                continue
            if backup.DiffBy:
                self._trim_diff(backup)
            else:
                self._trim_daily(backup)

    def _trim_diff(self, backup):
        def month_filter(table):
            table_month = get_month_begin(str(table).split("/")[-1])
            clear_month = (
                get_month_begin(self.date)
                - datetime.timedelta(days=30 * backup.Repeats)
            ).replace(day=1)
            return table_month <= clear_month

        for keypath in ("internal", "base", "delta"):
            tables = list(
                filter(
                    month_filter,
                    self.yt.list(
                        path=self._path(ypath_join(backup.DirName, keypath)),
                        absolute=True,
                        sort=True,
                    ),
                )
            )
            logger.info("Trim backups in `%s/%s` %s", backup.DirName, keypath, tables)
            for table in tables:
                self.yt.remove(table, force=True)

    def _trim_daily(self, backup):
        if not self.yt.exists(self._path(backup.DirName)):
            return
        tables = list(
            self.yt.list(path=self._path(backup.DirName), absolute=True, sort=True)
        )[: -backup.Repeats]
        logger.info("Trim backups in `%s` %s", backup.DirName, tables)
        for table in tables:
            self.yt.remove(table, force=True)

    def _path(self, name):
        if name is None:
            raise ValueError("Name should be string")
        return ypath_join(self.backup_path, name)

    def _table_path(self, directory):
        return ypath_join(self._path(directory), self.date)

    def _source_path(self, compressed, master=None):
        try:
            source = min(self.yt.list(self._path(master), absolute=True, sort=True))
        except:
            source = compressed
        date = self.yt.get(
            "{path}/@{attr}".format(path=source, attr=self.GENERATE_DATE_ATTR)
        )
        return source, date
