#!/usr/bin/env python
#
# $Id$
#

import logging
import os

from glob import glob
from subprocess import Popen
from urllib2 import urlopen
from functools import partial, wraps
from itertools import repeat, islice
from time import sleep
from string import digits
from filecmp import cmp as fcmp

ROOT_DIR = "/"
CONF_DIR = "etc"
HOME_DIR = "home"
CA_DIR = "cauth"
BACKUP_DIR = "bak"
DB_DIR = "db"
TMP_DIR = "tmp"
MASTER = "master.passwd"
SHADOW = "shadow"
PASSWD = "passwd"
GROUP  = "group"
MASTER_LOCAL = "master.passwd.*"
GROUP_LOCAL  = "group.*"
KEYS_LOCAL   = "keys.local"

BACKUPS = 3
FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC | os.O_NOFOLLOW

TIMEOUT = 10
URL_GROUP = "https://ldap-dev.yandex.net:4443/group/"
URL_PASSWD = "https://ldap-dev.yandex.net:4443/passwd/serverusers/"
URL_KEYS = "https://ldap-dev.yandex.net:4443/userkeys/"

FORMAT = "%(levelname)s:%(asctime)s: %(message)s"

NULL = open("/dev/null")
OSNAME = os.uname()[0]

#
# Subroutines
#

def retry_on_exception(retries=5, base=2.0, backoff='exponential'):
    backoffs = {
            'exponential': (base**x for x in xrange(retries)),
            'linear': (base*x for x in xrange(retries)),
            'fixed': islice(repeat(base), retries),
        }
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for wait in backoffs[backoff]:
                try:
                    return func(*args, **kwargs)
                except Exception:
                    logging.info("Function run failed", exc_info=True)
                logging.debug("Sleeping for [ {0} ] seconds".format(wait))
                sleep(wait)
            raise RuntimeError("Function {0} failed after {1} retries".format(
                    func.__name__, retries))
        return wrapper
    return decorator

def _chown(_file, uid, gid):
    """
    Simple safe+log wrapper around os.chown
    """
    try:
        logging.debug("Chown file: {0} using {1}:{2}".format(_file, uid, gid))
        os.chown(_file, uid, gid)
    except Exception:
        logging.error("Can't chown file: {0}".format(_file), exc_info=True)

def chown(top, uid_map, gid_map):
    """
    Walks tree from top ignoring symlinks and changes gids/uids according to
    uid_map and gid_map.
    """
    if not (uid_map or gid_map):
        logging.warning('No uid_map nor gid_map is given to chown()')
        return
    for root, dirs, files in os.walk(top):
        for _file in dirs + files:
            filename = os.path.join(root, _file)
            try:
                stat = os.lstat(os.path.join(root, _file))
                uid = stat.st_uid
                gid = stat.st_gid
                if uid in uid_map or gid in gid_map:
                    _chown(filename, uid_map.get(uid, uid), gid_map.get(gid, gid))
            except Exception:
                logging.error("Can't process file: {0}".format(filename), exc_info=True)

#
# Classes
#

class Entry(object):
    defaults = dict(
            # master.passwd
            id = -1,
            gid = -1,
            login_class = "",
            change = 0,
            expire = 0,
            gecos = "",
            home_dir = "",
            shell  = "",
            # shadow
            last_change = 14999,
            min_age = 0,
            max_age = 99999,
            warning = 7,
            inactive = "",
            expiration = "",
            reserved = "",
            # group
            members = "",
        )
    supported_types = dict(
            master = ['name', 'password', 'id', 'gid', 'login_class','change',
                'expire', 'gecos', 'home_dir', 'shell'],
            shadow = ['name', 'password', 'last_change', 'min_age', 'max_age',
                'warning', 'inactive', 'expiration', 'reserved'],
            passwd = ['name', 'password', 'id', 'gid', 'gecos', 'home_dir',
                'shell'],
            group = ['name', 'password', 'id', 'members'],
        )

    def __init__(self, fields):
        """
        Initializes Entry object from passwd/group/master entry. Guessing input
        format based on len(fields).
        """
        # Set defaults
        self.__dict__.update(self.defaults)
        # Detect input format
        for keys in self.supported_types.values():
            if len(fields) == len(keys):
                # Override defaults
                self.__dict__.update(zip(keys, self.intify(fields)))
                break
        else:
            raise TypeError("Unknown format: {0}".format(fields))
        if not self.name:
            raise ValueError("Empty name")
        if self.id < 0:
            raise ValueError("Negative id")

    def to_(self, _type, attrs, **kwargs):
        """Wrapper for to_* methods with postprocessings"""
        attr_list = ((attr, getattr(self, attr)) for attr in attrs)
        postprocess = "post_{0}".format(_type)
        if hasattr(self, postprocess):
            attr_list = getattr(self, postprocess)(attr_list, **kwargs)
        return ':'.join(map(str, (value for attr, value in attr_list)))

    @staticmethod
    def shadow_password(attr_list, pwd_shadow_char=None, **kwargs):
        """
        Postprocessing that replaces password attr with shadow char

        >>> attr_list = [('nothing', 'to test'), ('password', 'secret')]
        >>> list(Entry.shadow_password(attr_list, pwd_shadow_char='!'))
        [('nothing', 'to test'), ('password', '!')]
        """
        for attr, value in attr_list:
            if attr == 'password':
                yield attr, pwd_shadow_char
            else:
                yield attr, value

    @staticmethod
    def replace_shell(attr_list, shell_replacement_map={}, **kwargs):
        """
        Replace shells according to shell_replacement_map

        >>> attr_list = [('shell', '/bin/bash'), ('test', '/bin/bash')]
        >>> shell_replacement_map = {'/bin/bash': '/bin/false'}
        >>> list(Entry.replace_shell(attr_list, shell_replacement_map=shell_replacement_map))
        [('shell', '/bin/false'), ('test', '/bin/bash')]
        """
        for attr, value in attr_list:
            if attr == 'shell' and value in shell_replacement_map:
                yield attr, shell_replacement_map[value]
            else:
                yield attr, value

    def post_passwd(self, attr_list, **kwargs):
        """Postprocessing for passwd file: replace password with shadow char"""
        attr_list = self.shadow_password(attr_list, **kwargs)
        attr_list = self.replace_shell(attr_list, **kwargs)
        return attr_list 

    def post_master(self, attr_list, **kwargs):
        """Postprocessing for master.passwd file: replace shell path"""
        return self.replace_shell(attr_list, **kwargs)

    def post_group(self, attr_list, **kwargs):
        """Remove password from group file"""
        return self.shadow_password(attr_list, **kwargs)

    def __getattr__(self, name):
        if name.startswith('to_'):
            _type = name[3:]
            attrs = self.supported_types.get(_type)
            if attrs:
                return partial(self.to_, _type, attrs)
        raise AttributeError("No such attribute: {0}".format(name))

    @staticmethod
    def intify(lst):
        """
        Convert all-digit values in list of strings to corresponding ints

        >>> list(Entry.intify('0 123 L12 23L L321L letters L'.split()))
        [0, 123, 'L12', '23L', 'L321L', 'letters', 'L']

        >>> list(Entry.intify(["", "1", "L"]))
        ['', 1, 'L']
        """
        for item in lst:
            if item and all(char in digits for char in item):
                yield int(item)
            else:
                yield item

class Container(object):
    def __init__(self):
        self.names = {}
        self.ids = {}

    def add_entry(self, entry):
        name = entry.name
        id = entry.id
        if name in self.names:
            return
        if id not in self.ids:
            self.ids[id] = []
        self.names[name] = entry
        self.ids[id].append(entry)

    def get_entries(self):
        for id in sorted(self.ids):
            for e in self.ids[id]:
                yield e

class SSHKeys(object):
    def __init__(self, root_dir=None):
        self.root_dir = root_dir
        self.keys = {}

    def add_key(self, user, key):
        self.keys[user] = key

    def save(self, passwd):
        for user, key in self.keys.items():
            try:
                user_entry = passwd.names[user]
                uid = user_entry.id
                gid = user_entry.gid
                home_dir = user_entry.home_dir
                user_ssh_dir = os.path.join(self.root_dir, home_dir, '.ssh')
                logging.debug("Saving key for user [ {0} ] to: {1}".format(user, user_ssh_dir))
                key_file = os.path.join(user_ssh_dir, 'authorized_keys')
                key_file_new = key_file + '_new'
                if not os.path.isdir(user_ssh_dir):
                    os.makedirs(user_ssh_dir)
                # XXX(rbtz): Public keys are not considered secret information,
                # but it'll be useful to generalize CAuth install() that it
                # could be used in key installation.
                with open(key_file_new, 'w') as f:
                    f.write(key)
                if fcmp(key_file, key_file_new):
                    logging.info("{0} and {1} are the same. Skipping.".format(key_file, key_file_new))
                    continue
                os.chown(key_file_new, uid, gid)
                os.chmod(key_file_new, 0600)
                os.rename(key_file_new, key_file)
            except EnvironmentError:
                logging.error("Can't save key for user: {0}".format(user), exc_info=True)
            finally:
                try:
                    os.unlink(key_file_new)
                except:
                    pass

class CAuth(object):
    def __init__(self, pwd_shadow_char=None, shell_replacement_map=None,
            root_dir=None, conf_dir=None):
        self.pwd_shadow_char = pwd_shadow_char
        self.shell_replacement_map = shell_replacement_map
        self.passwd = Container()
        self.group = Container()
        self.ssh_keys = SSHKeys(root_dir=root_dir)

    @retry_on_exception(5)
    def geturl(self, url):
        return urlopen(url, timeout=TIMEOUT)

    def load_group(self, sources):
        self._load(sources, self.add_group)

    def load_passwd(self, sources):
        self._load(sources, self.add_passwd)

    def load_keys(self, sources):
        self._load(sources, self.add_keys)

    def _load(self, sources, method):
        for src in sources:
            if src.startswith("https://"):
                it = self.geturl(src)
                if it:
                    method(it)
            else:
                with open(src) as it:
                    method(it)

    def add_group(self, it):
        for fields in self._parser(it):
            # XXX: Hack for broken group files with missing ":" at the end
            if len(fields) == 3:
                fields.append('')
            self.group.add_entry(Entry(fields))
            logging.debug("Got group entry: {0}".format(fields))

    def add_passwd(self, it):
        for fields in self._parser(it):
            self.passwd.add_entry(Entry(fields))
            logging.debug("Got passwd entry: {0}".format(fields))

    def add_keys(self, it):
        for user, key in self._parser(it, sep=' : '):
            self.ssh_keys.add_key(user, key)
            logging.debug("Got key for user: {0}".format(user))

    @staticmethod
    def _parser(it, sep=':'):
        for line in it:
            line = line.strip()
            if line.startswith('#') or not line:
                continue
            yield line.split(sep)

    def _name_id_map(self, container):
        """Returns name->id mapping for container."""
        return dict((name, entry.id)
                for name, entry in getattr(self, container).names.items())

    @staticmethod
    def _join_maps_by_name(m1, m2):
        """
        Join m1 and m2 by name and return uid mapping.
        Useful for constructing remapping table between old and new UIDs.

        >>> old = {'root':0, 'rbtz':4042, 'test': 123}
        >>> new = {'root':0, 'rbtz':777777, 'test2': 123}
        >>> CAuth._join_maps_by_name(old, new)
        {4042: 777777}
        """
        uid_map = {}
        for name,uid in m1.items():
            if name in m2 and m2[name] != uid:
                uid_map[uid] = m2[name]
        return uid_map

    def get_group(self):
        return self.group.get_entries()

    def get_passwd(self):
        return self.passwd.get_entries()

    def _generate(self, out, mode, get_method, to_method, **kwargs):
        with os.fdopen(os.open(out, FLAGS, mode), 'w') as f:
            for e in get_method():
                print >> f, getattr(e, to_method)(**kwargs)

    def generate(self, files):
        try:
            os.mkdir(os.path.dirname(files[0]), 0750)
        except:
            pass
        for f in files:
            kwargs = dict(
                    shell_replacement_map=self.shell_replacement_map,
                    pwd_shadow_char=self.pwd_shadow_char,
                )
            name = os.path.basename(f)
            if name == MASTER:
                self._generate(f, 0600, self.get_passwd, "to_master", **kwargs)
            elif name == SHADOW:
                self._generate(f, 0600, self.get_passwd, "to_shadow", **kwargs)
            elif name == PASSWD:
                self._generate(f, 0644, self.get_passwd, "to_passwd", **kwargs)
            elif name == GROUP:
                self._generate(f, 0644, self.get_group, "to_group", **kwargs)
            else:
                raise ValueError("Unknown file type to generate")

    @staticmethod
    def check_group(group):
        if OSNAME == "FreeBSD":
            p = Popen(["/usr/sbin/chkgrp", group], stdout=NULL, stderr=NULL)
            if p.wait():
                raise RuntimeError("%s file is not valid" % GROUP)

    @staticmethod
    def check_master(master):
        if OSNAME == "FreeBSD":
            p = Popen(["/usr/sbin/pwd_mkdb", "-C", master], stdout=NULL, stderr=NULL)
            if p.wait():
                raise RuntimeError("%s file is not valid" % MASTER)

    @staticmethod
    def check_passwd(passwd):
        if OSNAME == "Linux":
            shadow = os.path.join(os.path.dirname(passwd), SHADOW)
            p = Popen(["/usr/sbin/pwck", "-q", "-r", passwd, shadow], stdout=NULL, stderr=NULL)
            if p.wait():
                raise RuntimeError("passwd/shadow files are not valid")

    def check(self, files):
        for f in files:
            name = os.path.basename(f)
            if name == MASTER:
                self.check_master(f)
            elif name == GROUP:
                self.check_group(f)
            elif name == PASSWD:
                self.check_passwd(f)
            elif name == SHADOW:
                pass
            else:
                raise ValueError("Unknown file type to check: {0}".format(name))

    @staticmethod
    def backup(src, dst, backup_dir):
        """Keep up to BACKUPS for dst file"""
        try:
            os.mkdir(backup_dir, 0750)
        except:
            pass
        src_name = os.path.basename(src)
        with open(dst) as f:
            stat = os.fstat(f.fileno())
            mtime = str(int(stat.st_mtime))
            backup_name = os.path.join(backup_dir, src_name) + "." + mtime
            logging.debug("Backing up {0} to {1}".format(dst, backup_name))
            with os.fdopen(os.open(backup_name, FLAGS, stat.st_mode), 'w') as to:
                for line in f:
                    to.write(line)
        for f in sorted(glob(os.path.join(backup_dir, src_name) + "*"))[:-BACKUPS]:
            logging.info("Removing old backup: {0}".format(f))
            os.unlink(f)

    @staticmethod
    def install(sources, dstdir, backup_dir):
        for src in sources:
            dst = os.path.join(dstdir, os.path.basename(src))
            if fcmp(src, dst):
                logging.info("{0} and {1} are the same. Skipping.".format(src, dst))
                continue
            CAuth.backup(src, dst, backup_dir)
            if os.path.basename(dst) == MASTER:
                p = Popen(["/usr/sbin/pwd_mkdb", "-p", "-d", dstdir, src], stdout=NULL, stderr=NULL)
                res = p.wait()
                if res:
                    raise RuntimeError("Failed to install %s" % MASTER)
            else:
                os.rename(src, dst)

def main(chown_dirs=[], dry_run=False, root_dir=None, conf_dir=None, ssh_keys=False):
    conf_dir  = os.path.join(root_dir, conf_dir)
    ca_dir    = os.path.join(conf_dir, CA_DIR)
    backup_dir = os.path.join(ca_dir, BACKUP_DIR)
    tmp_dir   = os.path.join(ca_dir, TMP_DIR)
    ca_master = os.path.join(tmp_dir, MASTER)
    ca_shadow = os.path.join(tmp_dir, SHADOW)
    ca_passwd = os.path.join(tmp_dir, PASSWD)
    ca_group  = os.path.join(tmp_dir, GROUP)
    ca_group_local = sorted(glob(os.path.join(ca_dir, DB_DIR, GROUP_LOCAL)))
    ca_group_local.append(URL_GROUP)
    ca_master_local = sorted(glob(os.path.join(ca_dir, DB_DIR, MASTER_LOCAL)))
    ca_master_local.append(URL_PASSWD)
    # XXX What about keys?
    ca_keys_local   = os.path.join(ca_dir, DB_DIR, KEYS_LOCAL)

    if OSNAME == 'FreeBSD':
        pwd_shadow_char = '*'
        shell_replacement_map = {
                '/bin/bash': '/usr/local/bin/bash',
            }
        files = [ca_group, ca_master]
    elif OSNAME == 'Linux':
        pwd_shadow_char = 'x'
        shell_replacement_map = {
                '/usr/local/bin/bash': '/bin/bash',
            }
        files = [ca_group, ca_shadow, ca_passwd]
    else:
        raise OSError("Unsupported OS: %s", OSNAME)

    if chown_dirs:
        old_state = CAuth()
        old_state.load_group([os.path.join(conf_dir, GROUP)])
        old_state.load_passwd([os.path.join(conf_dir, PASSWD)])
        old_user_map = old_state._name_id_map(PASSWD)
        old_group_map = old_state._name_id_map(GROUP)

    ca = CAuth(
            pwd_shadow_char=pwd_shadow_char,
            shell_replacement_map=shell_replacement_map,
            root_dir=root_dir,
            conf_dir=conf_dir
        )

    # Load
    ca.load_group(ca_group_local)
    ca.load_passwd(ca_master_local)
    if ssh_keys:
        ca.load_keys([ca_keys_local, URL_KEYS])

    # Generate
    ca.generate(files)

    # Check
    ca.check(files)

    if dry_run:
        exit(0)

    # Install
    ca.install(files, conf_dir, backup_dir)
    if ssh_keys:
        ca.ssh_keys.save(ca.passwd)

    if chown_dirs:
        new_user_map = ca._name_id_map(PASSWD)
        new_group_map = ca._name_id_map(GROUP)

        uid_map = CAuth._join_maps_by_name(old_user_map, new_user_map)
        gid_map = CAuth._join_maps_by_name(old_group_map, new_group_map)

        logging.debug("uid_map: {0}".format(uid_map))
        logging.debug("gid_map: {0}".format(gid_map))

        for chown_dir in chown_dirs:
            chown(chown_dir, uid_map, gid_map)

if __name__ == '__main__':
    from optparse import OptionParser
    parser = OptionParser()
    parser.add_option("-c", '--chown', type="string", action="append",
            dest="chown_dirs", help="Top dir for chowning files, can be " \
            "specified more than once")
    parser.add_option("-n", '--dry-run', action="store_true", dest="dry_run",
            help="Don't take any real action")
    parser.add_option("-r", '--root-dir', type="string", dest="root_dir",
            default=ROOT_DIR, help="Root directory (default: %default)")
    parser.add_option("-C", '--conf-dir', type="string", dest="conf_dir",
            default=CONF_DIR, help="Config directory (default: %default)")
    parser.add_option("-s", '--ssh-keys', action="store_true", dest="ssh_keys",
            help="Install ssh keys (experimental)")
    parser.add_option("-v", '--verbose', action="store_true", dest="verbose",
            help="Raise log level to DEBUG")
    (options, args) = parser.parse_args()

    level = logging.WARN
    if options.verbose:
        level = logging.DEBUG

    logging.basicConfig(level=level, format=FORMAT)
    main(chown_dirs=options.chown_dirs, dry_run=options.dry_run,
            root_dir=options.root_dir, conf_dir=options.conf_dir,
            ssh_keys=options.ssh_keys)

