"""Base for auth managers"""

import os
import re
import six
import errno

from collections import defaultdict
from itertools import chain

from kernel.util.sys.user import getUserName, getUserHome, getUserUID, userPrivileges

from . import log as root
from .key import AuthKey, CryptoKey, RSAKey, DSSKey, ECDSAKey


try:
    from api.srvmngr import getRoot as get_install_root
except ImportError:
    def get_install_root():
        raise RuntimeError("Cannot detect installation path")

getInstallRoot = get_install_root  # backward compatibility


class KeyFactory(object):
    @staticmethod
    def generateKey(bits=2048, keyType='rsa'):
        if keyType == 'rsa':
            return RSAKey.generate(bits)
        elif keyType in ('dsa', 'dss'):
            return DSSKey.generate(bits)
        elif keyType == 'ecdsa' and ECDSAKey is not None:
            return ECDSAKey.generate(bits)
        else:
            raise ValueError('Unknown key type: {0}'.format(keyType))

    @staticmethod
    def loadsKeys(data, comment='', log=None):
        return CryptoKey.loads(data, comment, log=log)

    @staticmethod
    def loadKeys(f, comment='', log=None):
        return CryptoKey.load(f, comment, log=log)


class BaseAuthManager(KeyFactory):
    """
    Base for signer and verifier
    """

    def __str__(self):
        return self.__class__.__name__

    def __init__(self, caStorage=None, log=None):
        self._keys = {}
        self._userKeys = defaultdict(set)
        self._caStorage = caStorage
        self.log = log or root.getChild('base')

    hash = AuthKey.hash

    def add_ca(self, ca):
        # FIXME (torkve) remove it: temporary backward compat noop method
        pass

    def fingerprints(self):
        """
        :returns: iterable with all known key ids
        """
        return set(six.iterkeys(self._keys))

    def __iter__(self):
        """
        Iterate over all loaded keys
        """
        return six.itervalues(self._keys)

    def hexFingerprints(self):
        """
        :returns: iterable with all known key ids
        """
        return set(key.hexFingerprint() for key in six.itervalues(self._keys))

    def load(self, username=None):
        """
        Loads key from predefined places
        :type username: str or None
        """
        if username is None:
            self._keys.clear()
            self._userKeys.clear()

    def filterKey(self, key):
        """
        This function should be overridden in descendants to filter keys
        :type key: services.cqueue.src.common.cert.CryptoKey
        """
        return False

    def addKey(self, key):
        """
        :type key: services.cqueue.src.auth.key.AuthKey
        """
        fingerprint = key.fingerprint()

        for userName in key.userNames:
            self._userKeys[userName].add(fingerprint)

        prevKey = self._keys.get(fingerprint)

        if prevKey:
            prevKey.comment = key.comment
            prevKey.userNames.update(key.userNames)
        else:
            self._keys[fingerprint] = key
            self.log.info('{0} add {1}'.format(self, key))

    def removeKey(self, item):
        """
        :param item: key or key fingerprint
        """
        if isinstance(item, AuthKey):
            item = item.fingerprint()
        key = self._keys.pop(item, None)
        if key:
            for userName in key.userNames:
                self._userKeys[userName].discard(item)

            self.log.info('{0} remove {1}'.format(self, key))


class ChainAuthManager(KeyFactory):
    def __init__(self, managers):
        super(ChainAuthManager, self).__init__()
        self.managers = list(managers)

    hash = BaseAuthManager.hash

    def fingerprints(self):
        return set(chain.from_iterable(manager.fingerprints() for manager in self.managers))

    def hexFingerprints(self):
        return set(chain.from_iterable(manager.hexFingerprints() for manager in self.managers))

    def __iter__(self):
        return chain.from_iterable(self.managers)

    def load(self, username=None):
        for manager in self.managers:
            manager.load(username)

    def filterKey(self, key):
        return all(chain((manager.filterKey(key) for manager in self.managers)))

    def addKey(self, key):
        return self.managers[0].addKey(key)

    def removeKey(self, item):
        return self.managers[0].removeKey(item)


class FileKeysAuthManager(BaseAuthManager):
    def __init__(self, commonKeyDirs=None, userKeyDirs=None, keyFiles=None, caStorage=None, log=None):
        super(FileKeysAuthManager, self).__init__(caStorage=caStorage, log=log)
        self.commonKeyDirs = commonKeyDirs if commonKeyDirs else []
        self.userKeyDirs = userKeyDirs if userKeyDirs else []
        self.keyFiles = keyFiles if keyFiles else []

    def load(self, userName=None):
        """
        Loads key from predefined places
        """
        super(FileKeysAuthManager, self).load(userName)
        keysDirs = [((i, True) for i in self.userKeyDirs)]
        if userName is None:
            userName = getUserName()
            userHome = getUserHome()
            keysDirs.append(((i, False) for i in self.commonKeyDirs))
            allowPrivilegesRaise = False
        else:
            userHome = getUserHome(userName)
            allowPrivilegesRaise = True

        formatArgs = dict(
            userHome=userHome,
            userName=userName,
            environ=os.environ
        )

        # noinspection PyBroadException
        try:
            formatArgs['installRoot'] = get_install_root()
        except Exception:
            pass

        for keysDir, permCheck in chain.from_iterable(keysDirs):
            try:
                keysDir = keysDir.format(**formatArgs)
            except KeyError:
                pass
            else:
                self.__loadKeysFromDir(keysDir, userName, permCheck, allowPrivilegesRaise)

    def _checkPermissions(self, filePath, uid, st):
        if st.st_uid not in (0, uid):
            self.log.warning('Key `{}` owner is not valid: {} (0 or {} expected)'.format(filePath, st.st_uid, uid))
            return False

        return True

    def __loadKeysFromDir(self, keysDir, username, permCheck=False, allowPrivilegesRaise=False):
        uid = getUserUID(username)

        def privilegedCall(func, *args, **kwargs):
            try:
                return func(*args, **kwargs)
            except EnvironmentError as err:
                if err.errno in (errno.EACCES, errno.EPERM) and allowPrivilegesRaise:
                    with userPrivileges():  # As root
                        return func(*args, **kwargs)
                else:
                    raise

        try:
            filesList = privilegedCall(os.listdir, keysDir)
        except EnvironmentError:
            return

        for fileName in filesList:
            filePath = os.path.join(keysDir, fileName)
            if any((re.match(pattern, fileName) for pattern in self.keyFiles)):
                try:
                    if permCheck:
                        if not self._checkPermissions(filePath, uid, privilegedCall(os.stat, filePath)):
                            raise OSError(errno.EPERM, 'Invalid key permissions')

                    data = privilegedCall(readFile, filePath)

                    for key in CryptoKey.loads(data, filePath, log=self.log):
                        if self.filterKey(key):
                            key.userNames.add(username)
                            self.addKey(key)

                except EnvironmentError as err:
                    self.log.warning('Key `{0}` was not loaded: {1}'.format(filePath, err))


def readFile(fileName):
    with open(fileName) as f:
        return f.read()
