import tensorflow as tf


def get_class_names():
    return ["no", "yes"]


def get_elements_number(tfrecord_path):
    cnt = 0
    for record in tf.python_io.tf_record_iterator(tfrecord_path):
        cnt += 1
    return cnt


def resize_image(image, size):
    image = tf.expand_dims(image, axis=0)
    resized_image = tf.image.resize_images(image, size)
    resized_image = tf.squeeze(resized_image, axis=[0])
    return resized_image


def preprocess(inputs):
    return 2.0 * (inputs / 255. - 0.5)


def parse_example_proto(example_proto, train_config, augment):
    features = {
        "image/encoded": tf.FixedLenFeature((), tf.string, default_value=''),
        "mask/encoded":  tf.FixedLenFeature((), tf.string, default_value=''),
        "label":         tf.FixedLenFeature((), tf.int64, default_value=0),
    }
    parsed_features = tf.parse_single_example(example_proto, features)

    image_size = [train_config["image"]["height"], train_config["image"]["width"]]
    image = tf.image.decode_jpeg(parsed_features["image/encoded"], channels=3)
    image = resize_image(image, image_size)
    mask = tf.image.decode_png(parsed_features["mask/encoded"], channels=1)
    mask = resize_image(mask, image_size)
    concat_image = tf.concat([image, mask], axis=2)
    preprocessed_image = preprocess(concat_image)
    if augment is True:
        preprocessed_image = tf.image.random_flip_left_right(preprocessed_image)
        rotate = tf.random_uniform([], 0, 4, tf.int32)
        preprocessed_image = tf.image.rot90(preprocessed_image, rotate)
    label = tf.to_int32(parsed_features["label"])
    return preprocessed_image, label


def get_parse_function(train_config, augment):
    def parse_function(example_proto):
        return parse_example_proto(example_proto, train_config, augment)
    return parse_function


def create_train_dataset(train_config):
    batch_size = train_config["batch_size"]
    return (tf.data.TFRecordDataset(train_config["dataset"]["train"])
            .map(get_parse_function(train_config, augment=True), batch_size)
            .prefetch(2 * batch_size)
            .batch(batch_size))


def create_valid_dataset(train_config):
    return (tf.data.TFRecordDataset(train_config["dataset"]["valid"])
            .map(get_parse_function(train_config, augment=False), 1)
            .prefetch(10)
            .batch(1))


def get_data(dataset):
    iterator = dataset.make_initializable_iterator()
    images, labels = iterator.get_next()
    return images, labels, iterator.initializer
