#!/skynet/python/bin/python

from __future__ import print_function

import re
import os
import sys
import time
import errno
import socket
import logging
import argparse
import platform
import collections
import threading as th
import subprocess as sp
import distutils.version


from sandbox.common import format as common_format
from sandbox.common import console as common_console
from sandbox.common import context as common_context
from sandbox.common.joint import errors as joint_errors

from sandbox.agentr import types as atypes
from sandbox.agentr import client as aclient


SAMOGON_ROOT = "/samogon/0" if os.path.isdir("/samogon/0") else "/samogon/1"
BASE = SAMOGON_ROOT + "/active/user/agentr/internal"


def mdstat():
    with open("/proc/mdstat", "r") as fh:
        return [
            x.strip()
            for x in fh
            if (x.startswith("md")) and not (x.startswith("md1 ") or x.startswith("md2 ") or x.startswith("md3 "))
        ]


def fstab():
    comre = re.compile(r"^# Storage bucket #(\d+) on (.+?)\s*$")
    devre = re.compile(r"^UUID=([^\s]+)\s+/storage/(\d+)")
    line_t = collections.namedtuple("Line", "no lineno uuid devs")
    res = []
    with open("/etc/fstab", "r") as fh:
        lineno = 0
        bno, devs = None, []
        for l in fh:
            m = comre.match(l)
            if m:
                bno, devs = m.group(1), [_.strip().replace("/dev/", "") for _ in m.group(2).split()]
                m = None
            else:
                m = devre.search(l)
            if m:
                assert bno == m.group(2), "Comment and mount point number mismatch: '%s' vs '%s'" % (bno, m.group(2))
                try:
                    ll = map(str.strip, sp.check_output(["btrfs", "fi", "sh", "/storage/" + bno]).split("\n"))
                    uuid = ll[0].split("uuid: ", 2)[1]
                    assert uuid == m.group(1), "UUID mismatch: '%s' vs '%s'" % (uuid, m.group(1))
                    devs = [_.split(" path /dev/", 2)[1] for _ in ll[2:] if _.startswith("devid")]
                except sp.CalledProcessError as ex:
                    print("Error checking bucket #%r: %s" % (bno, ex))
                    devs = None
                res.append(line_t(int(bno), lineno - 1, m.group(1), devs))
                print("Bucket found: %r" % (res[-1],), file=sys.stderr)
            lineno += 1
    return sorted(res, key=lambda _: _.no)


def blkls():
    """get data from system"""

    dev2bucket = {}
    for l in fstab():
        for d in l.devs or []:
            dev2bucket[d] = l

    # mapping block devices to serial number
    sn_disks = dict()
    devblock_path = "/dev/disk/by-id/"
    for sata in os.listdir(devblock_path):
        if (sata[:4] == "ata-" or sata[:5] == "scsi-") and ("-part" not in sata):
            disk = os.readlink(devblock_path + sata).split("/")[-1]
            sn = sata.split("_")[-1]
            sn_disks[disk] = sn

    # mapping block devices to slot in their shelves
    encl_path = "/sys/class/enclosure/"
    hw_shelves = []
    try:
        shelf_dirs = os.listdir(encl_path)
    except OSError:
        return hw_shelves
    for shelfdir in shelf_dirs:
        for slot in os.listdir(encl_path + shelfdir):
            if ("Disk" in slot) or ("ArrayDevice" in slot) or ("Connector" in slot) or (slot.isdigit()):
                try:
                    block_device = os.listdir(encl_path + shelfdir + "/" + slot + "/device/block/")[0]
                    raid_member = "none"
                    for md_disk in mdstat():
                        if block_device in md_disk:
                            raid_member = "RaidMember"
                    bno = dev2bucket[block_device].no if block_device in dev2bucket else "none"
                    hw_shelves.append((shelfdir, slot, block_device, sn_disks[block_device], bno, raid_member))
                except (IOError, OSError):
                    pass
    return sorted(hw_shelves)


def mkbucket(no, devs, fake=True, fstab_line=None, uuid=None, no_space_cache=False, mount=True):
    path = "/storage/" + str(no)
    if not uuid:
        print("Creating %r" % (path,), file=sys.stderr)
    opts = "rw,noatime,commit=300,compress=lzo,autodefrag," + (
        "space_cache=v2" if not no_space_cache else "clear_cache,nospace_cache"
    )
    cmd = ["mkfs.btrfs", "-f", "-d", "raid0", "-L", "bucket" + str(no)] + list(devs)
    if fake:
        print("Will execute command %r" % " ".join(cmd), file=sys.stderr)
        return

    if not uuid:
        try:
            os.mkdir(path)
        except OSError as ex:
            if ex.errno != errno.EEXIST:
                raise

        for d in devs:
            sp.check_call(["dd", "if=/dev/zero", "of=" + d, "bs=512", "count=1"])
            if d[-1].isdigit():
                continue
            sp.check_call(["wipefs", "--all", d])
            sp.check_call(["partprobe", d])
        sp.check_call(cmd)
        ln = sp.check_output(["btrfs", "fi", "sh", devs[0]]).split("\n", 1)[0]
        uuid = ln.split("uuid: ", 2)[1].strip()

    fstab_content = open("/etc/fstab", "r").readlines()
    fstab_line = len(fstab_content) + 1 if fstab_line is None else fstab_line
    with open("/etc/fstab", "w") as fh:
        map(fh.write, fstab_content[:fstab_line - 1])
        fh.write("\n# Storage bucket #{} on {}\n".format(no, " ".join(devs)))
        fh.write("UUID={}\t{}\tbtrfs\t{}\t0\t2\n".format(uuid, path, opts))
        map(fh.write, fstab_content[fstab_line + 2:])
    if mount:
        sp.check_call(["chown", "zomb-sandbox.sandbox", path])
        sp.check_call(["mount", path])
        sp.check_call(["chown", "zomb-sandbox.sandbox", path])


def make(args):
    shelves = {}
    info = blkls()
    hdr = ("pci_addr", "disk_name", "dev", "serial", "bucket", "raid")
    print(hdr, file=sys.stderr)
    for _ in info:
        print(_, file=sys.stderr)
        shelves.setdefault(_[0], []).append(_[2])
    print("Shelves: %r, disks: %r" % (sorted(shelves), len(info)), file=sys.stderr)
    if args.info_only:
        return

    l = 0
    for shid, disks in shelves.iteritems():
        if not l:
            l = len(disks)
        assert l == len(disks), (shid, len(disks))

    shids = sorted(shelves.keys())
    if len(shids) < 2:
        print("Running in a single-shelf mode - adding a virtual shelf", file=sys.stderr)
        if shids:
            ak, bk = shids[0], "virtual"
            disks = shelves[ak] + list(args.extra or [])
        else:
            print("WARNING!!! NO SHELVES DETECTED AT ALL!!!", file=sys.stderr)
            ak, bk = "virtual1", "virtual2"
            disks = list(args.extra or [])
        shelves[ak], shelves[bk] = [], []
        for dev in disks:
            shelves[ak].append(dev)
            ak, bk = bk, ak
        shids = sorted(shelves.keys())
        l = min(map(len, map(shelves.get, (ak, bk))))

    disks = sum(map(len, shelves.itervalues()))
    assert not len(shids) % 2, "Amount of shelves should be an even number"
    if disks % 2:
        print("WARNING: Amount of disks (%d) not an even number!" % (disks,), file=sys.stderr)
    else:
        print("Making buckets of %d disks totally" % (disks,), file=sys.stderr)
    buckets = l * len(shids) / 2
    print("Will create %d buckets (%d shelves of %d disks)" % (buckets, len(shids), l), file=sys.stderr)
    for p in xrange(len(shids) / 2):
        sh1, sh2 = shelves[shids[p * 2]], shelves[shids[p * 2 + 1]]
        if len(sh1) < len(sh2):
            sh1, sh2 = sh2, sh1
        for i in xrange(l):
            devs = ["/dev/" + sh1[i], "/dev/" + sh2[i]]
            if i == l - 1 and (len(sh1) != len(sh2)):
                devs.append("/dev/" + sh1[i + 1])
            mkbucket(p * l + i, devs, not args.commit)
    if args.commit:
        sp.check_call(["update-initramfs", "-u"])
    return 0


def __setup_config():
    if "SANDBOX_CONFIG" in os.environ:
        return

    samogon_dir = None
    for samogon_key in range(4):
        samogon_dir = "/samogon/{}".format(samogon_key)
        if os.path.exists(samogon_dir):
            break
    if samogon_dir is None:
        raise OSError("Samogon directory not found")
    config_dir = samogon_dir + "/active/user/agentr/internal"

    config_path = None
    for f in os.listdir(config_dir):
        if f.endswith(".cfg"):
            config_path = os.path.join(config_dir, f)

    os.environ.update({"SANDBOX_CONFIG": config_path})


def __confirm(msg="Continue"):
    try:
        return raw_input(msg + " (y/N)? ").lower() == "y"
    except EOFError:
        return False


def __check_kernel_version(cz):
    _ver = distutils.version.LooseVersion
    cur = platform.uname()[2]
    req = "4.9"
    if _ver(cur) < _ver(req):
        print(cz.red("Current kernel version %r is less than minumal required %r" % (cur, req)), file=sys.stderr)
        return False
    return True


def erase(args):
    path = "/storage/" + str(args.no)
    cz = common_console.AnsiColorizer()
    c = aclient.Service(logging.getLogger("agentr"))
    ci = c.bucket_cache_info(args.no)
    print(cz.yellow("About to reset bucket at %r with %d resources totally for %s" % (
        path, ci.amount, common_format.size2str(ci.size)
    )), file=sys.stderr)
    if not __confirm():
        return 1

    with common_console.LongOperation("Erasing"):
        c.erase_bucket(args.no)
    if args.unmount:
        with common_console.LongOperation("Unmounting"):
            sp.check_call(["umount", path])
    return 0


def remake(args):
    path = "/storage/" + str(args.no)
    cz = common_console.AnsiColorizer()
    c = aclient.Service(logging.getLogger("agentr"))
    ci = c.bucket_cache_info(args.no)
    if ci.amount:
        print(cz.red("There are %d resources totally for %s registered at bucket %r. Aborting." % (
            ci.amount, common_format.size2str(ci.size), path
        )), file=sys.stderr)
        return 2
    print(cz.yellow("About to re-make bucket at %r" % (path,)), file=sys.stderr)
    if not __check_kernel_version(cz) or not __confirm():
        return 1

    bucket = fstab()[args.no]
    if bucket.no != args.no:
        print(
            cz.red("Not all buckets are mounted correctly. Please check fstab and active mounts."),
            file=sys.stderr
        )
        return 2
    if args.dev and bucket.devs and set(_.replace("/dev/", "") for _ in args.dev) != set(bucket.devs):
        print(
            cz.red("Bucket devices provided %r while detected %r." % (args.dev, bucket.devs)),
            file=sys.stderr
        )
        return 2

    if bucket.devs:
        with common_console.LongOperation("Unmounting"):
            sp.check_call(["umount", path])

    devs = args.dev or bucket.devs
    if not devs:
        print(cz.red("The bucket %r is not mounted and no devices provided. Aborting." % (path,)), file=sys.stderr)
        return 2
    mkbucket(args.no, ["/dev/" + _.replace("/dev/", "") for _ in devs], False, bucket.lineno)
    with common_console.LongOperation("Updating initramfs"):
        sp.check_call(["update-initramfs", "-u"])
        sp.call(["sudo", "-u", "hw-watcher", "hw_watcher", "disk", "reset_status"])
    c.unban(args.no)
    return 0


def ban(args):
    cz = common_console.AnsiColorizer()
    with common_console.LongOperation("Connecting"):
        c = aclient.Service(logging.getLogger("agentr"))
    for no in args.no:
        di = c.ban(no)
        print("Bucket {} banned. Space left {} of {} ({}%)".format(
            cz.white("#" + str(no)),
            common_format.size2str(di.free), common_format.size2str(di.total), di.free * 100 / di.total
        ), file=sys.stderr)


def unban(args):
    cz = common_console.AnsiColorizer()
    with common_console.LongOperation("Connecting"):
        c = aclient.Service(logging.getLogger("agentr"))
    for no in args.no:
        di = c.unban(no)
        print("Bucket {} unbanned. Space left {} of {} ({}%)".format(
            cz.white("#" + str(no)),
            common_format.size2str(di.free), common_format.size2str(di.total), di.free * 100 / di.total
        ), file=sys.stderr)


def banned(args):
    c = aclient.Service(logging.getLogger("agentr"))
    banned = c.banned()
    for no in sorted(banned):
        di = banned[no]
        print("#{}: {} of {} ({}%)".format(
            no, common_format.size2str(di.free), common_format.size2str(di.total), di.free * 100 / di.total
        ))


def __mount(cz, lock, path, mount=True):
    now = time.time()
    checked, started = now, now
    while True:
        with lock:
            p = sp.Popen(["mount" if mount else "umount", path], stdout=sp.PIPE, stderr=sp.STDOUT)
        out, _ = p.communicate()
        if not mount and "not mounted" in out or mount and "already mounted" in out:
            break
        time.sleep(1)
        now = time.time()
        if now - checked > 60:
            minutes = (now - started) / 60
            with lock:
                print((cz.red if minutes > 3 else cz.yellow)("%s took more than %d minutes for %s ...") % (
                    ("Mounting" if mount else "Unmounting"), minutes, cz.white(path)
                ), file=sys.stderr)
            checked = now


def __remounter(lock, bid, bucket):
    path = "/storage/" + str(bid)
    cz = common_console.AnsiColorizer()
    if not bucket.devs:
        print(cz.red("Bucket #%d not mounted. Mount it first and re-run the command." % (bid,)), file=sys.stderr)
        return

    devs = ["/dev/" + _ for _ in bucket.devs]
    with lock:
        print(cz.blue("Recreating space cache for bucket ") + cz.white("#%d" % (bid,)), file=sys.stderr)
    with common_context.Timer() as t:
        with lock:
            print(cz.black("Unmounting %s ...") % (cz.white(path),), file=sys.stderr)
        __mount(cz, lock, path, False)
        with lock:
            mkbucket(bid, devs, False, bucket.lineno, bucket.uuid, True, False)
        with lock:
            print(cz.black("Resetting space cache 1/2 for %s ...") % (cz.white(path),), file=sys.stderr)
        __mount(cz, lock, path)
        with lock:
            print(cz.black("Resetting space cache 2/2 for %s ...") % (cz.white(path),), file=sys.stderr)
        __mount(cz, lock, path, False)

        with lock:
            print(cz.black("Creating space cache v2 for %s ...") % (cz.white(path),), file=sys.stderr)
        with lock:
            mkbucket(bid, devs, False, bucket.lineno, bucket.uuid, mount=False)
        __mount(cz, lock, path)
        with lock:
            print(
                cz.green("Space cache v2 enabled for bucket ") + cz.white("#%d" % (bid,)) +
                cz.green(" in ") + cz.white(common_format.td2str(t.secs)),
                file=sys.stderr
            )


def space_cache_v2(args):
    buckets = fstab()
    cz = common_console.AnsiColorizer()
    if not __check_kernel_version(cz):
        return 2

    c = aclient.Service(logging.getLogger("agentr"))
    if c.jobs():
        print(cz.red(
            "There are active jobs registered in AgentR: %r" % (sorted(c.jobs()),)
        ), file=sys.stderr)
        if not __confirm():
            return 1

    banned = c.banned()
    bids = args.no or range(len(buckets))
    not_banned = [bid for bid in bids if bid not in banned]
    if not_banned:
        print(cz.red(
            "Following buckets are not banned yet: %r" % (sorted(not_banned),)
        ), file=sys.stderr)
        if not __confirm():
            return 1

    print(cz.yellow("About to enable space cache v2 on buckets %r" % (bids,)), file=sys.stderr)
    if not __confirm():
        return 1

    lock = th.Lock()
    ths = [th.Thread(target=__remounter, args=(lock, bid, buckets[bid])) for bid in bids]

    with common_console.LongOperation("Waiting for workers") as op:
        op.intermediate("")
        map(th.Thread.start, ths)
        map(th.Thread.join, ths)

    with common_console.LongOperation("Releasing buckets"):
        map(c.unban, bids)

    with common_console.LongOperation("Updating initramfs"):
        sp.check_call(["update-initramfs", "-u"])
        sp.call(["sudo", "-u", "hw-watcher", "hw_watcher", "disk", "reset_status"])

    return 0


def reset(args):
    cz = common_console.AnsiColorizer()
    print(cz.yellow("About to reset client's state"), file=sys.stderr)
    if not __confirm():
        return 1

    with common_console.LongOperation("Resetting"):
        c = aclient.Service(logging.getLogger("agentr"))
        c.reset()

    return 0


def cleanup(args):
    cz = common_console.AnsiColorizer()
    c = aclient.Service(logging.getLogger("agentr"))
    df = c.df()
    perc = df.free * 100 / df.total
    print(cz.yellow("About to self-cleanup. Currently %s free of %s (%d%%)" % (
        common_format.size2str(df.free), common_format.size2str(df.total), perc
    )), file=sys.stderr)
    if perc >= args.threshold and not __confirm():
        return 1

    with common_console.LongOperation("Self-cleaning"):
        res = c.cleanup(args.chunk)
    odf, df = df, c.df()
    operc, perc = perc, df.free * 100 / df.total
    if not res.kind:
        print(cz.red("Nothing to clean. Currently %s free of %s (%d%%)" % (
            common_format.size2str(df.free), common_format.size2str(df.total), perc
        )), file=sys.stderr)
        return 3

    print(cz.green(
        "Cleaning of %s resources finished. Removed %d resources of %d found. "
        "Currently %s free of %s (%d%%), freed up %s (%d%%)." % (
            res.kind, res.removed, res.removable,
            common_format.size2str(df.free), common_format.size2str(df.total), perc,
            common_format.size2str(df.free - odf.free), perc - operc
        )
    ), file=sys.stderr)
    return 0


def maintain(args):
    cz = common_console.AnsiColorizer()
    c = aclient.Service(logging.getLogger("agentr"))
    df = c.df()
    perc = df.free * 100 / df.total
    print(cz.yellow("About to maintain. Currently %s free of %s (%d%%)" % (
        common_format.size2str(df.free), common_format.size2str(df.total), perc
    )), file=sys.stderr)
    if not args.yes and not __confirm():
        return 1

    limits = atypes.MaintainLimits(args.extra_local, args.extra_actual_local, args.extra_remote)
    with common_console.LongOperation("Self-maintaining"):
        res = c.maintain(args.dry_run, **dict(limits))
    odf, df = df, c.df()
    operc, perc = perc, df.free * 100 / df.total

    print(cz.green(
        "Maintaining finished. Found %d extra local, %d extra actual and %d extra remote resources. "
        "Currently %s free of %s (%d%%), freed up %s (%d%%)." % (
            res.extra_local, res.extra_actual_local, res.extra_remote,
            common_format.size2str(df.free), common_format.size2str(df.total), perc,
            common_format.size2str(df.free - odf.free), perc - operc
        )
    ), file=sys.stderr)
    return 0


def restore_links(args):
    c = aclient.Service(logging.getLogger("agentr"))
    with common_console.LongOperation("Restoring resources' symbolic links"):
        c("restore_links")
    return 0


def ping(args):
    client = aclient.Service(logging.getLogger("agentr"))
    result = client.ping(args.ping_value)
    assert result == args.ping_value, "Mismatch of return value of ping: got {}, expected {}".format(
        result, args.ping_value
    )


def socket_pid(args):
    client = aclient.Service(logging.getLogger("agentr"))
    try:
        peer_id = client._srv.connect()._sock.peerid
        print(peer_id.pid or 0)
    except socket.error:
        pass
    except joint_errors.HandshakeTimeout:
        print(-1)


def main():
    __setup_config()
    parser = argparse.ArgumentParser(
        formatter_class=lambda *args, **kwargs: argparse.ArgumentDefaultsHelpFormatter(*args, width=120, **kwargs),
        description="Storage buckets management tool."
    )
    subparsers = parser.add_subparsers(help="operational mode")
    # Create 'make' sub-parser.
    subparser = subparsers.add_parser("make", help="Make BTRFS buckets")
    subparser.set_defaults(func=make)

    subparser.add_argument(
        "-c", "--commit",
        action="store_true", help="Do make real buckets and/or change 'fstab'"
    )
    subparser.add_argument(
        "-i", "--info_only",
        action="store_true", help="Show storage mounts information only"
    )
    subparser.add_argument(
        "-e", "--extra",
        type=str, nargs="*", help="Extra devices to be used to create buckets (only for single shelf mode)"
    )

    # Create 'erase' sub-parser.
    subparser = subparsers.add_parser("erase", help="Erase given bucket cache information")
    subparser.set_defaults(func=erase)
    subparser.add_argument(
        "-u", "--unmount",
        action="store_true", help="Unmount the bucket after erasing"
    )
    subparser.add_argument(
        "no", metavar="NUMBER", type=int,
        help="Bucket number to be erased",
    )

    # Create 'remake' sub-parser.
    subparser = subparsers.add_parser("remake", help="Destroy a bucket and create a new in place of it")
    subparser.set_defaults(func=remake)
    subparser.add_argument(
        "no", metavar="NUMBER", type=int,
        help="Bucket number to be re-maked",
    )
    subparser.add_argument(
        "dev", metavar="DEVICE", type=str, nargs="*",
        help="Devices to be inserted into a bucket (without any check, for unmounted)",
    )

    # Create 'ban' sub-parser.
    subparser = subparsers.add_parser("ban", help="Ban a bucket")
    subparser.set_defaults(func=ban)
    subparser.add_argument(
        "no", metavar="NUMBER", type=int, nargs='*',
        help="Bucket number to be re-maked",
    )

    # Create 'unban' sub-parser.
    subparser = subparsers.add_parser("unban", help="Unban a bucket")
    subparser.set_defaults(func=unban)
    subparser.add_argument(
        "no", metavar="NUMBER", type=int, nargs='*',
        help="Bucket number to be re-maked",
    )

    # Create 'banned' sub-parser.
    subparser = subparsers.add_parser("banned", help="List banned buckets")
    subparser.set_defaults(func=banned)

    # Create 'space_cache_v2' sub-parser.
    subparser = subparsers.add_parser("space-cache-v2", help="List banned buckets")
    subparser.add_argument(
        "no", metavar="NUMBER", type=int, nargs='*',
        help="Bucket number(s) to operate. All in case of none provided",
    )
    subparser.set_defaults(func=space_cache_v2)

    # Create 'reset' sub-parser.
    subparser = subparsers.add_parser("reset", help="Reset client's state (task sessions)")
    subparser.set_defaults(func=reset)

    # Create 'cleanup' sub-parser.
    subparser = subparsers.add_parser("cleanup", help="Perform host self-cleanup procedure")
    subparser.add_argument("-c", "--chunk", type=int, default=None, help="Chunk size to operate with")
    subparser.add_argument("-t", "--threshold", type=int, default=25, help="Free space threshold in percents")
    subparser.set_defaults(func=cleanup)

    # Create 'maintain' sub-parser.
    subparser = subparsers.add_parser("maintain", help="Perform host maintain procedure")
    subparser.add_argument(
        "-y", "--yes",
        action="store_true", help="Assume 'yes' as answer to all prompts and run non-interactively"
    )
    subparser.add_argument(
        "-d", "--dry-run",
        action="store_true", help="No action; perform a scanning but do not actually remove any files"
    )
    subparser.add_argument(
        "-l", "--extra-local",
        type=int, default=None, help="Extra local cache items without actual data existance on the disk"
    )
    subparser.add_argument(
        "-a", "--extra-actual-local",
        type=int, default=None,
        help="Extra local cache items WITH actual data existance but without records on the server"
    )
    subparser.add_argument(
        "-r", "--extra-remote",
        type=int, default=None, help="Extra server records without actual data on the node"
    )
    subparser.set_defaults(func=maintain)

    # Create 'restore_links' sub-parser.
    subparser = subparsers.add_parser("restore_links", help="Restore symbolic links based on database content")
    subparser.set_defaults(func=restore_links)

    # Ping AgentR
    subparser = subparsers.add_parser("ping", help="Ping AgentR")
    subparser.add_argument("--ping-value", type=int, help="Ping value")
    subparser.set_defaults(func=ping)

    # Get PID of process that listening AgentR socket
    subparser = subparsers.add_parser("socket_pid", help="Get PID of process that listening AgentR socket")
    subparser.set_defaults(func=socket_pid)

    # Parse the args and call whatever function was selected.
    args = parser.parse_args()
    return args.func(args)


if __name__ == "__main__":
    sys.exit(main())
