from dataclasses import dataclass
from sklearn.metrics import (
    accuracy_score,
    balanced_accuracy_score,
    f1_score,
    precision_score,
    recall_score
)
from .preprocessing import load_data


@dataclass
class Metrics:
    """Class for keeping track of metrics"""
    accuracy: float
    balanced_accuracy: float
    precision: float
    recall: float
    f1: float


def calculate_metrics(y_true, y_pred, average_type='weighted', n_round=4) -> Metrics:
    acc = accuracy_score(y_true, y_pred)
    b_acc = balanced_accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average=average_type, zero_division=0)
    rec = recall_score(y_true, y_pred, average=average_type, zero_division=0)
    f1 = f1_score(y_true, y_pred, average=average_type, zero_division=0)

    metrics = Metrics(
        round(acc, n_round),
        round(b_acc, n_round),
        round(prec, n_round),
        round(rec, n_round),
        round(f1, n_round)
    )

    return metrics


def get_metrics(dataset, true_col='y_true', pred_col='y_pred') -> Metrics:
    df = load_data(dataset)
    return calculate_metrics(df[true_col], df[pred_col])
