import tensorflow as tf
import cv2
import numpy as np
import argparse
import random


def is_right_turn((p, q, r)):
    assert p != q and q != r and p != r

    sum1 = q[0]*r[1] + p[0]*q[1] + r[0]*p[1]
    sum2 = q[0]*p[1] + r[0]*q[1] + p[0]*r[1]
    turn = sum1 - sum2

    if turn < 0:
        return True
    else:
        return False


def convex_hull(points):
    points.sort()

    upper = [points[0], points[1]]
    for p in points[2:]:
        upper.append(p)
        while len(upper) > 2 and not is_right_turn(upper[-3:]):
            upper.pop(-2)

    points.reverse()
    lower = [points[0], points[1]]
    for p in points[2:]:
        lower.append(p)
        while len(lower) > 2 and not is_right_turn(lower[-3:]):
            lower.pop(-2)

    lower = lower[1:-1]

    return upper + lower


def load_graph(path):
    with tf.gfile.FastGFile(path, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name='')


def equalize_hist(img):
    img_yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV)
    img_yuv[:,:,0] = cv2.equalizeHist(img_yuv[:,:,0])
    return cv2.cvtColor(img_yuv, cv2.COLOR_YUV2BGR)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image', help = 'Path to image')
    parser.add_argument('--gdef', help = 'Path to gdef')
    parser.add_argument('--border_thresh', type = int, default = 127, help = 'thresh value for borders')
    parser.add_argument('--output', default = 'polygon.txt', help = 'Path to file with found polygons')
    args = parser.parse_args()

    load_graph(args.gdef)

    img = cv2.imread(args.image)
    if img is None:
        raise IOError("Failed to open file '{}'".format(args.image))
    equalized_img = equalize_hist(img)
    net_input = np.expand_dims(equalized_img, axis = 0)

    with tf.Session() as sess:
        sigmoid = sess.graph.get_tensor_by_name('inference_sigmoid:0')
        sigmoid = sess.run(sigmoid, {'inference_input:0': net_input})
        fuse = sigmoid[5]

    fuse = np.array(fuse*255, dtype = np.uint8)
    fuse = np.squeeze(fuse)
    fuse = 255 - fuse

    ret, binary_fuse = cv2.threshold(fuse, args.border_thresh, 255, cv2.THRESH_BINARY)

    border = np.zeros_like(fuse)

    for thresh in range(args.border_thresh, 256):
        while True:
            eroded = cv2.erode(binary_fuse, np.ones((3, 3), dtype = np.uint8))
            dilated = cv2.dilate(eroded, np.ones((3, 3), dtype = np.uint8))
            eroded[fuse > thresh + 1] = 255
            dilated[fuse < thresh + 1] = 0

            border[np.logical_and(binary_fuse == 255, dilated == 0)] = 255
            eroded[border == 255] = 255
            count_nonzero = np.count_nonzero(binary_fuse - eroded)

            binary_fuse = eroded
            if count_nonzero == 0:
                break
        binary_fuse[border < thresh+2] = 0
        binary_fuse[border == 255] = 255


    binary_fuse = cv2.dilate(binary_fuse, np.ones((3, 3), dtype = np.uint8))
    binary_fuse = 255 - binary_fuse

    output = cv2.connectedComponentsWithStats(binary_fuse, 8, cv2.CV_32S)
    num_labels = output[0]
    labels = output[1]

    polygon_file = open(args.output, 'w')

    max_component_size = 0
    max_component = -1
    for i in range(1, num_labels):
        size = np.sum(labels == i)
        if max_component_size < size:
            max_component_size = size
            max_component = i

    for i in range(1, num_labels):
        if i == max_component:
            continue
        y, x = np.where(labels == i)
        points = zip(x, y)
        if len(points) > 2:
            hull = convex_hull(points)
            if len(hull) > 2:
                polygon_file.write('bld {count}'.format(count = len(hull)))
                for point in hull:
                    polygon_file.write(' {x} {y}'.format(x = point[0], y = point[1]))
                polygon_file.write('\n')
    polygon_file.close()
