from __future__ import absolute_import, print_function, division

import os
import psutil
import Queue
import random
import requests
import resource
import signal
import simplejson
import socket
import subprocess as subproc
import sys
import threading
import time

from kernel.util.errors import formatException


PROFILE = False


class SkyboneCoordDaemonV2(object):
    # router: udp in, parse, grab TID and (ip, port), mrequest (msgpacked-request)
    # router: => shard 4b TID, 4b len, mrequest
    # shard => router: 4b TID, mresponse
    # router: find (ip, port) by TID, wait more responses if needed, send response

    # todo: info, proto v2
    def __init__(self, path, wpath, cfg, log):
        self.cfg = cfg
        self.log = log
        self.path = path
        self.wpath = wpath

        self._stats_thread = None
        self._child_procs = []
        self._process_threads = []

        self._stopped = threading.Event()
        self._stopped.set()
        self._stopev = threading.Event()
        self._stopev.clear()

        self.sensorsqueue = Queue.Queue(maxsize=600)
        self.statslock = threading.Lock()
        self.stats = {
            'main': {'shards': 0, 'routers': 0},
            'shards': {},
            'routers': {}
        }

        self.hostname = socket.gethostname()

        if '-' in self.hostname:
            self.dc = self.hostname.split('-', 1)[0]
            if self.dc[-1].isdigit():
                self.dc = self.dc[:-1]
        else:
            splithostname = self.hostname.split('.', 1)[0]
            if len(splithostname) > 2:
                if splithostname[-2].isdigit():
                    dclet = splithostname[-1]
                    if dclet == 'k':
                        self.dc = 'man'
                    elif dclet == 'i':
                        self.dc = 'sas'
                    elif dclet == 'e':
                        self.dc = 'iva'
                    elif dclet == 'f':
                        self.dc = 'myt'
                    else:
                        self.dc = None
                else:
                    self.dc = None
            else:
                self.dc = None

        self.log.info('Detected hostname: %s, dc: %s', self.hostname, self.dc)

    def _process(self, name, args, stopit, handle_stats='unknown'):
        while not stopit.isSet():
            try:
                proc = subproc.Popen(args, stdout=subproc.PIPE, stderr=subproc.PIPE, close_fds=True)
                self._child_procs.append(proc)

                thr_stdout = threading.Thread(
                    target=self._process_stream,
                    args=(name, proc.stdout, self.log.info, handle_stats),
                )
                thr_stdout.daemon = True

                thr_stderr = threading.Thread(target=self._process_stream, args=(name, proc.stderr, self.log.error))
                thr_stderr.daemon = True

                thr_stdout.start()
                thr_stderr.start()

                try:
                    while True:
                        if proc.poll() is None and not stopit.isSet():
                            time.sleep(1)
                            continue
                        break

                finally:
                    if proc.poll() is None:
                        proc.send_signal(signal.SIGINT)

                        for i in range(3000):
                            if proc.poll() is None:
                                time.sleep(0.1)
                                continue
                            break

                        proc.wait()
                        self._child_procs.remove(proc)

                        while thr_stdout.isAlive():
                            thr_stdout.join(timeout=1)

                        while thr_stderr.isAlive():
                            thr_stderr.join(timeout=1)

            except KeyboardInterrupt:
                return
            except Exception as ex:
                self.log.critical(
                    'Failed to start proc %r: %s: %s',
                    args, ex.__class__.__name__, str(ex)
                )
                time.sleep(1)
                continue

    def _process_stream(self, name, stream, log, handle_stats=False):
        while True:
            data = stream.readline()
            if not data:
                break

            if handle_stats:
                if data.startswith('STATS: '):
                    stats = simplejson.loads(data[7:])

                    with self.statslock:
                        self.stats[handle_stats].setdefault(name, {}).update(stats)

                    continue

            log('%s: %s' % (name, data.strip()))

    def statsthread(self):
        '''
        router:
            {'in_drop': 0,
             'in_drop_job': 0,
             'in_rps': 0,
             'in_rps_announce': 0,
             'in_rps_announce_legacy': 0,
             'in_rps_connect': 0,
             'in_rps_info': 0,
             'in_rps_invalid': 0,
             'in_rps_misc': 0,
             'in_rps_state_leech': 0,
             'in_rps_state_seed': 0,
             'in_rps_state_stop': 0,
             'ou_drop': 0,
             'ou_rps': 0,
             'shard_queue': 0,
             'shard_queue_max': 20,
             'squeue': '  0/ 20',
             'trans': 0,
             'udp_q': 0}

        shard:
            {'job_q': 0,
            'packets_cnt': 0,
            'res_q': '  0/  6',
            'result_queue': 0,
            'result_queue_max': 6}
        '''

        user_time = 0
        system_time = 0

        def avg(v):
            v = list(v)
            if not v:
                return 0
            return sum(v) / len(v)

        while not self._stopev.isSet():
            ts = time.time()
            tsint = int(ts)

            rusage_self = resource.getrusage(resource.RUSAGE_SELF)

            main_proc = psutil.Process(os.getpid())

            all_utime = (
                rusage_self.ru_utime +
                sum(v['cpu_utime'] for v in self.stats['shards'].itervalues()) +
                sum(v['cpu_utime'] for v in self.stats['routers'].itervalues())
            )
            all_stime = (
                rusage_self.ru_stime +
                sum(v['cpu_stime'] for v in self.stats['shards'].itervalues()) +
                sum(v['cpu_stime'] for v in self.stats['routers'].itervalues())
            )

            # Cap max possible cpu usage
            max_possible_cpu = 3000

            
            try:
                rss = main_proc.memory_info().rss
            except AttributeError:
                rss = main_proc.get_memory_info().rss
            rusage = {
                'user': min((all_utime - user_time) * 100, max_possible_cpu),
                'system': min((all_stime - system_time) * 100, max_possible_cpu),
                'rss': (
                        rss +
                        sum(v['rss'] for v in self.stats['shards'].itervalues()) +
                        sum(v['rss'] for v in self.stats['routers'].itervalues())
                )
            }

            user_time = all_utime
            system_time = all_stime

            try:
                sensors_to_send = {
                    'packets_in': {
                        'announce_leech': ('routers', 'in_rps_state_leech', sum),
                        'announce_seed': ('routers', 'in_rps_state_seed', sum),
                        'announce_stop': ('routers', 'in_rps_state_stop', sum),
                        'connect': ('routers', 'in_rps_connect', sum),
                        'info': ('routers', 'in_rps_info', sum),
                        'invalid': ('routers', 'in_rps_invalid', sum),
                    },
                    'packets': {
                        'incoming': ('routers', 'in_rps', sum),
                        'outgoing': ('routers', 'ou_rps', sum),
                        'dropped': ('routers', 'in_drop', sum),             # dropped right after recv, cant process
                        'dropped_job': ('routers', 'in_drop_job', sum),     # dropped while forwarding job to the shard
                        'dropped_out': ('routers', 'ou_drop', sum),         # dropped after processing, bad req
                    },
                    'protocol': {
                        'announce_legacy': ('routers', 'in_rps_announce_legacy', sum),
                        'announce_v3': ('routers', 'in_rps_announce', sum),
                        'connect_v1': ('routers', 'in_rps_connect_v1', sum),
                        'connect_v2': ('routers', 'in_rps_connect_v2', sum),
                        'connect_v3': ('routers', 'in_rps_connect_v3', sum),
                    },
                    'queues': {
                        'shard_in': ('routers', 'shard_queue', avg),
                        'udp': ('routers', 'udp_q', avg),
                        'shard_jobs': ('shards', 'job_q', avg),
                        'shard_results': ('shards', 'result_queue', avg),
                    },
                    'counts': {
                        'hashes': ('shards', 'tracker_hashes', sum),
                        'peers': ('shards', 'tracker_peers', max),
                    },
                    'cpu': {
                        'user': (rusage, 'user', None),
                        'system': (rusage, 'system', None),
                    },
                    'memory': {
                        'rss': (rusage, 'rss', None),
                    }
                }

                try:
                    hostname = socket.gethostname()

                    sensors = []
                    for label1, desc1 in sensors_to_send.iteritems():
                        for label2, (where, key, aggr) in desc1.iteritems():
                            for node in (hostname, self.dc):
                                if not node:
                                    continue

                                labels = {
                                    label1: label2,
                                    'node': node
                                }

                                try:
                                    if isinstance(where, dict):
                                        value = where[key]
                                    else:
                                        value = aggr(v[key] for v in self.stats[where].itervalues())
                                except ValueError:
                                    value = 0

                                sensors.append({
                                    'labels': labels,
                                    'ts': tsint,
                                    'value': value
                                })

                    self.sensorsqueue.put(sensors)
                except Exception as ex:
                    self.log.warning('Unable to send statistics %s: %s', type(ex).__name__, ex)

                with self.statslock:
                    for shardname, shardstats in self.stats['shards'].iteritems():
                        for statname in shardstats:
                            if statname in ('rss', 'cpu_utime', 'cpu_stime', 'tracker_hashes', 'tracker_peers'):
                                continue
                            else:
                                shardstats[statname] = 0

                    for routername, routerstats in self.stats['routers'].iteritems():
                        for statname in routerstats:
                            if statname in ('rss', 'cpu_utime', 'cpu_stime'):
                                continue
                            else:
                                routerstats[statname] = 0

            finally:
                time.sleep(1 - ts % 1)

    def solomon_pusher(self):
        session = requests.session()

        try:
            solomon_oauth = open('/var/lib/skybone-coord/secrets/solomon_oauth', 'rb').read().strip()
        except BaseException:
            solomon_oauth = 'fake'

        while not self._stopev.isSet():
            try:
                to_send = {
                    'sensors': []
                }

                while True:
                    try:
                        sensors_data = self.sensorsqueue.get(block=False)
                    except Queue.Empty:
                        break

                    to_send['sensors'].extend(sensors_data)

                response = session.post(
                    'http://solomon.yandex.net/api/v2/push?'
                    'project=skybonecoord&'
                    'cluster=%s&'
                    'service=tracker' % (self.cfg.main.cluster, ),
                    data=simplejson.dumps(to_send),
                    headers={
                        'Content-Type': 'application/json',
                        'Authorization': 'OAuth %s' % (solomon_oauth, )
                    },
                    timeout=10
                )
                response.raise_for_status()
            except Exception as ex:
                self.log.warning('Unable to send statistics: %s: %s', type(ex).__name__, ex)
            finally:
                time.sleep(15)

    def start(self):
        self._stopped.clear()

        thr = self._stats_thread = threading.Thread(target=self.statsthread)
        thr.daemon = True
        thr.start()

        thr = self._solomon_thread = threading.Thread(target=self.solomon_pusher)
        thr.daemon = True
        thr.start()

        try:
            if 'python' in os.path.basename(sys.executable):
                base_args = [sys.executable, '-Bttu']
                router_path = 'bin'
                shard_path = 'bin'
            else:
                base_args = []
                router_path = os.path.join('bin','router')
                shard_path = os.path.join('bin','shard')
            router_args = base_args + [self.path.join(router_path,'skybone-coord-router').strpath]
            router_args += [
                '--port', str(self.cfg.tracker.port),
                '--bind-v4', self.cfg.tracker.get('listen', {}).get('v4', 'any'),
                '--bind-v6', self.cfg.tracker.get('listen', {}).get('v6', 'any'),
            ]

            shard_args = base_args + [self.path.join(shard_path,'skybone-coord-shard').strpath]
            shard_args += [
                '--interval-min', str(self.cfg.tracker.announce.interval[0]),
                '--interval-max', str(self.cfg.tracker.announce.interval[1]),
                '--interval-leech-min', str(self.cfg.tracker.announce.interval_leech[0]),
                '--interval-leech-max', str(self.cfg.tracker.announce.interval_leech[1]),
                '--return-peers', str(self.cfg.tracker.announce.return_peers),
                '--return-seeders-ratio', '%.2f' % (self.cfg.tracker.announce.return_seeders_ratio, ),
                '--save-state-every', '300',
                '--save-state-rnd', str(int(random.random() * 300)),
            ]
            if 'blocked_nets' in self.cfg.tracker:
                shard_args += [
                    '--blocked-nets', ';'.join(self.cfg.tracker.blocked_nets),
                ]

            if os.environ.get('PROFILE', '0') == '1':
                router_args += ['--profile']
                shard_args += ['--profile']
                self.log.warning('ENABLED PROFILING')

            all_shards = []

            for i in range(self.cfg['main']['shards']):
                name = 'shard_%03d' % (i, )
                sockname = 'skybone_coord_%s' % (name, )

                args = shard_args + [
                    '--sockname', sockname,
                    '--state-file', self.wpath.join('%02d_%02d_state.msgp' % (self.cfg.main.shards, i + 1)).strpath,
                    '--logprefix', '%d.%d' % (self.cfg['main']['shards'], i + 1)
                ]
                if self.cfg.get('logpath') is not None:
                    args = args + ['--log-path', self.cfg['logpath']]

                thr = threading.Thread(
                    target=self._process, name=name,
                    args=(name, args, self._stopev, 'shards')
                )
                thr.daemon = True

                self._process_threads.append(thr)

                thr.start()

                all_shards.append(sockname)

            for i in range(self.cfg['main']['routers']):
                name = 'router_%02d' % (i, )

                args = router_args + [
                    '--shards', ','.join(all_shards),
                    '--logprefix', '%d' % (i + 1, )
                ]
                if self.cfg.get('logpath') is not None:
                    args = args + ['--log-path', self.cfg['logpath']]

                thr = threading.Thread(
                    target=self._process, name=name,
                    args=(name, args, self._stopev, 'routers')
                )
                thr.daemon = True

                self._process_threads.append(thr)

                thr.start()

            self.log.info('Started %d routers and %d shards', self.cfg['main']['routers'], self.cfg['main']['shards'])

        except BaseException:
            self._stopev.set()
            raise

    def stop(self):
        self._stopev.set()

        for proc in self._child_procs:
            if proc.poll() is None:
                proc.send_signal(signal.SIGINT)

    def join(self, ignore_sigint=False):
        for thr in self._process_threads:
            while thr.isAlive():
                try:
                    thr.join(timeout=1)
                except KeyboardInterrupt:
                    if not ignore_sigint:
                        raise

        if self._stats_thread:
            while self._stats_thread.isAlive():
                try:
                    self._stats_thread.join(timeout=1)
                except KeyboardInterrupt:
                    if not ignore_sigint:
                        raise


def main(ctx, fastexit=2):
    log = ctx.log

    log.info('Initializing servlet')

    random.seed()

    app = SkyboneCoordDaemonV2(path=ctx.appdir, wpath=ctx.workdir, cfg=ctx.cfg, log=log)
    app.start()

    try:
        app.join()
    except KeyboardInterrupt:
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        log.info('Caught SIGINT, stopping daemon')
        app.stop()
        app.join(ignore_sigint=True)
    except SystemExit as ex:
        if ex.args:
            log.warning(ex)
        else:
            log.warning('Got SystemExit exception in main loop')
        raise
    except BaseException:
        log.critical('Unhandled exception: %s, exit immidiately!' % (formatException(), ))
        os._exit(1)

    log.info('Stopped')
    os._exit(0)
