import re
import time
from logging import Logger
from typing import Any, Dict, List, Optional, Tuple

import attr
import boto3
from botocore.exceptions import ClientError
from sagemaker.exceptions import UnexpectedStatusException
from sagemaker.session import Session, production_variant
from sagemaker.vpc_utils import sanitize

from octarine.clients.TwitchVXModelRegistryTwirp import (
    DEPLOYED,
    DeploymentConfig,
    Model,
    SageMakerDeploymentConfig,
    TwitchVXModelRegistryLambdaClient,
    UpdateModelRequest,
)

from .logger import _get_module_logger

# The length limit of the name of model and sagemaker config
# the value is fixed, check API docs:
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html#sagemaker-CreateEndpointConfig-request-EndpointConfigName
MODEL_NAME_MAX_LENGTH = 63
ENDPOINT_CONFIG_NAME_MAX_LENGTH = 63
ENDPOINT_NAME_MAX_LENGTH = 63


class EndpointNotReady(Exception):
    """Exception raised for endpoints that are not ready to be updated


    Attributes:
        endpoint_name (str): the name of the SageMaker endpoint
    """

    def __init__(self, endpoint_name: str):
        self.endpoint_name = endpoint_name
        message = "the SageMaker endpoint {} is not ready to be updated".format(
            self.endpoint_name
        )
        super().__init__(message)


def _fits_length(_instance, _attribute, value):
    if len(value) > ENDPOINT_NAME_MAX_LENGTH:
        raise ValueError("endpoint name exceeds length limitation")


def _autoscaling_policy_name(endpoint_name: str) -> str:
    return f"{endpoint_name}-scaling-policy"


def _autoscaling_resource_id(endpoint_name: str, variant_name: str) -> str:
    return f"endpoint/{endpoint_name}/variant/{variant_name}"


@attr.s(kw_only=True, auto_attribs=True, frozen=True)
class AutoScalingConfig:
    """The autoscaling config for the Sagemaker endpoint.

    The autoscaling config should be immutable once initialized. The config is using the pre-defined metrics
    number of invocations per instance to scale up the endpoint


    Attributes:
        num_invocations_per_instance (int): the number of invocation per instance needed for the endpoint to scale up, it needs to breach the threshold for 3 consecutive minutes
        min_instance_count (int): Optional; the minimum instance count that autoscaling can scale down, the default is 1 instances
        max_instance_count (int): the maximum instance count that autoscaling can scale up to
        scale_in_cooldown (int): Optional; the cooldown in seconds between each scale in activity, the default is 300 seconds
        scale_out_cooldown (int): Optional; the cooldown in seconds between each scale out activity, the default is 300 seconds
        disable_scale_in (bool): Optional; whether disable scaling in on the endpoint, the default is False
    """

    num_invocations_per_instance: int
    min_instance_count: int = 1
    max_instance_count: int
    scale_in_cooldown: int = 300
    scale_out_cooldown: int = 300
    disable_scale_in: bool = False


@attr.s(kw_only=True, auto_attribs=True, frozen=True)
class Config:
    """The deployment config to deploy a model to a SageMaker endpoint.

    The deployment config should is immutable once initialized. It should be used in
    deploy_model function to deploy a model to a SageMaker endpoint


    Attributes:
        endpoint_name (str): the name of the SageMaker endpoint
        execution_role_arn (str): the role SageMaker will assume when creating endpoint
        model_s3_location (str): the s3 location of the model artifact
        container_url (str): the inference image url
        variant_name (str): Optional; the name of the variant for this model, the default value is 'default'
        model_name (str): Optional; the name of the model we will create on SageMaker. If not passed in, it will be a concatenation of model family and instance id
        initial_instance_count (int): Optional; the number of instances that host the model, default 1
        instance_type (str): Optional; the type of the instance that hosts the model, default ml.m5.large
        env_variables (dict): Optional; the environment variables that will be set on the container
        retain_all_variant_properties (bool): Optional; whether retain variant properties(number of instances and the weight of the variant) when updating an existing endpoint, default is True
        vpc_config (dict): Optional; the vpc config of the model. It will be passed to CreateModel API call (optional)
        tags (dict): Optional; the tags will be passed CreateModel, CreatEndpointConfig, CreateEndpoint API call (optional)
        autoscaling_config (AutoScalingConfig): Optional; the config will be used to configure autoscaling for the SageMaker endpoint
    """

    endpoint_name: str = attr.ib(
        validator=[
            attr.validators.instance_of(str),
            attr.validators.matches_re("^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}"),
            _fits_length,
        ]
    )
    execution_role_arn: str = attr.ib(validator=attr.validators.instance_of(str))
    model_s3_location: str = attr.ib(
        default="", validator=attr.validators.instance_of(str)
    )
    container_url: str = attr.ib(validator=attr.validators.instance_of(str))
    env_variables: Dict = attr.ib(
        default=attr.Factory(dict), validator=attr.validators.instance_of(dict)
    )
    model_name: str = attr.ib(
        default="", validator=[attr.validators.instance_of(str), _fits_length]
    )
    variant_name: str = attr.ib(
        default="default", validator=attr.validators.instance_of(str)
    )
    initial_instance_count: int = attr.ib(default=1, converter=int)
    instance_type: str = attr.ib(
        default="ml.m5.large", validator=attr.validators.instance_of(str)
    )
    vpc_config: Dict = attr.ib(default=None, converter=sanitize)
    retain_all_variant_properties: bool = attr.ib(default=True)
    tags: List = attr.ib(
        default=attr.Factory(list),
        converter=lambda d: [{"Key": k, "Value": v} for k, v in d.items()]
        if isinstance(d, dict)
        else d,
    )
    autoscaling_config: AutoScalingConfig = attr.ib(default=None)


def _construct_create_model_request(deploy_config: Config, model_name: str) -> Dict:
    request_params = {
        "name": model_name,
        "role": deploy_config.execution_role_arn,
        "container_defs": {
            "Image": deploy_config.container_url,
            "Environment": deploy_config.env_variables,
        },
        "vpc_config": deploy_config.vpc_config,
        "tags": deploy_config.tags,
    }
    if deploy_config.model_s3_location != "":
        request_params["container_defs"][
            "ModelDataUrl"
        ] = deploy_config.model_s3_location
    return request_params


def _construct_create_endpoint_config_request(
    deploy_config: Config, model_name: str, endpoint_config_name: str
) -> Dict:
    return {
        "EndpointConfigName": endpoint_config_name,
        "ProductionVariants": [
            production_variant(
                model_name,
                deploy_config.instance_type,
                deploy_config.initial_instance_count,
                deploy_config.variant_name,
            )
        ],
        "Tags": deploy_config.tags,
    }


def _construct_create_endpoint_request(
    deploy_config: Config, endpoint_config_name: str
) -> Dict:
    return {
        "endpoint_name": deploy_config.endpoint_name,
        "config_name": endpoint_config_name,
        "tags": deploy_config.tags,
    }


def _construct_update_endpoint_request(
    deploy_config: Config, endpoint_config_name: str
) -> Dict:
    return {
        "EndpointName": deploy_config.endpoint_name,
        "EndpointConfigName": endpoint_config_name,
        "RetainAllVariantProperties": deploy_config.retain_all_variant_properties,
    }


def _construct_describe_endpoint_request(deploy_config: Config) -> Dict:
    return {"EndpointName": deploy_config.endpoint_name}


def _sanitize_model_name(input_string: str) -> str:
    """Sanitize the model name since SageMaker API only accepts model name in certain pattern."""
    pattern = re.compile("[^a-zA-Z0-9]")
    output_string = pattern.sub("-", input_string)
    return output_string


def deploy_model(
    model_family: str,
    instance_id: str,
    model_producer_identity: int,
    deploy_config: Config,
    model_registry_client: Optional[TwitchVXModelRegistryLambdaClient] = None,
    region: Optional[str] = "us-west-2",
    logger: Optional[Logger] = None,
    environment: Optional[str] = "",
    boto_session: Optional[boto3.Session] = None,
):
    """Deploy the model to a SageMaker endpoint.

    This is a function that deploy the model to a SageMaker endpoint using the passed in config
    and update the status of the model to DEPLOYED in Model Registry

    Arguments:
        model_family (str): the model family name of the model
        instance_id (str): the instance id of the model
        model_producer_identity (int): the identity of the model producer
        deploy_config (Config): the deployment configuration
        model_registry_client (TwitchVXModelRegistryLambdaClient): Optional; the client to connect to Model Registry service, if not supplied, Model Registry will not be updated
        region (str): Optional; the region that the SageMaker endpoint is in, default us-west-2
        logger (Logger): Optional; a logger to log out information throughout deployment process
        boto_session (boto3.Session): Optional; a boto3 session that is used by sagemaker client, will create a default if not offered
    """
    if not logger:
        logger = _get_module_logger("INFO", "DEPLOY")

    logger.info("Deployment Config: {}".format(deploy_config))
    if not boto_session:
        boto_session = boto3.Session(region_name=region)
    session = Session(boto_session)
    model_name = "{model_family}-{instance_id}".format(
        model_family=_sanitize_model_name(model_family),
        instance_id=_sanitize_model_name(instance_id),
    )[:MODEL_NAME_MAX_LENGTH].rstrip("-")
    if deploy_config.model_name != "":
        model_name = deploy_config.model_name
    logger.info("Model Name: {}".format(model_name))
    endpoint_config_name_suffix = "-{}".format(time.strftime("%Y%m%dT%H%M%S"))
    endpoint_config_name = (
        model_name[: ENDPOINT_CONFIG_NAME_MAX_LENGTH - len(endpoint_config_name_suffix)]
        + endpoint_config_name_suffix
    )
    logger.info("Endpoint Config Name: {}".format(endpoint_config_name))

    if not _is_endpoint_ready(session, deploy_config):
        raise EndpointNotReady(deploy_config.endpoint_name)
    _, endpoint_exists = _is_endpoint_present(session, deploy_config)

    if not _is_model_present(session, model_name):
        logger.info("Creating model on SageMaker")
        session.create_model(
            **_construct_create_model_request(deploy_config, model_name)
        )

    logger.info("Creating endpoint config on SageMaker")
    session.sagemaker_client.create_endpoint_config(
        **_construct_create_endpoint_config_request(
            deploy_config, model_name, endpoint_config_name
        )
    )

    if endpoint_exists:
        current_autoscaling_config = _delete_autoscaling_policy(
            boto_session,
            deploy_config.endpoint_name,
            deploy_config.variant_name,
            logger,
        )
        try:
            logger.info("Updating endpoint on SageMaker")
            session.sagemaker_client.update_endpoint(
                **_construct_update_endpoint_request(
                    deploy_config, endpoint_config_name
                )
            )
            session.wait_for_endpoint(deploy_config.endpoint_name)
        except UnexpectedStatusException as exc:
            logger.warning("Failed to update the endpoint")
            if current_autoscaling_config is not None:
                logger.info(
                    f"Reverting to the previous autoscaling config {current_autoscaling_config}"
                )
                _create_autoscaling_policy(
                    boto_session,
                    deploy_config.endpoint_name,
                    deploy_config.variant_name,
                    current_autoscaling_config,
                    logger,
                )
                session.wait_for_endpoint(deploy_config.endpoint_name)
            raise exc
    else:
        logger.info("Creating endpoint on SageMaker")
        session.create_endpoint(
            **_construct_create_endpoint_request(deploy_config, endpoint_config_name)
        )
    if deploy_config.autoscaling_config is not None:
        _create_autoscaling_policy(
            boto_session,
            deploy_config.endpoint_name,
            deploy_config.variant_name,
            deploy_config.autoscaling_config,
            logger,
        )
        session.wait_for_endpoint(deploy_config.endpoint_name)

    if model_registry_client is not None and _is_endpoint_ready(session, deploy_config):
        update_model_request = UpdateModelRequest(
            model_family=model_family,
            instance_id=instance_id,
            metadata=Model(
                model_state=DEPLOYED,
                deployment_config=DeploymentConfig(
                    sagemaker_deployment_config=SageMakerDeploymentConfig(
                        endpoint_name=deploy_config.endpoint_name,
                        variant_name=deploy_config.variant_name,
                    )
                ),
            ),
            model_producer_id=model_producer_identity,
            environment=environment,
        )
        model_registry_client.update_model(update_model_request=update_model_request)


def _create_autoscaling_policy(
    session: boto3.Session,
    endpoint_name: str,
    variant_name: str,
    config: AutoScalingConfig,
    logger: Logger,
):
    client = session.client("application-autoscaling")
    policy_name = _autoscaling_policy_name(endpoint_name)
    resource_id = _autoscaling_resource_id(endpoint_name, variant_name)
    logger.info(
        f"Trying to create autoscaling on the endpoint {endpoint_name}, autoscaling config: {config}"
    )
    client.register_scalable_target(
        ServiceNamespace="sagemaker",
        ResourceId=resource_id,
        ScalableDimension="sagemaker:variant:DesiredInstanceCount",
        MinCapacity=config.min_instance_count,
        MaxCapacity=config.max_instance_count,
    )
    client.put_scaling_policy(
        PolicyName=policy_name,
        ServiceNamespace="sagemaker",
        ResourceId=resource_id,
        ScalableDimension="sagemaker:variant:DesiredInstanceCount",
        PolicyType="TargetTrackingScaling",
        TargetTrackingScalingPolicyConfiguration={
            "TargetValue": config.num_invocations_per_instance,
            "PredefinedMetricSpecification": {
                "PredefinedMetricType": "SageMakerVariantInvocationsPerInstance",
            },
            "ScaleOutCooldown": config.scale_out_cooldown,
            "ScaleInCooldown": config.scale_in_cooldown,
            "DisableScaleIn": config.disable_scale_in,
        },
    )


def _delete_autoscaling_policy(
    session: boto3.Session, endpoint_name: str, variant_name: str, logger: Logger
) -> Optional[AutoScalingConfig]:
    client = session.client("application-autoscaling")
    policy_name = _autoscaling_policy_name(endpoint_name)
    resource_id = _autoscaling_resource_id(endpoint_name, variant_name)
    scalable_dimension = "sagemaker:variant:DesiredInstanceCount"
    logger.info(
        f"Getting the current scaling policy, name: {policy_name}, resource_id: {resource_id}"
    )
    policy_resp = client.describe_scaling_policies(
        ResourceId=resource_id,
        ServiceNamespace="sagemaker",
    )
    target_resp = client.describe_scalable_targets(
        ResourceIds=[resource_id],
        ServiceNamespace="sagemaker",
    )
    scaling_policies = policy_resp["ScalingPolicies"]
    scalable_targets = target_resp["ScalableTargets"]
    if len(scaling_policies) == 1 and len(scalable_targets) == 1:
        current_scaling_config = AutoScalingConfig(
            num_invocations_per_instance=scaling_policies[0][
                "TargetTrackingScalingPolicyConfiguration"
            ]["TargetValue"],
            min_instance_count=scalable_targets[0]["MinCapacity"],
            max_instance_count=scalable_targets[0]["MaxCapacity"],
            scale_in_cooldown=scaling_policies[0][
                "TargetTrackingScalingPolicyConfiguration"
            ]["ScaleInCooldown"],
            scale_out_cooldown=scaling_policies[0][
                "TargetTrackingScalingPolicyConfiguration"
            ]["ScaleOutCooldown"],
            disable_scale_in=scaling_policies[0][
                "TargetTrackingScalingPolicyConfiguration"
            ]["DisableScaleIn"],
        )
        logger.info(f"Current scaling config {current_scaling_config}")
        logger.info(
            f"Attempting to delete policy, name: {policy_name}, resource_id: {resource_id}"
        )
        client.delete_scaling_policy(
            PolicyName=policy_name,
            ResourceId=resource_id,
            ScalableDimension=scalable_dimension,
            ServiceNamespace="sagemaker",
        )
        client.deregister_scalable_target(
            ServiceNamespace="sagemaker",
            ResourceId=resource_id,
            ScalableDimension=scalable_dimension,
        )
        return current_scaling_config

    if len(scalable_targets) > 1:
        logger.warning(
            f"Expected to have at most 1 scalable target but getting {len(scalable_targets)}, targets: {scalable_targets}"
        )
    if len(scaling_policies) > 1:
        logger.warning(
            f"Expected to have at most 1 scaling policy but getting {len(scaling_policies)}, policies: {scaling_policies}"
        )
    return None


def _is_endpoint_present(session: Session, config: Config) -> Tuple[Any, bool]:
    try:
        response = session.sagemaker_client.describe_endpoint(
            **_construct_describe_endpoint_request(config)
        )
    except ClientError as exc:
        if re.match(
            r"could not find endpoint",
            exc.response["Error"]["Message"],
            flags=re.IGNORECASE,
        ):
            return None, False
        raise exc
    return response, True


def _is_endpoint_ready(session: Session, config: Config) -> bool:
    response, ok = _is_endpoint_present(session, config)
    if ok and response["EndpointStatus"] != "InService":
        return False
    return True


def _is_model_present(session: Session, model_name: str) -> bool:
    try:
        session.describe_model(model_name)
    except ClientError as exc:
        if re.match(
            r"could not find model",
            exc.response["Error"]["Message"],
            flags=re.IGNORECASE,
        ):
            return False
        raise exc
    return True
