import os
import sys
import traceback
from socket import gethostname, getfqdn
import argparse
import errno
import contextlib
import fcntl

import gevent
from gevent import socket

from kernel.util.console import setProcTitle
from kernel.util.errors import formatException
from kernel.util.sys.user import userPrivileges

from api.logger import constructLogger

from .cauth import CauthKeyUpdater


log = constructLogger(app='keychain', filename='keychain.log')


class Server(object):
    def __init__(self):
        self.host = gethostname()
        if '.' not in self.host:
            self.host = getfqdn()

    def __enter__(self):
        log.info('Started')
        log.info('Current host: "{0}"'.format(self.host))
        return self

    def __exit__(self, exc_type, exc_value, tb):
        if exc_value is not None:
            if exc_type is KeyboardInterrupt:
                log.error('Caught SIGINT. Halting service now')
            else:
                log.error(
                    'Unhandled exception in Server:\n{0}\n{1}\nHalting service now'.format(
                        exc_value,
                        ''.join(traceback.format_tb(tb))
                    )
                )
        log.info('Stopped')

    def loop(self):
        while True:
            gevent.sleep(60)


@contextlib.contextmanager
def lock(path):
    try:
        try:
            fp = open(path, 'wb')
        except IOError as ex:
            if ex.errno == errno.ENOENT:
                os.makedirs(os.path.dirname(path))
                fp = open(path, 'wb')
            else:
                raise

        try:
            fcntl.flock(fp, fcntl.LOCK_EX | fcntl.LOCK_NB)
        except IOError:
            raise Exception('Lock %s is held by somebody else' % (path, ))

        yield
    finally:
        pass


def ping_responder(sock):
    sock.listen(8)

    while True:
        conn, peer = sock.accept()
        buff = []

        while True:
            ch = conn.recv(1)
            if ch == '\n' or not ch:
                break
            buff.append(ch)

        data = ''.join(buff)

        if data == 'PING':
            conn.sendall('OKAY\n')
            conn.shutdown(socket.SHUT_RDWR)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--sock', required=True, help='socket path')
    parser.add_argument('--lock', required=True, help='socket lock file path')
    args = parser.parse_args()

    result = 0

    try:
        with lock(args.lock):
            lock_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            try:
                lock_socket.bind(args.sock)
            except socket.error as ex:
                if ex.errno == errno.ENOENT:
                    os.makedirs(os.path.dirname(args.sock))
                elif ex.errno == errno.EADDRINUSE:
                    os.unlink(args.sock)
                else:
                    raise

                lock_socket.bind(args.sock)

            gevent.spawn(ping_responder, lock_socket)

            has_root_privileges = False
            if os.name == 'posix' and os.getuid() == 0:
                user = os.getenv('SKYNET_PROCMANUSER', 'skynet')
                priv = userPrivileges(user, limit=True)
                has_root_privileges = True
                priv.__enter__()

            server_name = 'skynet.keychain'

            setProcTitle(server_name)
            os.umask(0)  # we will control mode explicitly

            if has_root_privileges:
                log.info('Started with ROOT')
            else:
                log.info('Started without ROOT')

            with Server() as server:
                cauth = CauthKeyUpdater(log)
                with cauth():
                    try:
                        server.loop()
                    except KeyboardInterrupt:
                        log.error('Keyboard interrupt - exit now')
                        result = 1
                    except BaseException as ex:
                        log.error('Server error: "{0}"'.format(formatException(ex)))
                        result = 1

    except BaseException as ex:
        log.error('Server error: "{0}"'.format(formatException(ex)))
        return 1

    # Clean up sock + lock files and exit
    try:
        try:
            os.unlink(args.sock)
        except:
            pass

        try:
            os.unlink(args.lock)
        except:
            pass
    finally:
        os._exit(result)


def control():
    cmd = sys.argv[1]
    sockfile = sys.argv[2]
    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

    with gevent.Timeout(30) as tout:
        try:
            sock.connect(sockfile)
            sock.sendall(cmd.upper() + '\n')
            buff = []
            while True:
                ch = sock.recv(1)
                if ch == '\n' or not ch:
                    break
                buff.append(ch)

            data = ''.join(buff)

            if data == 'OKAY':
                return 0

        except gevent.Timeout as ex:
            if ex == tout:
                sys.stderr.write('Timeout: %s\n' % (ex, ))
                return 1
            raise

    return 1
