#!/usr/bin/env python
# -*- coding: utf-8 -*-

import functools
import logging

import numpy as np

from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.lookalike.lib.python.utils import (
    fields,
    utils,
)
from crypta.lookalike.lib.python.utils.config import config


logger = logging.getLogger(__name__)


def get_segments(yt_client, segments_vectors_table_path):
    segments = []
    for row in yt_client.read_table(segments_vectors_table_path):
        segments.append({
            fields.group_id: row[fields.group_id],
            fields.vector: utils.normalize(row['segment_vector']),
        })
    return segments


def calculate_dssm_lal_scores_mapper(row, segments):
    user_vector = utils.normalize(row['user_vector'])
    target_segments = row['segments']

    for segment in segments:
        segment_group_id = segment[fields.group_id]
        label = 1 if target_segments and segment_group_id in target_segments else 0

        yield {
            fields.yandexuid: row[fields.yandexuid],
            fields.group_id: segment_group_id,
            fields.distance: 1.0 - np.dot(segment[fields.vector], user_vector),
            fields.label: label,
        }


def calculate(nv_params):
    yt_client = utils.get_yt_client(nv_params=nv_params)

    segments = get_segments(
        yt_client=yt_client,
        segments_vectors_table_path=config.TEST_SEGMENTS_DSSM_VECTORS,
    )

    with NirvanaTransaction(yt_client):
        yt_helpers.create_empty_table(
            yt_client=yt_client,
            path=config.TEST_DSSM_LAL_DISTANCES,
            schema={
                fields.yandexuid: 'uint64',
                fields.group_id: 'string',
                fields.distance: 'double',
                fields.label: 'uint64',
            },
            additional_attributes={'optimize_for': 'scan'},
            force=True,
        )

        yt_client.run_map(
            functools.partial(calculate_dssm_lal_scores_mapper, segments=segments),
            config.TEST_USERS_DSSM_VECTORS,
            config.TEST_DSSM_LAL_DISTANCES,
        )
