from __future__ import absolute_import
from .session import Session
from .client import Client
from .metrics import CqueueMetrics
from .poll import Poll
from .handle import RHandle, _cqueue_unwrap
from .. import cfg, msgpackutils as msgpack
from ..utils import configure_log, monotime, log as root
from ..poll import Selectable
from ..mocksoul_rpc.server import Server as RPCServer, RPC as MRPC

from ya.skynet.library.auth.sign import FileKeysSignManager, ChainSignManager
from ya.skynet.library.auth.signagent import SignAgentClient
from ya.skynet.util.console import setProcTitle
from ya.skynet.util import pickle

import argparse
import tempfile
import os
import six
import sys
import stat
import gevent.select


def fix_hosts(hosts):
    return {
        k: v if isinstance(v, six.string_types) else tuple(v)
        for k, v in list(hosts.items())
    }


class Daemon(object):
    def __init__(self, rpc_path, allow_no_keys=False):
        super(Daemon, self).__init__()
        # FIXME configurable set of clients is needed
        self.log = root().getChild('daemon')
        opts = {
            'session_class': ClientSession,
            'allow_no_keys': allow_no_keys,
            'select_function': gevent.select.select,
        }
        self.clients = {
            'msgpack': Client(transport='msgpack', **opts),
        }
        try:
            self.clients['netlibus'] = Client(transport='netlibus', **opts)
        except ImportError:
            self.use_netlibus = False
        else:
            self.use_netlibus = True

        self.running = False
        self.tasks = {}
        self.rpc_path = rpc_path
        self.rpc = MRPC(self.log.getChild('rpc'))
        self.rpc_server = RPCServer(
            log=self.log.getChild('rpcserver'),
            backlog=10,
            max_conns=100,
            unix=rpc_path,
        )
        self.rpc_server.register_connection_handler(self.rpc.get_connection_handler())
        self.rpc.mount(self.run_task, name='run_task', typ='full')
        self.rpc.mount(self.rpc_send, name='rpc_send')
        self.rpc.mount(self.stop_server, name='stop_server')
        self.rpc.mount(self.stop_task, name='stop_task')
        self.rpc.mount(self.take_incoming, name='take_incoming', typ='generator')

    def __del__(self):
        self.stop_server()

    def create_signer(self, socket_path=None, peer_uid=None):
        sms = []
        fksm = FileKeysSignManager(
            commonKeyDirs=None,
            userKeyDirs=None,
            keyFiles=cfg.client.Auth.KeyFiles
        )
        fksm.load()
        sms.append(fksm)

        if not socket_path:
            return fksm

        if not os.path.isabs(socket_path):
            self.log.warning("path must be absolute, got: %r", socket_path)
            return fksm

        if not os.path.exists(socket_path):
            self.log.warning("path %r doesn't exist", socket_path)
            return fksm

        st = os.stat(socket_path)

        if not stat.S_ISSOCK(st.st_mode):
            self.log.warning("path %r is not a socket", socket_path)
            return fksm

        dir_st = os.stat(os.path.dirname(socket_path))
        if (
            st.st_uid != peer_uid
            or dir_st.st_uid != peer_uid
            or (dir_st.st_mode & 0o777) | 0o700 != 0o700
        ):
            self.log.warning("socket %r doesn't belong to uid %s or has wrong rights", socket_path, peer_uid)

        sa = SignAgentClient(socket_path)
        sa.load()
        sms.append(sa)

        return ChainSignManager(sms)

    def run_task(self, job, uuid, task, description, hosts, egg, user, socket_path=None, iter=False, netlibus=None):
        peer = job.peer_id
        hosts = fix_hosts(hosts)
        signer = self.create_signer(socket_path, peer[0])
        if netlibus is None:
            netlibus = self.use_netlibus
        options = dict(
            user=user,
            aggregate=True,  # forced, no change!
            pipeline=True,  # forced, no change!
            netlibus=netlibus,
            objdumped=True,
        )

        task = msgpack.dumps((task, egg))  # FIXME we're emulating serialize() task format here, it's wrong

        return self.execute(uuid, task, description, hosts, options, signer, iter=iter, netlibus=netlibus)

    def execute(self, uuid, task, description, hosts, options, signer, iter, netlibus=None):
        if netlibus is None:
            netlibus = self.use_netlibus

        client = self.clients['netlibus' if netlibus else 'msgpack']

        session = self.tasks.setdefault(uuid, None)
        if session is not None:
            return session.id

        session = DaemonSession(
            hosts,
            self,
            client.start_session(
                hosts,
                task,
                uuid=uuid,
                params=None,  # FIXME
                signer=signer,
                description=description,
                **options
            ),
        )
        self.tasks[session.id] = session

        return session.id

    def rpc_send(self, taskid, typ, uuid, msg, addrs, tree_data=None):
        session = self.tasks[taskid]
        session.rpc_send(typ, uuid, msg, addrs, tree_data=tree_data)
        return True

    def take_incoming(self, uuids, block=False, timeout=None):
        poller = Poll(
            filter(bool, [self.tasks.get(uuid) for uuid in uuids]),
            select_function=gevent.select.select,
        )

        for task in poller.poll(timeout if block else 0):
            for addr, label, msg in task.poll(0):
                yield task.id, addr, label, msg

    def stop_server(self):
        self.rpc_server.stop()
        self.rpc.stop()
        self.rpc.join()
        self.rpc_server.join()

    def stop_task(self, taskid):
        task = self.tasks.pop(taskid, None)
        if not task:
            return False
        task.shutdown()
        return True

    def serve_forever(self):
        if self.running:
            return
        try:
            self.log.info('thinclient is starting on %s', self.rpc_path)
            self.running = True
            self.rpc_server.start()
            self.rpc_server.join()
        except KeyboardInterrupt:
            self.log.info('serve_forever interrupted by SIGINT, stopping')
        except Exception as e:
            self.log.warning('serve_forever failed: %s', e, exc_info=sys.exc_info())
        finally:
            self.rpc_server.stop()
            self.rpc.stop()
            self.running = False


class ClientSession(Session):
    def handle_result(self, hostid, idx, data):
        res, err = self.loads(data)  # FIXME again loads
        if err:
            self.scheduler.task_done(hostid, monotime())

        self.results.append(((hostid, self.hosts[hostid].actual_addr), 'result', idx, data))

        self._notify()

    def handle_response(self, hostid, addr, label, index, data):
        if label == 'result':
            res, err = self.loads(data)  # FIXME duplicate load on this and client's side
            if err:
                self.scheduler.task_done(hostid, monotime())
        self.results.append((addr, label, index, data))
        self._notify()

    def make_results_handle(self):
        return Handle(self, self.results)


class DaemonSession(Selectable):
    def __init__(self, hosts, client, session):
        super(DaemonSession, self).__init__()
        self.client = client
        self.session = session
        self.handle = session.make_results_handle()
        self.metrics = CqueueMetrics(len(session.hosts))

    def _get_event_fd(self):
        return self.session._get_event_fd()

    def _is_data_ready(self):
        return self.session._is_data_ready()

    @property
    def id(self):
        return self.session.taskid

    def rpc_send(self, typ, uuid, msg, addrs, tree_data):
        return self.session.rpc_send(
            rpctype=typ,
            rpcid=uuid,
            data=msg,
            user_hosts=addrs,
            tree_data=tree_data
        )

    @property
    def running(self):
        return not self.session.is_empty()

    def shutdown(self):
        if self.client is not None:
            self.session.stop()
            self.client = None  # we no longer need to keep reference

    def poll(self, timeout=None):
        # Hide index and stat data, unpack address.
        for addr, label, _, msg in self.handle.poll(timeout=timeout):
            if not self.running:
                self.shutdown()
            if label == 'result':
                res, err = self.session.loads(msg)  # FIXME triple loads!
                self.metrics.update(err, res)

                if res is not None:
                    # drop metric data
                    res = res[0]
                msg = pickle.dumps((res, _cqueue_unwrap(err)))
            yield addr, label, msg


class Handle(RHandle):
    def _pop_and_unwrap(self):
        return self._results.pop()


def parse_args():
    parser = argparse.ArgumentParser(description='cqudp client-server for thin client')
    parser.add_argument('-d', '--debug', action='store_true', default=False, help='Log to terminal for debugging')
    parser.add_argument('-l', '--loglevel',
                        default=cfg.thinclient.LogLevel,
                        choices=('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'),
                        help='log messages level')
    parser.add_argument('--rpcpath',
                        default=cfg.thinclient.RpcPath.format(TempDir=tempfile.gettempdir()),
                        help='unix path to bind')

    return parser.parse_args()


def main():
    setProcTitle('cqudp-client')
    args = parse_args()
    configure_log('thinclient', level=args.loglevel, debug=args.debug)

    daemon = Daemon(args.rpcpath)
    daemon.serve_forever()


if __name__ == '__main__':
    main()
