import yt.wrapper as yt
try:
    import tensorflow as tf
    from tensorflow.contrib.keras import callbacks
    from tensorflow.contrib.keras import models
    from tensorflow.contrib.keras import backend as K
    from tensorflow.python.framework import graph_util
    from tensorflow.python.framework import graph_io
except ImportError as e:
    pass
import sys
import logging
try:
    from crypta.lib.nirvana.email_gender.generator import (
        generator_wrapper,
        MAX_LEN,
        DICT_SIZE,
        DOMAIN_DICT_SIZE,
    )
except ImportError as e:
    sys.path.append('generator')
    from crypta.lib.nirvana.email_organization.generator import (
        generator_wrapper,
        DICT_SIZE,
        MAX_LEN,
        DOMAIN_DICT_SIZE,
    )

from crypta.lib.nirvana.email_organization.app.conv_model import build_model

BATCH_SIZE = 25


def get_row_counts(path):
    row_count = int(yt.get_attribute(path, 'row_count'))
    border = int(0.7 * row_count)
    logger.info('Found %d rows in %s. Border: %d', row_count, path, border)
    return row_count, border


def get_generators(dataset_path, border):
    train_path = '{path}[#0:#{border}]'.format(path=dataset_path,
                                               border=border)
    logger.info("Train path: %s", train_path)
    validation_path = '{path}[#{border}:]'.format(path=dataset_path,
                                                  border=border)
    logger.info("Validation path: %s", validation_path)
    generators = generator_wrapper(train_path, BATCH_SIZE), \
        generator_wrapper(validation_path, BATCH_SIZE)
    return generators


def learn_model(dataset_path):
    max_epochs = 50

    model = build_model(DICT_SIZE, MAX_LEN, DOMAIN_DICT_SIZE)
    model.compile(loss='binary_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])

    row_count, border = get_row_counts(dataset_path)
    train_generator, val_generator = get_generators(dataset_path, border)
    early_stopping = callbacks.EarlyStopping(monitor='val_loss',
                                             min_delta=1e-3,
                                             patience=5,
                                             verbose=1)
    lr_reducer = callbacks.ReduceLROnPlateau(monitor='val_loss',
                                             factor=0.2,
                                             patience=3,
                                             verbose=1)
    checkpoint_path = 'email-organizations.hdf5'
    checkpointer = callbacks.ModelCheckpoint(checkpoint_path,
                                             monitor='val_loss',
                                             verbose=1,
                                             save_best_only=True,
                                             save_weights_only=False)
    callback_list = [early_stopping, checkpointer, lr_reducer]
    epoch_steps = int(border / BATCH_SIZE)
    val_steps = int((row_count - border) / BATCH_SIZE)
    hist = model.fit_generator(train_generator,
                               epochs=max_epochs,
                               verbose=1,
                               steps_per_epoch=epoch_steps,
                               validation_data=val_generator,
                               validation_steps=val_steps,
                               callbacks=callback_list)

    return models.load_model(checkpoint_path), hist


def save_keras_model(keras_model, out_path):
    with open('{base}/email_model.yaml'.format(base=out_path), 'w') as fh:
        fh.write(keras_model.to_yaml())
    keras_model.save_weights('{base}/model_weights.hdf5'.format(base=out_path))


def extract_const_tf_graph(keras_model):
    K.set_learning_phase(0)
    login_input, domain_input = keras_model.input
    tf.identity(login_input, name='login_input')
    tf.identity(domain_input, name='domain_input')
    tf.identity(keras_model.output, name='network_output')
    sess = K.get_session()
    graph = sess.graph
    const_graph = \
        graph_util.convert_variables_to_constants(sess,
                                                  graph.as_graph_def(),
                                                  ['network_output'])
    return const_graph


def set_logHandler(logger, out_path):
    logHandler = logging.FileHandler('{base}/debug.log'.format(base=out_path))
    logger.addHandler(logHandler)
    logger.setLevel(logging.DEBUG)


def main(dataset_path, out_path):
    logger.info('Tensorflow version: %s', tf.__version__)
    model, hist = learn_model(dataset_path)
    save_keras_model(model, out_path)
    const_graph = extract_const_tf_graph(model)
    output_path = '{base}/email.pb'.format(base=out_path)
    graph_io.write_graph(const_graph, '.', output_path, as_text=False)
    logger.info('History : %s\n', str(hist.history))


if __name__ == '__main__':
    global logger
    logger = logging.getLogger(__name__)
    set_logHandler(logger, sys.argv[2])
    main(sys.argv[1], sys.argv[2])
