from datetime import timedelta
import os

from crypta.affinitive_geo.services.org_embeddings.lib.utils import (
    config,
    utils,
)
from crypta.lib.python import templater
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.profile.lib import date_helpers


get_banner_rt_sadovaya_vector_query = u'''
PRAGMA yt.UseNativeYtTypes;
PRAGMA File('bigb.so', '{{bigb_udf_url}}');
PRAGMA udf('bigb.so');

$get_caesar_vector = ($tsar_vectors, $vector_id) -> {
    $vector = ListFilter(
        $tsar_vectors,
        ($vectors) -> {RETURN $vectors.VectorID == $vector_id;}
    );

    RETURN $vector[0].CompressedVector;
};

$uint8_bytes_to_list_script = @@
def bytes_to_list(bytes_list):
    import numpy as np

    if bytes_list is None:
        return None
    else:
        return [int(elem) for elem in np.frombuffer(bytes_list, dtype=np.uint8)]
@@;
$uint8_bytes_to_list_udf = Python3::bytes_to_list(Callable<(String?)->List<Uint64>>,
    $uint8_bytes_to_list_script);

$dequantize_vector = ($vector) -> {
    RETURN ListMap(
        $vector,
        ($elem) -> {RETURN {{compress_float_min}} + $elem * ({{compress_float_max}} - {{compress_float_min}}) / {{bin_count}};}
    )
};

INSERT INTO `{{caesar_info_table}}`
WITH TRUNCATE

SELECT
    BannerID AS bannerid,
    Bigb::ParseBannerProfile(TableRow()).Resources.GroupBannerID as groupbannerid,
    Bigb::ParseBannerProfile(TableRow()).Resources.Body AS banner_body,
    $dequantize_vector($uint8_bytes_to_list_udf($get_caesar_vector(
            Bigb::ParseBannerProfile(TableRow()).TsarVectors.Vectors, {{rt_sadovaya_vector_id}}
    ))) AS rt_sadovaya_vector,
FROM `{{caesar_latest_dump}}` AS caesar
INNER JOIN `{{active_banners_table}}` AS active_banners
ON caesar.BannerID == active_banners.bannerid
WHERE Bigb::ParseBannerProfile(TableRow()).Resources.GroupBannerID is not Null
    AND $get_caesar_vector(
        Bigb::ParseBannerProfile(TableRow()).TsarVectors.Vectors, {{rt_sadovaya_vector_id}}) is not Null
ORDER BY groupbannerid;
'''

get_banner_avg_distances_query = u'''
{{normalize_vector_query}}

{{dot_product_query}}

INSERT INTO `{{banner_avg_distances_table}}`
SELECT
    '{{date}}' AS `date`,
    COALESCE(AVG(1. - $dot_product(
        $normalize_vector(cur.rt_sadovaya_vector),
        $normalize_vector(prev.rt_sadovaya_vector),
    )), 1.) AS avg_distance,
FROM `{{caesar_info_cur_table}}` AS cur
INNER JOIN `{{caesar_info_prev_table}}` AS prev
USING(bannerid)
ORDER BY `date`;
'''


def get(yt_client, yql_client, date):
    caesar_info_table = os.path.join(config.CAESAR_INFO_DIR, date)

    with NirvanaTransaction(yt_client) as transaction:
        yql_client.execute(
            templater.render_template(
                get_banner_rt_sadovaya_vector_query,
                vars={
                    'bigb_udf_url': config.BIGB_UDF_URL,
                    'compress_float_min': config.COMPRESS_FLOAT_MIN,
                    'compress_float_max': config.COMPRESS_FLOAT_MAX,
                    'bin_count': config.BIN_COUNT,
                    'rt_sadovaya_vector_id': config.RT_SADOVAYA_VECTOR_ID,
                    'caesar_latest_dump': config.CAESAR_LATEST_DUMP,
                    'active_banners_table': config.ACTIVE_BANNERS_TABLE,
                    'caesar_info_table': caesar_info_table,
                },
            ),
            title='YQL affinitive geo get banner RT Sadovaya vectors',
            transaction=str(transaction.transaction_id),
        )

        yt_helpers.set_ttl(
            table=caesar_info_table,
            ttl_timedelta=timedelta(days=config.DAYS_TO_STORE_DAILY_TABLES),
            yt_client=yt_client,
        )

        yql_client.execute(
            templater.render_template(
                get_banner_avg_distances_query,
                vars={
                    'normalize_vector_query': utils.normalize_vector_query,
                    'dot_product_query': utils.dot_product_query,
                    'date': date,
                    'caesar_info_cur_table': caesar_info_table,
                    'caesar_info_prev_table': os.path.join(config.CAESAR_INFO_DIR, date_helpers.get_yesterday(date)),
                    'banner_avg_distances_table': config.BANNER_AVG_DISTANCES_TABLE,
                },
            ),
            title='YQL affinitive geo banner avg distances',
            transaction=str(transaction.transaction_id),
        )
