import typing
import logging
import yt.wrapper
from customer_service.ml.lib.data.knowledges import prepare_actual_knowledges, map_knowledges

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@yt.wrapper.yt_dataclass
class PreprocessedDataRow:
    ticketId: typing.Optional[int]
    messageTs: typing.Optional[int]
    knowledgeId: typing.Optional[str]
    knowledgeTitle: typing.Optional[str]
    text: typing.Optional[str]


@yt.wrapper.yt_dataclass
class PreparedDataRow:
    ticketId: typing.Optional[int]
    messageTs: typing.Optional[int]
    knowledgeId: typing.Optional[str]
    knowledgeTitle: typing.Optional[str]
    text: typing.Optional[str]
    best_distance: typing.Optional[float]
    best_match: typing.Optional[str]


class PrepareData(yt.wrapper.TypedJob):
    def __init__(self, actual_knowledges) -> None:
        super().__init__()
        self.actual_knowledges = actual_knowledges

    def prepare_operation(self, context, preparer):
        preparer.input(0, type=PreprocessedDataRow).output(0, type=PreparedDataRow)

    def __call__(self, input_row):
        best_distance, best_match = map_knowledges(input_row.knowledgeId, self.actual_knowledges)
        yield PreparedDataRow(
            ticketId=input_row.ticketId,
            messageTs=input_row.messageTs,
            knowledgeId=input_row.knowledgeId,
            knowledgeTitle=input_row.knowledgeTitle,
            text = input_row.text,
            best_distance=float(best_distance),
            best_match='/'.join(best_match)
        )


def run_yt_job(
    cluster: str,
    input_data: str,
    output_data: str,
    product_tag: str,
    s3_bucket: str,
    s3_filepath: str):

    actual_knowledges = prepare_actual_knowledges(s3_bucket, s3_filepath, product_tag)

    client = yt.wrapper.YtClient(proxy=cluster)
    logger.info('Calculating targets')
    client.run_map(
        PrepareData(actual_knowledges),
        source_table=input_data,
        destination_table=output_data,
    )