#!/usr/bin/env python
# coding: utf8
# Python 2.7
"""Beatcop tries to ensure that a specified process runs on exactly one node in a cluster.
It does this by acquiring an expiring lock in Redis, which it then continually refreshes.
If the node stops refreshing its lock for any reason (like sudden
death) another will acquire the lock and launch the specified process.

Beatcop is loosely based on the locking patterns described at http://redis.io/commands/set.

https://github.com/Luluvise/BeatCop/blob/master/beatcop.py
"""


import atexit
import logging
import os
import redis
import signal
import socket
import subprocess
import sys
import time


log = logging.getLogger()


class Lock(object):
    """ Lock class using Redis expiry. """

    lua_refresh = """
        if redis.call("get", KEYS[1]) == ARGV[1]
        then
            return redis.call("pexpire", KEYS[1], ARGV[2])
        else
            return 0
        end
    """

    def __init__(self, redis_, name, timeout=None, sleep=0.1):
        self.redis = redis_
        self.name = name
        self.timeout = timeout
        self._timeout = int((timeout or 0) * 1000)
        self.sleep = sleep
        # Instead of putting any old rubbish into the Lock's value, use our FQDN and PID
        self.value = "%s-%d" % (socket.getfqdn(), os.getpid())
        self._refresh_script = self.redis.register_script(self.lua_refresh)

    def acquire(self, block=True):
        """ Acquire lock. Blocks until acquired if `block` is `True`,
        otherwise returns `False` if the lock could not be
        acquired. """
        while True:
            # Try to set the lock
            if self.redis.set(self.name, self.value, px=self._timeout, nx=True):
                # It's ours until the timeout now
                return True
            # Lock is taken
            if not block:
                return False
            # If blocking, try again in a bit
            time.sleep(self.sleep)

    def refresh(self):
        """Refresh an existing lock to prevent it from expiring.
        Uses a LUA (EVAL) script to ensure only a lock which we own is being overwritten.
        Returns True if refresh succeeded, False if not."""
        # Redis docs claim EVALs are atomic, and I'm inclined to believe it.
        return self._refresh_script(keys=[self.name], args=[self.value, self._timeout]) == 1

    def who(self):
        """Returns the owner (value) of the lock or `None` if there isn't one."""
        return self.redis.get(self.name)


class BeatCop(object):
    """Run a process on a single node by using a Redis lock."""

    def __init__(
        self, command, sentinel_host, sentinel_port, sentinel_namespace, lockname=None, timeout=3, shell=False, db=0
    ):
        self.command = command
        self.command_s = ' '.join(command)
        self.shell = shell
        self.timeout = timeout
        self.sleep = timeout / 3.0  # make sure we refresh at least 3 times per timeout period
        self.process = None

        # self.redis = redis.Redis.from_url(redis_server)
        from redis.sentinel import Sentinel
        sentinel = Sentinel([(sentinel_host, sentinel_port)], socket_timeout=0.1)
        self.redis = sentinel.master_for(sentinel_namespace, socket_timeout=0.1, db=db)

        try:
            redis_info = self.redis.info()
        except redis.exceptions.ConnectionError as e:
            log.error("Couldn't connect to Redis: %s", e.message)
            sys.exit(os.EX_NOHOST)

        _redis_ver = reduce(
            lambda l, r: l * 1000 + r,
            [int(v) for v in redis_info['redis_version'].split('.')])
        if _redis_ver < 2006012:
            log.error("Redis server is too old. You got %s, minimum requirement is %s",
                      redis_info['redis_version'], '2.6.12')
            sys.exit(os.EX_PROTOCOL)

        self.lockname = lockname or ("beatcop:%s" % (self.command_s))
        self.lock = Lock(self.redis, self.lockname, timeout=self.timeout, sleep=self.sleep)

        atexit.register(self.crash)
        signal.signal(signal.SIGINT, self.handle_signal)
        signal.signal(signal.SIGTERM, self.handle_signal)
        signal.signal(signal.SIGHUP, self.handle_signal)

    def run(self):
        """Run process if nobody else is, otherwise wait until we're needed. Never returns."""

        log.info("Waiting for lock, currently held by %s", self.lock.who())
        if self.lock.acquire():
            log.info("Lock acquired")
            # We got the lock, so we make sure the process is running
            # and keep refreshing the lock - if we ever stop for any
            # reason, for example because our host died, the lock will
            # soon expire.
            while True:
                if self.process is None:  # Process not spawned yet
                    self.process = self.spawn(self.command)
                    log.info("Spawned PID %d", self.process.pid)
                child_status = self.process.poll()
                if child_status is not None:
                    # Oops, process died on us.
                    log.error("Child died with exit code %d", child_status)
                    sys.exit(1)
                # Refresh lock and sleep
                if not self.lock.refresh():
                    who = self.lock.who()
                    if who is None:
                        if self.lock.acquire(block=False):
                            log.warning(("Lock refresh failed, but "
                                         " successfully re-acquired unclaimed lock"))
                        else:
                            log.error(
                                ("Lock refresh and subsequent"
                                 " re-acquire failed, giving up (Lock"
                                 " now held by %s)"), self.lock.who())
                            self.cleanup()
                            sys.exit(os.EX_UNAVAILABLE)
                    else:
                        log.error("Lock refresh failed, %s stole it - bailing out", self.lock.who())
                        self.cleanup()
                        sys.exit(os.EX_UNAVAILABLE)
                time.sleep(self.sleep)

    def spawn(self, command):
        """Spawn process."""
        if self.shell:
            if not isinstance(command, (bytes, unicode)):
                if len(command) == 1:
                    command = command[0]
                else:
                    raise Exception("For shell=True, command must be a single string")
            args = command
        else:
            if not hasattr(command, '__iter__'):
                raise Exception("For shell=False, command must be a list")
            args = command
        return subprocess.Popen(args, shell=self.shell)

    def cleanup(self):
        """Clean up, making sure the process is stopped before we pack up and go home."""
        if self.process is None:  # Process wasn't running yet, so nothing to worry about
            return
        if self.process.poll() is None:
            log.info("Sending TERM to %d", self.process.pid)
            self.process.terminate()
            # Give process a second to terminate, if it didn't, kill it.
            start = time.clock()
            while time.clock() - start < 1.0:
                time.sleep(0.05)
                if self.process.poll() is not None:
                    break
            else:
                log.info("Sending KILL to %d", self.process.pid)
                self.process.kill()
        assert self.process.poll() is not None

    def handle_signal(self, sig, frame):
        """Handles signals, surprisingly."""
        if sig in [signal.SIGINT]:
            log.warning("Ctrl-C pressed, shutting down...")
        if sig in [signal.SIGTERM]:
            log.warning("SIGTERM received, shutting down...")
        self.cleanup()
        sys.exit(-sig)

    def crash(self):
        """Handles unexpected exit, for example because Redis connection failed."""
        self.cleanup()


def main():
    if len(sys.argv) < 2:
        print "Usage: %s <command>" % sys.argv[0]
        sys.exit(os.EX_USAGE)

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s BeatCop: %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S %Z')

    cfg = dict(
        command=sys.argv[1:],
        sentinel_host=os.environ['SENTINEL_HOST'],
        sentinel_port=os.environ['SENTINEL_PORT'],
        sentinel_namespace=os.environ['SENTINEL_NAMESPACE'],
        lockname='SINGLE_BEAT_%s' % (os.environ['SINGLE_BEAT_IDENTIFIER'],),
        timeout=float(os.environ.get('SINGLE_BEAT_LOCK_TIME', 3)),
        db=int(os.environ.get('REDIS_DB_BEATCOP', 0)),
    )
    beatcop = BeatCop(**cfg)

    log.info("BeatCop starting on %s using lock '%s'",
             beatcop.lock.value, beatcop.lockname)
    beatcop.run()


if __name__ == '__main__':
    main()
