import os
import base64
import hashlib
import binascii

import six

from kernel.util.errors import sAssert
from kernel.util.streams.streambase import StreamBase
from kernel.util.functional import memoized

from . import log as root


try:
    import paramiko
except ImportError:
    paramiko = None


log = root.getChild('key')


if six.PY2:
    bytesio = six.moves.cStringIO
    b = lambda s: s

    def tobytes(s):
        if isinstance(s, six.text_type):
            return s.encode('latin-1')
        else:
            return ''.join(s)

else:
    bytesio = six.BytesIO
    b = lambda s: s.encode('latin-1') if isinstance(s, str) else s

    def tobytes(s):
        if isinstance(s, str):
            return bytes(s, 'latin-1')
        else:
            return bytes(s)


class AuthKey(object):
    def __init__(self):
        self.userNames = set()
        self._comment = ''
        self.options = {}

    @property
    def comment(self):
        return self._comment

    @comment.setter
    def comment(self, val):
        if six.PY3 and isinstance(val, six.binary_type):
            self._comment = val.decode('utf-8')
        else:
            self._comment = val

    def __str__(self):
        description = [self.hexFingerprint()]
        if len(self.userNames) == 1:
            description.append(next(iter(self.userNames)))
        elif self.userNames:
            description.append('(' + ', '.join(self.userNames) + ')')

        if self.comment:
            description.append(self.comment)
        return '{0} [{1}]'.format(self.__class__.__name__, ', '.join(description))

    def sign(self, hash):
        raise NotImplementedError

    def verify(self, hash, sign):
        raise NotImplementedError

    def has_private(self):
        raise NotImplementedError

    hasPrivate = has_private

    def publicKey(self):
        raise NotImplementedError

    def exportKey(self, keyformat='PEM'):
        raise NotImplementedError

    def public(self):
        return self.publicKey()

    def export(self):
        return self.exportKey()

    def networkRepresentation(self):
        raise NotImplementedError

    def fingerprint(self, hash_function=None):
        """
        :param hash_function: callable or hashlib func name
        :return: ssh compatible public key fingerprint
        """
        if hash_function is None:
            hash_function = self.hash
        elif isinstance(hash_function, six.string_types):
            hash_function = getattr(hashlib, hash_function)
        return hash_function(self.publicKey().networkRepresentation()).digest()

    def hexFingerprint(self, hash_function=None):
        """
        :param hash_function: callable or hashlib func name
        :return: ssh compatible public key fingerprint
        """
        if hash_function is None:
            hash_function = self.hash
        elif isinstance(hash_function, six.string_types):
            hash_function = getattr(hashlib, hash_function)
        return hash_function(self.publicKey().networkRepresentation()).hexdigest()

    def size(self):
        raise NotImplementedError

    @property
    def type(self):
        raise NotImplementedError

    def description(self):
        return self.comment if self.comment is not None else ''

    def correctedHash(self, hash):
        return hash

    _hashModule = hashlib.md5

    @classmethod
    def hash(cls, data=b''):
        """
        :param data: starting hash data
        :return: hash function with .update(data) and .digest() methods
        """
        return cls._hashModule(data)

    @classmethod
    def generate(cls, bits=2048):
        """
        :param int bits: key bits
        :return: new generated key
        """
        raise NotImplementedError

    @classmethod
    def load(cls, f, log=None):
        """
        :param f: file or str
        :param logging.Logger log: log to use
        :rtype: AuthKey
        """
        raise NotImplementedError

    @classmethod
    def fromNetworkRepresentation(cls, data, log=None):
        raise NotImplementedError


class PKey(AuthKey):
    """compatibility with paramiko.PKey"""

    def asbytes(self):
        return self.networkRepresentation()

    def can_sign(self):
        return self.has_private()

    def get_base64(self):
        data = base64.b64encode(self.publicKey().networkRepresentation())
        if six.PY3:
            data = data.decode()
        return data

    def get_bits(self):
        return self.size()

    def get_fingerprint(self):
        return self.fingerprint()

    def get_name(self):
        return self.type if six.PY2 else self.type.decode()

    def sign_ssh_data(self, data):
        if paramiko is None:
            raise RuntimeError("paramiko is not available")
        hash = self.correctedHash(data)
        signature = self.encodeSign(self.sign(hash))
        message = paramiko.Message()
        message.add_string(self.type)
        message.packet.write(signature)
        return message

    def verify_ssh_sig(self, data, msg):
        key_type = msg.get_string()
        if key_type != self.type:
            return False

        hash = self.correctedHash(data)
        sign = self.decodeSign(msg.get_remainder())
        return self.verify(hash, sign)

    def write_private_key(self, file_obj, password=None):
        if password:
            raise NotImplementedError("password encryption is not supported")
        file_obj.write(self.exportKey())

    def write_private_key_file(self, filename, password=None):
        with open(filename, 'w', 0o600) as f:
            os.chmod(filename, 0o600)
            self.write_private_key(f, password=password)


class CryptoKeyMetaclass(type):
    def __new__(mcs, name, bases, clsDict):
        newType = type.__new__(mcs, name, bases, clsDict)
        if name == 'CryptoKey':
            return newType

        for keyType in [clsDict.get('keyType', None)] + clsDict.get('keyTypes', []):
            if keyType is not None:
                prevType = CryptoKey._KeyTypes_.get(keyType)
                sAssert(prevType is None, 'keyType collision in {0} and {1}'.format(name, str(prevType)))
                CryptoKey._KeyTypes_[keyType] = newType

        return newType


@six.add_metaclass(CryptoKeyMetaclass)
class CryptoKey(PKey):
    """Wrapper around Crypto.PublicKey API"""

    _KeyTypes_ = {}

    # Should be defined in subclasses
    cryptoKeyClass = None
    keyType = None
    keyTypes = None
    defaultSize = None
    _sshPublicKeyPrefixes = ()
    KeyLoadErrors = (ValueError, IndexError, binascii.Error)
    if paramiko is not None:
        KeyLoadErrors = KeyLoadErrors + (paramiko.SSHException, paramiko.PasswordRequiredException)

    def __init__(self, *args):
        super(CryptoKey, self).__init__()
        if args and hasattr(args[0], 'verify') and hasattr(args[0], 'sign'):
            self._key = args[0]
        else:
            self._key = self.cryptoKeyClass.construct(*args)

    def encodeSign(self, sign):
        """Return signature representable as network string"""
        raise NotImplementedError

    def decodeSign(self, sign):
        """Read signature from network string"""
        raise NotImplementedError

    def has_private(self):
        return self._key.has_private()

    hasPrivate = has_private

    @memoized
    def publicKey(self):
        return self.__class__(self._key.publickey())

    @classmethod
    def fromNetworkRepresentation(cls, data, log=None):
        """Read key from network representation (which should contain keyType as first param)"""
        stream = StreamBase(bytesio(data))
        keyType = stream.readBEStr()
        if cls.keyType is not None:
            if keyType != cls.keyType:
                return None
            return cls.fromNetworkStream(stream)

        if keyType not in CryptoKey._KeyTypes_:
            return None
        try:
            return CryptoKey._KeyTypes_[keyType].fromNetworkStream(stream)
        except NotImplementedError:
            return CryptoKey._KeyTypes_[keyType].fromNetworkRepresentation(data, log=log)

    @memoized
    def networkRepresentation(self):
        """
        Export key in the format SSH uses to transmit key
        through the network. It means that format is binary
        and that there's only public key part is exported.
        """
        io = bytesio()
        stream = StreamBase(io)
        stream.writeBEStr(self.keyType)
        self.dumpToNetworkStream(stream)
        io.seek(0, 0)
        return io.read()

    @classmethod
    def fromNetworkStream(cls, stream, log=None):
        raise NotImplementedError

    def dumpToNetworkStream(self, stream):
        raise NotImplementedError

    def size(self):
        return self._key.size()

    @property
    def type(self):
        return self.keyType

    @classmethod
    def generate(cls, bits=None):
        """
        :param int bits: key bits
        :return: new generated key
        """
        if bits is None:
            bits = cls.defaultSize
        return cls(cls.cryptoKeyClass.generate(bits))

    @classmethod
    def load(cls, f, comment='', log=None):
        """
        :param f: file or str
        :rtype: CryptoKey generator
        """
        _f = f
        try:
            try:
                if isinstance(f, str) and not comment:
                    comment = f
                    _f = open(f)
                data = _f.read()
            except EnvironmentError:
                return
            else:
                # force this function to be a
                # generator, so plain "return" above
                # will not return None, but empty iterator
                for key in cls.loads(data, comment, log=log):
                    yield key
        finally:
            if _f is not f:
                _f.close()

    @classmethod
    def loads(cls, data, comment='', log=None):
        log = log or root.getChild('auth')
        classes = set(six.itervalues(CryptoKey._KeyTypes_)) if cls is CryptoKey else {cls}
        errs = []
        empty = True
        for c in classes:
            try:
                for key in c._doMultiLoads(data, comment, log=log):
                    yield key
                    empty = False
            except cls.KeyLoadErrors as err:
                errs.append(err)

        if empty and errs:
            log.warning('Malformed key `{0}`: {1}'.format(
                comment if comment else data[:10],
                ', '.join((str(err) for err in errs)))
            )

    @classmethod
    def construct(cls, *args, **kwargs):
        log = kwargs.pop('log', None) or root.getChild('auth')
        try:
            return cls(cls.cryptoKeyClass.construct(*args))
        except cls.KeyLoadErrors:
            log.warning('Malformed key `{0}`'.format(', '.join(args)))

    @classmethod
    def _parse_options(cls, options_string):
        options_string = options_string.strip()
        if not options_string:
            return None
        ret = {}
        if six.PY3:
            options_string = options_string.decode('utf-8')

        in_quote = False
        key = None
        has_value = False
        value = None
        start_index = 0
        for idx, c in enumerate(options_string):
            if c == '"':
                in_quote = not in_quote
                continue
            elif in_quote:
                continue
            elif c == '=':
                key = options_string[start_index:idx]
                start_index = idx + 1
                has_value = True
            elif c == ',':
                if has_value:
                    value = options_string[start_index:idx].strip('"')
                else:
                    key = options_string[start_index:idx]
                ret[key] = value
                key = None
                value = None
                start_index = idx + 1
                has_value = False

        if in_quote:
            raise ValueError("Invalid options string")
        if start_index < len(options_string) - 1:
            if has_value:
                value = options_string[start_index:].strip('"')
            else:
                key = options_string[start_index:]
            ret[key] = value

        return ret

    @classmethod
    def _doMultiLoads(cls, data, comment='', log=None):
        """Wrapper for _doLoads which try to handle several keys in data"""
        sAssert(cls is not CryptoKey)

        err = None
        empty = True

        multiKey = False

        if six.PY3 and not isinstance(data, six.binary_type):
            data = b(data)

        for line in data.split(b'\n'):
            parts = line.replace(b'\t', b' ' * 8).split(b' ')
            for prefix in cls._sshPublicKeyPrefixes:
                try:
                    prefixPos = parts.index(prefix)
                except ValueError:
                    continue
                else:
                    line = b' '.join(parts[prefixPos:])
                    options = cls._parse_options(b' '.join(parts[:prefixPos]).strip())

                    multiKey = True
                    try:
                        yield cls._doLoads(line, comment, options, log=log)
                        empty = False
                        break
                    except cls.KeyLoadErrors as e:
                        err = e

        if not multiKey:
            try:
                yield cls._doLoads(data, comment, log=log)
                empty = False
            except cls.KeyLoadErrors as e:
                err = e

        if empty and err:
            raise err

    def _doLoads(self, data, comment='', options=None, log=None):
        """Should be overridden in children, should try load single key"""
        raise NotImplementedError
