import nirvana.job_context as nv
import tarfile
import os
import re
import threading
import time


def get_file_list(dir_path, filename_mask):
    if (os.path.exists(dir_path)):
        return [os.path.join(dir_path, filename) for filename in os.listdir(dir_path) if re.match(filename_mask, filename)]
    else:
        return []


def pack_tarfile(output_filename, filename_list):
    with tarfile.open(output_filename, "w:gz") as tar:
        for source_file in filename_list:
            tar.add(source_file, arcname=os.path.basename(source_file))
            print source_file


def unpack_tarfile(input_filename, source_dir = '.'):
    with tarfile.open(input_filename, "r:gz") as tar:
        tar.extractall(path = source_dir)


def has_input_snapshot(name):
    return nv.context().get_inputs().has(name)


def has_output_snapshot(name):
    return os.path.exists(nv.context().get_outputs().get(name))


def has_snapshot(name):
    return has_input_snapshot(name) or has_output_snapshot(name)


def get_input_snapshot_path(name):
    return nv.context().get_inputs().get(name)


def get_output_snapshot_path(name):
    return nv.context().get_outputs().get(name)


def dump_snapshot(name, filename_list):
    print("Start dump snapshot: {}".format(name))
    if (not nv.context().get_outputs().has(name)):
        raise Exception('output \'state\' not exists')
    pack_tarfile('tmp_snapshot_{}.tar.gz'.format(name), filename_list)
    os.rename('tmp_snapshot_{}.tar.gz'.format(name), get_output_snapshot_path(name))


def load_snapshot(name, output_dir):
    if (has_output_snapshot(name)):
        snapshot_path = get_output_snapshot_path(name)
    elif (has_input_snapshot(name)):
        snapshot_path = get_input_snapshot_path(name)
    else:
        raise Exception('no snapshot available')
    unpack_tarfile(snapshot_path, output_dir)


def snapshot_worker(name, dir_path, time_interval, filename_mask, parent_thread):
    print 'start snapshot worker'
    file_list = get_file_list(dir_path, filename_mask)
    while parent_thread.is_alive():
        new_file_list = get_file_list(dir_path, filename_mask)
        if file_list != new_file_list or name == "logs":
            file_list = new_file_list
            try:
                dump_snapshot(name, file_list)
            except Exception, e:
                print("Could not dump snapshot: {}".format(name))
        time.sleep(time_interval)
    file_list = get_file_list(dir_path, filename_mask)
    try:
        dump_snapshot(name, file_list)
    except Exception, e:
        print 'Could not dump snapshot: {}'.format(e)


def snapshotter_start(snapshot_name, snapshot_dir, time_interval = 300, filename_mask = '.*'):
    snapshot_thread = threading.Thread(target = snapshot_worker, kwargs = {'name': snapshot_name,
                                                                           'dir_path':      snapshot_dir,
                                                                           'time_interval': time_interval,
                                                                           'filename_mask': filename_mask,
                                                                           'parent_thread': threading.current_thread()})
    snapshot_thread.daemon = False
    snapshot_thread.start()

def init(data_type, params):
    if data_type == "model":
        if not os.path.isdir(params["out_snapshot"]):
            print "make snapshot folder"
            os.mkdir(params["out_snapshot"])
        if (has_snapshot("state")):
            load_snapshot("state", params["out_snapshot"])
        snapshotter_start("state", params["out_snapshot"], 300, 'model-?[0-9]*\.(gdef|index|meta|data-00000-of-00001)$')
        if (has_snapshot("state")):
            return 1 + max([int(name[6:-6]) for name in os.listdir(params["out_snapshot"]) if name.endswith('index')])
        else:
            return 0
    elif data_type == "logs":
        if not os.path.isdir(params["out_logs"]):
            os.mkdir(params["out_logs"])
        if (has_snapshot("logs")):
            load_snapshot("logs", params["out_logs"])
        snapshotter_start("logs", params["out_logs"], 300)
    else:
        raise Exception("Unknown snapshot type: {}".format(data_type))
