import gc
import random
import tempfile
import tracemalloc
import objgraph
import os

import flask

from distutils.util import strtobool


snapshot = None


def list_with_weakproxy(lst):
    if not lst:
        return False

    for it in lst:
        try:
            if type(it).__name__ == "weakproxy":
                return True
        except Exception:
            pass
    return False


def get_argument(name, default=None):
    return flask.request.args.get(name, default)


def set_args(args):
    limit = get_argument("limit", None)
    if limit is not None:
        args["limit"] = int(limit)
    shortnames = get_argument("shortnames", None)
    if shortnames is not None:
        args["shortnames"] = strtobool(shortnames)
    max_depth = get_argument("max_depth", None)
    if max_depth is not None:
        args["max_depth"] = int(max_depth)
    too_many = get_argument("too_many", None)
    if too_many is not None:
        args["too_many"] = int(too_many)


def get_obj():
    addr = get_argument("addr", None)
    if addr:
        return objgraph.at(int(addr, 16))
    else:
        type_ = get_argument("type")
        return random.choice(objgraph.by_type(type_))


def return_json(j):
    # self.set_status(200)
    # self.add_header("Content-Type", "application/json")
    # self.write(json.dumps(j, indent=4, separators=(",", ": ")))
    return flask.jsonify(j)


def return_svg(fp):
    resp = flask.Response(fp.read())
    resp.headers["Content-Type"] = "image/svg+xml"
    return resp


def memview_handler(method):
    global snapshot
    if not method:
        return "expected 'method' parameter", 400

    # Logger.get().info("handle method={}".format(method))

    args = {}
    if method == "most_common_types":
        set_args(args)
        lst = objgraph.most_common_types(**args)
        return return_json(lst)
    elif method == "count":
        set_args(args)
        lst = objgraph.count(get_argument("type"))
        return return_json(lst)
    elif method == "growth":
        set_args(args)
        lst = objgraph.growth(**args)
        return return_json(lst)
    elif method == "refs":
        set_args(args)
        obj = get_obj()
        with tempfile.NamedTemporaryFile(mode="rb", suffix=".svg") as fp:
            objgraph.show_refs(obj, filename=fp.name, **args)
            return return_svg(fp)
    elif method == "backrefs":
        set_args(args)
        obj = get_obj()
        with tempfile.NamedTemporaryFile(mode="rb", suffix=".svg") as fp:
            objgraph.show_backrefs(obj, filename=fp.name, **args)
            return return_svg(fp)
    elif method == "get_leaking_objects":
        return return_json([str(obj) for obj in objgraph.get_leaking_objects()])
    elif method == "most_fat_lists":
        set_args(args)
        result = []
        for slst in sorted(
            ((len(lst), id(lst)) for lst in objgraph.by_type("list")),
            key=lambda x: x[0],
            reverse=True,
        )[:args.get("limit", 100)]:
            result.append([slst[0], "0x{:x}".format(slst[1])])
        return return_json(result)
    elif method == "most_fat_lists_with_weakproxy":
        set_args(args)
        result = []
        for slst in sorted(
            ((len(lst), id(lst)) for lst in objgraph.by_type("list") if list_with_weakproxy(lst)),
            key=lambda x: x[0],
            reverse=True,
        )[:args.get("limit", 100)]:
            result.append([slst[0], "0x{:x}".format(slst[1])])
        return return_json(result)
    elif method == "most_fat_dicts":
        set_args(args)
        result = []
        for dict_info in sorted(
            ((len(d), id(d)) for d in objgraph.by_type("dict")),
            key=lambda x: x[0],
            reverse=True,
        )[:args.get("limit", 100)]:
            result.append([dict_info[0], "0x{:x}".format(dict_info[1])])
        return return_json(result)
    elif method == "tracemalloc_start":
        if tracemalloc.is_tracing():
            return "already started"
        nframe = int(get_argument("nframe", 1))
        tracemalloc.start(nframe)
        snapshot = tracemalloc.take_snapshot()
        return "started(nframe={})".format(nframe)
    elif method == "tracemalloc_stop":
        if not tracemalloc.is_tracing():
            return "already stopped"
        snapshot = None
        tracemalloc.stop()
        return "stopped"
    elif method == "tracemalloc_clear":
        tracemalloc.clear_traces()
        return "cleared"
    elif method == "tracemalloc_diff":
        if not tracemalloc.is_tracing():
            return "tracemalloc stopped"
        new_snapshot = tracemalloc.take_snapshot()
        key_type = get_argument("key_type", "lineno")
        stat_diff = new_snapshot.compare_to(snapshot, key_type)
        snapshot = new_snapshot
        limit = int(get_argument("limit", "0"))
        result = []
        for n, rec in enumerate(stat_diff, 1):
            if limit and limit > n:
                break
            frame = frame = rec.traceback[0]
            filename = os.sep.join(frame.filename.split(os.sep)[-2:])
            line = "{}:{}".format(filename, frame.lineno)
            result.append(dict(
                count=rec.count,
                count_diff=rec.count_diff,
                size=rec.size,
                size_diff=rec.size_diff,
                line=line,
            ))
        return return_json(result)
    elif method == "traced_memory":
        if not tracemalloc.is_tracing():
            return "tracemalloc stopped", 403
        return return_json(list(tracemalloc.get_traced_memory()))
    elif method == "gc_collect":
        before = gc.get_count()
        generation = int(get_argument("generation", "2"))
        gc.collect(generation)
        after = gc.get_count()
        return return_json(dict(before=before, after=after))
    elif method == "gc_stats":
        return return_json(gc.get_stats())
    elif method == "gc_threshold":
        res = ""
        try:
            if get_argument("threshold0", None) is not None:
                threshold0 = int(get_argument("threshold0"))
                threshold1 = int(get_argument("threshold1"))
                threshold2 = int(get_argument("threshold2"))
                gc.set_threshold(threshold0, threshold1, threshold2)
                return "gc threshold updated"
        except Exception as exc:
            res += str(exc)
        return res + "\ncurrent_threshold: {}".format(gc.get_threshold())
    elif method == "gc_garbage":
        result = []
        for obj in gc.garbage:
            try:
                result.append([id(obj), str(obj)])
            except Exception as exc:
                result.append([0, str(exc)])
        return return_json(result)
    else:
        return "unsupported 'method' parameter value == {}".format(method), 400
