import click
import logging
from nile.api.v1 import clusters
from nile.api.v1.record import Record
from customer_service.ml.lib.data.knowledges import (
    prepare_actual_knowledges,
    match_knowledges_to_df
)
from customer_service.ml.lib.data.utils import create_target


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@click.command()
@click.option('--cluster', default='hahn', help='Cluster in YT')
@click.option('--yt_alias', default='customer-service-yt', help='Alias for YT in yandex vault')
@click.option('--input_data', required=True, help='Path to input dataset in YT')
@click.option('--output_data', required=True, help='Path to output dataset in YT')
@click.option('--product_tag', required=True, help='Name of the product')
@click.option('--knowledges_url', 
              default='https://support-private.s3.mds.yandex.net/knowledge-base/snapshots/snapshot_2022-04-25.gz', 
              help='URL to knowledges snapshot')
def main(cluster: str,
         yt_alias: str,
         input_data: str,
         output_data: str,
         product_tag: str,
         knowledges_url: str):

    actual_knowledges = prepare_actual_knowledges(knowledges_url, product_tag)
    cluster = clusters.YT(cluster)
    records = cluster.read(input_data)
    df = records.as_dataframe()
    
    logger.info('Matching knowledges')
    df = match_knowledges_to_df(df, actual_knowledges)
    logger.info('Knowledges matched')

    filtered = df[df['distance'] < 0.6].copy()
    filtered['target'] = filtered['matched_categories'].apply(lambda x: create_target(x))

    for level in range(1, 6):
        filtered[f'target_level_{level}'] = filtered['matched_categories'].apply(lambda x: create_target(x, level))

    records = [Record(**x) for x in filtered.to_dict(orient='records')]
    cluster.write(output_data, records)


if __name__ == '__main__':
    main()



