#!/usr/bin/env python
# -*- coding: utf-8 -*-

import json
import logging

from crypta.lib.python import templater
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lookalike.lib.python.utils import utils
from crypta.lookalike.lib.python.utils.config import config


logger = logging.getLogger(__name__)

make_train_sample_query_template = """
$segments_counts = (
    SELECT
        GroupID,
        2.0 * ids_cnt AS ids_cnt,
        is_validation,
    FROM `{{segments_with_counts_table}}`
);

$users_train_sample = (
    SELECT
        positives.*,
        1 AS target,
    FROM `{{positives_with_dssm_features_table}}` AS positives
UNION ALL
    SELECT
        user_dssm_features.*,
        negatives.GroupID AS GroupID,
        0 AS target,
    FROM `{{negatives_table}}` AS negatives
    INNER JOIN `{{user_dssm_features_table}}` AS user_dssm_features
    USING(yandexuid)
);

$train_sample = (
    SELECT
        users_train_sample.*,
        RANDOM(users_train_sample.yandexuid) AS shuffling_number,
        1.0 / Math::Sqrt(segments_counts.ids_cnt) AS weight,
        segments_dssm_features.segment_affinitive_sites_ids AS segment_affinitive_sites_ids,
        segments_dssm_features.segment_affinitive_apps AS segment_affinitive_apps,
        segments_dssm_features.segment_float_features AS segment_float_features,
        segments_counts.is_validation AS is_validation,
    FROM $users_train_sample AS users_train_sample
    INNER JOIN $segments_counts AS segments_counts
    ON users_train_sample.GroupID == segments_counts.GroupID
    INNER JOIN `{{segments_dssm_features_table}}` AS segments_dssm_features
    ON users_train_sample.GroupID == segments_dssm_features.GroupID
);

INSERT INTO `{{train_sample_table}}`
WITH TRUNCATE

SELECT *
FROM $train_sample
{% if make_validation_sample %}
WHERE NOT is_validation
{% endif %}
ORDER BY shuffling_number;

{% if make_validation_sample %}
INSERT INTO `{{validation_sample_table}}`
WITH TRUNCATE

SELECT *
FROM $train_sample
WHERE is_validation
ORDER BY shuffling_number;
{% endif %}
"""


def make(nv_params, inputs):
    yt_client = utils.get_yt_client(nv_params=nv_params)
    yql_client = utils.get_yql_client(nv_params=nv_params)

    with open(inputs.get('resource_info'), 'r') as resource_info:
        need_full_retrain = json.load(resource_info)['need_full_retrain']

    with NirvanaTransaction(yt_client) as transaction:
        query = templater.render_template(
            make_train_sample_query_template,
            vars={
                'segments_with_counts_table': config.TRAIN_SEGMENTS_WITH_COUNTS_TABLE,
                'positives_with_dssm_features_table': config.POSITIVES_WITH_DSSM_FEATURES_TABLE,
                'negatives_table': config.NEGATIVES_TABLE,
                'user_dssm_features_table': config.USER_DSSM_FEATURES_TABLE,
                'segments_dssm_features_table': config.SEGMENTS_DSSM_FEATURES_TABLE,
                'train_sample_table': config.TRAIN_SAMPLE_TABLE,
                'validation_sample_table': config.VALIDATION_SAMPLE_TABLE,
                'make_validation_sample': not need_full_retrain,
            }
        )
        yql_client.execute(
            query=query,
            transaction=str(transaction.transaction_id),
            title='YQL LaL make train sample',
        )
