﻿using System;
using System.IO;
using System.Security.Cryptography;
using Curse.Logging;

namespace Curse.SocketInterface
{
    public static class CryptoHelper
    {
        #region RSA Key Exchange

        private static string _xmlPrivateKey;

        public static RSACryptoServiceProvider NewRsaProvider(string privateKeyPath = null)
        {
            var rsa = new RSACryptoServiceProvider();

            // Try to load an existing key
            if (_xmlPrivateKey != null)
            {
                try
                {
                    rsa.FromXmlString(_xmlPrivateKey);
                    return rsa;
                }
                catch (Exception ex)
                {
                    Logger.Error(ex);
                    _xmlPrivateKey = null;
                }
            }

            // No (valid) key loaded, try reading from disk
            if (privateKeyPath != null && File.Exists(privateKeyPath))
            {
                try
                {
                    var key = File.ReadAllText(privateKeyPath);
                    rsa.FromXmlString(key);
                    _xmlPrivateKey = key;
                    return rsa;
                }
                catch (Exception ex)
                {
                    Logger.Error(ex);
                }
            }

            // Generate a new key and try to save it
            _xmlPrivateKey = rsa.ToXmlString(true);
            if (privateKeyPath != null)
            {
                try
                {
                    File.WriteAllText(privateKeyPath, _xmlPrivateKey);
                }
                catch (Exception ex)
                {
                    Logger.Error(ex);
                }
            }
            return rsa;
        }
      
        public static byte[] CreateSecret(SymmetricAlgorithm algo, byte[] publicKey)
        {
            // Generate a secret AES key
            algo.GenerateKey();
            algo.GenerateIV();
            var secret = algo.Key;

            using (var rsa = new RSACryptoServiceProvider())
            {
                rsa.ImportPublicKey(publicKey);
                return rsa.Encrypt(secret, false);
            }
        }

        public static void SetSecret(SymmetricAlgorithm algo, RSACryptoServiceProvider rsa, byte[] data)
        {
            algo.Key = rsa.Decrypt(data, false);
        }

        #endregion RSA Key Exchange

        #region SymmetricAlgorithm

        public static byte[] Encrypt(this SymmetricAlgorithm algo, byte[] message)
        {
            lock (algo)
            {
                // Calculate the encrypted size to save a copy at the end
                var iv = algo.IV;
                var size = (message.Length + iv.Length*2)/iv.Length*iv.Length;

                using (var memoryStream = new MemoryStream(size))
                {
                    // Include the IV at the start
                    memoryStream.Write(iv, 0, iv.Length);

                    // Generate the cypher text
                    using (var encryptor = algo.CreateEncryptor())
                    using (var cryptoStream = new CryptoStream(memoryStream, encryptor, CryptoStreamMode.Write))
                    {
                        cryptoStream.Write(message, 0, message.Length);
                        cryptoStream.FlushFinalBlock();

                        // Get the output, avoiding a copy if our calculation was correct
                        var result = memoryStream.GetBuffer();
                        if (result.Length != memoryStream.Length)
                        {
                            Array.Resize(ref result, (int) memoryStream.Length);
                        }

                        // Use the last block as the next IV
                        Array.Copy(result, result.Length - iv.Length, iv, 0, iv.Length);
                        algo.IV = iv;

                        return result;
                    }
                }
            }
        }

        public static byte[] Decrypt(this SymmetricAlgorithm algo, byte[] message)
        {
            lock (algo)
            {
                using (var memoryStream = new MemoryStream(message))
                {
                    // Read in the IV
                    var iv = new byte[algo.BlockSize/8];
                    memoryStream.Read(iv, 0, iv.Length);
                    algo.IV = iv;

                    // Generate the plain text
                    using (var decryptor = algo.CreateDecryptor())
                    using (var cryptoStream = new CryptoStream(memoryStream, decryptor, CryptoStreamMode.Read))
                    {
                        return ReadToEnd(cryptoStream, message.Length);
                    }
                }
            }
        }

        private static byte[] ReadToEnd(Stream stream, int bufferSize)
        {
            var buffer = new byte[bufferSize];
            var offset = 0;

            int read;
            while ((read = stream.Read(buffer, offset, buffer.Length - offset)) != 0)
            {
                offset += read;
                if (offset >= buffer.Length)
                {
                    Array.Resize(ref buffer, buffer.Length * 2);
                }
            }

            if (offset < buffer.Length)
            {
                Array.Resize(ref buffer, offset);
            }

            return buffer;
        }

        #endregion SymmetricAlgorithm

        #region RSA Extensions

        private static readonly byte[] SshRsaPrefix = { 0x00, 0x00, 0x00, 0x07, 0x73, 0x73, 0x68, 0x2d, 0x72, 0x73, 0x61 };

        public static byte[] ExportPublicKey(this RSACryptoServiceProvider rsa)
        {
            var info = rsa.ExportParameters(false);

            var data = new byte[SshRsaPrefix.Length + 4 + info.Exponent.Length + 4 + info.Modulus.Length];
            Array.Copy(SshRsaPrefix, data, SshRsaPrefix.Length);

            var offset = SshRsaPrefix.Length;
            WriteWithPrefix(info.Exponent, data, ref offset);
            WriteWithPrefix(info.Modulus, data, ref offset);

            return data;
        }

        private static void WriteWithPrefix(byte[] source, byte[] destination, ref int offset)
        {
            WriteBigEndian(source.Length, destination, offset);
            Array.Copy(source, 0, destination, offset + 4, source.Length);

            offset += 4 + source.Length;
        }

        private static void WriteBigEndian(int value, byte[] destination, int offset)
        {
            destination[offset + 0] = (byte)((value >> 24) & 0xFF);
            destination[offset + 1] = (byte)((value >> 16) & 0xFF);
            destination[offset + 2] = (byte)((value >> 8) & 0xFF);
            destination[offset + 3] = (byte)(value & 0xFF);
        }

        private static void ImportPublicKey(this RSACryptoServiceProvider rsa, byte[] publicKey)
        {
            if (!CheckHeader(publicKey))
            {
                throw new ArgumentException("missing 'ssh-rsa' header", "publicKey");
            }

            var offset = SshRsaPrefix.Length;
            var e = ReadWithPrefix(publicKey, ref offset);
            var m = ReadWithPrefix(publicKey, ref offset);

            if (e == null || m == null)
            {
                throw new ArgumentException("invalid key format", "publicKey");
            }

            rsa.ImportParameters(new RSAParameters
            {
                Exponent = e,
                Modulus = m,
            });
        }

        private static bool CheckHeader(byte[] publicKey)
        {
            if (publicKey == null || publicKey.Length < SshRsaPrefix.Length)
            {
                return false;
            }

            for (var i = 0; i < SshRsaPrefix.Length; i++)
            {
                if (publicKey[i] != SshRsaPrefix[i])
                {
                    return false;
                }
            }

            return true;
        }

        private static byte[] ReadWithPrefix(byte[] source, ref int offset)
        {
            if (offset + 4 > source.Length)
            {
                return null;
            }
            var length = ReadBigEndian(source, offset);
            offset += 4;

            if (length < 0 || offset + length > source.Length)
            {
                return null;
            }
            var data = new byte[length];
            Array.Copy(source, offset, data, 0, length);
            offset += length;

            return data;
        }

        private static int ReadBigEndian(byte[] source, int offset)
        {
            return (source[offset + 0] << 24) |
                   (source[offset + 1] << 16) |
                   (source[offset + 2] << 8) |
                   (source[offset + 3]);
        }

        #endregion RSA Extensions
    }
}
