import argparse
import fcntl
import gc
import logging
import msgpack
import os
import psutil
import Queue
import resource
import select
import signal
import simplejson
import socket
import sys
import threading
import time
import traceback
import errno

from setproctitle import setproctitle

from api.logger import SkynetLoggingHandler

from .tracker import Tracker, MallformedPacket

from core.daemon import logger


BASE_PROCTITLE = 'skybone-coord-shard '


def pwatch():
    while True:
        if os.getppid() == 1:
            os._exit(1)
        time.sleep(1)


def setup_logger(logprefix, logpath=None):
    log = logging.getLogger('')
    log.setLevel(logging.DEBUG)
    if logpath is not None:
        handler = logging.handlers.RotatingFileHandler(os.path.join(logpath, 'skybone-coord-shard.log'), maxBytes=1.5*(1<<30), backupCount=2)
        handler.setFormatter(logger.get_formatter('full'))
    else:
        handler = SkynetLoggingHandler(app='skybone-coord', filename='skybone-coord-shard.log')
    handler.setLevel(logging.DEBUG)
    log.addHandler(handler)
    return log.getChild('shard').getChild(logprefix)


def avg(v):
    v = list(v)
    if not v:
        return 0
    return sum(v) / len(v)


class SkyboneCoordShard(object):
    def __init__(
        self, state_file, sockname,
        interval, interval_leech, return_peers, return_seeders_ratio,
        save_state_every, save_state_rnd, blocked_nets,
        logger
    ):
        self.sockname = sockname
        self.state_file = state_file

        self._stopev = threading.Event()
        self._stopev.clear()

        self._sock = None

        self._acceptor_thread = None
        self._tracker_thread = None
        self._conn_threads = []

        self.jobs = Queue.Queue(maxsize=10)
        self._result_queues = []

        self.log = logger

        self.trk = Tracker(
            interval=interval,
            interval_leech=interval_leech,
            return_peers=return_peers,
            return_seeders_ratio=return_seeders_ratio,
            blocked_nets=blocked_nets,
            logger=self.log.getChild('trk')
        )

        self.stats = {
            'packets_cnt': 0,
            'job_q': 0,
            'result_queue': 0
        }

        self.statslock = threading.Lock()
        self.stage = 'init'

        self.saving_state = False

        self._state_save_every = save_state_every
        self._state_save_ts = int(time.time()) - save_state_rnd

    def statsprinter(self):
        while not self._stopev.isSet():
            try:
                ts = time.time()

                self.stats['job_q'] = self.jobs.qsize()
                self.stats['result_queue'] = avg(q.qsize() for q in self._result_queues)
                self.stats['result_queue_max'] = avg(q.maxsize for q in self._result_queues)

                proc = psutil.Process(os.getpid())

                resusage = resource.getrusage(resource.RUSAGE_SELF)
                try:
                    self.stats['rss'] = proc.memory_info().rss
                except:
                    self.stats['rss'] = proc.get_memory_info().rss
                self.stats['cpu_utime'] = resusage.ru_utime
                self.stats['cpu_stime'] = resusage.ru_stime

                self.stats['res_q'] = '%3d/%3d' % (
                    avg(q.qsize() for q in self._result_queues),
                    avg(q.maxsize for q in self._result_queues),
                )

                trk_stats = self.trk.get_stats()
                self.stats['tracker_hashes'] = trk_stats[2]
                self.stats['tracker_peers'] = trk_stats[0]

                sys.stdout.write('STATS: %s\n' % (simplejson.dumps(self.stats), ))
                sys.stdout.flush()

                if self.stage == 'ok':
                    proctitle = '%s [rps:%5d|queue:%4d|rqueue:%s|peers:%6d|conns:%6d|hashes:%d]'
                    if self.saving_state:
                        proctitle = proctitle[:-1] + '|savestate]'

                    setproctitle(proctitle % (
                        BASE_PROCTITLE,
                        self.stats['packets_cnt'], self.stats['job_q'], self.stats['res_q'],
                        trk_stats[0], trk_stats[1], trk_stats[2]
                    ))
                else:
                    proctitle = '%s [stage:%s]' % (BASE_PROCTITLE, self.stage)
                    setproctitle(proctitle)

                for k in self.stats:
                    self.stats[k] = 0

                if self.stage == 'ok' and ts - self._state_save_ts >= self._state_save_every:
                    self.save_state()
                    self._state_save_ts = ts

                time.sleep(1 - ts % 1)
            except Exception as ex:
                self.log.critical('statsprinter died')
                self.log.critical(traceback.format_exc())
                time.sleep(60)

    def savestat(self, n, v):
        with self.statslock:
            self.stats[n] += v

    def logerr(self, msg, *args):
        try:
            if args:
                msg = msg % args

            sys.stderr.write(msg + '\n')
            self.log.error(msg)
        except Exception:
            sys.stderr.write('Unable to deliver error log\n')
            sys.stderr.write(traceback.format_exc())

    def _tracker_worker(self):
        while not self._stopev.isSet():
            try:
                while not self._stopev.isSet():
                    try:
                        result_queue, jobs = self.jobs.get(timeout=1)
                        break
                    except Queue.Empty:
                        pass
                else:
                    # stopev is set probably
                    continue

                if result_queue is None:
                    continue

                if self._stopev.isSet():
                    return

                responses = []
                for tid, peerip, peerport, data in jobs:
                    try:
                        response = self.trk.handle_packet(
                            data, peerip, peerport
                        )

                        self.savestat('packets_cnt', 1)

                        responses.append((tid, response))

                    except MallformedPacket:
                        # should be already logged by tracker itself
                        continue
                    except Exception as ex:
                        self.logerr(
                            'Unable to process request from peer %r: %s: %s',
                            (peerip, peerport), type(ex).__name__, ex
                        )
                        continue

                if responses:
                    while not self._stopev.isSet():
                        try:
                            result_queue.put(responses, timeout=1)
                            break
                        except Queue.Full:
                            pass
                    else:
                        # Finished, put to queue
                        return

            except Exception as ex:
                self.logerr('Unhandled exception in tracker thread: %s: %s', type(ex).__name__, ex)
                self.logerr(traceback.format_exc())
                time.sleep(1)

    def _connresulter(self, conn, queue, state):
        while not self._stopev.isSet():
            try:
                responses = queue.get(timeout=1)
                state[0] -= 1
            except Queue.Empty:
                continue

            if responses is None:
                return

            try:
                for idx, (tid, response) in enumerate(responses):
                    if response is not None:
                        responses[idx] = (tid, self.trk.msg_packer.pack(response))

                responses_msgpack = self.trk.msg_packer.pack(responses)

                # router => shard conn has timeout for 1 sec usually
                # so we are able to check stopflag if no data is coming
                #
                # but to send results back we dont need timeout at all
                # it could block forever if router will not accept data back
                #
                # but this is okay
                old_conn_timeout = conn.gettimeout()
                conn.settimeout(None)
                try:
                    conn.sendall(responses_msgpack)
                finally:
                    conn.settimeout(old_conn_timeout)
            except:
                self.logerr(traceback.format_exc())
                return

    def _connworker(self, conn, peer):
        try:
            state = [0]  # number of pending jobs
            results_queue = Queue.Queue(maxsize=3)

            self._result_queues.append(results_queue)

            thr = threading.Thread(target=self._connresulter, args=(conn, results_queue, state))
            thr.daemon = True
            thr.start()

            try:
                unpacker = msgpack.Unpacker()

                # This timeout is used for asking new jobs from router
                # timeout is set to 1, so we can check stopev flag every second
                conn.settimeout(1)

                while True:
                    try:
                        data = conn.recv(1024 * 8)
                    except socket.error:
                        if self._stopev.isSet():
                            return
                        continue

                    if self._stopev.isSet():
                        conn.shutdown(socket.SHUT_RDWR)
                        return

                    if not data:
                        break

                    if not thr.isAlive():
                        raise Exception('Resulter thread died')

                    unpacker.feed(data)

                    for jobs in unpacker:
                        state[0] += 1  # increase pending jobs before adding to queue to avoid possible negative nums
                        self.jobs.put((results_queue, jobs))
            finally:
                while thr.isAlive():
                    try:
                        results_queue.put(None, timeout=0.1)
                    except Queue.Full:
                        pass

                    thr.join(timeout=1)

        except Exception as ex:
            self.logerr('Connection worker died with %s: %s', type(ex).__name__, str(ex))
            self.logerr(traceback.format_exc())
            conn.shutdown(socket.SHUT_RDWR)
            self._conn_threads.remove(threading.currentThread())

            try:
                # Wait all scheduled jobs, to prevent tracker working blocking on
                # result_queue put.
                for cnt in range(state[0]):
                    results_queue.get(timeout=60)
            except:
                self.logerr('Unable to wait all scheduled jobs, exiting')
                os._exit(1)

        finally:
            self._result_queues.remove(results_queue)

    def _acceptor(self):
        self._sock.settimeout(1)

        while True:
            try:
                try:
                    conn, peer = self._sock.accept()
                except socket.error:
                    if self._stopev.isSet():
                        return
                    continue

                if self._stopev.isSet():
                    return

                self.log.info('Accepted new conn from peer %r', peer)
                conn.send('\x42')

                thr = threading.Thread(target=self._connworker, args=(conn, peer))
                thr.daemon = True

                self._conn_threads.append(thr)

                thr.start()

            except Exception as ex:
                self.logerr('Unhandled exception in acceptor: %s: %s', type(ex).__name__, ex)
                conn.shutdown(socket.SHUT_RDWR)
                time.sleep(1)

    def start(self):
        self.stage = 'starting'

        try:
            thr = threading.Thread(target=self.statsprinter)
            thr.daemon = True
            thr.start()

            self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

            deadline = time.time() + 120
            while True:
                try:
                    self._sock.bind('\0' + self.sockname)
                except socket.error as ex:
                    if ex.errno == getattr(errno, 'EADDRINUSE', 98):
                        if time.time() >= deadline:
                            raise
                        time.sleep(0.1)
                        continue
                    else:
                        raise
                else:
                    break


            self._sock.listen(socket.SOMAXCONN)

            self.load_state()

            self._tracker_thread = thr = threading.Thread(target=self._tracker_worker)
            thr.daemon = True
            thr.start()

            self._acceptor_thread = thr = threading.Thread(target=self._acceptor)
            thr.daemon = True
            thr.start()

            self.log.info('Waiting for connections')
            self.stage = 'ok'
        except:
            self._stopev.set()
            raise

    def stop(self):
        if not self._stopev.isSet():
            save_state = True
            self._stopev.set()
        else:
            save_state = False

        if self._tracker_thread:
            try:
                self.jobs.put((None, None), timeout=1)
            except Queue.Full:
                pass

        self._sock.close()

        if save_state:
            self.log.info('Joining self before saving state')
            self.join()
            self.save_state()

    def join(self, ignore_sigint=False):
        for thr in [self._acceptor_thread, self._tracker_thread]:
            if thr is None:
                continue

            while thr.isAlive():
                try:
                    thr.join(timeout=1)
                except KeyboardInterrupt:
                    if not ignore_sigint:
                        raise

    def save_state(self):
        if self.saving_state:
            return

        self.stage = 'save_state'
        self.saving_state = True

        pipe = os.pipe()
        self.log.info('Forking (pipes %r)...', pipe)

        logging._acquireLock()

        try:
            pid = os.fork()
        finally:
            logging._releaseLock()

        if pid:
            os.close(pipe[1])

            rlist = select.select([pipe[0]], [], [], 60)[0]
            killed = False
            if pipe[0] not in rlist:
                self.log.critical('UNABLE TO SAVE STATE -- FORK PIPE CLOSE TIMED OUT (PID=%d)', pid)
                os.kill(pid, signal.SIGKILL)
                killed = True

            os.read(pipe[0], 1)
            os.close(pipe[0])

            def _collect(pid):
                os.waitpid(pid, 0)
                self.saving_state = False

            if not killed:
                self.stage = 'ok'
                thr = threading.Thread(target=_collect, args=(pid, ))
                thr.daemon = True
                thr.start()
            else:
                self.stage = 'killed'
                _collect(pid)
                self.stage = 'ok'

            return

        def _harakiri():
            wait = 600

            import signal
            time.sleep(wait)

            pid = os.getpid()
            self.log.critical('UNABLE TO SAVE STATE IN %d SECS, KILLING MYSELF (PID=%d)', wait, pid)
            os.kill(pid, signal.SIGKILL)

        os.setpgrp()  # be the process leader, so we will not be killed by skycore upon restart

        thr = threading.Thread(target=_harakiri)
        thr.daemon = True
        thr.start()

        self.stage = 'save_state_fork'
        setproctitle('%s [fork: save_state]' % (BASE_PROCTITLE, ))

        os.close(sys.stdin.fileno())
        os.close(sys.stdout.fileno())
        os.close(sys.stderr.fileno())

        sys.stdin.close()
        sys.stdout.close()
        sys.stderr.close()

        os.close(pipe[0])

        self.log = self.log.getChild('fork')
        self.log.info('Saving state...')

        try:
            os.makedirs(os.path.dirname(self.state_file))
        except:
            pass

        lockfile = self.state_file + '.lock'
        lockfile_fp = open(lockfile, 'wb')

        fcntl.lockf(lockfile_fp.fileno(), fcntl.LOCK_EX)
        os.write(pipe[1], '1')
        os.close(pipe[1])

        self.log.info('Acquired state lock')

        state = {
            'version': 1,
            'tracker': self.trk.get_state()
        }

        state_file_tmp = self.state_file + '.tmp'

        self.log.debug('Saving state to %s...', state_file_tmp)
        ts = time.time()

        with open(state_file_tmp, 'wb') as fp:
            msgpack.dump(state, fp)

        if os.path.exists(self.state_file):
            os.unlink(self.state_file)

        self.log.debug('Moving state %s => %s', state_file_tmp, self.state_file)
        os.rename(state_file_tmp, self.state_file)

        self.log.info('Saved state in %0.4fs', time.time() - ts)

        os.unlink(lockfile)
        os._exit(0)

    def load_state(self):
        self.stage = 'load_state_init'

        lockfile = self.state_file + '.lock'
        if os.path.exists(lockfile):
            lockfile_fp = open(lockfile, 'wb')
            try:
                fcntl.lockf(lockfile_fp, fcntl.LOCK_EX | fcntl.LOCK_NB)
            except IOError:
                self.log.info('State lock is being held, wait for it...')
                self.stage = 'load_state_wait'
                fcntl.lockf(lockfile_fp, fcntl.LOCK_EX)

            self.log.info('Acquired state lock')
        else:
            self.log.info('No state lock file -- just load state')

        self.stage = 'load_state'

        try:
            state = msgpack.load(open(self.state_file, 'rb'))
        except Exception as ex:
            self.log.info('Unable to load state: %s', ex)
            return

        if not isinstance(state, dict):
            self.log.warning('Unable to load state: is not a dictionary')
            return

        if state.get('version', None) != 1:
            self.log.warning('Unable to load state: dont know how to load version: %r', state.get('version', None))
            return

        self.trk.set_state(state['tracker'])


def main():
    global BASE_PROCTITLE
    setproctitle(BASE_PROCTITLE)

    thr = threading.Thread(target=pwatch)
    thr.daemon = True
    thr.start()

    parser = argparse.ArgumentParser()

    parser.add_argument('--sockname', required=True)
    parser.add_argument('--logprefix', default=str(os.getpid()))
    parser.add_argument('--log-path', default=None, required=False)
    parser.add_argument('--state-file', required=True)
    parser.add_argument('--interval-min', type=int, required=True)
    parser.add_argument('--interval-max', type=int, required=True)
    parser.add_argument('--interval-leech-min', type=int, required=True)
    parser.add_argument('--interval-leech-max', type=int, required=True)
    parser.add_argument('--return-peers', type=int, required=True)
    parser.add_argument('--return-seeders-ratio', type=float, required=True)
    parser.add_argument('--save-state-every', type=int, required=True)
    parser.add_argument('--save-state-rnd', type=int, required=True)
    parser.add_argument('--profile', action='store_true')
    parser.add_argument('--blocked-nets')

    args = parser.parse_args()
    log = setup_logger(args.logprefix, args.log_path)

    if args.profile:
        BASE_PROCTITLE += ' (PROF)'
        setproctitle(BASE_PROCTITLE)

    log.info('Start with args: %r', vars(args))

    shard = SkyboneCoordShard(
        state_file=args.state_file,
        sockname=args.sockname,
        interval=(args.interval_min, args.interval_max),
        interval_leech=(args.interval_leech_min, args.interval_leech_max),
        return_peers=args.return_peers,
        return_seeders_ratio=args.return_seeders_ratio,
        save_state_every=args.save_state_every,
        save_state_rnd=args.save_state_rnd,
        blocked_nets=args.blocked_nets.split(';') if args.blocked_nets is not None else [],
        logger=log
    )

    # disable garbage collection so that gc.enable() in 3rd party modules
    # won't turn it back on
    gc.set_threshold(0)

    shard.start()

    if args.profile:
        import yappi
        yappi.set_clock_type('cpu')
        yappi.start(builtins=True)
        log.warning('ENABLED PROFILING')

    try:
        shard.join()
    except KeyboardInterrupt:
        signal.signal(signal.SIGINT, signal.SIG_IGN)
        sys.stderr.write('Caught SIGINT, stopping shard\n')
        log.info('Caught SIGINT, stopping shard')
        shard.stop()
        shard.join(ignore_sigint=True)
    except SystemExit as ex:
        if ex.args:
            sys.stderr.write(str(ex) + '\n')
            log.warning(ex)
        else:
            sys.stderr.write('Got SystemExit exception in main loop\n')
            log.warning('Got SystemExit exception in main loop')
        raise
    except BaseException:
        sys.stderr.write('Unhandled exception: %s, exit immidiately!\n' % (traceback.format_exc(), ))
        log.critical('Unhandled exception: %s, exit immidiately!', traceback.format_exc())
        os._exit(1)

    log.info('Stopped')

    if args.profile:
        stats = yappi.get_func_stats()
        stats.save('/tmp/shard_%d.prof' % (os.getpid(), ), type='callgrind')

    os._exit(0)
