import tensorflow as tf
from math import sqrt, exp, floor
from rectangles import Rect

def get_angle(rect):
    """Get angle of rotation in [0, 90].

    Args:
        rect (Rect): Named tuple Rect with rectangle parameters.
            Rect = namedtuple('Rect', ['x', 'y', 'width', 'height', 'angle'])

    Returns:
        float: Angle of rotation in [0, 90].
    """
    return rect.angle-floor(rect.angle/90.0)*90


def set_angle(rect, new_angle):
    """Change angle of rotation of rectangle.
        New angle of rotation is closest to old angle from angle+90*k

    Args:
        rect (Rect): Named tuple Rect with rectangle parameters.
            Rect = namedtuple('Rect', ['x', 'y', 'width', 'height', 'angle'])
        angle (float): New angle of rotation in [0, 90].

    Return:
        (list): List of rectangle parameters.
    """
    min_diff = 1000
    for offset in range(-360, 361, 90):
        diff = abs(rect.angle-(new_angle+offset))
        if diff < min_diff:
            closest_angle = new_angle+offset
            min_diff = diff
    return Rect(rect.x, rect.y, rect.width, rect.height, closest_angle)

def distance(rect1, rect2):
    """Calculate distance between centers of rectangles.

    Args:
        rect1 (Rect): Named tuple Rect with rectangle parameters.
            Rect = namedtuple('Rect', ['x', 'y', 'width', 'height', 'angle'])
        rect1 (Rect): Named tuple Rect with rectangle parameters.

    Returns:
        float: Euclidean distance between centers of rectangles
    """
    return sqrt((rect1.x-rect2.x)**2+(rect1.y-rect2.y)**2)


def pair_weight(rect1, rect2):
    """Calculate weight between two rectangles.
        The weight is inversely proportional to distance between centers of rectangles.

    Args:
        rect1 (Rect): Named tuple Rect with rectangle parameters.
            Rect = namedtuple('Rect', ['x', 'y', 'width', 'height', 'angle'])
        rect2 (Rect): Named tuple Rect with rectangle parameters.

    Returns:
        float: Weight between two rectangles.
    """
    return 1/(distance(rect1, rect2)+1.0)

def loss_func(rects, weight_threshold, angle_delta, angle_decay):
    """Generate loss function.
        Loss function is weighted sum of square differences beetwen angles of rotation of rectangles.

    Args:
        rects (list): List of named tuples Rect with rectangles parameters.
            Rect = namedtuple('Rect', ['x', 'y', 'width', 'height', 'angle'])
        weight_threshold (float): Minimum weight between rectangles in loss function
        angle_delta (float): Maximum difference between angles of rotation of rectangles in loss function
        angle_decay (float): Weight of penalty for large change of angle.

    Returns:
        list: List of tf.Variables corresponding to rotation angles of rectangles.
        tf.Tensor: Loss function.
    """
    angle_variables = []
    for rect in rects:
        angle_variables.append(tf.Variable(get_angle(rect), trainable = True, dtype = tf.float32))

    loss = 0.0
    for i in range(len(rects)):
        for j in range(i):
            weight = pair_weight(rects[i], rects[j])
            if weight < weight_threshold:
                continue
            else:
                angle_diff = abs(get_angle(rects[i])-get_angle(rects[j]))
                if angle_diff < angle_delta:
                    loss += weight*tf.square(angle_variables[i]-angle_variables[j])
                elif 90 - angle_diff < angle_delta:
                    if get_angle(rects[i]) > get_angle(rects[j]):
                        loss += weight*tf.square(90 - (angle_variables[i]-angle_variables[j]))
                    else:
                        loss += weight*tf.square(90 - (angle_variables[j]-angle_variables[i]))

    # Add penalty for large change of angles
    for i in range(len(rects)):
        loss += angle_decay*tf.nn.relu(tf.abs(get_angle(rects[i])-angle_variables[i])-angle_delta)

    return angle_variables, loss

def align_angles(rects, weight_threshold, angle_delta, angle_decay, iterations):
    """Generate loss function add optimize it with gradient descent.

    Args:
        rects (list): List of named tuples Rect with rectangles parameters.
            Rect = namedtuple('Rect', ['x', 'y', 'width', 'height', 'angle'])
        weight_threshold (float): Minimum weight between rectangles in loss function
        angle_delta (float): Maximum difference between angles of rotation of rectangles in loss function
        angle_decay (float): Weight of penalty for large change of angle.
        iterations (int): Number of iterations.

    Returns:
        list: List of aligned rectangles in the following format: [(x, y), (w, h), angle].
    """
    print "Generate loss function"
    angle_variables, loss = loss_func(rects, weight_threshold, angle_delta, angle_decay)

    print "Optimize loss function"
    opt = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(iterations):
            loss_val = sess.run(loss)
            sess.run(opt)
            print "Iteration {}. loss: {}".format(i+1, loss_val)
        angle_values = sess.run(angle_variables)

    print "Set new angles"
    aligned_rects = []
    for i in range(len(rects)):
        aligned_rects.append(set_angle(rects[i], angle_values[i]))
    return aligned_rects
