#!/usr/bin/env python

import psycopg2
import requests
import sys
import time
import argparse
import socket

# 1. Hostname is missing in any capacity -- critical
# 2. Shard has abnormal count of masters -- critical
# 3. Shard is missing replicas -- critical
# 4. Replica is not in live state for more that specified periods of time. -- warning
# 5. Master is not in live state for more than specified periods of time. -- critical

ATTEMPTS = 5
CRITICAL = 3
WARNING = 2
INTERVAL = 1
ACTIONS = ['missing_in_stat', 'master_cnt', 'no_replicas', 'replicas_dead', 'masters_dead', 'phantoms']

class Status(object):
    """ Class for holding Juggler status """
    def __init__(self):
        self.code = 0
        self.text = []

    def _shorten(self, message):
        """ strip underscores and newlines. """
        replacements = [
            ['_', ' '],
            ['\n',''],
            ['mail.yandex.net', 'm'],
            ['disk.yandex.net', 'd'],
            ['db.yandex.net', 'db'],
        ]
        for patterns in replacements:
            message = message.replace(*patterns)
        return message

    def set_code(self, new_code):
        """ Set the code if it is greater than the current. """
        if new_code > self.code:
            self.code = new_code

    def append(self, new_text):
        """Accumulate the status text"""
        self.text.append( new_text )

    def report(self, retcode=0, message=None):
        """ Output formatted status message"""
        # concatenate all received statuses
        if message is None:
            message = '. '.join(self.text)
        if not message and self.code == 0:
            message = 'ok'
        # Check if code is above current setting
        self.set_code(retcode)
        print(
            '{code};{msg}'.format(
                code=self.code,
                msg=self._shorten(message)
            )
        )
        sys.exit(0)


class SharpeiCheck(object):
    """Check how /stat output corresponds to DB contents"""
    stat = []
    shards = []
    names = None

    def __init__(self, stat=None, shards=None):
        """
        Retrive the data from Postgresql backend and /stat handle
        """
        self.shards = shards
        self.stat = stat
        if stat is None:
            self.stat = self.get_stat()
        if shards is None:
            self.shards = self.get_shards()

    def _stat_hosts(self, shard):
        """
        Return a list of hosts in a given shard based on /stat output
        """
        hosts = [
            x['address']['host'] for x in self.stat[shard]['databases']
        ]
        return hosts

    def get_stat(self):
        """
        Call sharpei
        """
        try:
            result = requests.get('http://localhost:9999/stat?force=1')
            result.raise_for_status()
            return result.json()
        except requests.ConnectionError as err:
            raise Exception('/stat: unable to connect')
        except requests.Timeout as err:
            raise Exception('/stat: timed out')



    def get_shards(self):
        """
        Call postgresql
        """
        connection = psycopg2.connect('dbname=sharddb user=monitor connect_timeout=1')
        cur = connection.cursor()
        shards_list = """
            SELECT i.shard_id, host, name
            FROM shards.instances i
            JOIN shards.shards s USING (shard_id)
        """
        cur.execute(' '.join(shards_list.split()))
        state = {}
        for (num, host, name) in cur.fetchall():
            try:
                # row: 0 = shard num, 1 = hostname, 2 = shard name
                # state = { shard: [[host, shard_name], [...]], ... }
                state[name]['hosts'].append(host)
            except KeyError:
                state[name] = {
                    'hosts': [],
                    'id': str(num)
                }
        return state

    def missing_in_stat(self):
        """Checks if hostname is missing from /stat output"""
        missing = []
        for shard, props in self.shards.iteritems():
            shard_id = props['id']
            for host in props['hosts']:
                try:
                    if host not in self._stat_hosts(shard_id):
                        missing.append((shard, host))
                except KeyError as e:
                    missing.append((shard, host))
        return missing

    def master_cnt(self):
        """Checks how many masters are there in the shard"""
        anomalies = []
        for shard_id, props in self.stat.items():
            master = filter(
                lambda x: x['role'] == 'master',
                props['databases']
            )
            if len(master) != 1:
                anomalies.append(
                    (props['name'], len(master))
                )
        return anomalies

    def no_replicas(self):
        """Checks for missing replicas"""
        anomalies = []
        for shard_id, props in self.stat.items():
            replicas = filter(
                lambda x: x['role'] == 'replica',
                props['databases']
            )
            if len(replicas) < 1:
                anomalies.append((props['name'], ''))
        return anomalies

    def replicas_dead(self):
        """Check if there are replicas not in 'alive' state"""
        dead_replicas = []
        for shard_id, props in self.stat.items():
            dead_in_shard = filter(
                lambda x: x['role'] == 'replica' and x['status'] != 'alive',
                props['databases']
            )
            hosts_dead = [x['address']['host'] for x in dead_in_shard]
            if hosts_dead:
                dead_replicas.append(
                    (props['name'], ' '.join(hosts_dead))
                )
        return dead_replicas

    def masters_dead(self):
        """Check if there are masters not in 'alive' state"""
        dead_masters = []
        for shard_id, props in self.stat.items():
            dead_in_shard = filter(
                lambda x: x['role'] == 'master' and x['status'] != 'alive',
                props['databases']
            )
            hosts_dead = [x['address']['host'] for x in dead_in_shard]
            if hosts_dead:
                dead_masters.append(
                    (props['name'], ' '.join(hosts_dead))
                )
        return dead_masters

    def phantoms(self):
        """Check if there are shards in /stat not present in db"""
        phantoms = []
        known_ids = [x['id'] for x in self.shards.values()]
        for shard_id, props in self.stat.items():
            if shard_id not in known_ids:
                phantoms.append((shard,''))
        return phantoms

if __name__ == '__main__':

    arg = argparse.ArgumentParser(description="""
            Sharpei DB - /stat checker.
            """
            )
    arg.add_argument('-c', '--critical', type=int, required=False, metavar='<integer>', default=CRITICAL,
            help='critical threshold. Default: %s' % CRITICAL)
    arg.add_argument('-w', '--warning', type=int, required=False, metavar='<integer>', default=WARNING,
            help='warning threshold. Default: %s' % WARNING)
    arg.add_argument('-t', '--attempts', type=int, required=False, metavar='<integer>', default=ATTEMPTS,
            help='number of trials to run every check. Default: %s' % ATTEMPTS)
    arg.add_argument('-a', '--actions', nargs='+', metavar='<str>', default=ACTIONS,
            help='perform these checks. Default: {a}'.format(a=ACTIONS))

    settings = vars(arg.parse_args())

    # Testing harness.
    # import stat_test

    # Init result container
    results = []
    # Init Juggler status
    status = Status()

    # Perform actual checks
    for _ in xrange(settings.get('attempts')):
        try:
            # Testing harness.
            # check = SharpeiCheck(stat=stat_test.stat, shards=stat_test.shards)
            check = SharpeiCheck()
            for action in settings.get('actions'):
                try:
                    func = getattr(check, action)
                    # func returns a list of tuples -- [('object', 'subject'), (...)]
                    # where object is usually a shard, and object is a hostname.
                    results.append({action: func()})
                except Exception as e:
                    results.append({action: [('exc', unicode(e))]})
        except Exception as e:
            results.append({'err': [('exc', unicode(e))]})
        time.sleep(INTERVAL)

    # Testing harness.
    # results = stat_test.results

    # Analyse the results.
    # 'err' has possible runtime exceptions.
    for action in ['err'] + settings.get('actions'):
        events = []
        for attempt in results:
            if attempt.get(action):
                events += attempt.get(action)

        if len(events) >= settings.get('warning'):
            status.set_code(1)
            # set() to make repeating errors unique and not clobber the output.
            message = []
            for item in set(events):
                 if item[1]:
                     message.append(' '.join( [ str(item[0]), str(item[1]) ] ))
                 else:
                     message.append(str(item[0]))
            status.append('%s: %s' % (action, ','.join(message)))

        if len(events) >= settings.get('critical'):
            status.set_code(2)
            # The message is already formed, as critical is bigger than warning.

    # Output the result.
    status.report()
