#!/skynet/python/bin/python

import os
import sys

import time
import timeit
import gc  # for Timer
import gevent
import logging
import logging.handlers
import random
from optparse import OptionParser
import socket
import traceback

from library.sky.hostresolver.resolver import Resolver as BlinovResolver
from api.cqueue import Client as CqueueClient
#from ya.skynet.services.cqueue.auth.key import CryptoKey
#from types import InstanceType


class Timer(object):
    """
    Let make simple code timings via 'with' clause

    Based on:
    http://code.activestate.com/recipes/577896-benchmark-code-with-the-with-statement/
    """

    def __init__(self, title, log=None, timer=None, disable_gc=False, verbose=True):
        self.title = title

        if log is None:
            self.log = logging.getLogger(__name__ + '.' + self.__class__.__name__)
        else:
            self.log = log

        if timer is None:
            timer = timeit.default_timer
        self.timer = timer
        self.disable_gc = disable_gc
        self.verbose = verbose
        self.start = self.end = self.interval = None

    def __enter__(self):
        if self.disable_gc:
            self.gc_state = gc.isenabled()
            gc.disable()
        self.start = self.timer()
        return self

    def __exit__(self, *args):
        self.end = self.timer()
        if self.disable_gc and self.gc_state:
            gc.enable()
        self.interval = self.end - self.start
        if self.verbose:
            if self.log:
                self.log.info('TIMING: %s: %f', self.title, self.interval)


class Instance(object):
    def __init__(self, shard, str_instance):
        self.full_name = str_instance
        self.name, self.configuration = str_instance.split('@')
        self.host, self.port = self.name.split(':')
        self.port = int(self.port)
        self.shard = shard


class BsconfigGetInstanceState(object):
    osUser = 'loadbase'
    
    def __init__(self, check_port=False, user=None):
        self.check_port = check_port
        if user is not None:
            self.osUser = user

    def __call__(self):
        instances = yield None

        result = {}
        for shard, iname in instances:
            instance = Instance(shard, iname)
            result[iname] = {
                'shard_name': shard,
                'error': None,
            }
            try:

                result[iname]['shard'] = self.get_shard_state(instance.shard)
                result[iname]['resources'] = \
                        self.get_instance_state(instance)

                if not result[iname]['shard'] or not result[iname]['resources']:
                    result[iname]['instance'] = 'incomplete'
                    continue
                
                if self.instance_running(instance):
                    result[iname]['instance'] = 'started'
                else:
                    result[iname]['instance'] = 'startable'
                    
            except Exception as e:
                result[iname]['error'] = e
        yield result

    def get_shard_state(self, shard):
        if shard == 'none':
            return True
        
        shard_state_path = os.path.join('/db/BASE', shard, 'shard.state')

        return self.check_state_file(shard_state_path, 'install')

    def get_instance_state(self, instance):
        instance_state_path = os.path.join(
           '/db/bsconfig/configinstall',
           instance.configuration,
           instance.name,
           'instance.state',
        )

        return self.check_state_file(instance_state_path, 'resources')

    def check_state_file(self, path, key):
        if not os.path.exists(path):
            return False
        try:
            fd = open(path)
            lines = fd.readlines()
            fd.close()
            for line in lines:
                line = line.strip()
                if not line.startswith(key):
                    continue

                if line[len(key) + 1:] == 'OK':
                    return True

                break
        except:
            pass

        return False

    def instance_running(self, instance):
        run_flag_path = os.path.join(
          '/db/bsconfig/configinstall',
           instance.configuration,
           instance.name,
           'run.flag',
        )

        try:
            if not os.path.exists(run_flag_path):
                return False
        except:
            return False
        
        if not self.check_port:
            return True
        
        port = instance.port
        
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(2)
        res = sock.connect_ex(('127.0.0.1', port))
        sock.close()
        
        return res == 0


class StateAggregator(object):
    """Aggreagete and calculate statistics of instance states"""
    def __init__(self, options, instances):
        self.log = logging.getLogger(__name__ + '.' + self.__class__.__name__)
        self.__setup_colors()
        self.options = options
        
        self.instances = instances
        
        self.host_errors = {}

        self.instance_state = {
                'total': sum([len(self.instances[x]) for x in self.instances]),
                'started': 0,
                'startable': 0,
                'incomplete': 0,
                'error': 0,
        }
        
        self.shard_state = {}
        for host in self.instances:
            for shard, _instance in self.instances[host]:
                if not shard in self.shard_state:
                    self.shard_state[shard] = {
                        'total': 0,
                        'started': 0,
                        'startable': 0,
                        'incomplete': 0,
                        'error': 0,
                    }
                self.shard_state[shard]['total'] += 1
                
        self.host_reports = {}

    def __setup_colors(self):
        c = {}
        if sys.stdout.isatty():
            c['BLACK'], c['RED'], c['GREEN'],\
                 c['YELLOW'], c['BLUE'], c['MAGENTA'],\
                c['CYAN'], c['WHITE'] = \
                ["\033[1;%dm" % (30 + i) for i in range(8)]

            c['RESET'] = "\033[0m"
            c['BOLD'] = "\033[1m"

        else:
            c['BLACK'], c['RED'], c['GREEN'],\
                c['YELLOW'], c['BLUE'], c['MAGENTA'],\
                c['CYAN'], c['WHITE'] = \
                ["" for i in range(8)]

            c['RESET'] = ""
            c['BOLD'] = ""

        self.colors = c
        
    def update_host_errors(self, host, error):
        if error is None and host in self.host_errors:
            del self.host_errors[host]
            
        if error is not None:
            self.host_errors[host] = error
            
    def update(self, host, report):
        assert not host in self.host_reports, \
               'Duplicate host report for host %s' % host
        
        self.host_reports[host] = report
        
        for iname in report:
            shard = report[iname]['shard_name']
            if report[iname]['error'] is not None:
                state = 'error'
            else:
                state = report[iname]['instance']
            
            self.instance_state[state] += 1
            self.shard_state[shard][state] += 1
            
    def check(self):
        result = True
        if self.options.check_in_shard_percent is not None:
            result &= self.check_in_shard_percent()
        if result and self.options.check_started_percent is not None:
            result &= self.check_started_percent()
        if result and self.options.check_total_percent is not None:
            result &= self.check_total_percent()
        return result
            
    def calculate_shard_state(self):
        shard_state = {}
        for host in self.instances:
            for shard, instance in self.instances[host]:
                instance = instance.split('@')[0]
                if not shard in shard_state:
                    shard_state[shard] = {}
                    
                if shard in shard_state and instance in shard_state[shard]:
                    self.log.error('Duplicate instance %s for shard %s.', instance, shard)
                    
                shard_state[shard][instance] = None

        for report in self.host_reports.values():
            for instance in report:
                shard = report[instance]['shard_name']
                instance_ = instance.split('@')[0]
                shard_state[shard][instance_] = report[instance]

        return shard_state

    def print_set(self, set_):
        if len(set_) < 20:
            return ' '.join(sorted(set_))
        else:
            items = sorted(set_)
            return ' '.join(items[:8]) + ' ... ' + ' '.join(items[-8:])

    def print_report(self):
        print
        print '==== Report ===='
        print 'Host state:'
        print """    Total: %d
    OK:    %d
    Error: %d""" % (
            len(set(self.instances.keys())),
            len(set(self.instances.keys()) - set(self.host_reports.keys()) - \
                set(self.host_errors.keys())),
            len(self.host_errors))
        nodata_hosts = set(self.instances.keys()) - \
            set(self.host_reports.keys()) - set(self.host_errors.keys())
        if nodata_hosts:
            print '    No data: ', len(nodata_hosts)

        errors = {}
        for host in self.host_errors:
            repr_ = repr(self.host_errors[host])
            if not repr_ in errors:
                errors[repr_] = set()
            errors[repr_].add(host)
            
        if errors or nodata_hosts:
            print
            print 'Host errors:'
                
            for err in sorted(errors.keys()):
                print '   ', err, ': ', self.print_set(errors[err])
    
            if nodata_hosts:
                print '    No data: ', self.print_set(nodata_hosts)

        args = dict(self.instance_state)
        color_mapping = {
            'c_started': self.colors['GREEN'],
            'c_startable': self.colors['CYAN'],
            'c_incomplete': self.colors['YELLOW'],
            'c_error': self.colors['RED'],
            'c_reset': self.colors['RESET'],
        }

        args.update(color_mapping)
        args['nodata'] = self.instance_state['total']
        for i in ['started', 'startable', 'incomplete']:
            args['nodata'] -= self.instance_state[i]
            
        print
        print """Instance state statistics:
    Total:      %(total)d
    %(c_started)sStarted%(c_reset)s:    %(started)d
    %(c_startable)sStartable%(c_reset)s:  %(startable)d
    %(c_incomplete)sIncomplete%(c_reset)s: %(incomplete)d
    %(c_error)sErrors%(c_reset)s:     %(error)d
    No data:    %(nodata)d
""" % args

        if not self.options.short_report:
            print "Per shard stats:"
            self.print_shard_state()

    def print_shard_state(self):
        shard_state = self.calculate_shard_state()
        color_mapping = {
            'started': self.colors['GREEN'],
            'startable': self.colors['CYAN'],
            'incomplete': self.colors['YELLOW'],
            'error': self.colors['RED'],
            'none': self.colors['RESET'],
        }
        for shard in sorted(shard_state.keys()):
            line = self.colors['BLUE'] + shard + self.colors['RESET'] + ':'
            for instance in sorted(shard_state[shard].keys()):
                if shard_state[shard][instance] is None:
                    istate = 'none'
                    sstate = True
                elif shard_state[shard][instance]['error'] is not None:
                    istate = 'error'
                    sstate = True
                else:
                    istate = shard_state[shard][instance]['instance']
                    sstate = shard_state[shard][instance]['shard']
                    
                color = color_mapping.get(istate, self.colors['RESET'])
                host, port = instance.split(':')
                line += ' ' + host + ':' + color + port + self.colors['RESET']
                if not sstate:
                    line += '(' + self.colors['RED'] + '*' + self.colors['RESET'] + ')'

            print line

    def check_total_percent(self):
        "Return True if given fraction (ex. 0.9) of total instances are startable."
        state = self.instance_state
        result_fraction = (state['startable'] + state['started']) / (state['total'] * 1.0)
        return result_fraction >= self.options.check_total_percent

    def check_started_percent(self):
        "Return True if given fraction (ex. 0.9) of total instances are started."
        state = self.instance_state
        result_fraction = (state['started']) / (state['total'] * 1.0)
        return result_fraction >= self.options.check_started_percent

    def check_in_shard_percent(self):
        "Return true if given fraction (ex. 0.9) of instances in each shard are startable."
        for shard in self.shard_state:
            complete_instances = self.shard_state[shard]['started'] + \
                self.shard_state[shard]['startable']
            fraction = \
                (complete_instances * 1.0) / self.shard_state[shard]['total']
            if fraction < self.options.check_in_shard_percent:
                return False
        
        return True


class SkyStateCollector(object):
    """Collect instance state from cluster via skynet cqueue"""
    def __init__(self, options, instance_filter):
        self.log = logging.getLogger(__name__ + '.' + self.__class__.__name__)
        self.options = options
        
        self.client = CqueueClient("cqueue")
        if self.options.key_file:
            from ya.skynet.services.cqueue.auth.key import CryptoKey
            key = CryptoKey.load(self.options.key_file).next()
            self.client._CQueueClientProxy__slave.signer.addKey(key)

        with Timer('Initial instance resolving', verbose=options.verbose):
            self.instances = BlinovResolver().resolveInstances(instance_filter)
            
        self.log.debug("Hosts: %s", ' '.join(sorted(self.instances.keys())))
        self.state_aggregator = StateAggregator(options, self.instances)

        self.host_state = {
            'total': set(self.instances.keys()),
            'done': set(),
            'error': set(),
            'inprogress': set(),
        }
        self.host_errors = {}

        self.stop_data_collection = False

    def collect(self):
        self.log.info('Collecting data')
        # FIXME: make several tries for failed hosts
        thread_group = gevent.pool.Group()
        
        try:
            self.last_data_update = time.time()
            while True:
                
                remaining_hosts = list(self.host_state['total'] - \
                    self.host_state['error'] - self.host_state['done'] - \
                    self.host_state['inprogress'])
                
                self.log.info('Done: %d, error: %d, inprogress: %d, pending: %d',
                    len(self.host_state['done']),
                    len(self.host_state['error']),
                    len(self.host_state['inprogress']),
                    len(remaining_hosts),
                )
                if (self.options.check_total_percent is not None or
                        self.options.check_started_percent is not None or
                        self.options.check_in_shard_percent is not None):
                    self.log.info(
                        'Check total percent: %s, check started percent: %s, '
                        'check in shard percent: %s',
                        self.state_aggregator.check_total_percent(),
                        self.state_aggregator.check_started_percent(),
                        self.state_aggregator.check_in_shard_percent(),
                    )
                
                if not remaining_hosts and not self.host_state['inprogress']:
                    self.log.info('No more hosts to process')
                    break
                
                self.log.info('B: %s, C: %s', self.options.batch_mode, self.state_aggregator.check())
                if self.options.batch_mode and self.state_aggregator.check():
                    self.log.info('Stop data processing -- enough data collected.')
                    self.stop_data_collection = True
                    break

                if not remaining_hosts and self.host_state['inprogress']:
                    if time.time() - self.last_data_update > self.options.cqueue_timeout * 2:
                        self.log.error('No progress in timeout time.')
                        self.log.error('Pending hosts: %r', self.host_state['inprogress'])
                        break
                    gevent.sleep(1)
                    continue
                    
                random.shuffle(remaining_hosts)
                
                low = 0
                total_hosts = len(remaining_hosts)
                while low < total_hosts:
                    # FIXME: magic constant
                    high = min(low + 500, total_hosts)
                    target_hosts = remaining_hosts[low:high]
                    self.host_state['inprogress'] |= set(target_hosts)
                    thread = gevent.spawn(self.collect_host_data, target_hosts)
                    thread_group.add(thread)
                    self.log.info('Spawned %d-hosts thread', high - low)
                    gevent.sleep(0.1)
                    low = high
                    
                gevent.sleep(1)
            
            if self.stop_data_collection:
                thread_group.kill(StopIteration)
            
            with gevent.Timeout(self.options.cqueue_timeout):
                thread_group.join()
                
        except (gevent.Timeout, KeyboardInterrupt):
            pass
        
        return self.state_aggregator

    def collect_host_data(self, hosts):
        if self.options.user is not None:
            state_collector = BsconfigGetInstanceState(
                check_port=self.options.check_port,
                user=self.options.user,
            )
        else:
            state_collector = BsconfigGetInstanceState(
                check_port=self.options.check_port,
            )
        cq_session = self.client.iterFull(hosts, state_collector)
        send_left = set(hosts)
        wait_result = set()

        try:
            with gevent.Timeout(self.options.cqueue_timeout):
                it = cq_session.wait()
                host, res, err = it.next()
            
            while True:
                self.log.debug('Processing host=%r, res=%r, err=%r', host, res, err)
                if isinstance(err, StopIteration):
                    send_left.discard(host)
                    wait_result.discard(host)
                    with gevent.Timeout(self.options.cqueue_timeout):
                        host, res, err = it.next()
                    continue
                if err is not None:
                    if type(err) not in (EnvironmentError, RuntimeError):
                        try:
                            self.log.error('%r\n%s', err, ''.join(err._traceback))
                        except:
                            pass
                    
                    self.host_state['error'].add(host)
                    self.host_state['inprogress'].discard(host)
                    self.state_aggregator.update_host_errors(host, err)
                    
                    send_left.discard(host)
                    wait_result.discard(host)
                    with gevent.Timeout(self.options.cqueue_timeout):
                        host, res, err = it.next()
                    continue

                if res is None and host in send_left:
                    send_left.discard(host)
                    wait_result.add(host)
                    with gevent.Timeout(self.options.cqueue_timeout):
                        host, res, err = it.send(self.instances[host])
                    continue
                
                assert res is not None and host in wait_result, \
                    'Unexpected state. res = %r' % res

                self.last_data_update = time.time()
                                
                self.host_state['done'].add(host)
                self.host_state['inprogress'].discard(host)
                
                send_left.discard(host)
                wait_result.discard(host)

                self.state_aggregator.update_host_errors(host, None)
                self.state_aggregator.update(host, res)
                
                with gevent.Timeout(self.options.cqueue_timeout):
                    host, res, err = it.next()
                    
        except gevent.Timeout:
            self.log.error(
                'Got Timeout exception while processing %d hosts. '
                'Send left hosts: %d, wait result hosts: %d.',
                len(hosts),
                len(send_left),
                len(wait_result),
            )
        except StopIteration:
            pass
        except KeyboardInterrupt:
            self.log.error(
                'Got KeyboardInterrupt exception while processing %d hosts. '
                'Send left hosts: %d, wait result hosts: %d.',
                len(hosts),
                len(send_left),
                len(wait_result),
            )
            self.stop_data_collection = True
        finally:
            #self.host_state['error'] |= send_left | wait_result
            self.host_state['inprogress'] -= send_left | wait_result
            for host in send_left | wait_result:
                self.state_aggregator.update_host_errors(host, Exception('Interrupted'))

if __name__ == '__main__':
    usage = "usage: %prog [options] <skynet blinov filter>"
    parser = OptionParser(usage=usage)
    parser.add_option(
        "--check-total-percent",
        metavar="<float>",
        type="float",
        dest="check_total_percent",
        help="Check if given fraction (ex. 0.9) of total instances are startable.",
        default=None,
    )
    parser.add_option(
        "--check-started-percent",
        metavar="<float>",
        type="float",
        dest="check_started_percent",
        help="Check if given fraction (ex. 0.9) of total instances are started.",
        default=None,
    )
    parser.add_option(
        "--check-in-shard-percent",
        metavar="<float>",
        type="float",
        dest="check_in_shard_percent",
        help="Check if given fraction (ex. 0.9) of instances in each shard are startable.",
        default=None,
    )

    parser.add_option(
        "-t",
        "--cqueue-timeout",
        metavar="<seconds>",
        type="int",
        dest="cqueue_timeout",
        help="Set cqueue data timeout. (default: %default)",
        default=60,
    )

    parser.add_option(
        "-v",
        "--verbose",
        dest="verbose",
        help="Enable verbose mode.",
        default=False,
        action='store_true',
    )
    parser.add_option(
        "-d",
        "--debug",
        dest="debug",
        help="Enable debug mode.",
        default=False,
        action='store_true',
    )
    parser.add_option(
        "--batch-mode",
        dest="batch_mode",
        help="Stop data collection if enough data gathered",
        default=False,
        action='store_true',
    )

    parser.add_option(
        "-s",
        "--short-report",
        dest="short_report",
        help="Skip per shard statistics printing.",
        default=False,
        action='store_true',
    )
    
    parser.add_option(
        "-k",
        "--key-file",
        dest="key_file",
        help="Use key for cqeue auth",
        default=None,
    )
    
    parser.add_option(
        "-u",
        "--user",
        dest="user",
        help="Use user for cqeue auth",
        default=None,
    )
    
    parser.add_option(
        "-p",
        "--check-port",
        dest="check_port",
        help="Check port connectivity for running instances",
        default=False,
        action='store_true',

    )

    parser.add_option(
        "--skip-report",
        dest="skip_report",
        help="Do not print report",
        default=False,
        action='store_true',

    )
    options, args = parser.parse_args()

    log = logging.getLogger()
    if options.verbose:
        log.setLevel(logging.INFO)
    else:
        log.setLevel(logging.WARNING)
    if options.debug:
        log.setLevel(logging.DEBUG)

    formatter = logging.Formatter(
            '%(asctime)s - %(name)s:%(lineno)d - '
            '%(process)d - %(levelname)s - %(message)s')

    #if options.logfile:
    #    logfile_handler = logging.handlers.RotatingFileHandler(options.logfile,
    #            maxBytes=10 ** 9, backupCount=5)
    #else:
    #    logfile_handler = logging.StreamHandler()
    logfile_handler = logging.StreamHandler()

    logfile_handler.setFormatter(formatter)
    log.addHandler(logfile_handler)

    log.info('Starting as %s' % ' '.join(sys.argv))

    collector = SkyStateCollector(options, ' '.join(args))
    statistics = collector.collect()
    if not options.skip_report:
        statistics.print_report()

    if not statistics.check():
        sys.exit(1)
