from .graph_render_base import SimpleGraphRender

NEEDED_COLUMNS = ['TotalPMatchBitsFails', 'HalfPMatchFails', 'AllPMatchFails']


def generate_graph(task, loc, graph_name):
    import pandas as pd
    import plotly.graph_objs as go

    from plotly.subplots import make_subplots

    def get_data(column):
        loc_df = loc[['ExpID', column]].sort_values(column, ascending=False)
        all_sum = loc_df[column].sum()
        res = []
        i = 0
        current_sum = 0
        for idx, row in loc_df.iterrows():
            if i >= 10 or current_sum / all_sum >= 0.95:
                res.append(pd.Series(data={
                    'ExpID': 'Other',
                    'Value': loc_df.iloc[i:][column].sum(),
                }))
                break
            else:
                res.append(pd.Series(data={
                    'ExpID': task.get_exp_name(row.ExpID),
                    'Value': row[column]
                }))
            i += 1
        return pd.DataFrame(res)

    def get_figure(columns_prefiltered):
        columns = list(filter(lambda x: loc[x].sum() > 0, columns_prefiltered))
        fig = make_subplots(
            rows=1,
            cols=len(columns),
            specs=[[{"type": "pie"}] * len(columns)],
            subplot_titles=columns,
        )

        for i, name in enumerate(columns):
            frame = get_data(name)
            fig.add_trace(go.Pie(labels=frame.ExpID, values=frame.Value), col=i + 1, row=1)

        fig.update_layout(title=graph_name, autosize=True)
        return fig

    return get_figure(NEEDED_COLUMNS)


def iter_pmatch_fails_graphs(df, task):
    for label in sorted(df.PageLabel.unique()):
        loc = df[df.PageLabel == label]
        if sum(map(lambda x: loc[x].sum(), NEEDED_COLUMNS)) == 0:
            continue
        name = 'For label: "{}"'.format(label)
        yield generate_graph(task, loc, name), name


class PMatchFailsGraphRender(SimpleGraphRender):
    Name = 'PMatch Fails By Experiments Graph'
    FileName = 'pmatch_fails_by_exps_graph'
    Func = iter_pmatch_fails_graphs
