import click
import logging
from customer_service.ml.chats.zeliboba.lib.classifier import classify_queries_embeddings

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@click.command()
@click.option('--input_data_train', required=True, help='Path to YT train dataset with embeddings')
@click.option('--input_data_val', required=True, help='Path to YT val dataset with embeddings')
@click.option('--input_data_test', required=True, help='Path to YT test dataset with embeddings')
@click.option('--output_data_train', required=False, help='Path to YT train dataset with classification results')
@click.option('--output_data_val', required=False, help='Path to YT val dataset with classification results')
@click.option('--output_data_test', required=False, help='Path to YT test dataset with classification results')
@click.option('--doc_col', default='know', help='Name of document column')
@click.option('--url_col', default='knowledge_id', help='Name of knowldge url column')
@click.option('--query_col', default='dialog', help='Name of query column')
@click.option('--doc_embedding_col', default='know_emb', help='Name of document embedding column')
@click.option('--query_embedding_col', default='dialog_emb', help='Name of query embedding column')
@click.option('--target_col', default='y_true', help='Name of column with ground truth values')
@click.option('--predict_col', default='y_pred', help='Name of column with predicted values')
def main(input_data_train: str,
         input_data_val: str,
         input_data_test: str,
         output_data_train: str,
         output_data_val: str,
         output_data_test: str,
         doc_col: str,
         url_col: str,
         query_col: str,
         doc_embedding_col: str,
         query_embedding_col: str,
         target_col: str,
         predict_col: str):

    logger.info('Classification of embeddings')
    classify_queries_embeddings(
        input_data_train, input_data_val, input_data_test,
        output_data_train, output_data_val, output_data_test,
        doc_col, url_col, query_col, doc_embedding_col, query_embedding_col,
        target_col, predict_col
        )


if __name__ == '__main__':
    main()
