#!/usr/bin/env python2

import os
import subprocess
import types
import logging
import time
import datetime


DEBUG = False
DEBUG_OUT = True
DELIMITER = '=' * 65
DELIMITER_LITE = '-' * 65

SSH_TIMEOUT_SEC = 3
SSH_CONNECT_ERROR_HEADER = "ssh: connect to host "


log = logging.getLogger()


class SSHError(RuntimeError):
    pass


class ParsingOutputError(RuntimeError):
    pass


def error(msg, must_exit=False):
    sys.stderr.write('%s\n' % (msg, ))
    sys.stderr.flush()
    if must_exit:
        sys.exit(1)


def die(msg):
    error(msg, must_exit=True)


def out(cmds, fail=True, env=None, shell=False):
    """
    get command output
    """
    def out_debug(msg):
        if DEBUG_OUT:
            logging.debug('out: ' + msg)

    if isinstance(cmds, types.StringTypes):
        cmd_str = cmds
    else:
        cmd_str = ' '.join(cmds)

    log.info(DELIMITER)
    log.info('RUN: %s' % (cmd_str, ))

    if env:
        out_debug('env: %r' % (env, ))

    if env is None:  # not "if env:" because there can be empty dict: env={}
        local_env = dict(os.environ)
    else:
        local_env = dict(env)

    local_env.update({'LANG': 'en_US.UTF-8'})

    p = subprocess.Popen(cmds, close_fds=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                         env=local_env, shell=shell)
    s = p.stdout.read()
    ret = p.wait()

    if fail and ret != 0:
        raise Exception("Process ret=%d: %s" % (ret, s))

    log.info('%s\n%s\n%s\n' % (DELIMITER_LITE, s, DELIMITER_LITE))
    return s


def get_host_logs(host, host_dir, target_dir):
    pass


def get_instance_ps_filter(prod_flag):

    def line_filter(line):
        parts = line.split()
        if prod_flag:
            return "-Xmx8g" in parts and "-Xms8g" in parts
        else:
            return "-Xmx3g" in parts and "-Xms3g" in parts

    return line_filter


def get_instance_path(command_line):
    """
    -XX:HeapDumpPath=/db/iss3/instances/job-service_dev_iss_job-service_dev_IhwS1KIYwsF
    """
    parts = [p for p in command_line.split() if p.startswith('-XX:HeapDumpPath=/db/iss3/instances/')]
    if len(parts) != 1:
        raise ParsingOutputError("cant get HeapDumpPath from [%s]" % (command_line, ))

    return parts[0].split('=', 1)[1]


def get_ssh_errors(response):
    errors = [s for s in response.splitlines() if s.startswith('ssh: ')]
    return '; '.join(errors)


def has_ssh_error(response):
    return len(get_ssh_errors(response))


def assert_ssh_ok(response):
    errors = get_ssh_errors(response)
    if errors:
        raise SSHError(errors)


def get_log_path(host, prod_flag):
    ps_result = out(["ssh", "-o", "ConnectTimeout=%s" % SSH_TIMEOUT_SEC, host, "ps ax | grep java | grep Xms | grep Xmx | grep -v grep"], fail=False)
    assert_ssh_ok(ps_result)

    lines = ps_result.splitlines()
    line_filter = get_instance_ps_filter(prod_flag)
    lines = [i for i in lines if line_filter(i)]

    if len(lines) != 1:
        raise ParsingOutputError("cant get instance command line from [%s]" % (ps_result, ))

    return get_instance_path(lines[0])


def add_slash(s):
    if not s.endswith('/'):
        s += '/'
    return s


def removedir(dirname):
    if os.path.isdir(dirname):
        os.removedirs(dirname)


def pretty_size(size):
    label = ''
    if size >= 1024:
        size >>= 10
        label = 'K'
    if size >= 1024:
        size >>= 10
        label = 'M'
    if size >= 1024:
        size >>= 10
        label = 'G'
    return "%d%s" % (size, label)


def report_dir_size(dirname):
    dir_size = 0
    empty_files = []
    for r, dirs, files in os.walk(dirname):
        for f in files:
            full_path = os.path.join(r, f)
            size = os.path.getsize(full_path)
            dir_size += size
            if not size:
                empty_files.append(full_path)
    return dir_size, empty_files


def remove_files(files):
    for f in files:
        try:
            os.unlink(f)
        except Exception, e:
            print 'ERROR delete:', f, e


def download_logs(hostname, remote_dir, local_dir):
    assert os.path.isdir(local_dir)

    remote_mask = add_slash(remote_dir) + '*.log'

    args = ["scp",
            "-C",
            "-o", "ConnectTimeout=%s" % SSH_TIMEOUT_SEC,
            "%s:%s" % (hostname, remote_mask),
            add_slash(local_dir)]
    # print ' '.join(args)
    # print args
    start_time = time.time()
    response = out(args, fail=False)
    elapsed = int(time.time() - start_time)

    assert_ssh_ok(response)

    dir_size, empty_files = report_dir_size(local_dir)

    print 'size:', pretty_size(dir_size), ' elapsed time:', elapsed // 60, 'min'

    remove_files(empty_files)

    print response


#job-service-1.haze.yandex.net

JOBAPI_HOSTS_DEV = """
job-service-sas-1.haze.yandex.net
job-service-2.i.fog.yandex.net
job-service-3.haze.yandex.net
""".split()

JOBAPI_HOSTS_PROD = """
job-service-sas-1.haze.yandex.net
job-service-2.i.fog.yandex.net
job-service-3.haze.yandex.net
""".split()


def get_hosts(prod_flag):
    if prod_flag:
        return JOBAPI_HOSTS_PROD
    else:
        return JOBAPI_HOSTS_DEV


def get_logs(prod_flag):
    if prod_flag:
        suffix = 'prod'
    else:
        suffix = 'dev'

    print "get JobAPI logs for:", suffix

    root_dir = time.strftime('%Y-%m-%d--%H%M%S', time.localtime(time.time()))
    root_dir += "--" + suffix

    os.makedirs(root_dir)

    for hostname in get_hosts(prod_flag):
        try:
            local_dir = os.path.join(root_dir, hostname)

            print "query %s..." % hostname
            remote_dir = get_log_path(hostname, prod_flag)
            print "dir:", remote_dir

            if not os.path.exists(local_dir):
                os.makedirs(local_dir)

            download_logs(hostname, remote_dir, local_dir)

        except SSHError, e:
            print "ERROR: %s" % (e.args[0], )
            removedir(local_dir)
        except ParsingOutputError, e:
            print "ERROR parsing: %s" % (e.args[0], )
            removedir(local_dir)

        print


if __name__ == '__main__':
    get_logs(prod_flag=True)
