#!/usr/bin/env python

import os
import sys
import argparse

import msgpack

SANDBOX_DIR = reduce(lambda p, _: os.path.dirname(p), xrange(2), os.path.abspath(__file__))  # noqa
sys.path = ["/skynet", os.path.dirname(SANDBOX_DIR), SANDBOX_DIR] + sys.path  # noqa

from sandbox import common
import sandbox.serviceq.types as qtypes
import sandbox.serviceq.state as qstate
from sandbox.yasandbox import controller


def load_snapshot(snapshot):
    with open(snapshot) as f:
        raw_data = f.read()
    unpacker = msgpack.Unpacker()
    unpacker.feed(raw_data)
    _, data = unpacker
    return qstate.PersistentState.decode(data)


def get_fields(sem):
    return dict(
        owner=sem.owner,
        capacity=sem.capacity,
        auto=sem.auto,
        shared=sem.shared,
        public=sem.public
    )


def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="Script for restoring semaphores from snapshot"
    )
    parser.add_argument("snapshot", help="Path to snapshot of the Service Q state")
    args = parser.parse_args()

    qclient = controller.TaskQueue.qclient
    assert qclient.status(secondary=False) == qtypes.Status.PRIMARY

    cz = common.console.AnsiColorizer()

    print cz.white("Loading snapshot...")
    state = load_snapshot(args.snapshot)
    print cz.white("Found {} persistent semaphore(s).".format(len(state.semaphores) - len(state.auto_semaphores)))
    semaphores = {}
    semaphores_by_name = {}
    print cz.white("Loading semaphores from Q...")
    for sid, sem in qclient.semaphores(secondary=False):
        semaphores[sid] = sem
        semaphores_by_name[sem.name] = sid
    print cz.white("Found {} semaphore(s).".format(len(semaphores)))
    print cz.white("Comparing persistent semaphores...")
    for sem in state.semaphores.itervalues():
        if sem.auto:
            continue
        fields = get_fields(sem)
        sid = semaphores_by_name.get(sem.name)
        if sid is None:
            sid, sem = qclient.create_semaphore(dict(fields, name=sem.name))
            print cz.green("Created new semaphore #{}: {}".format(sid, sem))
            continue
        sem = semaphores[sid]
        sem_fields = get_fields(sem)
        if sem_fields != fields:
            qclient.update_semaphore(sid, fields)
            print cz.yellow("Updated semaphore #{} ({}): {} -> {}".format(sid, sem.name, sem_fields, fields))
    print cz.white("Done.")

if __name__ == "__main__":
    main()
