import os
import tarfile

import numpy as np
import tensorflow as tf
import nirvana_dl
from ..data_preprocessors.faster_rcnn_preprocessor import FasterRCNNPreprocessor

from ...faster_rcnn_impl.model import factory


class SaveModelAndPrepareGDefs(tf.keras.callbacks.Callback) :
    def __init__(self, output_dir):
        super().__init__()
        self.output_dir=output_dir

    def on_epoch_end(self, epoch, logs=None):
        self.model.export_for_serving(os.path.join(self.output_dir, str(epoch)))

    def on_train_end(self, logs=None):
        with tarfile.open(os.path.join(self.output_dir, "gdefs.tar.gz"), "w|gz") as gdefs_archive:
            for cur_gdef in os.listdir(os.path.join(self.output_dir, "gdefs")) :
                gdefs_archive.add(os.path.join(self.output_dir, "gdefs", cur_gdef), arcname=cur_gdef)


class NirvanaSaveState(tf.keras.callbacks.Callback) :
    def on_epoch_end(self, epoch, logs=None):
        nirvana_dl.snapshot.dump_snapshot()


def lr_scheduler(epoch, lr) :
    if epoch < 7 :
        return lr

    return lr * 0.5


class FasterRCNN():

    OPTIMIZERS = {
        'adam': tf.keras.optimizers.Adam,
        'rmsprop': tf.keras.optimizers.RMSprop,
        'sgd': tf.keras.optimizers.SGD
    }

    def __init__(
        self, preprocessor: FasterRCNNPreprocessor
    ) -> None:
        self.preprocessor = preprocessor
        self.wrapper = self.preprocessor.io_wrapper
        self.model = None
        self.strategy = tf.distribute.MirroredStrategy()

        self.train_steps = self.preprocessor.params["train_size"] // self.preprocessor.params["batch_size"]
        self.lr_scheduler = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=list(np.asarray(self.preprocessor.params["lr_boundaries"]) * self.train_steps),
            values=self.preprocessor.params["lr_values"]
        )
        self.callbacks = [
            tf.keras.callbacks.TensorBoard(
                log_dir=self.wrapper.get_logs_dir()
            ),
            tf.keras.callbacks.experimental.BackupAndRestore(
                backup_dir=self.wrapper.get_checkpoint_dir()
            ),
            SaveModelAndPrepareGDefs(self.wrapper.get_output_dir())
        ]
        if self.wrapper.io_type == "nirvana" :
            self.callbacks.append(NirvanaSaveState())

    def prepare_model_for_training(self) -> tf.keras.Model:
        if self.model is None :
            self.model = self.build_model()
        self.model.trainable = True

    def train(self) -> None:
        self.prepare_model_for_training()
        train_gen = self.strategy.experimental_distribute_dataset(self.preprocessor.get_train_generator())
        val_gen = None
        valid_steps = 0
        if self.preprocessor.params["has_valid"] is True :
            val_gen = self.strategy.experimental_distribute_dataset(self.preprocessor.get_valid_generator())
            valid_steps = self.preprocessor.params["test_size"] // self.preprocessor.params["batch_size"]

        self.model.fit(
            train_gen,
            epochs=self.preprocessor.params["epochs"],
            steps_per_epoch=self.train_steps,
            callbacks=self.callbacks,
            validation_data=val_gen,
            validation_steps=valid_steps,
            verbose=2
        )

    def build_model(self) -> tf.keras.Model:
        with self.strategy.scope() :
            model = factory.build_model(self.preprocessor.params, list(self.preprocessor.str_to_ind.keys()))

            model.compile(
                loss=self.get_loss_fn(),
                optimizer=self.get_optimizer_fn(),
                metrics=[]
            )
        return model

    def get_loss_fn(self) :
        return None

    def get_optimizer_fn(self) :
        return self.OPTIMIZERS[self.preprocessor.params["optimizer"]](learning_rate=self.lr_scheduler)
