﻿using Curse.LoadTests.Contracts;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Curse.Friends.Data;
using System.Security.Cryptography;
using System.Text;
using System;
using System.Collections.Concurrent;
using System.Configuration;
using System.Data.SqlClient;

using System.Data;
using System.Data.SQLite;
using System.Reflection;
using Aerospike.Client;
using Curse.Aerospike;
using System.Diagnostics;
using System.IO;
using Curse.LoadTests.Coordinator.Configuration;
using Curse.Logging;
using Newtonsoft.Json;

namespace Curse.LoadTests.Coordinator
{
    internal class LoadTestUser
    {
        private static readonly Dictionary<int, UserInfo[]> CachedUsers = new Dictionary<int, UserInfo[]>();

        private static readonly string ConnectionString;

        public static IEnumerable<int> UserIDs
        {
            get { return CachedUsers.SelectMany(p => p.Value).Select(p => p.UserID); }
        }

        static LoadTestUser()
        {
            ConnectionString = ConfigurationManager.ConnectionStrings["Users"].ConnectionString;
        }

        public static void EnsureUsers(Dictionary<int, int> numUsersByRegion, CancellationToken token)
        {
            LoadUsers(numUsersByRegion, token);
            EnsureUserDefaultPasswords(token);
            WriteUsersToFile(token);
        }

        private static SqlConnection GetConnection()
        {
            var conn = new SqlConnection(ConnectionString);
            try
            {
                conn.Open();
            }
            catch (Exception ex)
            {
                Console.WriteLine("Unable to open connection to database: " + ex.Message);
                throw;
            }

            return conn;
        }

        private static void EnsureUserDefaultPasswords(CancellationToken token)
        {
            var sw = Stopwatch.StartNew();
            //generate password hash and salt 
            var randSalt = new byte[5];
            var rand = new Random((int) DateTime.UtcNow.Ticks);
            rand.NextBytes(randSalt);
            var passwordHash = GetPasswordHash(LoadTestConstants.DefaultPassword, randSalt);


            if (token.IsCancellationRequested)
            {
                return;
            }

            Console.WriteLine(@"Overriding passwords for selected users.");

            // Bulk copy the userIDs to a temp table
            var userIDs = UserIDs;

            var datatable = new DataTable();
            datatable.Columns.Add("ID", typeof (int));
            datatable.BeginLoadData();

            foreach (var userID in userIDs)
            {
                datatable.Rows.Add(userID);
            }

            datatable.EndLoadData();

            using (var conn = GetConnection())
            {
                using (var cmd = conn.CreateCommand())
                {
                    cmd.CommandText = "CREATE TABLE [#TempUserIDs]([ID] [int] NOT NULL)";
                    cmd.ExecuteNonQuery();
                }

                using (var bulk = new SqlBulkCopy(conn, SqlBulkCopyOptions.TableLock, null))
                {
                    bulk.DestinationTableName = "#TempUserIDs";
                    bulk.WriteToServer(datatable);
                }

                using (var cmd = conn.CreateCommand())
                {
                    cmd.CommandTimeout = 0;
                    cmd.UpdatedRowSource = UpdateRowSource.None;
                    cmd.CommandText = "UPDATE users SET _password = @Password, _salt = @Salt from [users] inner join [#TempUserIDs] on [#TempUserIDs].[ID] = [users].[_uid]";
                    cmd.Parameters.AddWithValue("@Password", passwordHash);
                    cmd.Parameters.AddWithValue("@Salt", randSalt);
                    var rowsAffected = cmd.ExecuteNonQuery();
                    Console.WriteLine("Reset " + rowsAffected + " passwords");
                }
            }

            sw.Stop();
            Console.WriteLine(@"Password override completed in {0} seconds", sw.Elapsed.TotalSeconds);
        }

        private static byte[] GetPasswordHash(string password, byte[] salt)
        {
            byte[] strbytes = Encoding.UTF8.GetBytes(password);
            byte[] combined = new byte[strbytes.Length + salt.Length];

            Buffer.BlockCopy(salt, 0, combined, 0, salt.Length);
            Buffer.BlockCopy(strbytes, 0, combined, 5, strbytes.Length);
            SHA1 sha1 = new SHA1Managed();
            byte[] hash = sha1.ComputeHash(combined);
            return hash;
        }

        private static void LoadUsers(Dictionary<int, int> numUsersByRegion, CancellationToken token)
        {
            var tableDefinition = typeof (User).GetCustomAttribute<TableDefinitionAttribute>();

            var regionList = new List<AerospikeConfiguration> {User.LocalConfiguration};
            regionList.AddRange(User.RemoteConfigurations);

            foreach (var configuration in regionList)
            {
                if (token.IsCancellationRequested)
                {
                    break;
                }

                int numUsers;
                if (numUsersByRegion.TryGetValue(configuration.RegionIdentifier, out numUsers))
                {
                    Console.WriteLine(@"Loading Users for Region {0}", configuration.RegionKey);
                    var watch = new Stopwatch();
                    watch.Start();
                    CachedUsers[configuration.RegionIdentifier] = GetSomeUsers(configuration, numUsers, tableDefinition);
                    watch.Stop();
                    Console.WriteLine(@"{0} users loaded from Region {1} in {2} seconds.", numUsers,
                        configuration.RegionKey, watch.Elapsed.TotalSeconds);
                }
            }

            //CachedUsers[1] = GetUsersFromSql(numUsersByRegion[1]);
        }

        private static UserInfo[] GetUsersFromSql(int numUsers)
        {
            var users = new List<UserInfo>();
            using (var conn = GetConnection())
            {
                using (var cmd = conn.CreateCommand())
                {
                    cmd.CommandTimeout = 0;
                    cmd.UpdatedRowSource = UpdateRowSource.None;
                    cmd.CommandText = string.Format("SELECT TOP {0} * FROM [users]", numUsers);
                    using (var reader = cmd.ExecuteReader())
                    {
                        while (reader.Read())
                        {
                            users.Add(new UserInfo(reader.GetInt32(0), reader.GetString(1)));
                        }
                    }
                }
            }
            return users.ToArray();
        }

        private static UserInfo[] GetSomeUsers(AerospikeConfiguration config, int numUsers,
            TableDefinitionAttribute tableDefinition)
        {
            // TODO: Randomize at least a little bit?

            var users = new ConcurrentBag<User>();
            var currentIndex = 0;

            var baseKeySpace = (tableDefinition.KeySpace == null)
                ? AerospikeConfiguration.DefaultKeySpace.ToLower()
                : tableDefinition.KeySpace.ToLower();
            var keyspace = tableDefinition.ReplicationMode == ReplicationMode.None
                ? baseKeySpace + "-" + config.RegionGroup.ToLower()
                : baseKeySpace;

            var stmt = new Statement();
            stmt.SetNamespace(keyspace);
            stmt.SetSetName(tableDefinition.TableName);
            try
            {
                config.Client.ScanAll(null, keyspace, tableDefinition.TableName, (key, record) =>
                {
                    var model = new User();
                    model.Hydrate(config.Client, key, record, config.RegionIdentifier);
                    if (model.HasHydrationErrors)
                    {
                        return;
                    }

                    var num = Interlocked.Increment(ref currentIndex);

                    if (num <= numUsers)
                    {
                        users.Add(model);
                    }

                    if (users.Count >= numUsers)
                    {
                        throw new AerospikeException.ScanTerminated();
                    }
                });
            }
            catch (AerospikeException ex)
            {
                if (ex.InnerException != null)
                {
                    var innerException = ex.InnerException as AerospikeException;
                    if (innerException.Result != ResultCode.SCAN_TERMINATED)
                    {
                        throw;
                    }
                }
                else
                {
                    throw;
                }
            }
            return users.Select(u => new UserInfo(u.UserID, u.Username)).ToArray();
        }

        public static string[] GetUsersInRegion(int region, int numberOfRecords, int numberToSkip)
        {
            return CachedUsers.ContainsKey(region) ? CachedUsers[region].Skip(numberToSkip).Take(numberOfRecords).Select(u => u.Username).ToArray() : new string[0];
        }

        private static void WriteUsersToFile(CancellationToken token)
        {
            if (token.IsCancellationRequested)
            {
                return;
            }
            Console.WriteLine("writing users and regions to file");

            var currentDirectory = new DirectoryInfo(Assembly.GetExecutingAssembly().Location);
            var loadTestingDirectory = currentDirectory.Parent.Parent.Parent.Parent;
            var clientBinDirectory = new DirectoryInfo(Path.Combine(loadTestingDirectory.FullName, "Curse.LoadTests.Client", "bin", TestCoordinatorConfiguration.Mode.ToString()));
            var filename = Path.Combine(clientBinDirectory.FullName, LoadTestConstants.UsersFileName);

            if (File.Exists(filename))
            {
                File.Delete(filename);
            }

            using (var connection = new SQLiteConnection(string.Format(@"Data Source={0};", filename)))
            {
                connection.Open();

                using (var command = connection.CreateCommand())
                {
                    command.CommandText = "CREATE TABLE User (regionID INTEGER NOT NULL, rowIndex INTEGER NOT NULL, userID INTEGER UNIQUE NOT NULL, username VARCHAR(128) UNIQUE NOT NULL)";
                    command.ExecuteNonQuery();
                }

                using (var transaction = connection.BeginTransaction())
                {
                    using (var command = connection.CreateCommand())
                    {
                        command.Transaction = transaction;
                        command.CommandText = "INSERT INTO User (regionID, rowIndex, userID, username) VALUES (@regionID, @rowIndex, @userID, @username)";

                        foreach (var region in CachedUsers)
                        {
                            var rowIndex = 0;
                            foreach (var user in region.Value)
                            {
                                command.Parameters.Clear();
                                command.Parameters.AddWithValue("regionID", region.Key);
                                command.Parameters.AddWithValue("rowIndex", rowIndex++);
                                command.Parameters.AddWithValue("userID", user.UserID);
                                command.Parameters.AddWithValue("username", user.Username);
                                command.ExecuteNonQuery();
                            }
                        }
                    }

                    transaction.Commit();
                }
            }
        }

        private class UserInfo
        {

            public int UserID { get; private set; }
            public string Username { get; private set; }

            public UserInfo(int userID, string username)
            {
                UserID = userID;
                Username = username;
            }
        }
    }
}
