
resource "aws_s3_bucket_object" "table_job_script" {
  bucket  = local.computed_s3_script_bucket
  key     = "db_exports/${var.job_name}/table.py"
  content = <<EOF

${data.template_file.common_code.rendered}

from awsglue.context import GlueContext
from awsglue.dynamicframe import DynamicFrame
from awsglue.job import Job
from awsglue.transforms import *
from pyspark.context import SparkContext
from pyspark.sql import functions as F
from pyspark.sql.functions import udf

## @params: [JOB_NAME, table_name, ts, fail_on_error]
args = getResolvedOptions(sys.argv, ['JOB_NAME', 'table_name', 'ts', 'fail_on_error'])
logger.info('args: %s', args)

fail_on_error = bool(int(args['fail_on_error']))
sc = SparkContext()
glueContext = GlueContext(sc)
spark = glueContext.spark_session
job = Job(glueContext)
job.init(args['JOB_NAME'], args)

table_name = args['table_name']
# table name is lowercased for s3 output to avoid conflicts when migrating output to Tahoe API
proper_table_name = table_name.lower()
logger.info('Exporting table %s', table_name)

raw_config = get_raw_config()
table_config = load_table_config(raw_config, table_name)
output_prefix = glue_job_output_prefix(proper_table_name, args['ts'])
session = boto3.Session(region_name='us-west-2')
bucket = session.resource('s3').Bucket('${local.computed_s3_output_bucket}')
resp = bucket.objects.filter(Prefix=output_prefix).delete()
if resp:
    logger.warn('Deleted %d existing s3 objects', len(resp[0]['Deleted']))

job_source_type = '${var.database_type}'
RDS_INPUT_SOURCE = 'rds'
REDSHIFT_INPUT_SOURCE = 'redshift'
DYNAMODB_INPUT_SOURCE = 'dynamodb'

if job_source_type == RDS_INPUT_SOURCE:
    ssm = session.client('ssm', region_name='us-west-2')
    ssm_response = ssm.get_parameter(Name='${var.db_password_parameter_name}', WithDecryption=True)

    rds = session.client('rds', region_name='us-west-2')
    if ${var.aurora}:
        instance = rds.describe_db_clusters(DBClusterIdentifier='${var.cluster_name}')['DBClusters'][0]
        endpoint = (instance['ReaderEndpoint'], instance['Port'])
        db_name = instance['DatabaseName']
    else:
        if ${var.skip_snapshot}:
            instance = rds.describe_db_instances(DBInstanceIdentifier='${var.cluster_name}')['DBInstances'][0]
        else:
            instance = rds.describe_db_instances(DBInstanceIdentifier='${var.job_name}-export-snapshot')['DBInstances'][0]
        endpoint = (instance['Endpoint']['Address'], instance['Endpoint']['Port'])
        db_name = instance['DBName']

    connection_options = {
        'url': 'jdbc:postgresql://{}:{}/{}?ssl=true'.format(
            endpoint[0], endpoint[1],
            table_config.get('database', db_name)),
        'dbtable': table_name,
        'user': '${var.cluster_username}',
        'password': ssm_response['Parameter']['Value'],
    }
    connection_type = 'postgresql'
elif job_source_type == DYNAMODB_INPUT_SOURCE:
    connection_options = {
        'dynamodb.input.tableName': table_name,
        'dynamodb.throughput.read.percent': str(table_config['read_ratio']),
        'dynamodb.splits': str(table_config.get('dynamodb_splits_count', ${var.dynamodb_splits_count})),
    }
    connection_type = 'dynamodb'
elif job_source_type == REDSHIFT_INPUT_SOURCE:
    ssm = session.client('ssm', region_name='us-west-2')
    ssm_response = ssm.get_parameter(Name='${var.db_password_parameter_name}', WithDecryption=True)

    redshift = session.client('redshift', region_name='us-west-2')
    cluster = redshift.describe_clusters(ClusterIdentifier='${var.cluster_name}')['Clusters'][0]
    leader = [n for n in cluster['ClusterNodes'] if n['NodeRole'] == 'LEADER'][0]
    endpoint = (leader['PrivateIPAddress'], cluster['Endpoint']['Port'])
    db_name = cluster['DBName']
    namespace = table_config.get('namespace', '')
    if namespace:
        namespace += '.'

    connection_options = {
        'url': 'jdbc:redshift://{}:{}/{}'.format(
            endpoint[0], endpoint[1],
            table_config.get('database', db_name)),
        'dbtable': namespace + table_name,
        'user': '${var.cluster_username}',
        'password': ssm_response['Parameter']['Value'],
        'redshiftTmpDir': 's3://${local.computed_s3_output_bucket}/_scratch/',
    }
    connection_type = 'redshift'

if job_source_type not in (DYNAMODB_INPUT_SOURCE, REDSHIFT_INPUT_SOURCE):
    for field in ('hashexpression', 'hashfield', 'hashpartitions'):
        if table_config.get(field):
            connection_options[field] = table_config.get(field)

datasource = glueContext.create_dynamic_frame_from_options(connection_type=connection_type, connection_options=connection_options)
full_schema = table_config['schema']
non_derived_schema = [c for c in full_schema if not c.get('is_derived', False)]

def experimental_spark_changes():
    # Turn column definition into (field, type, field, type), since input and output are the same.
    def get_input_mapping_type(source_input_type):
        if source_input_type == 'timestamp' and job_source_type == DYNAMODB_INPUT_SOURCE:
            return 'string'
        else:
            return source_input_type

    mappings = [(c['name'], get_input_mapping_type(c['type']),
                c['name'], c['type']) for c in non_derived_schema]
    applymapping = ApplyMapping.apply(frame=datasource, mappings=mappings)

    if fail_on_error and applymapping.errorsCount():
        raise RuntimeError('There are %d row errors after ApplyMapping', applymapping.errorsCount())

    def clean(frame):
        ${replace(var.cleaning_code, "\n", "\n        ")}

    cleaned = clean(applymapping)
    if fail_on_error and cleaned.errorsCount():
        raise RuntimeError('There are %d row errors after Custom Cleaning', cleaned.errorsCount())

    string_fields = [c['name'] for c in non_derived_schema if c['type'] == 'string']

    resolvechoice = ResolveChoice.apply(frame=cleaned, choice='make_struct')
    if fail_on_error and resolvechoice.errorsCount():
        raise RuntimeError('There are %d row errors after ResolveChoice', resolvechoice.errorsCount())

    resolved_choice_df = resolvechoice.toDF()

    def spark_clean(df):
        ${replace(var.spark_cleaning_code, "\n", "\n        ")}

    spark_cleaned_df = spark_clean(resolved_choice_df)

    for str_col in string_fields:
        # regex in pyspark uses java regex under the hood
        # the following unicode character categories (in the regex) matches on:
        # line, space, and paragraph separators. \s is equivalent to Python's string.whitespace
        convert_blank_strings_to_null_func = (F.when((F.col(str_col).isNull()) | (F.col(str_col).rlike(r'^[\s\p{Zs}\p{Zl}\p{Zp}]*$')), None)
            .otherwise(F.col(str_col)))
        spark_cleaned_df = spark_cleaned_df.withColumn(str_col, convert_blank_strings_to_null_func)

    output_fields = table_config.get('output_fields', [c['name'] for c in full_schema])
    selected_fields_df = spark_cleaned_df.select(output_fields)

    lowercased_df = selected_fields_df.toDF(*[col.lower() for col in selected_fields_df.columns])

    output_path = 's3://${local.computed_s3_output_bucket}/{}'.format(output_prefix)
    logger.info('Exporting output to %s', output_path)
    logger.info('Table config: %s', table_config)
    lowercased_df.write.mode('overwrite').parquet(output_path)
    logger.info('Total number of row errors after ApplyMapping: %d', applymapping.errorsCount())
    logger.info('Total number of row errors after custom cleaning: %d', cleaned.errorsCount())

def stable_transformation():
    # Turn column definition into (field, type, field, type), since input and output are the same.
    # Turn "timestamp" into "string" to work around a bug in Glue with microseconds on timestamps.
    # Turn "multi" into "string" to handle multi types as strings.
    mappings = [(c['name'], 'string' if c['type'] == 'multi' else c['type'],
                c['name'], 'string' if c['type'] in ('timestamp', 'multi') else c['type']) for c in non_derived_schema]

    applymapping = ApplyMapping.apply(frame=datasource, mappings=mappings)
    if fail_on_error and applymapping.errorsCount():
        raise RuntimeError('There are %d row errors after ApplyMapping', applymapping.errorsCount())

    def clean(frame):
        ${replace(var.cleaning_code, "\n", "\n        ")}

    cleaned = clean(applymapping)
    if fail_on_error and cleaned.errorsCount():
        raise RuntimeError('There are %d row errors after Custom Cleaning', cleaned.errorsCount())

    string_fields = [c['name'] for c in non_derived_schema if c['type'] == 'string']
    multi_fields = [c['name'] for c in non_derived_schema if c['type'] == 'multi']
    timestamp_fields = [c['name'] for c in non_derived_schema if c['type'] == 'timestamp']

    def null_blank_empty(rec):
        # Change blank or empty strings to null.
        # Convert timestamps from strings to timestamps to work around a Glue bug that removes leading
        #  "0"s on microseconds.
        for f in string_fields:
            if f not in rec or rec[f] is None:
                continue
            # Convert arrays to strings.
            if not hasattr(rec[f], 'strip'):
                rec[f] = json.dumps(rec[f])
            elif rec[f].strip() == '':
                rec[f] = None
        # Handle multi types, which have type information encoded...
        for f in multi_fields:
            if f not in rec or rec[f] is None or not hasattr(rec[f], 'items'):
                continue
            items = list(rec[f].items())
            if len(items) != 1:
                continue
            k, v = items[0]
            if k == 'array':
                rec[f] = json.dumps(v)
            else:
                rec[f] = str(v)
        for f in timestamp_fields:
            if f in rec and rec[f] is not None:
                try:
                    rec[f] = datetime.datetime.strptime(rec[f], '%Y-%m-%d %H:%M:%S.%f')
                except ValueError:
                    rec[f] = datetime.datetime.strptime(rec[f], '%Y-%m-%d %H:%M:%S')
        return rec

    resolvechoice = ResolveChoice.apply(frame=cleaned, choice='make_struct')
    if fail_on_error and resolvechoice.errorsCount():
        raise RuntimeError('There are %d row errors after ResolveChoice', resolvechoice.errorsCount())

    nullstrings = Map.apply(frame=resolvechoice, f=null_blank_empty)
    if fail_on_error and nullstrings.errorsCount():
        raise RuntimeError('There are %d row errors after null_blank_empty Map', nullstrings.errorsCount())

    # Select out the columns we want (and undo scrambling from Map).
    output_fields = table_config.get('output_fields', [c['name'] for c in full_schema])

    ordered = SelectFields.apply(frame=nullstrings, paths=output_fields)
    if fail_on_error and ordered.errorsCount():
        raise RuntimeError('There are %d row errors after SelectFields', ordered.errorsCount())

    dropnullfields = DropNullFields.apply(frame=ordered)
    if fail_on_error and dropnullfields.errorsCount():
        raise RuntimeError('There are %d row errors after DropNullFields', dropnullfields.errorsCount())

    # dynamic frame needs to be converted to a dataframe in order to change all the field names to lowercase in one go
    logging.info("Transforming DynamicFrame to Dataframe and lowercasing all column names")

    transformed_frame = dropnullfields.toDF()
    lowercased_df = transformed_frame.toDF(*[col.lower() for col in transformed_frame.columns])
    output_dynamic_frame = DynamicFrame.fromDF(lowercased_df, glueContext, table_name)

    output_path = 's3://${local.computed_s3_output_bucket}/{}'.format(output_prefix)
    logger.info('Exporting output to %s', output_path)
    logger.info('Table config: %s', table_config)
    datasink = glueContext.write_dynamic_frame.from_options(
        frame=output_dynamic_frame, connection_type='s3',
        connection_options={'path': output_path},
        format=table_config.get('output_format', 'parquet'))
    logger.info('Total number of row errors after ApplyMapping: %d', applymapping.errorsCount())
    logger.info('Total number of row errors after custom cleaning: %d', cleaned.errorsCount())
    logger.info('Total number of row errors after ResolveChoice: %d', resolvechoice.errorsCount())
    logger.info('Total number of row errors after null_blank_empty Map: %d', nullstrings.errorsCount())
    logger.info('Total number of row errors after SelectFields: %d', ordered.errorsCount())
    logger.info('Total number of row errors after DropNullFields: %d', dropnullfields.errorsCount())

experimental_spark_optimization = table_config.get('spark_optimization', False)

if experimental_spark_optimization:
    logger.info('Loading table with experimental spark changes')
    experimental_spark_changes()
else:
    logger.info('Loading table without experimental spark changes')
    stable_transformation()

job.commit()

EOF
}
