import numpy as np

class Dataset(object):
    def __init__(self, params):
        self.width  = params["width"]
        self.height = params["height"]
        self.chns   = params["chns"]

        self.data = {"image" : None, "segm" : None}

        self.data["image"] = np.fromfile(params["images_filepath"], dtype = np.uint8).reshape([-1, self.height, self.width, self.chns])
        self.data["segm"] = np.fromfile(params["segm_filepath"], dtype = np.uint8).reshape([-1, self.height, self.width, 1])
        self.data_source = {"image" : params["images_filepath"], "segm" : params["segm_filepath"]}
        assert self.data["image"].shape[0] == self.data["segm"].shape[0]
        if (params["remove_empty_cell"]):
            not_empty_cells = np.sum(self.data["segm"], axis = (1, 2, 3)) > 0
            self.data["image"] = self.data["image"][not_empty_cells]
            self.data["segm"] = self.data["segm"][not_empty_cells]
        # split into train and validate parts
        validate_cnt = params["validate_batchs"] * params["batch_size"]
        self.split_data(validate_cnt, params["data_split_seed"])
        # data augmentation
        if(params["data_augmentation"]):
            self.augment_data()

    def split_data(self, validate_cnt, random_state):
        train_cnt = self.data["image"].shape[0] - validate_cnt
        p = np.random.RandomState(seed=random_state).permutation(self.data["image"].shape[0])
        self.train_data = {"image" : None, "segm" : None}
        self.train_data["image"] = self.data["image"][p[0:train_cnt]]
        self.train_data["segm"] = self.data["segm"][p[0:train_cnt]]

        self.validate_data = {"image" : None, "segm" : None}
        self.validate_data["image"] = self.data["image"][p[train_cnt:]]
        self.validate_data["segm"] = self.data["segm"][p[train_cnt:]]

        del self.data

    def expand_data(self, data):
        data = np.concatenate([data, data[:,:,::-1,:]], axis = 0)
        data = np.concatenate([data,
                               np.rot90(data, k=1, axes = (1,2)),
                               np.rot90(data, k=2, axes = (1,2)),
                               np.rot90(data, k=3, axes = (1,2))], axis = 0)
        return data

    def augment_data(self):
        # Expand train data
        self.train_data["image"] = self.expand_data(self.train_data["image"])
        self.train_data["segm"] = self.expand_data(self.train_data["segm"])
        # Expand validate data
        self.validate_data["image"] = self.expand_data(self.validate_data["image"])
        self.validate_data["segm"] = self.expand_data(self.validate_data["segm"])

    def shuffle_train_data(self):
        p = np.random.permutation(self.train_data["image"].shape[0])
        self.train_data["image"] = self.train_data["image"][p]
        self.train_data["segm"] = self.train_data["segm"][p]

    def train_cnt(self):
        return self.train_data["image"].shape[0]

    def validate_cnt(self):
        return self.validate_data["image"].shape[0]

    def get_batch(self, index, batch_size, is_training):
        batch_start = index * batch_size
        batch_end = batch_start + batch_size
        batch = {"image" : None, "segm" : None}
        if is_training:
            batch["image"] = self.train_data["image"][batch_start:batch_end, ...]
            batch["segm"] = self.train_data["segm"][batch_start:batch_end, ...]
        else:
            batch["image"] = self.validate_data["image"][batch_start:batch_end, ...]
            batch["segm"] = self.validate_data["segm"][batch_start:batch_end, ...]
        return batch

    def print_info(self):
        print("Dataset type: segmentation")
        print("Source:")
        print("    Images: {}".format(self.data_source["image"]))
        print("    Segmentation: {}".format(self.data_source["segm"]))

        print("Image width: {}".format(self.width))
        print("Image height: {}".format(self.height))

        print("Train samples: {}".format(self.train_cnt()))
        print("    Building pixels count: {}".format(np.sum(self.train_data["segm"] > 0)))
        print("Validate samples: {}".format(self.validate_cnt()))
        print("    Building pixels count: {}".format(np.sum(self.validate_data["segm"] > 0)))

