import logging
from collections import OrderedDict
from typing import NamedTuple, Optional

from yt.wrapper.schema.table_schema import TableSchema
from yt.wrapper.schema.types import ti

from travel.hotels.content_manager.lib.persistence_manager import Node, YtPersistenceManager, TableRow


class TableInfo(NamedTuple):
    data: list[TableRow]
    schema: TableSchema


class HistoryMerger:

    __merged_dir__ = 'merged'
    __job_to_merge_count__ = 10
    __job_to_keep_count__ = 5

    def __init__(self, persistence_manager: YtPersistenceManager):
        self.persistence_manager = persistence_manager

    def do_merge(self, history_path: str) -> None:
        for stage_node in self.persistence_manager.list(history_path):
            self.merge_stage(stage_node)

    def merge_stage(self, stage_node: Node) -> None:
        logging.info(f'Merging {stage_node}')

        merged_node = None
        jobs = list()
        for job_node in self.persistence_manager.list(stage_node.path):
            if job_node.name == self.__merged_dir__:
                merged_node = job_node
                continue
            if job_node.type != 'map_node':
                continue
            jobs.append(job_node)

        if len(jobs) < self.__job_to_merge_count__ + self.__job_to_keep_count__:
            logging.info('Not enough jobs to merge')
            return

        jobs.sort(key=lambda x: x.created_at)
        jobs = jobs[:self.__job_to_merge_count__]

        with self.persistence_manager.transaction():
            self.merge_artifact(stage_node.path, jobs, merged_node, 'input')
            self.merge_artifact(stage_node.path, jobs, merged_node, 'output')

        for job in jobs:
            logging.info(f'Deleting {job.path}')
            self.persistence_manager.delete(job.path)

    def merge_artifact(self, stage_path: str, jobs: [Node], merged_node: Optional[Node], artifact: str) -> None:
        logging.info(f'Merging {artifact}')

        schemas = self.get_artifact_schemas(jobs, merged_node, artifact)
        logging.info(f'{schemas=}')

        table_data = dict()
        if merged_node:
            self.fill_job_data(merged_node, artifact, schemas, table_data, fill_meta=False)
        for job in jobs:
            self.fill_job_data(job, artifact, schemas, table_data, fill_meta=True)

        for table_name, table_info in table_data.items():
            logging.info(f'Merging {table_name}')
            merged_path = self.persistence_manager.join(stage_path, self.__merged_dir__, artifact, table_name)
            temp_path = self.persistence_manager.join(stage_path, self.__merged_dir__, artifact, table_name + '_temp')
            self.persistence_manager.write(temp_path, table_info.data, table_info.schema)

            if self.persistence_manager.exists(merged_path):
                self.persistence_manager.delete(merged_path)
            self.persistence_manager.copy(temp_path, merged_path)
            self.persistence_manager.delete(temp_path)

    def get_artifact_schemas(self, jobs: [Node], merged_node: Optional[Node], artifact: str) -> dict[str, TableSchema]:
        schemas = dict()
        if merged_node:
            jobs = jobs[:] + [merged_node]
        for job in jobs:
            tables_path = self.persistence_manager.join(job.path, artifact)
            for table_node in self.persistence_manager.list(tables_path):
                if table_node.type != 'table':
                    continue
                table_schema = self.persistence_manager.yt_client.get_attribute(table_node.path, 'schema')
                table_schema = TableSchema.from_yson_type(table_schema)
                merged_schema = schemas.get(table_node.name)
                if merged_schema is None:
                    schemas[table_node.name] = table_schema
                    continue
                schemas[table_node.name] = self.get_merged_schema(merged_schema, table_schema)

        schema_extension = TableSchema()
        schema_extension.add_column('created_at', ti.Uint64)
        schema_extension.add_column('job_id', ti.String)

        extended_schemas = dict()
        for table_name, schema in schemas.items():
            extended_schemas[table_name] = self.get_merged_schema(schema, schema_extension)
        return extended_schemas

    @staticmethod
    def get_merged_schema(left: TableSchema, right: TableSchema) -> TableSchema:
        columns = OrderedDict()
        for column in left.columns:
            columns[column.name] = column
        for column in right.columns:
            columns[column.name] = column
        merged_schema = TableSchema()
        merged_schema.strict = False
        merged_schema.unique_keys = left.unique_keys
        for column in columns.values():
            if column.type.name != 'Optional':
                column.type = ti.Optional[column.type]
            merged_schema.add_column(column)
        return merged_schema

    def fill_job_data(
        self, job: Node,
        artifact: str,
        schemas: dict[str, TableSchema],
        table_data: dict[str, TableInfo],
        fill_meta: bool,
    ) -> None:
        path = self.persistence_manager.join(job.path, artifact)
        tables = self.persistence_manager.list(path)
        for table_node in tables:
            if table_node.type != 'table':
                continue
            logging.info(f'Reading {table_node.path}')
            table_info = TableInfo(list(), schemas[table_node.name])
            table_info = table_data.setdefault(table_node.name, table_info)
            data = list()
            for row in self.persistence_manager.read(table_node.path):
                if fill_meta:
                    row['created_at'] = int(job.created_at.timestamp())
                    row['job_id'] = job.name
                data.append(row)
            table_info.data.extend(data)
