import json
import os
import string
from decimal import Decimal
from unittest import TestCase

from google.protobuf.json_format import Parse
from moto import mock_cloudwatch, mock_dynamodb2, mock_s3, mock_sts
from moto.s3.responses import DEFAULT_REGION_NAME

from octarine.clients.TwitchFeatureIngestionClient import *


def get_test_data(path):
    with open(os.getcwd() + path) as f:
        test_data = json.load(f)

    proto_message = Parse(json.dumps(test_data), FeatureMetadata())
    test_raw_content = proto_message.SerializeToString()
    return test_raw_content


def generate_feature_records(feature_configs, limit=10):
    feature_records = []
    letters = string.ascii_lowercase

    # null case
    entity_list = [FeatureEntity(ENTITY_CHANNEL, str(0))]

    feature_val_float = 0.0
    feature_val_list = []
    feature_val = {
        feature_configs[0]: feature_val_float,
        feature_configs[1]: feature_val_list,
    }
    feature_records.append(FeatureRecord(entity_list, feature_val))

    for i in range(1, limit):
        entity_list = [FeatureEntity(ENTITY_CHANNEL, str(i))]

        feature_val_float = 0.5
        feature_val_list = [letters[i] * i]
        feature_val = {
            feature_configs[0]: feature_val_float,
            feature_configs[1]: feature_val_list,
        }
        feature_records.append(FeatureRecord(entity_list, feature_val))

    return feature_records


@mock_s3
@mock_sts
class FeatureIngestionInitClientTest(TestCase):
    def setUp(self) -> None:
        bucket_name = "mlfs-feature-metadata-beta"
        test_key = "fic/test/1/metadata"

        s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
        s3.create_bucket(Bucket=bucket_name)
        s3.Bucket(bucket_name).put_object(
            Body=get_test_data("/tests/test_data.json"), Key=test_key
        )

    def test_init_feature_ingestion_client(self):
        conf = Config(AWS_REGION=DEFAULT_REGION_NAME, ENVIRONMENT="beta")
        feature_keys_config = [FeatureKey("fic", "test", "1")]
        fic = TwitchFeatureIngestionClient(conf, feature_keys_config)
        feature_key = feature_keys_config[0]
        feature_metadata = fic.metadata
        self.assertEqual(
            feature_metadata[feature_key]["feature_metadata"].namespace,
            "fic",
            msg="namespace != fic",
        )
        self.assertEqual(
            feature_metadata[feature_key]["feature_metadata"].feature_id,
            "test",
            msg="feature_id != test",
        )
        self.assertEqual(
            feature_metadata[feature_key]["feature_metadata"].version,
            1,
            msg="version != 1",
        )


@mock_s3
@mock_dynamodb2
@mock_sts
@mock_cloudwatch
class FeatureIngestionPutFeaturesTest(TestCase):
    def setUp(self) -> None:
        bucket_name = "mlfs-feature-metadata-beta"
        test_key = "fic/test/1/metadata"

        s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
        s3.create_bucket(Bucket=bucket_name)
        s3.Bucket(bucket_name).put_object(
            Body=get_test_data("/tests/test_data.json"), Key=test_key
        )

        self.dynamodb = boto3.resource("dynamodb", region_name=DEFAULT_REGION_NAME)
        self.dynamodb.create_table(
            TableName="test-table",
            KeySchema=[{"AttributeName": "ofs_id", "KeyType": "HASH"}],
            AttributeDefinitions=[{"AttributeName": "ofs_id", "AttributeType": "S"}],
            ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
        )

    def test_put_features(self):
        conf = Config(AWS_REGION=DEFAULT_REGION_NAME, ENVIRONMENT="beta")
        feature_keys_config = [FeatureKey("fic", "test", "1")]
        fic = TwitchFeatureIngestionClient(conf, feature_keys_config)

        feature_entities = [FeatureEntity(ENTITY_CHANNEL, "4823")]
        feature_key = feature_keys_config[0]
        feature_values = {feature_key: 5}
        fic.put_features(feature_entities, feature_values)

        table = self.dynamodb.Table(
            fic.metadata[feature_key][
                "feature_metadata"
            ].source.beta_online.dynamodb.table
        )
        object = table.get_item(Key={"ofs_id": "4823"})
        self.assertIsNotNone(object["Item"])
        self.assertIn("test@1", object["Item"])
        self.assertEqual(object["Item"]["test@1"], Decimal(5))
        self.assertIn("ofs_expire_time", object["Item"])


@mock_s3
@mock_dynamodb2
@mock_sts
@mock_cloudwatch
class FeatureIngestionBatchWriteFeaturesTest(TestCase):
    def setUp(self) -> None:
        bucket_name = "mlfs-feature-metadata-beta"
        test_float_key = "fic/test-float/1/metadata"
        test_string_list_key = "fic/test-string-list/0/metadata"

        s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
        s3.create_bucket(Bucket=bucket_name)
        s3.Bucket(bucket_name).put_object(
            Body=get_test_data("/tests/test_data_batch_2.json"), Key=test_float_key
        )
        s3.Bucket(bucket_name).put_object(
            Body=get_test_data("/tests/test_data_batch_1.json"),
            Key=test_string_list_key,
        )

        self.dynamodb = boto3.resource("dynamodb", region_name=DEFAULT_REGION_NAME)
        self.dynamodb.create_table(
            TableName="batch-write-test-table-1",
            KeySchema=[{"AttributeName": "ofs_id", "KeyType": "HASH"}],
            AttributeDefinitions=[{"AttributeName": "ofs_id", "AttributeType": "S"}],
            ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
        )
        self.dynamodb.create_table(
            TableName="batch-write-test-table-2",
            KeySchema=[{"AttributeName": "ofs_id", "KeyType": "HASH"}],
            AttributeDefinitions=[{"AttributeName": "ofs_id", "AttributeType": "S"}],
            ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
        )

    def test_batch_write_features(self):
        conf = Config(AWS_REGION=DEFAULT_REGION_NAME, ENVIRONMENT="beta")
        feature_keys_config = [
            FeatureKey("fic", "test-float", "1"),
            FeatureKey("fic", "test-string-list", "0"),
        ]
        fic = TwitchFeatureIngestionClient(conf, feature_keys_config)

        feature_records = generate_feature_records(feature_keys_config)
        fic.batch_write_features(feature_records)

        table = self.dynamodb.Table(
            fic.metadata[feature_keys_config[0]][
                "feature_metadata"
            ].source.beta_online.dynamodb.table
        )
        object = table.get_item(Key={"ofs_id": "3"})
        self.assertIsNotNone(object["Item"])
        self.assertIn("test-float@1", object["Item"])
        self.assertEqual(object["Item"]["test-float@1"], Decimal(0.5))
        self.assertIn("ofs_expire_time", object["Item"])

        another_table = self.dynamodb.Table(
            fic.metadata[feature_keys_config[1]][
                "feature_metadata"
            ].source.beta_online.dynamodb.table
        )
        object = another_table.get_item(Key={"ofs_id": "3"})
        self.assertIsNotNone(object["Item"])
        self.assertIn("test-string-list@0", object["Item"])
        self.assertEqual(object["Item"]["test-string-list@0"], ["ddd"])
        self.assertIn("ofs_expire_time", object["Item"])
