#!/usr/bin/env python2
from pkg_resources import resource_filename, safe_name, to_filename
from distutils.spawn import find_executable
import subprocess
import os.path
import click
import glob
import paramiko
import select
import sys
import termios
import tty
import socket

KEYPAIR_ENVVAR = "SPARK_KEYPAIR"
SSHKEY_ENVVAR = "SPARK_SSH_KEY"
STATEFILE_DIR_ENVVAR = "SPARK_STATEFILES"

TERRAFORM_FILEPATH = resource_filename(__name__, "terraform/sparkcluster.tf")
TERRAFORM_DIR = os.path.dirname(TERRAFORM_FILEPATH)
PROVISION_SCRIPT_DIR = resource_filename(__name__, "terraform/provision")

statefile_dir = os.getcwd()

REMOTE_JOBS_DIR = "/var/opt/spark/jobs"

debug_mode = False
DEBUG_COLOR = "yellow"


def validate_env(ctx):
    if os.environ.get("AWS_ACCESS_KEY_ID") is None:
        ctx.fail("the AWS_ACCESS_KEY_ID env var must be set")
    if os.environ.get("AWS_SECRET_ACCESS_KEY") is None:
        ctx.fail("the AWS_SECRET_ACCESS_KEY env var must be set")

def debug(msg):
    """ log a message to stderr if debug_mode is enabled """
    if debug_mode:
        click.secho(msg, fg=DEBUG_COLOR, err=True)


def log(msg):
    """ log a message to stderr """
    click.secho(msg, fg="green", err=True)


def terraform_exec(command, capture_output=False):
    """Execute a terraform command. moves to the appropriate
    directory. if capture_output is True, then this returns stdout of
    the command; else, it returns the exit code of the invoked
    subprocess
    """
    debug("moving to %s" % TERRAFORM_DIR)
    debug("calling command %s" % (command))
    if capture_output:
        return subprocess.check_output(command, cwd=TERRAFORM_DIR)
    elif debug_mode:
        output_handler = None  # will just print to stdout
    else:
        output_handler = open(os.devnull, 'wb')
    return subprocess.call(command, cwd=TERRAFORM_DIR, stdout=output_handler)


def statefile_path(name):
    """ returns the path to a statefile for the given cluster name """
    escapedname = to_filename(safe_name(name))
    return os.path.join(statefile_dir, '%s.tfstate' % escapedname)


def cluster_exists(name):
    """ returns true if a statefile exists for the given cluster name """
    return os.path.exists(statefile_path(name))


def assert_exists(ctx, name):
    """ fatally exits if the given cluster doesn't have a statefile '"""
    if not cluster_exists(name):
        ctx.fail("unable to find a terraform statefile describing the cluster '%s' (checked %s)" %
                 (name, click.format_filename(statefile_path(name))))


@click.group()
@click.option("--dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, writable=True),
              help="directory to store state files (default is current working directory)",
              default=statefile_dir, envvar=STATEFILE_DIR_ENVVAR)
@click.option("-q", "--quiet", default=False, help="silence debug output", is_flag=True)
@click.pass_context
def cli(ctx, dir, quiet):
    global statefile_dir, debug_mode
    if dir is not None:
        statefile_dir = dir
    debug_mode = not quiet
    validate_env(ctx)


@cli.group()
@click.argument("name")
@click.option("-i", "--sshkey", type=click.Path(exists=True, dir_okay=False),
              required=True, envvar=SSHKEY_ENVVAR,
              help="SSH PEM key to use to connect to instances for provisioning. "
              "Will check the %s environment variable if left unset" % SSHKEY_ENVVAR)
@click.pass_context
def slaves(ctx, name, sshkey):
    if not cluster_exists(name):
        ctx.fail("%s does not exist!" % name)
    ctx.obj = {}
    ctx.obj['name'] = name
    ctx.obj['sshkey'] = sshkey

@slaves.command(help="check disk usage on workspace mounts for all slaves")
@click.pass_context
def disk(ctx):
    name = ctx.obj['name']
    key = ctx.obj['sshkey']
    slaves = read_from_statefile(statefile_path(name), 'slave-ips').split('\n')
    for s in slaves:
        print "Disk usage for %s:" % s
        client = connect(s, key)
        remote_exec(client, "df -h /var/opt/spark/scratch/mnt*", True)
        client.close()
    
@cli.command(help="launch a new spark cluster")
@click.argument("name")
@click.option("-n", "--num-slaves", default=2, type=click.IntRange(0, 200),
              help="number of slave nodes to spin up (between 0 and 200)")
@click.option("-k", "--keyname", type=str, required=True, envvar=KEYPAIR_ENVVAR,
              help="name of the AWS keypair to use when starting instances - "
              "will check the %s environment variable if left unset" % KEYPAIR_ENVVAR)
@click.option("-i", "--sshkey", type=click.Path(exists=True, dir_okay=False),
              required=True, envvar=SSHKEY_ENVVAR,
              help="SSH PEM key to use to connect to instances for provisioning. "
              "Will check the %s environment variable if left unset" % SSHKEY_ENVVAR)
@click.pass_context
def create(ctx, name, num_slaves, keyname, sshkey):
    if cluster_exists(name):
        ctx.fail("the cluster %s already exists!" % name)
    log("starting cluster %s" % name)
    command = [
        'terraform', 'apply',
        '-state', statefile_path(name),
        '-state-out', statefile_path(name),
        '-var', 'config.script_dir=%s' % PROVISION_SCRIPT_DIR,
        '-var', 'config.cluster_name=%s' % name,
        '-var', 'config.num_slaves=%d' % num_slaves,
        '-var', 'config.keypair_name=%s' % keyname,
        '-var', 'config.sshkey=%s' % sshkey,
    ]
    ctx.exit(terraform_exec(command))


@cli.command(help="print terraform's plan for the given inputs")
@click.argument("name")
@click.option("-n", "--num-slaves", default=2, type=click.IntRange(0, 200),
              help="number of slave nodes to spin up (between 0 and 200)")
@click.pass_context
def plan(ctx, name, num_slaves):
    command = [
        'terraform', 'plan',
        '-var', 'config.script_dir=%s' % PROVISION_SCRIPT_DIR,
        '-var', 'config.cluster_name=%s' % name,
        '-var', 'config.num_slaves=%d' % num_slaves,
    ]
    if cluster_exists(name):
        command += ['-state', statefile_path(name)]
    output = terraform_exec(command, capture_output=True)
    click.echo(output)
    ctx.exit(0)


@cli.command(help="resize an existing spark cluster")
@click.argument("name")
@click.option("-n", "--num-slaves", default=2, type=click.IntRange(0, 200),
              help="number of slave nodes to spin up (between 0 and 200)")
@click.option("-k", "--keyname", type=str, required=True, envvar=KEYPAIR_ENVVAR,
              help="name of the AWS keypair to use when starting instances - "
              "will check the %s environment variable if left unset" % KEYPAIR_ENVVAR)
@click.option("-i", "--sshkey", type=click.Path(exists=True, dir_okay=False),
              required=True, envvar=SSHKEY_ENVVAR,
              help="SSH PEM key to use to connect to instances for provisioning. "
              "Will check the %s environment variable if left unset" % SSHKEY_ENVVAR)
@click.pass_context
def resize(ctx, name, num_slaves, keyname, sshkey):
    assert_exists(ctx, name)
    log("resizing cluster %s" % name)
    command = [
        'terraform', 'apply',
        '-state', statefile_path(name),
        '-state-out', statefile_path(name),
        '-var', 'config.script_dir=%s' % PROVISION_SCRIPT_DIR,
        '-var', 'config.cluster_name=%s' % name,
        '-var', 'config.num_slaves=%d' % num_slaves,
        '-var', 'config.keypair_name=%s' % keyname,
        '-var', 'config.sshkey=%s' % sshkey,
    ]
    ctx.exit(terraform_exec(command))


@cli.command(help="get the URI of the spark master for a cluster and print it to stdout")
@click.argument("name")
@click.pass_context
def whereis(ctx, name):
    assert_exists(ctx, name)
    uri = get_master_uri(ctx, name)
    print uri
    ctx.exit(0)


@cli.command(help="open the web UI administration page for this cluster in the default web browser")
@click.argument("name")
@click.pass_context
def webui(ctx, name):
    ip = get_master_ip(ctx, name)
    url = "http://%s:8080" % ip
    ctx.exit(click.launch(url))


@cli.command(help="destroy a spark cluster")
@click.argument("name")
@click.option("-d", "--delete", default=False, is_flag=True,
              help="delete the associated tfstate file after bringing down the cluster")
@click.pass_context
def destroy(ctx, name, delete):
    log("destroying cluster %s" % name)
    assert_exists(ctx, name)
    command = [
        'terraform', 'destroy',
        '-state', statefile_path(name),
    ]
    code = terraform_exec(command)
    if code != 0:
        ctx.exit(code)

    if delete:
        log("deleting statefile %s" % statefile_path(name))
        os.remove(statefile_path(name))
    ctx.exit(0)


def get_master_ip(ctx, name):
    """ get IP address of master of given cluster """
    assert_exists(ctx, name)
    ip = read_from_statefile(statefile_path(name), 'master-ip')
    if ip == "":
        ctx.fail("%s has no active master" % name)
    return ip


def get_master_uri(ctx, name):
    ip = get_master_ip(ctx, name)
    return "spark://%s:7077" % ip


def spark_shell_exists():
    """ returns true if spark-submit bin is available"""
    return find_executable("spark-shell") is not None


# SSH wrappers

class IgnoreHostKeys(paramiko.client.MissingHostKeyPolicy):
    def missing_host_key(self, client, hostname, key):
        return


def posix_shell(chan):
    oldtty = termios.tcgetattr(sys.stdin)
    try:
        tty.setraw(sys.stdin.fileno())
        tty.setcbreak(sys.stdin.fileno())
        chan.settimeout(0.0)

        while True:
            r, w, e = select.select([chan, sys.stdin], [], [])
            if chan in r:
                try:
                    x = chan.recv(1024)
                    if len(x) == 0:
                        break # eof
                    sys.stdout.write(x)
                    sys.stdout.flush()
                except socket.timeout:
                    pass
            if sys.stdin in r:
                x = sys.stdin.read(1)
                if len(x) == 0:
                    break
                chan.send(x)
    finally:
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty)

def connect(ip, sshkey):
    client = paramiko.SSHClient()
    client.set_missing_host_key_policy(IgnoreHostKeys())
    client.connect(
        hostname=ip,
        username='ubuntu',
        key_filename=sshkey,
        compress=True
    )
    return client
    
        
def connect_to_master(ctx, name, sshkey):
    master_ip = get_master_ip(ctx, name)
    return connect(master_ip, sshkey)

@cli.command(help="open an ssh session on the master node")
@click.argument("name")
@click.option("-i", "--sshkey", type=click.Path(exists=True, dir_okay=False),
              required=True, envvar=SSHKEY_ENVVAR,
              help="SSH PEM key to use to connect to instances for provisioning. "
              "Will check the %s environment variable if left unset" % SSHKEY_ENVVAR)
@click.pass_context
def ssh(ctx, name, sshkey):
    client = connect_to_master(ctx, name, sshkey)
    channel = client.invoke_shell()
    posix_shell(channel)
    channel.close()
    client.close()


@cli.command(help="open an ssh session on the master node")
@click.argument("name")
@click.argument("command", nargs=-1, type=click.UNPROCESSED)
@click.option("-i", "--sshkey", type=click.Path(exists=True, dir_okay=False),
              required=True, envvar=SSHKEY_ENVVAR,
              help="SSH PEM key to use to connect to instances for provisioning. "
              "Will check the %s environment variable if left unset" % SSHKEY_ENVVAR)
@click.option("-t", "--term", help="allocate a pseuoterminal for the session",
              is_flag=True, default=False)
@click.pass_context
def sshexec(ctx, name, command, sshkey, term):
    client = connect_to_master(ctx, name, sshkey)
    remote_exec(client, " ".join(command), term)
    client.close()


def remote_exec(sshclient, command, term):
    channel = sshclient.get_transport().open_session()
    if term:
        debug("getting pty with term %s" % os.environ.get("TERM", "vt100"))
        channel.get_pty(os.environ.get('TERM', 'vt100'))
    debug("executing remote command %s" % command)
    channel.exec_command(command)
    posix_shell(channel)
    if channel.exit_status_ready():
        code = channel.recv_exit_status()
    else:
        code = 1
    channel.close()
    return code


@cli.command(help="submit a job to a cluster")
@click.argument("name")
@click.argument("jar", type=click.Path(exists=True, dir_okay=False))
@click.argument("classname")
@click.option('--include', multiple=True, type=click.Path(exists=True, dir_okay=False), help="extra jars to include")
@click.option("-i", "--sshkey", type=click.Path(exists=True, dir_okay=False),
              required=True, envvar=SSHKEY_ENVVAR,
              help="SSH PEM key to use to connect to instances for provisioning. "
              "Will check the %s environment variable if left unset" % SSHKEY_ENVVAR)
@click.argument("job-args", nargs=-1, type=click.UNPROCESSED)
@click.pass_context
def submit(ctx, name, jar, classname, include, sshkey, job_args):
    # We need to SSH onto the master and submit the job from
    # there. The job-submitter is, in Spark terms, the 'driver' of the
    # Spark job, and the driver needs to be addressable form the
    # workers, so we need to be on the actual cluster for this to
    # work.
    jar = os.path.expanduser(jar)
    client = connect_to_master(ctx, name, sshkey)

    try:
        upload_jar(client, jar)
        for j in include:
            upload_dependency(client, os.path.expanduser(j))
        master_uri = get_master_uri(ctx, name)
        ctx.exit(run_job(client, master_uri, jar, classname, include, *job_args))
    finally:
        client.close()
    ctx.exit(0)


def upload_dependency(sshclient, jarpath):
    debug("uploading dependency jar from %s" % jarpath)
    sftp = sshclient.open_sftp()
    try:
        basename = os.path.basename(jarpath)
        remote_dest = os.path.join(REMOTE_JOBS_DIR, basename)
        debug("uploading to %s" % remote_dest)
        sftp.put(jarpath, remote_dest)
    finally:
        sftp.close()

def upload_jar(sshclient, jarpath):
    debug("uploading jar from %s" % jarpath)
    sftp = sshclient.open_sftp()
    try:
        basename = os.path.basename(jarpath)
        remote_dest = os.path.join(REMOTE_JOBS_DIR, basename)
        debug("uploading to %s" % remote_dest)
        sftp.put(jarpath, remote_dest)
    finally:
        sftp.close()

def run_job(sshclient, master_uri, jar, classname, include, *extra_args):
    jarbase = os.path.basename(jar)
    remote_jar = os.path.join(REMOTE_JOBS_DIR, jarbase)

    jars = ",".join(os.path.join(REMOTE_JOBS_DIR, j) for j in include)
    
    # wrap each arg in quotes to avoid shell expansion
    args = " ".join(['"{0}"'.format(x) for x in extra_args])

    cmd = '/opt/spark/bin/spark-submit --master "{master}" --class {classname} --executor-memory 8G '.format(
            master=master_uri,
            classname=classname,
    )
    if len(jars) > 0:
        cmd += ' --jars {jars}'.format(jars=jars)
    cmd += ' {remote_jar} {args}'.format(remote_jar=remote_jar, args=args)
        
    debug("executing remote command: {}".format(cmd))
    return remote_exec(sshclient, cmd, term=True)


@cli.command(help="connect to a spark shell")
@click.argument("name")
@click.argument("spark-shell-args", nargs=-1, type=click.UNPROCESSED)
@click.pass_context
def shell(ctx, name, spark_shell_args):
    if not spark_shell_exists():
        ctx.fail("spark-shell binary couldn't be found. Are you sure spark installed?")
    command = [
        "spark-shell",
        "--master", get_master_uri(ctx, name),
    ]
    command += list(spark_shell_args)
    ctx.exit(subprocess.call(command))


@cli.command(name="list", help="list clusters with statefiles to stdout")
@click.pass_context
def list_clusters(ctx):
    # Force quiet to simplify output
    global debug_mode
    debug_mode = False

    statefiles = glob.glob(os.path.join(statefile_dir, "*.tfstate"))
    click.secho("state\tname\tmaster-ip\tnum-slaves\tkeypair", fg="green")
    for s in statefiles:
        data = describe_cluster(s)
        color = "green" if data["state"] == "UP" else "red"
        click.secho("{state}\t{name}\t{master_ip}\t{num_slaves}\t{keypair}".format(**data), fg=color)
    ctx.exit(0)
    

def read_from_statefile(f, val):
    """ read an output val from a statefile's output '"""
    command = [
        'terraform', 'output',
        '-state', f,
        val
    ]
    try:
        return terraform_exec(command, capture_output=True).strip()
    except subprocess.CalledProcessError:
        return None


def describe_cluster(statefile):
    data = {
        "name": read_from_statefile(statefile, "name"),
        "master_ip": read_from_statefile(statefile, "master-ip"),
        "num_slaves": read_from_statefile(statefile, "num-slaves"),
        "keypair": read_from_statefile(statefile, "keypair"),
        "slaves": read_from_statefile(statefile, "slave-ips"),
    }
    if data["master_ip"] != "":
        data["state"] = "UP"
    else:
        data["state"] = "DOWN"
    return data
