import argparse
import nirvana_dl
import pandas as pd
import torch
from transformers import (
    AutoModel,
    AutoTokenizer,
)
import yt.wrapper as yt


yt.config["proxy"]["url"] = "hahn.yt.yandex.net"

params = nirvana_dl.params()
segments_with_info_table, segments_vectors_with_info_table = params['segments_with_info_table'],  params['segments_vectors_with_info_table']
description_field, vector_field = params['description_field'], params['vector_field']

parser = argparse.ArgumentParser()
parser.add_argument(
  '--model_folder',
  type=str,
  default='',
  help='Absolute path to model dir'
)
args, unparsed = parser.parse_known_args()
labse_tokenizer = AutoTokenizer.from_pretrained(args.model_folder)
labse_model = AutoModel.from_pretrained(args.model_folder)

class GetTextEmbedding:
    def __init__(self, tokenizer, model):
        self.tokenizer = tokenizer
        self.model = model
        
    def __call__(self, text):
        t = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            model_output = self.model(**{k: v.to(self.model.device) for k, v in t.items()})
        embeddings = model_output.last_hidden_state[:, 0, :]
        embeddings = torch.nn.functional.normalize(embeddings)
        return embeddings[0].cpu().numpy().tolist()

segments_with_info_df = pd.DataFrame(yt.read_table(segments_with_info_table))
segments_with_info_df[vector_field] = segments_with_info_df[description_field].apply(GetTextEmbedding(labse_tokenizer, labse_model))

yt.write_table(
    segments_vectors_with_info_table,
    segments_with_info_df.to_dict('records'),
)
