#!/usr/bin/env python
# -*- coding: utf-8 -*-

import json
import os
import logging

from crypta.lib.python import templater
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lookalike.lib.python.utils.config import config
from crypta.lookalike.lib.python.utils import (
    utils,
)


logger = logging.getLogger(__name__)


add_features_query_template = """
$prepare_features = ($features_list) -> (
    String::JoinFromList(
        ListMap(
            $features_list,
            ($feature) -> (IF($feature IS NULL, '0.', CAST($feature AS String))),
        ), ',')
);

INSERT INTO `{{segments_dssm_features_table}}`
WITH TRUNCATE

SELECT
    {% for field_name in additional_segment_features %}
    COALESCE(segment_features.{{field_name}}, '') AS {{field_name}},
    {% endfor %}
    {% if default_segment_features|length > 0 %}
    base.segment_float_features || ',' ||
        $prepare_features(COALESCE(segment_features.features, {{default_segment_features}})) AS segment_float_features,
    base.* WITHOUT base.segment_float_features,
    {% else %}
    base.*,
    {% endif %}
FROM `{{segments_dssm_features_table}}` AS base
LEFT JOIN `{{new_segments_features_table}}` AS segment_features
USING (GroupID);

INSERT INTO `{{user_data_dssm_features_table}}`
WITH TRUNCATE

SELECT
    {% for field_name in additional_user_features %}
    COALESCE(user_features.{{field_name}}, '') AS {{field_name}},
    {% endfor %}
    {% if default_user_features|length > 0 %}
    base.user_float_features || ',' ||
        $prepare_features(COALESCE(user_features.features, {{default_user_features}})) AS user_float_features,
    base.* WITHOUT base.user_float_features,
    {% else %}
    base.*,
    {% endif %}
FROM `{{user_data_dssm_features_table}}` AS base
LEFT JOIN `{{new_user_features_table}}` AS user_features
USING (yandexuid);

INSERT INTO `{{train_sample_table}}`
WITH TRUNCATE

SELECT
    {% for field_name in additional_segment_features %}
    COALESCE(segment_features.{{field_name}}, '') AS {{field_name}},
    {% endfor %}
    {% for field_name in additional_user_features %}
    COALESCE(user_features.{{field_name}}, '') AS {{field_name}},
    {% endfor %}
    {% if default_segment_features|length > 0 %}
    base.segment_float_features || ',' ||
        $prepare_features(COALESCE(segment_features.features, {{default_segment_features}})) AS segment_float_features,
    base.user_float_features || ',' ||
        $prepare_features(COALESCE(user_features.features, {{default_user_features}})) AS user_float_features,
    base.* WITHOUT base.segment_float_features, base.user_float_features,
    {% else %}
    base.*,
    {% endif %}
FROM `{{train_sample_table}}` AS base
LEFT JOIN `{{new_segments_features_table}}` AS segment_features
ON segment_features.GroupID == base.GroupID
LEFT JOIN `{{new_user_features_table}}` AS user_features
ON user_features.yandexuid == base.yandexuid
ORDER BY shuffling_number;
"""


validate_features_size_query = """
INSERT INTO `{bad_user_features_table}`
WITH TRUNCATE
SELECT
    yandexuid,
    ListLength(features) AS features_number,
FROM `{new_user_features_table}`
WHERE ListLength(features) != {features_number};


INSERT INTO `{bad_segment_features_table}`
WITH TRUNCATE
SELECT
    GroupID,
    ListLength(features) AS features_number,
FROM `{new_segments_features_table}`
WHERE ListLength(features) != {features_number};
"""


def validate_features(yt_client, yql_client, segments_features_path, user_features_path, features_number, transaction):
    with yt_client.TempTable() as bad_user_features_table, yt_client.TempTable() as bad_segment_features_table:
        yql_client.execute(
            query=validate_features_size_query.format(
                bad_user_features_table=bad_user_features_table,
                bad_segment_features_table=bad_segment_features_table,
                new_segments_features_table=segments_features_path,
                new_user_features_table=user_features_path,
                features_number=features_number,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL Validate features size',
        )

        assert yt_client.row_count(bad_user_features_table) == 0, 'User features list has wrong size'
        assert yt_client.row_count(bad_segment_features_table) == 0, 'Segment features list has wrong size'


def add(nv_params, output):
    yt_client = utils.get_yt_client(nv_params=nv_params)
    yql_client = utils.get_yql_client(nv_params=nv_params)

    working_dir_path = nv_params['working-dir']
    assert yt_client.exists(working_dir_path), 'Path to working directory is invalid'

    new_segments_features_path = os.path.join(os.path.dirname(working_dir_path), 'segment_features')
    new_user_features_path = os.path.join(os.path.dirname(working_dir_path), 'user_features')
    assert yt_client.exists(new_segments_features_path), 'New segments features table does not exist'
    assert yt_client.exists(new_user_features_path), 'New user features table does not exist'

    n_new_float_features = nv_params['n_features']
    if 'default_segment_features' in nv_params.keys():
        assert len(nv_params['default_segment_features']) == n_new_float_features, \
            'Default segment features list must have {} elements'.format(n_new_float_features)
    if 'default_user_features' in nv_params.keys():
        assert len(nv_params['default_user_features']) == n_new_float_features, \
            'Default user features list must have {} elements'.format(n_new_float_features)

    default_segment_features = nv_params.get('default_segment_features', [0.0] * n_new_float_features)
    default_user_features = nv_params.get('default_user_features', [0.0] * n_new_float_features)

    with NirvanaTransaction(yt_client) as transaction:
        if nv_params.get('validate', True):
            validate_features(
                yt_client,
                yql_client,
                new_segments_features_path,
                new_user_features_path,
                n_new_float_features,
                transaction,
            )

        use_bow_features = nv_params.get('use_bow_features', False)
        if use_bow_features:
            additional_user_features = utils.get_additional_dssm_fields(
                yt_client,
                utils.get_production_path(config.USER_DSSM_FEATURES_TABLE, nv_params),
                new_user_features_path,
            )
            additional_segment_features = utils.get_additional_dssm_fields(
                yt_client,
                utils.get_production_path(config.TEST_SEGMENTS_DSSM_FEATURES_TABLE, nv_params),
                new_segments_features_path,
            )

        float_features_size = yt_client.get_attribute(config.USER_DSSM_FEATURES_TABLE, 'float_features_size')

        query = templater.render_template(
            add_features_query_template,
            vars={
                'segments_dssm_features_table': config.TEST_SEGMENTS_DSSM_FEATURES_TABLE,
                'user_data_dssm_features_table': config.USER_DSSM_FEATURES_TABLE,
                'train_sample_table': config.TRAIN_SAMPLE_TABLE,
                'new_segments_features_table': new_segments_features_path,
                'new_user_features_table': new_user_features_path,
                'default_segment_features': default_segment_features,
                'default_user_features': default_user_features,
                'additional_user_features': additional_user_features if use_bow_features else [],
                'additional_segment_features': additional_segment_features if use_bow_features else [],
            },
        )
        yql_client.execute(query=query, transaction=str(transaction.transaction_id), title='YQL Add custom features')

        with open(output, 'w') as output_file:
            json.dump({
                'prod_train_sample_path': config.TRAIN_SAMPLE_TABLE.replace('new_model', 'prod'),
                'new_train_sample_path': config.TRAIN_SAMPLE_TABLE,
                'ParamsSize': float_features_size,
                'NewParamsSize': float_features_size + n_new_float_features,
            }, output_file)
