import os
import argparse
import tensorflow as tf
import cv2
import numpy as np
from object_detection.utils import dataset_util


def load_label(path):
    label_file = open(path, "r")
    label = []
    for line in label_file:
        items = line.strip().split()
        if items[0] != "bld":
            continue
        cnt = int(items[1])
        coords = items[2:]
        points = []
        for i in range(cnt):
            x = int(float(coords[i * 2]))
            y = int(float(coords[i * 2 + 1]))
            points.append([x, y])
        label.append(points)
    return label


def filter_label(label, xmin, xmax, ymin, ymax):
    MIN_BLD_PIXELS_CNT = 300
    result = []
    for obj in label:
        in_cell = False
        for point in obj:
            if point[0] in range(xmin, xmax) and point[1] in range(ymin, ymax + 1):
                in_cell = True
                break
        if in_cell:
            shifted_obj = []
            for point in obj:
                shifted_obj.append([point[0] - xmin, point[1] - ymin])
            mask = np.zeros([ymax - ymin, xmax - xmin], dtype=np.uint8)
            cv2.fillPoly(mask, [np.array(shifted_obj)], 1)
            cnt = np.sum(mask)
            if cnt > MIN_BLD_PIXELS_CNT:
                result.append(shifted_obj)
    return result


def create_tf_example(filename, img, masks, xmins, xmaxs, ymins, ymaxs):
    CLASS_NAME = "bld"
    _, encoded_image_data = cv2.imencode(".jpg", img)
    image_format = 'jpeg'.encode('utf8')

    classes_text = [] # List of string class name of bounding box (1 per box)
    classes = [] # List of integer class id of bounding box (1 per box)

    encoded_mask_png_list = []

    for i in range(len(xmins)):
        classes = classes + [1]
        classes_text = classes_text + [CLASS_NAME.encode('utf8')]
        _, encoded_mask_data = cv2.imencode(".png", masks[i])
        encoded_mask_png_list.append(encoded_mask_data.tobytes())

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(img.shape[0]),
        'image/width': dataset_util.int64_feature(img.shape[1]),
        'image/filename': dataset_util.bytes_feature(filename.encode('utf8')),
        'image/source_id': dataset_util.bytes_feature(filename.encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_image_data.tobytes()),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/mask': dataset_util.bytes_list_feature(encoded_mask_png_list),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example


def write_image(writer, image_path, label_path, cell_size):
    print("image: {}, lable: {}".format(image_path, label_path))
    image = cv2.imread(image_path)
    label = load_label(label_path)
    cnt = 0
    vcell_cnt = image.shape[0] // cell_size
    hcell_cnt = image.shape[1] // cell_size

    for row in range(0, vcell_cnt * cell_size, cell_size):
        for col in range(0, hcell_cnt * cell_size, cell_size):
            cell_image = image[row:row+cell_size, col:col+cell_size]
            cell_obj = filter_label(label, col, col+cell_size, row, row+cell_size)
            if len(cell_obj) == 0:
                continue
            masks = []
            xmins = []
            xmaxs = []
            ymins = []
            ymaxs = []

            for obj in cell_obj:
                mask = np.zeros([cell_size, cell_size], dtype=np.uint8)
                cv2.fillPoly(mask, [np.array(obj)], 255)
                masks.append(mask)
                xmin = max(0, min([point[0] for point in obj])) / float(mask.shape[1])
                xmax = min(max([point[0] for point in obj]), mask.shape[1] - 1) / float(mask.shape[1])
                ymin = max(0, min([point[1] for point in obj])) / float(mask.shape[0])
                ymax = min(max([point[1] for point in obj]), mask.shape[0] - 1) / float(mask.shape[0])
                xmins.append(xmin)
                xmaxs.append(xmax)
                ymins.append(ymin)
                ymaxs.append(ymax)
            filename = "{}-{}-{}".format(image_path, row, col)
            tf_example = create_tf_example(filename, cell_image, masks, xmins, xmaxs, ymins, ymaxs)
            writer.write(tf_example.SerializeToString())
            cnt += 1
    return cnt


def main(input_folder, tfrecord_path, cell_size):
    writer = tf.python_io.TFRecordWriter(tfrecord_path)
    items = [name.split('.')[0] for name in os.listdir(os.path.join(input_folder, "labels"))]
    total_cnt = 0
    for item in items:
        image_fullpath = os.path.join(input_folder, "images", "{}.jpg".format(item))
        label_fullpath = os.path.join(input_folder, "labels", "{}.txt".format(item))
        total_cnt += write_image(writer, image_fullpath, label_fullpath, cell_size)
    writer.close()
    print("Sample count: {}".format(total_cnt))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Tool for create Mask R-CNN dataset")
    parser.add_argument("--input_folder", required=True)
    parser.add_argument("--tfrecord", required=True)
    parser.add_argument("--size", type=int, default=512)
    args = parser.parse_args()
    main(args.input_folder, args.tfrecord, args.size)
