# -*- coding: utf-8 -*-

import re
from xml.sax.saxutils import escape as escape_html_str

import sandbox.sandboxsdk.parameters as sdk_parameters

from sandbox.projects import resource_types
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.errors import SandboxTaskFailureError


class CompareProfileStats(SandboxTask):
    class ResourcesInfoField(sdk_parameters.SandboxInfoParameter):
        name = 'resources_info_field'
        description = 'Task resources'

    class ProfileStats1Id(sdk_parameters.ResourceSelector):
        name = 'profile_stats1_id'
        description = 'Profile stats 1st'
        resource_type = resource_types.PROFILE_STAT
        required = True

    class ProfileStats2Id(sdk_parameters.ResourceSelector):
        name = 'profile_stats2_id'
        description = 'Profile stats 2nd'
        resource_type = resource_types.PROFILE_STAT
        required = True

    class ColoredReport(sdk_parameters.SandboxBoolParameter):
        name = 'colored_report'
        description = 'Colored report'

    class DontNormalize(sdk_parameters.SandboxBoolParameter):
        name = 'dont_normalize'
        description = 'Don\'t normalize ticks'

    class IgnoreWarnings(sdk_parameters.SandboxBoolParameter):
        name = 'ignore_warnings'
        description = 'Ignore warnings'

    type = 'COMPARE_PROFILE_STATS'
    input_parameters = [
        ResourcesInfoField,
        ProfileStats1Id,
        ProfileStats2Id,
        ColoredReport,
        DontNormalize,
        IgnoreWarnings
    ]

    def on_enqueue(self):
        SandboxTask.on_enqueue(self)
        resource = self._create_resource(self.descr, 'profile_stats_diff.html', resource_types.PROFILE_STATS_DIFF)
        self.ctx['out_resource_id'] = resource.id

    def on_execute(self):
        stats1_r = self._read_resource(self.ctx[self.ProfileStats1Id.name])
        stats2_r = self._read_resource(self.ctx[self.ProfileStats2Id.name])

        stats_diff = None

        normalize = not self.ctx[self.DontNormalize.name]
        ignore_warnings = self.ctx[self.IgnoreWarnings.name]

        with open(stats1_r.abs_path(), 'ro') as stats1_f:
            with open(stats2_r.abs_path(), 'ro') as stats2_f:
                stats1 = _parse_stats(stats1_f, ignore_warnings)
                stats2 = _parse_stats(stats2_f, ignore_warnings)
                stats_diff = _calc_stats_diff(stats1, stats2, normalize)

        if stats_diff:
            stats_diff_r = self._read_resource(self.ctx['out_resource_id'], sync=False)
            with open(stats_diff_r.abs_path(), 'w') as stats_diff_f:
                colored = self.ctx[self.ColoredReport.name]
                stats_diff_f.write(_dump_stats_diff(stats_diff, colored, normalize))


def _ignore_warning(stats_f):
    line = stats_f.next()
    if not line.startswith('!!!!WARNING: '):
        raise UnexpectedLineError(line)
    stats_f.next()
    stats_f.next()


def _parse_stats(stats_f, ignore_warnings):
    top = _parse_stats_top(stats_f)
    total = top['samples'][0]

    try:
        modules = _parse_stats_impl(stats_f, _ModuleStat, total)
    except UnexpectedLineError:
        if not ignore_warnings: raise
        _ignore_warning(stats_f)
        modules = _parse_stats_impl(stats_f, _ModuleStat, total)

    functions, hotspots = [_parse_stats_impl(stats_f, cls, total)
                           for cls in [_FunctionStat, _HotspotStat]]
    return _Stats(top, modules, functions, hotspots)


def _calc_stats_diff(stats1st, stats2nd, normalize):
    top = {}
    for k, value1st in stats1st.top.iteritems():
        index = value1st[1]
        value1st = value1st[0]
        if k in stats2nd.top:
            value2nd = stats2nd.top[k][0]
            top[(k, index)] = (value1st, value2nd, value2nd - value1st, 100. * (value2nd - value1st) / value1st if value1st else None)

    if normalize:
        norm_coef = 1. * stats1st.top['samples'][0] / stats2nd.top['samples'][0]
        for one in [stats2nd.modules, stats2nd.functions, stats2nd.hotspots]:
            for v in one.itervalues():
                v.time = int(v.time * norm_coef)

    modules = _calc_diff_impl(stats1st.modules, stats2nd.modules)
    functions = _calc_diff_impl(stats1st.functions, stats2nd.functions)
    hotspots = _calc_diff_impl(stats1st.hotspots, stats2nd.hotspots)

    return _Diff(top, modules, functions, hotspots)


def _dump_stats_diff(stats_diff, colored, normalized):
    dump = '<html>\n<body>\n'

    lrrrr = [0, 2, 2, 2, 2]

    dump += '<br>\n'
    dump += _begin_table()
    dump += _make_htr(['', '1st', '2nd', 'Diff', 'Diff, %'], lrrrr)
    for k in sorted(stats_diff.top.keys(), key=lambda x: x[1]):
        v = stats_diff.top[k]
        index = k[1]
        k = k[0]
        dump += _make_tr([str(index) + ' ' + k, v[0], v[1], _format_abs_diff(v[2]), _format_percent(v[3]) if v[3] is not None else ''], lrrrr)
    dump += _end_table()

    caption2nd = normalized and '2nd [normalized]\n(% of total)' or '2nd\n(% of total)'

    sort_key = lambda x: (x.abs, x.second.sort_key(), x.first.sort_key())

    dump += '<br>\n<h3>Modules:</h3>'
    dump += _begin_table()
    dump += _make_htr(['Module', '1st\n(% of total)', caption2nd, 'Diff\n(% of total)', 'Diff, %'], lrrrr)
    for one in sorted(stats_diff.modules, key=sort_key):
        dump += _make_tr([one.first.module or one.second.module, _format_stat(one.first), _format_stat(one.second), _format_complex_diff(one), _format_percent(one.rel)], lrrrr, colored and _get_color(one.abs))
    dump += _end_table()

    dump += '<br>\n<h3>Hot functions:</h3>'
    dump += _begin_table()
    dump += _make_htr(['Function', '1st\n(% of total)', caption2nd, 'Diff\n(% of total)', 'Diff, %'], lrrrr)
    for one in sorted(stats_diff.functions, key=sort_key):
        dump += _make_tr([one.first.function or one.second.function, _format_stat(one.first), _format_stat(one.second), _format_complex_diff(one), _format_percent(one.rel)], lrrrr, colored and _get_color(one.abs))
    dump += _end_table()

    llrrrr = [0, 0, 2, 2, 2, 2]

    dump += '<br>\n<h3>Hotspots:</h3>'
    dump += _begin_table()
    dump += _make_htr(['Address', 'Place', '1st\n(% of total)', caption2nd, 'Diff\n(% of total)', 'Diff, %'], llrrrr)
    for one in sorted(stats_diff.hotspots, key=sort_key):
        dump += _make_tr([one.first.address or one.second.address, one.first.place or one.second.place, _format_stat(one.first), _format_stat(one.second), _format_complex_diff(one), _format_percent(one.rel)], llrrrr, colored and _get_color(one.abs))
    dump += _end_table()

    dump += '<br>\n'
    dump += '</body>\n</html>'

    return dump


def _get_color(value):
    if value:
        if value < 0:
            return '#ccffcc'
        elif value > 0:
            return '#ffcccc'
    return None


def _format_percent(value):
    if type(value) == float:
        return  '{0:+.2f}%'.format(value)
    elif type(value) == int:
        return '{0:+d}'.format(value)
    else:
        return value


def _format_abs_diff(value):
    if type(value) == float:
        return '{0:+.2f}'.format(value)
    elif type(value) == int:
        return '{0:+d}'.format(value)
    else:
        return value


def _format_complex_diff(diff):
    return diff.abs and '{0}\n({1:+.4f}%)'.format(_format_abs_diff(diff.abs), diff.abs_percent) or ''


def _format_stat(one_stat):
    return one_stat.time and '{0}\n({1:.3f}%)'.format(one_stat.time, one_stat.percent) or ''


def _format_cell(value, align=0, tag='td'):
    align = ['left', 'center', 'right'][align]
    if value == '':
        value = '&nbsp;'
    else:
        value = escape_html_str(str(value))
    return '<{2} align={1}>{0}</{2}>'.format(value, align, tag)


def _make_tr(cells, aligns, color=None):
    if color:
        tag = '<tr style="background-color: {0};">'.format(color)
    else:
        tag = '<tr>'

    return tag + ''.join([_format_cell(cell, aligns[i] if aligns else 0) for i, cell in enumerate(cells)]) + '</tr>\n'


def _make_htr(cells, aligns):
    return '<tr>' + ''.join([_format_cell(cell, aligns[i] if aligns else 0, 'th') for i, cell in enumerate(cells)]) + '</tr>\n'


def _begin_table():
    return '<table border=1 cellspacing=0 cellpadding=4 style="font-family: monospace;">\n'


def _end_table():
    return '</table>\n'


def _parse_stats_top(stats_f):
    stats_top = {}

    index = 0

    while True:
        line = stats_f.next().strip()
        if not line: break
        k, v = line.split(':')
        try: v = int(v)
        except ValueError:
            try: v = float(v)
            except: raise UnexpectedLineError(line)
        stats_top[k] = (v, index)
        index += 1

    while True:
        line = stats_f.next().strip()
        if not line: break
        parts = line.split()
        for i in xrange(len(parts)/2):
            k = parts[2*i].strip()[:-1]
            v = int(parts[2*i+1])
            stats_top[k] = (v, index)
            index += 1

    return stats_top


def _calc_diff_impl(stats1st, stats2nd):
    key_set = set(stats1st.keys()) | set(stats2nd.keys())
    diff = [_OneDiff(stats1st.get(key), stats2nd.get(key)) for key in key_set]
    return diff


def _parse_stats_impl(stats_f, cls, total):
    line = stats_f.next().strip()
    if line != cls.head:
        raise UnexpectedLineError(line)

    stats = {}
    r = re.compile(cls.re)
    while True:
        line = stats_f.next().strip()
        if not line: break
        m = r.match(line)
        if m:
            init_params = m.groupdict().copy()
            init_params['total'] = total
            stat = cls(**init_params)
            stats[stat.key()] = stat
    return stats


class _ModuleStat(object):
    def __init__(self, module, time, percent, total):
        self.module = module
        self.time = int(time)
        self.percent = 100. * self.time / total

    def sort_key(self):
        return -self.time

    def key(self):
        return self.module

    head = 'Modules:'
    re = '^\s*(?P<time>\d+)\s+(?P<percent>[0-9.]{4,5})%\s+\|X*\s+(?P<module>.+)\s*$'


class _FunctionStat(object):
    def __init__(self, function, time, percent, total):
        self.function = function
        self.time = int(time)
        self.percent = 100. * self.time / total

    def key(self):
        return self.function

    def sort_key(self):
        return -self.time

    head = 'Hot functions:'
    re = '^\s*(?P<time>\d+)\s+(?P<percent>[0-9.]{4,5})%\s+\|X*\s+(?P<function>.+)\s*$'


class _HotspotStat(object):
    def __init__(self, place, address, time, percent, total):
        self.place = place
        self.address = address
        self.time = int(time)
        self.percent = 100. * self.time / total

    def key(self):
        return self.place

    def sort_key(self):
        return -self.time

    head = 'Hotspots:'
    re = '^\s*(?P<time>\d+)\s+(?P<percent>[0-9.]{4,5})%\s+\|\s+(?P<address>0x[0-9a-f]{16})\s+(?P<place>.+)\s*$'


class _Stats(object):
    def __init__(self, top, modules, functions, hotspots):
        self.top = top
        self.modules = modules
        self.functions = functions
        self.hotspots = hotspots


class _OneDiff(object):

    class _StatStub(object):
        def __getattr__(self, attr):
            return ''

        def key(self):
            return ''

        def sort_key(self):
            return ''

    _stub = _StatStub()

    def __init__(self, first, second):
        if first is not None and second is not None:
            self.abs = second.time - first.time
            self.abs_percent = second.percent - first.percent
            self.rel = 100. * self.abs / first.time
        else:
            self.abs = ''
            self.abs_percent = ''
            self.rel = ''

        self.first = first or _OneDiff._stub
        self.second = second or _OneDiff._stub


class _Diff(_Stats):
    def __init__(self, top, modules, functions, hotspots):
        _Stats.__init__(self, top, modules, functions, hotspots)


class UnexpectedLineError(SandboxTaskFailureError):
    def __init__(self, line):
        SandboxTaskFailureError.__init__(self, 'Parsing error: unexpected line - "{0}"'.format(line))


__Task__ = CompareProfileStats
