// Functions for sharing across Glue jobs
// This is rendered into the script code to make it easier to follow the script in the Glue
// console. It must be rendered near the top of the scripts since it does the module importing.
data "template_file" "common_code" {
  template = <<END_OF_STRING
from __future__ import division
import copy
import datetime
import json
import logging
import os
import sys
import time

from awsglue.utils import getResolvedOptions
import boto3
from botocore import credentials
from botocore import session as botocore_session
from botocore import exceptions as botocore_exceptions

logger = logging.getLogger()
# the following is a dummy setting to "warm up" the glue python shell logging configure
logger.setLevel(logging.INFO)
# remove all existing ones
if logger.handlers:
    for handler in logger.handlers:
        logger.removeHandler(handler)
# reset and get the correct logging ouput
formatter = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s')
logging_handler_out = logging.StreamHandler(sys.stdout)
logging_handler_out.setLevel(logging.INFO)
logging_handler_out.setFormatter(formatter)
logger.addHandler(logging_handler_out)

for library in ('boto', 'botocore', 'boto3', 'requests'):
    logging.getLogger(library).setLevel(logging.WARNING)

def get_json_schema_validator():
    # encode the json file due to indentation + boolean mismatch
    # make sure to update the draft version if using a different draft
    # import is done in this function because json schema library is only linked for watcher job
    from jsonschema import Draft7Validator

    validation_schema_json = ${jsonencode(var.database_type == "dynamodb" ?
file("${path.module}/resources/dynamodb_table_config_schema.json") : file("${path.module}/resources/rds_table_config_schema.json"))}
    return Draft7Validator(json.loads(validation_schema_json))

def get_s3_size(s3_resource, input_path):
    """Return the size of the files under the given s3 prefix."""
    bucket, _, key = input_path[5:].partition('/')  # split off leading "s3://"
    size = 0
    for o in s3_resource.Bucket(bucket).objects.filter(Prefix=key):
        size += o.size
    return size

def get_raw_config():
    """Returns an encoded JSON representation of the config provided in the Terraform template."""
    return ${jsonencode(var.table_config)}


def load_table_config(raw_config, table_name):
    """Get JSON table configuration from raw config, replacing single quotes along the way."""
    return json.loads(raw_config[table_name].replace('\'', '"'))


def glue_job_output_prefix(table_name, ts):
    year_month = ts[:6]
    return '${local.rendered_s3_output_key}{}/{}/{}/'.format(table_name, year_month, ts)


def get_ts(ts):
    """Parse the ts if it exists, or use now. Return datetime and formatted versions."""
    if ts != '0':
        return datetime.datetime.strptime(ts, '%Y%m%dT%H%M%S'), ts
    now = datetime.datetime.utcnow()
    now = now.replace(minute=0, second=0, microsecond=0)
    return now, now.strftime('%Y%m%dT%H%M%S')

def validate_table_config_schema(table_name, table_config_dict, json_schema_validator):
    """
    Validates table config against predefined json schema.
    Needs to create a dict with the table name before validating
    """
    json_schema_config = {table_name: table_config_dict}
    json_schema_validator.validate(json_schema_config)

#######################################
#### BEGIN: Tahoe API related code ####
#######################################

TAHOE_SERVICE = '/twirp/twitch.fulton.example.twitchtahoeapiservice.TwitchTahoeAPIService'

TYPE_MAPPING = {
    'bigint': 'LONG',
    'string': 'STRING',
    'multi': 'STRING',
    'binary': 'STRING',
    'float': 'DOUBLE',
    'double': 'DOUBLE',
    'boolean': 'BOOL',
    'timestamp': 'TIMESTAMP',
    'date': 'DATE',
    'int': 'INT',
}


def attempt_another(table_name, state, jobs_to_run, running_jobs, failed_tables):
    config = running_jobs[table_name]
    logging.warning('%s %s. attempts remaining: %d', table_name, state, config['attempts'])
    if config['attempts'] > 0:
        jobs_to_run[table_name] = running_jobs[table_name]
        logging.info('retrying %s', table_name)
    else:
        failed_tables.append(table_name)
        logging.error('out of attempts for %s', table_name)


def process_all_tables(tahoe_import_id, tahoe_session, producer_api_key, skip_db_export, glue_arguments):
    raw_config = get_raw_config()
    jobs_to_run = {}
    running_jobs = {}
    failed_tables = []
    tables_to_import = {}
    running_imports = {}

    for table_name in raw_config:
        table_config = load_table_config(raw_config, table_name)

        if skip_db_export and should_import_table(table_config):
            tahoe_table = TableToImport(table_name, table_config, tahoe_import_id, producer_api_key)

            # We assume the ts provided as argument represents a past export run
            error = migrate_glue_output(session, tahoe_session, tahoe_table, ts)
            if error:
                logging.error('Problem with s3 data, not attempting tahoe import for %s', table_name)
                raise RuntimeError('Job(s) for {} did not SUCCEED'.format([table_name]))

            tahoe_table.attempts = table_config.get('max_attempts', 3)
            tables_to_import[table_name] = tahoe_table
        else:
            jobs_to_run[table_name] = {
                'worker_type': table_config.get('worker_type', 'Standard'),
                'number_of_workers': table_config.get('worker_count', table_config.get('dpu_count', 10)),
                'attempts': table_config.get('max_attempts', 3),
            }
        logging.info('Tables to import: %s', ', '.join(tables_to_import.keys()))
        logging.info('Jobs to run: %s', ', '.join(jobs_to_run.keys()))

    while jobs_to_run or running_jobs or tables_to_import or running_imports:
        logging.info('Checking %d table export job statuses', len(running_jobs))
        for table_name, config in list(running_jobs.items()):
            response = glue.get_job_run(JobName=subjob_name, RunId=config['job_run_id'])
            time.sleep(1)  # Avoid hammering Glue API
            state = response['JobRun']['JobRunState']
            if state not in ['STOPPED', 'SUCCEEDED', 'FAILED', 'TIMEOUT']:
                continue
            if state == 'SUCCEEDED':
                logging.info('%s SUCCEEDED', table_name)

                # We're now ready to trigger a Tahoe import for this table if possible
                table_config = load_table_config(raw_config, table_name)
                if should_import_table(table_config):
                    logging.info("Migrating glue output for table %s into Tahoe input bucket", table_name)
                    tahoe_table = TableToImport(table_name, table_config, tahoe_import_id, producer_api_key)
                    error = migrate_glue_output(session, tahoe_session, tahoe_table, ts)
                    if error:
                        logging.error('Problem with s3 data, retrying %s', table_name)
                        attempt_another(table_name, state, jobs_to_run, running_jobs, failed_tables)
                        del running_jobs[table_name]
                        continue

                    tahoe_table.attempts = config['attempts']
                    tables_to_import[table_name] = tahoe_table
                else:
                    logging.warning('Skipping Tahoe import of %s', table_name)
            else:
                attempt_another(table_name, state, jobs_to_run, running_jobs, failed_tables)
            del running_jobs[table_name]

        logging.info('%d unstarted table export(s)', len(jobs_to_run))
        for table_name, config in sorted(jobs_to_run.items(), key=lambda x: -x[1]['number_of_workers']):
            # Don't try to run more jobs than allowed.
            if len(running_jobs) >= ${var.max_concurrent_runs}:
                logging.info('We have enough table export jobs running already')
                break
            try:
                logging.info('Running %s', table_name)
                arguments = {
                        '--table_name': table_name,
                        '--ts': ts,
                }
                if glue_arguments is not None:
                    arguments.update(glue_arguments)
                response = glue.start_job_run(
                    JobName=subjob_name,
                    WorkerType=config['worker_type'],
                    NumberOfWorkers=config['number_of_workers'],
                    Arguments=arguments)
            except glue.exceptions.ConcurrentRunsExceededException:
                logging.warning('Tried to run a job "%s" but failed due to concurrency limits. Will retry.', table_name)
                break
            logging.info('Started job to export "%s": %s', table_name, response['JobRunId'])
            running_jobs[table_name] = jobs_to_run[table_name]
            running_jobs[table_name]['job_run_id'] = response['JobRunId']
            running_jobs[table_name]['attempts'] -= 1
            del jobs_to_run[table_name]

        if '${var.tahoe_producer_name}':
            # Monitoring Tahoe imports can be done in the same way for RDS and Dynamo exports
            new_state = get_state_of_tahoe_imports(tahoe_session, {
                'running_imports': running_imports,
                'tables_to_import': tables_to_import,
                'failed_tables': failed_tables,
            })
            running_imports = new_state['running_imports']
            tables_to_import = new_state['tables_to_import']
            failed_tables = new_state['failed_tables']

        if running_jobs:
            logging.info('Running export jobs: %s', ', '.join(running_jobs.keys()))
        if jobs_to_run:
            logging.info('Unstarted export jobs: %s', ', '.join(jobs_to_run.keys()))

        time.sleep(30)

    if failed_tables:
        raise RuntimeError('Job(s) for {} did not SUCCEED'.format(failed_tables))

class TableToImport:
    """Represent a table we will import to Tahoe."""

    def __init__(self, name, table_config, import_id, producer_api_key):
        self.name = name.lower()
        self.version = table_config['version']
        self.custom_view_sql_def = table_config.get('custom_view_sql_def')
        self.import_id = import_id  # must be unique, e.g. timestamp or upload ID.
        self.producer_api_key = producer_api_key
        self.tahoe_view_name = table_config.get('tahoe_view_name', name.replace('-', '_')).lower()

        if 'output_fields' in table_config:
            self.columns = [c for c in table_config['schema'] if c['name'] in table_config['output_fields']]
        else:
            self.columns = table_config['schema']

        # enforce lowercasing on column names
        for col in self.columns:
            col['name'] = col['name'].lower()

        self.tahoe_schema = '${var.tahoe_producer_name}'
        tahoe_account_id = '${var.tahoe_producer_role_arn}'.rsplit(':', 2)[1]
        self.tahoe_input_bucket = 'tahoe-input-{}'.format(tahoe_account_id)
        self.tahoe_output_bucket = '${var.tahoe_producer_name}-{}'.format(tahoe_account_id)

    @property
    def tahoe_view_schema(self):
        """Name of the schema the view will live in."""
        return '${local.is_qa_job ? "qa_" : ""}dbsnapshots'

    @property
    def tahoe_view_sql_def(self):
        """SELECT portion of the Tahoe view SQL definition """
        if self.custom_view_sql_def:
            # A custom SQL always overrides the default definition
            return self.custom_view_sql_def.format(
                tahoe_schema=self.tahoe_schema,
                tahoe_versioned_name=self.tahoe_versioned_name)
        table_columns = ['"{}"'.format(c['name']) for c in self.columns]
        return """
            SELECT {table_columns}
            FROM {tahoe_schema}.{tahoe_versioned_name}
        """.format(
            table_columns=', '.join(table_columns),
            tahoe_schema=self.tahoe_schema,
            tahoe_versioned_name=self.tahoe_versioned_name)

    @property
    def tahoe_versioned_name(self):
        """Name of the table as it will be in Tahoe."""
        return '{}_v{}${local.is_qa_job ? "_qa" : ""}'.format(self.name.replace('-', '_'), self.version)

    @property
    def tahoe_application(self):
        return self.tahoe_schema

    @property
    def tahoe_column_definitions(self):
        return [{
            'name': c['name'],
            'type': TYPE_MAPPING[c.get('output_type', c['type'])],
            'sensitivity': c.get('sensitivity', 'NONE').upper(),
        } for c in self.columns]

    @property
    def tahoe_base_prefix(self):
        return '{}/{}/'.format(self.tahoe_application, self.tahoe_versioned_name)

    @property
    def tahoe_import_prefix(self):
        return '{}{}/'.format(self.tahoe_base_prefix, self.import_id)

    @property
    def tahoe_input_path(self):
        return 's3://{}/{}'.format(self.tahoe_input_bucket, self.tahoe_import_prefix)

    @property
    def tahoe_output_path(self):
        return 's3://{}/{}'.format(self.tahoe_output_bucket, self.tahoe_import_prefix)


def generate_tahoe_request(method, body):
    """Generate the payload of a Tahoe request."""
    log_body = body.copy()
    if 'api_key' in log_body:
        log_body['api_key'] = '<REDACTED>'
    return {
        'httpMethod': 'POST',
        'path': '{}/{}'.format(TAHOE_SERVICE, method),
        'headers': {'Content-Type': 'application/json'},
        'body': json.dumps(body),
        'log_body': json.dumps(log_body),
    }

def call_tahoe(session, payload):
    """Invoke the Tahoe API Lambda and return the parsed JSON response."""
    log_payload = payload.copy()
    log_payload['body'] = log_payload['log_body']
    del log_payload['log_body']
    del payload['log_body']
    logging.info('request payload: %s', log_payload)
    lambda_client = session.client('lambda')
    tahoe_lambda = 'ProducerLambdaFunction'
    resp = lambda_client.invoke(FunctionName=tahoe_lambda, Payload=json.dumps(payload))
    r_payload = resp['Payload'].read()
    if resp['StatusCode'] >= 300:
        raise RuntimeError(
            'Response code {} from {}: {}'.format(resp["StatusCode"], payload["path"], r_payload))
    r_payload = json.loads(r_payload)
    if resp.get('FunctionError'):
        raise RuntimeError(
            'FunctionError ({}) from {}: {}'.format(resp["FunctionError"], payload["path"], r_payload))
    if r_payload['statusCode'] >= 300:
        raise RuntimeError(
            'Response code {} from {}: {}'.format(r_payload["statusCode"], payload["path"], r_payload))
    logging.info("payload: %s", r_payload)
    return json.loads(r_payload['body'])


def parquet_table_request(session, table):
    """Import a full table to Tahoe."""
    s3_resource = session.resource('s3')
    size = get_s3_size(s3_resource, table.tahoe_input_path)
    payload = {
        'api_key': table.producer_api_key,
        'dag_name': 'parquet_table',
        'parameters': [{
            'name': 'COLUMN_DEFINITIONS',
            'value_column_definitions': table.tahoe_column_definitions,
        }, {
            'name': 'S3_SOURCE_PATH',
            'value_string': table.tahoe_input_path,
        }, {
            'name': 'S3_DESTINATION_PATH',
            'value_string': table.tahoe_output_path,
        }, {
            'name': 'S3_TABLE_PATH',
            'value_string': table.tahoe_output_path,  # we're doing a full table replacement
        }, {
            'name': 'TABLE_NAME',
            'value_string': table.tahoe_versioned_name,
        }, {
            'name': 'TABLE_SCHEMA',
            'value_string': table.tahoe_schema,
        }],
    }
    return generate_tahoe_request('RunDag', payload)


def get_tahoe_session(session, tahoe_import_id):
    """Generate a new AWS session for all future Tahoe API interactions."""
    role_session_name = 'producer-${var.tahoe_producer_name}-{}'.format(tahoe_import_id[-4:])
    tahoe_producer_role = '${var.tahoe_producer_role_arn}'
    return assume_role(session, tahoe_producer_role, role_session_name)


class AssumeRoleProvider(object):
    """Provider that fetches credentials via an assumed role."""
    METHOD = 'assume-role'

    def __init__(self, fetcher):
        self._fetcher = fetcher

    def load(self):
        return credentials.DeferredRefreshableCredentials(
            self._fetcher.fetch_credentials,
            self.METHOD
        )


def create_client_factory(session, region_name):
    """A factory to make a create_client function"""
    def _create_client(service_name, **kwargs):
        """Specify the regional endpoint for STS if making a STS client"""
        if service_name == 'sts':
            kwargs['endpoint_url'] = 'https://sts.{}.amazonaws.com'.format(region_name)
        return session.create_client(service_name, **kwargs)
    return _create_client


def assume_role(session, role_arn, role_session_name):
    """Return an AWS session assuming an auto-refreshing role."""
    region = 'us-west-2'
    fetcher = credentials.AssumeRoleCredentialFetcher(
        create_client_factory(session._session, region),
        session._session.get_credentials(),
        role_arn,
        extra_args={
            'RoleSessionName': role_session_name,
        },
    )
    role_session = botocore_session.Session()
    role_session.register_component(
        'credential_provider',
        credentials.CredentialResolver([AssumeRoleProvider(fetcher)])
    )
    return boto3.Session(botocore_session=role_session, region_name=region)


def migrate_glue_output(glue_session, tahoe_session, tahoe_table, ts):
    """Move Parquet files from the Glue output bucket into the Tahoe input bucket. Return True if data problem"""
    glue_s3 = glue_session.resource('s3')
    glue_s3_bucket = glue_s3.Bucket('${local.computed_s3_output_bucket}')
    tahoe_s3 = tahoe_session.resource('s3')
    tahoe_input_bucket = tahoe_s3.Bucket(tahoe_table.tahoe_input_bucket)
    output_prefix = glue_job_output_prefix(tahoe_table.name, ts)
    moved_names = []
    size = 0
    max_retries_per_copy = 4
    for object_summary in glue_s3_bucket.objects.filter(Prefix=output_prefix):
        copy_source = {
            'Bucket': '${local.computed_s3_output_bucket}',
            'Key': object_summary.key,
        }
        parquet_filename = os.path.basename(object_summary.key)
        moved_names.append(parquet_filename)
        size += object_summary.size
        target_path = '{}{}'.format(tahoe_table.tahoe_import_prefix, parquet_filename)
        logging.info('Moving s3://%s/%s into s3://%s/%s', copy_source['Bucket'],
                     copy_source['Key'], tahoe_table.tahoe_input_bucket, target_path)

        retry = 0
        while True:
            try:
                tahoe_input_bucket.copy(copy_source, target_path)
                break
            except botocore_exceptions.ClientError:
                retry += 1
                if retry > max_retries_per_copy:
                    raise
                time.sleep(2 ** retry)


    # Checking that the generated files are sequential, same uuid, and non-zero size to attempt to
    # catch issues with Spark. Example name: part-00000-29885a4d-4be5-484d-b0f3-43b08ce13eb9-c000.snappy.parquet
    return not (len(moved_names) == (max([int(n[5:10]) for n in moved_names]) + 1) and
                len({n[11:47] for n in moved_names}) == 1 and size > 0)


def trigger_tahoe_import(session, tahoe_table):
    """Trigger a Parquet import request and returns a dictionary with info for tracking."""
    payload = parquet_table_request(session, tahoe_table)
    resp = call_tahoe(session, payload)
    return {
        'id': resp['job_id'],
        'tahoe_request': payload,
        'table': tahoe_table,
    }


def generate_view_update_payload(tahoe_table):
    """Generates a request payload to update the view definition of a Tahoe table."""
    return {
        'api_key': tahoe_table.producer_api_key,
        'view': {
            'relation': tahoe_table.tahoe_view_name,
            'schema': tahoe_table.tahoe_view_schema,
        },
        'dependencies': [{
            'relation': tahoe_table.tahoe_versioned_name,
            'schema': tahoe_table.tahoe_schema,
        }],
        'definition': tahoe_table.tahoe_view_sql_def,
    }


def validate_config(tahoe_session, producer_api_key):
    """Ensure the config is valid and that no tables are missing version upgrades."""
    # get the json schema validator to make sure all table configs are set correctly
    validator = get_json_schema_validator()

    raw_config = get_raw_config()
    for table_name in raw_config:
        table_config = load_table_config(raw_config, table_name)
        validate_table_config_schema(table_name, table_config, validator)
        if tahoe_session is None:
            continue

        table = TableToImport(table_name, table_config, 'validation', producer_api_key)
        job_defs = table.tahoe_column_definitions
        query_payload = {
            'target_type': 'TABLE',
            'target': {
                'relation': table.tahoe_versioned_name,
                'schema': table.tahoe_schema,
            }
        }
        resp = call_tahoe(tahoe_session, generate_tahoe_request('QueryRelation', query_payload))
        # If the relation doesn't exist yet, we can define it as whatever.
        if not resp:
            logger.info('No table schema yet for %s', table.tahoe_versioned_name)
            continue
        tahoe_defs = resp['relations'][0]['column_definitions']
        if len(job_defs) != len(tahoe_defs):
            raise ValueError('Different number of columns between current Tahoe definition and '
                             f'job schema for {table_name}. Please increment the table\'s version.')
        for job_def, tahoe_def in zip(job_defs, tahoe_defs):
            tahoe_def['type'] = tahoe_def.get('type', 'STRING')
            tahoe_def['sensitivity'] = tahoe_def.get('sensitivity', 'NONE')
            if tahoe_def != job_def:
                raise ValueError('Job definition of column does not match Tahoe definition. '
                                 f'Please revert the job or increment the table\'s version. '
                                 f'Job: {job_def}. Tahoe: {tahoe_def}')
    logger.info('Existing table schemas look correct')

def should_update_tahoe_view_def(tahoe_session, update_payload):
    """Queries Tahoe API for the current table view definition to compare against local version."""
    query_payload = {
        'target_type': 'VIEW',
        'target': {
            'relation': update_payload['view']['relation'],
            'schema': update_payload['view']['schema'],
        }
    }
    resp = call_tahoe(tahoe_session, generate_tahoe_request('QueryRelation', query_payload))
    if not resp:
        return True
    view_def = resp['relations'][0]
    return (view_def['definition'].strip().lower() != update_payload['definition'].strip().lower()
            or view_def['relation'] != update_payload['view']
            or view_def['dependencies'] != update_payload['dependencies'])


def get_state_of_tahoe_imports(tahoe_session, current_state):
    """Triggers new Tahoe imports and checks the status of the ones currently running."""
    new_state = copy.deepcopy(current_state)
    running_imports = new_state['running_imports']
    tables_to_import = new_state['tables_to_import']
    failed_tables = new_state['failed_tables']

    logging.info('Checking %d Tahoe import job statuses', len(running_imports))
    for table_name, job in list(running_imports.items()):
        logging.info('Requesting status of Tahoe job with id %s', job['id'])
        job_request = generate_tahoe_request('QueryJobs', {'job_id': job['id']})
        resp_jobs = call_tahoe(tahoe_session, job_request)['jobs']
        if len(resp_jobs) != 1:
            raise RuntimeError('No jobs returned from Tahoe for id {}'.format(job['id']))
        resp_job = resp_jobs[0]
        if resp_job['status'] not in ['CANCELLED', 'FAILED', 'SUCCESS']:
            continue
        if resp_job['status'] == 'SUCCESS':
            logging.info('Tahoe import with id %s succeeded!', job['id'])

            # A successful import means the table is ready to be used as a dependency for the view
            update_payload = generate_view_update_payload(job['table'])
            if should_update_tahoe_view_def(tahoe_session, update_payload):
                logging.info('Updating view definition of %s', table_name)

                # On success the response body will be empty so no need to capture it
                call_tahoe(tahoe_session, generate_tahoe_request('PutView', update_payload))
            else:
                logging.info('View definition of %s requires no updating. Skipping.', table_name)
        else:
            if job['table'].attempts > 0:
                job['table'].attempts -= 1
                tables_to_import[table_name] = job['table']
                logging.info('Retrying Tahoe import for table %s', table_name)
            else:
                failed_tables.append(table_name)
                logging.error('Out of Tahoe import attempts for %s', table_name)
        del running_imports[table_name]

    logging.info('%d unstarted Tahoe import jobs', len(tables_to_import))
    for table_name, config in list(tables_to_import.items()):
        # Don't try to run more Tahoe imports than allowed.
        if len(running_imports) >= ${var.tahoe_max_concurrent_imports}:
            logging.info('We have enough Tahoe jobs running already')
            break

        logging.info("Triggering Tahoe import request for table %s. Attempts left: %s",
                     table_name, config.attempts)
        running_imports[table_name] = trigger_tahoe_import(tahoe_session, config)
        del tables_to_import[table_name]

    return new_state


def should_import_table(table_config):
    """Returns True if a Tahoe producer exists and the table is exported to Parquet"""
    return '${var.tahoe_producer_name}' and table_config.get('output_format', 'parquet') == 'parquet'


def patched_glue_client(session):
    """Patch the glue client for the most recent StartJobRun API"""
    json_model = session._session.get_component('data_loader').load_service_model('glue', 'service-2')
    json_model['operations'].update(json.loads("""${data.template_file.patched_glue_operations.rendered}"""))
    json_model['shapes'].update(json.loads("""${data.template_file.patched_glue_shapes.rendered}"""))
    return session.client('glue')


#####################################
#### END: Tahoe API related code ####
#####################################

END_OF_STRING
}
