import os
import sys
import six
import tempfile
import errno
import socket

from kernel.util.streams.streambase import StreamBase
from kernel.util.net.socketstream import SocketStream
from kernel.util.errors import sAssert
from kernel.util.misc import daemonthr


from . import log as root
from .key import CryptoKey, bytesio


class SocketServer(six.moves.socketserver.ThreadingMixIn, six.moves.socketserver.UnixStreamServer):
    # TODO think of normal polling implementation, without threading and sluts

    timeout = 60
    daemon_threads = True

    def __init__(self, authManager, log, *args, **kwargs):
        self.authManager = authManager
        self.log = log
        six.moves.socketserver.UnixStreamServer.__init__(self, *args, **kwargs)


class SshConnectionHandler(six.moves.socketserver.BaseRequestHandler):
    def handle(self):
        stream = SocketStream(self.request)
        try:
            while True:
                message = SshMessage.Load(stream, self.server.log)
                if isinstance(message, SignRequestMessage):
                    key = CryptoKey.fromNetworkRepresentation(message.key_blob, log=self.server.log)
                    fp = key.fingerprint()
                    signs = list(self.server.authManager.sign(hash=message.data, fingerprints=fp))
                    if not signs:
                        SshOutgoingMessage().write(stream)
                        continue
                    response = SignResponseMessage()
                    # Strip prepended size header (4 bytes)
                    response.signature = StreamBase(bytesio(signs[0][1])).readBEStr()
                    response.signatureType = key.type
                    # In one case ssh-agent protocol requires to change the format of the sign output
                    if key.keyType == b'ssh-dss' and message.flags & SshGeneric.OldSignature == SshGeneric.OldSignature:
                        response.oldSignatureMode = True
                    response.write(stream)
                elif isinstance(message, RequestIdentitiesMessage):
                    response = IdentitiesResponseMessage()
                    response.keys = [(key.publicKey().networkRepresentation(), key.comment) for key in self.server.authManager]
                    response.write(stream)
                else:
                    self.server.log.warning("Unexpected message read from socket: %s", message)
                    SshOutgoingMessage().write(stream)
        except socket.error as e:
            if e.errno not in (errno.ECONNRESET, errno.EPIPE, errno.ECONNABORTED):
                self.server.log.warning("network error during processing message from socket: %s", e, exc_info=sys.exc_info())
                raise
        except Exception as e:
            self.server.log.warning("error during processing message from socket: %s", e, exc_info=sys.exc_info())
            raise
        finally:
            self.request.shutdown(socket.SHUT_RDWR)
            self.request.close()


class SshAgent(object):  # pragma: nocoverage
    """
    Agent, that creates UNIX domain socket and makes it simulate ssh-agent
    Currently only list-keys and sign-message actions are supported.
    """

    def __init__(self, authManager, socketPath=None, log=None):
        """
        Create new agent
        :param authManager: AuthManager, that will handle requests, coming to the ssh agent
        :param socketPath: optional path to create socket on. If not set, new random path will be used
        """
        self.authManager = authManager
        self.socketPath = socketPath
        self.log = log or root.getChild('ssh-agent')
        self.__sockDir = None
        self.__server = None
        self.__started = False

    def start(self):
        """Starts the ssh-agent on the *self.socketPath* socket"""
        if self.__started:
            return
        self.__started = True
        try:
            if self.socketPath is None:
                self.__sockDir = tempfile.mkdtemp(suffix='_ssh', prefix='cqueue_')
                self.socketPath = os.path.join(self.__sockDir, str(os.getpid()))
            if self.__server is None:
                self.__server = SocketServer(self.authManager, self.log, self.socketPath, SshConnectionHandler)
            daemonthr(self.__server.serve_forever)
            self.log.info("SshAgent started on socket %s", self.socketPath)
        except:
            self.__started = False
            raise

    def stop(self):
        """Stops the server and closes the socket"""
        if not self.__started:
            return
        self.log.debug("stopping sshagent on %s", self.socketPath)
        self.__started = False
        self.__server.shutdown()
        try:
            os.unlink(self.socketPath)
            if self.__sockDir is not None:
                os.rmdir(self.__sockDir)
        except EnvironmentError:
            pass

    def __del__(self):
        self.stop()


class Ssh1Codes(object):
    """SSHv1 codes. Probably will never be supported"""
    # Requests
    RequestRsaIdentities = 1
    RsaChallenge = 3
    AddRsaIdentity = 7
    RemoveRsaIdentity = 8
    RemoveAllRsaIdentities = 9
    AddRsaIdConstrained = 24
    # Responses
    RsaIdentitiesAnswer = 2
    RsaResponse = 4


class Ssh2Requests(object):
    """SSHv2 requests"""
    # Supported
    Identities = 11
    Sign = 13
    # Unsupported
    AddIdentity = 17
    RemoveIdentity = 18
    RemoveAllIdentities = 19
    AddSmartcardKey = 20
    RemoveSmartcardKey = 21
    Lock = 22
    Unlock = 23
    AddIdConstrainted = 25
    AddSmartcardKeyConstrainted = 26
    Extension = 27


class Ssh2Responses(object):
    """SSHv2 responses"""
    Identities = 12
    Sign = 14


class SshGeneric(object):
    """Generic ssh codes"""
    # Generic server responses
    Failure = 5
    Success = 6
    ExtensionFailure = 28
    # Flags
    OldSignature = 1
    RsaSha2_256 = 2
    RsaSha2_512 = 4
    # Other
    UnsupportedMessage = 1024


class _SshMessageMetaclass(type):
    _MessageTypes_ = {}

    def __new__(mcs, name, bases, dct):
        newType = type.__new__(mcs, name, bases, dct)
        mt = dct.get('MessageType', None)
        if mt is not None:
            prevType = _SshMessageMetaclass._MessageTypes_.get(mt)
            sAssert(prevType is None, 'MessageType collision in {0} and {1}'.format(name, str(prevType)))
            _SshMessageMetaclass._MessageTypes_[mt] = newType
            newType._MessageTypes_ = _SshMessageMetaclass._MessageTypes_
        return newType


@six.add_metaclass(_SshMessageMetaclass)
class SshMessage(object):
    @classmethod
    def Load(cls, source, log=None):
        log = log or root.getChild('ssh')
        if isinstance(source, six.string_types):
            iostream = bytesio(source)
            stream = StreamBase(iostream)
        else:
            stream = source
        length = stream.readBEInt()
        sAssert(length > 0, 'Protocol error')
        msgType = stream.readFInt(1)
        length -= 1
        if msgType not in SshMessage._MessageTypes_:
            log.warning("Unknown message type: %s", msgType)
            stream.read(length)  # Ignore message
            return SshIncomingMessage()
        msg = SshMessage._MessageTypes_[msgType]()
        msg.load(stream, length)
        return msg

    def dump(self):
        ioStream = bytesio()
        stream = StreamBase(ioStream)
        self.write(stream)
        return ioStream.getvalue()

    def write(self, stream):
        raise NotImplementedError

    def load(self, stream, length):
        raise NotImplementedError


class SshOutgoingMessage(SshMessage):
    MessageType = SshGeneric.Failure

    def write(self, stream):
        stream.writeBEInt(1)
        stream.writeFInt(SshGeneric.Failure, 1)

    def load(self, stream, length):
        sAssert(length == 0, 'Protocol error')


class SshIncomingMessage(SshMessage):
    pass


class SshUnsupportedMessage(SshIncomingMessage):
    MessageType = SshGeneric.UnsupportedMessage


class SignResponseMessage(SshOutgoingMessage):
    MessageType = Ssh2Responses.Sign

    @property
    def signature(self):
        return self._signature

    @signature.setter
    def signature(self, val):
        self._signature = val

    @property
    def oldSignatureMode(self):
        return getattr(self, '_oldSignatureMode', 0)

    @oldSignatureMode.setter
    def oldSignatureMode(self, value):
        self._oldSignatureMode = value

    @property
    def signatureType(self):
        return self._signatureType

    @signatureType.setter
    def signatureType(self, val):
        # sAssert(val == 'ssh-rsa', 'Unsupported key type')
        self._signatureType = val

    def write(self, stream):
        if self.oldSignatureMode:
            stream.writeBEInt(1 + len(self._signature))
            stream.writeFInt(self.MessageType, 1)
            stream.write(self._signature)
            return

        sBlobIO = bytesio()
        sBlobStream = StreamBase(sBlobIO)
        sBlobStream.writeBEStr(self._signatureType)
        sBlobStream.writeBEStr(self._signature)
        signature = sBlobIO.getvalue()
        stream.writeBEInt(1 + 4 + len(signature))
        stream.writeFInt(self.MessageType, 1)
        stream.writeBEStr(signature)

    def load(self, stream, length):
        if self.oldSignatureMode:
            sAssert(length == 40, 'Protocol error')
            self.signatureType = b'ssh-dss'
            self.signature = stream.read(40)
            return

        sAssert(length >= 4, 'Protocol error')
        signatureBlob = stream.readBEStr()
        sAssert(length == len(signatureBlob) + 4, 'Protocol error')

        sigStream = StreamBase(bytesio(signatureBlob))
        self.signatureType = sigStream.readBEStr()

        signature = sigStream.readBEStr()
        io = bytesio()
        StreamBase(io).writeBEStr(signature)
        self.signature = io.getvalue()


class IdentitiesResponseMessage(SshOutgoingMessage):
    MessageType = Ssh2Responses.Identities

    @property
    def keys(self):
        return self._keys

    @keys.setter
    def keys(self, val):
        self._keys = [(key, (comment or "")) for key, comment in val]

    def __init__(self):
        self._keys = []

    def write(self, stream):
        length = 1 + 4 + sum([len(x[0]) + len(x[1]) + 8 for x in self._keys])
        count = len(self._keys)
        stream.writeBEInt(length)
        stream.writeFInt(self.MessageType, 1)
        stream.writeBEInt(count)
        for key, comment in self._keys:
            stream.writeBEStr(key)
            stream.writeBEStr(comment)

    def load(self, stream, length):
        sAssert(length >= 4, 'Protocol Error')
        count = stream.readBEInt()
        length -= 4
        self._keys = []
        for i in six.moves.xrange(count):
            sAssert(length >= (count - i) * 8, 'Protocol error')
            key = stream.readBEStr()
            length -= 4 + len(key)
            sAssert(length >= (count - i - 1) * 8 + 4, 'Protocol error')
            comment = stream.readBEStr()
            length -= 4 + len(comment)
            self._keys.append((key, comment))
        sAssert(length == 0, 'Protocol error')


class RequestIdentitiesMessage(SshIncomingMessage):
    MessageType = Ssh2Requests.Identities

    def write(self, stream):
        stream.writeBEInt(1)
        stream.writeFInt(self.MessageType, 1)

    def load(self, stream, length):
        sAssert(length == 0, 'Protocol error')


class SignRequestMessage(SshIncomingMessage):
    MessageType = Ssh2Requests.Sign

    def __init__(self):
        self._flags = 0
        self._keyBlob = b''
        self._data = b''

    @property
    def key_blob(self):
        return self._keyBlob

    @key_blob.setter
    def key_blob(self, val):
        self._keyBlob = val

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, val):
        self._data = val

    @property
    def flags(self):
        return self._flags

    @flags.setter
    def flags(self, val):
        self._flags = val

    def write(self, stream):
        length = 1 + 4 + len(self._keyBlob) + 4 + len(self._data) + 4
        stream.writeBEInt(length)
        stream.writeFInt(self.MessageType, 1)
        stream.writeBEStr(self._keyBlob)
        stream.writeBEStr(self._data)
        stream.writeBEInt(self._flags)

    def load(self, stream, length):
        sAssert(length > 12, 'Protocol error')  # 3 uint32
        self._keyBlob = stream.readBEStr()
        length -= len(self._keyBlob) + 4
        sAssert(length > 8, 'Protocol error')
        self._data = stream.readBEStr()
        length -= len(self._data) + 4
        sAssert(length == 4, 'Protocol error')
        self._flags = stream.readBEInt()
        sAssert(self._flags == 0, 'Protocol error')
