import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
from dataclasses import dataclass
from customer_service.ml.lib.data.utils import (
    load_data,
    save_data
)


@dataclass
class DocEmbedItem:
    """Class for keeping track of document item with embeddings"""
    text: str
    url: str
    embed: np.ndarray


@dataclass
class DocEmbedData:
    """Class for keeping track of document embeddings"""
    items: dict
    know_emb_matrix: np.ndarray


def get_target_encoder(df, document_col):
    le = LabelEncoder()
    le.fit(df[document_col])
    return le


def generate_doc_embed_data(df, target_col='target', text_col='know',
                            url_col='knowledge_id', embedding_col='know_emb'):
    items = {}
    for _, row in tqdm(df.iterrows()):
        items[row[target_col]] = DocEmbedItem(row[text_col], row[url_col], row[embedding_col])

    know_emb_matrix = np.array(
        [items[i].embed for i in range(len(items))]
    )

    return DocEmbedData(items, know_emb_matrix)


def generate_query_embed_matrix(df, embedding_col='dialog_emb'):
    return np.array(
        list(
            df[embedding_col]
        )
    )


def closest_document(Q, D):
    return np.argmax((Q @ D.T), -1)


def get_classifier():
    return KNeighborsClassifier(n_neighbors=10)


def classify_queries_embeddings(
        train_input, val_input, test_input,
        train_output, val_output, test_output,
        doc_col='know',
        url_col='knowldge_id',
        query_col='dialog',
        doc_embedding_col='know_emb',
        query_embedding_col='dialog_emb',
        target_col='y_true',
        predict_col='y_pred'):

    df_train = load_data(train_input)
    df_val = load_data(val_input)
    df_test = load_data(test_input)

    df = pd.concat([df_train, df_val, df_test])

    le = get_target_encoder(df, doc_col)
    df_train[target_col] = le.transform(df_train[doc_col])
    df_val[target_col] = le.transform(df_val[doc_col])
    df_test[target_col] = le.transform(df_test[doc_col])
    df[target_col] = le.transform(df[doc_col])

    doc_embed_data = generate_doc_embed_data(df, target_col, doc_col, url_col, doc_embedding_col)

    query_mt_train = generate_query_embed_matrix(df_train, embedding_col=query_embedding_col)
    query_mt_val = generate_query_embed_matrix(df_val, embedding_col=query_embedding_col)
    query_mt_test = generate_query_embed_matrix(df_test, embedding_col=query_embedding_col)

    clf = get_classifier()
    clf.fit(query_mt_train, df_train[target_col])

    df_val[predict_col] = clf.predict(query_mt_val)
    df_test[predict_col] = clf.predict(query_mt_test)

    df_val['pred_know'] = df_val[predict_col].apply(lambda x: doc_embed_data.items[x].text)
    df_test['pred_know'] = df_test[predict_col].apply(lambda x: doc_embed_data.items[x].text)
    df_val['pred_url'] = df_val[predict_col].apply(lambda x: doc_embed_data.items[x].url)
    df_test['pred_url'] = df_test[predict_col].apply(lambda x: doc_embed_data.items[x].url)

    for output, df_out in zip([train_output, val_output, test_output],
                              [df_train, df_val, df_test]):
        if output:
            save_data(df_out, output)
