#!/usr/bin/env python
# -*- coding: utf-8 -*-

from functools import partial
import logging

import numpy as np

from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.lookalike.lib.python.utils import (
    fields,
    mobile_utils,
)
from crypta.lookalike.lib.python.utils.mobile_config import config as mobile_config
from crypta.lookalike.lib.python.utils.utils import normalize

logger = logging.getLogger(__name__)

train_val_apps_query = """
$apps = (
    SELECT
        app_id,
        id_type,
        'training' AS app_type
    FROM `{train_apps}` AS train_apps
UNION ALL
    SELECT
        app_id,
        id_type,
        'validation' AS app_type
    FROM `{val_apps}` AS val_apps
);

"""

get_train_sample_query = """
$counts = (
    SELECT
        target_app,
        id_type,
        COUNT(*) AS devids_cnt
    FROM `{train_pairs}`
    GROUP BY target_app, id_type
);

INSERT INTO `{counts}`
WITH TRUNCATE

SELECT *
FROM $counts;

$user_web_features = (
    SELECT
        id,
        id_type,
        user_web_features.*
    FROM `{apps_by_devid_and_crypta_id}` AS matching
    LEFT JOIN `{user_web_features}` AS user_web_features
    ON matching.cryptaId == CAST(user_web_features.cryptaId AS String)
);

$default_user_vector = '{default_user_web_vector}';

INSERT INTO `{train_sample}`
WITH TRUNCATE

SELECT
    train_pairs.id AS id,
    train_pairs.id_type AS id_type,
    train_pairs.target_app AS target_app,
    1 AS target,
    1.0 / Math::Sqrt(CAST(devids_cnt AS Double)) AS weight,
    devids_cnt,
    -- USER
    user_apps_features_from_stores,
    user_apps_vector_features,
    installed_apps,
    COALESCE(user_web_features, $default_user_vector) AS user_web_features,
    -- APP
    app_segment_features,
    app_features_from_stores,
    app_vector_features,
    app_publisher_vector,
    affinitive_apps,
    RANDOM(train_pairs.target_app) AS shuffling_number
FROM `{train_pairs}` AS train_pairs
INNER JOIN `{dssm_app_features}` AS app_features
ON train_pairs.target_app == app_features.app_id AND train_pairs.id_type == app_features.id_type
INNER JOIN $user_web_features AS user_web_features
ON train_pairs.id == user_web_features.id AND train_pairs.id_type == user_web_features.id_type
INNER JOIN $counts AS counts
ON train_pairs.target_app == counts.target_app AND train_pairs.id_type == counts.id_type
ORDER BY shuffling_number;
"""


def merge_apps_features(nv_params):
    yt_client = mobile_utils.get_yt_client(nv_params=nv_params)
    yql_client = mobile_utils.get_yql_client(nv_params=nv_params)

    with NirvanaTransaction(yt_client) as transaction, \
            yt_client.TempTable() as app_features_table:
        yql_client.execute(
            query='{}{}'.format(
                train_val_apps_query.format(
                    train_apps=mobile_config.TRAIN_APPS_TABLE,
                    val_apps=mobile_config.VALIDATION_APPS_TABLE,
                ),
                mobile_utils.apps_features_query.format(
                    input_table='$apps',
                    apps_affinities=mobile_config.APPS_AFFINITIES,
                    app2vec=mobile_config.APP2VEC_TABLE,
                    publisher_vectors=mobile_config.APPS_VECTORS_BY_PUBLISHER,
                    apps_features_from_stores=mobile_config.APPS_FEATURES_FROM_STORES,
                    segments_dssm_features=mobile_config.APP_SEGMENTS_DSSM_WEB_FEATURES_TABLE,
                    merged_stores=mobile_config.MERGED_STORES,
                    output_table=app_features_table,
                ),
            ),
            transaction=str(transaction.transaction_id),
            title='YQL get final apps features table',
        )

        yt_helpers.create_empty_table(
            yt_client=yt_client,
            path=mobile_config.APP_DSSM_FEATURES,
            schema={
                fields.app_id: 'string',
                fields.id_type: 'string',
                fields.app_type: 'string',
                fields.app_vector_features: 'string',
                fields.app_features_from_stores: 'string',
                fields.app_segment_features: 'string',
                fields.app_publisher_vector: 'string',
                fields.affinitive_apps: 'string',
                fields.MD5Hash: 'uint64',
            },
            force=True,
        )

        category_to_app_vector = mobile_utils.get_category_to_vector_dict(
            yt_client, mobile_config.CATEGORY2VEC_TABLE, fields.category, fields.vector,
        )
        category_to_segment_vector = mobile_utils.get_category_to_vector_dict(
            yt_client,
            mobile_config.CATEGORY_SEGMENTS_DSSM_WEB_FEATURES_TABLE,
            fields.group_id,
            fields.app_segment_features,
        )
        category_to_url_vector = mobile_utils.get_category_to_vector_dict(
            yt_client, mobile_config.CATEGORY_VECTORS_BY_PUBLISHER, fields.category, fields.vector,
        )

        yt_client.run_map(
            partial(
                mobile_utils.process_apps_features,
                category_to_app_vector=category_to_app_vector,
                category_to_segment_vector=category_to_segment_vector,
                category_to_url_vector=category_to_url_vector,
                additional_columns=[fields.app_type, fields.MD5Hash],
            ),
            app_features_table,
            mobile_config.APP_DSSM_FEATURES,
            spec={
                'mapper': {
                    'memory_limit': 2 * 1024 * 1024 * 1024,
                    'memory_reserve_factor': 1,
                },
            },
        )


def calculate_mean_vector(yt, path, column_name, max_number_of_rows=mobile_config.USERS_TO_SAMPLE_CNT):
    for idx, row in enumerate(yt.read_table(path)):
        if idx == max_number_of_rows:
            break

        if idx == 0:
            vector = np.array(map(float, row[column_name].split(',')))
        else:
            vector += np.array(map(float, row[column_name].split(',')))

    return normalize(vector)


def make(nv_params):
    yt_client = mobile_utils.get_yt_client(nv_params=nv_params)
    yql_client = mobile_utils.get_yql_client(nv_params=nv_params)

    default_user_web_vector = ','.join(map(str, calculate_mean_vector(
        yt_client, mobile_config.USERS_DSSM_FEATURES_WEB, fields.user_web_features
    )))

    with NirvanaTransaction(yt_client) as transaction:
        yt_helpers.create_empty_table(
            yt_client=yt_client,
            path=mobile_config.DEFAULT_USER_DSSM_FEATURES_WEB,
            schema={
                fields.user_web_features: 'string',
            },
            force=True,
        )
        yt_client.write_table(mobile_config.DEFAULT_USER_DSSM_FEATURES_WEB, [{
            fields.user_web_features: default_user_web_vector,
        }])

        yql_client.execute(
            query=get_train_sample_query.format(
                train_pairs=mobile_config.USERS_TRAIN_FEATURES_MOBILE,
                counts=mobile_config.TRAIN_DEVIDS_COUNTS_BY_APP,
                apps_by_devid_and_crypta_id=mobile_config.APPS_BY_DEVID_AND_CRYPTA_ID,
                default_user_web_vector=default_user_web_vector,
                user_web_features=mobile_config.USERS_DSSM_FEATURES_WEB,
                train_sample=mobile_config.TRAIN_SAMPLE,
                dssm_app_features=mobile_config.APP_DSSM_FEATURES,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL get train sample',
        )
