from datetime import timedelta
import os

from crypta.affinitive_geo.services.org_embeddings.lib.utils import (
    config,
    utils,
)
from crypta.lib.python import templater
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.profile.lib import date_helpers


make_bases_query = u'''
PRAGMA yt.UseNativeYtTypes;

$org_banners = (
    SELECT
        affinities.permalink AS permalink,
        DictKeys(affinities.top_banners) AS top_banner_groups,
        org_weights.geo_id AS region_id,
        org_weights.weight AS weight,
    FROM `{{org_affinitive_banners_table}}` AS affinities
    INNER JOIN `{{orgs_weights_table}}` AS org_weights
    USING(permalink)
);

$org_banners_flatten = (
    SELECT
        permalink,
        CAST(groupbannerid AS Uint64) AS groupbannerid,
        region_id,
        weight,
    FROM $org_banners
    FLATTEN LIST BY top_banner_groups AS groupbannerid
);

$caesar_by_groups = (
    SELECT
        groupbannerid,
        SOME(rt_sadovaya_vector) AS rt_sadovaya_vector,
    FROM `{{caesar_info_table}}`
    GROUP BY groupbannerid
);

$list_by_component_sum = AggregationFactory(
    "UDAF",
    ($x, $_) -> ($x),
    ($x, $y, $_) -> (ListMap(ListZip($x, $y), ($elem) -> { RETURN $elem.0 + $elem.1; })),
    ($x, $y) -> (ListMap(ListZip($x, $y), ($elem) -> { RETURN $elem.0 + $elem.1; }))
);

{{normalize_vector_query}}

INSERT INTO `{{org_embedding_base_table}}`
WITH TRUNCATE

SELECT
    org.permalink AS permalink,
    $normalize_vector(AGGREGATE_BY(rt_sadovaya_vector, $list_by_component_sum)) AS embedding,
    COALESCE(SOME(region_id), 0) AS region_id,
    COALESCE(SOME(weight), 1) AS weight,
FROM $org_banners_flatten AS org
INNER JOIN $caesar_by_groups AS caesar
USING(groupbannerid)
GROUP BY org.permalink
ORDER BY permalink;

INSERT INTO `{{geohash_org_base_table}}`
WITH TRUNCATE

SELECT
    geohash,
    AGGREGATE_LIST_DISTINCT(geohash_permalink.permalink) AS permalinks,
FROM `{{geohash_permalink_table}}` AS geohash_permalink
INNER JOIN $org_banners AS selected_orgs
USING(permalink)
GROUP BY geohash_permalink.geohash AS geohash
ORDER BY geohash;
'''

get_org_avg_distances_query = u'''
{{dot_product_query}}

INSERT INTO `{{org_avg_distances_table}}`
SELECT
    '{{date}}' AS `date`,
    COALESCE(AVG(1. - $dot_product(
        cur.embedding,
        prev.embedding,
    )), 1.) AS avg_distance,
FROM `{{org_embedding_base_cur_table}}` AS cur
INNER JOIN `{{org_embedding_base_prev_table}}` AS prev
USING(permalink)
ORDER BY `date`;
'''

get_yandex_top_banners_query = u'''
{{normalize_vector_query}}

{{dot_product_query}}

$by_groupbannerid = (
    SELECT
        groupbannerid,
        SOME(banner_body) AS banner_body,
        SOME(rt_sadovaya_vector) AS rt_sadovaya_vector,
    FROM `{{caesar_info_table}}`
    GROUP BY groupbannerid
);

$ref_vector = (
    SELECT SOME(embedding)
    FROM `{{org_embedding_base_table}}`
    WHERE permalink == {{yandex_permalink}}
);

INSERT INTO `{{yandex_top_banners_table}}`
WITH TRUNCATE

SELECT
    'Яндекс' AS org_name,
    groupbannerid,
    banner_body,
    1. - $dot_product($normalize_vector(rt_sadovaya_vector), $ref_vector) AS distance,
FROM $by_groupbannerid
ORDER BY distance
LIMIT {{top_banners_for_dash_cnt}};
'''


def make(yt_client, yql_client, date):
    org_embedding_base_daily_table = os.path.join(config.ORG_EMBEDDING_BASE_DAILY_DIR, date)

    with NirvanaTransaction(yt_client) as transaction:
        yql_client.execute(
            query=templater.render_template(
                make_bases_query,
                vars={
                    'normalize_vector_query': utils.normalize_vector_query,
                    'org_affinitive_banners_table': os.path.join(config.ORG_AFFINITIVE_BANNERS_DIR, date),
                    'orgs_weights_table': os.path.join(config.ORGS_WEIGHTS_DIR, date),
                    'caesar_info_table': os.path.join(config.CAESAR_INFO_DIR, date),
                    'geohash_permalink_table': max(yt_client.list(config.GEOHASH_PERMALINK_DIR, absolute=True)),
                    'org_embedding_base_table': config.ORG_EMBEDDING_BASE_TABLE,
                    'geohash_org_base_table': config.GEOHASH_ORG_BASE_TABLE,
                }
            ),
            transaction=str(transaction.transaction_id),
            title='YQL affinitive geo make bases',
        )

        yt_client.copy(
            config.ORG_EMBEDDING_BASE_TABLE,
            org_embedding_base_daily_table,
            force=True,
        )

        yt_helpers.set_ttl(
            table=org_embedding_base_daily_table,
            ttl_timedelta=timedelta(days=config.DAYS_TO_STORE_DAILY_TABLES),
            yt_client=yt_client,
        )

        yql_client.execute(
            templater.render_template(
                get_org_avg_distances_query,
                vars={
                    'dot_product_query': utils.dot_product_query,
                    'date': date,
                    'org_embedding_base_cur_table': org_embedding_base_daily_table,
                    'org_embedding_base_prev_table': os.path.join(config.ORG_EMBEDDING_BASE_DAILY_DIR, date_helpers.get_yesterday(date)),
                    'org_avg_distances_table': config.ORG_AVG_DISTANCES_TABLE,
                },
            ),
            title='YQL affinitive geo org avg distances',
            transaction=str(transaction.transaction_id),
        )

        yql_client.execute(
            templater.render_template(
                get_yandex_top_banners_query,
                vars={
                    'normalize_vector_query': utils.normalize_vector_query,
                    'dot_product_query': utils.dot_product_query,
                    'yandex_permalink': config.YANDEX_PERMALINK,
                    'top_banners_for_dash_cnt': config.TOP_BANNERS_FOR_DASH_CNT,
                    'caesar_info_table': os.path.join(config.CAESAR_INFO_DIR, date),
                    'org_embedding_base_table': config.ORG_EMBEDDING_BASE_TABLE,
                    'yandex_top_banners_table': config.YANDEX_TOP_BANNERS_TABLE,
                },
            ),
            title='YQL affinitive geo yandex top banners',
            transaction=str(transaction.transaction_id),
        )
