import logging

from crypta.lib.python import templater
from crypta.lib.python.juggler.juggler_helpers import report_event_to_juggler
from crypta.profile.lib import date_helpers
from crypta.siberia.bin.custom_audience.lib.python.clustering import config as clustering_config


logger = logging.getLogger(__name__)


update_metrics_table_template = """
$normalize_vector = ($vector) -> {
    $norm = Math::Sqrt(ListSum(ListMap($vector, ($elem) -> { RETURN Math::Pow($elem, 2); })));
    $normalized_vector = ListMap(
        $vector,
        ($elem) -> {RETURN $elem / $norm;}
    );

    RETURN $normalized_vector;
};

$dot_product = ($vec1, $vec2) -> {
    RETURN ListSum(ListMap(ListZip($vec1, $vec2), ($elem) -> { RETURN $elem.0 * $elem.1; }));;
};

$merged_clustering_table = (
    SELECT
        new_clustering_table.{{clustering_fields.name}} AS segment_name,
        old_clustering_table.{{clustering_fields.cluster_id}} AS old_cluster_id,
        new_clustering_table.{{clustering_fields.cluster_id}} AS new_cluster_id,
        new_clustering_table.{{clustering_fields.users_count}} AS users_count,
        1. - $dot_product(
            $normalize_vector(
                Yson::ConvertToDoubleList(old_centroids_table.{{clustering_fields.vector}})
            ),
            $normalize_vector(
                Yson::ConvertToDoubleList(new_centroids_table.{{clustering_fields.vector}})
            )) AS distance,
    FROM `{{old_clustering_table}}` AS old_clustering_table
    INNER JOIN `{{new_clustering_table}}` AS new_clustering_table
    ON old_clustering_table.{{clustering_fields.name}} == new_clustering_table.{{clustering_fields.name}}
    INNER JOIN `{{centroids_table}}` AS old_centroids_table
    ON old_clustering_table.{{clustering_fields.cluster_id}} == old_centroids_table.{{clustering_fields.cluster_id}}
    INNER JOIN `{{centroids_table}}` AS new_centroids_table
    ON new_clustering_table.{{clustering_fields.cluster_id}} == new_centroids_table.{{clustering_fields.cluster_id}}
    {% if id_type is not none %}
    WHERE old_clustering_table.{{clustering_fields.id_type}} == '{{id_type}}'
        AND new_clustering_table.{{clustering_fields.id_type}} == '{{id_type}}'
        AND old_centroids_table.{{clustering_fields.id_type}} == '{{id_type}}'
        AND new_centroids_table.{{clustering_fields.id_type}} == '{{id_type}}'
    {% endif %}
);

SELECT *
FROM $merged_clustering_table;

$new_clusterid_info_table = (
    SELECT
        new_cluster_id,
        SUM(users_count) AS users_count_sum,
        MAX(users_count) AS users_count_max,
    FROM $merged_clustering_table
    GROUP BY new_cluster_id
);

$old_new_clusterids_table = (
    SELECT
        merged_clustering_table.old_cluster_id AS old_cluster_id,
        merged_clustering_table.new_cluster_id AS new_cluster_id,
        SUM(merged_clustering_table.users_count) AS users_count,
    FROM $merged_clustering_table AS merged_clustering_table
    GROUP BY old_cluster_id, new_cluster_id
);

$old_clusterid_share_table = (
    SELECT
        old_new_clusterids_table.new_cluster_id AS new_cluster_id,
        MAX(
            CAST(old_new_clusterids_table.users_count AS Double) / new_clusterid_info_table.users_count_sum
        ) as old_cluster_id_share,
        SOME(users_count_sum) AS users_count_sum,
    FROM $old_new_clusterids_table AS old_new_clusterids_table
    INNER JOIN $new_clusterid_info_table AS new_clusterid_info_table
    USING (new_cluster_id)
    GROUP BY old_new_clusterids_table.new_cluster_id
);

$users_count_sum = (
    SELECT SUM(users_count_sum)
    FROM $old_clusterid_share_table
);

$old_clusterid_share = (
    SELECT
        SUM(old_cluster_id_share * (CAST(users_count_sum AS Double) / $users_count_sum)),
    FROM $old_clusterid_share_table
);

$users_count_sum = (
    SELECT SUM(users_count)
    FROM $merged_clustering_table
);

INSERT INTO `{{metrics_table}}`

SELECT
    '{{date}}' AS fielddate,
    $old_clusterid_share AS old_clusterid_share,
    SUM(merged_clustering_table.distance * (CAST(merged_clustering_table.users_count AS Double) / $users_count_sum)) AS distance,
    MAX(CAST(merged_clustering_table.users_count AS Double) / new_clusterid_info_table.users_count_sum) AS largest_segment_share_in_cluster,
    {% if id_type %}
    '{{id_type}}' AS {{clustering_fields.id_type}},
    {% endif %}
FROM $merged_clustering_table AS merged_clustering_table
INNER JOIN $new_clusterid_info_table AS new_clusterid_info_table
USING (new_cluster_id)
"""


def update(yt_client, transaction, yql_client, centroids_table, monthly_clustering_dir, metrics_table,
            service_name, clustering_fields, id_types=None, lower_bounds=None, upper_bounds=None):
    monthly_clustering_tables = yt_client.list(monthly_clustering_dir, absolute=True, sort=True)
    if len(monthly_clustering_tables) < 2:
        logger.info('{dir} has less than 2 tables to compare'.format(dir=monthly_clustering_dir))
        return
    old_clustering_table, new_clustering_table = monthly_clustering_tables[-2:]

    id_types = id_types if id_types is not None else [None]

    if lower_bounds is None:
        lower_bounds = {
            'old_clusterid_share': 0.,
        }
    if upper_bounds is None:
        upper_bounds = {
            'distance': 2.,
        }

    for id_type in id_types:
        yql_client.execute(
            query=templater.render_template(
                update_metrics_table_template,
                vars={
                    'old_clustering_table': old_clustering_table,
                    'new_clustering_table': new_clustering_table,
                    'centroids_table': centroids_table,
                    'clustering_fields': clustering_fields,
                    'id_type': id_type,
                    'date': date_helpers.get_today_date_string(),
                    'metrics_table': metrics_table,
                },
            ),
            transaction=str(transaction.transaction_id),
            title='YQL update_metrics_table',
        )

        new_metrics = next(yt_client.read_table(
            yt_client.TablePath(
                name=metrics_table,
                start_index=yt_client.row_count(metrics_table)-1,
            ),
        ))

        bad_metrics_messages = []

        for metric_name, metric_value in new_metrics.items():
            if metric_name in lower_bounds:
                message = 'Lower bound for {metric} is {bound_value}, current value is {value}'.format(
                    metric=metric_name,
                    bound_value=lower_bounds[metric_name],
                    value=metric_value,
                )
                logger.info(message)
                if metric_value < lower_bounds[metric_name]:
                    bad_metrics_messages.append(message)

            if metric_name in upper_bounds:
                message = 'Upper bound for {metric} is {bound_value}, current value is {value}'.format(
                    metric=metric_name,
                    bound_value=upper_bounds[metric_name],
                    value=metric_value,
                )
                logger.info(message)
                if metric_value > upper_bounds[metric_name]:
                    bad_metrics_messages.append(message)

        service = '{service_name}_metrics'.format(service_name=service_name)
        if len(bad_metrics_messages) == 0:
            report_event_to_juggler(
                status='OK',
                service=service,
                host=clustering_config.CRYPTA_ML_JUGGLER_HOST,
                tags=[service],
                logger=logger,
            )
        else:
            report_event_to_juggler(
                status='WARN',
                service=service,
                host=clustering_config.CRYPTA_ML_JUGGLER_HOST,
                description='. '.join(bad_metrics_messages),
                tags=[service],
                logger=logger,
            )
            raise Exception('Metrics do not satisfy bounds')
