import logging
import os

from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.services.train_socdem_models.lib.common.utils import replace_table
from crypta.profile.utils.utils import report_ml_metrics_to_solomon
from crypta.profile.utils.config import config

logger = logging.getLogger(__name__)

plain_sample_query = """
$socdem_plain_training_sample = (
    SELECT
        socdem_labels_for_learning.*,
        yandexuid2vec.vector AS vector,
        yandexuid_cryptaid.crypta_id AS crypta_id,
        merged_segments.heuristic_common AS heuristic_common,
        merged_segments.longterm_interests AS longterm_interests,
        merged_hits_by_yandexuid.raw_site_weights AS raw_site_weights
    FROM `{socdem_labels_for_learning}` AS socdem_labels_for_learning
    INNER JOIN `{daily_yandexuid2vec}` AS yandexuid2vec
    ON socdem_labels_for_learning.yandexuid == yandexuid2vec.yandexuid
    INNER JOIN `{yandexuid_cryptaid}` AS yandexuid_cryptaid
    ON socdem_labels_for_learning.yandexuid == yandexuid_cryptaid.yandexuid
    INNER JOIN `{merged_segments}` AS merged_segments
    ON socdem_labels_for_learning.yandexuid == merged_segments.yandexuid
    INNER JOIN `{merged_hits_by_yandexuid}` AS merged_hits_by_yandexuid
    ON socdem_labels_for_learning.yandexuid == merged_hits_by_yandexuid.yandexuid
);

INSERT INTO `{socdem_plain_training_sample}`
WITH TRUNCATE

SELECT *
FROM $socdem_plain_training_sample;

INSERT INTO `{general_population}`
WITH TRUNCATE

SELECT *
FROM $socdem_plain_training_sample
WHERE income_segment IS NULL;

INSERT INTO `{socdem_training_sample_counts}`
WITH TRUNCATE

SELECT
    COUNT_IF(gender is not Null) AS gender,
    COUNT_IF(age_segment is not Null) AS age,
    COUNT_IF(income_segment is not Null) AS income
FROM $socdem_plain_training_sample;
"""


def get(yt_client, yql_client, date, common_train_sample, general_population):
    with NirvanaTransaction(yt_client) as transaction, yt_client.TempTable() as socdem_training_sample_counts_table:
        yql_client.execute(
            query=plain_sample_query.format(
                socdem_labels_for_learning=config.SOCDEM_LABELS_FOR_LEARNING_TABLE,
                daily_yandexuid2vec=config.DAILY_YANDEXUID2VEC,
                yandexuid_cryptaid=config.YANDEXUID_CRYPTAID_TABLE,
                merged_segments=config.SEGMENTS_STORAGE_BY_YANDEXUID_TABLE,
                merged_hits_by_yandexuid=config.YANDEXUID_METRICS_MERGED_HITS_TABLE,
                socdem_plain_training_sample=common_train_sample,
                socdem_training_sample_counts=socdem_training_sample_counts_table,
                general_population=general_population,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL get plain socdem training sample',
        )

        socdem_training_sample_counts = next(yt_client.read_table(socdem_training_sample_counts_table))

        metrics_to_send = []
        for socdem_type in socdem_config.SOCDEM_TYPES:
            metrics_to_send.append({
                'labels': {
                    'sample': 'common',
                    'socdem': socdem_type,
                    'vectors': 'web',
                    'metric': 'sample_size',
                },
                'value': int(socdem_training_sample_counts[socdem_type]),
            })
        report_ml_metrics_to_solomon(
            service=socdem_config.SOLOMON_SERVICE,
            metrics_to_send=metrics_to_send,
        )

        if os.environ.get('CRYPTA_ENVIRONMENT') == 'testing':
            replace_table(yt_client, common_train_sample)

        yt_client.set_attribute(common_train_sample, 'generate_date', date)
        yt_client.set_attribute(general_population, 'generate_date', date)
