import os.path
import math
import cv2
import numpy as np


PATCH_SIZE = 4
MOVING_MNIST_SEQ_LENGTH = 20

MOVING_MNIST_TRAIN_PATH = './moving-mnist-train.npz'
MOVING_MNIST_VALID_PATH = './moving-mnist-valid.npz'
MOVING_MNIST_TEST_PATH  = './moving-mnist-test.npz'

def load_data(folder_path):
    train = np.load(os.path.normpath(os.path.join(folder_path, MOVING_MNIST_TRAIN_PATH)))
    valid = np.load(os.path.normpath(os.path.join(folder_path, MOVING_MNIST_VALID_PATH)))
    test  = np.load(os.path.normpath(os.path.join(folder_path, MOVING_MNIST_TEST_PATH )))

    train_shape = train['input_raw_data'].shape
    valid_shape = train['input_raw_data'].shape
    test_shape = train['input_raw_data'].shape

    assert train_shape[1:] == test_shape[1:]
    assert train_shape[1:] == valid_shape[1:]

    channels = train_shape[1]
    height   = train_shape[2]
    width    = train_shape[3]
    print("input data shape: ")
    print("  width  =   ", width)
    print("  height =   ", height)
    print("  channels = ", channels)

    train_data = train['input_raw_data'].reshape(-1, MOVING_MNIST_SEQ_LENGTH, height, width, channels)
    valid_data = valid['input_raw_data'].reshape(-1, MOVING_MNIST_SEQ_LENGTH, height, width, channels)
    test_data  = test['input_raw_data'].reshape( -1, MOVING_MNIST_SEQ_LENGTH, height, width, channels)

    return train_data, valid_data, test_data

def show_data(dataset):
    shape = dataset.shape
    seq_count = shape[0]
    seq_length = shape[1]
    height     = shape[2]
    width      = shape[3]
    channels   = shape[4]
    temp = np.zeros((height, width * seq_length, channels))
    for bs_idx in range(seq_count):
        for seq_idx in range(seq_length):
            temp[:, seq_idx * width : (seq_idx+1) * width, :] = dataset[bs_idx, seq_idx,:,:,:]
        cv2.imshow('test', temp)
        cv2.waitKey()

def show_enc_data(dataset):
    shape = dataset.shape
    seq_count = shape[0]
    seq_length = shape[1]
    height     = shape[2]
    width      = shape[3]
    channels   = shape[4]
    if (False):
        temp = np.zeros((height * channels, width * seq_length, 1))
        for bs_idx in range(seq_count):
            for seq_idx in range(seq_length):
                for ch_idx in range(channels):
                    temp[ch_idx*height : (ch_idx + 1) * height, seq_idx * width : (seq_idx+1) * width, 0] = dataset[bs_idx, seq_idx,:,:,ch_idx]
            cv2.imshow('test', temp)
            cv2.waitKey()
    else:
        temp = np.zeros((height * PATCH_SIZE, width * PATCH_SIZE * seq_length, 1))
        for bs_idx in range(seq_count):
            for seq_idx in range(seq_length):
                for y in range(height):
                    for x in range(width):
                        for ch_idx in range(channels):
                            temp[y * PATCH_SIZE + ch_idx // PATCH_SIZE,
                                 seq_idx * width * PATCH_SIZE + x * PATCH_SIZE + ch_idx % PATCH_SIZE,
                                 0] = dataset[bs_idx, seq_idx, y, x,ch_idx]
            cv2.imshow('test', temp)
            cv2.waitKey()

def save_data(dataset, folder_path):
    shape = dataset.shape
    batch_size = shape[0]
    seq_length = shape[1]
    height     = shape[2]
    width      = shape[3]
    channels   = shape[4]
    temp = np.zeros((shape[2], shape[3] * shape[1], shape[4]))
    for bs_idx in range(shape[0]):
    #for bs_idx in range(1):
        for seq_idx in range(shape[1]):
            temp[:, seq_idx * shape[3] : (seq_idx+1) * shape[3], :] = dataset[bs_idx, seq_idx,:,:,:]
        cv2.imwrite("{}/batch{}.png".format(folder_path, bs_idx), temp * 255)

def encode_data(data):
    shape = data.shape
    seq_count  = shape[0]
    seq_length = shape[1]

    tensor_height = shape[2] // PATCH_SIZE
    tensor_width  = shape[3] // PATCH_SIZE
    tensor_ch     = shape[4] * PATCH_SIZE * PATCH_SIZE

    #                            0           1              2           3             4           5
    data = data.reshape((seq_count, seq_length, tensor_height, PATCH_SIZE, tensor_width, PATCH_SIZE, -1))
    data = np.transpose(data, (0, 1, 2, 4, 3, 5, 6))
    #                    seq_count, seq_length, tensor_height, tensor_width, PATCH_SIZE, PATCH_SIZE
    data = data.reshape((seq_count, seq_length, tensor_height, tensor_width, tensor_ch))
    #                            0           1              2             3          4

    return data

def decode_data(data):
    shape = data.shape
    seq_count  = shape[0]
    seq_length = shape[1]

    height = shape[2] * PATCH_SIZE
    width  = shape[3] * PATCH_SIZE
    ch     = shape[4] // (PATCH_SIZE * PATCH_SIZE)

    data = data.reshape((seq_count, seq_length, shape[2], shape[3], PATCH_SIZE, PATCH_SIZE, ch))
    data = np.transpose(data, (0, 1, 2, 4, 3, 5, 6))
    data = data.reshape((seq_count, seq_length, height, width, ch))

    return data

def generate_data(mnist, shape=(64,64), seq_len=MOVING_MNIST_SEQ_LENGTH, seq_cnt=10000, obj_per_image=2):
    MIN_SPEED = 2
    MAX_SPEED = 5
    height, width = shape
    obj_height, obj_width = mnist.shape[1:3]
    lims = [x_lim, y_lim] = width - obj_width, height - obj_height
    dataset = np.zeros((seq_cnt, seq_len, height, width, 1), dtype=np.float32)
    for seq_idx in range(seq_cnt):
        # randomly generate direc/speed/position, calculate velocity vector
        direcs = np.pi * (np.random.rand(obj_per_image)*2 - 1)
        speeds = np.random.randint(MAX_SPEED - MIN_SPEED, size=obj_per_image) + MIN_SPEED
        veloc = [[v*math.cos(d), v*math.sin(d)] for d,v in zip(direcs, speeds)]

        mnist_images = [mnist[r] for r in np.random.randint(0, mnist.shape[0], obj_per_image)]
        pos = [[np.random.rand()*x_lim, np.random.rand()*y_lim] for _ in range(obj_per_image)]
        for frame_idx in range(seq_len):
            for i in range(obj_per_image):
                x = int(round((pos[i])[0]))
                y = int(round((pos[i])[1]))
                dataset[seq_idx, frame_idx, y : y + obj_height, x : x + obj_width, :] = dataset[seq_idx, frame_idx, y : y + obj_height, x : x + obj_width, :] + mnist_images[i]
            dataset[seq_idx, frame_idx] = np.minimum(1.0, dataset[seq_idx, frame_idx])
            for i in range(obj_per_image):
                x = pos[i][0] + veloc[i][0]
                y = pos[i][1] + veloc[i][1]
                if x <= 0 or x >= lims[0]:
                    veloc[i][0] = -veloc[i][0]
                    x = x + veloc[i][0]
                if y <= 0 or y >= lims[1]:
                    veloc[i][1] = -veloc[i][1]
                    y = y + veloc[i][1]
                pos[i] = [x, y]
    return dataset
