import datetime
import json
import os
import sys

import cv2
import numpy as np
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile

DATASET_LIMIT = None
if "YTF_NIRVANA" in os.environ:
    import nirvana.job_context as nv
else:  # limit dataset size for local machine
    DATASET_LIMIT = int(os.getenv('YTF_DATASET_LIMIT', '300'))

import tf_utils

GRAPH_NAME = "inference_logits"
GRAPH_FILENAME = "road_classifier{}.gdef"


def hms_string(sec_elapsed):
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60.
    return "{}:{:>02}:{:>05.2f}".format(h, m, s)


class Dataset:
    # data_split (ds) splits data like: train = ds*data / test = (1-ds)*data
    # if shuffle is False data is taken as is
    def __init__(self, a_class_path, nota_class_path, params, shuffle=True):
        if params["cross_validation"]:
            self.ds = 1 - (1 / params["cross_validation_folds"])
        else:
            self.ds = params["data_split"]

        self.split_shift = 0
        self.batch_shift = 0
        # Two classes for binary classification
        get_class_size = lambda class_path: min(
            size for size in [DATASET_LIMIT, len(os.listdir(class_path))] if size is not None)

        self.dataset = {
            "A": np.empty([get_class_size(a_class_path), 224, 224, 3], dtype=np.uint8),
            "notA": np.empty([get_class_size(nota_class_path), 224, 224, 3], dtype=np.uint8)}
        for i, img in enumerate(self._get_images(a_class_path)):
            self.dataset["A"][i] = img
        for i, img in enumerate(self._get_images(nota_class_path)):
            self.dataset["notA"][i] = img
        if shuffle:
            np.random.shuffle(self.dataset["A"])
            np.random.shuffle(self.dataset["notA"])

        self.a_label = np.array([1, 0])
        self.nota_label = np.array([0, 1])

        # np.empty for PyCharm code analyzer
        self.train = {"A": np.empty([1, 244, 244, 3]), "notA": np.empty([1, 244, 244, 3])}
        self.test = {"A": np.empty([1, 244, 244, 3]), "notA": np.empty([1, 244, 244, 3])}
        self._apply_cv_shift()

    def get_batch(self, batch_size):
        # return batched images, labels with shapes: (batch_size, 244, 244, 3), (batch_size, 2)
        half_batch = batch_size // 2
        indices = np.random.permutation(half_batch * 2)
        batch_labels = np.concatenate((np.repeat([self.a_label], half_batch, axis=0),
                                       np.repeat([self.nota_label], half_batch, axis=0)))
        batch = np.concatenate(
            (np.take(self.train["A"], range(self.batch_shift, self.batch_shift + half_batch), mode='wrap', axis=0),
             np.take(self.train["notA"], range(self.batch_shift, self.batch_shift + half_batch), mode='wrap', axis=0)))
        self.batch_shift += half_batch
        self.batch_shift %= self.dataset["A"].size
        return batch[indices, ...], batch_labels[indices, ...]

    def get_validation_batches(self, batch_size):
        def generate_validation_batches(data):
            return (data[i * batch_size:(i + 1) * batch_size, ...] for i in
                    range(int(np.ceil(data.shape[0] / batch_size))))

        return generate_validation_batches(self.test["A"]), generate_validation_batches(self.test["notA"])

    def next_cv_split(self):
        self.split_shift += 1
        self.batch_shift = 0
        self._apply_cv_shift()

    def _apply_cv_shift(self):
        i = self.split_shift
        s = self.ds

        def get_indices_train(n):
            return (np.arange(round(n * s)) + (i + 1) * round(n * (1 - s))) % n

        def get_indices_test(n):
            return (np.arange(round(n * (1 - s))) + i * round(n * (1 - s))) % n

        for class_key in self.dataset:
            self.train[class_key] = np.take(self.dataset[class_key],
                                            get_indices_train(self.dataset[class_key].shape[0]), axis=0)

        for class_key in self.dataset:
            self.test[class_key] = np.take(self.dataset[class_key],
                                           get_indices_test(self.dataset[class_key].shape[0]), axis=0)

    @staticmethod
    def _get_images(class_path):
        files = os.listdir(class_path)
        if DATASET_LIMIT:
            files = files[:DATASET_LIMIT]
        for name in files:
            full_name = os.path.join(class_path, name)
            img = cv2.imread(full_name)
            if img is not None:
                h, w = img.shape[:2]
                if h != 224 or w != 224:
                    img = cv2.resize(img, (224, 224))
                yield img

    @staticmethod
    def save_images(path, images, road=True):
        import hashlib
        os.makedirs(path, exist_ok=True)
        class_name = "road" if road else "not_road"
        os.makedirs(os.path.join(path, class_name), exist_ok=True)
        hash = hashlib.sha1()
        for img in images[:, ...]:
            hash.update(str(datetime.datetime.now()).encode("utf8"))
            name = hash.hexdigest()[:10]
            filename = os.path.join(path, class_name, "{}.jpg".format(name))
            cv2.imwrite(filename, img)


class Progressbar:
    def __init__(self, total, bar_width=15, one_line_mode=False):
        self.start = datetime.datetime.now()
        self.progress = 0
        self.total = total
        self.width = bar_width
        self.one_line_mode = one_line_mode

    def update(self, progress_step):
        self.progress += progress_step
        if self.progress == self.total:
            self.one_line_mode = False
            self.print_progress()

    def elapsed_seconds(self):
        return (datetime.datetime.now() - self.start).total_seconds()

    def estimated_time_left(self):
        if self.progress == 0:
            return 0
        return self.elapsed_seconds() / self.progress * (self.total - self.progress)

    def print_progress(self):
        if self.total == 0:
            return
        percent = self.progress / self.total
        progress = ("#" * int(percent * self.width)).ljust(self.width, " ")
        line_begin = "\r" if self.one_line_mode else ""
        line_end = "" if self.one_line_mode else "\n"
        print(line_begin + "{: >#04.1f}% [{}] ({}<{})".format(100 * percent, progress,
                                                              hms_string(self.elapsed_seconds()),
                                                              hms_string(self.estimated_time_left())),
              end=line_end)


def alexnet(X, params, is_training):
    regulizer = None
    if "weight_decay" in params:
        regulizer = tf.contrib.layers.l2_regularizer(params["weight_decay"])
    bn_params = None
    if params['bn_params']:
        bn_params = params['bn_params']
        bn_params['is_training'] = is_training
    ####################
    conv1 = tf_utils.conv2d(X, "conv1", 11, 32, 4,
                            bn_params=bn_params, use_bias=False,
                            weight_regularizer=regulizer,
                            bias_regularizer=regulizer)
    pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding="VALID", name="pool1")
    print(conv1)
    print(pool1)
    ####################
    conv2 = tf_utils.conv2d(pool1, "conv2", 5, 48,
                            bn_params=bn_params, use_bias=True,
                            weight_regularizer=regulizer,
                            bias_regularizer=regulizer)
    pool2 = tf.nn.max_pool(conv2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding="VALID", name="pool2")
    print(conv2)
    print(pool2)
    ####################
    conv3 = tf_utils.conv2d(pool2, "conv3", 3, 64,
                            bn_params=bn_params, use_bias=False,
                            weight_regularizer=regulizer,
                            bias_regularizer=regulizer)
    conv4 = tf_utils.conv2d(conv3, "conv4", 3, 64,
                            bn_params=bn_params, use_bias=True,
                            weight_regularizer=regulizer,
                            bias_regularizer=regulizer)
    conv5 = tf_utils.conv2d(conv4, "conv5", 3, 48,
                            bn_params=bn_params, use_bias=True,
                            weight_regularizer=regulizer,
                            bias_regularizer=regulizer)
    pool5 = tf.nn.max_pool(conv5, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding="VALID", name="pool5")
    print(conv5)
    print(pool5)
    ####################
    reshape = tf_utils.flatten(pool5)
    fc1 = tf_utils.fc(reshape, 'fc1', 256,
                      bn_params=bn_params, use_bias=True,
                      weight_regularizer=regulizer,
                      bias_regularizer=regulizer)
    fc1_do = tf_utils.dropout(fc1, "dropout_fc1", params["keep_prob"], is_training)
    fc2 = tf_utils.fc(fc1_do, 'fc2', 256,
                      bn_params=bn_params, use_bias=True,
                      weight_regularizer=regulizer,
                      bias_regularizer=regulizer)
    fc2_do = tf_utils.dropout(fc2, "dropout_fc2", params["keep_prob"], is_training)
    y = tf_utils.fc(fc2_do, 'y', 2,
                    bn_params=bn_params, use_bias=False,
                    non_linear_func=None,
                    weight_regularizer=regulizer,
                    bias_regularizer=regulizer)
    print(fc1)
    print(fc1_do)
    print(fc2)
    print(fc2_do)
    print(y)
    return y


def model_evaluation(dataset, params, graph_path=None, existing_sess=None, silent=False):
    log_step = params["log_step"]
    validation_batch_size = params["validation_batch_size"]

    if graph_path:
        tf.reset_default_graph()
        load_graph(graph_path)
    TP = 0
    TN = 0
    FP = 0
    FN = 0
    if existing_sess is None:
        sess = tf.Session()
    else:
        sess = existing_sess
    inf_logits = sess.graph.get_tensor_by_name('{}:0'.format(GRAPH_NAME))

    pbr = None
    if not silent:
        pbr = Progressbar(len(dataset.test["A"]))
        print("Positive examples evaluation...")
    class_a_batches, class_nota_batches = dataset.get_validation_batches(validation_batch_size)
    for i, batch in enumerate(class_a_batches):
        if i % log_step == 0 and pbr:
            pbr.print_progress()
        pred = sess.run(inf_logits, {'inference_input:0': batch})
        pred_one_hot = sess.run(tf.one_hot(tf.argmax(pred, axis=1), 2))
        exp_one_hot = np.repeat([dataset.a_label], pred.shape[0], axis=0)
        mask = np.all(np.equal(pred_one_hot, exp_one_hot), axis=1)
        if "YTF_NIRVANA" not in os.environ:  # Save incorrectly predicted images on local machine
            dataset.save_images("errors", batch[np.logical_not(mask)], road=True)
        correct = np.sum(mask)
        TP += correct
        FN += batch.shape[0] - correct
        if pbr:
            pbr.update(batch.shape[0])

    pbr = None
    if not silent:
        pbr = Progressbar(len(dataset.test["notA"]))
        print("Negative examples evaluation...")
    for i, batch in enumerate(class_nota_batches):
        if i % log_step == 0 and pbr:
            pbr.print_progress()
        pred = sess.run(inf_logits, {'inference_input:0': batch})
        pred_one_hot = sess.run(tf.one_hot(tf.argmax(pred, axis=1), 2))
        exp_one_hot = np.repeat([dataset.nota_label], pred.shape[0], axis=0)
        mask = np.all(np.equal(pred_one_hot, exp_one_hot), axis=1)
        if "YTF_NIRVANA" not in os.environ:  # Save incorrectly predicted images on local machine
            dataset.save_images("errors", batch[np.logical_not(mask)], road=False)
        correct = np.sum(mask)
        TN += correct
        FP += batch.shape[0] - correct
        if pbr:
            pbr.update(batch.shape[0])

    if existing_sess is None:
        sess.close()
    return TP, FN, TN, FP


def print_metrics(TP, FN, TN, FP):
    if TP + FP != 0:
        print("Precision: {}".format(TP / (TP + FP)))
    if TP + FN != 0:
        print("Recall: {}".format(TP / (TP + FN)))
    if TP + FP + TN + FN != 0:
        print("Accuracy: {}".format((TP + TN) / (TP + FP + TN + FN)))


def load_graph(graph_path):
    graph_def = graph_pb2.GraphDef()
    with open(graph_path, "rb") as f:
        graph_def.ParseFromString(f.read())

    for node in graph_def.node:
        # print(node.name, node.op, [input for input in node.input])
        node.device = ""
        # print(node.name)
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in range(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr: del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr: del node.attr['use_locking']

    _ = tf.import_graph_def(graph_def, name='')


def info():
    print("Python version: {}".format(sys.version))


def get_nv_param(name):
    return nv.context().get_parameters().get(name)


def get_nv_input(name):
    return nv.context().get_inputs().get(name)


def get_nv_output(name):
    return nv.context().get_outputs().get(name)


def init_nirvana_params():
    print("Initializing params on Nirvana")
    params = dict()
    params["lr_base"] = get_nv_param('lr_base')
    params["lr_decay_rate"] = get_nv_param('lr_decay_rate')
    params["lr_decay_times"] = get_nv_param('lr_decay_times')
    params["lr_decay_staircase"] = get_nv_param('lr_decay_staircase')
    params["optimizer"] = get_nv_param('optimizer')
    params["adam_beta1"] = get_nv_param('adam_beta1')
    params["adam_beta2"] = get_nv_param('adam_beta2')
    params["adam_epsilon"] = get_nv_param('adam_epsilon')
    params["momentum"] = get_nv_param('momentum')
    params["iter_cnt"] = get_nv_param('iter_cnt')
    params["log_step"] = get_nv_param('log_step')
    params["validate_on_n_log"] = get_nv_param('validate_on_n_log')
    params["batch_size"] = get_nv_param('batch_size')
    params["validation_batch_size"] = get_nv_param('validation_batch_size')
    params["cross_validation"] = get_nv_param('cross_validation')
    params["cross_validation_folds"] = get_nv_param('cross_validation_folds')
    params["data_split"] = get_nv_param('data_split')
    params["use_batch_norm"] = get_nv_param('use_batch_norm')
    params["bn_decay"] = get_nv_param('bn_decay')
    params["bn_epsilon"] = get_nv_param('bn_epsilon')
    params["keep_prob"] = get_nv_param('keep_prob')
    params["random_state"] = get_nv_param('random_state')
    params["shuffle_data"] = get_nv_param('shuffle_data')
    print(json.dumps(params, indent=4))
    return params


def init_local_params():
    print("Initializing params on local machine")
    params = dict()
    params["lr_base"] = float(os.getenv('YTF_LR_BASE', '0.001'))
    params["lr_decay_rate"] = float(os.getenv('YTF_DECAY_RATE', '0.96'))
    params["lr_decay_times"] = float(os.getenv('YTF_DECAY_TIMES', '4'))
    params["lr_decay_staircase"] = os.getenv('YTF_DECAY_STAIRCASE', 'false').lower() in ('true', 't', '1')
    params["optimizer"] = os.getenv('YTF_OPTIMIZER', 'Momentum')
    params["momentum"] = float(os.getenv('YTF_MOMENTUM', '0.9'))
    params["adam_beta1"] = float(os.getenv('YTF_ADAM_BETA1', '0.9'))
    params["adam_beta2"] = float(os.getenv('YTF_ADAM_BETA2', '0.999'))
    params["adam_epsilon"] = float(os.getenv('YTF_MOMENTUM', '0.1'))
    params["iter_cnt"] = int(os.getenv('YTF_ITER_CNT', '2000000'))
    params["log_step"] = int(os.getenv('YTF_LOG_STEP', '1000'))
    params["validate_on_n_log"] = int(os.getenv('YTF_VALIDATE_ON_N_LOG', '10'))
    params["batch_size"] = int(os.getenv('YTF_BATCH_SIZE', '16'))
    params["validation_batch_size"] = int(os.getenv('YTF_BATCH_SIZE', '100'))
    params["cross_validation"] = os.getenv('YTF_CROSS_VALIDATION', 'false').lower() in ('true', 't', '1')
    params["cross_validation_folds"] = int(os.getenv('YTF_CROSS_VALIDATION_FOLDS', '5'))
    params["data_split"] = float(os.getenv('YTF_DATA_SPLIT', '0.7'))
    params["use_batch_norm"] = os.getenv('YTF_USE_BATCH_NORM', 'false').lower() in ('true', 't', '1')
    params["bn_decay"] = float(os.getenv('YTF_BN_DECAY', '0.9'))
    params["bn_epsilon"] = float(os.getenv('YTF_BN_EPSILON', '0.001'))
    params["keep_prob"] = float(os.getenv('YTF_KEEP_PROB', '0.5'))
    params["random_state"] = os.getenv('YTF_RANDOM_STATE')
    params["shuffle_data"] = os.getenv('YTF_SHUFFLE_DATA', 'false').lower() in ('true', 't', '1')

    # optional params
    if params["random_state"] is not None:
        params["random_state"] = int(params["random_state"])

    print(json.dumps(params, indent=4))
    return params


def main():
    info()

    # Initialization
    if "YTF_NIRVANA" in os.environ:
        params = init_nirvana_params()
    else:
        params = init_local_params()
    iter_cnt = params["iter_cnt"]
    log_step = params["log_step"]
    batch_size = params["batch_size"]
    validate_n_log = params["validate_on_n_log"]

    np.random.seed(params["random_state"])
    # if there is a folder for testing it should be used to test a model
    testing_dataset = None
    if os.path.exists("./testing"):
        testing_dataset = Dataset("./testing/road", "./testing/not_road", {"cross_validation": False, "data_split": 0},
                                  shuffle=params["shuffle_data"])
    dataset = Dataset("./data/road", "./data/not_road", params, shuffle=params["shuffle_data"])

    runs = params["cross_validation_folds"] if params["cross_validation"] else 1
    for run in range(runs):
        if params["cross_validation"]:
            print("Cross-Validation: {} run...".format(run + 1))
        print("Building tensorflow graph...")
        tf.reset_default_graph()
        images_ph = tf.placeholder(tf.uint8, shape=[None, 224, 224, 3])
        images_phf = 2.0 * ((tf.to_float(images_ph) / 255.) - 0.5)
        bool_false = tf.constant(False, tf.bool)
        is_training = tf.placeholder_with_default(bool_false, shape=[], name="is_training")
        global_step = tf.Variable(0, trainable=False)
        lr = tf.train.exponential_decay(params["lr_base"],
                                        global_step,
                                        params["iter_cnt"] // params["lr_decay_times"],
                                        params["lr_decay_rate"],
                                        staircase=params["lr_decay_staircase"])

        if "momentum" in params["optimizer"].lower():
            optimizer = tf.train.MomentumOptimizer(lr, momentum=params["momentum"])
        elif "adam" in params["optimizer"].lower():
            optimizer = tf.train.AdamOptimizer(lr, params["adam_beta1"], params["adam_beta2"], params["adam_epsilon"])
        else:
            raise Exception("Unknown optimizer type")

        bn_params = None
        if params["use_batch_norm"]:
            bn_params = {'decay': params['bn_decay'], 'epsilon': params["bn_epsilon"]}
        model_params = {"weight_decay": 0.0001,
                        "keep_prob": params["keep_prob"],
                        'bn_params': bn_params,
                        }
        logits = alexnet(images_phf, model_params, is_training)
        labels_ph = tf.placeholder(tf.float32, shape=[None, 2])

        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels_ph))
        loss_reg = loss
        if tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES):
            loss_reg += tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        train_step = optimizer.minimize(loss_reg, global_step=global_step)

        # Training
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            # inference
            inference_input_ph = tf.placeholder(tf.uint8, shape=[None, 224, 224, 3], name="inference_input")
            inference_input_phf = 2.0 * ((tf.to_float(inference_input_ph) / 255.) - 0.5)
            inference_logits = alexnet(inference_input_phf, model_params, is_training)
            _ = tf.identity(inference_logits, name=GRAPH_NAME)  # logits has shape of [None, 2]
            pbr = Progressbar(iter_cnt)
            print("Training started")
            for i in range(iter_cnt):
                images, labels = dataset.get_batch(batch_size)
                feed_dict = {images_ph: images, labels_ph: labels, is_training: True}
                _, batch_loss, batch_loss_reg = sess.run([train_step, loss, loss_reg], feed_dict=feed_dict)
                if i % log_step == 0:
                    print("Iteration: {}, loss: {}, loss reg: {}".format(i + 1, batch_loss, batch_loss_reg), end='')
                    if testing_dataset and validate_n_log and i % (log_step * validate_n_log) == 0:
                        TP, FN, TN, FP = model_evaluation(testing_dataset, params, existing_sess=sess, silent=True)
                        print(", acc: {}".format((TP + TN) / (TP + FN + TN + FP)))
                    else:
                        print()

                    graph_def = tf.get_default_graph().as_graph_def()
                    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, [GRAPH_NAME])
                    graph_filename = GRAPH_FILENAME.format("")
                    if params["cross_validation"]:
                        graph_filename = GRAPH_FILENAME.format("_cv{}".format(run + 1))
                    with gfile.GFile(graph_filename, "wb") as f:
                        f.write(output_graph_def.SerializeToString())

                    pbr.print_progress()
                pbr.update(1)
            print("Training completed")
            sess.close()
        if params["cross_validation"]:
            dataset.next_cv_split()

    # Evaluation
    print("Model evaluation on testing split started")
    for run in range(runs):
        graph_filename = GRAPH_FILENAME.format("")
        if params["cross_validation"]:
            print("Cross-Validation: {} run...".format(run + 1))
            graph_filename = GRAPH_FILENAME.format("_cv{}".format(run + 1))
        TP, FN, TN, FP = model_evaluation(dataset, params, graph_path=graph_filename)
        if params["cross_validation"]:
            dataset.next_cv_split()
        print_metrics(TP, FN, TN, FP)
    del dataset

    if testing_dataset:
        print("Using testing dataset for model evaluation...")
        graph_filename = GRAPH_FILENAME.format("")
        TP, FN, TN, FP = model_evaluation(testing_dataset, params, graph_path=graph_filename)
        print_metrics(TP, FN, TN, FP)


if __name__ == "__main__":
    main()
