import logging
import os

from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.profile.services.train_socdem_models.lib.common.utils import replace_table
from crypta.profile.utils.config import config

logger = logging.getLogger(__name__)

plain_sample_query = """
$sample = (
    SELECT
        indevice_yandexuid.id AS id,
        indevice_yandexuid.id_type AS id_type,
        labels.gender AS gender,
        labels.age_segment AS age_segment,
        labels.income_segment AS income_segment,
        yandexuid_cryptaid.crypta_id AS crypta_id
    FROM `{socdem_labels_table}` AS labels
    INNER JOIN `{indevice_yandexuid_table}` AS indevice_yandexuid
    ON labels.yandexuid == indevice_yandexuid.yandexuid
    INNER JOIN `{yandexuid_crytaid}` AS yandexuid_cryptaid
    ON labels.yandexuid == yandexuid_cryptaid.yandexuid
    WHERE indevice_yandexuid.id_type == 'idfa' OR indevice_yandexuid.id_type == 'gaid'
);

$unique_socdem_sample = (
    SELECT
        id,
        id_type,
        SOME(crypta_id) AS crypta_id,
        CASE
            WHEN ListLength(ListUniq(AGGREGATE_LIST(gender))) < 1
            THEN NULL
            ELSE ListUniq(AGGREGATE_LIST(gender))[0]
        END AS gender,
        CASE
            WHEN ListLength(ListUniq(AGGREGATE_LIST(age_segment))) < 1
            THEN NULL
            ELSE ListUniq(AGGREGATE_LIST(age_segment))[0]
        END AS age_segment,
        CASE
            WHEN ListLength(ListUniq(AGGREGATE_LIST(income_segment))) < 1
            THEN NULL
            ELSE ListUniq(AGGREGATE_LIST(income_segment))[0]
        END AS income_segment
    FROM $sample
    GROUP BY id, id_type
    HAVING ListLength(ListUniq(AGGREGATE_LIST(gender))) <= 1
        AND ListLength(ListUniq(AGGREGATE_LIST(age_segment))) <= 1
        AND ListLength(ListUniq(AGGREGATE_LIST(income_segment))) <= 1
);

$mobile_socdem_plain_training_sample = (
    SELECT
        labels.id AS id,
        labels.id_type AS id_type,
        labels.crypta_id AS crypta_id,
        labels.gender AS gender,
        labels.age_segment AS age_segment,
        labels.income_segment AS income_segment,
        devid_vectors.vector AS vector,
        app_metrica.model AS model,
        app_metrica.manufacturer AS manufacturer,
        Yson::ConvertToUint64(app_metrica.main_region_obl) AS main_region_obl,
        app_metrica.categories AS categories
    FROM $unique_socdem_sample AS labels
    INNER JOIN `{daily_devid_vectors_table}` AS devid_vectors
    ON labels.id == devid_vectors.id AND labels.id_type == devid_vectors.id_type
    INNER JOIN `{app_metrica_table}` as app_metrica
    ON labels.id == app_metrica.id AND labels.id_type == app_metrica.id_type
);

INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT *
FROM $mobile_socdem_plain_training_sample;

INSERT INTO `{general_population}`
WITH TRUNCATE

SELECT *
FROM $mobile_socdem_plain_training_sample
WHERE income_segment IS NULL;
"""


def get(yt_client, yql_client, date, common_train_sample, general_population):
    with NirvanaTransaction(yt_client) as transaction:
        yql_client.execute(
            query=plain_sample_query.format(
                socdem_labels_table=config.SOCDEM_LABELS_FOR_LEARNING_TABLE,
                indevice_yandexuid_table=config.INDEVICE_YANDEXUID,
                daily_devid_vectors_table=config.DAILY_DEVID2VEC,
                app_metrica_table=config.APP_BY_DEVID_MONTHLY_TABLE,
                yandexuid_crytaid=config.YANDEXUID_CRYPTAID_TABLE,
                output_table=common_train_sample,
                general_population=general_population,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL get plain mobile socdem training sample',
        )

        if os.environ.get('CRYPTA_ENVIRONMENT') == 'testing':
            replace_table(yt_client, common_train_sample)

        yt_client.set_attribute(common_train_sample, 'generate_date', date)
        yt_client.set_attribute(general_population, 'generate_date', date)
