#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import six

from crypta.lib.python import templater
from crypta.lib.python.yt import yt_helpers
from crypta.lookalike.lib.python.utils import fields

quantiles = ('0.00001', '0.00002', '0.00003', '0.00004', '0.00005', '0.00007',  '0.0001', '0.0002', '0.0003', '0.0004',
             '0.0005', '0.0007', '0.001', '0.002', '0.003', '0.004', '0.005', '0.007', '0.01', '0.02', '0.03', '0.04',
             '0.05', '0.07', '0.1', '0.15', '0.2', '0.3', '0.4', '0.5', '0.6', '0.7', '0.8', '0.9', '1')

calculate_tp_fp_tn_fn_query_unrendered_template = """
$flattened_segments = (
    SELECT
        {group_columns},
        label,
        ROW_NUMBER() OVER w AS row_rank
    FROM `{input_table}`
    WINDOW w AS (
        PARTITION BY {group_columns}
        ORDER BY {order_by}
    )
);

$segments_positive_ranks = (
    SELECT
        {group_columns},
        {% for q in almost_all_quantiles %}
        MIN(row_rank) + (MAX(row_rank) - MIN(row_rank)) * {{q}} AS qnt_{{q | replace('.', '')}}_positive_rank,
        {% endfor %}
        MAX(row_rank) AS qnt_1_positive_rank
    FROM $flattened_segments
    WHERE label == 1
    GROUP BY {group_columns}
);

$segments_with_positive_ranks = (
    SELECT
        segments_positive_ranks.*,
        label,
        row_rank
    FROM $flattened_segments AS flattened_segments
    INNER JOIN $segments_positive_ranks AS segments_positive_ranks
    USING ({group_columns})
);

INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT
    {group_columns},
    SOME(qnt_1_positive_rank) AS last_positive_rank,

    {% for q in all_quantiles %}
    {% set quantile = q | replace('.', '') %}
    COUNT_IF(row_rank <= qnt_{{quantile}}_positive_rank AND label == 1) AS tp_{{quantile}}_qnt,
    COUNT_IF(row_rank <= qnt_{{quantile}}_positive_rank AND label == 0) AS fp_{{quantile}}_qnt,
    COUNT_IF(row_rank > qnt_{{quantile}}_positive_rank and label == 0) AS tn_{{quantile}}_qnt,
    COUNT_IF(row_rank > qnt_{{quantile}}_positive_rank AND label == 1) AS fn_{{quantile}}_qnt,
    {% endfor %}

FROM $segments_with_positive_ranks
GROUP BY {group_columns}
ORDER BY {group_columns};
"""

calculate_tp_fp_tn_fn_query_template = templater.render_template(
    template_text=calculate_tp_fp_tn_fn_query_unrendered_template,
    vars={
        'almost_all_quantiles': quantiles[:-1],
        'all_quantiles': quantiles,
    },
)


def calculate_precision(tp, fp):
    return float(tp) / (tp + fp)


def calculate_recall(tp, fn):
    return float(tp) / (tp + fn)


def calc_aupr(points_precision, points_recall):
    area = 0.

    for i in six.moves.range(1, len(points_recall)):
        cur_precision = points_precision[i]
        prev_precision = points_precision[i-1]
        cur_recall = points_recall[i]
        prev_recall = points_recall[i-1]

        if (not np.isnan(prev_recall)) and (prev_recall >= 0):
            width = cur_recall - prev_recall
            height = (cur_precision + prev_precision) / 2
            area += width * height

    return area


def calculate_pr_stats(yt_client, table_name, mobile=False):
    pr_stats_all_segments = []
    for row in yt_client.read_table(table_name):
        points_recall = []
        points_precision = []
        for quantile in quantiles:
            tp = row[fields.get_tp_qnt_field_name(qnt=quantile)]
            if tp == 0:
                continue

            points_precision.append(calculate_precision(
                tp=tp,
                fp=row[fields.get_fp_qnt_field_name(qnt=quantile)],
            ))

            points_recall.append(calculate_recall(
                tp=tp,
                fn=row[fields.get_fn_qnt_field_name(qnt=quantile)],
            ))

        points_recall.insert(0, 0.)
        points_precision.insert(0, 0.)

        index = 0
        while index < len(points_recall):
            if np.isnan(points_recall[index]) or points_recall[index] < 0:
                points_recall.pop(index)
                points_precision.pop(index)
            else:
                index += 1

        if mobile:
            pr_stats = {fields.app_id: row[fields.app_id], fields.id_type: row[fields.id_type]}
        else:
            pr_stats = {fields.group_id: row[fields.group_id]}
        pr_stats[fields.pr_curve] = {fields.recall: points_recall, fields.precision: points_precision}
        pr_stats[fields.aupr] = calc_aupr(points_precision=points_precision, points_recall=points_recall)

        pr_stats_all_segments.append(pr_stats)

    return pr_stats_all_segments


def get_pr_stats_table(yt_client, input_table, output_table):
    pr_stats_all_segments = calculate_pr_stats(yt_client=yt_client, table_name=input_table)

    yt_helpers.create_empty_table(
        yt_client=yt_client,
        path=output_table,
        schema={
            fields.group_id: 'string',
            fields.pr_curve: 'any',
            fields.aupr: 'double',
        },
        additional_attributes={'optimize_for': 'scan'},
        force=True,
    )

    yt_client.write_table(output_table, pr_stats_all_segments)
    yt_client.run_sort(output_table, sort_by=fields.group_id)


def get_ci_with_bootstrap(data, statistic=np.median, ci_type='basic', alpha=0.05, n_samples=10000):
    assert ci_type in {'basic', 'percentile'}, 'Only basic and percentile confidence interval types are supported'

    estimate = statistic(data)

    estimations = []
    for i in range(n_samples):
        sample = np.random.choice(data, size=len(data))
        estimations.append(statistic(sample))
    left, right = np.quantile(estimations, q=[alpha / 2., 1. - alpha / 2.])

    if ci_type == 'percentile':
        return estimate, left, right

    return estimate, 2 * estimate - right, 2 * estimate - left
