from ads.pytorch.lib.online_learning.preprocess import (
    format_ads_pytorch_input_table,
    map_date_field_logos,
)
from ads.pytorch.lib.pool.py_formatter.formatter import SkiffTensorSerializer
from ads.pytorch.lib.pool.py_formatter.inference import (
    infer_converters,
    ITensorConverter,
)

from crypta.poor_profiles.logos_tasks.features import (
    CAT_FEATURES,
    RV_FEATURES,
)


def prepare(
    yt_client,
    yql_client,
    is_logos_test,
    poor_profiles_st_table,
    poor_profiles_st_torch_table,
):
    all_columns = CAT_FEATURES + RV_FEATURES

    with yt_client.TempTable() as torch_intermediate_table:

        map_date_field_logos(
            yql_client=yql_client,
            src_table=yt_client.TablePath(poor_profiles_st_table, columns=all_columns),
            select_columns=all_columns,
            result_table=torch_intermediate_table,
            time_field="ShowTime",
            granularity=300,
            pivot_column="HitLogID",
            shuffle=False,
            only_range=None,
        )

        converters = infer_converters(
            table=torch_intermediate_table,
            categorical_factors=CAT_FEATURES,
            realvalue_factors=RV_FEATURES,
            unravelled_categorical_factors=[],
            yt_client=yt_client,
        )
        for conv in converters:
            assert isinstance(conv, ITensorConverter), "Don't use legacy py converters"

        format_ads_pytorch_input_table(
            yt_client=yt_client,
            yt_pool=yt_client.config["pool"],
            src_table=torch_intermediate_table,
            result_table=poor_profiles_st_torch_table,
            converters=converters,
            batch_size=2048 if not is_logos_test else 32,
            pivot_column=None,
            serializer=SkiffTensorSerializer(),
            cpp_backend=True,
            memory_limit=4,
        )
