import datetime
import logging
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Dict, List

import boto3
from botocore.client import Config as BotocoreConfig
from botocore.exceptions import ClientError

from .feature_schema_pb2 import FeatureMetadata
from .metrics import MetricsSender, MetricUnit
from .roles import get_session
from .utils import check_entity_ordering, check_feature_type, float_to_decimal

ENTITY_UNSPECIFIED = 0
ENTITY_OTHER = 1
ENTITY_CHANNEL = 2
ENTITY_DEVICE = 3
ENTITY_QUERY = 4
ENTITY_USER = 5
ENTITY_CATEGORY = 6

Config = namedtuple(
    "Config",
    ["AWS_REGION", "ENVIRONMENT"],
)
FeatureKey = namedtuple("FeatureKey", ["namespace", "feature_id", "version"])
FeatureEntity = namedtuple("FeatureEntity", ["entity_type", "entity_value"])

DEFAULT_TTL = 2592000  # 30 days
DEFAULT_DDB_CLIENT_CONFIG = BotocoreConfig()
STS_ENDPOINT_URL = "https://sts.{}.amazonaws.com"

logger = logging.getLogger(__file__)


@dataclass
class FeatureRecord:
    entity_list: List[FeatureEntity]
    feature_vals: Dict[FeatureKey, Any]


def _get_ofs_id_from_entities(entities: [FeatureEntity]):
    ofs_id = "#".join(map(lambda e: str(e.entity_value), entities))
    return ofs_id


def _get_entity_types_from_entities(entities: [FeatureEntity]):
    entity_types = "#".join(map(lambda e: str(e.entity_type), entities))
    return entity_types


def _feature_error_checks(feature_key: FeatureKey, metadata, feature_val, entity_types):
    # Check if feature key exists in metadata
    if feature_key not in metadata:
        raise KeyError(
            "Feature key: "
            + feature_key.namespace
            + "/"
            + feature_key.feature_id
            + "/"
            + feature_key.version
            + " not found in feature registry"
        )

    # Check if feature type/shape matches the one specified in metadata
    if not check_feature_type(metadata[feature_key], feature_val):
        raise AssertionError(
            "Feature key: "
            + feature_key.namespace
            + "/"
            + feature_key.feature_id
            + "/"
            + feature_key.version
            + " data type/shape does not match the type from Feature Registry or is unspecified"
        )

    # Check if entity ordering matches the one in feature registry
    if not check_entity_ordering(metadata[feature_key], entity_types):
        raise AssertionError(
            "Order of entity types for feature instance: "
            + entity_types
            + "does not match the order from "
            "Feature Registry"
        )


class TwitchFeatureIngestionClient(object):
    def __init__(
        self,
        config: Config,
        feature_key_configs: [FeatureKey],
        ddb_client_config=DEFAULT_DDB_CLIENT_CONFIG,
    ):
        self.s3 = boto3.client("s3", region_name=config.AWS_REGION)
        self.sts = boto3.client(
            "sts",
            region_name=config.AWS_REGION,
            endpoint_url=STS_ENDPOINT_URL.format(config.AWS_REGION),
        )
        self.metrics = MetricsSender(
            region_name=config.AWS_REGION,
            namespace="feature_ingestion_{env}".format(env=config.ENVIRONMENT),
            base_dimensions={"Stage": config.ENVIRONMENT},
        )
        self.config = config
        self.ddb_client_config = ddb_client_config
        self.metadata_bucket = "mlfs-feature-metadata-{}".format(config.ENVIRONMENT)
        self.metadata = self._get_feature_metadata(feature_key_configs)
        self.ddb_clients = self._init_ddb_clients()

    def _init_ddb_clients(self):
        """
        Initialize ddb clients for each feature key using the write role from feature metadata.
        This is done at the client level because creating sessions are expensive and we dont want to add latency to the ingestion API
        """
        sessions = {}
        for feature_key in self.metadata.keys():
            write_role = self.metadata[feature_key]["write_role"]
            session = get_session(
                role_arn=write_role,
                aws_region=self.config.AWS_REGION,
                sts_client=self.sts,
            )
            dynamodb_client = session.resource(
                "dynamodb",
                region_name=self.config.AWS_REGION,
                config=self.ddb_client_config,
            )
            sessions[feature_key] = dynamodb_client

        return sessions

    def _get_feature_metadata(self, feature_key_configs: [FeatureKey]):
        """
        Retrieves feature metadata from the feature registry

        :return metadata
            Map of feature_key to metadata for the feature keys provided
        """

        metadata = {}
        for feature_key in feature_key_configs:
            bucket_feature_key = (
                feature_key.namespace
                + "/"
                + feature_key.feature_id
                + "/"
                + feature_key.version
            )
            try:
                obj = self.s3.get_object(
                    Bucket=self.metadata_bucket, Key=bucket_feature_key + "/metadata"
                )
            except ClientError as error:
                raise error

            feature_metadata = FeatureMetadata()
            feature_metadata.ParseFromString(obj["Body"].read())

            metadata[feature_key] = {}
            metadata[feature_key]["feature_metadata"] = feature_metadata
            write_role = "arn:aws:iam::"
            if self.config.ENVIRONMENT == "prod":
                table_name = feature_metadata.source.prod_online.dynamodb.table
                metadata_write_role = (
                    feature_metadata.source.prod_online.dynamodb.write_role
                )
                if metadata_write_role == "":
                    write_role += (
                        feature_metadata.source.prod_online.dynamodb.account_id
                        + ":role/write_"
                        + table_name
                    )
                else:
                    write_role = metadata_write_role
            else:
                table_name = feature_metadata.source.beta_online.dynamodb.table
                metadata_write_role = (
                    feature_metadata.source.beta_online.dynamodb.write_role
                )
                if metadata_write_role == "":
                    write_role += (
                        feature_metadata.source.beta_online.dynamodb.account_id
                        + ":role/write_"
                        + table_name
                    )
                else:
                    write_role = metadata_write_role

            metadata[feature_key]["write_role"] = write_role
            metadata[feature_key]["table_name"] = table_name

        return metadata

    def put_features(
        self,
        entities: [FeatureEntity],
        features: Dict[FeatureKey, Any],
        ttl=DEFAULT_TTL,
        disable_metrics=False,
    ):
        """
        Put feature values into OFS

        :param entities: List of FeatureEntity
        :param features: Map of FeatureKey to feature values
        :param ttl: TTL for DDB item in seconds
        :param disable_metrics: Disable sending metrics from this client
        """

        entity_types = _get_entity_types_from_entities(entities)
        ofs_id = _get_ofs_id_from_entities(entities)

        for feature_key in features.keys():
            _feature_error_checks(
                feature_key, self.metadata, features[feature_key], entity_types
            )

            feature_metadata = self.metadata[feature_key]["feature_metadata"]
            table_name = self.metadata[feature_key]["table_name"]

            dynamodb = self.ddb_clients[feature_key]
            ddb_table = dynamodb.Table(table_name)

            timeout = datetime.datetime.utcnow() + datetime.timedelta(seconds=ttl)

            ddb_table.update_item(
                Key={"ofs_id": ofs_id},
                UpdateExpression="SET #f1 = :val1, #t = :val2",
                ExpressionAttributeNames={
                    "#f1": str(feature_metadata.feature_id)
                    + "@"
                    + str(feature_metadata.version),
                    "#t": "ofs_expire_time",
                },
                ExpressionAttributeValues={
                    ":val1": float_to_decimal(features[feature_key]),
                    ":val2": int(timeout.timestamp()),
                },
            )

            if not disable_metrics:
                self.metrics.send(
                    "put.success",
                    1,
                    MetricUnit.Count,
                    dimensions={
                        "namespace": str(feature_metadata.namespace),
                        "feature_id": str(feature_metadata.feature_id),
                        "version": str(feature_metadata.version),
                    },
                )

    def batch_write_features(
        self, feature_records: List[FeatureRecord], ttl=DEFAULT_TTL
    ):
        """
        Batch write feature values into OFS

        :param feature_records: List of FeatureRecord to be written to ddb
        :param ttl: TTL for ddb item in seconds
        :return List of failed items in batch
        """

        # ddb table -> ofs_id -> item to be written to that table
        feature_batches = {}
        for feature_record in feature_records:
            entity_types = _get_entity_types_from_entities(feature_record.entity_list)
            ofs_id = _get_ofs_id_from_entities(feature_record.entity_list)
            timeout = datetime.datetime.utcnow() + datetime.timedelta(seconds=ttl)

            for feature_key in feature_record.feature_vals.keys():
                _feature_error_checks(
                    feature_key,
                    self.metadata,
                    feature_record.feature_vals[feature_key],
                    entity_types,
                )

                table_name = self.metadata[feature_key]["table_name"]
                dynamodb = self.ddb_clients[feature_key]
                ddb_table = dynamodb.Table(table_name)

                if ddb_table not in feature_batches:
                    feature_batches[ddb_table] = {}

                if ofs_id not in feature_batches[ddb_table]:
                    feature_batches[ddb_table][ofs_id] = {
                        "ofs_id": ofs_id,
                        "ofs_expire_time": int(timeout.timestamp()),
                    }

                feature_batches[ddb_table][ofs_id][
                    feature_key.feature_id + "@" + feature_key.version
                ] = float_to_decimal(feature_record.feature_vals[feature_key])

        # ddb table -> list of items to be written to that table
        batched_items = {
            table: [item for _, item in feature_batches[table].items()]
            for table in feature_batches.keys()
        }

        failed_items = []
        for table, items in batched_items.items():
            try:
                with table.batch_writer() as batch_writer:
                    for item in items:
                        batch_writer.put_item(Item=item)
            except ClientError:
                logger.exception("Unable to batch write item: %s" % (item["ofs_id"]))
                failed_items.append(item)

        return failed_items
