import logging
import sys

import boto3
from awsglue.context import GlueContext
from awsglue.utils import getResolvedOptions
from pyspark.context import SparkContext

from octarine.clients.TwitchFeatureIngestionClient import (
    ENTITY_CHANNEL,
    Config,
    FeatureEntity,
    FeatureKey,
    FeatureRecord,
    TwitchFeatureIngestionClient,
)


def batch_write_features_to_ofs(data_config):
    logger = logging.getLogger("executor.OFS.write")
    feature_configs = []
    for conf in data_config:
        feature, _, _ = conf.strip().split(",")
        feature_name, version = feature.strip().split("@")
        feature_key = FeatureKey("test", feature_name, version)
        feature_configs.append(feature_key)

    config = Config(AWS_REGION="us-west-2", ENVIRONMENT="beta")
    logger.info("feat configs", data_config)

    def write_features(rows):
        fic = TwitchFeatureIngestionClient(config, feature_configs)

        feature_records = []
        for row in rows:
            ofs_id = row["ofs_id"]
            feature_values = {}
            for i in range(len(data_config)):
                feat_conf = data_config[i]
                feature_id, type, shape = feat_conf.strip().split(",")
                if type == "float":
                    if shape == "list":
                        feature_val = list(map(float, row[feature_id].split(";")))
                    else:
                        feature_val = float(row[feature_id])
                elif type == "string":
                    if shape == "list":
                        feature_val = row[feature_id].split(";")
                    else:
                        feature_val = row[feature_id]
                else:
                    feature_val = ""

                feature_values[feature_configs[i]] = feature_val

            feature_record = FeatureRecord(
                [FeatureEntity(ENTITY_CHANNEL, ofs_id)],
                feature_values,
            )
            feature_records.append(feature_record)

        fic.batch_write_features(feature_records, ttl=5)

    return write_features


def write_to_ofs(partition):
    print("Not implemented!")


def main():
    args = getResolvedOptions(sys.argv, ["JOB_NAME", "method"])
    method = args["method"]

    s3 = boto3.client("s3", region_name="us-west-2")
    response = s3.get_object(Bucket="fic-load-test-data", Key="feature_config.csv")
    config_data = response["Body"].read().decode("utf-8").splitlines(True)

    input_s3_paths = config_data[0].strip().split(",")
    if input_s3_paths[-1] == "":
        input_s3_paths = input_s3_paths[:-1]

    print("dataset: ", input_s3_paths)
    sc = SparkContext()
    glue_context = GlueContext(sc)
    dyf = glue_context.create_dynamic_frame.from_options(
        connection_type="s3",
        format="csv",
        connection_options={
            "paths": input_s3_paths,
            "recurse": True,
        },
        format_options={"withHeader": True, "separator": ","},
    )

    df = dyf.toDF()
    df.printSchema()
    print("num partitions: ", str(df.rdd.getNumPartitions()))
    if method == "batch":
        df.foreachPartition(batch_write_features_to_ofs(config_data[1:]))
    else:
        df.foreachPartition(write_to_ofs)


if __name__ == "__main__":
    main()
