from __future__ import absolute_import

import six
import os
import threading

from .auth import create_auth, create_empty_auth, AuthSyslog
from ..utils import short, sleep, log as root
from ..exceptions import (
    CQueueError,
    CQueueNetworkError,
    CQueueRuntimeError,
    CQueueAuthenticationError,
    CQueueAuthorizationError,
    ReadableKeyError,
)
from ya.skynet.library.auth.key import Certificate
from ya.skynet.util.functional import Cache
from ya.skynet.util.sys.user import getUserName, UserPrivileges
from ya.skynet.util.errors import saveTraceback, setTraceback, getTraceback

try:
    from ya.skynet.services.portoshell.slots.exceptions import AuthError as PortoshellAuthError
except ImportError:
    PortoshellAuthError = Exception


class Manager(object):
    def __init__(
        self,
        privileges_lock=None,
        auth=None,
        log=None,
        keys_storage=None,
        ca_storage=None,
    ):
        self.pam_cache = Cache(cachePeriod=60 * 60)

        # Verify attribute is used only for testing
        self.verify = True
        self.user_lock = privileges_lock or threading.RLock()
        self.lock_owned = privileges_lock is None
        self.log = log or root().getChild('manager')
        self.keys_storage = keys_storage
        self.ca_storage = ca_storage
        self.auth = auth or create_auth(ca_storage=self.ca_storage, log=self.log)
        self.auth_syslog = AuthSyslog('skynet.cqudp')

        self.auth_syslog.open()

    def _fixup_sign(self, signature):
        # NOTE: for backward compatibility reasons certificate-signed messages come
        # in the form ((fingerprint, certificate), signature) which we need to
        # translate into (fingerprint, signature, certificate)
        if isinstance(signature[0], tuple) and len(signature[0]) == 2:
            try:
                cert = next(self.auth.loadsKeys(signature[0][1], log=self.auth.log))
                if isinstance(cert, Certificate):
                    return signature[0][0], signature[1], cert
            except Exception:
                pass
        return signature

    def verify_message(self, msg, user, signs, hash, iface):
        if self.auth is None:
            self.log.info('[%s] authentication disabled', short(msg['uuid']))
            return

        signs = [self._fixup_sign(signature) for signature in signs]

        attempts = 2
        for attempt in six.moves.xrange(2):
            with self.user_lock:
                res = self.auth.verify(hash, signs, user)

            if res:
                break

            self.auth_syslog.info("Authentication as %s for %s from %s has failed",
                                  user,
                                  short(msg['uuid']),
                                  msg['acc_host']
                                  )
            self.log.info("[%s] authentication as %s@%s failed%s",
                          short(msg['uuid']),
                          user,
                          msg['acc_host'],
                          ", will reload keys and retry" if attempt != attempts - 1 else "",
                          )
            if attempt != attempts - 1:
                self.auth.load()

        if not res:
            raise CQueueAuthenticationError('invalid credentials', user=user)

        self.auth_syslog.info("Authenticated as %s for %s from %s with %s",
                              user,
                              short(msg['uuid']),
                              msg['acc_host'],
                              res
                              )

        if user and user != getUserName() and os.uname()[0].lower() == 'linux':
            # checking PAM
            self._pam_check(user, msg['acc_host'], msg['uuid'])

        self.log.info("[%s] authenticated as %s@%s: %s", short(msg['uuid']), user, msg['acc_host'], res)

    def _pam_check(self, user, host, uuid, attempts=3):
        cached = self.pam_cache.get((user, host))
        if cached:
            return

        for i in six.moves.xrange(attempts):
            try:
                with self.user_lock:
                    from ya.skynet.util.sys import pam
                    with UserPrivileges(modifyGreenlet=False):
                        pam.check_account(user, host=host)
            except BaseException as e:
                if i == attempts - 1:
                    self.log.info(
                        "[%s] PAM authorization as %s@%s failed", short(uuid), user, host
                    )
                    raise CQueueAuthorizationError(str(e), user=user)
                else:
                    sleep(0.05)
                    continue
            else:
                self.pam_cache[(user, host)] = True
                return

    def shutdown(self):
        self.auth_syslog.close()
        if self.lock_owned:
            self.user_lock.acquire(True)

    @staticmethod
    def prepare_exception(ex, message):
        saveTraceback(ex)
        new_exc = _wrap_exception(ex, message=message)
        setTraceback(new_exc, getTraceback(ex))
        return new_exc

    def mtn_verify_message(self, msg, user, signs, hash, slot_info):
        if slot_info is None:
            self.log.info(
                "[%s] authentication into unknown slot as %s failed for %s@%s",
                short(msg['uuid']), user, msg['acc_user'], msg['acc_host'],
            )
            self.auth_syslog.info(
                "Authentication into unknown slot as %s failed for %s@%s",
                user,
                msg['acc_user'], msg['acc_host'],
            )
            raise CQueueAuthenticationError("Do not know how to authenticate in this slot")

        signs = [self._fixup_sign(signature) for signature in signs]

        try:
            user_keys = slot_info.get_auth_keys(user, keys_storage=self.keys_storage)
        except PortoshellAuthError as e:
            self.log.info(
                "[%s] authentication as %s@%s failed for %s@%s",
                short(msg['uuid']),
                user, slot_info.identifier(),
                msg['acc_user'], msg['acc_host'],
            )
            self.auth_syslog.info(
                "authentication as %s@%s failed for %s@%s",
                user, slot_info.identifier(),
                msg['acc_user'], msg['acc_host'],
            )
            raise CQueueAuthenticationError(str(e))

        vm = create_empty_auth(ca_storage=self.ca_storage, log=self.log)
        vm.load(username=user)
        for key in user_keys:
            key.userNames.add(user)
            vm.addKey(key)

        res = vm.verify(hash, signs, user)
        if not res:
            self.log.info(
                "[%s] Authentication into %s@%s failed for %s@%s",
                short(msg['uuid']),
                user, slot_info.identifier(),
                msg['acc_user'], msg['acc_host'],
            )
            self.auth_syslog.info(
                "Authentication into %s@%s failed for %s@%s",
                user, slot_info.identifier(),
                msg['acc_user'], msg['acc_host'],
            )
            raise CQueueAuthorizationError("Authorization as %s@%s failed" % (user, slot_info.identifier()))

        self.auth_syslog.info(
            "Authenticated %s@%s as %r by %s into %s",
            msg['acc_user'],
            msg['acc_host'],
            user,
            res,
            slot_info.as_auth_info(),
        )
        self.log.info(
            "[%s] Authenticated %s@%s as %r by %s into %s",
            short(msg['uuid']),
            msg['acc_user'],
            msg['acc_host'],
            user,
            res,
            slot_info.as_auth_info(),
        )

        return res


def _wrap_exception(e, message=None):
    if isinstance(e, CQueueError):
        return e
    elif isinstance(e, EnvironmentError):
        if message is None:
            message = e.strerror
        new_err = CQueueNetworkError(*e.args, message=message)
        if e.filename is not None:
            new_err.filename = e.filename
        return new_err
    elif isinstance(e, RuntimeError):
        return CQueueRuntimeError(getattr(e, 'message', None) or e.args)
    elif isinstance(e, KeyError):
        return ReadableKeyError(e)
    else:
        return CQueueRuntimeError(str(e))
