# coding: utf-8

from collections import defaultdict
from itertools import product
from math import isnan
from pandas import DataFrame
from plotly.graph_objs import Figure
from typing import Dict
from typing import List
from typing import Tuple
from typing import Type

from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs import YabsServerRankingGraphs
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.aggregations import AggregationPrefixes
from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.aggregations import Percentiles

from .graph_render_base import GraphRenderBase
from .graph_render_base import TGraphRenderBase
from .stage_stats import StagesSchemas


ORDER_COLUMN = 'StageOrder'
OVERALL_COLUMN = 'Overall'


Stages = (
    'InitPhraseInfos',
    'Bscount',
    'ComputeCosts',
    'L1',
    'L2',
    'MatchedBannerPhrases',
    'LimitSendingBanners',
    'PMatchResult',
    'L3Candidates',
    'L3',
    'BannerSelectBeforePrefilter',
    'BannerSelectAfterPrefilter',
    'BannerSelectCalculateStatistics',
    'BannerSelectDone',
)


PLACE_SELECT_TO_NAME = {
    0: 'Normal',
    2: 'OrderFair',
    3: 'Ppc',
    4: 'PpcPage',
    5: 'Premium',
    6: 'Dynamic',
    7: 'Rank',
    8: 'Rumain',
    9: 'PpcMkb',
    10: 'Vote',
    12: 'GeoPpc',
    13: 'GeoPremium',
    14: 'GeoRange',
    17: 'FakePremiumFull',
    19: 'FakePpcFull',
    20: 'PremiumExtra',
    23: 'Performance',
    24: 'PpcSmart',
    25: 'PpcSmartCluster',
    26: 'Distribution',
    27: 'SmartWithMedia',
    28: 'Creative',
    29: 'MediaSearch',
    30: 'DistributionFake',
    31: 'ReachMedia',
    32: 'Trafaret',
    33: 'RankTrafaret',
    34: 'MediaRankTrafaret',
    35: 'AdvertisingZen',
    36: 'Gallery',
    37: 'GeoTrafaret',
}

InvertStages = {
    stage: i
    for i, stage in enumerate(Stages)
}


def get_linear_color(base, i):
    COLOR_COEF = 12
    if base.startswith('#'):
        base = base[1:]
    r = int(base[0:2], 16)
    g = int(base[2:4], 16)
    b = int(base[4:6], 16)
    r = min(r + COLOR_COEF * i, 255)
    g = min(g + COLOR_COEF * i, 255)
    b = min(b + COLOR_COEF * i, 255)
    return '#{:02X}{:02X}{:02X}'.format(r, g, b)


def generate_graph(df: DataFrame, trace_generator, title, schema=None, trace_generator_kwargs=None, iter_buttons=None) -> Tuple[Figure, str]:
    import plotly.graph_objs as go
    from sandbox.projects.yabs.ranking_group.YabsServerRankingGraphs.lib.stage_stats import StagesSchemas
    if schema is None:
        schema = StagesSchemas.BannerCounts
    if trace_generator_kwargs is None:
        trace_generator_kwargs = {}

    traces = []
    groups_indexes = {}
    default_group = None
    for group in schema.groups:
        indexes = {}
        for idx, column in enumerate(group.columns):
            for trace, id in trace_generator(df, column, idx, **trace_generator_kwargs):
                indexes[id] = len(traces)
                traces.append(trace)
        if indexes:
            groups_indexes[group.name] = indexes
            if default_group is None:
                default_group = group

    if not groups_indexes:
        return None

    def get_title(group):
        return '<b>{}</b>'.format(group.name)

    def get_visible(group):
        visibles = [False] * len(traces)
        for index in groups_indexes[group.name].values():
            visibles[index] = True
        return visibles

    for trace, value in zip(traces, get_visible(default_group)):
        trace.visible = value

    layout = go.Layout(
        title=get_title(default_group),
        yaxis=dict(
            visible=True,
            ticks='outside',
        ),
        boxmode='group',
        autosize=True,
    )

    slice_buttons = dict(
        type='buttons',
        active=0,
        buttons=[{
            'label': group.name,
            'method': 'update',
            'args': [{
                'visible': get_visible(group),
            }, {
                'title': get_title(group),
            }]
        } for group in schema.groups if group.name in groups_indexes
        ],
    )
    if iter_buttons is None:
        layout['updatemenus'] = [slice_buttons]
    else:
        layout['updatemenus'] = list(iter_buttons(slice_buttons, groups_indexes, schema))

    return go.Figure(data=traces, layout=layout), title


def funnel_graph(df, column, idx, aggregation):
    import plotly.graph_objs as go
    current_column = aggregation + column.table_column
    if current_column in df.columns:
        loc = df[current_column]
        if not loc.empty and any(loc != 0):
            yield go.Funnel(
                name=column.graph_column,
                y=df.StageName,
                x=loc,
                visible=False,
                hovertemplate='(%{y}, %{x})<br>' +
                              '%{percentInitial} of initial<br>' +
                              '%{percentPrevious} of previous<br>' +
                              '%{customdata.0:.01%} of initial total<br>' +
                              '%{customdata.1:.01%} of current stage total',
                customdata=[
                    (
                        float(x) / df[aggregation + OVERALL_COLUMN].iloc[0] if df[aggregation + OVERALL_COLUMN].iloc[0] != 0 else 1.,
                        float(x) / total_stage if total_stage != 0 else 1.,
                    )
                    for x, total_stage in zip(df[current_column], df[aggregation + OVERALL_COLUMN])
                ],
            ), column.graph_column


def percentile_graph(df, column, idx, x=None, delete_first_zeros=False):
    import plotly.graph_objs as go
    import plotly.express as px
    import math

    is_full = True
    if x is None:
        x = df.StageName
    for i, percentile in enumerate(Percentiles):
        if is_full:
            column_name = percentile.ColumnPrefix + column.table_column
            if column_name not in df.columns:
                continue
            loc = df[column_name]
            if all(loc == 0):
                is_full = False
                if i == 0:
                    break
            if delete_first_zeros:
                x_vec = []
                y_vec = []
                for row_x, row_y in zip(x, loc):
                    if y_vec or (row_y is not None and not math.isnan(row_y) and row_y != 0):
                        y_vec.append(row_y)
                        x_vec.append(row_x)
            else:
                x_vec = x
                y_vec = loc
            if len(x_vec) == 0:
                continue
            yield go.Scatter(
                name=column.graph_column + ' ' + percentile.Label,
                mode='lines',
                legendgroup=column.graph_column,
                x=x_vec,
                y=y_vec,
                line_color=get_linear_color(px.colors.qualitative.Dark24[idx], i),
            ), column.graph_column + ' ' + percentile.Label


def percentile_iter_buttons(slice_buttons, groups_indexes, schema):
    group_idx = 0
    trace_count = len(groups_indexes)
    for group in schema.groups:
        if group.name not in groups_indexes:
            continue
        buttons = []
        indexes = defaultdict(set)
        for column, percentile in product(group.columns, Percentiles):
            id = column.graph_column + ' ' + percentile.Label
            if id not in groups_indexes[group.name]:
                continue
            trace_index = groups_indexes[group.name][id]
            indexes[percentile.Label].add(trace_index)
        for percentile in Percentiles:
            buttons.append(dict(
                label=percentile.Label,
                method='restyle',
                args=(
                    dict(
                        visible=True,
                    ),
                    list(indexes[percentile.Label]),
                ),
                args2=(
                    dict(
                        visible='legendonly',
                    ),
                    list(indexes[percentile.Label]),
                ),
            ))

        booleans = [False] * trace_count + [True]
        booleans[group_idx] = True
        for i, val in enumerate(booleans):
            slice_buttons['buttons'][group_idx]['args'][1]['updatemenus[{}].visible'.format(i)] = val

        yield dict(
            type='buttons',
            direction='right',
            x=0.5,
            xanchor='left',
            y=1.15,
            yanchor='bottom',
            active=0,
            visible=group_idx == 0,
            buttons=buttons,
        )

        group_idx += 1

    yield slice_buttons


def transform_dataframe(df: DataFrame, main_stages=['StageName']):
    import pandas

    def transform(x, add_place_select=False, grouping_by=[]):
        res = {}
        for _, row in x.iterrows():
            for k, val in row.iteritems():
                if (not isinstance(val, float) or not isnan(val)) and k not in grouping_by:
                    res[k] = [val]
        if add_place_select:
            res['PlaceSelect'] = [None]
        return pandas.DataFrame(res)

    df = pandas.concat([
        df[df.PlaceSelect.isna()].groupby(main_stages).apply(transform, add_place_select=True, grouping_by=main_stages).reset_index(),
        df[df.PlaceSelect.notna()].groupby(main_stages + ['PlaceSelect']).apply(transform, grouping_by=main_stages + ['PlaceSelect']).reset_index(),
    ], sort=False)
    df[ORDER_COLUMN] = df['StageName'].apply(lambda x: InvertStages[x])
    df.sort_values([ORDER_COLUMN, 'PlaceSelect'], inplace=True)
    return df


def get_graphs(render: Type[TGraphRenderBase], graph_name: str, graphs: List[Tuple[str, Figure]]):
    res_str = ''
    if graphs:
        for nu, (name, fig) in enumerate(graphs):
            res_str += f'<h2>{name}</h2>\n'
            res_str += render.figure_to_str(f'{graph_name}_{nu}', fig) + '\n'
    return res_str


class StagesGraphsHolder(object):
    def __init__(self):
        self.main_graphs: List[Tuple[str, Figure]] = []
        self.graphs_by_place_select: Dict[str, List[Tuple[str, Figure]]] = defaultdict(list)

    def add(self, select_name: str, name,  fig: Figure):
        if select_name is None:
            self.main_graphs.append((name, fig))
        else:
            self.graphs_by_place_select[select_name].append((name, fig))


class StageGraphsRender(GraphRenderBase):
    Name = 'Stages Graph'
    FileName = 'stages_stats_graph'

    @classmethod
    def render(cls, task: YabsServerRankingGraphs, df: DataFrame) -> str:
        df = transform_dataframe(df)
        place_selects = df.PlaceSelect.unique()
        holder = StagesGraphsHolder()
        for select in place_selects:
            postfix = ''
            if select is None or isnan(select):
                loc = df[df.PlaceSelect.isna()]
                select_name = None
            else:
                loc = df[df.PlaceSelect == select]
                select_name = PLACE_SELECT_TO_NAME.get(select, f'PlaceSelect: #{int(select)}')
                postfix = f' For {select_name}'
            if loc.empty:
                continue
            name = 'Funnel Avg Banners Count'
            g = generate_graph(loc, funnel_graph, name + postfix,
                               trace_generator_kwargs=dict(aggregation=AggregationPrefixes.AVG))
            if g is not None:
                holder.add(select_name, name, g[0])
            name = 'Percentiles Banners Count'
            g = generate_graph(loc, percentile_graph, name + postfix,
                               iter_buttons=percentile_iter_buttons)
            if g is not None:
                holder.add(select_name, name, g[0])

            name = 'Percentiles average values (SourceCost, RealCost, CTR, PCTR, ABConversionCostCoef, BidCorrection, L1Value, L2Value, L3Value AS Value)'
            g = generate_graph(
                loc,
                percentile_graph,
                name + postfix,
                schema=StagesSchemas.AverageValues,
                iter_buttons=percentile_iter_buttons,
                trace_generator_kwargs=dict(
                    delete_first_zeros=True,
                ),

            )
            if g is not None:
                holder.add(select_name, name, g[0])

        res = ''

        if holder.main_graphs:
            res += '<h1>Stages Graph</h1>\n'
            res += get_graphs(cls, 'stages_stats', holder.main_graphs)
        if holder.graphs_by_place_select:
            res += '<h1>Stages Graphs By PlaceSelect</h1>\n'
            for select_name, graphs in holder.graphs_by_place_select.items():
                res += f'<details><summary>{select_name}</summary>'
                res += get_graphs(cls, f'stages_graph_by_{select_name}', graphs)
                res += '</details>\n'

        return res
