﻿using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using Curse.Logging;
using Curse.ServiceClients.Contracts;
using Curse.SocketInterface;
using Curse.SocketMessages;
using Curse.Voice.Contracts;
using Curse.WebRTC;
using Curse.WebRTC.Stun;

namespace Curse.Voice.HostRuntime
{
    public class RTCRelayServer
    {
        private static readonly LogCategory Logger = new LogCategory("RTCRelayServer") 
            { Throttle = TimeSpan.FromSeconds(30) };
        
        private const byte OpusAudioPayloadType = 111;

        public static bool IsAudioPacket(ByteBuffer rtpPacket)
        {
            return (rtpPacket[1] & 0x7F) == OpusAudioPayloadType;
        }

        public delegate ISocketInterface ClientLookupFunc(int clientID);

        private ClientLookupFunc _clientLookup;
        private RTCSocketPair[] _sockets;

        public IPAddress Address { get; private set; }

        public IList<RTCSocketPair> Sockets { get { return _sockets; } }

        public void Start(BufferManager bufferManager, IPEndPoint[] endPoints, int maxConnections, ClientLookupFunc clientLookup)
        {
            _clientLookup = clientLookup;

            _sockets = new RTCSocketPair[endPoints.Length];
            for (var i = 0; i < endPoints.Length; ++i)
            {
                var port = endPoints[i].Port;
                _sockets[i] = new RTCSocketPair(port, port + 1, bufferManager, maxConnections, Receive);
            }

            var ipHostInfo = Dns.GetHostEntry(Dns.GetHostName());
            var ipAddress = ipHostInfo.AddressList.FirstOrDefault(addr => addr.AddressFamily == AddressFamily.InterNetwork);
            Address = ipAddress;
        }

        public void ShutdownAll()
        {
            foreach (var pair in _sockets)
            {
                pair.Shutdown();
            }
        }

        private void Receive(RTCSocketPair sockets, SocketBuffer msg, bool rtcp)
        {
            if (msg.Count <= 0)
                return;

            try
            {
                // Check for TransmitRequest packets
                if (MessageFilter(sockets, msg, rtcp))
                    return;
            }
            catch (Exception ex)
            {
                Logger.Warn(ex, "Error in MessageFilter");
                return;
            }

            // Demux various packet types (per RFC 5764 section 5.1.2)
            var b0 = msg[0];
            if (b0 < 2)
            {
                // STUN
                if (!StunMessage.IsValid(msg.Bytes))
                {
                    Logger.Debug("Dropping invalid STUN message", msg.Sender);
                    return;
                }

                // Ignore everything except binding requests for now
                var type = (StunMessageType)msg.Bytes.ToUInt16BE(0);
                if (type != StunMessageType.BindingRequest)
                {
#if CONFIG_STAGING || CONFIG_DEBUG
                    if (type != StunMessageType.BindingResponse)
                    {
                        Logger.Debug("Dropping STUN packet: " + type);
                    }
#endif
                    return;
                }

                // For now, we really only care about responding to binding requests
                var conn = StunMessage.Process(msg.Bytes);
                if (conn == null)
                {
                    Logger.Debug("Dropping STUN with no associated peer connection", msg.Sender);
                    return;
                }
                try
                {
                    OnHandshake(((VoiceSession)conn.Owner).ServerSocket);
                    conn.SendStunReply(sockets, msg, rtcp);
                }
                catch (Exception ex)
                {
                    Logger.Warn(ex, "Error processing STUN message", conn.GetErrorDetails());
                }
            }
            else if (19 < b0 && b0 < 64)
            {
                // DTLS
                var conn = RTCPeerConnection.Lookup(msg.Sender);
                if (conn == null)
                {
                    Logger.Debug("Dropping DTLS with no associated peer connection", msg.Sender);
                    return;
                }
                try
                {
                    conn.ProcessDtls(sockets, msg);
                }
                catch (Exception ex)
                {
                    Logger.Warn(ex, "Error in ProcessDtls", conn.GetErrorDetails());
                }
            }
            else if (127 < b0 && b0 < 192)
            {
                // RTP/RTCP
                var conn = RTCPeerConnection.Lookup(msg.Sender);
                if (conn == null)
                {
                    Logger.Debug("Dropping RTP with no associated peer connection", msg.Sender);
                    return;
                }
                try
                {
                    conn.ProcessRtp(sockets, msg);

                    // Repeated binding requests may not be sent during transmission, so update handshake timestamp here as well
                    OnHandshake(((VoiceSession)conn.Owner).ServerSocket);
                }
                catch (Exception ex)
                {
                    Logger.Warn(ex, "Error in ProcessRtp", conn.GetErrorDetails());
                }
            }
            else
            {
                // Unknown
#if CONFIG_STAGING || CONFIG_DEBUG
                Logger.Debug("Dropping unknown packet", new { msg.Sender, msg.Count, FirstByte = b0 });
#endif
            }
        }

        private bool MessageFilter(RTCSocketPair sockets, SocketBuffer msg, bool rtcp)
        {
            // Check for TransmitRequest packets
            if (msg.Count < MessageHeader.HeaderSize)
                return false;

            // Validated the message size and encryption status
            var bodySize = msg.ToInt32LE(4);

            if (bodySize < 0 || MessageHeader.HeaderSize + bodySize != msg.Count)
                return false;

            // Check the message type
            var messageType = msg.ToInt32LE(8);
            var isSerialized = msg.ToBooleanLE(12);

            if (messageType == TransmitRequest.MessageType && isSerialized == TransmitRequest.IsSerialized)
            {
                // See if this looks like a valid TransmitRequest
                var length = msg.ToInt32LE(MessageHeader.HeaderSize + 5);

                if (length < 0 || MessageHeader.HeaderSize + TransmitRequest.FixedSize + length != msg.Count)
                    return false;

                // Process message
                var clientID = msg.ToInt32LE(0);

                var client = _clientLookup(clientID) ?? VoiceWebSocketService.Lookup(clientID);
                if (client == null)
                {
#if CONFIG_STAGING || CONFIG_DEBUG
                    Logger.Debug("TransmitRequest received for unknown client: " + clientID);
#endif
                }
                else
                {
                    var loopbackEnabled = msg.ToBooleanLE(MessageHeader.HeaderSize);

                    var session = (VoiceSession)client.Session;
                    session.LoopbackEnabled = loopbackEnabled;
                    session.UpdateSockets(sockets, msg.Sender);

                    OnTransmitRequest(session, msg.Bytes);
                }
                return true;
            }
            
            if (messageType == Handshake.MessageType && isSerialized == Handshake.IsSerialized && bodySize == 2)
            {
                // Process message
                var clientID = msg.ToInt32LE(0);

                var client = _clientLookup(clientID) ?? VoiceWebSocketService.Lookup(clientID);
                if (client == null)
                {
#if CONFIG_STAGING || CONFIG_DEBUG
                    Logger.Debug("Handshake received for unknown client: " + clientID);
#endif
                }
                else
                {
                    OnHandshake(client);

                    var session = (VoiceSession)client.Session;
                    session.UpdateSockets(sockets, msg.Sender);

                    session.UdpChannel.Sockets = sockets;
                    var buf = new byte[msg.Count];
                    msg.BlockCopy(0, buf, 0, msg.Count);
                    session.UdpChannel.StartSend(msg.Sender, new ByteBuffer(buf, 0, buf.Length), rtcp);
                }
                return true;
            }

            return false;
        }

        private static void OnHandshake(ISocketInterface client)
        {
            var isFirstHandshake = !client.IsHandshaken;

            client.IsHandshaken = true;
            client.DateLastHandshake = Environment.TickCount;

            if (isFirstHandshake) // If this is the first handshake, we must raise the client connected event
            {
                var session = (VoiceSession)client.Session;
                session.Parent.ClientConnected(session);
            }
        }

        private void OnTransmitRequest(IRTCStream stream, ByteBuffer buf)
        {
            var seqNum = buf.ToUInt32LE(MessageHeader.HeaderSize + 1);

            const int offset = MessageHeader.HeaderSize + TransmitRequest.FixedSize - 12; // RTP header size
            buf = new ByteBuffer(buf.Buffer, buf.Offset + offset, buf.Count - offset);

            // Add a minimal, but valid, RTP header
            buf[0] = 0x80; // Version = 2, Padding = False, Extension = False, #CSID = 0
            buf[1] = OpusAudioPayloadType; // Marker = False, Payload Type = 111
            buf.SetBytesBE((ushort)(seqNum & 0xFFFF), 2); // Sequence Number
            buf.SetBytesBE((uint)(seqNum * stream.SamplesPerPacket), 4); // Timestamp (sample count)
            buf.SetBytesBE((uint)stream.SSRC, 8); // SSRC

            stream.ProcessRtp(buf);
        }
    }

    static class RTCPeerConnectionEx
    {
        public static object GetErrorDetails(this RTCPeerConnection conn)
        {
            if (conn == null)
                return null;

            var session = conn.Owner as VoiceSession;
            if (session == null)
            {
                return new
                {
                    conn.EndPoint,
                    conn.IsConnected,
                    conn.IsDisposed,
                };
            }

            var socket = session.ServerSocket;
            if (socket == null)
            {
                return new
                {
                    conn.EndPoint,
                    conn.IsConnected,
                    conn.IsDisposed,
                    conn.Owner.SSRC,
                    conn.Owner.MSID,

                    SessionID = session.ID,
                    SessionIsConnected = session.IsConnected,
                    SessionIsPendingDisconnect = session.IsPendingDisconnect,
                    SessionUseWebRTC = session.UseWebRTC,
                    SessionUseSdp = session.UseSdp,
                };
            }

            return new
            {
                conn.EndPoint,
                conn.IsConnected,
                conn.IsDisposed,
                conn.Owner.SSRC,
                conn.Owner.MSID,

                SessionID = session.ID,
                SessionIsConnected = session.IsConnected,
                SessionIsPendingDisconnect = session.IsPendingDisconnect,
                SessionUseWebRTC = session.UseWebRTC,
                SessionUseSdp = session.UseSdp,

                session.ServerSocket.ClientID,
                session.ServerSocket.DateLastHandshake,
                session.ServerSocket.IsAuthenticated,
                session.ServerSocket.IsHandshaken,
                session.ServerSocket.RemoteAddress,
            };
        }
    }
}
