import errno
import gevent
import msgpack
import struct
import time

try:
    import gevent.coros as coros
except ImportError:
    import gevent.lock as coros

from . import compression
from . import encryption
from .dc import determine_dc
from .logger import ResourceIdLoggerAdapter


class PeerLoggerAdapter(ResourceIdLoggerAdapter):
    def process(self, msg, kwargs):
        wuid = self.extra.get('wuid', ' ' * 8)
        wdesc = self.extra.get('wdesc', None)

        if not wdesc:
            wdesc = ''

        out = self.extra.get('out', None)
        if out is True:
            flag = 'ou'
        elif out is False:
            flag = 'in'
        else:
            flag = '??'

        if wdesc:
            return '[%s %-40s]  %s' % (flag, wdesc[:40], msg), kwargs
        else:
            return '[%8s] [%s %-40s]  %s' % (wuid[:8], flag, wdesc[:40], msg), kwargs


class PeerConnection(object):
    class EPIPE(Exception):
        pass

    def __init__(self, swarm, (ip, port), sock, connected=False, log=None):
        self.swarm = swarm
        self.log = log

        self.ip = ip
        self.port = port

        self.sock = sock
        self.lock = coros.Semaphore(1)

        self._connected = connected
        self._outgoing = False

        self.activated = False

        self.no_interrupt_rcv = coros.Semaphore(1)
        self.no_interrupt_snd = coros.Semaphore(1)

        self.msg_unpacker = msgpack.Unpacker()

        self.compressor = None
        self.decompressor = None
        self.encryptor = None
        self.decryptor = None

        self.log = PeerLoggerAdapter(self.log, {'peer': (ip, port), 'out': not connected})

        if port not in (6881, 16881):
            self.log.extra['wdesc'] = '%s : %r' % (ip, port)
        else:
            self.log.extra['wdesc'] = ip

        self.state_changed = gevent.event.Event

        # Initial state
        self.state = {
            'head_has': False,
            'head_want': False,
        }

        self.capabilities = set()
        self.snd_compression_mode = None

        self.callbacks = {}

        self.store_payload = False
        self.payload = []

        self.verified = gevent.event.AsyncResult()

        remote_dc = determine_dc(self.ip, self.log)
        current_dc = self.swarm.handle.world.current_dc
        self.is_local = current_dc == remote_dc

    def register_cb(self, typ, meth):
        self.callbacks.setdefault(typ, []).append(meth)

    def _call_cb(self, typ, *args, **kwargs):
        for cb in self.callbacks.get(typ, ()):
            cb(*args, **kwargs)

    def init_compressor(self, mode):
        self.compressor = compression.make_compressor(mode, self.log)

    def init_decompressor(self, mode):
        self.decompressor = compression.make_decompressor(mode, self.log)

    def init_encryptor(self, enc_params, key):
        self.encryptor = encryption.make_encryptor(enc_params, key, self.log)

    def init_decryptor(self, enc_params, key):
        self.decryptor = encryption.make_decryptor(enc_params, key, self.log)

    def connect(self):
        assert not self._connected, 'already connected'

        self.swarm.handle.world.connect_socket(
            self.sock, self.swarm.uid, self.ip, self.port,
            self.swarm.handle.net_priority
        )
        self._connected = True
        self._outgoing = True

    def start(self):
        return self

    def write_handshake(self, uid, world_uid, world_desc, dfs):
        # proto len
        # proto (len=proto len=19)
        #  8 byte flags
        # 20 byte infohash
        # 20 byte peer uid

        if len(world_uid) == 40:
            world_uid_hex = world_uid.decode('hex')
            world_uid_str = None
        else:
            world_uid_hex = None
            world_uid_str = world_uid

        # Capabilities:
        # - skybit: support for skybit proto
        # - http_1: support for downloading files via http
        # - dfs   : peer is a dfs peer, which is able to sent http links only
        # - dfs_http_range: all http links support Range header

        cap = ['skybit']

        if dfs:
            cap.extend(('dfs', 'dfs_http_range'))
        else:
            cap.append('http')

        cap.append(compression.compression_codecs_to_capability(
            self.swarm.compression_params['reply']
        ))
        cap.append(encryption.encryption_modes_to_capability(
            self.swarm.encryption_config
        ))

        handshake_data = {
            'version': 1,
            'uid': uid.decode('hex'),
            'world_uid_hex': world_uid_hex,
            'world_uid_str': world_uid_str,
            'world_desc': world_desc,
            'capabilities': cap
        }

        if self.swarm.compression_params['request']:
            handshake_data.update({
                'compression': self.swarm.compression_params['request']
            })

        handshake_msg = self.make_message('SKYBIT_HANDSHAKE_1', handshake_data)

        buff = struct.pack(
            '!6sI',
            'SKYBIT', len(handshake_msg)
        )

        self.log.debug('MSG: HANDSHAKE header=%d payload=%d', len(buff) - 4, len(handshake_msg) + 4)

        with self.no_interrupt_snd:
            try:
                self._sendall(buff)
                self._sendall(handshake_msg)
            except self.swarm.Deactivate as ex:
                self.log.critical('write_handshake: interrupted with Deactivate exception: %s', str(ex))
                raise

    def read_handshake(self, payload=None):
        nullres = None, None, None

        try:
            if payload:
                buff = payload[:6]
                payload = payload[6:]
            else:
                buff = self._recv(6)
        except gevent.socket.error as ex:
            if ex.errno == errno.ECONNRESET:
                raise self.swarm.NoSkybit()
            else:
                raise

        if len(buff) == 0:
            raise self.swarm.NoSkybit()

        if len(buff) < 6:
            self.log.debug('RCV: HANDSHAKE HEADER TOO SHORT (%d bytes)', len(buff))
            return nullres

        if buff == 'aaaaaa':
            # Pings from skycore
            return nullres

        if buff != 'SKYBIT':
            self.log.debug('RCV: HANDSHAKE HEADER INVALID PROTO (%r)', buff)
            return nullres

        if self.store_payload:
            self.payload.append(buff)

        return self.read_extended_handshake(payload)

    def read_extended_handshake(self, payload=None):
        nullres = None, None, None
        msgtype, message = self.get_message(payload)

        if msgtype == 'SKYBIT_HANDSHAKE_1':
            if message.get('version', None) != 1:
                return nullres

            infohash = message['uid'].encode('hex')
            assert message['world_uid_hex'] or message['world_uid_str']
            if message['world_uid_hex'] is not None:
                world_uid = message['world_uid_hex'].encode('hex')
            else:
                world_uid = message['world_uid_str']
            world_desc = message['world_desc']
        else:
            return nullres

        self.capabilities = set(message['capabilities'])
        self.snd_compression_mode = message.get('compression')

        self.log.extra['wuid'] = world_uid

        if 'wdesc' in self.log.extra:
            self.log.debug('switch desc to %s', world_desc)

        self.log.extra['wdesc'] = world_desc
        self.log.debug('RCV: HANDSHAKE OK')

        return infohash, world_uid, world_desc

    def make_message(self, msgtype, msg):
        return msgpack.dumps((msgtype, msg))

    def _sendall(self, data, payload=False):
        sent = 0
        total = len(data)
        bucket = self.swarm.handle.world.bucket_ou

        while sent < total:
            sent_now = self.sock.send(memoryview(data)[sent:])
            if bucket:
                bucket.leak(sent_now)
            sent += sent_now

        self.swarm.count_sent_bytes(len(data), payload=payload)

    def _recv(self, count, payload=False):
        bucket = self.swarm.handle.world.bucket_in
        if bucket:
            bucket.leak(count)

        ret = self.sock.recv(count)
        self.swarm.count_recv_bytes(len(ret), payload=payload)
        return ret

    def send_raw_data(self, data):
        with self.no_interrupt_snd:
            try:
                self._sendall(data, payload=True)
            except self.swarm.Deactivate as ex:
                self.log.critical('send_raw_data: interrupted with Deactivate exception: %s', str(ex))
                raise

    def send_raw_message(self, message):
        with self.no_interrupt_snd:
            try:
                self._sendall(struct.pack('!I', len(message)))
                self._sendall(message)
            except self.swarm.Deactivate as ex:
                self.log.critical('send_raw_message: interrupted with Deactivate exception: %s', str(ex))
                raise

    def send_message(self, msgtype, msg, hint=''):
        if not self.sock:
            raise self.EPIPE('Connection closed (no sock)')

        self.log.debug('SND: %s %s', msgtype.upper(), hint)
        try:
            return self.send_raw_message(self.make_message(msgtype, msg))

        except gevent.socket.error as ex:
            if ex.errno == errno.EPIPE:
                raise self.EPIPE('Connection closed (broken pipe)')
            elif ex.errno == errno.ECONNRESET:
                raise self.EPIPE('Connection closed (resed by peer)')
            else:
                raise

        except gevent.GreenletExit:
            raise

        except BaseException as ex:
            self.log.warning('Unable to send message %r with %s', msgtype, ex)

    def msg_ping(self):
        self.send_message('PING', 'PING')

    def msg_stop(self, reason_code, reason_text):
        self.send_message('STOP', (reason_code, reason_text), hint=reason_code)

    def msg_stop_abort(self, reason_code, reason_text):
        self.msg_stop(reason_code, reason_text)

    def msg_head_has(self):
        self.send_message('HEAD_HAS', 1)

    def msg_head_want(self):
        # For now we are asuming we already asked
        self.send_message('HEAD_WANT', 1)

    def msg_head(self, head):
        self.send_message('HEAD', head)

    def msg_piece_map_req(self):
        self.send_message('PIECEMAPREQ', 1)

    def msg_piece_map(self, pieces):
        self.send_message('PIECEMAP', pieces)

    def msg_piece_req(self, idx, buf):
        self.send_message('PIECEREQ', idx, idx)

    def msg_piece(self, idx, data):
        if self.compressor:
            data = self.compressor.compress(data)
        if self.encryptor:
            data = self.encryptor.encrypt(data)
        self.send_message('PIECE', (idx, len(data)), '%d (%d bytes)' % (idx, len(data)))
        self.send_raw_data(data)

    def msg_have(self, idx):
        self.send_message('HAVE', idx, idx)

    def msg_file_links(self, links):
        self.send_message('LINKS', links)

    def msg_auth(self, data):
        self.send_message('AUTH', data)

    def msg_verify(self, data):
        self.send_message('VERIFY', data)

    def subscribe(self, msgtype, cb):
        self.callbacks.setdefault(msgtype, []).append(cb)

    def process_messages(self, piece_shmem_segments):
        try:
            while True:
                msgtype, payload = self.get_message()
                if not msgtype:
                    if not self.sock:
                        return 'CLOSED_BOTH', 'connection closed by both sides'
                    else:
                        return 'CLOSED_REMOTE', 'connection closed by remote side'

                elif msgtype == 'PING':
                    self.log.debug('RCV: PING')

                # Legacy stoppers:
                elif msgtype == 'NO_SLOT':
                    self.log.debug('RCV: NO_SLOT')
                    return 'NO_SLOT', 'connection closed, because of no_slot'

                elif msgtype == 'NO_RESOURCE':
                    self.log.debug('RCV: NO_RESOURCE')
                    return 'NO_RESOURCE', 'connection closed, because of no_resource'

                elif msgtype == 'NO_NEED':
                    self.log.debug('RCV: NO_NEED')
                    return 'NO_NEED', 'connection closed, because of no_need'

                elif msgtype == 'NO_ACTIVATE':
                    self.log.debug('RCV: NO_ACTIVATE')
                    return 'CLOSED_NOACT', 'connection closed, because of no_activate'

                elif msgtype == 'DE_ACTIVATE':
                    self.log.debug('RCV: DE_ACTIVATE')
                    return 'CLOSED_DEACT', 'connection closed, because of de_activate'

                elif msgtype == 'STOP':
                    reason_code, reason_text = payload
                    self.log.debug('RCV: STOP (%s, %s)', reason_code, reason_text)
                    return reason_code, 'connection closed: %s: %s' % (reason_code, reason_text)

                elif msgtype == 'HEAD_HAS':
                    self.log.debug('RCV: HEAD_HAS')
                    self.state['head_has'] = True
                    self.state['head_want'] = False
                    self._call_cb('head_has')

                elif msgtype == 'HEAD_WANT':
                    self.log.debug('RCV: HEAD_WANT')
                    self.state['head_want'] = True
                    self.state['head_has'] = False
                    self._call_cb('head_want')

                elif msgtype == 'HEAD':
                    self.log.debug('RCV: HEAD')
                    self._call_cb('head', head=payload)

                elif msgtype == 'LINKS':
                    self.log.debug('RCV: LINKS (count=%d)', len(payload))
                    self._call_cb('links', payload)

                elif msgtype == 'PIECEMAPREQ':
                    self.log.debug('RCV: PIECEMAPREQ')
                    self._call_cb('piecemapreq')

                elif msgtype == 'PIECEMAP':
                    self.log.debug('RCV: PIECEMAP')
                    self._call_cb('piecemap', pieces=payload)

                elif msgtype == 'PIECEREQ':
                    self.log.debug('RCV: PIECEREQ %r', payload)
                    self._call_cb('piecereq', idx=payload)

                elif msgtype == 'PIECE':
                    idx = payload[0]
                    num_bytes = payload[1]
                    self.log.debug('RCV: PIECE %r (%d bytes, start)', idx, num_bytes)
                    done = self.get_data(num_bytes, piece_shmem_segments[idx])
                    if done:
                        self._call_cb('piece', idx, num_bytes)
                    self.log.debug('RCV: PIECE %r (%d bytes, finish)', idx, num_bytes)

                elif msgtype == 'HAVE':
                    self.log.debug('RCV: HAVE %r', payload)
                    self._call_cb('have', idx=payload)

                elif msgtype == 'AUTH':
                    self.log.debug('RCV: AUTH')
                    self._call_cb('auth', data=payload)

                elif msgtype == 'VERIFY':
                    self.log.debug('RCV: VERIFY')
                    self._call_cb('verify', sign=payload)

                else:
                    self.log.debug('RCV: %s (unknown)', msgtype)
        except gevent.GreenletExit:
            return
        except Exception as ex:
            self.log.error('got error while processing messages: %s', ex)
        finally:
            self.close()

    def get_message(self, payload=None):
        try:
            if payload:
                llen = payload[:4]
                payload = payload[4:]
            else:
                llen = self._recv(4)
            if not llen:
                return None, None
            if self.store_payload:
                self.payload.append(llen)
            llen = struct.unpack('!I', llen)[0]

            with self.no_interrupt_rcv:
                try:
                    while llen:
                        if payload:
                            data = payload[:min(64 * 1024, llen)]
                            payload = payload[len(data):]
                        else:
                            data = self._recv(min(64 * 1024, llen))
                        if not data:
                            self.log.error('get_message: connection broken while reading message')
                            return None, None
                        if self.store_payload:
                            self.payload.append(data)
                        llen -= len(data)
                        self.msg_unpacker.feed(data)
                except self.swarm.Deactivate as ex:
                    self.log.critical('get_message: interrupted with Deactivate exception: %s', str(ex))
                    raise

                msgtype, message = self.msg_unpacker.next()

            return msgtype, message
        except Exception as ex:
            if self.sock:
                self.log.warning('get_message: error: %s', ex)
            return None, None

    def get_data(self, num_bytes, memory_segment):
        with self.no_interrupt_rcv:
            assert memory_segment.offset == memory_segment.size == 0, 'Memory segment already used'

            sock_out_stream = memory_segment
            if self.decompressor:
                self.decompressor.start(out_stream=sock_out_stream)
                sock_out_stream = self.decompressor
            if self.decryptor:
                self.decryptor.start(out_stream=sock_out_stream)
                sock_out_stream = self.decryptor

            total_bytes = num_bytes
            time_start = time.time()
            try:
                while num_bytes:
                    data = self._recv(min(8 * 1024, num_bytes), payload=True)
                    if not data:
                        return
                    num_bytes -= len(data)
                    sock_out_stream.write(data)
            except self.swarm.Deactivate as ex:
                self.log.critical('get_data: interrupted with Deactivate exception: %s', str(ex))
                raise
            finally:
                self.swarm.register_recv_speed(total_bytes, time.time() - time_start, self.is_local)
                if self.decryptor:
                    self.decryptor.finish()
                if self.decompressor:
                    self.decompressor.finish()
                memory_segment.rewind()

            return True

    def close(self):
        if not self.sock:
            return

        try:
            self.log.debug('CLOSE')
            self.sock.close()
            self.sock = None
        except Exception as ex:
            self.log.warning('Unable to close connection: %s', ex)
