import datetime
import time

import boto3
import pytest
from botocore.exceptions import ClientError
from sagemaker.session import Session

from octarine.clients.TwitchVXModelRegistryTwirp import (
    MODEL_PRODUCER_VX_CAMI,
    MODEL_SCHEMA_JSON_DICTIONARY,
    SAGEMAKER_DEPLOYMENT,
    TRAINED,
    Model,
    ModelFamilyConfig,
    ModelValidationConfig,
    RegisterModelRequest,
    TwitchVXModelRegistryLambdaClient,
    UpdateModelRequest,
)

from .deploy import AutoScalingConfig, Config, deploy_model

TEST_ENDPOINT_NAME = "integ-test-endpoint"
TEST_EXECUTION_ROLE = "arn:aws:iam::830245545714:role/TwitchVXModelRegistryLamb-IntegrationTestingPassin-1Q419QD6QDVVS"
TEST_ENV = "staging"
TEST_MODEL_1 = "s3://twitchvxmodelregistrylam-integrationtestings3buck-z4usxp01w0bd/xgb-churn-prediction-model.tar.gz"
TEST_MODEL_2 = "s3://twitchvxmodelregistrylam-integrationtestings3buck-z4usxp01w0bd/xgb-churn-prediction-model2.tar.gz"
MODEL_REGISTRY_BETA_ENDPOINT = "arn:aws:lambda:us-west-2:830245545714:function:TwitchVXModelRegistryLambda-LambdaFunction-1GSV33ZHUWPQN:live"


@pytest.mark.usefixtures("clean_up_test_resources")
def test_deploy_model():
    session = Session(boto3.Session(region_name="us-west-2"))
    model_registry_client = TwitchVXModelRegistryLambdaClient(
        lambda_endpoint=MODEL_REGISTRY_BETA_ENDPOINT,
        boto_lambda_client=boto3.client("lambda", region_name="us-west-2"),
    )
    deploy_config = Config(
        endpoint_name=TEST_ENDPOINT_NAME,
        execution_role_arn=TEST_EXECUTION_ROLE,
        model_s3_location=TEST_MODEL_1,
        container_url="246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3",
        variant_name="Variant1",
        initial_instance_count=1,
        instance_type="ml.m5.xlarge",
        autoscaling_config=AutoScalingConfig(
            num_invocations_per_instance=100,
            max_instance_count=10,
        ),
    )
    model_family = "integ-test-utility-package"
    instance_id = "test-instance-{}".format(time.strftime("%Y%m%dT%H%M%S"))
    _register_model(model_registry_client, model_family, instance_id)

    deploy_model(
        model_family=model_family,
        instance_id=instance_id,
        model_producer_identity=MODEL_PRODUCER_VX_CAMI,
        deploy_config=deploy_config,
        model_registry_client=model_registry_client,
        environment=TEST_ENV,
    )

    endpoint_description = session.sagemaker_client.describe_endpoint(
        EndpointName=TEST_ENDPOINT_NAME
    )
    assert (
        endpoint_description["EndpointStatus"] == "InService"
    ), "endpoint should be in service but got {}".format(
        endpoint_description["EndpointStatus"]
    )

    deploy_config = Config(
        endpoint_name=TEST_ENDPOINT_NAME,
        execution_role_arn=TEST_EXECUTION_ROLE,
        model_s3_location=TEST_MODEL_2,
        container_url="246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3",
        variant_name="Variant1",
        initial_instance_count=1,
        instance_type="ml.m5.xlarge",
        autoscaling_config=AutoScalingConfig(
            num_invocations_per_instance=100,
            max_instance_count=10,
        ),
    )
    instance_id = "test-instance-{}".format(time.strftime("%Y%m%dT%H%M%S"))
    _register_model(model_registry_client, model_family, instance_id)
    deploy_model(
        model_family=model_family,
        instance_id=instance_id,
        model_producer_identity=MODEL_PRODUCER_VX_CAMI,
        deploy_config=deploy_config,
        model_registry_client=model_registry_client,
        environment=TEST_ENV,
    )

    endpoint_description = session.sagemaker_client.describe_endpoint(
        EndpointName=TEST_ENDPOINT_NAME
    )
    assert (
        endpoint_description["EndpointStatus"] == "InService"
    ), "endpoint should be in service but got {}".format(
        endpoint_description["EndpointStatus"]
    )
    endpoint_config_name = endpoint_description["EndpointConfigName"]
    endpoint_config_description = session.sagemaker_client.describe_endpoint_config(
        EndpointConfigName=endpoint_config_name
    )
    model_name = endpoint_config_description["ProductionVariants"][0]["ModelName"]
    model_description = session.describe_model(name=model_name)
    assert (
        model_description["PrimaryContainer"]["ModelDataUrl"] == TEST_MODEL_2
    ), "model s3 location should be {} but got {}".format(
        TEST_MODEL_2, model_description["PrimaryContainer"]["ModelDataUrl"]
    )


def _register_model(
    model_registry_client: TwitchVXModelRegistryLambdaClient,
    model_family: str,
    instance_id: str,
):
    model_registry_client.register_model(
        register_model_request=RegisterModelRequest(
            model_family=model_family,
            instance_id=instance_id,
            model_family_config=ModelFamilyConfig(
                deploy_target=SAGEMAKER_DEPLOYMENT,
                validation_config=ModelValidationConfig(
                    schema_name=MODEL_SCHEMA_JSON_DICTIONARY,
                ),
            ),
            model_producer_id=MODEL_PRODUCER_VX_CAMI,
        )
    )
    model_registry_client.update_model(
        update_model_request=UpdateModelRequest(
            model_family=model_family,
            instance_id=instance_id,
            metadata=Model(
                model_state=TRAINED,
            ),
            model_producer_id=MODEL_PRODUCER_VX_CAMI,
            environment=TEST_ENV,
        )
    )


@pytest.fixture
def clean_up_test_resources():
    current_time = datetime.datetime.now()
    yield
    session = Session(boto3.Session(region_name="us-west-2"))
    # clean up endpoint
    try:
        # the endpoint is not necessarily created if the test fails
        session.delete_endpoint(endpoint_name=TEST_ENDPOINT_NAME)
    except ClientError as err:
        print(err)
    # clean up endpoint configs created in the test
    endpoint_configs = session.sagemaker_client.list_endpoint_configs(
        NameContains="integ-test-utility-package", CreationTimeAfter=current_time
    )
    for endpoint_config in endpoint_configs["EndpointConfigs"]:
        session.delete_endpoint_config(
            endpoint_config_name=endpoint_config["EndpointConfigName"]
        )
    # clean up models created in the test
    models = session.sagemaker_client.list_models(
        NameContains="integ-test-utility-package", CreationTimeAfter=current_time
    )
    for model in models["Models"]:
        session.delete_model(model_name=model["ModelName"])
