from __future__ import absolute_import

import sys
import time
import socket

import six

from ya.skynet.util import logging
from ya.skynet.util.sys.user import getUserName
from ya.skynet.library.portoshell import make_message, sign_message
from ya.skynet.library.auth.sign import FileKeysSignManager, ChainSignManager
from ya.skynet.library.auth.signagent import SignAgentClient

from ..exceptions import CQueueRuntimeError
from ..utils import log as root, reconfigure_log, poll_select, getaddrinfo, gencid, fqdn
from ..rpc.pipe import Pipe
from .. import cfg
from .client import Client
from .session import ResultsQueue, MultiTaskSession, PortoshellSession
from .metrics import CqueueMetrics
from .poll import Poll
from ..poll import Selectable


IN_ARCADIA = bool(getattr(sys, 'is_standalone_binary', False))


class CqueueClient(object):
    def __init__(self, **kwargs):
        def pop_flag(name, default):
            flag = kwargs.pop(name, default)
            if isinstance(flag, six.string_types):
                flag = (flag.lower().strip() == 'true')
            return flag

        self.extra_args = {
            'aggregate': pop_flag('aggregate', cfg.client.Transport.Aggregate),
            'pipeline': pop_flag('pipeline', cfg.client.Transport.Pipeline),
            'msgpack': pop_flag('msgpack', cfg.client.Transport.Msgpack),
            'netlibus': pop_flag('netlibus', cfg.client.Transport.Netlibus),
            'debuglog': pop_flag('debuglog', False),
            'send_own_ip': pop_flag('send_own_ip', cfg.client.Transport.SendOwnIp),
        }

        log = kwargs.pop('log', None)
        logger = log if log is not None else root().getChild('client')

        port_range = pop_flag('port_range', None)
        if port_range and not isinstance(port_range, (list, tuple)):
            port_range = cfg.client.Transport.SpecialPortRange
        self.extra_args['port_range'] = port_range

        signer = kwargs.pop('signer', None)

        system_task_type = pop_flag('native_python', False)
        if system_task_type:
            self.default_task_type = 'task_py2' if six.PY2 else 'task_py3'
        else:
            self.default_task_type = 'task'

        self.extra_args.update(kwargs)

        level = self.extra_args.get('loglevel', cfg.client.LogLevel)
        debug = self.extra_args['debuglog']
        if log is None:
            logger.setLevel(level)
        reconfigure_log('client', logger=logger, level=level, debug=debug, rename_levels=log is None)
        self.log = logging.MessageAdapter(
            logger,
            fmt="[x%(cid)s] %(message)s",
            data={'cid': gencid()},
        )

        self._signer = signer or create_signer(log=self.log.getChild('auth'))
        self.client = Client(ip=self._get_ip(),
                             transport=self._get_transport(),
                             send_own_ip=self.extra_args['send_own_ip'],
                             log=self.log,
                             )

        self._process_extra_args()
        self._initialize_heartbeat_api()

    def shutdown(self):
        # we may have encountered exception creating client
        if getattr(self, 'client', None) is not None:
            self.client.shutdown()
            self.client = None

    @property
    def signer(self):
        return self._signer

    def _wrap_str_iterable(self, hosts, params):
        if isinstance(hosts, six.string_types):
            hosts = [hosts]
            if params is not None:
                params = [params]
        return hosts, params

    def run(self, hosts, remote_object, params=None):
        hosts, params = self._wrap_str_iterable(hosts, params)
        kwargs = self.extra_args.copy()
        if IN_ARCADIA:
            kwargs.setdefault('exec_fn', 'arcadia_binary')
            kwargs['arcadia_serialize'] = True

        session = self.client.start_session(
            hosts,
            Wrapper(remote_object),
            params,
            signer=self.signer,
            task_type=self.default_task_type,
            **kwargs
        )
        self._schedule_heartbeat_report(
            session.taskid, 'run', len(hosts),
            object=str(remote_object),
            params_used=params is not None,
            **kwargs
        )
        return CqueueSession(
            hosts,
            self,
            session,
            remote_object
        )

    def iter(self, hosts, remote_object, params=None):
        hosts, params = self._wrap_str_iterable(hosts, params)
        kwargs = self.extra_args.copy()
        if IN_ARCADIA:
            kwargs.setdefault('exec_fn', 'arcadia_binary')
            kwargs['arcadia_serialize'] = True

        session = self.client.start_session(
            hosts,
            Wrapper(remote_object),
            params,
            signer=self.signer,
            task_type=self.default_task_type,
            **kwargs
        )
        self._schedule_heartbeat_report(
            session.taskid, 'iter', len(hosts),
            object=str(remote_object),
            params_used=params is not None,
            **kwargs
        )
        return CqueueSession(
            hosts,
            self,
            session,
            remote_object,
            is_iter=True,
        )

    def ping(self, hosts, *args, **kwargs):
        hosts, params = self._wrap_str_iterable(hosts, None)
        # sending heartbeat report for ping is skipped to not flood the heartbeat
        return PingSession(
            hosts,
            self,
            self.client.start_ping_session(hosts, **self.extra_args),
            remote_object=None,
        )

    def createPipe(self):
        return Pipe(ResultsQueue(self.client.select_function))

    def createPoll(self, iterable=None):
        return Poll(iterable=iterable, select_function=self.client.select_function)

    def createQueue(self, maxsize=0):
        raise NotImplementedError

    def set_signer(self, value):
        self._signer = value

    def register_safe_unpickle(self, module_name=None, attr_name=None, obj=None):
        if not ((module_name and attr_name) or obj is not None):
            raise ValueError("either `module_name` and `attr_name`, or `obj` must be specified")

        if module_name:
            self.client.allowed_unpickles[module_name].add(attr_name)
        else:
            self.client.allowed_unpickles[obj.__module__].add(obj.__name__)

    def send_shutdown(self, host):
        self.client.send_shutdown(host)

    def run_shell(self, hosts, cmd, user, timeout, **extra_opts):
        hosts, _ = self._wrap_str_iterable(hosts, None)

        session = self._create_custom_session(hosts, cmd, None, user, timeout, **extra_opts)

        # for security reasons do not report environment
        extra_opts.pop('extra_env', None)
        self._schedule_heartbeat_report(
            session.taskid, 'run_shell', len(hosts),
            object=str(cmd),
            user=user,
            timeout=timeout,
            **extra_opts
        )
        return CqueueSession(
            hosts,
            self,
            session,
            remote_object=None,
        )

    def run_in_porto(self, hosts, remote_object, params=None, is_iter=False):
        """
        Run task in porto container(s) on host
        :param iterable hosts: pairs (hostname, container_name), where:
                hostname: name of the host, can be string with optional port, or tuple (host, port)
                container_name: string with existing container to run in. Task will be executed in unique subcontainer
        :param callable remote_object: object to execute
        :param iterable params: parameters for each task. Each param should be a dict with two optional keys:
            optional 'porto_params': special params for the porto container. See `portoctl` for list of available ones
            optional 'task_params': arguments to pass to callable on the remote side
            example:
                params=[
                    'porto_params': {
                        'cpu_priority': 'rt',
                        'memory_guarantee': 256 * 1024 * 1024,
                    },
                    {
                        'task_params': [1, 2, []]
                    },
                    {},
                    {
                        'porto_params': {
                            'cpu_priority': 'rt',
                        },
                        'task_params': ['param1']
                    }
                ]
        :param bool is_iter: is the callable iterable or not
        """
        kwargs = self.extra_args.copy()
        if IN_ARCADIA:
            kwargs.setdefault('exec_fn', 'arcadia_binary')
            kwargs['arcadia_serialize'] = True

        session = self.client.start_custom_session(
            hosts,
            function=None,
            params=params,
            signer=self.signer,
            runnable=Wrapper(remote_object),
            task_type='multi_task',
            session_class=MultiTaskSession,
            exec_args=kwargs,
            **kwargs
        )
        self._schedule_heartbeat_report(
            session.taskid, 'run_in_porto', len(hosts),
            object=str(remote_object),
            params_used=params is not None,
            is_iter=bool(is_iter),
            **kwargs
        )
        return PortoSession(
            hosts,
            self,
            session,
            remote_object,
            is_iter=is_iter,
        )

    def run_in_portoshell(self, hosts, command, user=None, streaming=True, is_iter=False, extra_env=None, **kwargs):
        """
        Run task in porto container(s) on host
        :param iterable hosts: tuples (hostname, slot[, configuration_id]), where:
                hostname: name of the host, can be string with optional port, or tuple (host, port)
                slot: string with slot name to run in. Task will be executed in unique subcontainer
                configuration_id: optional string with configuration id of the slot to use
        :param str command: shell command to execute
        :param bool is_iter: is the callable iterable or not
        :param dict extra_env: additional environment variables to set
        """
        params = [
            make_message(
                user=user or '',
                slot=item[1],
                configuration_id='' if len(item) < 3 else item[2],
                command=command,
                api_mode=True,
                streaming=streaming,
                extra_env=list(extra_env.items()) if extra_env else None,
            )
            for item in hosts
        ]
        for msg in params:
            sign_message(msg, self.signer)

        exec_args = self.extra_args.copy()
        exec_args.update(kwargs)

        session = self.client.start_custom_session(
            hosts,
            function='portoshell_slow',
            params=params,
            signer=self.signer,
            task_type='task_portoshell_slow',
            runnable=None,
            exec_args=dict(command=command, **exec_args),
            session_class=PortoshellSession,
            # FIXME extra_args are currently duplicated
            **self.extra_args
        )
        self._schedule_heartbeat_report(
            session.taskid, 'run_in_portoshell', len(hosts),
            object=str(command),
            params_used=params is not None,
            user=user,
            streaming=bool(streaming),
            is_iter=bool(is_iter),
            **exec_args
        )
        return PortoSession(
            hosts,
            self,
            session,
            None,
            is_iter=is_iter,
        )

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.shutdown()

    def __del__(self):
        self.shutdown()

    def _create_custom_session(self, hosts, cmd, params, user, timeout, session_type='shell', **extra_args):
        username = user if user is not None else getUserName()

        opts = dict(self.extra_args)
        opts.update(extra_args)
        return self.client.start_custom_session(
            hosts,
            function=session_type,
            params=params,
            signer=self.signer,
            task_type='task',
            username=username,
            timeout=timeout,
            exec_args=dict(cmd=cmd),
            **opts
        )

    def _get_transport(self):
        if self.extra_args['netlibus']:
            return 'netlibus'
        return 'msgpack'

    def _get_ip(self):
        ip = self.extra_args.get('ip')
        if not ip:
            return None

        try:
            ip = getaddrinfo(ip, 0, 0, socket.SOCK_DGRAM)[0][4][0]
            return ip
        except (socket.gaierror, IndexError):
            raise CQueueRuntimeError('Provided ip `{}` is not valid'.format(ip))

    def _process_extra_args(self):
        default_port = self.extra_args.get('default_port')
        if default_port:
            self.client.set_default_port(int(default_port))

        self.client.select_function = self.extra_args.get('select_function', poll_select)

    def _initialize_heartbeat_api(self):
        self._heartbeat_schedule = None
        try:
            from api.heartbeat.client import schedule_report_async
            self._heartbeat_schedule = schedule_report_async
        except Exception as e:
            self.log.warning("failed to initialize heartbeat api: %s", e)

    def _schedule_heartbeat_report(self, uuid, callname, hosts_count, **report):
        if self._heartbeat_schedule is None:
            return

        try:
            report.update(
                uuid=uuid,
                timestamp=time.time(),
                fqdn=fqdn(),
                method=callname,
                remote_hosts=hosts_count,
                accounting_user=getUserName(),
            )

            self._heartbeat_schedule(name='cqudp-api-call', report=report, incremental=True)
        except Exception as e:
            self.log.warning("failed to schedule heartbeat report: %s", e)


ICQClientImpl = CqueueClient  # backward compatibility just to be sure


class ICQSessionImpl(Selectable):
    def __init__(self, hosts, client, session, remote_object):
        super(ICQSessionImpl, self).__init__()

        self.client = client
        self.session = session
        self.handle = session.make_results_handle()
        self.remote_object = remote_object

    def _get_event_fd(self):
        return self.session._get_event_fd()

    def _is_data_ready(self):
        return self.session._is_data_ready()

    @property
    def id(self):
        return self.session.taskid

    @property
    def running(self):
        return not self.session.is_empty()

    @property
    def remoteObject(self):
        return self.remote_object

    def _wait(self, fn, timeout=None):
        raise NotImplementedError

    def wait(self, timeout=None):
        return self._wait(self.handle.wait, timeout=timeout)

    def poll(self, timeout=None):
        return self._wait(self.handle.poll, timeout=timeout)

    def shutdown(self):
        if self.client is not None:
            self.session.stop()
            self.session.report_type_stats()
            self.client = None  # we no longer need to keep reference

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.shutdown()


class CqueueSession(ICQSessionImpl):
    def __init__(self, hosts, client, session, remote_object, is_iter=False):
        super(CqueueSession, self).__init__(hosts, client, session, remote_object)

        self.yield_stopiter = is_iter
        self.metrics = CqueueMetrics(len(session.hosts))

    def _wait(self, fn, timeout=None):
        # Hide index and stat data, unpack address.
        for addr, _, res, err in fn(timeout=timeout):
            if not self.running:
                self.shutdown()
            self.metrics.update(err, res)
            if isinstance(err, StopIteration):
                if self.yield_stopiter:
                    yield addr, res, err
            else:
                if res is not None:
                    # drop metric data
                    res = res[0]
                yield addr, res, err


class PingSession(ICQSessionImpl):
    def _wait(self, fn, timeout=None):
        # Hide index, unpack address.
        for addr, _, res, err in fn(timeout=timeout):
            if not isinstance(err, StopIteration):
                yield addr, res, err


class PortoSession(ICQSessionImpl):
    def __init__(self, hosts, client, session, remote_object, is_iter=False):
        super(PortoSession, self).__init__(hosts, client, session, remote_object)
        self.yield_stopiter = is_iter
        self.metrics = CqueueMetrics(len(session.hosts))

    def _wait(self, fn, timeout=None):
        # Hide index and stat data, unpack address.
        for addr, _, res, err in fn(timeout=timeout):
            self.metrics.update(err, res)
            if isinstance(err, StopIteration):
                if self.yield_stopiter:
                    yield addr, res, err
            else:
                if res is not None:
                    # drop metric data
                    res = res[0]
                yield addr, res, err


class Wrapper(object):
    def __init__(self, obj):
        self.obj = obj
        self.check()

    def __str__(self):
        return "{}[{}]".format(self.__class__.__name__, self.obj)

    def __repr__(self):
        return "{}({!r})".format(self.__class__.__name__, self.obj)

    def __getattr__(self, name):
        # TODO take default values from remote object API

        if name == 'obj':
            # It means that object is not instantiated yet
            raise AttributeError(name)
        return getattr(self.obj, name)

    def __call__(self, *args, **kwargs):
        fn = getattr(self.obj, 'run', None) or getattr(self.obj, '__call__', None)

        return fn(*args, **kwargs)

    def check(self):
        fn = getattr(self.obj, '__call__', None) or getattr(self.obj, 'run', None)

        if not fn:
            raise CQueueRuntimeError('remote object has no run or __call__ method')

IterWrapper = Wrapper  # Backward compatibility.


def create_signer(log=None):
    # TODO inherit log from client
    fksm = FileKeysSignManager(
        commonKeyDirs=cfg.client.Auth.CommonKeyDirs,
        userKeyDirs=cfg.client.Auth.UserKeyDirs,
        keyFiles=cfg.client.Auth.KeyFiles,
        log=log
    )
    fksm.load()

    sa = SignAgentClient(log=log)
    sa.load()

    return ChainSignManager([fksm, sa])
