import logging

from crypta.lib.python import native_yt
from crypta.lib.python.yt import (
    schema_utils,
    yt_helpers,
)
from crypta.siberia.bin.common.yt_describer.proto.group_stats_pb2 import TGroupStats
from crypta.siberia.bin.common.yt_describer.proto.sampler_state_pb2 import TSamplerState
from crypta.siberia.bin.common.yt_describer.py import native_operations


BY_CRYPTA_ID = ["CryptaId"]
BY_GROUP_ID = ["GroupID"]
BY_ID = ["IdValue", "IdType"]


def describe(yt_client, tx, config):
    logger = logging.getLogger(__name__)

    common_kwargs = {
        "proxy": str(yt_client.config["proxy"]["url"]),
        "pool": str(yt_client.config["pool"]),
        "token": str(yt_client.config["token"]),
        "transaction": str(tx.transaction_id),
    }

    def tmp_table():
        return yt_client.TempTable(path=config.TmpDir)

    with tmp_table() as sampler_output_table, tmp_table() as matcher_output_table, tmp_table() as stats_getter_output_table, tmp_table() as stats_merger_output_table:
        yt_helpers.set_attribute(stats_getter_output_table, "compression_codec", "zstd_3", client=yt_client)

        id_to_crypta_id_table = yt_client.TablePath(config.IdToCryptaIdTable, rename_columns={"id": "IdValue", "id_type": "IdType", "crypta_id": "CryptaId"})
        crypta_id_user_data_table = yt_client.TablePath(config.CryptaIdUserDataTable, rename_columns={"crypta_id": "CryptaId", "stats": "Stats"})

        logger.info("Run TSampler...")
        native_yt.run_native_map_reduce(
            mapper_name=native_operations.TEmptyMapper,
            reducer_name=native_operations.TSampler,
            source=[config.InputTable],
            destination=[sampler_output_table],
            reduce_by=BY_GROUP_ID,
            sort_by=BY_GROUP_ID,
            reducer_state=TSamplerState(MaxSampleSize=config.MaxSampleSize).SerializeToString(),
            **common_kwargs
        )

        logger.info("Run sort...")
        yt_client.run_sort(sampler_output_table, sampler_output_table, BY_ID)

        logger.info("Run TMatcher...")
        native_yt.run_native_reduce(
            reducer_name=native_operations.TMatcher,
            source=[sampler_output_table, id_to_crypta_id_table],
            destination=[matcher_output_table],
            reduce_by=BY_ID,
            **common_kwargs
        )

        logger.info("Run sort...")
        yt_client.run_sort(matcher_output_table, matcher_output_table, BY_CRYPTA_ID)

        logger.info("Run TStatsGetter...")
        native_yt.run_native_reduce(
            reducer_name=native_operations.TStatsGetter,
            source=[matcher_output_table, crypta_id_user_data_table],
            destination=[stats_getter_output_table],
            reduce_by=BY_CRYPTA_ID,
            **common_kwargs
        )

        logger.info("Run TStatsMerger...")
        native_yt.run_native_map_reduce(
            mapper_name=native_operations.TEmptyMapper,
            reducer_name=native_operations.TStatsMerger,
            source=[stats_getter_output_table],
            destination=[stats_merger_output_table],
            reduce_by=BY_GROUP_ID,
            sort_by=BY_GROUP_ID,
            **common_kwargs
        )

        logger.info("Remove old output table...")
        yt_client.remove(config.OutputTable, force=True)

        logger.info("Run sort...")
        yt_client.run_sort(stats_merger_output_table, yt_client.TablePath(config.OutputTable, attributes={"schema": get_output_schema()}), BY_GROUP_ID)


def get_output_schema():
    return schema_utils.get_schema_from_proto(TGroupStats, key_columns=BY_GROUP_ID)
