﻿using System;
using System.Net;
using System.Text;
using Curse.Logging;
using Curse.SocketInterface;

namespace Curse.WebRTC.Stun
{
    public static class StunMessage
    {
        private static readonly LogCategory Logger = new LogCategory("StunMessage")
            { Throttle = TimeSpan.FromSeconds(30) };

        public const int HeaderSize = 20;
        public const int AttrHeaderSize = 4;

        private const uint MagicCookie = 0x2112A442;
        private const uint FingerprintXor = 0x5354554E;

        private const int MinMessageSize = HeaderSize + FingerprintSize;

        public static bool IsValid(ByteBuffer buf)
        {
            // Make sure the packet length seems reasonable and matches the embedded length
            if ((buf.Count & 3) != 0 || buf.Count < MinMessageSize ||
                HeaderSize + buf.ToUInt16BE(2) != buf.Count)
                return false;

            // Make sure the first two bits (0) and the magic cookie (0x2112A442) have the right values
            if ((buf[0] & 0xA0) != 0 || buf.ToUInt32BE(4) != MagicCookie)
                return false;

            // Make sure the packet ends with a fingerprint attribute
            var offset = buf.Count - FingerprintSize;
            if (buf.ToUInt16BE(offset) != (ushort)StunAttributeType.Fingerprint ||
                buf.ToUInt16BE(offset + 2) != sizeof(uint))
                return false;

            // Check the fingerprint value
            var fingerprint = buf.ToUInt32BE(offset + 4) ^ FingerprintXor;
            var expected = Crc32.Compute(buf, 0, offset);
            return fingerprint == expected;
        }

        public static RTCPeerConnection Process(ByteBuffer buf)
        {
            // Test attribute traversal
            var offset = HeaderSize;
            while (offset + AttrHeaderSize <= buf.Count)
            {
                var attrLen = buf.ToUInt16BE(offset + 2);
                offset += GetAttrSize(attrLen);
            }
            if (offset != buf.Count)
            {
                Logger.Debug("Bad attribute framing detected");
                return null;
            }

            // Lookup the Connection object using the USERNAME field and then check MESSAGE_INTEGRITY (if present)
            var conn = default(RTCPeerConnection);
            offset = HeaderSize;
            while (offset < buf.Count)
            {
                var attrType = (StunAttributeType)buf.ToUInt16BE(offset);
                var attrLen = buf.ToUInt16BE(offset + 2);

                if (attrType == StunAttributeType.Username && conn == null)
                {
                    var user = Encoding.UTF8.GetString(buf.Buffer, buf.Offset + offset + 4, attrLen);

                    conn = RTCPeerConnection.Lookup(user);
                    if (conn == null)
                    {
                        Logger.Debug("No connection found", user);
                        return null;
                    }
                }
                else if (attrType == StunAttributeType.MessageIntegrity)
                {
                    if (conn == null)
                    {
                        Logger.Debug("Can't check MESSAGE_INTEGRITY without a connection");
                        return null;
                    }

                    // MESSAGE_INTEGRIGY is computed as if the packet ended after this attribute
                    buf.SetBytesBE((ushort)(offset + 24 - HeaderSize), 2);

                    // Validate (hash up to the start of this attribute)
                    var expectedHash = conn.ComputeLocalHash(buf, 0, offset);
                    if (expectedHash == null)
                    {
                        Logger.Debug("Can't check MESSAGE_INTEGRITY without local hash");
                        return null;
                    }
                    offset += AttrHeaderSize;
                    for (var i = 0; i < expectedHash.Length; i++)
                    {
                        if (buf[offset + i] != expectedHash[i])
                        {
                            Logger.Debug("Wrong hash in MESSAGE_INTEGRITY");
                            return null;
                        }
                    }

                    // Ignore anything after MESSAGE_INTEGRITY
                    return conn;
                }
                // We don't really care about other attributes for now

                offset += GetAttrSize(attrLen);
            }

            // If we get this far the message must not have had a MESSAGE_INTEGRITY attribute
            Logger.Debug("Invalid binding request (no MESSAGE_INTEGRITY)");
            return null;
        }

        private static int GetAttrSize(int attrLen)
        {
            var paddedLen = (attrLen + 3) & ~3; // round up to nearest multiple of 4
            return AttrHeaderSize + paddedLen;
        }

        private const int ReplyMessageLength = HeaderSize
                                               + XorMappedAddressIpv4Size
                                               + MessageIntegritySize
                                               + FingerprintSize;

        public static ByteBuffer Reply(RTCPeerConnection conn, SocketBuffer request)
        {
            // Compose a response
            var buf = new ByteBuffer(new byte[ReplyMessageLength], 0, ReplyMessageLength);

            // StunHeader
            buf.SetBytesBE((ushort)StunMessageType.BindingResponse, 0);
            // 2-byte length (filled out later)
            request.BlockCopy(4, buf, 4, HeaderSize - 4); // MagicCookie and TransactionID should remain the same
            var offset = HeaderSize;

            // Attributes
            offset = XorMappedAddressIpv4(ref buf, offset, request.Sender);
            offset = MessageIntegrity(ref buf, offset, conn.ComputeLocalHash);
            offset = Fingerprint(ref buf, offset);

#if DEBUG
            buf = new ByteBuffer(buf.Buffer, buf.Offset, offset);
            if (!IsValid(buf))
            {
                throw new Exception("Invalid STUN response");
            }
#endif
            return buf;
        }

        private const int RequestMessageLength = HeaderSize
                                                 //+ Username (variable length)
                                                 + IceControllingSize
                                                 + UseCandidateSize
                                                 + PrioritySize
                                                 + MessageIntegritySize
                                                 + FingerprintSize;

        public static ByteBuffer Request(RTCPeerConnection conn, IPEndPoint destination)
        {
            // Serialize the username
            var username = conn.RemoteUsername;
            var length = RequestMessageLength + GetAttrSize(username.Length);

            // Generate 5 random words for the TransactionID and TieBreaker
            var random1 = 0;
            var random2 = 0;
            var random3 = 0;
            var random4 = 0;
            var random5 = 0;


            // Compose a request
            var buf = new ByteBuffer(new byte[length], 0, length);

            // StunHeader
            buf.SetBytesBE((ushort)StunMessageType.BindingRequest, 0);
            // 2-byte length (filled out later)
            buf.SetBytesBE((uint)MagicCookie, 4);
            buf.SetBytesBE(random1, 8);
            buf.SetBytesBE(random2, 12);
            buf.SetBytesBE(random3, 16);
            var offset = HeaderSize;

            // Attributes
            offset = Username(ref buf, offset, username);
            offset = IceControlling(ref buf, offset, ((long)random4 << 32) | random5);
            offset = UseCandidate(ref buf, offset);
            offset = Priority(ref buf, offset);
            offset = MessageIntegrity(ref buf, offset, conn.ComputeRemoteHash);
            offset = Fingerprint(ref buf, offset);

#if DEBUG
            buf = new ByteBuffer(buf.Buffer, buf.Offset, offset);
            if (!IsValid(buf))
            {
                throw new Exception("Invalid STUN request");
            }
#endif
            return buf;
        }

        #region Builder Methods

        private const int XorMappedAddressIpv4Length = 8;
        private const int XorMappedAddressIpv4Size = AttrHeaderSize + XorMappedAddressIpv4Length;

        private static int XorMappedAddressIpv4(ref ByteBuffer buf, int offset, IPEndPoint sender)
        {
            // XOR-MAPPED-ADDRESS (IPv4 = 4 + 8 bytes)
            buf.SetBytesBE((ushort)StunAttributeType.XorMappedAddress, offset);
            buf.SetBytesBE((ushort)XorMappedAddressIpv4Length, offset + 2);
            buf.SetBytesBE((ushort)StunAddressFamily.IPv4, offset + 4);
            buf.SetBytesBE((ushort)(sender.Port ^ MagicCookie >> 16), offset + 6);
#pragma warning disable 618 // Obsolete
            var ip4 = (uint)IPAddress.HostToNetworkOrder((int)sender.Address.Address);
#pragma warning restore 618
            buf.SetBytesBE(ip4 ^ MagicCookie, offset + 8);

            return offset + XorMappedAddressIpv4Size;
        }

        private const int MessageIntegrityLength = 20;
        private const int MessageIntegritySize = AttrHeaderSize + MessageIntegrityLength;

        private static int MessageIntegrity(ref ByteBuffer buf, int offset, Func<ByteBuffer, int, int, byte[]> computeHash)
        {
            // Hash is computed assuming the message ends after this attribute
            var length = offset + MessageIntegritySize;
            buf.SetBytesBE((ushort)(length - HeaderSize), 2);

            // MESSAGE-INTEGRITY (4 + 20 bytes)
            buf.SetBytesBE((ushort)StunAttributeType.MessageIntegrity, offset);
            buf.SetBytesBE((ushort)MessageIntegrityLength, offset + 2);
            var hash = computeHash(buf, 0, offset);
            hash.BlockCopy(0, buf, offset + AttrHeaderSize, hash.Length);

            return length;
        }

        private const int FingerprintLength = 4;
        private const int FingerprintSize = AttrHeaderSize + FingerprintLength;

        private static int Fingerprint(ref ByteBuffer buf, int offset)
        {
            // CRC is computed assuming the message ends after this attribute
            var length = offset + FingerprintSize;
            buf.SetBytesBE((ushort)(length - HeaderSize), 2);

            // FINGERPRINT (4 + 4 bytes)
            buf.SetBytesBE((ushort)StunAttributeType.Fingerprint, offset);
            buf.SetBytesBE((ushort)FingerprintLength, offset + 2);
            var check = Crc32.Compute(buf, 0, offset) ^ FingerprintXor;
            buf.SetBytesBE((uint)check, offset + AttrHeaderSize);

            return length;
        }

        private static int Username(ref ByteBuffer buf, int offset, byte[] username)
        {
            // USERNAME (4 + X bytes)
            buf.SetBytesBE((ushort)StunAttributeType.Username, offset);
            buf.SetBytesBE((ushort)username.Length, offset + 2);
            username.BlockCopy(0, buf, offset + AttrHeaderSize, username.Length);

            return offset + GetAttrSize(username.Length);
        }

        private const int IceControllingLength = 8;
        private const int IceControllingSize = AttrHeaderSize + IceControllingLength;

        private static int IceControlling(ref ByteBuffer buf, int offset, long tieBreaker)
        {
            // ICE-CONTROLLING (4 + 8 bytes)
            buf.SetBytesBE((ushort)StunAttributeType.IceControlling, offset);
            buf.SetBytesBE((ushort)IceControllingLength, offset + 2);
            buf.SetBytesBE(tieBreaker, offset + AttrHeaderSize);

            return offset + IceControllingSize;
        }

        private const int UseCandidateLength = 0;
        private const int UseCandidateSize = AttrHeaderSize + UseCandidateLength;

        private static int UseCandidate(ref ByteBuffer buf, int offset)
        {
            // USE-CANDIDATE (4 + 0 bytes)
            buf.SetBytesBE((ushort)StunAttributeType.UseCandidate, offset);
            buf.SetBytesBE((ushort)UseCandidateLength, offset + 2);

            return offset + UseCandidateSize;
        }

        private const int PriorityLength = 4;
        private const int PrioritySize = AttrHeaderSize + PriorityLength;

        private const int PriorityValue = 1853817087;

        private static int Priority(ref ByteBuffer buf, int offset)
        {
            // PRIORITY (4 + 4 bytes)
            buf.SetBytesBE((ushort)StunAttributeType.Priority, offset);
            buf.SetBytesBE((ushort)PriorityLength, offset + 2);
            buf.SetBytesBE((uint)PriorityValue, offset + AttrHeaderSize);

            return offset + PrioritySize;
        }

        #endregion
    }
}
