import logging
import os

from ads.bsyeti.libs.primitives.counter_proto import counter_ids_pb2
from yabs.server.proto.keywords import keywords_data_pb2

from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.prism.lib.config import config
from crypta.prism.services.training.lib import utils


logger = logging.getLogger(__name__)

get_raw_train_sample_query = """
PRAGMA File('bigb.so', '{bigb_udf_url}');
PRAGMA udf('bigb.so');

$bm_cats_counter = ($dump) -> {{
    RETURN ListFilter(
        $dump.counters,
        ($counter) -> {{
            return $counter.counter_id == {bm_cats_counter_id}
        }}
    )
}};

$get_keyword_by_id = ($dump, $keyword_id) -> {{
    RETURN ListFilter(
        $dump.items,
        ($item) -> {{
            return $item.keyword_id == $keyword_id
        }}
    )
}};

$bb_profiles = (
    SELECT
        CAST(UniqID AS String) AS yandexuid,
        MAX_BY(ProfileDump, `TimeStamp`) AS profile,
        MAX(`TimeStamp`) AS update_time
    FROM RANGE(
        `{beh_hit_hour_log_dir}`,
        `{date}T00:00:00`,
        `{date}T23:00:00`,
    )
    TABLESAMPLE SYSTEM({bigb_sampling_rate})
    GROUP BY UniqID
);

$prism_segment_target_mapping = AsDict({prism_segment_target_mapping});

$with_targets = (
    SELECT
        profiles.update_time as update_time,
        profiles.profile AS profile,
        Bigb::ParseProfile(profiles.profile) as profile_parsed,
        COALESCE($prism_segment_target_mapping[prism.prism_segment], -1) AS target,
        COALESCE(CAST(prism.crypta_id AS Uint64), CAST(prism.yandexuid AS Uint64)) AS crypta_id
    FROM $bb_profiles as profiles
    INNER JOIN `{input_offline_prism_table}` as prism
    USING(yandexuid)
);


$parsed = (
    SELECT
        update_time,
        profile,
        Yson::Serialize(Yson::FromDoubleDict(COALESCE(ToDict(ListZip(
            ListMap($bm_cats_counter(profile_parsed)[0].key, ($key) -> {{ RETURN CAST($key AS String) }}),
            $bm_cats_counter(profile_parsed)[0].value)
        ), AsDict()))) AS bindings,
        COALESCE($get_keyword_by_id(profile_parsed, {operating_systems_keyword_id}), [NULL]) AS operating_systems_keyword,
        COALESCE($get_keyword_by_id(profile_parsed, {mobile_models_keyword_id}), [NULL]) AS mobile_models_keyword,
        COALESCE($get_keyword_by_id(profile_parsed, {region_keyword_id}), [NULL]) AS region_keyword,
        crypta_id,
        target
    FROM $with_targets
    WHERE target != -1
);

INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT
    crypta_id,
    SOME(target) AS target,
    MAX_BY(profile, update_time) AS profile,
    MAX_BY(bindings, update_time) AS bindings,
    AGGREGATE_LIST_DISTINCT(ListHead(operating_systems_keyword.uint_values)) AS operating_systems,
    AGGREGATE_LIST_DISTINCT(mobile_models_keyword.string_value) AS mobile_models,
    AGGREGATE_LIST_DISTINCT(ListHead(region_keyword.uint_values)) AS regions,
FROM $parsed
FLATTEN LIST BY (operating_systems_keyword, mobile_models_keyword, region_keyword)
GROUP BY crypta_id
ORDER BY crypta_id
LIMIT {train_sample_size};
"""


def get_last_offline_prism_date(yt_client):
    return max(yt_client.list(config.PRISM_OFFLINE_USER_WEIGHTS_DIR))


def get(yt_client, yql_client):
    with NirvanaTransaction(yt_client) as transaction:
        last_offline_prism_date = get_last_offline_prism_date(yt_client)
        yql_client.execute(
            query=get_raw_train_sample_query.format(
                bigb_udf_url=config.BIGB_UDF_URL,
                bigb_sampling_rate=config.BIGB_SAMPLING_RATE,
                bm_cats_counter_id=counter_ids_pb2.ECounterId.CI_QUERY_CATEGORIES_INTEREST,
                operating_systems_keyword_id=keywords_data_pb2.EKeyword.KW_DETAILED_DEVICE_TYPE_BT,
                mobile_models_keyword_id=keywords_data_pb2.EKeyword.KW_DEVICE_MODEL_BT,
                region_keyword_id=keywords_data_pb2.EKeyword.KW_USER_REGION,
                train_sample_size=config.TRAIN_SAMPLE_SIZE,
                prism_segment_target_mapping=str(list(utils.PRISM_SEGMENT_TARGET_MAPPING.items()))[1:-1],
                date=last_offline_prism_date,
                beh_hit_hour_log_dir=config.BEH_HIT_HOUR_LOG_DIR,
                input_offline_prism_table=os.path.join(config.PRISM_OFFLINE_USER_WEIGHTS_DIR, last_offline_prism_date),
                output_table=config.RAW_TRAIN_SAMPLE_TABLE,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL realtime prism get raw train sample',
        )
