import numpy as np


def find_thresholds(table, segments, needed_recalls, needed_recall=1., constant_for_full_coverage=1., max_iter=100):
    """
    Finds thresholds for classification with desirable classes distribution and coverage.

    Parameters:
        table: pandas DataFrame (should contain column cnt for counts and columns with names of classes)
            a histogram of the probability distribution in the sample
        segments: list
            names of classes
        needed_recalls: numpy array with float dtype
            indicates the desired classes distribution
        needed_recall: float (default 1.0)
            indicates the desired coverage
        constant_for_full_coverage: float (default 1.0)
            set 0.5 to ensure that final coverage is 100% (use with needed_recall=1)
        max_iter:  uint (default 100)
            maximum number of iterations to find thresholds.
    """
    table['pdf'] = table['cnt'] / float(table['cnt'].sum())

    tables = dict()
    for segment in segments:
        table_segment = table.sort_values(segment, ascending=False)
        table_segment['cdf'] = table_segment['cnt'].cumsum() / float(table_segment['cnt'].sum())
        tables[segment] = table_segment

    current_recall_thresholds = needed_recall * needed_recalls
    current_needed_recalls = needed_recalls.copy()

    real_recall = needed_recall
    current_needed_recall = needed_recall
    thresholds = np.zeros(len(segments))

    for _ in range(max_iter):
        taken = []
        for segment_idx, segment in enumerate(segments):
            taken.append(tables[segment].loc[tables[segment]['cdf'] < current_recall_thresholds[segment_idx]])

        list_indicies = [set(taken[i].index) for i in range(len(segments))]
        overlapped_indicies = set()
        for i in range(len(list_indicies)):
            for j in range(i + 1, len(list_indicies)):
                overlapped_indicies.update(list_indicies[i].intersection(list_indicies[j]))

        prediciton_table = table.copy()
        for segment_idx, segment in enumerate(segments):
            tail = taken[segment_idx].tail(1)[segment]
            thresholds[segment_idx] = tail.item() if len(tail) > 0 else 1.
            prediciton_table[segment] /= thresholds[segment_idx]
        prediciton_table['score'] = prediciton_table[segments].max(axis='columns')
        prediciton_table = prediciton_table[prediciton_table['score'] >= 1.]
        prediciton_table['prediction'] = prediciton_table[segments].idxmax(axis='columns')

        for idx, segment in enumerate(segments):
            segment_prediction_table = prediciton_table[prediciton_table['prediction'] == segment]
            segment_overlapped_indicies = overlapped_indicies.difference(set(segment_prediction_table.index))
            if set(taken[idx].index).intersection(segment_overlapped_indicies):
                overlapped_pdf_accumulated = taken[idx].reindex(segment_overlapped_indicies)['pdf'].sum()
                real_recall -= overlapped_pdf_accumulated
                if segment_prediction_table['pdf'].sum() >= needed_recalls[idx] * needed_recall:
                    current_needed_recalls[idx] = 0.
                else:
                    current_needed_recalls[idx] = overlapped_pdf_accumulated
            else:
                current_needed_recalls[idx] = 0.

        normalizing_sum = current_needed_recalls.sum()
        if not np.isclose(normalizing_sum, 0):
            current_needed_recalls /= normalizing_sum

        if real_recall < needed_recall and not np.isclose(real_recall, needed_recall):
            recall_diff = needed_recall - real_recall
            current_needed_recall += recall_diff
            current_recall_thresholds += recall_diff * current_needed_recalls
            real_recall = current_needed_recall
        else:
            break

    return list(thresholds * constant_for_full_coverage)
