﻿using Curse.CloudServices.Models;
using System;
using System.ServiceModel;
using System.ServiceModel.Channels;
using System.ServiceModel.Description;
using System.ServiceModel.Dispatcher;
using Curse.Logging;
using Curse.ServiceEncryption;

namespace Curse.CloudServices.Authentication
{
    [Flags]
    public enum AccessLevel
    {

        Anonymous = 0,
        Authenticated = 1,
        PremiumPromotion = 2,
        PremiumSubscription = 4,
        Moderator = 8,
        Administrator = 16,
        ApiKey = 32
    }

    public class RequiresAuthenticationAttribute : Attribute, IOperationBehavior, IParameterInspector
    {
        private static bool _levelsDisabled = false;

        public static void DisableLevels()
        {
            _levelsDisabled = true;
        }

        private readonly AccessLevel _accessLevels;

        public RequiresAuthenticationAttribute(AccessLevel roles)
        {
            _accessLevels = roles;
        }

        #region IOperationBehavior Members

        public void AddBindingParameters(OperationDescription operationDescription, BindingParameterCollection bindingParameters) { }
        
        public void ApplyClientBehavior(OperationDescription operationDescription, ClientOperation clientOperation) { }

        public void ApplyDispatchBehavior(OperationDescription operationDescription, DispatchOperation dispatchOperation)
        {
            dispatchOperation.ParameterInspectors.Add(this);
        }

        public void Validate(OperationDescription operationDescription) { }


        #endregion

        #region Authentication Failure Methods

        private void RaiseAuthenticationFailure(AuthenticationStatus status)
        {
            RaiseSoapAuthenticationFailure((int)status);            
        }
       
        private void RaiseSoapAuthenticationFailure(int status)
        {
            var subCode = new FaultCode("FailedAuthentication");            
            int statusCode = status;
            var faultReason = new FaultReason(statusCode.ToString());
            throw new FaultException(faultReason, subCode);
        }

        #endregion

        #region IParameterInspector Members

        private static readonly LogCategory Logger = new LogCategory("RequiresAuthenticationAttribute") { Throttle = TimeSpan.FromMinutes(10) };


        public object BeforeCall(string operationName, object[] inputs)
        {

            var authContext = AuthenticationContext.Current;

            if (authContext.IsAnonymous && _levelsDisabled)
            {
                return null;
            }

            if (_levelsDisabled || (_accessLevels | AccessLevel.Anonymous) == AccessLevel.Anonymous)
            {
                return null;
            }

            if ((_accessLevels | AccessLevel.ApiKey) == AccessLevel.ApiKey)
            {
                if (string.IsNullOrEmpty(authContext.ApiKey) || !authContext.ApiKey.Equals(AuthenticationConfiguration.ApiKey))
                {
                    Logger.Warn("Attempt to access API method with an invalid API key.");
                    RaiseAuthenticationFailure(AuthenticationStatus.InvalidApiKey);
                }

                return null;
            }

            // Request is anonymous, so raise an invalid session fault
            if (authContext.IsAnonymous || authContext.Token == null)
            {
                Logger.Warn("Attempt to access an access restricted endpoint with a missing or anonymous token.", new { Anonymous = authContext.IsAnonymous, Null = authContext.Token == null });
                RaiseAuthenticationFailure(AuthenticationStatus.InvalidSession);
            }

            AuthenticationTokenData tokenData = null;
            var storeToken = false;

            // First check the session cache
            if (AuthenticationConfiguration.HasTokenPersistence)
            {
                tokenData = AuthenticationConfiguration.AuthTokenStorageProvider.FindToken(authContext.Token);
            }
            
            // Decrypt the encrypted session
            if (tokenData == null)
            {
                storeToken = true;
                var token = EncryptionToken.FromValue(authContext.Token);

                if (!token.IsValid)
                {
                    if (!string.IsNullOrWhiteSpace(authContext.Token))
                    {
                        Logger.Warn("Failed authentication due to an invalid token: " + authContext.Token); 
                    }                    
                    RaiseAuthenticationFailure(AuthenticationStatus.InvalidSession);
                }

                tokenData = new AuthenticationTokenData
                {
                    UserID = token.GetInteger("UserID"),
                    Username = token.GetValue("Username"),                    
                    Timestamp = token.GetDateTime("Timestamp"),
                };
            }

            if (tokenData.UserID == 0 || string.IsNullOrEmpty(tokenData.Username))
            {
                Logger.Warn("Failed authentication due to invalid token values.", tokenData); 
                RaiseAuthenticationFailure(AuthenticationStatus.InvalidSession);
                return null;
            }

            if (authContext.UserID != tokenData.UserID)
            {
                Logger.Warn("Failed authentication due to a mismatch in token user ID"); 
                RaiseAuthenticationFailure(AuthenticationStatus.InvalidSession);
                return null;
            }

            if (DateTime.UtcNow - tokenData.Timestamp > AuthenticationConfiguration.TokenLifespan)
            {
                Logger.Warn("Failed authentication due to an expired token."); 
                RaiseAuthenticationFailure(AuthenticationStatus.ExpiredSession);
                return null;
            }

            authContext.Username = tokenData.Username;
            
            if (AuthenticationConfiguration.HasTokenPersistence && storeToken)
            {
                AuthenticationConfiguration.AuthTokenStorageProvider.StoreToken(authContext.Token, tokenData);
            }

            return null;
        }


        public void AfterCall(string operationName, object[] outputs, object returnValue, object correlationState) { }

        #endregion
    }
}