import operator
import os

import numpy as np

from crypta.lib.python.custom_ml import training_config

make_common_train_sample_query = """
$train_sample = (
    SELECT
        {{target_processing}} AS {{target}},
        indevice_yandexuid.yandexuid AS yandexuid,
        COALESCE(CAST(user_data.CryptaID AS Uint64), RandomNumber(user_data.CryptaID)) AS crypta_id,
        user_data.Vectors AS Vectors,
        user_data.Attributes AS Attributes,
        user_data.Segments AS Segments,
    FROM `{{input_table}}` AS train_sample
    INNER JOIN `{indevice_yandexuid_table}` AS indevice_yandexuid
    ON train_sample.id == indevice_yandexuid.id AND train_sample.id_type == indevice_yandexuid.id_type
    INNER JOIN `{user_data_table}` VIEW raw AS user_data
    ON indevice_yandexuid.yandexuid == CAST(user_data.yuid AS Uint64)
    {{additional_conditions}}
);
INSERT INTO `{{output_table}}`
WITH TRUNCATE
SELECT DISTINCT
    yandexuid,
    {{target}},
    crypta_id,
    Vectors,
    Attributes,
    Segments
FROM $train_sample
ORDER BY yandexuid;
""".format(
    indevice_yandexuid_table=training_config.INDEVICE_YANDEXUID,
    user_data_table=training_config.USER_DATA_TABLE,
)


def check_array_to_have_all_values(array, required_values):
    present = set(array[~np.isnan(array)])
    required = set(required_values)
    assert present == required, \
        'Required values assert failed. Required: {}, present: {}'.format(
            required, present
        )


def to_categorical(y):
    has_value = ~np.isnan(y)
    nb_classes = int(y[has_value].max()) + 1
    n = len(y)
    categorical = np.zeros((n, nb_classes))
    categorical[np.arange(n)[has_value], y[has_value].astype(int)] = 1
    return categorical


def revert_dict(mapping):
    return {v: k for k, v in mapping.items()}


def normalize_probabilities(numeric_values_dict):
    total = sum(numeric_values_dict.values())
    result_dict = {}
    for key, value in numeric_values_dict.items():
        result_dict[key] = value / total

    return result_dict


def get_item_with_max_value(numeric_values_dict):
    item_with_max_value, max_value = max(list(numeric_values_dict.items()), key=operator.itemgetter(1))
    return item_with_max_value, max_value


def get_model_tag(resource_type=None, resource_id=None, yt_path=None):
    if yt_path is not None:
        if os.environ.get('CRYPTA_ENVIRONMENT') == 'local_testing':
            return 'yt://plato/{}'.format(yt_path[2:])

        return 'yt://hahn/{}'.format(yt_path)

    if resource_type is not None:
        return 'https://proxy.sandbox.yandex-team.ru/last/{}?attrs={{"released":"stable"}}'.format(resource_type)

    return 'https://proxy.sandbox.yandex-team.ru/{}'.format(resource_id)
