#! /usr/bin/env python2.7

import logging
import logging.handlers
import sys
import random
import signal
import time
import json
import os.path
import os
import re

import hub
import xtable
import web

from multiprocessing import Pool

class Profiler:
    def __init__(self, logger, step_name):
        self.logger = logger
        self.step_name = step_name
        self.start_time = None

    def __enter__(self):
        self.start_time = time.time()
        self.logger.info('%s started', self.step_name)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        duration = time.time() - self.start_time
        self.logger.info('%s done in %.6f sec', self.step_name, duration)
        return False

def shards_are_same(conninfo1, conninfo2):
    # Try to match by dbname, if that fails - compare conninfos.
    # TODO: parse conninfo and compare by host/port/dbname/ triple
    m1 = re.search('xiva_xtable_[a-z]+_[0-9][0-9]', conninfo1)
    m2 = re.search('xiva_xtable_[a-z]+_[0-9][0-9]', conninfo2)
    if not m1 or not m2:
        return conninfo1 == conninfo2
    return m1.group(0) == m2.group(0)

def reshard_step(gid, cluster, step_number, reaction, interval_before_select, logger):
    with Profiler(logger, 'query status'):
        (state, shard_from, shard_to) = cluster.resharding_status(gid)
    logger.info('in state %s lives in %s moves to %s', state, shard_from, shard_to)
    if shards_are_same(shard_from, shard_to):
        return 'no migration necessary'

    if state == 'finished':
        return 'already migrated'

    # Delete any subs existing in new shard for the gid. They may have appeared
    # after a failed migration attempt (e.g. moved all subscriptions, but failed
    # to finalize migration on hubs).
    if step_number == 0:
        with Profiler(logger, 'xtable count'): num_subs = xtable.count(gid, shard_to)
        if num_subs > 0:
            return 'cannot migrate, shard_to contains subs for the gid'
    else:
        with Profiler(logger, 'xtable delete'): num_deleted = len(xtable.delete(gid, shard_to))
        logger.info('deleted %s subscriptions', num_deleted)

    if state != 'pending':
        with Profiler(logger, 'abort'): cluster.abort_migration(gid)

    with Profiler(logger, 'prepare'): cluster.prepare_migration(gid)
    with Profiler(logger, 'start'): cluster.start_migration(gid, reaction)
    # This sleep helps us make sure (in most cases) that already running
    # xtable queries are complete before we start to read subscriptions for the gid.
    # Implementing a system that tracks all queries on all hubs and allows to state
    # with 100% certainity that there are no queries in progress and we can proceed
    # with resharding is too costly.
    time.sleep(interval_before_select)
    with Profiler(logger, 'xtable read'): subs = xtable.read(gid, shard_from)
    with Profiler(logger, 'xtable insert'): xtable.insert(subs, shard_to)
    with Profiler(logger, 'finalize'): cluster.finalize_migration(gid)

    return 'finished'


class GidAdapter(logging.LoggerAdapter):
    def process(self, msg, kwargs):
        return '%-5s | %s' % (self.extra['gid'], msg), kwargs

def reshard(gid, cluster, reactions, retry_interval, interval_before_select):
    logger = GidAdapter(logging.getLogger(), {'gid': gid})
    with Profiler(logger, 'reshard'):
        for step_number in range(0, len(reactions)):
            reaction = reactions[step_number]
            try:
                with Profiler(logger, 'step with reaction ' + reaction):
                    return reshard_step(gid, cluster, step_number, reaction, interval_before_select, logger)
            except Exception as e:
                logger.error('migration failed: %s: %s, retrying in %.3f',
                    type(e).__name__, str(e), retry_interval)
                # If we fail here, will abort if necessary before resharding
                with Profiler(logger, 'try abort'): cluster.try_abort_migration(gid)
                time.sleep(retry_interval)

        return 'migration failed'

def reshard_unpack(args):
    return reshard(*args)

def ReshardInProgressError(Exception): pass

def range_text(ranges):
    ret = ''
    for r in ranges:
        if r[0] == r[1]:
            ret += str(r[0]) + ', '
        else:
            ret += '{0}-{1}, '.format(r[0], r[1])
    return ret[:-2]

class Resharder:
    def __init__(self, config):
        self.config = config
        self.clusters = [hub.Cluster(c) for c in config['clusters']]
        self.worker_pool = self.setup_pool()
        self.results = {}

    def reshard(self, gid_start, gid_end):
        if self.reshard_in_progress():
            raise ReshardInProgressError

        logging.info('resharding gids from %s to %s', gid_start, gid_end)

        reactions = self.config['reactions']
        interval = self.config['retry_interval']
        interval_before_select = self.config['interval_before_select']
        tasks = [(gid, cluster, reactions, interval, interval_before_select) for cluster in self.clusters \
            for gid in range(gid_start, gid_end + 1) if gid in cluster.gid_range]

        # Map results by gid for status reporting.
        self.results = {task[0]: self.worker_pool.apply_async(reshard_unpack, (task,)) for task in tasks}
        return 'OK'

    def abort(self):
        self.worker_pool.terminate()
        self.worker_pool.join()

        for gid, result in self.results.iteritems():
            if not result.ready():
                self.results[gid] = 'aborted'

        self.worker_pool = self.setup_pool()
        return self.status()

    # {
    #     "failed": "1-3, 100, 101-300",
    #     "finished": "3, 4, 5"
    # }
    def status(self):
        ranges_by_result = {}
        for gid, result in self.results.iteritems():
            if result == 'aborted':
                result_text = 'aborted'
            elif result.ready():
                result_text = result.get()
            else:
                result_text = 'migrating'

            if result_text not in ranges_by_result:
                ranges_by_result[result_text] = [[gid, gid]]
            else:
                last_range = ranges_by_result[result_text][-1]
                end_gid = last_range[1]
                if gid - end_gid == 1:
                    last_range[1] = gid
                    ranges_by_result[result_text][-1] = last_range
                else:
                    ranges_by_result[result_text].append([gid, gid])

        return {result: range_text(ranges) for result, ranges in ranges_by_result.iteritems()}

    def reshard_in_progress(self):
        return any([not result.ready() for gid,result in self.results.iteritems() if result != 'aborted'])

    def setup_pool(self):
        def init_worker(config):
            signal.signal(signal.SIGINT, signal.SIG_IGN)

            xtable.insert_page_size = config['insert_page_size']
            xtable.request_timeout = config['timeouts']['xtable']

            hub.connect_timeout = config['timeouts']['hub']['connect']
            hub.request_timeout = config['timeouts']['hub']['request']

        return Pool(self.config['workers'], init_worker, (self.config,))

def qloud_env_from_file():
    try:
        meta_json = json.load(open('/etc/qloud/meta.json'))
        return meta_json['qloud_environment']
    except Exception:
        return None

def get_env_name():
    qloud_env = os.getenv('QLOUD_ENVIRONMENT')
    if qloud_env is not None:
        return qloud_env

    logging.info('failed to get env from QLOUD_ENVIRONMENT, trying file')
    qloud_env = qloud_env_from_file()
    if qloud_env is not None:
        return qloud_env

    logging.info('failed to get env from file, assuming local')
    return 'local'

def main():
    LOGFORMAT = '%(asctime)s | %(levelname)s | %(message)s'
    formatter = logging.Formatter(LOGFORMAT)
    logging.basicConfig(format=LOGFORMAT,
                        level=logging.INFO,
                        stream=sys.stdout)
    file_handler = logging.handlers.RotatingFileHandler(
        '/var/log/yxiva-resharder/yxiva-resharder.log',
        maxBytes=100*1024*1024,
        backupCount=10)
    file_handler.setFormatter(formatter)
    file_handler.setLevel(logging.INFO)
    logging.getLogger().addHandler(file_handler)

    logging.getLogger("requests").setLevel(logging.WARNING)
    logging.getLogger("urllib3").setLevel(logging.WARNING)

    env_name = get_env_name()
    logging.info('starting in environment %s', env_name)

    if env_name == 'local':
        config_path = sys.argv[1]
    else:
        config_path = '/etc/yxiva-resharder/%s.json' % env_name
    config = json.load(open(config_path))
    resharder = Resharder(config)

    webserver = web.server(config['web'])

    webserver.bind('/reshard',
        lambda gid_start, gid_end: resharder.reshard(gid_start, gid_end),
        gid_start=int, gid_end=int)
    webserver.bind('/abort', lambda: resharder.abort())
    webserver.bind('/status', lambda: resharder.status())
    webserver.bind('/ping', lambda: 'pong')

    try:
        webserver.serve_forever()
    except KeyboardInterrupt:
        logging.info('shutting down')
        webserver.shutdown()

if __name__ == '__main__':
    main()
