﻿using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.NetworkInformation;
using System.Text;
using System.Threading;
using Curse.Logging;
using Curse.SocketInterface;
using Curse.SocketMessages;
using Curse.Voice.Contracts;

namespace Curse.Voice.Client
{
    public enum ConnectionState
    {
        Disconnected = 1,
        Connecting,
        Joining,
        Handshaking,
        Connected
    }

    public class VoiceClient : ClientSocketInterface, IVoiceClient
    {
        private static readonly LogCategory Logger = new LogCategory("VoiceClient");

        bool IVoiceClient.IsConnected { get { return IsConnected; } }

        private uint _currentSequenceNumber;

        public string SessionID { get; set; }

        public string DisplayName { get; set; }

        public string AvatarUrl { get; set; }

        public int? UserID { get; set; }

        public string InGameName { get; set; }

        public string InGameRegion { get; set; }

        public string ClientVersion { get; set; }

        public long? AccessToken { get; set; }

        public string AuthToken { get; set; }

        public CodecInfo CodecInfo { get; set; }

        public ConnectionState ConnectionState { get; set; }

        public event EventHandler<SocketDisconnectEventArgs> DisconnectedSession;

        public event EventHandler<EventArgs<UserLeftNotification>> UserLeft;

        public event EventHandler<EventArgs<UserJoinedNotification>> UserJoined;

        public event EventHandler<EventArgs<UserUpdatedNotification>> UserUpdated;

        public event EventHandler<EventArgs<TransmitStartNotification>> UserStartTransmission;

        public event EventHandler<EventArgs<TransmitEndNotification>> UserEndTransmission;

        public event EventHandler<EventArgs<TransmitNotification>> VoiceTranmission;

        public event EventHandler<EventArgs<FailoverNotification>> Failover;

        public event EventHandler<EventArgs<ChatMessageNotification>> ChatMessageReceived;

        public event EventHandler<EventArgs<ChatMessageResponse>> ChatMessageResponse;

        public event EventHandler<EventArgs<GetUsersResponse>> UserListUpdated;

        public event EventHandler<EventArgs<TimeSpan>> PingUpdated;

        public event EventHandler<EventArgs<UserDisconnectNotification>> UserDisconnected;

        public event EventHandler<EventArgs<AddPendingUsersNotification>> AddPendingUsers;
        public event EventHandler<EventArgs<RemovePendingUserNotification>> RemovePendingUsers;

        public event EventHandler<EventArgs<VoiceInstanceChangedNotification>> VoiceInstanceChanged; 

        public bool LoopbackEnabled { get; private set; }

        public bool LocalLoopbackEnabled { get; private set; }

        private Thread _handshakeThread;

        private DateTime _lastHandshakeSent;

        public TimeSpan Ping { get; private set; }

        public static VoiceClient Connect(int[] ports, VoiceConnectionInfo connectionInfo, bool loopbackEnabled, string clientVersion, out JoinSessionResponse joinSessionResponse, bool localLoopbackEnabled = false)
        {
            joinSessionResponse = new JoinSessionResponse { Status = JoinSessionStatus.FailedUnhandledException };

            foreach (var port in ports)
            {
                Logger.Info("Attempting voice server connection on port " + port);
                var client = new VoiceClient(connectionInfo, port, loopbackEnabled, localLoopbackEnabled, clientVersion);
                joinSessionResponse = client.TryConnect();

                if (joinSessionResponse.Status == JoinSessionStatus.Timeout || joinSessionResponse.Status == JoinSessionStatus.FailedUnhandledException)
                {
                    Logger.Info("Attempt on port " + port + " failed");
                }
                else
                {
                    return client;
                }
            }

            // Try to ping the server:
            try
            {
                using (var pingTest = new Ping())
                {
                    var options = new PingOptions
                    {
                        DontFragment = true
                    };

                    // Create a buffer of 32 bytes of data to be transmitted.
                    var buffer = Encoding.ASCII.GetBytes("CurseVoicePingTest");
                    var reply = pingTest.Send(connectionInfo.IPAddress, 500, buffer, options);

                    if (reply != null)
                    {
                        Logger.Info("Ping test analysis of connectivity to " + connectionInfo.IPAddress, new { reply.Status, reply.RoundtripTime });
                    }
                    else
                    {
                        Logger.Warn("Ping test failed to return a reply.");
                    }
                }
            }
            catch (Exception ex)
            {
                Logger.Warn(ex, "Ping test failed with an unhandled exception!");
            }

            return null;
        }

        private VoiceClient(VoiceConnectionInfo connectionInfo, int port, bool loopbackEnabled, bool localLoopbackEnabled, string clientVersion)
            : base(connectionInfo.IPAddress, port)
        {
            ConnectionState = ConnectionState.Disconnected;

            SessionID = connectionInfo.InstanceID;
            DisplayName = connectionInfo.DisplayName;
            AvatarUrl = connectionInfo.AvatarUrl;
            UserID = connectionInfo.UserID;            
            InGameName = connectionInfo.InGameName;
            InGameRegion = connectionInfo.InGameRegion;
            LoopbackEnabled = loopbackEnabled;
            LocalLoopbackEnabled = localLoopbackEnabled;
            ClientVersion = clientVersion;
            AccessToken = connectionInfo.AccessToken;
            AuthToken = connectionInfo.AuthToken;
            CodecInfo = connectionInfo.CodecInfo;

            AddContractDispatcher<JoinSessionResponse>(OnJoinSession);
            AddContractDispatcher<Handshake>(OnHandshake);
            AddContractDispatcher<UserLeftNotification>(OnUserLeft);
            AddContractDispatcher<UserJoinedNotification>(OnUserJoined);
            AddContractDispatcher<UserUpdatedNotification>(OnUserUpdated);
            AddContractDispatcher<TransmitNotification>(OnVoiceTransmission);
            AddContractDispatcher<TransmitStartNotification>(OnTransmitStart);
            AddContractDispatcher<TransmitEndNotification>(OnTransmitEnd);
            AddContractDispatcher<FailoverNotification>(OnFailover);
            AddContractDispatcher<ChatMessageResponse>(OnChatMessageResponse);
            AddContractDispatcher<ChatMessageNotification>(OnChatMessageNotification);
            AddContractDispatcher<GetUsersResponse>(OnGetUsersResponse);
            AddContractDispatcher<UserDisconnectNotification>(OnUserDisconnected);
            AddContractDispatcher<VoiceInstanceChangedNotification>(OnVoiceInstanceChanged);
            AddContractDispatcher<AddPendingUsersNotification>(OnAddPendingUsersNotification);
            AddContractDispatcher<RemovePendingUserNotification>(OnRemovePendingUsersNotification);

            Disconnected += OnDisconnected;
            MessageReceived += OnMessageReceived;
        }

        
        private void SendHandshake()
        {
            _lastHandshakeSent = DateTime.UtcNow;
            SendPacket(new Handshake());            
        }

        private void OnHandshake(ISocketInterface ISocketInterface, Handshake handshake)
        {
            if (_lastHandshakeSent != default(DateTime))
            {
                Ping = DateTime.UtcNow - _lastHandshakeSent;
                _lastHandshakeSent = default(DateTime);

                var handler = PingUpdated;
                if (handler != null)
                {
                    handler(this, new EventArgs<TimeSpan>(Ping));
                }
            }
       
            if (ConnectionState == ConnectionState.Handshaking)
            {
                ConnectionState = ConnectionState.Connected;
                lock (_joinSync)
                {
                    if (_handshakeResetEvent != null)
                    {
                        _handshakeResetEvent.Set();
                    }
                }
            }
        }

        private void OnChatMessageNotification(ISocketInterface ISocketInterface, ChatMessageNotification chatMessageNotification)
        {
            if (ChatMessageReceived != null)
            {
                ChatMessageReceived(this, new EventArgs<ChatMessageNotification>(chatMessageNotification));
            }
        }

        private void OnChatMessageResponse(ISocketInterface ISocketInterface, ChatMessageResponse chatMessageResponse)
        {
            if (ChatMessageResponse != null)
            {
                ChatMessageResponse(this, new EventArgs<ChatMessageResponse>(chatMessageResponse));
            }
        }

        private void OnFailover(ISocketInterface ISocketInterface, FailoverNotification failoverNotification)
        {
            // Sweet, we have failover!!
            Logger.Info("Received failover command from voice host. Failover to host '" + failoverNotification.HostName + "' with IP address '" + failoverNotification.IPAddress + "' and port '" + failoverNotification.Port + "'");

            // We need to proactively disconnect from the existing host, clear the user list, then re-connect to the new host
            if (Failover != null)
            {
                Failover(this, new EventArgs<FailoverNotification>(failoverNotification));
            }
        }

        private void OnTransmitStart(ISocketInterface ISocketInterface, TransmitStartNotification notification)
        {
            if (UserStartTransmission != null)
            {
                UserStartTransmission(this, new EventArgs<TransmitStartNotification>(notification));
            }
        }

        private void OnTransmitEnd(ISocketInterface ISocketInterface, TransmitEndNotification notification)
        {
            if (UserEndTransmission != null)
            {
                UserEndTransmission(this, new EventArgs<TransmitEndNotification>(notification));
            }
        }

        private void OnUserUpdated(ISocketInterface ISocketInterface, UserUpdatedNotification userUpdatedNotification)
        {
            if (UserUpdated != null)
            {
                UserUpdated(this, new EventArgs<UserUpdatedNotification>(userUpdatedNotification));
            }
        }

        private void OnUserJoined(ISocketInterface ISocketInterface, UserJoinedNotification userJoinedNotification)
        {
            if (UserJoined != null)
            {
                UserJoined(this, new EventArgs<UserJoinedNotification>(userJoinedNotification));
            }
        }

        private void OnVoiceTransmission(ISocketInterface ISocketInterface, TransmitNotification voiceTransmissionResponse)
        {
#if SOCKET_LOGGING
            //TimeSpan elapsed = DateTime.UtcNow - new DateTime(voiceTransmissionResponse.Timestamp);
            //_asyncLogger.Log("Elapsed: " + elapsed.TotalMilliseconds.ToString("###,##0.0000"));
#endif
            if (VoiceTranmission != null)
            {
                VoiceTranmission(this, new EventArgs<TransmitNotification>(voiceTransmissionResponse));
            }
        }

        private void OnJoinSession(ISocketInterface ISocketInterface, JoinSessionResponse joinSessionResponse)
        {
            _joinResponse = joinSessionResponse;

            lock (_joinSync)
            {
                if (_joinResetEvent != null)
                {
                    _joinResetEvent.Set();
                }
            }
        }

        private void OnUserLeft(ISocketInterface ISocketInterface, UserLeftNotification userLeftNotification)
        {
            if (UserLeft != null)
            {
                UserLeft(this, new EventArgs<UserLeftNotification>(userLeftNotification));
            }
        }

        private void OnUserDisconnected(object sender, UserDisconnectNotification notification)
        {
            if (UserDisconnected != null)
            {
                UserDisconnected(this, new EventArgs<UserDisconnectNotification>(notification));
            }
        }

        private void OnVoiceInstanceChanged(ISocketInterface socket, VoiceInstanceChangedNotification notification)
        {
            if (VoiceInstanceChanged != null)
            {
                VoiceInstanceChanged(this, new EventArgs<VoiceInstanceChangedNotification>(notification));
            }
        }

        private void OnGetUsersResponse(ISocketInterface ISocketInterface, GetUsersResponse getUsersResponse)
        {
            if (UserListUpdated != null)
            {
                UserListUpdated(this, new EventArgs<GetUsersResponse>(getUsersResponse));
            }
        }

        private void OnAddPendingUsersNotification(ISocketInterface ISocketInterface, AddPendingUsersNotification notification)
        {
            if (AddPendingUsers != null)
            {
                AddPendingUsers(this, new EventArgs<AddPendingUsersNotification>(notification));
            }
        }

        private void OnRemovePendingUsersNotification(ISocketInterface ISocketInterface, RemovePendingUserNotification notification)
        {
            if (RemovePendingUsers != null)
            {
                RemovePendingUsers(this, new EventArgs<RemovePendingUserNotification>(notification));
            }
        }


        private readonly Dictionary<int, IContractDispatcher> _contractDispatchers = new Dictionary<int, IContractDispatcher>();

        public void AddContractDispatcher<T>(Action<ISocketInterface, T> handler) where T : Contract<T>, new()
        {
            int messageType = Contract<T>.MessageType;
            IContractDispatcher dispatcher = new ContractDispatcher<T>(handler);
            _contractDispatchers.Add(messageType, dispatcher);
        }

        private ManualResetEvent _joinResetEvent;
        private ManualResetEvent _handshakeResetEvent;
        private readonly object _joinSync = new object();
        private JoinSessionResponse _joinResponse;

        public JoinSessionResponse TryConnect(int millisecondsTimeout = 2000)
        {
            var response = new JoinSessionResponse { Status = JoinSessionStatus.Timeout };

            try
            {
                if (Connect(millisecondsTimeout))
                {
                    IncomingDataChannelProcessor.HandlePacket = HandleUdpPacket;

                    response = Join();

                    if (response.Status == JoinSessionStatus.Successful)
                    {
                        _handshakeThread = new Thread(HandshakeThread) { IsBackground = true, Name = "VoiceClientHandshake" };
                        _handshakeThread.Start();
                    }
                }
            }
            catch (Exception ex)
            {
                Logger.Error(ex, "Unhandled exception occurred while connecting to a voice server.");
            }
            finally
            {
                if (response.Status != JoinSessionStatus.Successful)
                {
                    Dispose();
                }
            }

            return response;
        }

        private bool HandleUdpPacket(BaseSocketInterface socketInterface, byte[] incomingData)
        {
            // If this is an expected Contract type, just use the default processing
            if (MessageFilter(incomingData))
            {
                return false;
            }

            // Otherwise, attempt to parse this as an RTP packet
            if (incomingData.Length < 12)
            {
                return false; // Not an RTP packet
            }

            // Demux various packet types (per RFC 5764 section 5.1.2)
            // RTP v2 check
            var b0 = incomingData[0];
            if ((b0 & 0xC0) != 0x80)
            {
                return false; // Not an RTP packet
            }

            var pt = incomingData[1];
            if ((pt & 0x60) == 0x40) // 64 <= payload type < 96
            {
                return false; // RTCP packet?
            }

            // Looks like RTP (packet format: https://tools.ietf.org/html/rfc3550#section-5.1)
            // Version has already been checked above
            var padding = (b0 & 0x20) != 0 ? incomingData[incomingData.Length - 1] : 0;
            var extension = (b0 & 0x10) != 0 ? 4 : 0; 
            var csrcCount = b0 & 0x0F;
            // We don't care about the payload type and marker bit (just assume it's 48kHz Opus data)

            // Check the packet size
            var offset = 12 + csrcCount*4 + extension;
            if (offset > incomingData.Length)
            {
                Logger.Warn("Not enough data in RTP packet");
                return false;
            }
            if (extension > 0)
            {
                offset += incomingData.ToUInt16BE(offset - 2)*4;
            }
            if (offset + padding > incomingData.Length)
            {
                Logger.Warn("Not enough data in RTP packet");
                return false;
            }

            // Get the data we really need
            var seqNum = incomingData.ToUInt16BE(2);
            var ts = incomingData.ToUInt32BE(4);
            var ssrc = incomingData.ToUInt32BE(8);
            var data = new byte[incomingData.Length - offset - padding];
            incomingData.BlockCopy(offset, data, 0, data.Length);

            OnVoiceTransmission(socketInterface, new TransmitNotification
            {
                SenderID = (int)ssrc,
                SequenceNumber = seqNum,
                Timestamp = ts,
                VoiceDataLength = data.Length,
                VoiceData = data,
            });
            return true;
        }

        private bool MessageFilter(byte[] incomingData)
        {
            // Check for TransmitRequest packets
            if (incomingData.Length < MessageHeader.HeaderSize)
                return false;

            // Validated the message size and encryption status
            var bodySize = incomingData.ToInt32LE(4);

            if (bodySize < 0 || MessageHeader.HeaderSize + bodySize != incomingData.Length)
                return false;

            // Check the message type
            var messageType = incomingData.ToInt32LE(8);
            var isSerialized = incomingData.ToBooleanLE(12);

            if (messageType == Handshake.MessageType && isSerialized == Handshake.IsSerialized && bodySize == 2)
            {
                // Process message
                return true;
            }

            if (messageType == TransmitRequest.MessageType && isSerialized == TransmitRequest.IsSerialized)
            {
                // See if this looks like a valid TransmitRequest
                var length = incomingData.ToInt32LE(MessageHeader.HeaderSize + 5);

                if (length < 0 || MessageHeader.HeaderSize + TransmitRequest.FixedSize + length != incomingData.Length)
                    return false;

                // Process message
                return true;
            }

            return false;
        }

        private JoinSessionResponse Join()
        {
            ConnectionState = ConnectionState.Joining;

            _joinResponse = null;

            var joinSessionRequest = new JoinSessionRequest
            {
                SessionID = SessionID,
                DisplayName = DisplayName,
                AvatarUrl = AvatarUrl,
                UserID = UserID,
                InGameName = InGameName,
                InGameRegion = InGameRegion,
                ClientVersion = ClientVersion,                
                AuthToken = AuthToken,
                AccessToken = AccessToken,
                CodecInfo = CodecInfo,
                WebRTCType = MultiStreamType.NullCipher,
            };

            try
            {
                using (_joinResetEvent = new ManualResetEvent(false))
                {
                    ConnectionState = ConnectionState.Connecting;

                    Logger.Info("Sending join session request...", new
                    {
                        joinSessionRequest.SessionID, 
                        joinSessionRequest.DisplayName,
                        joinSessionRequest.AvatarUrl,
                        joinSessionRequest.ClientVersion,
                        joinSessionRequest.InGameName,
                        joinSessionRequest.InGameRegion,
                        joinSessionRequest.AccessToken
                    });

                    SendContract(joinSessionRequest);
                    _joinResetEvent.WaitOne(2000); // Wait up to 2 seconds for the join call to complete
                }
                var resp = _joinResponse;
                if (resp == null)
                {
                    Logger.Warn("Join session request timed out.");
                    return new JoinSessionResponse { Status = JoinSessionStatus.Timeout };
                }

                if (resp.Status != JoinSessionStatus.Successful)
                {
                    Logger.Warn("Join session request failed.", new { resp.Status });
                    return _joinResponse;
                }

                // Try to handshake
                ClientID = resp.ClientID;
                ConnectionState = ConnectionState.Handshaking;

                // Attempt to handshake with the remote server (UDP). If it fails after 5 attempts, move on.
                using (_handshakeResetEvent = new ManualResetEvent(false))
                {
                    for (var i = 0; i < 5; i++)
                    {
                        _handshakeResetEvent.Reset();
                        Logger.Info("Sending handshake packet...", new { Attempt = i + 1 });
                        SendHandshake();
                        if (i == 0)
                        {
                            IncomingDataChannelProcessor.Start();
                        }
                        _handshakeResetEvent.WaitOne(1000); // Wait up to one second for the handshake to complete
                        if (ConnectionState == ConnectionState.Connected)
                        {
                            break;
                        }
                    }
                }

                if (ConnectionState != ConnectionState.Connected)
                {
                    Logger.Warn("Handshake timed out!");
                    return new JoinSessionResponse { Status = JoinSessionStatus.Timeout };
                }

                return resp;
            }
            catch (Exception ex)
            {
                Logger.Error(ex);
            }
            finally
            {
                lock (_joinSync)
                {
                    _joinResetEvent = null;
                    _handshakeResetEvent = null;
                }
            }

            return _joinResponse ?? new JoinSessionResponse { Status = JoinSessionStatus.Timeout };
        }

        public void StartTransmit()
        {
            var request = new TransmitStartRequest();
            SendContract(request);
        }

        public void EndTransmit()
        {
            var request = new TransmitEndRequest();
            SendContract(request);
        }

        public void UpdateUser(string displayName, string avatarUrl, string ingameName, string ingameRegion)
        {
            var request = new UpdateUserRequest()
            {
                DisplayName = displayName,
                AvatarUrl = avatarUrl,
                InGameName = ingameName,
                InGameRegion = ingameRegion
            };
            SendContract(request);
        }

        private uint GetNextSequenceNumber()
        {
            uint newSequenceNumber = ++_currentSequenceNumber;
            if (_currentSequenceNumber > uint.MaxValue - 1000)
            {
                newSequenceNumber = _currentSequenceNumber = 1;
            }

            return newSequenceNumber;
        }

        public void Transmit(byte[] data)
        {
            var seqNum = GetNextSequenceNumber();

            var buf = new byte[12 + data.Length];
            var samplesPerPacket = CodecInfo.SampleRate*CodecInfo.PacketMilliseconds/1000;

            // Add a minimal, but valid, RTP header
            buf[0] = 0x80; // Version = 2, Padding = False, Extension = False, #CSID = 0
            buf[1] = 111; // Marker = False, Payload Type = 111
            buf.SetBytesBE((ushort)(seqNum & 0xFFFF), 2); // Sequence Number
            buf.SetBytesBE((uint)(seqNum * samplesPerPacket), 4); // Timestamp (sample count)
            buf.SetBytesBE((uint)ClientID, 8); // SSRC
                
            // Payload
            Buffer.BlockCopy(data, 0, buf, 12, data.Length);

            SendPacket(Message.FromRawOutgoing(buf));

            if (LocalLoopbackEnabled)
            {
                OnVoiceTransmission(null, new TransmitNotification
                {
                    SenderID = 1,
                    SequenceNumber = seqNum,
                    Timestamp = 0,
                    VoiceData = data,
                    VoiceDataLength = data.Length,
                });
            }
        }

        private void OnMessageReceived(object sender, MessageEventArgs messageEventArgs)
        {
            if (!IsConnected)
            {
                return;
            }

            IContractDispatcher dispatcher = null;
            if (_contractDispatchers.TryGetValue(messageEventArgs.Message.Header.MessageType, out dispatcher))
            {
                dispatcher.Dispatch(this, messageEventArgs.Message);
            }
        }

        public void SendChatMessage(string body)
        {
            var request = new ChatMessageRequest()
            {
                Body = body,
            };

            SendContract(request);
        }

        public void GetUserList()
        {
            SendContract(new GetUsersRequest());
        }

        public void LeaveSession()
        {
            var request = new LeaveSessionRequest();
            SendContract(request);
        }

        private void OnDisconnected(object sender, SocketDisconnectEventArgs e)
        {
            if (ConnectionState == ConnectionState.Connected && DisconnectedSession != null)
            {
                DisconnectedSession(this, e);
            }
            ConnectionState = ConnectionState.Disconnected;
        }

        public override void Dispose()
        {
            if (_isDisposing)
            {
                return;
            }

            try
            {
                if (_handshakeThread != null)
                {                    
                    try
                    {
                        _stopHandshakeThread = true;
                        if (!_handshakeThread.Join(200))
                        {
                            Logger.Warn("Handshake thread failed to exit cleanly, aborting...");
                            _handshakeThread.Abort();
                        }
                    }
                    catch (Exception ex)
                    {
                        Logger.Error(ex, "Failed to abort handshake thread!");
                    }
                }

                base.Dispose();
            }
            catch (Exception ex)
            {
                Logger.Error(ex, "Failed to cleanly dispose CurseVoiceClient!");
            }
        }

        private volatile bool _stopHandshakeThread;
        private readonly TimeSpan _handshakeInterval = TimeSpan.FromSeconds(10);

        [DebuggerStepThrough]
        private void HandshakeThread()
        {
            var lastHandshake = DateTime.UtcNow.AddSeconds(5); // Start handshaking 5 seconds after connecting

            while (!_stopHandshakeThread)
            {                
                try
                {
                    Thread.Sleep(100); // Every 100 milliseconds, to eliminate the abort

                    if (DateTime.UtcNow - lastHandshake < _handshakeInterval)
                    {
                        continue;
                    }

                    if (!IsConnected)
                    {
                        continue;
                    }

                    lastHandshake = DateTime.UtcNow;
                    DoHandshake();
                }
                catch (ThreadAbortException) // Do nothing
                {
                }
                catch (Exception ex)
                {
                    Logger.Error(ex, "[VoiceClient] Handshake thread error");
                }
            }
        }

        private void DoHandshake()
        {
            try
            {
                // UDP handshake for router hole punching (and ping)
                SendHandshake();

                // TCP 'handshake' for keep alive
                SendContract(new TestContract() { Message = new byte[1] });
            }
            catch (Exception ex)
            {
                Logger.Error(ex, "Failed to send handshake! Disconnecting...");
                Disconnect(SocketDisconnectReason.HandshakeFailure);
            }
        }
    }
}
