# -*- coding: utf-8 -*-

from __future__ import unicode_literals

import atexit
from collections import defaultdict
from cProfile import Profile as Cprofile
import logging
from os import environ
from pstats import Stats as ClassicProfilerStats
import signal
from StringIO import StringIO

from passport.backend.core.lazy_loader import LazyLoader
from passport.backend.social.common.chrono import now
from pyprof2calltree import CalltreeConverter
from werkzeug.serving import (
    BaseWSGIServer,
    WSGIRequestHandler,
)
from werkzeug.wrappers import (
    Request,
    Response,
)


logger = logging.getLogger(__name__)


def create_profiler_from_environment():
    profiler_name = environ.get('SOCIALISM_PROFILER')
    if profiler_name == 'nylas':
        return NylasProfiler()
    elif profiler_name == 'classic':
        return ClassicProfiler()
    else:
        return DummyProfiler()


def get_profiler():
    return LazyLoader.get_instance('profiler')


def start_profiler_manager():
    profiler = get_profiler()
    if not isinstance(profiler, DummyProfiler):
        emitter = _Emitter(profiler, '127.0.0.1', 16384)
        emitter.run()


class ClassicProfiler(object):
    def __init__(self):
        self._profiler = Cprofile(builtins=False)

    def enable(self):
        self._profiler.enable()

    def disable(self):
        self._profiler.disable()

    def get_file(self):
        filename = '/tmp/profiler_stats'
        self._profiler.create_stats()
        try:
            stats = ClassicProfilerStats(self._profiler)
        except TypeError:
            return StringIO()
        converter = CalltreeConverter(stats)
        with open(filename, 'wb') as f:
            converter.output(f)
        return open(filename, 'rb')

    def reset(self):
        self._profiler.disable()
        self._profiler = Cprofile(builtins=False)


class NylasProfiler(object):
    def __init__(self):
        self._sampler = _NylasStackSampler()

    def enable(self):
        self._sampler.start()

    def disable(self):
        self._sampler.stop()

    def reset(self):
        self._sampler.stop()
        self._sampler.reset()

    def get_file(self):
        stats = self._sampler.output_stats()
        return StringIO(stats)


class DummyProfiler(object):
    def get_file(self):
        return StringIO('dummy\n')

    def enable(self):
        pass

    def disable(self):
        pass

    def reset(self):
        pass


class _NylasStackSampler(object):
    """
    A simple stack sampler for low-overhead CPU profiling: samples the call
    stack every `interval` seconds and keeps track of counts by frame. Because
    this uses signals, it only works on the main thread.
    """
    def __init__(self, interval=0.005):
        self.interval = interval
        self._started = None
        self._stack_counts = defaultdict(int)
        self._elapsed = 0.0

    def start(self):
        self._started = now.f()
        try:
            signal.signal(signal.SIGVTALRM, self._sample)
        except ValueError:
            raise ValueError('Can only sample on the main thread')

        signal.setitimer(signal.ITIMER_VIRTUAL, self.interval)
        atexit.register(self.stop)

    def _sample(self, signum, frame):
        stack = []
        while frame is not None:
            stack.append(self._format_frame(frame))
            frame = frame.f_back

        stack = ';'.join(reversed(stack))
        self._stack_counts[stack] += 1
        signal.setitimer(signal.ITIMER_VIRTUAL, self.interval)

    def _format_frame(self, frame):
        return '{}({})'.format(frame.f_code.co_name,
                               frame.f_globals.get('__name__'))

    def output_stats(self):
        elapsed = self._elapsed
        if self._started is not None:
            elapsed += now.f() - self._started

        lines = ['elapsed {}'.format(elapsed),
                 'granularity {}'.format(self.interval)]
        ordered_stacks = sorted(self._stack_counts.items(),
                                key=lambda kv: kv[1], reverse=True)
        lines.extend(['{} {}'.format(frame, count)
                      for frame, count in ordered_stacks])
        return '\n'.join(lines) + '\n'

    def reset(self):
        self._elapsed = 0.0
        self._stack_counts = defaultdict(int)
        if self._started is not None:
            self._started = now.f()

    def stop(self):
        if self._started is not None:
            self._elapsed += now.f() - self._started
        self._started = None
        signal.setitimer(signal.ITIMER_VIRTUAL, 0)

    def __del__(self):
        self.stop()


class _Emitter(object):
    def __init__(self, profiler, host, port):
        self.profiler = profiler
        self.host = host
        self.port = port

    def handle_request(self, environ, start_response):
        response = Response(self.profiler.get_file())
        request = Request(environ)
        if request.args.get('reset') == '1':
            self.profiler.reset()
        return response(environ, start_response)

    def run(self):
        server = BaseWSGIServer(self.host, self.port, self.handle_request, _QuietHandler)
        server.log = lambda *args, **kwargs: None
        logger.info('Serving profiler on port {}'.format(self.port))
        server.serve_forever()


class _QuietHandler(WSGIRequestHandler):
    def log_request(self, *args, **kwargs):
        """Suppress request logging so as not to pollute application logs."""
