import crypta.lib.python.bt.conf.conf as conf
import logging
import time
import os

from datetime import datetime

from yt.wrapper import YPath
from crypta.lib.python.native_yt import (
    run_native_map,
    run_native_map_reduce,
    run_native_map_reduce_with_combiner,
    run_native_reduce,
    run_native_join_reduce,
)
from crypta.lib.python.native_yt.proto import create_schema, extract_fields

from crypta.graph.export.lib.native import (
    graph_yql_proto_fields,
    TMapEdgesWithActiveIdentifiers,
    TReduceEdgesWithActiveIdentifiers,
    TFilterUnchangedGraphs,
    TReduceTVEdges,
    TEmptyMapper,
    TCombine,
    TProfileFilterMapper,
    TJoinTVEdgesWithCryptaID,
    TFilterUsedTVEdges,
)

from crypta.graph.export.lib.proto.messages_pb2 import (
    TGraphRecord,
    TTimestampState,
    TEdgeV2,
    TRecordWithProfile,
)
from yt.wrapper.common import (
    date_string_to_timestamp,
)

from crypta.graph.export.lib.python.spanning_graph import proto_reducer as spanning_reducer

logger = logging.getLogger(__name__)


def _compatible_version(version):
    # hacky hack to use proto
    if isinstance(version, int):
        mapping = {2: "V2", 3: "EXP"}
        return mapping[version]
    else:
        return version.upper()


def _get_path(path, columns):
    return "{}{{{}}}".format(path, ",".join(columns))


class ExportConf:

    GENERATE_DATE = "generate_date"
    ORIGIN_GENERATE_DATE = "origin_generate_date"
    MODIFICATION_TIME = "modification_time"
    VERSION_2 = "V2"
    EXPERIMENT = "EXP"
    HOUSEHOLDS = "HOUSEHOLDS"

    SUPPORTED_VERSIONS = {VERSION_2, EXPERIMENT}
    NEW_VERSIONS = {VERSION_2, EXPERIMENT}

    # has't graph path and destination_root is equal to output_root
    SPECIAL_VERSIONS = {HOUSEHOLDS}

    @classmethod
    def is_new_version(cls, version):
        return version in cls.NEW_VERSIONS

    @classmethod
    def is_special_version(cls, version):
        return version in cls.SPECIAL_VERSIONS


class Paths(object):
    def __init__(self, version, sources_with_output_root=None):
        self.version = version
        self.current_graphs = None
        self.rejected_graphs = None
        self.rejected_graphs_final = None
        self.active_ids = None
        self.tv_edges = None

        self.source_edges = None
        self.source_vertices = None
        self.compatible_version = _compatible_version(version)

        self.output_root = {
            ExportConf.VERSION_2: YPath(conf.Paths.Output.Main),
            ExportConf.EXPERIMENT: YPath(conf.Paths.Output.MainExp),
            # Households will not be upload to IS as is, so skip config for ExportConf.HOUSEHOLDS
        }.get(self.compatible_version)

        self.source_edges = {
            ExportConf.VERSION_2: YPath(conf.Paths.Input.Edges),
            ExportConf.EXPERIMENT: YPath(conf.Paths.Input.ExpEdges),
            ExportConf.HOUSEHOLDS: YPath(conf.Paths.Input.HouseholdsEdges),
        }.get(self.compatible_version)

        self.source_vertices = {
            ExportConf.VERSION_2: YPath(conf.Paths.Input.Vertices),
            ExportConf.EXPERIMENT: YPath(conf.Paths.Input.ExpVertices),
        }.get(self.compatible_version)

        if sources_with_output_root is not None:
            source_edges, output_root = sources_with_output_root
            version = ExportConf.VERSION_2
            self.source_edges = YPath(source_edges)
            self.output_root = YPath(output_root)

        if ExportConf.is_special_version(version):
            self.destination_root = self.output_root
        else:
            self.destination_root = self.output_root.join(conf.Paths.Output.RelativePaths.Vulture)
            self.current_graphs = self.output_root.join(conf.Paths.Output.RelativePaths.Graphs)
            self.rejected_graphs = self.output_root.join(conf.Paths.Output.RelativePaths.RejectedGraphs)
            self.rejected_graphs_final = self.output_root.join(conf.Paths.Output.RelativePaths.RejectedGraphsFinal)
            self.active_ids = self.output_root.join(conf.Paths.Output.RelativePaths.ActiveIdentifiers)
            self.tv_edges = self.output_root.join(conf.Paths.Output.RelativePaths.Households)

        self.current_tables = [
            self.current_graphs,
            self.rejected_graphs,
            self.rejected_graphs_final,
            self.active_ids,
            self.tv_edges,
        ]

    def get_source_edges(self):
        return self.source_edges

    def get_source_vertices(self):
        return self.source_vertices

    def get_destination_root(self):
        return self.destination_root

    def get_destination(self, timestamp):
        return self.destination_root.join(timestamp)

    def get_current_graphs(self):
        assert not ExportConf.is_special_version(self.version)
        return self.current_graphs

    def get_rejected_graphs(self):
        return self.rejected_graphs

    def get_rejected_graphs_final(self):
        return self.rejected_graphs_final

    def get_active_ids(self):
        return self.active_ids

    @staticmethod
    def get_identifiers_paths(client, columns=("id", "id_type", "dates")):
        yuids = client.TablePath(conf.Paths.Input.Identifiers.yuids, columns=columns)
        device_ids = client.TablePath(conf.Paths.Input.Identifiers.device_ids, columns=columns)
        uuids = client.TablePath(conf.Paths.Input.Identifiers.uuids, columns=columns)
        return [yuids, device_ids, uuids]


class Exporter(object):

    K_DAY_SECONDS = 24 * 60 * 60
    K_KYLOBYTE = 1 << 10  # bytes in kylobyte
    K_MEGABYTE = 1 << 20  # bytes in megabyte
    K_GYGABYTE = 1 << 30  # bytes in gygabyte

    def __init__(self, client):
        self.client = client

    def extract_profiles(self, without_updating=False, timestamp=0, **operation_kwargs):
        yuids = self._create_table(conf.Paths.ActiveProfiles.yuids, TRecordWithProfile)
        all_ids = self._create_table(conf.Paths.ActiveProfiles.all_ids, TRecordWithProfile)

        log_dir = conf.Paths.ActiveProfiles.log_dir
        if not self.client.exists(log_dir) or without_updating:
            return all_ids

        sources = self._get_last_tables(log_dir, conf.Options.ActiveProfiles.n_tables)
        attr_values = [
            (str(table), date_string_to_timestamp(table.attributes.get("creation_time"))) for table in sources
        ]

        attr_key = "sources_with_timestamp"

        current_attr_values = self.client.get_attribute(all_ids, attr_key, [])
        if not len(attr_values):
            return all_ids
        if len(current_attr_values):
            if (
                current_attr_values[0][0] == attr_values[0][0]
                or current_attr_values[0][1] > time.time() - 2 * self.K_DAY_SECONDS
            ):
                return all_ids

        self._create_table(yuids, TRecordWithProfile, force=True)
        self._create_table(all_ids, TRecordWithProfile, force=True)

        fields = extract_fields(TRecordWithProfile)
        run_native_map(
            mapper_name=TProfileFilterMapper,
            source=sources,
            destination=[all_ids, yuids],
            reduce_by=fields.id,
            sort_by=fields.id,
            state=self._get_timestamp_state(timestamp),
            spec={"data_size_per_sort_job": 256 * self.K_MEGABYTE},
            **operation_kwargs
        )
        self.client.set_attribute(yuids, attr_key, attr_values)
        self.client.set_attribute(all_ids, attr_key, attr_values)
        self.client.run_sort(yuids, yuids, sort_by=[fields.id, fields.id_type], sync=False)
        self.client.run_sort(all_ids, all_ids, sort_by=[fields.id, fields.id_type])

        return all_ids

    def export(self, version, diff_only=True, source_with_output_root=None):
        timestamp = str(int(time.time()))
        paths = Paths(version)
        if source_with_output_root is not None:
            paths = Paths(ExportConf.VERSION_2, source_with_output_root)
        destination = paths.get_destination(timestamp)
        logger.info("Source: %s", paths.get_source_edges())
        logger.info("Destination: %s", destination)

        with self.client.Transaction() as transaction:
            operation_kwargs = dict(
                proxy=self.client.config["proxy"]["url"],
                token=conf.Yt.Token,
                transaction=str(transaction.transaction_id),
                pool=conf.Yt.Pool,
            )

            return self._export_v2(timestamp, paths, diff_only, destination, operation_kwargs)

    def _prepare_households_to_v2_export(self, households_path, destination, vertices, operation_kwargs):
        edge_fields = extract_fields(TEdgeV2)

        with self.client.TempTable(prefix="households_") as households_reduce:

            run_native_map_reduce_with_combiner(
                mapper_name="",
                reducer_name=TReduceTVEdges,
                combiner_name=TReduceTVEdges,
                source=str(households_path),
                destination=households_reduce,
                reduce_by=[edge_fields.id1, edge_fields.id1Type],
                sort_by=[edge_fields.id1, edge_fields.id1Type, edge_fields.id2, edge_fields.id2Type],
                **operation_kwargs
            )
            self.client.run_sort(households_reduce, households_reduce, sort_by=[edge_fields.id_type, edge_fields.id])

            run_native_join_reduce(
                reducer_name=TFilterUsedTVEdges,
                source=[
                    households_reduce,
                    "<foreign=%true>" + _get_path(vertices, [edge_fields.id_type, edge_fields.id]),
                ],
                destination=households_reduce,
                join_by=[edge_fields.id_type, edge_fields.id],
                **operation_kwargs
            )

            self.client.run_sort(households_reduce, households_reduce, sort_by=[edge_fields.id_type, edge_fields.id])

            run_native_join_reduce(
                reducer_name=TJoinTVEdgesWithCryptaID,
                source=[
                    households_reduce,
                    "<foreign=%true>"
                    + _get_path(vertices, [edge_fields.id_type, edge_fields.id, edge_fields.cryptaId]),
                ],
                destination=destination,
                join_by=[edge_fields.id_type, edge_fields.id],
                **operation_kwargs
            )

    def _export_v2(self, timestamp, paths, diff_only, destination, operation_kwargs):
        logger.info("Export v2")

        source_edges = paths.get_source_edges()
        source_vertices = paths.get_source_vertices()

        edge_fields = extract_fields(TEdgeV2)
        graph_fields = extract_fields(TGraphRecord)
        self._create_output(destination)
        origin_generate_date = self.client.get_attribute(source_edges, ExportConf.GENERATE_DATE, None)
        if origin_generate_date is None:
            origin_generate_date = self._get_modification_date(source_edges)

        generate_date = datetime.today().strftime("%Y-%m-%d")
        with self.client.TempTable(prefix="edges_") as edges:

            active_ids = paths.get_active_ids()

            households = _get_path(
                Paths(ExportConf.HOUSEHOLDS).get_source_edges(),
                [
                    edge_fields.id1,
                    edge_fields.id1Type,
                    edge_fields.id2,
                    edge_fields.id2Type,
                    edge_fields.logSource,
                    edge_fields.sourceType,
                ],
            )
            households_reduce = paths.tv_edges

            active_bigb_ids = self.extract_profiles(
                without_updating=(paths.compatible_version == ExportConf.EXPERIMENT),
                timestamp=timestamp,
                **operation_kwargs
            )

            self._prepare_households_to_v2_export(households, households_reduce, source_vertices, operation_kwargs)

            edges_with_identifiers_tables = [
                # input for TMapEdgesWithActiveIdentifiers::Do switch(tab_index)
                source_edges,  # v2_edges  tab_index 0
                households_reduce,  # tv_edges  tab_index 1
                active_bigb_ids,  # Active    tab_index 2
            ]
            edges_with_identifiers_tables.extend(paths.get_identifiers_paths(self.client))

            run_native_map_reduce(
                mapper_name=TMapEdgesWithActiveIdentifiers,
                reducer_name=TReduceEdgesWithActiveIdentifiers,
                source=edges_with_identifiers_tables,
                destination=[edges, active_ids],
                reduce_by=[edge_fields.id, edge_fields.id_type],
                sort_by=[edge_fields.id, edge_fields.id_type, edge_fields.cryptaId],
                mapper_state=self._get_timestamp_state(timestamp),
                **operation_kwargs
            )

            self._create_output(paths.get_rejected_graphs())

            with self.client.TempTable(prefix="graphs_") as graphs, self.client.TempTable(
                prefix="final_graphs"
            ) as final_graphs:

                self.client.alter_table(graphs, schema=create_schema(TGraphRecord))

                run_native_map_reduce(
                    mapper_name=TEmptyMapper,
                    reducer_name=TCombine,
                    source=[edges],
                    destination=[graphs, paths.get_rejected_graphs()],
                    reduce_by=edge_fields.cryptaId,
                    sort_by=[
                        edge_fields.cryptaId,
                        edge_fields.IsNotPrivate,
                        edge_fields.id2,
                        edge_fields.id2Type,
                        edge_fields.id,
                        edge_fields.id1,
                        edge_fields.id1Type,
                        edge_fields.sourceType,
                        edge_fields.logSource,
                    ],
                    spec={"reducer": {"memory_limit": 4 * self.K_GYGABYTE}},
                    **operation_kwargs
                )

                self.client.run_sort(graphs, sort_by=graph_fields.cryptaId, spec={"combine_chunks": True})
                self.client.run_sort(active_ids, sort_by=graph_fields.cryptaId, spec={"combine_chunks": True})

                # write sorted directly from reducer. Won't work for map-reduce.
                self.client.alter_table(
                    final_graphs,
                    schema=create_schema(
                        TGraphRecord, dynamic=False, strong=False, cryptaId={"sort_order": "ascending"}
                    ),
                )
                try:
                    self.client.run_reduce(
                        spanning_reducer(conf.Options.max_component_size),
                        [graphs, active_ids],
                        [final_graphs, paths.get_rejected_graphs_final()],
                        reduce_by=graph_fields.cryptaId,
                    )
                except Exception as e:
                    from yt.common import format_error

                    raise Exception(format_error(e, 10000))

                current_graphs = paths.get_current_graphs()
                if diff_only and self.client.exists(current_graphs):
                    with self.client.TempTable(prefix="graphs_diff_") as graphs_diff:
                        run_native_reduce(
                            reducer_name=TFilterUnchangedGraphs,
                            source=[final_graphs, current_graphs],
                            destination=graphs_diff,
                            reduce_by=edge_fields.cryptaId,
                            **operation_kwargs
                        )

                        self.client.run_sort(
                            graphs_diff, destination, sort_by=graph_fields.cryptaId, spec={"combine_chunks": True}
                        )
                else:
                    self.client.copy(final_graphs, destination, force=True, recursive=True)

                self.client.run_merge(destination, destination, spec={"combine_chunks": True}, mode="ordered")
                self.client.move(final_graphs, current_graphs, force=True, recursive=True)

                self._add_graph_yql_attributes(path=destination)
                self.client.set_attribute(destination, ExportConf.GENERATE_DATE, generate_date)

                self._add_graph_yql_attributes(path=current_graphs)
                self.client.set_attribute(current_graphs, ExportConf.GENERATE_DATE, generate_date)

                self.client.set_attribute(destination, ExportConf.ORIGIN_GENERATE_DATE, origin_generate_date)

        self._remove_old_tables(paths.get_destination_root(), timestamp, int(conf.Options.max_vulture_tables))
        return destination

    def _create_table(self, path, proto, force=False, schema_kwargs=None):
        schema_kwargs = schema_kwargs or {}
        if force and self.client.exists(path):
            self.client.remove(path)
        self.client.create(
            "table",
            path,
            ignore_existing=True,
            recursive=True,
            attributes=dict(
                schema=create_schema(proto, **schema_kwargs),
                optimize_for="scan",
            ),
        )
        return path

    def _add_graph_yql_attributes(self, path):
        for key, value in graph_yql_proto_fields().items():
            self.client.set_attribute(path, key, value)

    def _create_output(self, path, schema_kwargs=None):
        self._create_table(path, TGraphRecord, schema_kwargs)
        self._add_graph_yql_attributes(path)

    def _get_timestamp_state(self, timestamp):
        state = TTimestampState()
        state.OldestActiveTimestamp = int(int(timestamp) - self.K_DAY_SECONDS * int(conf.Options.active_interval))
        state.OldestBBActiveTimestamp = int(
            int(timestamp) - self.K_DAY_SECONDS * int(conf.Options.ActiveProfiles.bb_active_interval)
        )

        return state.SerializeToString()

    def _get_last_tables(self, _dir, n_tables):
        return sorted(
            self.client.list(_dir, absolute=True, attributes=["creation_time"]),
            key=lambda table: table.attributes.get("creation_time", ""),
            reverse=True,
        )[:n_tables]

    def _get_modification_date(self, path):
        return datetime.strptime(
            self.client.get_attribute(path, ExportConf.MODIFICATION_TIME), "%Y-%m-%dT%H:%M:%S.%fZ"
        ).strftime("%Y-%m-%d")

    def _make_copy(self, destination, from_paths, to_paths):
        basename = os.path.basename(str(destination))
        self.client.copy(destination, to_paths.get_destination(basename), recursive=True, force=True)
        for (tableFrom, tableTo) in zip(from_paths.current_tables, to_paths.current_tables):
            if self.client.exists(tableFrom):
                self.client.copy(tableFrom, tableTo, force=True)

    def _remove_old_tables(self, root, timestamp, max_count):
        timestamp = int(timestamp)

        if not self.client.exists(root):
            return
        tables = list(self.client.list(root, attributes=["modification_time"], absolute=True))
        if len(tables) < max_count:
            return

        oldest_timestamp = 0
        shift = max_count * self.K_DAY_SECONDS
        if shift < timestamp:
            oldest_timestamp = timestamp - shift
        for table in tables:
            if date_string_to_timestamp(table.attributes.get("modification_time")) < oldest_timestamp:
                self.client.remove(table)
