import sys
import os.path
import tensorflow as tf
import tarfile
import nirvana.job_context as nv

import calc_stat

STAT_SUBHEADER_PREFIX = "Steps: "

def _load_statistic(f):
    return f.read()

def _parse_statistic(text):
    ckpt_tested = []
    for line in text.split('\n'):
        if (not line.startswith(STAT_SUBHEADER_PREFIX)):
            continue
        ckpt_idx = int(line[len(STAT_SUBHEADER_PREFIX):])
        ckpt_tested += [ckpt_idx]
    return ckpt_tested

def _dump_statistic(state, f):
    f.write(state)

def main():
    ctx     =  nv.context()
    inputs  = ctx.get_inputs()
    outputs = ctx.get_outputs()
    params  = ctx.get_parameters()

    full_statistics = ctx.load_state('statistic', _load_statistic, default_value='', read_mode='r')
    ckpt_tested = _parse_statistic(full_statistics)
    print(ckpt_tested)

    test_data_paths  = inputs.get_list('test_data')
    base_thr         = params.get('base_thr')
    digit_seq_length = params.get('digit_seq_length')

    print(inputs.get('gdef_tgz'))
    current_dir = os.path.abspath('.')
    gdef_tgz = tarfile.open(inputs.get('gdef_tgz'), 'r')
    for fl in gdef_tgz:
        ckpt_idx = int(fl.name.split('.')[0])
        if (ckpt_idx in ckpt_tested):
            continue
        gdef_path = os.path.abspath(os.path.join('.', fl.name))
        gdef_tgz.extract(fl.name, '.')

        statistic = calc_stat.calc_stat(gdef_path, test_data_paths, base_thr, digit_seq_length)
        full_statistics += '{}{}\n{}\n\n'.format(STAT_SUBHEADER_PREFIX, ckpt_idx, statistic)
        ctx.dump_state('statistic', full_statistics, _dump_statistic, write_mode='wt')
        tf.gfile.Remove(gdef_path)
        tf.reset_default_graph()

    gdef_tgz.close()

if __name__ == '__main__':
    main()
