from typing import Any, Dict
import json
import logging
from airflow.settings import conf
from zenyatta.aws.s3 import get_base_s3_path, upload_file_to_s3, check_s3_object_exist
from zenyatta.common import get_airflow_connection
from zenyatta.common.errors import ZenyattaError
from zenyatta.common.util import retrieve_http_response, get_aws_resource_variable


def check_table_partition(table: str, conn_id: str) -> Any:
    """
    check if table has partition from Variable
    return None if the table does not have a qualified primary key.
    :param table: table name
    :parm conn_id: connection name
    :return
    """
    var_key = conn_id + '-pitr-table-partition'
    cur_partitions = get_aws_resource_variable(var_key, 'partitions')
    return cur_partitions[table] if table in cur_partitions.keys() else None


def make_spark_table_config(table: str, conn_id: str, rds_pitr_host: str,
                            ts_nodash: str, s3_meta: Dict) -> Dict:
    """
    make a configuration file for spark job which is to fill the spark script template.
    return configuration dict
    :param table: table name
    :param conn_id: connection name
    :param rds_pitr_host: rds pitr host name
    :param ts_nodash
    :param s3_meta
    :return
    """
    connection = get_airflow_connection(conn_id)
    spark_app_config = {}
    spark_app_config['conn_type'] = connection.conn_type.lower()
    jdbc_prefix = 'mysql' if spark_app_config['conn_type'] == 'mysql' else 'postgresql'
    spark_app_config['jdbc_url'] = 'jdbc:{}://{}:{}/{}'.format(jdbc_prefix, rds_pitr_host,
                                                               connection.port, connection.schema)
    spark_app_config['driver_class'] = 'com.mysql.jdbc.Driver' \
        if spark_app_config['conn_type'] == 'mysql' else 'org.postgresql.Driver'
    spark_app_config['user'] = connection.login
    spark_app_config['password'] = connection.get_password()
    spark_app_config['table'] = table
    spark_app_config['object_prefix'] = get_base_s3_path(s3_meta['s3_key'], ts_nodash, conn_id) \
        + '/pq/{table}/'.format(**locals())
    spark_app_config['s3_pq_path'] = 's3://{bucket}/'.format(**s3_meta) + \
        get_base_s3_path(s3_meta['s3_key'], ts_nodash, conn_id) + '/pq/{table}/'.format(**locals())
    spark_app_config['pq_success_path'] = get_base_s3_path(s3_meta['s3_key'], ts_nodash, conn_id) \
        + '/pq/{table}/_SUCCESS'.format(**locals())
    spark_app_config['script_name'] = 'spark-{conn_id}-{table}-{ts_nodash}.py'.format(**locals())
    spark_app_config['local_script_path'] = conf.get('spark', 'scripts_dir') + '/' + \
        spark_app_config['script_name']
    spark_app_config['s3_script_path'] = 'zenyatta/spark/scripts/' + spark_app_config['script_name']
    spark_app_config['ts_nodash'] = ts_nodash
    rtn = check_table_partition(table, conn_id)
    if rtn is not None:
        spark_app_config['partitionFlag'] = True
        spark_app_config['partitionCol'] = rtn['col']
        spark_app_config['lowerBound'] = rtn['min']
        spark_app_config['upperBound'] = rtn['max']
        spark_app_config['partitionNum'] = 128 if rtn['max'] > 1000000000 else 32
    else:
        spark_app_config['partitionFlag'] = False
    return spark_app_config


def generate_spark_script(app_conf: Dict, s3_meta: Dict) -> bool:
    """
    use the app_conf to generate a spark job script which will be saved on local directory
    then uploaded to 's3://{bucket}/zenyatta/spark/scripts/'.
    :param app_conf: spark configuration
    :param s3_meta
    """
    try:
        script = '''
from __future__ import print_function
from pyspark import SparkContext
from pyspark.sql import DataFrameReader, DataFrameWriter, SQLContext, SparkSession

spark = SparkSession.builder.appName("PsqlToParquet_{table}_{ts_nodash}").getOrCreate()
sc = spark.sparkContext
sqlContext = SQLContext(sc)

'''.format(**app_conf) + ('''
df = sqlContext.read.format("jdbc").options(url="{jdbc_url}", driver="{driver_class}", user="{user}",
                                            password="{password}", dbtable="{table}",
                                            partitionColumn="{partitionCol}", lowerBound={lowerBound},
                                            upperBound={upperBound}, numPartitions={partitionNum},
                                            fetchsize=1000).load()
'''.format(**app_conf) if app_conf['partitionFlag'] else '''
df = sqlContext.read.format("jdbc").options(url="{jdbc_url}", driver="{driver_class}", user="{user}",
                                            password="{password}", dbtable="{table}",
                                            fetchsize=1000).load()
'''.format(**app_conf)) + '''
df.write.mode("overwrite").format("parquet").save("{s3_pq_path}")

spark.stop()
        '''.format(**app_conf)
        output = open(app_conf['local_script_path'], "w")
        output.write(script)
        output.close()
        upload_file_to_s3(s3_meta['bucket'], app_conf['local_script_path'],
                          app_conf['s3_script_path'], s3_meta['role_arn'])
        return check_s3_object_exist(s3_meta['bucket'], s3_meta['role_arn'], app_conf['s3_script_path'])
    except Exception as e:
        logging.info("{}".format(e))


def get_yarn_app_by_name(dns: str, app_name: str) -> Dict:
    url = ("http://{0}:8088/ws/v1/cluster/apps/".format(dns))
    data = retrieve_http_response(url)
    if data is None:
        logging.info("unexpected error. could not retrieve app detail.")
        return None
    else:
        app_list = data['apps']['app']
        pick = [app for app in app_list if app['name'] == app_name]
        pick = sorted(pick, key=lambda d: d['startedTime'], reverse=True)
        return pick[0] if len(pick) > 0 else None


def get_yarn_app_by_id(dns: str, app_id: str) -> Dict:
    url = ("http://{0}:8088/ws/v1/cluster/apps/{1}".format(dns, app_id))
    data = retrieve_http_response(url)
    return data['app'] if data is not None else None
