﻿using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using Curse.Logging;
using Curse.SocketInterface;
using Curse.Voice.Contracts;
using Curse.WebRTC.Stun;

namespace Curse.WebRTC
{
    public class RTCPeerConnection : IDisposable
    {
        private static readonly LogCategory Logger = new LogCategory("RTCPeerConnection")
            { Throttle = TimeSpan.FromSeconds(30) };

        public static void Init()
        {
            Dtls.Init();
            Srtp.Init();
        }

        #region Connection Registry

        private static readonly ConcurrentDictionary<string, RTCPeerConnection> _connectionsByUsername =
            new ConcurrentDictionary<string, RTCPeerConnection>();

        private static readonly ConcurrentDictionary<IPEndPoint, RTCPeerConnection> _connectionsByEndPoint =
            new ConcurrentDictionary<IPEndPoint, RTCPeerConnection>();

        public static RTCPeerConnection Lookup(string username)
        {
            RTCPeerConnection conn;
            _connectionsByUsername.TryGetValue(username, out conn);
            return conn;
        }

        public static RTCPeerConnection[] All
        {
            get { return _connectionsByUsername.Values.ToArray(); }
        }

        public static int Count
        {
            get { return _connectionsByUsername.Count; }
        }

        public static RTCPeerConnection Lookup(IPEndPoint endPoint)
        {
            RTCPeerConnection conn;
            _connectionsByEndPoint.TryGetValue(endPoint, out conn);
            return conn;
        }

        private static readonly byte[] Empty = new byte[0];

        private bool UpdateRemoteUFrag(string ufrag)
        {
            // If it's the same as the current value, do nothing
            if (_remoteIceUFrag == ufrag)
                return true;

            // Remove from collections
            var value = default(string);
            try
            {
                var prev = _remoteIceUFrag;
                var prevKey = _localIceUFrag + ":" + prev;

                // Unregister the existing value
                if (prev != null)
                {
                    RTCPeerConnection conn;
                    if (_connectionsByUsername.TryRemove(prevKey, out conn) && conn != this)
                    {
                        Logger.Warn("Another connection was registered for our ufrag (this shouldn't happen)",
                            new { Key = prevKey });

                        if (!_connectionsByUsername.TryAdd(prevKey, conn))
                        {
                            Logger.Warn("Failed to re-add conflicting connection");
                            conn._remoteIceUFrag = null;
                            conn.RemoteUsername = Empty;
                        }
                        return false; // set _remoteIceUFrag to null
                    }
                }

                value = ufrag;
                var ufragKey = _localIceUFrag + ":" + ufrag;

                // Register the new value
                if (ufrag != null)
                {
                    if (!_connectionsByUsername.TryAdd(ufragKey, this))
                    {
                        Logger.Warn("Another connection is already using this IPEndPoint", new { Key = ufragKey });
                        return false; // set _remoteIceUFrag to null
                    }
                    value = ufrag;
                }

                return true; // set _remoteIceUFrag to ufrag
            }
            finally
            {
                _remoteIceUFrag = value;
                RemoteUsername = value != null ? Encoding.UTF8.GetBytes(value + ":" + _localIceUFrag) : Empty;
            }
        }

        public bool UpdateEndPoint(IPEndPoint endPoint, bool makeDefault = false)
        {
            try
            {
                // Bail quickly in the common case
                if (Equals(EndPoint, endPoint))
                    return true;

                lock (_endPoints)
                {
                    // This means we should un-register all endpoints
                    if (endPoint == null)
                    {
                        return RemoveAllEndPoints();
                    }

                    // Always make the first end point the default
                    if (EndPoint == null)
                    {
                        makeDefault = true;
                    }

                    // Check if there's anything we actually need to do
                    var shouldAdd = !_endPoints.Contains(endPoint);
                    if (!makeDefault && !shouldAdd)
                    {
                        return true;
                    }

                    // Make a copy so it can't be muted elsewhere
                    endPoint = new IPEndPoint(endPoint.Address, endPoint.Port);

                    // Add to our list of allowed endpoints if needed
                    if (shouldAdd && _endPoints.Add(endPoint))
                    {
                        // Register the new value
                        if (!_connectionsByEndPoint.TryAdd(endPoint, this))
                        {
                            Logger.Warn("Another connection is already using this IPEndPoint " + endPoint);
                            _endPoints.Remove(endPoint);
                            return false;
                        }
                    }

                    // Update the default route
                    if (makeDefault)
                    {
                        EndPoint = endPoint;
                    }
                    return true;
                }
            }
            catch (Exception ex)
            {
                Logger.Error(ex, "Failed to update end point " + endPoint);
                return false;
            }
        }

        private bool RemoveAllEndPoints()
        {
            var success = true;
            foreach (var ep in _endPoints)
            {
                try
                {
                    RTCPeerConnection conn;
                    if (_connectionsByEndPoint.TryRemove(ep, out conn) && conn != this)
                    {
                        // This shouldn't happen
                        Logger.Error("Another connection was registered for our IPEndPoint " + ep);

                        if (!_connectionsByEndPoint.TryAdd(ep, conn))
                        {
                            // This really shouldn't happen
                            Logger.Error("Failed to re-add conflicting connection");
                            if (Equals(conn.EndPoint, ep))
                            {
                                conn.EndPoint = null;
                            }
                            success = false;
                        }
                    }
                }
                catch (Exception ex)
                {
                    Logger.Error(ex, "Failed to un-register IPEndPoint " + ep);
                    success = false;
                }
            }

            _endPoints.Clear();
            EndPoint = null;
            return success;
        }

        #endregion Connection Registry

        private static long _sessionCount;

        private readonly object _syncRoot = new object();
        private readonly IRTCStream _owner;
        private readonly MultiStreamType _type;
        private readonly RTCChannelProcessor _processor;

        private readonly long _sessionId;
        private int _sessionVersion = 1;

        private readonly string _localIceUFrag;
        private readonly string _localIcePwd;

        private string _remoteIceUFrag;
        private string _remoteIcePwd;

        private HMACSHA1 _localHmac;
        private HMACSHA1 _remoteHmac;

        private Dtls _dtls;
        private Srtp _srtpIn;
        private Srtp _srtpOut;

        private readonly HashSet<IPEndPoint> _endPoints = new HashSet<IPEndPoint>();

        public byte[] RemoteUsername { get; private set; }

        public uint AudioSourceSSRC { get; set; }
        public uint VideoSourceSSRC { get; set; }

        public RTCPeerConnection(IRTCStream owner, MultiStreamType type, RTCChannelProcessor processor)
        {
            _owner = owner;
            _type = type;
            _processor = processor;

            RemoteUsername = Empty;

            IsConnected = _type == MultiStreamType.NullCipher; // No DTLS required

            // The rest is only needed for SDP generation
            if (!UseSdp)
            {
                return;
            }

            _sessionId = Interlocked.Increment(ref _sessionCount) & 0x7FFFFFFFFFFFFFFFL; // non-negative, unique to this machine

            using (var rng = new RNGCryptoServiceProvider())
            {
                var bytes = new byte[30];
                rng.GetBytes(bytes);

                _localIceUFrag = Convert.ToBase64String(bytes, 0, 12);
                _localIcePwd = Convert.ToBase64String(bytes, 12, 18);
            }
        }

        public bool UseSdp { get { return _type != MultiStreamType.NullCipher; } }

        public IRTCStream Owner { get { return _owner; } }

        public IPEndPoint EndPoint { get; private set; }

        public bool IsConnected { get; private set; }

        public bool IsDisposed { get; private set; }

        public void Close()
        {
            lock (_syncRoot)
            {
                if (IsDisposed)
                {
                    Logger.Warn("Close called when connection was already closed");
                    return;
                }

                IsDisposed = true;
                IsConnected = false;

                // Unregister
                UpdateRemoteUFrag(null);
                UpdateEndPoint(null);

                // Cleanup unmanaged resources
                if (_srtpIn != null)
                {
                    _srtpIn.Dispose();
                    _srtpIn = null;
                }
                if (_srtpOut != null)
                {
                    _srtpOut.Dispose();
                    _srtpOut = null;
                }
                if (_dtls != null)
                {
                    _dtls.Dispose();
                    _dtls = null;
                }

                if (_localHmac != null)
                {
                    _localHmac.Dispose();
                    _localHmac = null;
                }
                if (_remoteHmac != null)
                {
                    _remoteHmac.Dispose();
                    _remoteHmac = null;
                }
            }
        }

        void IDisposable.Dispose()
        {
            Close();
        }

        public string CreateOffer(IList<RTCSocketPair> sockets, IEnumerable<IRTCStream> streams)
        {
            if (!UseSdp)
            {
                Logger.Warn("Attempt to create offer when SDP is disabled");
                return "";
            }

            switch (_type)
            {
                case MultiStreamType.Unified:
                    return CreateUnifiedOffer(sockets, streams);

                case MultiStreamType.PlanB:
                    return CreatePlanBOffer(sockets, streams);
            }

            // Fallback to PlanB for now
            return CreatePlanBOffer(sockets, streams);
        }


        private string CreateUnifiedOffer(IList<RTCSocketPair> sockets, IEnumerable<IRTCStream> streams)
        {
            var sb = new StringBuilder(2000);

            // Basic header info
            sb.Append("v=0\r\n");
            sb.AppendFormat("o=- {0} {1} IN IP4 0.0.0.0\r\n", _sessionId, Interlocked.Increment(ref _sessionVersion));
            sb.Append("s=-\r\n");
            sb.Append("t=0 0\r\n");
            sb.Append("a=ice-lite\r\n");
            sb.Append("a=group:BUNDLE");
            foreach (var stream in streams)
            {
                sb.AppendFormat(" audio_u{0}", stream.SSRC);
            }
            sb.Append("\r\n");
            sb.Append("a=msid-semantic: WMS *\r\n");

            var media = new StringBuilder(2000);

            var first = sockets[0];

            // Connection info for the audio stream(s)
            media.AppendFormat("m=audio {0} UDP/TLS/RTP/SAVPF 111\r\n", first.RtpPort);
            media.AppendFormat("c=IN IP4 0.0.0.0\r\n");
            media.AppendFormat("a=rtcp:{0} IN IP4 0.0.0.0\r\n", first.RtcpPort);

            var i = 0;
            foreach (var socket in sockets)
            {
                ++i;
                media.AppendFormat("a=candidate:{0} 1 UDP 2122260223 0.0.0.0 {1} typ host\r\n", i, socket.RtpPort);
            }

            // Generic RTP audio settings
            media.Append("a=setup:actpass\r\n");
            media.Append("a=extmap:1 urn:ietf:params:rtp-hdrext:ssrc-audio-level\r\n");
            media.Append("a=extmap:3 http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time\r\n");
            media.Append("a=rtcp-mux\r\n");
            media.Append("a=rtpmap:111 opus/48000/2\r\n");
            media.Append("a=fmtp:111 minptime=40; useinbandfec=1\r\n");
            media.Append("a=ptime:40\r\n");
            media.Append("a=maxptime:40\r\n");
            media.Append("a=sendrecv\r\n");

            var mediaStr = media.ToString();

            // Streams in "Unified" (aka Firefox) format
            foreach (var stream in streams)
            {
                sb.Append(mediaStr);

                // Authentication info
                sb.AppendFormat("a=ice-ufrag:{0}\r\n", _localIceUFrag);
                sb.AppendFormat("a=ice-pwd:{0}\r\n", _localIcePwd);
                sb.AppendFormat("a=fingerprint:{0}\r\n", Dtls.Fingerprint);

                sb.AppendFormat("a=mid:audio_u{0}\r\n", stream.SSRC);
                sb.AppendFormat("a=msid:audio u{0}\r\n", stream.SSRC);
                sb.AppendFormat("a=ssrc:{0} cname:curse.com\r\n", stream.SSRC);
            }

            return sb.ToString();
        }

        private string CreatePlanBOffer(IList<RTCSocketPair> sockets, IEnumerable<IRTCStream> streams)
        {
            var first = sockets[0];

            var sb = new StringBuilder(2000);

            // Basic header info
            sb.Append("v=0\r\n");
            sb.AppendFormat("o=- {0} {1} IN IP4 0.0.0.0\r\n", _sessionId, Interlocked.Increment(ref _sessionVersion));
            sb.Append("s=-\r\n");
            sb.Append("t=0 0\r\n");
            sb.Append("a=ice-lite\r\n");
            sb.Append("a=group:BUNDLE audio\r\n");
            sb.Append("a=msid-semantic: WMS *\r\n");

            // Authentication info
            sb.AppendFormat("a=ice-ufrag:{0}\r\n", _localIceUFrag);
            sb.AppendFormat("a=ice-pwd:{0}\r\n", _localIcePwd);
            sb.AppendFormat("a=fingerprint:{0}\r\n", Dtls.Fingerprint);

            // Connection info for the audio stream(s)
            sb.AppendFormat("m=audio {0} UDP/TLS/RTP/SAVPF 111\r\n", first.RtpPort);
            sb.AppendFormat("c=IN IP4 0.0.0.0\r\n");
            sb.AppendFormat("a=rtcp:{0} IN IP4 0.0.0.0\r\n", first.RtcpPort);

            var i = 0;
            foreach (var socket in sockets)
            {
                ++i;
                sb.AppendFormat("a=candidate:{0} 1 udp 2122260223 0.0.0.0 {1} typ host generation 0\r\n", i, socket.RtpPort);
                sb.AppendFormat("a=candidate:{0} 2 udp 2122260222 0.0.0.0 {1} typ host generation 0\r\n", i, socket.RtcpPort);
            }

            // Generic RTP audio settings
            sb.Append("a=setup:actpass\r\n");
            sb.Append("a=mid:audio\r\n");
            sb.Append("a=extmap:1 urn:ietf:params:rtp-hdrext:ssrc-audio-level\r\n");
            sb.Append("a=extmap:3 http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time\r\n");
            sb.Append("a=rtcp-mux\r\n");
            sb.Append("a=rtpmap:111 opus/48000/2\r\n");
            sb.Append("a=fmtp:111 minptime=40; useinbandfec=1\r\n");
            sb.Append("a=ptime:40\r\n");
            sb.Append("a=maxptime:40\r\n");
            sb.Append("a=sendrecv\r\n");

            // Streams in "Plan B" (aka Chrome) format
            foreach (var stream in streams)
            {
                if (stream.SSRC == Owner.SSRC)
                    continue;

                sb.AppendFormat("a=ssrc:{0} cname:curse.com\r\n", stream.SSRC);
                sb.AppendFormat("a=ssrc:{0} msid:{1}\r\n", stream.SSRC, stream.MSID);
            }

            return sb.ToString();
        }

        public void SetAnswer(string answer)
        {
            if (!UseSdp)
            {
                Logger.Warn("Attempt to set answer when SDP is disabled");
                return;
            }

            var ufrag = GetSdpValue(answer, "ice-ufrag:");
            var pwd = GetSdpValue(answer, "ice-pwd:");

            lock (_syncRoot)
            {
                if (IsDisposed)
                {
                    Logger.Warn("Attempt to set answer when connection is closed");
                    return;
                }

                if (!UpdateRemoteUFrag(ufrag))
                    return;
                _remoteIcePwd = pwd;

                _localHmac = new HMACSHA1(Encoding.ASCII.GetBytes(_localIcePwd));
                _remoteHmac = new HMACSHA1(Encoding.ASCII.GetBytes(_remoteIcePwd));
            }
        }

        private static string GetSdpValue(string sdp, string key)
        {
            var start = sdp.IndexOf(key, StringComparison.InvariantCulture) + key.Length;
            var end = sdp.IndexOf('\r', start);
            return sdp.Substring(start, end - start);
        }

        public byte[] ComputeLocalHash(ByteBuffer buf, int offset, int length)
        {
            lock (_syncRoot)
            {
                if (IsDisposed || _localHmac == null)
                    return null;

                return _localHmac.ComputeHash(buf.Buffer, buf.Offset + offset, length);
            }
        }

        public byte[] ComputeRemoteHash(ByteBuffer buf, int offset, int length)
        {
            lock (_syncRoot)
            {
                if (IsDisposed || _remoteHmac == null)
                    return null;

                return _remoteHmac.ComputeHash(buf.Buffer, buf.Offset + offset, length);
            }
        }

        public void SendStunReply(RTCSocketPair sockets, SocketBuffer request, bool rtcp)
        {
            lock (_syncRoot)
            {
                if (IsDisposed)
                    return;

                _processor.Sockets = sockets;
                UpdateEndPoint(request.Sender);

                // Reply to the binding request
                var response = StunMessage.Reply(this, request);
                _processor.StartSend(EndPoint, response, rtcp);

                // And then send a binding request of our own
                var bindingRequest = StunMessage.Request(this, request.Sender);
                _processor.StartSend(EndPoint, bindingRequest, rtcp);
            }
        }

        public void ProcessDtls(RTCSocketPair sockets, SocketBuffer msg)
        {
            if (_type == MultiStreamType.NullCipher)
            {
                Logger.Warn("Received DTLS message for unencrypted connection");
                return;
            }

            lock (_syncRoot)
            {
                if (IsDisposed)
                    return;

                if (_dtls == null)
                    _dtls = new Dtls();

                _dtls.HandleTraffic(_processor, msg);

                if (_dtls.IsConnected && !IsConnected)
                {
                    _dtls.CreateSrtp(out _srtpIn, out _srtpOut);

                    IsConnected = true;
                }
            }
        }

        public void ProcessRtp(RTCSocketPair sockets, SocketBuffer msg)
        {
            if (msg.Count < 2)
                return;

            var pt = msg[1] & 0x7F;
            if (63 < pt && pt < 96)
            {
                // RTPC
                var length = msg.Count;
                lock (_syncRoot)
                {
                    if (!IsConnected || IsDisposed)
                        return;

                    if (_type != MultiStreamType.NullCipher)
                    {
                        _srtpIn.UnprotectRtcp(msg.Buffer, msg.Offset, ref length);
                    }
                }

                var buf = new ByteBuffer(msg.Buffer, msg.Offset, length);
                Owner.ProcessRtcp(buf);
            }
            else
            {
                // RTP
                var length = msg.Count;
                lock (_syncRoot)
                {
                    if (!IsConnected || IsDisposed)
                        return;

                    if (_type != MultiStreamType.NullCipher)
                    {
                        _srtpIn.Unprotect(msg.Buffer, msg.Offset, ref length);
                    }
                }

                UpdateEndPoint(msg.Sender, true);

                var buf = new ByteBuffer(msg.Buffer, msg.Offset, length);
                Owner.ProcessRtp(buf);
            }
        }

        public void SendRtp(ByteBuffer data)
        {
            lock (_syncRoot)
            {
                if (!IsConnected || IsDisposed)
                    return;

                var length = data.Count;
                var buf = new byte[length + Srtp.MaxChecksumSize];
                data.BlockCopy(0, buf, 0, length);

                if (_type != MultiStreamType.NullCipher)
                {
                    _srtpOut.Protect(buf, 0, ref length);
                }

                _processor.StartSend(EndPoint, new ByteBuffer(buf, 0, length));
            }
        }

        public void SendRtcp(ByteBuffer data)
        {
            lock (_syncRoot)
            {
                if (!IsConnected || IsDisposed)
                    return;

                var length = data.Count;
                var buf = new byte[length + Srtp.MaxChecksumSize];
                data.BlockCopy(0, buf, 0, length);

                if (_type != MultiStreamType.NullCipher)
                {
                    _srtpOut.ProtectRtcp(buf, 0, ref length);
                }
                
                _processor.StartSend(EndPoint, new ByteBuffer(buf, 0, length));
            }
        }

        public void SendPli()
        {
            lock (_syncRoot)
            {
                if (!IsConnected || IsDisposed)
                    return;

                // Manually build the packet data
                var len = 12; // PLI = 12, FIR = 20
                var buf = new byte[len + Srtp.MaxChecksumSize];

                // RTCP Header
                buf[0] = 0x81; // PLI = 1
                buf[1] = 206; // Payload Specific
                buf[2] = 0;
                buf[3] = 2;
                buf.SetBytesBE((uint)0, 4); // Sender doesn't matter
                buf.SetBytesBE((uint)VideoSourceSSRC, 8);

                if (_type != MultiStreamType.NullCipher)
                {
                    _srtpOut.ProtectRtcp(buf, 0, ref len);
                }

                //Logger.Debug("Sending PLI request to " + VideoSourceSSRC);
                _processor.StartSend(EndPoint, new ByteBuffer(buf, 0, len));
            }
        }

        public void ForwardNack(ByteBuffer nack)
        {
            lock (_syncRoot)
            {
                if (!IsConnected || IsDisposed)
                    return;

                var len = nack.Count;
                var buf = new byte[len + Srtp.MaxChecksumSize];
                nack.BlockCopy(0, buf, 0, len);

                buf.SetBytesBE((uint)0, 4); // Sender doesn't matter
                buf.SetBytesBE((uint)VideoSourceSSRC, 8);

                if (_type != MultiStreamType.NullCipher)
                {
                    _srtpOut.ProtectRtcp(buf, 0, ref len);
                }

                //Logger.Debug("Sending NACK request to " + VideoSourceSSRC);
                _processor.StartSend(EndPoint, new ByteBuffer(buf, 0, len));
            }
        }
    }
}
