import sys
import os.path
import tensorflow as tf
from tensorflow.core.protobuf import saver_pb2

class SaverWithEvents(tf.train.Saver):
    def __init__(self,
                 before_save_event = None,
                 after_save_event = None,
                 var_list=None,
                 reshape=False,
                 sharded=False,
                 max_to_keep=5,
                 keep_checkpoint_every_n_hours=10000.0,
                 name=None,
                 restore_sequentially=False,
                 saver_def=None,
                 builder=None,
                 defer_build=False,
                 allow_empty=False,
                 write_version=saver_pb2.SaverDef.V2,
                 pad_step_number=False,
                 save_relative_paths=False,
                 filename=None):
        self._before_save_event = before_save_event
        self._after_save_event = after_save_event
        super(SaverWithEvents, self).__init__(var_list=var_list,
                                           reshape=reshape,
                                           sharded=sharded,
                                           max_to_keep=max_to_keep,
                                           keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                                           name=name,
                                           restore_sequentially=restore_sequentially,
                                           saver_def=saver_def,
                                           builder=builder,
                                           defer_build=defer_build,
                                           allow_empty=allow_empty,
                                           write_version=write_version,
                                           pad_step_number=pad_step_number,
                                           save_relative_paths=save_relative_paths,
                                           filename=filename)

    def save(self, sess, save_path, global_step = None, latest_filename = None, meta_graph_suffix = 'meta', write_meta_graph = True, write_state = True):
        if (self._before_save_event is not None):
            self._before_save_event()
        path_prefix = super(SaverWithEvents, self).save(sess, save_path, global_step, latest_filename, meta_graph_suffix, write_meta_graph, write_state)
        if (self._after_save_event is not None):
            self._after_save_event(path_prefix)
        return path_prefix