# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals

try:
    from yt_worker import YtWorker, yt
except ImportError:
    from .yt_worker import YtWorker, yt


def _string_to_dt(timestring):
    from datetime import datetime, timedelta
    mseconds = timestring.split(".")[1]
    mseconds = int(mseconds.rstrip("Z"), 10)
    timestring = timestring.split(".")[0]

    dt = datetime.strptime(timestring, "%Y-%m-%dT%H:%M:%S")
    dt = dt + timedelta(microseconds=mseconds)
    return dt


class YtAclChecker(YtWorker):
    """
    Main class, child of assistive class with main parameters setup and secondary functions.

    Class attributes:
    analyze_cluster -- list of cluster for analysis.
    current_table -- string with date in format %Y%m%d, to analyze tables in the past.

    """

    def __init__(self, yt_token, hec_token, analyze_clusters=None, current_table=None):
        super(YtAclChecker, self).__init__(yt_token, analyze_clusters=analyze_clusters)
        self.table_id = None
        self.hec_token = hec_token

        if current_table is None:
            self.current_dump_table = self.get_current_date()
            self.old_dump_table = self.get_yesterday_date()
        else:
            self.check_timeformat(current_table)
            self.current_dump_table = current_table
            self.old_dump_table = self.get_yesterday_date(base_date=current_table)

        self.dest_tables_list = list()
        self.results_paths = None

    # Mapper functions
    @yt.aggregator
    def _map_prepare_table(self, records):
        mapped_record = dict()

        for record in records:
            size = record.get("size")
            path = record.get("path")
            owner = record.get("owner")
            effective_acl = record.get("effective_acl")
            inherit_acl = record.get("inherit_acl")
            ttl = record.get("expiration_time")
            table_id = self.table_id

            if size is not None and path is not None:
                mapped_record.update({"table_id": table_id,
                                      "path": path,
                                      "owner": owner,
                                      "effective_acl": effective_acl,
                                      "inherit_acl": inherit_acl,
                                      "size": size,
                                      "expiration_time": ttl
                                      })
                yield mapped_record

    def _reduce_result_table(self, key, records):
        path = key["path"]
        reduced_record = dict()
        reduced_record["path"] = path
        reduced_record["table_id"] = list()

        for record in records:
            table_id = record.get("table_id")
            size = record.get("size")
            owner = record.get("owner")
            effective_acl = record.get("effective_acl")
            inherit_acl = record.get("inherit_acl")
            ttl = record.get("expiration_time")
            if table_id not in reduced_record["table_id"]:
                reduced_record["table_id"].append(table_id)

            reduced_record["size"] = size
            reduced_record["owner"] = owner
            reduced_record["effective_acl"] = effective_acl
            reduced_record["inherit_acl"] = inherit_acl
            reduced_record["expiration_time"] = ttl
        if len(reduced_record["table_id"]) == 1 and "old" in reduced_record["table_id"]:
            yield reduced_record

    def get_diff(self):
        for cluster in self.clusters:
            # Prepare first table path for diff
            source_old_table = list([self.source_path, cluster])
            source_old_table.append(".".join([cluster, self.old_dump_table]))
            source_old_table = "/".join(source_old_table)

            # Prepare second table path for diff
            source_new_table = list([self.source_path, cluster])
            source_new_table.append(".".join([cluster, self.current_dump_table]))
            source_new_table = "/".join(source_new_table)

            if not self.yt.exists(source_old_table) or not self.yt.exists(source_new_table):
                continue

            # Prepare dest table path for results
            dest_table = list([self.dest_path, cluster])
            dest_table.append("".join([cluster, "_", self.get_current_date(delim="-")]))
            dest_table = "/".join(dest_table)

            self.dest_tables_list.append(dest_table)

            # Create node
            dest_node = "/".join(dest_table.split("/")[:-1])
            self.yt.create("map_node", dest_node, ignore_existing=True, recursive=True)

            # Map first table
            self.table_id = "old"

            dest_tmp_table_old = dest_table.split("/")[:-1]
            dest_tmp_table_old.append("old-temp")
            dest_tmp_table_old = "/".join(dest_tmp_table_old)

            self.yt.run_map(self._map_prepare_table, source_old_table, dest_tmp_table_old, spec=self.yt_spec)

            # Map second table
            self.table_id = "new"

            dest_tmp_table_new = dest_table.split("/")[:-1]
            dest_tmp_table_new.append("new-temp")
            dest_tmp_table_new = "/".join(dest_tmp_table_new)

            self.yt.run_map(self._map_prepare_table, source_new_table, dest_tmp_table_new, spec=self.yt_spec)

            # Set ttl for tmp tables == 5h
            self.yt.set(dest_tmp_table_old + "/@expiration_time", self.tmp_ttl)
            self.yt.set(dest_tmp_table_new + "/@expiration_time", self.tmp_ttl)

            # Sort
            self.yt.run_sort(dest_tmp_table_old, dest_tmp_table_old, sort_by="path")
            self.yt.run_sort(dest_tmp_table_new, dest_tmp_table_new, sort_by="path")

            # Reduce
            self.yt.run_reduce(self._reduce_result_table, [dest_tmp_table_old, dest_tmp_table_new], dest_table,
                               reduce_by="path", spec=self.yt_spec)

            # Set ttl on result table
            dest_table_with_attr = dest_table + "/@expiration_time"
            self.yt.set(dest_table_with_attr, self.ttl)

        return self.dest_tables_list

    @staticmethod
    def _get_cpu_cores_number(percentage):
        from multiprocessing import cpu_count

        percentage = percentage / 100.0

        if percentage > 1 or percentage <= 0:
            return None
        else:
            result = int(round(cpu_count() * percentage))
            if result < 1:
                result = 1
            return result

    def check_attributes(self, table_format=yt.YsonFormat()):
        import sys
        import time

        for path in self.dest_tables_list:
            result = list()
            current_path_cluster = path.split("/")[-2]
            current_path_proxy = current_path_cluster + ".yt.yandex.net"
            client = self.yt.YtClient(proxy=current_path_proxy, token=self.yt_token)

            result_table = path

            rows = self.yt.read_table(path, table_format)
            rows_list = list(rows)

            batch_client = client.create_batch_client()

            responses = [(row, batch_client.get(path=row["path"] + "/@",
                                                attributes=["effective_acl"])) for row in rows_list]

            sys.stdout.write('Committing the batch... ')
            sys.stdout.flush()
            t0 = time.time()
            batch_client.commit_batch()
            t1 = time.time()
            sys.stdout.write('done (in {:.2f}s)\n'.format(t1 - t0))
            sys.stdout.flush()

            for row, response in responses:
                effective_acl = response.get_result()
                table_is_open = False

                if effective_acl is None:
                    continue
                else:
                    for line in effective_acl["effective_acl"]:
                        if "yandex" in line["subjects"] and len(line["permissions"]) > 0:
                            table_is_open = True
                            break

                    if table_is_open:
                        row["old_effective_acl"] = row.pop("effective_acl")
                        del row["table_id"]
                        row["new_effective_acl"] = effective_acl
                        result.append(row)

            self.yt.write_table(result_table, result, format=yt.YsonFormat())

        return self.dest_tables_list

    def run_check(self):
        # Prepare diff table for last two dumps
        self.get_diff()

        # Check paths in diff table (check attributes)
        result_tables_paths = self.check_attributes()
        self.results_paths = result_tables_paths
        return self.results_paths
