use hahn;

pragma yt.PoolTrees = 'physical';
pragma yt.TentativePoolTrees = 'cloud';

declare $input_path as String;
declare $output_path as String;
declare $query_to_cluster_path as String;
declare $yt_pool as String?;
declare $query_column as String;
declare $region_column as String;

$yt_pool = $yt_pool ?? "default";
pragma yt.Pool = $yt_pool;

pragma File("libbert_models_udf.so", "yt://hahn/home/ranking/prod_build_artifacts_storage/latest_libbert_models_udf.so");
pragma udf("libbert_models_udf.so");
pragma File("split.bert.byparts.htxt", "https://proxy.sandbox.yandex-team.ru/1660373591");

pragma library("common.sql");
pragma library("define_cluster_lib.sql");

import define_cluster_lib symbols $get_query, $get_region, $normalize_dopp;

$raw_queries = (
    select
        $get_query(TableRow(), $query_column) as query,
        $get_region(TableRow(), $region_column) as region,
    from $input_path as t
    where query is not null
);

$new_queries = (
    select q.*
    from any $raw_queries as q
    left only join $query_to_cluster_path as c
    using (query)
);

$new_queries_with_qfuf_top_10 = (
    select
        source_table.*,
        source_queries_exp_top10.Ext_Qfuf_Top10 as _Ext_Qfuf_Top10
    from $new_queries as source_table
    left join any `//home/searchshare/common/source_queries_exp_top10` as source_queries_exp_top10
    on source_queries_exp_top10.Qdopp = $normalize_dopp(source_table.query)
);

$BNorm = ($x) -> {
    return SearchRequest::NormalizeBert(cast($x as Utf8)) ?? ""
};

$RegionNamesRus_BetaBertNorm = ($x) -> {
    return $BNorm(BertModels::RegionToGeoNamesRus_Beta($x))
};

$model_begemot_tokenization = BertModels::LoadModel(
    FilePath("split.bert.byparts.htxt"),
    '{"UseGPU": false, "OpenLayout": "BegemotPartTokenization", "MaxBatchSize": 512, "NumThreads": 32}'
);

insert into @tokenization
with truncate
select
    PassThrough,
    [('begemot_tokenization', Result)] as Features
from (
    process
        (
            select
                [
                    AsTuple("QueryBertNormed", $BNorm(query)),
                    AsTuple("BaseRegionNamesRus_BetaBertNormed", $RegionNamesRus_BetaBertNorm(region)),
                    AsTuple("Ext_Qfuf_Top10", _Ext_Qfuf_Top10 ?? "")
                ] as Features,
                <|
                    Ext_Qfuf_Top10: _Ext_Qfuf_Top10,
                    BaseRegion: region,
                    QueryRaw: query
                |> as PassThrough
            from $new_queries_with_qfuf_top_10
        )
    using
        BertModels::ApplyForPortableDataStream(
            $model_begemot_tokenization,
            TableRows(),
        )
);

commit;

pragma yt.DefaultMaxJobFails = "50";
pragma yt.DataSizePerJob="1280M";
pragma yt.DefaultMemoryLimit="32G";
pragma yt.AutoMerge = "disabled";
pragma yt.QueryCacheMode = "disable";
pragma yt.Pool = "research_gpu";
pragma yt.PoolTrees = "gpu_geforce_1080ti";
pragma yt.OperationSpec = '{mapper = {gpu_limit = 1; enable_setup_commands = true; enable_gpu_layers = true; cuda_toolkit_version = "9.1"; layer_paths = [
            "//porto_layers/delta/gpu/cuda/11.3";
            "//porto_layers/delta/gpu/driver/450.119.04";
            "//porto_layers/cached_in_tmpfs/bionic_base"
        ]; };}';

$model_begemot_calculation = BertModels::LoadModel(
    FilePath("split.bert.byparts.htxt"),
    '{"UseGPU": true, "OpenLayout": "BegemotPartWithoutTokenization", "MaxBatchSize": 1024, "NumThreads": 32}'
);

insert into $output_path
with truncate
select
    PassThrough.Ext_Qfuf_Top10 as Ext_Qfuf_Top10,
    PassThrough.BaseRegion as BaseRegion,
    PassThrough.QueryRaw as QueryRaw,
    Result as model_1660373591_query_predict,
from (
    process
        @tokenization
    using
        BertModels::ApplyOnPortableForPortableDataStream(
            $model_begemot_calculation,
            TableRows(),
            '{"BatchSize":1024}'
        )
);
