#!/usr/bin/env python

import psycopg2
import requests
import sys
import time
import argparse
import logging

# 1. Count number of shards in each registration scope. (w<=2, c<=1)
# 2. For each registration scope, check for status of the shard:
# 2.1 Sharpei is aware of this shard and has info about it in /stat.
# 2.2. In each scope, there are at least 2 masters alive in any situation.
#      (1 master alive is a WARN situation)
# 3 There are also checks for replica state and masters count, but they
#   are mostly mimicking db_stat chek semantics, so I decided not to use them here.

ACTIONS = ['scope_masters_up', 'unknown_shard', 'reg_shard_cnt']

log = logging.getLogger(__name__)

class Status(object):
    """ Class for holding Juggler status """
    def __init__(self):
        self.code = 0
        self.text = []

    def _shorten(self, message):
        """ strip underscores and newlines. """
        replacements = [
            ['_', ' '],
            ['\n',''],
            ['mail.yandex.net', 'm'],
            ['disk.yandex.net', 'd'],
            ['db.yandex.net', 'db'],
        ]
        for patterns in replacements:
            message = message.replace(*patterns)
        return message

    def set_code(self, new_code):
        """ Set the code if it is greater than the current. """
        if new_code > self.code:
            self.code = new_code

    def append(self, new_text):
        """Accumulate the status text"""
        self.text.append( new_text )

    def report(self, retcode=0, message=None):
        """ Output formatted status message"""
        # concatenate all received statuses
        if message is None:
            message = '. '.join(self.text)
        if not message and self.code == 0:
            message = 'ok'
        # Check if code is above current setting
        self.set_code(retcode)
        print(
            '{code};{msg}'.format(
                code=self.code,
                msg=self._shorten(message)
            )
        )
        sys.exit(0)


class SharpeiRegCheck(object):
    """Check if there are enough shards for registration and they are alive"""
    def __init__(self, stat=None, scopes=None):
        """
        Retrive the data from Postgresql backend and /stat handle
        """

        self.scopes = scopes
        self.stat = stat
        if stat is None:
            self.stat = self.get_stat()
        if scopes is None:
            self.scopes = self.get_scopes()

    def __psql(self):
        """
        Calls local Psql
        """
        connection = psycopg2.connect('dbname=sharddb user=monitor connect_timeout=1')
        return connection.cursor()

    def get_scopes(self):
        """
        Retrieve shards count
        """
        cur = self.__psql()
        statement = """
SELECT sc.name AS scope, sh.shard_id, sh.name AS shard, host
FROM shards.shards sh
JOIN shards.instances i USING (shard_id)
JOIN shards.scopes_by_shards scsh USING (shard_id)
RIGHT JOIN shards.scopes sc USING (scope_id)
ORDER BY sc.name DESC
"""
        cur.execute(statement)
        scopes = {}
        for (scope, shard_id, shard, host) in cur.fetchall():
            # Determine how full is the current structure
            scope_present = bool(scopes.get(scope, False))
            shard_present = bool(scopes.get(scope, {}).get(shard_id, False))
            host_present = bool(host)
            # Now fill it depending on the degree.
            if not host_present:
                scopes[scope] = {}
                continue
            if not scope_present:
                scopes[scope] = {shard_id: {'hosts': [host], 'name': shard}}
                continue
            if not shard_present:
                scopes[scope].update({shard_id: {'hosts': [host], 'name': shard}})
                continue

            scopes[scope][shard_id]['hosts'].append(host)

        return scopes

    def get_stat(self):
        """
        Call sharpei
        """
        try:
            result = requests.get('http://localhost:9999/stat?force=1')
            result.raise_for_status()
            return result.json()
        except requests.ConnectionError as err:
            raise Exception('/stat: unable to connect')
        except requests.Timeout as err:
            raise Exception('/stat: timed out')

    def reg_shard_cnt(self):
        """Check if there are enough shards in each scope"""
        status = []
        code = 0
        log.debug('reg_shard_cnt: scopes: {s}'.format(s=self.scopes))
        for scope in self.scopes:
            shards_num = len(self.scopes[scope].keys())
            log.debug('reg_shard_cnt: scope: "{sc}", shard num: {n}'.format(sc=scope, n=shards_num))
            if shards_num < 2:
                if code < 1:
                    code = 1  # warning
                status.append('"{sc}": {n}'.format(sc=scope, n=shards_num))
            if shards_num < 1:
                code = 2
        return (code, ','.join(status))

    def unknown_shard(self):
        """Check for presence of shard_id in 'stat'"""
        status = []
        code = 0
        for scope in self.scopes.keys():
            for shard_id in self.scopes[scope].keys():
                try:
                    _ = self.stat[str(shard_id)]
                except KeyError as key:
                    code = 2
                    status.append('{s} ({n})'.format(s=key, n=self.scopes[scope][shard_id]['name']))
                    del(self.scopes[scope][shard_id])

        return (code, ','.join(status))

    def scope_masters_up(self):
        """
        Checks how many masters are alive for each registration scope:
        Should be exactly one per shard, and no less than two per scope.
        """
        status = []
        code = 0
        for scope, shards in self.scopes.iteritems():
            scope_masters_available = 0
            for shard_id in shards.keys():
                log.debug('master_cnt: scope: {s}, shard_id: {id}'.format(s=scope, id=shard_id))
                master = filter(
                    lambda x: x['role'] == 'master' and x['status'] == 'alive',
                    self.stat[str(shard_id)]['databases']
                )
                if len(master) == 1:
                    scope_masters_available += 1
                elif len(master) > 1:
                    code = 2
                    status.append(
                        '{sc}: >1 mastr@{s}'.format(
                            sc=scope,
                            s=self.stat[str(shard_id)]['name']
                        )
                    )
            if scope_masters_available < 2:
                if code < 1:
                    code = 1
                status.append('{sc}: {n}'.format(
                    sc=scope,
                    n=scope_masters_available
                ))
            if scope_masters_available < 1:
                code = 2
        return (code, ','.join(status))

if __name__ == '__main__':

    arg = argparse.ArgumentParser(description="""
            Sharpei DB - Registration shard checker.
            """
            )
    arg.add_argument('-d', '--debug', action='store_true', default=False,
            help='print debug info. Default: no debug')
    arg.add_argument('-a', '--actions', nargs='+', metavar='<str>', default=ACTIONS,
            help='perform these checks. Default: {a}'.format(a=ACTIONS))

    settings = vars(arg.parse_args())
    if settings.get('debug', False):
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.ERROR)

    # Testing harness.
    # import stat_test

    # Init result container
    results = []
    # Init Juggler status
    status = Status()

    try:
        # Testing harness.
        # check = SharpeiRegCheck(stat=stat_test.stat, scopes=stat_test.scopes)
        check = SharpeiRegCheck()
        for action in settings.get('actions',[]):
            try:
                func = getattr(check, action)
                # func returns a list of tuples -- [('object', 'subject'), (...)]
                # where object is usually a shard, and object is a hostname.
                code, message = func()
                log.debug('check: {ch}, code: {c}, status: "{s}"'.format(ch=action, c=code, s=message))
                if code > 0:
                    status.set_code(code)  # if it was a CRIT, it will never degrade back to WARN.
                    status.append('%s: %s' % (action, message))
            except Exception as exc:
                status.set_code(2)
                status.append('%s: %s' % (action, unicode(exc)))
                log.debug('check: {ch}, exception: {e}'.format(ch=action, e=exc), exc_info=1)
    except Exception as exc:
        status.set_code(2)
        status.append('%s' % unicode(exc))
        log.debug('general error: exception: {e}'.format(e=exc), exc_info=1)

    status.report()
