import os

import requests
import matplotlib.pyplot as plt
import numpy as np

from nirvana_api.highlevel_api import WorkflowInstance
from nirvana_api import NirvanaApi


def parse_wf_id(url):
    return url.split('/')[4]


def get_names_urls(wf_id, nirvana_token=None, status='running'):
    napi = NirvanaApi(nirvana_token or os.environ['NIRVANA_TOKEN'])

    names_urls = []
    for wfi_params in napi.find_workflows(wf_id):
        wfi_id = wfi_params['instanceId']

        wfi = WorkflowInstance(napi, wfi_id)

        if wfi.get_execution_state()['status'] == status:
            block_guid = wfi.get_block_results()[0]['blockGuid']
            logs = napi.get_block_logs(
                workflow_id=wf_id,
                workflow_instance_id=wfi_id,
                log_names=['stdout.log'],
                block_patterns=[{'guid': block_guid}]
            )[0]['logs']

            if len(logs) > 0:
                names_urls.append((wfi_params['instanceComment'], logs[0]['storagePath']))

    return names_urls


def load_torch_data(url):
    train_loss, val_loss, val_auc = [], [], []
    train_loss_x, val_loss_x, val_auc_x = [], [], []
    min_objs, max_objs = 0, 0

    r = requests.get(url, stream=True)
    for line in r.iter_lines():
        line_split = line.split('|')
        if len(line_split) == 13:
            num = float(line_split[4][1:-1])
            train_loss.append(num)

            x = int(line_split[0].strip())
            train_loss_x.append(x)
            if len(val_loss) > len(val_loss_x):
                val_loss_x.append(x)
            if len(val_auc) > len(val_auc_x):
                val_auc_x.append(x)

            total_objs = int(line_split[1].strip())
            min_objs = min(min_objs, total_objs)
            max_objs = max(max_objs, total_objs)

        line_split = line.split('Validation loss Loss: ')
        if len(line_split) == 2:
            val_loss.append(float(line_split[-1]))

        line_split = line.split('Validation AUC ')
        if len(line_split) == 2:
            val_auc.append(float(line_split[-1]))

    if len(val_loss) > len(val_loss_x):
        val_loss_x.append(x)
    if len(val_auc) > len(val_auc_x):
        val_auc_x.append(x)

    return (
        train_loss, val_loss, val_auc,
        train_loss_x, val_loss_x, val_auc_x,
        min_objs, max_objs
    )


def plot_curve(ax, name, data):
    (
        train_loss, val_loss, val_auc,
        train_loss_x, val_loss_x, val_auc_x,
        min_objs, max_objs
    ) = data

    ax.set_title(name, fontweight='semibold')
    ax.plot(train_loss_x, train_loss, 'b-', linewidth=0.5, color='green', alpha=0.3)
    ax.plot(val_loss_x, val_loss, 'b-', linewidth=0.5, color='blue')
    ax.plot(val_auc_x, val_auc, 'b-', linewidth=0.5, color='red')

    ax.set_xticks([train_loss_x[0], train_loss_x[-1]])
    ax.set_xticklabels([min_objs, max_objs])

    argmax_auc = np.argmax(val_auc)
    max_auc_x = val_auc_x[argmax_auc]
    max_auc = val_auc[argmax_auc]
    ax.plot(
        max_auc_x,
        max_auc,
        marker='x',
        markersize=5,
        color='red'
    )
    ax.text(
        max_auc_x,
        max_auc,
        round(max_auc * 100, 2),
        horizontalalignment='right',
        fontsize=14,
        weight='semibold'
    )
    ax.axhline(y=max_auc, color='grey', linestyle='--', alpha=0.7, linewidth=0.5)

    argmin_loss = np.argmin(val_loss)
    min_loss_x = val_loss_x[argmin_loss]
    min_loss = val_loss[argmin_loss]
    ax.plot(
        min_loss_x,
        min_loss,
        marker='x',
        markersize=5,
        color='blue'
    )
    ax.text(
        min_loss_x,
        min_loss,
        round(min_loss, 4),
        horizontalalignment='right',
        verticalalignment='top',
        fontsize=14,
        weight='semibold'
    )
    ax.axhline(y=min_loss, color='grey', linestyle='--', alpha=0.7, linewidth=0.5)

    ax.text(0.05, 0.95, str(train_loss_x[-1]) + ' sec.', ha='left', va='top',
            transform=ax.transAxes, fontsize=14)


def plot_curves(url, width=3, figsize=None, dpi=200, ylim=(0.5, 0.8), title='Learning curves',
                load_data=load_torch_data):

    wf_id = parse_wf_id(url)
    names_urls = get_names_urls(wf_id)

    nplots = len(names_urls)
    height = -(-nplots // width)

    if figsize is None:
        figsize = (6 * width, 5 * height)

    fig, axes = plt.subplots(height, width, figsize=figsize, dpi=dpi)
    plt.setp(axes, ylim=ylim)
    fig.suptitle(title, fontsize=20, fontweight='bold')

    for i, (name, url) in enumerate(names_urls):
        if height == 1:
            ax = axes[i]
        else:
            ax = axes[i // width][i % width]
        plot_curve(ax, name, load_data(url))

    if width * height > nplots:
        if height == 1:
            for ax in axes[-(width * height - nplots):]:
                fig.delaxes(ax)
        else:
            for ax in axes[-1][-(width * height - nplots):]:
                fig.delaxes(ax)

    fig.legend(['Train Loss', 'Validation Loss', 'AUC'])

    plt.show()
