# coding: utf-8

from collections import defaultdict

from .graph_render_base import SimpleGraphRender

import logging

TIMING_STAGES = (
    ('Bscount', 'BannerMatch'),
    ('ComputeCosts', 'BannerMatch'),
    ('L1', 'BannerMatch'),
    ('L2', 'BannerMatch'),
    ('MatchedBannerPhrases', 'BannerMatch'),
    ('LimitSendingBanners', 'BannerMatch'),
    ('PatternPrematch'),
    ('ContextFinding'),
    ('MatchCommonPhrases', 'ContextFinding'),
    ('MatchBroadPhrases', 'ContextFinding'),
    ('BroadPhrasesWordnet', 'MatchBroadPhrases'),
    ('BroadMatchAdvMachine', 'MatchBroadPhrases'),
    ('BroadMatchOffer', 'MatchBroadPhrases'),
    ('BroadMatchPrefilterContext', 'MatchBroadPhrases'),
    ('ProcessMetagroups', 'ContextFinding'),
    ('BannerMatch'),
    ('ResetPlace', 'BannerMatch'),
    ('InitOrderData', 'BannerMatch'),
    ('MainContextStage', 'BannerMatch'),
    ('InitBannerCategories', 'BannerMatch'),
    ('BannerCategoryMatch', 'BannerMatch'),
    ('BannerPhraseFilter', 'BannerMatch'),
    ('ComputeFeatures', 'BannerMatch'),
    ('AddBannerMxFeatures', 'BannerMatch'),
    ('AddBannerGPUFeatures', 'BannerMatch'),
    ('FillMxFeaturesLog', 'BannerMatch'),
    ('FillBannerMatchAtoms', 'BannerMatch'),
    ('SerializePmatchResponse'),
)

PARENTS_MAP = {tup[0]: tup[1] if len(tup) == 2 else '' for tup in TIMING_STAGES}

for key, val in PARENTS_MAP.items():
    if val is None:
        raise AssertionError('No parent for stage {}'.format(key))


def get_children_map():
    children_map = defaultdict(list)
    for child, parent in PARENTS_MAP.items():
        children_map[parent] += [child]
    return children_map


CHILDREN_MAP = get_children_map()


class ChildBiggerParentException(Exception):
    '''When sum of child greater than parent'''


def iter_stage_timings_graph(df, task):
    import plotly.graph_objs as go
    df['Parent'] = df['Stage'].map(PARENTS_MAP)
    df.Parent[df.Parent.isnull()] = ''

    def relax(d, col):
        updated = True
        while updated:
            updated = False
            for _, row in d.iterrows():
                stage = row['Stage']
                children_sum = d[d['Parent'] == stage][col].sum()

                if row[col] < children_sum:
                    logging.debug('UPDATE: %s, %s, %f', stage, row[col], children_sum)
                    threshold = 500
                    difference = children_sum - row[col]
                    if difference >= threshold:
                        raise ChildBiggerParentException('ChildrenSum({}) - row[{}]({}) = {} >= {} for {}'.format(
                            children_sum,
                            col,
                            row[col],
                            difference,
                            threshold,
                            stage
                        ))
                    updated = True
                    d.loc[d['Stage'] == row['Stage'], col] = children_sum

    def get_graph(d):
        value_col = 'AvgDuration'
        d[value_col] = d['SumDuration']/d['Count']
        relax(d, value_col)
        logging.debug('\n%s', d)

        return go.Sunburst(
            values=d[value_col],
            labels=d['Stage'],
            parents=d['Parent'],
            branchvalues='total',
        )
    layout = go.Layout(
        autosize=False,
        width=1000,
        height=1000,
    )
    for hard_hit, title in zip((False, True), ('< 100 ms', '> 100ms')):
        d = df[df['HardHit'] == hard_hit]
        if d.empty:
            continue
        try:
            yield go.Figure(data=get_graph(d), layout=layout), title
        except Exception as e:
            task.set_info('Error with graph: `{}`: {}'.format(title, e))


class StageTimingsGraphRender(SimpleGraphRender):
    Name = 'StageTimings Graph'
    FileName = 'stage_timings_graph'
    Func = iter_stage_timings_graph
