#!/skynet/python/bin/python

from subprocess import Popen, PIPE
import sys
import os
import os.path

class Result(object):
    dir_exists = False
    exec_exists = False
    processes = None
    expected_num_processes = 0
    manifest = None

    def __init__(self):
        pass

class CheckStatusRunner(object):
    def __init__(self, exec_dir, exec_name, manifest_name):
        self.exec_dir = exec_dir
        self.exec_name = exec_name
        self.manifest_name = manifest_name
     
    def __str__(self):
        return '<CheckStatusRunner(%s, %s, %s)>' % (self.exec_dir, self.exec_name, self.manifest_name)

    def run(self):
        result = Result()
        num_proc_path = os.path.join(self.exec_dir, 'num_processes')
        if os.path.isfile(num_proc_path):
            with open(num_proc_path, 'r') as f:
                result.expected_num_processes = int(f.read().strip())
        if os.path.isdir(self.exec_dir):
            result.dir_exists = True
        if os.path.isfile(os.path.join(self.exec_dir, self.exec_name)):
            result.exec_exists = True
        manifest_path = os.path.join(self.exec_dir, self.manifest_name)
        if os.path.isfile(manifest_path):
            with open(manifest_path, 'r') as f:
                result.manifest = f.read()
        ps = Popen('ps -x -w -w -o pid= -o comm= -o command=', shell=True, stdout=PIPE, stderr=PIPE, stdin=PIPE)
        try:
            result.processes = [ \
              (int(pid), fullcmd) \
              for pid, cmd, fullcmd \
              in [ x.strip().split(None, 2) for x in ps.stdout ] \
              if cmd == self.exec_name ]
        finally:
            ps.wait()
        
        return result

def run():
    from library.sky.hosts import resolveHosts
    from api.cms import Registry
    from optparse import OptionParser
    import time
    import hashlib
    from library.tasks.command import CommandRunner
    import api.kqueue
    import json
    import itertools

    def list_instances(filters):
        return list(Registry.listSearchInstances(**filters))

    def list_basesearches(instance, shard_tags):
        filters = {
            'conf' : 'HEAD',
            'instanceTagName' : instance
        }

        if (shard_tags):
            result = []
            for shard_tag in shard_tags:
                filters['shardTagName'] = shard_tag
                result += list_instances(filters)
        else:
            result = list_instances(filters)
        
        result.sort(key = lambda r: (r.host, r.port))
        return [(r.host, r.port, '/db/BASE/' + r.shard) for r in result]

    def decode_error(err):
        if isinstance(err, unicode):
            return err.decode('utf-8')
        elif isinstance(err, str):
            return err
        else:
            return repr(err)

    def strip_domain(hostname):
        return hostname.split('.', 1)[0]

    def compact_hosts(host_group):
        host_group = set(host_group)
        result = []
        for found, hostnames in itertools.groupby(hosts, lambda h: h in host_group):
            if found:
                hostnames = list(hostnames)
                start = strip_domain(hostnames[0])
                end = strip_domain(hostnames[-1])
                if start == end:
                    result.append(start)
                elif len(hostnames) == 2:
                    result.append(start)
                    result.append(end)
                else:
                    result.append('%s..%s [%d]' % (start, end, len(hostnames)))
        return ', '.join(result)

    def calc_hash(result):
        md5 = hashlib.md5()
        md5.update(str(result.dir_exists))
        md5.update(',' + str(result.exec_exists))
        md5.update(',' + (result.manifest or '{}'))
        md5.update(',' + str(len(result.processes) > 0))
        if result.expected_num_processes:
            md5.update(',' + str(result.expected_num_processes == result.processes))
        return md5.digest()

    def add_item(map, id, item):
        if id not in map:
            itemlist = [item]
            map[id] = itemlist
        else:
            map[id].append(item)

    DEFAULT_INSTANCE = 'testws-production-replica'

    parser = OptionParser()
    parser.add_option("-i", "--instance", dest="instance", help="instance tag name (for example, testws-production-replica)", default=None)
    parser.add_option("-s", "--shard", dest="shard_tags", help="shard tag name (for example, RusTier0)", default=[], action="append")
    parser.add_option("-b", "--beta", dest="beta_name", help="beta name (eg. 'mark4')", default="mark4")
    parser.add_option("-p", "--program", dest="program_name", help="program base name (eg. 'httpsearch')", default="httpsearch")
    (options, hostdefs) = parser.parse_args()

    beta_name = options.beta_name
    beta_path = '/var/tmp/sky_snippets_%s' % beta_name
    beta_program = '%s_%s' % (options.program_name, beta_name)
    beta_manifest = '%s.manifest' % beta_program

    # This logic allows us to invoke the script in many interesting ways:
    # select all base searches: ./check_status -i ...
    # select specified hosts: ./check_status +scrooge.yandex.ru
    # select base searches and some other hosts: ./check_status -i ... +scrooge.yandex.ru
    aggregate_hostdefs = []
    instance = options.instance
    if not hostdefs and not instance:
        instance = DEFAULT_INSTANCE
    if instance:
        basesearches = list_basesearches(instance, options.shard_tags)
        aggregate_hostdefs = ['+' + h[0] for h in basesearches]
    aggregate_hostdefs += hostdefs

    hosts, unresolvedHosts = resolveHosts(aggregate_hostdefs)
    if unresolvedHosts:
        print >>sys.stderr, 'Unresolved hostdefs:'
        for h in unresolvedHosts:
            print >>sys.stderr, h
        sys.exit(1)

    hosts = sorted(hosts)
    
    hash_to_result = {}
    hash_to_hosts = {}
    error_to_hosts = {}
    
    runner = CheckStatusRunner(beta_path, beta_program, beta_manifest)
    client = api.kqueue.Client()
    for (host, result, err) in client.run(hosts, runner).wait():
        if err:
            add_item(error_to_hosts, decode_error(err), host)
        else:
            hash = calc_hash(result)
            add_item(hash_to_hosts, hash, host)
            hash_to_result[hash] = result
    
    def print_nonempty(title, value, color = None):
        value = str(value)
        if value:
            output = "% -20s %s" % (title+':', value)
            if os.isatty(1) and color is not None:
                output = "\033[%dm%s\033[m" % (30 + color, output)
            print output
    
    for err_id, host_group in error_to_hosts.iteritems():
        print_nonempty('error', err_id, 1)
        print_nonempty('host count', len(host_group))
        print_nonempty('hosts', compact_hosts(host_group))
        print ''
   
    for (hash, result) in hash_to_result.iteritems():
        host_group = hash_to_hosts[hash]
        num_processes = len(result.processes)
        run_color = 1
        if not num_processes:
            run_state = 'False'
        elif result.expected_num_processes:
            if num_processes == result.expected_num_processes:
                run_state = 'Full'
                run_color = 2
            else:
                run_state = 'Partial'
                run_color = 3
        else:
            run_state = 'True'
            run_color = 2
        print_nonempty('running', run_state, run_color)
        print_nonempty('executable exists', result.exec_exists)
        print_nonempty('dir exists', result.dir_exists)
        
        manifest = json.loads(result.manifest or '{}')
        print_nonempty('user', manifest.get('user',''))
        print_nonempty('message', manifest.get('message',''))
        if 'time' in manifest:
            print_nonempty('created', time.strftime('%c', time.localtime(manifest['time'])))
        print_nonempty('shard tags', ', '.join(manifest.get('shards', [])))
        print_nonempty('instance', manifest.get('instance', ''))
        
        print_nonempty('host count', len(host_group))
        print_nonempty('hosts', compact_hosts(host_group))
        print ''
    
if __name__=='__main__':
    run()
