import os

import numpy as np
import pandas as pd

from crypta.lib.python import templater
from crypta.lib.python.custom_ml import training_config
from crypta.lib.python.custom_ml.proto import classification_pb2
from crypta.lib.python.custom_ml.tools import fields
from crypta.lib.python.yt import yt_helpers


segments_metrics_template = """
<{Initial sample in segments by classes
{{ initial_sample_stats_df }}
}>
{% if existing_model %}
{% for segment_type, score in segments_scores %}
{% if score > 0.1 %}
Metrics for the {{ segment_type }} segment have improved: !!(green){{ score }})!!
{% elif score > -0.05 %}
Metrics for the {{ segment_type }} segment have not changed: !!(grey){{ score }}!!
{% else %}
Metrics for the {{ segment_type }} segment have deteriorated: !!(red){{ score }}!!
{% endif %}
{% endfor %}

<{New sample in segments by classes
{{ new_sample_stats_df }}
}>
{% if if_sample_should_be_added %}
{% set decision="YES" %}
{% else %}
{% set decision="NO" %}
{% endif %}
**If sample should be added: {{ decision }}**
{% if mixed_up_targets %}
Warning: targets in the new sample may be mixes up.
{% endif %}
{% endif %}

Model score:
{% if score > MODEL_APPROVING_SCORE_THRESHOLD %}
!!(green){{ score }}!!
{% elif score > MODEL_UNSURE_SCORE_THRESHOLD %}
!!(grey){{ score }}!!
{% else %}
!!(red){{ score }}!!
{% endif %}

{% if positive_class_score < 0 %}
Warning: model cannot distinguish positive class well
{% endif %}
{% if negative_class_score < 0 %}
Warning: model cannot distinguish negative class well
{% endif %}

"""


def sure_classes_ratio(y_pred, thresholds, sample_weight=None):
    # For sigmoid classification
    if y_pred.shape[1] == 1:
        y_pred = np.hstack([1 - y_pred, y_pred])

    assert y_pred.shape[1] == len(thresholds)

    sure_mask = (y_pred >= thresholds).any(axis=1)
    if sample_weight is not None:
        return (sample_weight * sure_mask).sum() / sample_weight.sum()
    else:
        return sure_mask.mean()


def classes_accuracy(y_true, y_pred, thresholds, sample_weight=None):
    assert y_pred.shape[1] == len(thresholds)

    ratio = np.apply_along_axis(lambda elem: elem / thresholds, 1, y_pred)
    y_pred_with_nan = np.apply_along_axis(
        lambda elem: float(np.argmax(elem)) if np.max(elem) > 1.0 else np.nan, 1, ratio
    )

    # There are possible nans in true labels because age and gender are trained together
    y_true_with_nan = np.apply_along_axis(
        lambda elem: float(np.argmax(elem)) if np.max(elem) == 1.0 else np.nan, 1, y_true
    )

    y_pred_sure = y_pred_with_nan[~(np.isnan(y_pred_with_nan) | np.isnan(y_true_with_nan))]
    y_true_sure = y_true_with_nan[~(np.isnan(y_pred_with_nan) | np.isnan(y_true_with_nan))]

    accuracy_mask = y_pred_sure == y_true_sure
    if sample_weight is not None:
        sample_weight = sample_weight[~(np.isnan(y_pred_with_nan) | np.isnan(y_true_with_nan))]
        return (sample_weight * accuracy_mask).sum() / sample_weight.sum()
    else:
        return accuracy_mask.mean()


def calculate_base_metrics(
        target_type,
        thresholds,
        predictions,
        labels,
        training=False,
        weights=None,
        is_socdem=False,
        is_mobile=False,
):
    from sklearn.metrics import log_loss

    common_labels = {'sample': 'training' if training else 'test'}
    if is_socdem:
        common_labels['socdem'] = target_type
        common_labels['vectors'] = 'mobile' if is_mobile else 'web'
    else:
        common_labels['model'] = target_type

    metrics = {
        'logloss': log_loss(
            labels,
            predictions,
            sample_weight=weights,
        ),
        'sure_ratio': sure_classes_ratio(
            predictions,
            thresholds,
            sample_weight=weights,
        ),
        'accuracy': classes_accuracy(
            labels,
            predictions,
            thresholds,
            sample_weight=weights,
        ),
    }

    metrics_with_labels = []
    for metric_name, metric_value in metrics.items():
        labels = {'metric': metric_name}
        labels.update(common_labels)
        metrics_with_labels.append({
            'labels': labels,
            'value': metric_value,
        })

    return metrics_with_labels


def calculate_metrics(
        target_type,
        y_predicted_train, y_train,
        y_predicted_test, y_test,
        ordered_thresholds,
        logger,
        weight_train=None, weight_test=None,
):
    ml_metrics = calculate_base_metrics(target_type, ordered_thresholds, y_predicted_train, y_train, True, weight_train)
    ml_metrics += calculate_base_metrics(target_type, ordered_thresholds, y_predicted_test, y_test, False, weight_test)

    for metric in ml_metrics:
        logger.info(metric)

    return ml_metrics


def calculate_roc_auc_scores(y_true, y_pred, target_type, id_to_segment_name):
    from sklearn.metrics import roc_auc_score

    if len(y_true.shape) == 2:
        y_true = np.argmax(y_true, axis=1)

    labels = np.unique(y_true)
    weighted_roc_auc = 0
    roc_auc_scores = []
    for label in labels:
        y_true_binary = y_true == label
        roc_auc_score_value = roc_auc_score(y_true_binary, y_pred[:, label])
        roc_auc_scores.append({
            'labels': {
                'metric': 'ovr_roc_auc',
                'model': target_type,
                'segment': id_to_segment_name[label],
            },
            'value': roc_auc_score_value,
        })
        weighted_roc_auc += float(np.sum(y_true_binary)) * roc_auc_score_value

    roc_auc_scores.append({
        'labels': {
            'metric': 'roc_auc',
            'model': target_type,
        },
        'value': weighted_roc_auc / len(y_true),
    })
    return roc_auc_scores


def convert_solomon_metrics_to_dict(solomon_metrics):
    resulting_metrics = {}
    for metric in solomon_metrics:
        if 'model' in metric['labels']:
            model_name = metric['labels']['model']
            labels = [label_value for label, label_value in metric['labels'].items() if label != 'model']
            resulting_metrics['{}.{}'.format(model_name, '.'.join(labels))] = metric['value']
        else:
            resulting_metrics['.'.join(metric['labels'].values())] = metric['value']

    return resulting_metrics


def default_float_formatter(col, value):
    return '{}'.format(round(value, 3))


def pandas_to_startrek(df, add_columns=True, add_index=False, formatter=default_float_formatter):
    """
    Converts pandas_yt DataFrame to startrek format.
    """
    ans = '#|\n'

    if add_columns:
        if not add_index:
            cols = df.columns
        else:
            cols = df.columns.insert(0, 'index')
        ans += '|| ' + ' | '.join(cols) + ' ||\n'

    for idx, row in df.iterrows():
        ans += '|| '
        if add_index:
            ans += '{} | '.format(idx)
        ans += ' | '.join(map(lambda item: item[1] if isinstance(item[1], str) else formatter(item[0], item[1]),
                              row.items())) + ' ||\n'

    ans += '|#'
    return ans


def format_metrics_comparison(existing_metrics=None, existing_top_features=None,
                              new_metrics=None, new_top_features=None):
    metrics_df = pd.DataFrame(columns=('sample',) + training_config.metrics_to_show)
    if existing_metrics is not None:
        existing_metrics_row = {'sample': 'existing'}
        for metric in training_config.metrics_to_show:
            existing_metrics_row[metric] = np.round(existing_metrics['existing_model_segments.{}'.format(metric)], 3)
        metrics_df = metrics_df.append(existing_metrics_row, ignore_index=True)

    new_metrics_row = {'sample': 'new'}
    for metric in training_config.metrics_to_show:
        new_metrics_row[metric] = np.round(new_metrics['new_model_segments.{}'.format(metric)], 3)
    metrics_df = metrics_df.append(new_metrics_row, ignore_index=True)

    result = 'Metrics:\n{}\n'.format(pandas_to_startrek(metrics_df))

    if existing_metrics is None:
        new_top_features = list(map(lambda feature: '- {}'.format(feature), new_top_features))
        result += '<{{Top-{} features\n{}\n}}>\n'.format(
            len(new_top_features), '\n'.join(new_top_features)
        )
    else:
        existing_top_features = list(map(lambda feature: '- {}'.format(feature), existing_top_features))
        result += '<{{Top-{} features for existing model\n{}\n}}>\n'.format(
            len(existing_top_features), '\n'.join(existing_top_features)
        )
        new_top_features = list(map(lambda feature: '- {}'.format(feature), new_top_features))
        result += '<{{Top-{} features for new model\n{}\n}}>\n'.format(
            len(new_top_features), '\n'.join(new_top_features)
        )

    return result


check_intersection_query = """
$positive_cnt = SELECT CAST(COUNT_IF(segment_name == 'positive') AS Double) FROM `{train_sample_by_yuid}`;
$negative_cnt = SELECT CAST(COUNT_IF(segment_name == 'negative') AS Double) FROM `{train_sample_by_yuid}`;

$joined = (
    SELECT
        predictions.id AS yandexuid,
        predictions.segment_name AS predicted_segment,
        train_sample.segment_name AS initial_segment,
    FROM `{predictions_table}` AS predictions
    INNER JOIN `{train_sample_by_yuid}` AS train_sample
    ON predictions.id == cast(train_sample.yandexuid as string)
);

INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT
    predicted_segment,
    COUNT_IF(initial_segment == 'positive') / $positive_cnt AS positive_ratio,
    COUNT_IF(initial_segment == 'negative') / $negative_cnt AS negative_ratio,
FROM $joined
GROUP BY predicted_segment
ORDER BY predicted_segment;
"""


def validate_predictions(yt_client, yql_client, predictions_table_path, train_sample_path):
    with yt_client.Transaction() as transaction, \
            yt_client.TempTable() as metrics_table:
        yql_client.execute(
            query=check_intersection_query.format(
                predictions_table=predictions_table_path,
                train_sample_by_yuid=train_sample_path,
                output_table=metrics_table,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL compute intersection between predictions and train sample',
        )

        table = list(yt_client.read_table(metrics_table))
        pd_table = pd.DataFrame(table, columns=[fields.predicted_segment, fields.positive_ratio, fields.negative_ratio])

        return pd_table


def compare_models_predictions(yt_client, yql_client, train_sample_path,
                               new_predictions_table_path, existing_predictions_table_path):
    metrics_tables = {}
    for model_type, predictions_table in (
        ('existing', existing_predictions_table_path),
        ('new', new_predictions_table_path),
    ):
        metrics_tables[model_type] = validate_predictions(
            yt_client=yt_client,
            yql_client=yql_client,
            predictions_table_path=predictions_table,
            train_sample_path=train_sample_path,
        )

    metrics_tables['existing'][fields.predicted_segment] = metrics_tables['existing'][fields.predicted_segment].apply(
        lambda value: '_'.join(value.split('_')[-2:])
    )

    merged_table = pd.merge(metrics_tables['existing'], metrics_tables['new'], on=fields.predicted_segment)
    merged_table.columns = [
        fields.predicted_segment,
        fields.positive_for_existing_model, fields.negative_for_existing_model,
        fields.positive_for_new_model, fields.negative_for_new_model,
    ]

    return merged_table[[
        fields.predicted_segment,
        fields.positive_for_existing_model, fields.positive_for_new_model,
        fields.negative_for_existing_model, fields.negative_for_new_model,
    ]]


def get_model_scores(df, positive_ratio_column, negative_ratio_column):
    """
    Model score is computed as positive/negative ratio difference between top and bottom segments.
    """
    positive_class_score = df[positive_ratio_column].values[0] - df[positive_ratio_column].values[-1]
    negative_class_score = df[negative_ratio_column].values[-1] - df[negative_ratio_column].values[0]

    return positive_class_score, negative_class_score


def get_ratio_change_for_segment(df, segment_idx):
    """
    Accumulate positive and negative classes ratio change for a given segment to a single score.
    """
    return df[fields.positive_for_new_model].values[segment_idx] - \
        df[fields.positive_for_existing_model].values[segment_idx] + \
        df[fields.negative_for_existing_model].values[segment_idx] - \
        df[fields.negative_for_new_model].values[segment_idx]


def get_decision_for_new_sample(df):
    """
    Check that there is no significant degradation in train sample distribution (existing/new) in computed segments.
    """
    score_for_top_segment = get_ratio_change_for_segment(df, 0)
    score_for_bottom_segment = -get_ratio_change_for_segment(df, -1)

    if_sample_should_be_added = False
    if min(score_for_top_segment, score_for_bottom_segment) > training_config.THRESHOLD_FOR_NEW_SAMPLE_ADDING:
        if_sample_should_be_added = True

    return if_sample_should_be_added, score_for_top_segment, score_for_bottom_segment


def compute_segments_metrics_for_new_model(yt_client, yql_client, predictions_table_path, train_sample_path):
    stats_df = validate_predictions(
        yt_client=yt_client,
        yql_client=yql_client,
        predictions_table_path=predictions_table_path,
        train_sample_path=train_sample_path
    )
    positive_class_score, negative_class_score = get_model_scores(stats_df, fields.positive_ratio, fields.negative_ratio)

    return templater.render_template(
        template_text=segments_metrics_template,
        vars={
            'existing_model': False,
            'initial_sample_stats_df': pandas_to_startrek(stats_df),
            'score': round(positive_class_score + negative_class_score, 3),
            'positive_class_score': round(positive_class_score, 3),
            'negative_class_score': round(negative_class_score, 3),
            'MODEL_APPROVING_SCORE_THRESHOLD': training_config.MODEL_APPROVING_SCORE_THRESHOLD,
            'MODEL_UNSURE_SCORE_THRESHOLD': training_config.MODEL_UNSURE_SCORE_THRESHOLD,
        },
    )


def compute_segments_metrics_for_existing_model(
    yt_client,
    yql_client,
    existing_predictions_table_path,
    new_predictions_table_path,
    initial_train_sample_path,
    new_train_sample_path,
):
    models_stats = {}
    for sample, sample_name in (
        (initial_train_sample_path, fields.initial_sample),
        (new_train_sample_path, fields.new_sample),
    ):
        models_stats[sample_name] = compare_models_predictions(
            yt_client=yt_client,
            yql_client=yql_client,
            train_sample_path=sample,
            new_predictions_table_path=new_predictions_table_path,
            existing_predictions_table_path=existing_predictions_table_path,
        )

    if_sample_should_be_added, score_for_top_segment, score_for_bottom_segment = get_decision_for_new_sample(
        models_stats[fields.initial_sample],
    )

    if if_sample_should_be_added:
        positive_class_score, negative_class_score = get_model_scores(
            df=models_stats[fields.initial_sample],
            positive_ratio_column=fields.positive_for_new_model,
            negative_ratio_column=fields.negative_for_new_model,
        )
    else:
        positive_class_score, negative_class_score = get_model_scores(
            df=models_stats[fields.initial_sample],
            positive_ratio_column=fields.positive_for_existing_model,
            negative_ratio_column=fields.negative_for_existing_model,
        )

    positive_class_score_on_new_sample, negative_class_score_on_new_sample = get_model_scores(
        df=models_stats[fields.new_sample],
        positive_ratio_column=fields.positive_for_existing_model,
        negative_ratio_column=fields.negative_for_existing_model,
    )
    mixed_up_targets = False
    if min(positive_class_score_on_new_sample, negative_class_score_on_new_sample) < training_config.MIXED_UP_CLASSES_THRESHOLD:
        mixed_up_targets = True
        if_sample_should_be_added = False

    return templater.render_template(
        template_text=segments_metrics_template,
        vars={
            'existing_model': True,
            'initial_sample_stats_df': pandas_to_startrek(models_stats[fields.initial_sample]),
            'segments_scores': (('top', round(score_for_top_segment, 3)), ('bottom', round(score_for_bottom_segment, 3))),
            'new_sample_stats_df': pandas_to_startrek(models_stats[fields.new_sample]),
            'if_sample_should_be_added': if_sample_should_be_added,
            'score': round(positive_class_score + negative_class_score, 3),
            'positive_class_score': round(positive_class_score, 3),
            'negative_class_score': round(negative_class_score, 3),
            'MODEL_APPROVING_SCORE_THRESHOLD': training_config.MODEL_APPROVING_SCORE_THRESHOLD,
            'MODEL_UNSURE_SCORE_THRESHOLD': training_config.MODEL_UNSURE_SCORE_THRESHOLD,
            'mixed_up_targets': mixed_up_targets,
        },
    ), if_sample_should_be_added


def write_metrics(yt_client, output_dir, output):
    output.seek(0)
    yt_client.write_file(os.path.join(output_dir, 'metrics'), output, force_create=True)


def save_metrics(yt_client, output_dir, matching_stats, model_metrics, model_top_features, model_type='new'):
    metrics_proto = classification_pb2.TMetrics(
        sample_id=os.path.basename(output_dir),
        roc_auc=model_metrics['{}_model_segments.roc_auc'.format(model_type)],
        accuracy=model_metrics['{}_model_segments.accuracy.test'.format(model_type)],
        positive_class_ratio=model_metrics['{}_model_segments.train_distribution.positive'.format(model_type)],
        negative_class_ratio=model_metrics['{}_model_segments.train_distribution.negative'.format(model_type)],
        train_sample_size=model_metrics['{}_model_segments.train_sample_size'.format(model_type)],
        matched_ids_ratio=float(matching_stats['matched_cnt'].sum()) / matching_stats['cnt'].sum(),
    )

    metrics_proto.top_features.extend(model_top_features)

    metrics_table = os.path.join(output_dir, 'model_metrics')
    yt_helpers.create_empty_table(
        yt_client=yt_client,
        path=metrics_table,
        schema={'Metrics': 'string'},
    )
    yt_client.write_table(metrics_table, [{'Metrics': metrics_proto.SerializeToString()}])
    yt_helpers.set_yql_proto_field(metrics_table, 'Metrics', classification_pb2.TMetrics, yt_client)

    return metrics_proto
