import os.path
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
import sys
import numpy as np
import json
import tensorflow as tf
from object_detection.utils import dataset_util
import cv2
import random
import urllib.request
import json_ds_utils as jutils
import tfconfig_utils
import progress

TRAIN_PART = 0.9
LABEL_MAP_FILENAME      = '/label_map.pbtxt'
TRAIN_TFRECORD_FILENAME = '/train.tfrecord'
TEST_TFRECORD_FILENAME  = '/test.tfrecord'
CONFIG_FILENAME         = '/pipeline.config'
BAD_URLS_FILENAME       = '/bad_urls.txt'

features = {
    'image/format' :    tf.FixedLenFeature((), tf.string, default_value='jpeg'),
    'image/filename':   tf.FixedLenFeature((), tf.string, default_value=''),
    'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/source_id':  tf.FixedLenFeature((), tf.string, default_value=''),
    'image/height':     tf.FixedLenFeature((), tf.int64, default_value=1),
    'image/width':      tf.FixedLenFeature((), tf.int64, default_value=1),
    'image/encoded':    tf.FixedLenFeature((), tf.string, default_value=''),
    'image/object/bbox/xmin':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/xmax':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/ymin':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/ymax':   tf.VarLenFeature(tf.float32),
    'image/object/class/label': tf.VarLenFeature(tf.int64),
    'image/object/class/text':  tf.VarLenFeature(tf.string),
    'image/object/area':        tf.VarLenFeature(tf.float32),
    'image/object/is_crowd':    tf.VarLenFeature(tf.int64),
    'image/object/difficult':   tf.VarLenFeature(tf.int64),
    'image/object/group_of':    tf.VarLenFeature(tf.int64),
    'image/object/weight':      tf.VarLenFeature(tf.float32),
    'image/object/mask': tf.VarLenFeature(tf.string),
    'image/labels':    tf.FixedLenFeature((), tf.string, default_value=''),
}

def create_tf_example(feature_id, image, bboxes, cl_names, classes, img_horz_flip = False):
    height   = image.shape[0]
    width    = image.shape[1]

    # перекодируем, потому что время от времени попадаются
    # файлы которые закодированы как-то странно и TF их отказывается читать
    _, encoded_image_data = cv2.imencode('.jpg', image)
    encoded_image_data = encoded_image_data.tobytes()

    image_format = 'jpeg'.encode('utf8')

    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes_lbl = []

    for cl, bbox in zip(cl_names, bboxes):
        try:
            cl_idx = classes.index(cl)
        except ValueError:
            continue
        xmin = min(float(bbox[0][0]), float(bbox[1][0])) / float(width)
        xmax = max(float(bbox[0][0]), float(bbox[1][0])) / float(width)
        if (xmin < 0.001):
            xmin = 0.001
        if (xmax > 0.999):
            xmax = 0.999
        ymin = min(float(bbox[0][1]), float(bbox[1][1])) / float(height)
        ymax = max(float(bbox[0][1]), float(bbox[1][1])) / float(height)
        if (ymin < 0.001):
            ymin = 0.001
        if (ymax > 0.999):
            ymax = 0.999

        xmins = xmins + [xmin]
        xmaxs = xmaxs + [xmax]
        ymins = ymins + [ymin]
        ymaxs = ymaxs + [ymax]
        classes_text = classes_text + [cl.encode('utf8')]
        classes_lbl = classes_lbl + [cl_idx + 1]

    assert len(xmins) == len(xmaxs)
    assert len(xmins) == len(ymins)
    assert len(xmins) == len(ymaxs)
    assert len(xmins) == len(classes_text)
    assert len(xmins) == len(classes_lbl)
    if (len(classes_text) == 0):
        return None

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature('{}'.format(feature_id).encode('utf8')),
        'image/source_id': dataset_util.bytes_feature('{}'.format(feature_id).encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_image_data),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes_lbl),
    }))
    return tf_example

def make_ds_from_json(json_data, classes, image_folder, train_tfrecord_path, test_tfrecord_path, horz_flip_map, min_width, min_height, train_part = TRAIN_PART):
    csr = json_data['classified_signs']

    train_writer = tf.python_io.TFRecordWriter(train_tfrecord_path)
    test_writer  = tf.python_io.TFRecordWriter(test_tfrecord_path)

    train_data_augm = []
    cnt = len(csr)
    for idx, item in enumerate(csr):
        signs = item['signs']
        url   = item['source']

        feature_id, image_path = jutils.image_url_to_local_path(url, image_folder)
        assert os.path.isfile(image_path), 'Image not found {}'.format(image_path)

        progress.printProgressBar(idx+1, cnt, prefix = 'Make tfrecord: ', suffix = feature_id)

        in_train_part = (random.random() < train_part)
        bboxes = []
        cl_names = []
        for sign in signs:
            if (sign['answer'] != 'ok'):
                continue
            if (in_train_part):
                bbox = sign['bbox']
                if (min_width  > abs(bbox[1][0] - bbox[0][0]) and
                   (min_height > abs(bbox[1][1] - bbox[0][1]))):
                    continue
            bboxes = bboxes + [sign['bbox']]
            cl_names = cl_names + [sign['sign_id']]

        image = cv2.imread(image_path, 1)
        tf_example = create_tf_example(feature_id, image, bboxes, cl_names, classes)
        if (tf_example is None):
            continue
        if (in_train_part):
            train_writer.write(tf_example.SerializeToString())
            train_data_augm += [(feature_id, image_path, bboxes, cl_names)]
        else:
            test_writer.write(tf_example.SerializeToString())

    random.shuffle(train_data_augm)
    cnt = len(train_data_augm)
    for idx, item in enumerate(train_data_augm):
        feature_id, image_path, bboxes, cl_names = item
        progress.printProgressBar(idx+1, cnt, prefix = 'Add augmentation data: ', suffix = feature_id)

        image = cv2.imread(image_path, 1)
        image = cv2.flip(image, 1)

        img_h = image.shape[0]
        img_w = image.shape[1]

        cl_names_flip = []
        bboxes_flip = []
        for cl, bbox in zip(cl_names, bboxes):
            if (cl in horz_flip_map.keys()):
                cl = horz_flip_map[cl]
                if (cl == 'undefined'):
                    continue
                bbox[0][0] = img_w - 1 - bbox[0][0]
                bbox[1][0] = img_w - 1 - bbox[1][0]
                cl_names_flip += [cl]
                bboxes_flip += [bbox]
        if (len(cl_names_flip) == 0):
            conitnue
        tf_example = create_tf_example(feature_id, image, bboxes_flip, cl_names_flip, classes)
        if (tf_example is not None):
            train_writer.write(tf_example.SerializeToString())
    train_writer.close()
    test_writer.close()

def ds_stat(tfrecord_path):
    filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=1)

    reader = tf.TFRecordReader()
    _, example = reader.read(filename_queue)

    example_features = tf.parse_single_example(example, features=features)
    class_names = tf.sparse_tensor_to_dense(example_features['image/object/class/text'], default_value='')
    class_lbls = tf.sparse_tensor_to_dense(example_features['image/object/class/label'])

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    cl_names = {}
    cl_lbls = {}
    with tf.Session()  as sess:
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while (True):
                class_names_, class_lbls_ = sess.run([class_names, class_lbls])
                for cl_name in class_names_:
                    if cl_name in cl_names:
                        cl_names[cl_name] += 1
                    else:
                        cl_names[cl_name] = 1
                for cl_lbl in class_lbls_:
                    if cl_lbl in cl_lbls:
                        cl_lbls[cl_lbl] += 1
                    else:
                        cl_lbls[cl_lbl] = 1
        except tf.errors.OutOfRangeError:
            print('Epoch ended')
        finally:
            coord.request_stop()
        coord.request_stop()
        coord.join(threads)

    cl_all = 0
    for cl in cl_names.items():
        cl_all = cl_all + cl[1]
    for cl in cl_names.items():
        print('{:<20}: {:>5} ({:.2f}%)'.format(cl[0].decode('utf-8'), cl[1], float(cl[1] * 100)/float(cl_all)))
    print('{:<20}: {:>5}'.format('All', cl_all))
    print()
    for cl in cl_lbls.items():
        print('{}: {}'.format(cl[0], cl[1]))

def save_label_map(classes, label_map_path):
    label_map = ''
    for id, name in enumerate(classes):
        label_map += 'item {\n'
        label_map += '  id: ' + (str(id + 1)) + '\n'
        label_map += '  name: \'' + name + '\'\n'
        label_map += '}\n\n'
    with open(label_map_path, 'w') as f:
        f.write(label_map)

def make_ds(settings_path, image_folder, tfrecord_folder, config_template, train_part = TRAIN_PART):
    while (tfrecord_folder[-1] == '/'):
        tfrecord_folder = tfrecord_folder[:-1]
    train_tfrecord_path = tfrecord_folder + TRAIN_TFRECORD_FILENAME
    test_tfrecord_path  = tfrecord_folder + TEST_TFRECORD_FILENAME
    label_map_path      = tfrecord_folder + LABEL_MAP_FILENAME
    config_out_path     = tfrecord_folder + CONFIG_FILENAME
    bad_urls_out_path   = tfrecord_folder + BAD_URLS_FILENAME

    json_data, classes = jutils.prepare_json(settings_path, image_folder, bad_urls_out_path)
    if (json_data is None or classes is None):
        return

    if (not os.path.isdir(tfrecord_folder)):
        os.makedirs(tfrecord_folder)

    with open(settings_path) as f:
        settings = json.load(f)

    make_ds_from_json(json_data, classes, image_folder,
                      train_tfrecord_path, test_tfrecord_path,
                      settings["horz_flip_map"],
                      settings["min_bbox_width"], settings["min_bbox_height"],
                      train_part)
    save_label_map(classes, label_map_path)

    tfconfig_utils.update_template(config_template, len(classes), train_tfrecord_path, label_map_path, config_out_path)

    print('Train part: ')
    ds_stat(tfrecord_folder + '/train.tfrecord')

    print('Test part: ')
    ds_stat(tfrecord_folder + '/test.tfrecord')

if __name__ == '__main__':
    if ('stat' == sys.argv[1]):
        ds_stat(sys.argv[2])
    else:
        make_ds(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])
