#!/usr/bin/env python3

import argparse
import logging
import logging.handlers
import os
import socket
import subprocess as subproc
import sys
import threading
import time
import pwd

import yaml
import json

DEFAULT_CFG_PATH = '/opt/genisys/etc/production.conf'
DEFAULT_TOILER_NAME = 'genisys-run-toiler'
DEFAULT_WEB_NAME = 'genisys-run-wsgi'
DEFAULT_STATSITE_NAME = 'statsite'
DEFAULT_STATSITE_SOLOMON_SINC_NAME = 'genisys-statsite-solomon-sinc'

DEVNULL = open('/dev/null', mode='wb')

API_PORTS_START = 13250
UI_PORTS_START = 13350

STATSITE_PORTS = (13450, 13451, 13452)  # toiler, api, ui


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', default=DEFAULT_CFG_PATH)
    parser.add_argument('--log', default=None, help='Log directory')
    parser.add_argument('-q', '--quiet', action='store_true')
    parser.add_argument('--toiler-name', default=DEFAULT_TOILER_NAME)
    parser.add_argument('--web-name', default=DEFAULT_WEB_NAME)
    parser.add_argument('--statsite-name', default=DEFAULT_STATSITE_NAME)
    parser.add_argument('--statsite-solomon-sinc-name', default=DEFAULT_STATSITE_SOLOMON_SINC_NAME)
    parser.add_argument('--chuid', default='nobody' if os.getuid() == 0 else None)
    parser.add_argument('--vardir', default='/home/genisys/var')

    return parser.parse_args()


def configure_logging(quiet, main_logfile, toiler_logfile, api_logfile, ui_logfile, statsite_logfile):
    logger = logging.getLogger('')
    logger.setLevel(logging.DEBUG)

    formatter = logging.Formatter('%(asctime)s  %(name)s  %(message)s')

    def _add_fh(logger, logfile):
        fh = logging.handlers.WatchedFileHandler(logfile)
        fh.setLevel(logging.DEBUG)

        fh.setFormatter(formatter)
        logger.addHandler(fh)

    log = logger.getChild('genisys.daemon')

    toiler_log = log.getChild('proc.toiler')
    toiler_log.propagate = False

    api_log = log.getChild('proc.web.api')
    api_log.propagate = False

    ui_log = log.getChild('proc.web.ui')
    ui_log.propagate = False

    statsite_log = log.getChild('proc.statsite')
    statsite_log.propagate = False

    if not quiet:
        ch = logging.StreamHandler()
        ch.setLevel(logging.DEBUG)

        ch.setFormatter(formatter)

        [_.addHandler(ch) for _ in (logger, toiler_log, api_log, ui_log, statsite_log)]

    if main_logfile:
        _add_fh(logger, main_logfile)

    if toiler_logfile:
        _add_fh(toiler_log, toiler_logfile)

    if api_logfile:
        _add_fh(api_log, api_logfile)

    if ui_logfile:
        _add_fh(ui_log, ui_logfile)

    if statsite_logfile:
        _add_fh(statsite_log, statsite_logfile)

    return logger.getChild('genisys.daemon')


class ProcMan(object):
    def __init__(self, user, log):
        self.log = log.getChild('proc')
        self.stopflag = threading.Event()

        self.process_threads = []

        if user:
            self.uid = pwd.getpwnam(user).pw_uid
        else:
            self.uid = None

    def _preexec(self):
        if self.uid:
            os.setuid(self.uid)

    def _runner(self, stdin, args, env, log):
        while not self.stopflag.isSet():
            proc = subproc.Popen(
                args,
                stdin=subproc.PIPE, stderr=subproc.PIPE, stdout=subproc.PIPE,
                env=env,
                preexec_fn=self._preexec
            )
            stream_threads = []

            try:
                if stdin:
                    proc.stdin.write(stdin.encode('utf-8'))
                proc.stdin.close()

                for stream in (
                    proc.stdout,
                    proc.stderr,
                ):
                    thr = threading.Thread(target=self._stream_log, args=(stream, log))
                    thr.start()
                    stream_threads.append(thr)

                log.debug('started, waiting to finish/die')

                while not self.stopflag.isSet():
                    try:
                        result = proc.wait(timeout=1)
                        log.debug('finished with %r', result)
                        break
                    except subproc.TimeoutExpired:
                        continue
                else:
                    log.debug('shouldstop flag set, killing child...')
            finally:
                if not proc.poll():
                    try:
                        proc.kill()
                    except OSError:
                        pass

                proc.wait()
                [thr.join() for thr in stream_threads]

    def _stream_log(self, stream, log, prefix=None):
        while True:
            line = stream.readline()
            if not line:
                break
            if prefix:
                prefix = '%s: ' % (prefix, )
            else:
                prefix = ''

            try:
                logline = line.decode('utf-8').rstrip()
            except:
                logline = 'unable to utf-8 decode: %r' % (line, )

            try:
                log.debug('%s%s', prefix, logline)
            except:
                import traceback
                sys.stderr.write(traceback.format_exc())

                try:
                    log.critical('%s%s', prefix, 'unable to log (see stderr)')
                except:
                    pass

    def run(self, name, args, count, stdin=None, env=None, log=None):
        for idx in range(1, count + 1):
            if not log:
                log = self.log

            log = self.log.getChild(name)

            if count > 1:
                log = log.getChild(str(idx))

            if env:
                environ = os.environ.copy()
                environ.update(env)
            else:
                environ = None

            thr = threading.Thread(target=self._runner, args=(stdin, args, environ, log))
            thr.daemon = True
            thr.start()

            self.process_threads.append(thr)

    def stop(self):
        self.stopflag.set()

        for thr in self.process_threads:
            thr.join()


def run_statsite(procman, vardir, statsite_binary, statsite_sinc_binary, cfg, log):
    log.info('Running statsite for toilers')

    base_statsite_config = '''\
[statsite]
port = {port}
udp_port = {port}
extended_counters = 1
extended_counters_include = lower,upper,sum
timers_include = count,mean,stdev,sum,sum_sq,lower,upper
flush_interval = 10
use_type_prefix = 0
stream_cmd = {stream_cmd}
'''

    base_stream_cmd = (
        '{sinc} -c \''
        '{{{{"project": "genisys", "cluster": "{cluster}", "service": "{{service}}", "host": "{hostname}"}}}}'
        '\' '
        '-H \'{{headers}}\''
    ).format(
        sinc=statsite_sinc_binary,
        cluster=cfg['cluster'],
        hostname=socket.gethostname(),
    )

    statsite_toiler_config = base_statsite_config.format(
        port=STATSITE_PORTS[0],
        stream_cmd=base_stream_cmd.format(
            service='toiler',
            headers=json.dumps(cfg['solomon_headers']),
        ),
    )

    statsite_api_config = base_statsite_config.format(
        port=STATSITE_PORTS[1],
        stream_cmd=base_stream_cmd.format(
            service='api',
            headers=json.dumps(cfg['solomon_headers']),
        ),
    )

    statsite_ui_config = base_statsite_config.format(
        port=STATSITE_PORTS[2],
        stream_cmd=base_stream_cmd.format(
            service='ui',
            headers=json.dumps(cfg['solomon_headers']),
        ),
    )

    toiler_config_fn = os.path.join(vardir, 'statsite_toiler.ini')
    open(toiler_config_fn, 'w').write(statsite_toiler_config)

    api_config_fn = os.path.join(vardir, 'statsite_api.ini')
    open(api_config_fn, 'w').write(statsite_api_config)

    ui_config_fn = os.path.join(vardir, 'statsite_ui.ini')
    open(ui_config_fn, 'w').write(statsite_ui_config)

    procman.run(
        'statsite.toiler.%d' % (STATSITE_PORTS[0], ), [statsite_binary, '-f', toiler_config_fn], 1
    )

    procman.run(
        'statsite.api.%d' % (STATSITE_PORTS[1], ), [statsite_binary, '-f', api_config_fn], 1
    )

    procman.run(
        'statsite.ui.%d' % (STATSITE_PORTS[2], ), [statsite_binary, '-f', ui_config_fn], 1
    )


def run_toilers(procman, vardir, toiler_binary, cfg, log):
    workers_count = cfg['toiler']['workers']
    if not workers_count:
        log.info('Will run 0 toilers (as wanted by config)')
        return

    log.info('Will run %d toilers', workers_count)

    toiler_stdin = '''\
MONGODB_URI = ['{mongo_uri}']
SECRET_KEY = '{secret_key}'
STATSD_HOST = '127.0.0.1'
STATSD_PORT = {statsd_port}
SMTP_SERVER = '{smtp_server}'
SMTP_PORT = {smtp_port}
STAFF_HEADERS = {staff_headers!r}
'''.format(
        mongo_uri=cfg['mongo']['uri'],
        secret_key=cfg['web']['api']['flask']['secret_key'],
        statsd_port=STATSITE_PORTS[0],
        smtp_server=cfg['mail']['smtp']['host'],
        smtp_port=cfg['mail']['smtp']['port'],
        staff_headers=cfg['toiler']['staff_headers'],
    )

    toiler_conf_fn = os.path.join(vardir, 'genisys_toiler.cfg')
    open(toiler_conf_fn, 'w').write(toiler_stdin)

    procman.run('toiler', [toiler_binary], workers_count, env={'GENISYS_TOILER_CONFIG': toiler_conf_fn})


def run_wsgi_api(procman, vardir, web_binary, cfg, log):
    workers_count = cfg['web']['api']['workers']
    if not workers_count:
        log.info('Will run 0 api backends')
        return

    api_ports = list(range(API_PORTS_START, API_PORTS_START + workers_count))

    log.info('Will run %d api backends', len(api_ports))

    all_sockets = []

    for idx, port in enumerate(api_ports):
        uds = '/tmp/genisys_api_%d.sock' % (idx + 1, )
        web_stdin = '''\
MONGODB_URI = ['{mongo_uri}']
BLACKBOX_URI = '{blackbox_uri}'
SECRET_KEY = '{secret_key}'
STATSD_HOST = '127.0.0.1'
STATSD_PORT = {statsd_port}
PROPAGATE_EXCEPTIONS = 1
WSGI_LISTEN_ADDRESS = '{uds}'
WSGI_LOCK_PORT = {port}
'''.format(
            mongo_uri=cfg['mongo']['uri'],
            blackbox_uri=cfg['web']['blackbox_uri'],
            secret_key=cfg['web']['api']['flask']['secret_key'],
            statsd_port=STATSITE_PORTS[1],
            port=port,
            uds=uds,
        )

        worker_conf_fn = os.path.join(vardir, 'genisys_api_%d.cfg' % (idx + 1, ))
        open(worker_conf_fn, 'w').write(web_stdin)

        all_sockets.append(uds)

        procman.run(
            'web.api.%d' % (idx + 1, ), [web_binary, 'api', '--gevent'], 1,
            env={
                'GENISYS_API_CONFIG': worker_conf_fn,
            }
        )

    return all_sockets


def run_wsgi_ui(procman, vardir, web_binary, cfg, log):
    workers_count = cfg['web']['ui']['workers']
    if not workers_count:
        log.info('Will run 0 ui backeds')
        return

    app_ports = list(range(UI_PORTS_START, UI_PORTS_START + workers_count))

    log.info('Will run %d app backends', len(app_ports))

    all_sockets = []
    for idx, port in enumerate(app_ports):
        uds = '/tmp/genisys_ui_%d.sock' % (idx + 1, )
        web_stdin = '''\
MONGODB_URI = ['{mongo_uri}']
BLACKBOX_URI = '{blackbox_uri}'
SECRET_KEY = '{secret_key}'
STAFF_HEADERS = {staff_headers!r}
FORCE_URL_SCHEME = '{force_url_scheme}'
STATSD_HOST = '127.0.0.1'
STATSD_PORT = {statsd_port}
PROPAGATE_EXCEPTIONS = 1
WSGI_LISTEN_ADDRESS = '{uds}'
WSGI_LOCK_PORT = {port}
GENISYS_API_URI = 'http://genisys.yandex-team.ru'
'''.format(
            mongo_uri=cfg['mongo']['uri'],
            blackbox_uri=cfg['web']['blackbox_uri'],
            secret_key=cfg['web']['ui']['flask']['secret_key'],
            staff_headers=cfg['web']['ui']['staff_headers'],
            force_url_scheme=cfg['web']['ui']['flask']['force_url_scheme'],
            statsd_port=STATSITE_PORTS[2],
            port=port,
            uds=uds,
        )

        worker_conf_fn = os.path.join(vardir, 'genisys_ui_%d.cfg' % (idx + 1, ))
        open(worker_conf_fn, 'w').write(web_stdin)

        all_sockets.append(uds)
        procman.run(
            'web.ui.%d' % (idx + 1, ), [web_binary, 'ui', '--gevent'], 1,
            env={
                'GENISYS_WEB_CONFIG': worker_conf_fn,
            }
        )

    return all_sockets


def add_nginx_upstream(api_sockets, ui_sockets):
    if not os.path.exists('/etc/nginx/sites-enabled'):
        return

    upstream_tpl = '''\
upstream {name} {{
{servers}
}}
'''

    api_upstream_servers = ['  server unix:%s fail_timeout=1;' % (socket, ) for socket in api_sockets]
    open('/etc/nginx/sites-enabled/genisys_upstream_api.conf', 'w').write(
        upstream_tpl.format(
            name='genisys_api',
            servers='\n'.join(api_upstream_servers)
        )
    )

    ui_upstream_servers = ['  server unix:%s fail_timeout=1;' % (socket, ) for socket in ui_sockets]
    open('/etc/nginx/sites-enabled/genisys_upstream_ui.conf', 'w').write(
        upstream_tpl.format(
            name='genisys_ui',
            servers='\n'.join(ui_upstream_servers)
        )
    )

    os.system('service nginx reload')


def main():
    args = parse_args()
    log = configure_logging(
        args.quiet,
        os.path.join(args.log, 'main.log') if args.log else None,
        os.path.join(args.log, 'toiler.log') if args.log else None,
        os.path.join(args.log, 'api.log') if args.log else None,
        os.path.join(args.log, 'ui.log') if args.log else None,
        os.path.join(args.log, 'statsite.log') if args.log else None,
    )

    cfg = yaml.load(open(args.config, mode='rb'))

    procman = ProcMan(args.chuid, log)

    if not os.path.exists(args.vardir):
        os.makedirs(args.vardir)

    bindir = os.path.dirname(sys.argv[0])
    statsite_binary = os.path.join(bindir, args.statsite_name)
    statsite_sinc_binary = os.path.join(bindir, args.statsite_solomon_sinc_name)
    toiler_binary = os.path.join(bindir, args.toiler_name)
    web_binary = os.path.join(bindir, args.web_name)

    run_statsite(procman, args.vardir, statsite_binary, statsite_sinc_binary, cfg['statsite'], log)
    run_toilers(procman, args.vardir, toiler_binary, cfg, log)
    api_sockets = run_wsgi_api(procman, args.vardir, web_binary, cfg, log)
    ui_sockets = run_wsgi_ui(procman, args.vardir, web_binary, cfg, log)

    add_nginx_upstream(api_sockets, ui_sockets)

    try:
        while True:
            time.sleep(60)
    except:
        procman.stop()

if __name__ == '__main__':
    main()
