from datetime import timedelta
import os

from crypta.affinitive_geo.services.org_embeddings.lib.utils import (
    config,
    utils,
)
from crypta.lib.python import templater
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.profile.lib import date_helpers


get_org_weights_query = u'''
{{normalize_vector_query}}

{{dot_product_query}}

$org_city_vectors = (
    SELECT
        CAST(orgs.GroupID AS Uint64) AS permalink,
        orgs.segment_vector AS org_vector,
        regions.segment_vector AS city_vector,
        orgs_info.name AS title,
        orgs_info.main_rubric_name_ru AS category,
        orgs_info.lat AS lat,
        orgs_info.lon AS lon,
        orgs_info.region_name AS region_name,
        orgs_info.region_id AS region_id,
        orgs_info.geo_id AS geo_id,
    FROM `{{orgs_dssm_vectors_table}}` AS orgs
    INNER JOIN `{{orgs_info_table}}` AS orgs_info
    ON CAST(orgs.GroupID AS Uint64) == orgs_info.permalink
    INNER JOIN `{{regions_dssm_vectors_table}}` AS regions
    ON orgs_info.region_id == CAST(regions.GroupID AS Int32)
);

$org_distances = (
    SELECT
        1. - $dot_product(
            $normalize_vector(org_vector), $normalize_vector(city_vector)) AS distance,
        data.* WITHOUT org_vector, city_vector,
    FROM $org_city_vectors AS data
);

$orgs_cnt = (
    SELECT
        region_id,
        COUNT(*) AS orgs_cnt,
    FROM $org_distances
    GROUP BY region_id
);

$org_distances = (
    SELECT
        data.*,
        counts.orgs_cnt AS orgs_cnt,
    FROM $org_distances AS data
    INNER JOIN $orgs_cnt AS counts
    USING(region_id)
);

INSERT INTO `{{orgs_weights_table}}`
WITH TRUNCATE

SELECT
    CAST((RANK() OVER w - 1.) * 100 AS Int64) / orgs_cnt + 1 AS weight,
    orgs_cnt + 1 - RANK() OVER w AS `rank`,
    data.* WITHOUT distance, orgs_cnt,
FROM $org_distances AS data
WINDOW w AS (
    PARTITION BY region_id
    ORDER BY distance
)
ORDER BY region_id, rank;
'''

get_org_weight_rank_correlation_query = u'''
INSERT INTO `{{org_weight_rank_correlation_table}}`
SELECT
    '{{date}}' AS `date`,
    COALESCE(CORRELATION(cur.`rank`, prev.`rank`), 0.) AS rank_correlation,
FROM `{{orgs_weights_cur_table}}` AS cur
INNER JOIN `{{orgs_weights_prev_table}}` AS prev
USING(permalink)
ORDER BY `date`;
'''


def calculate(yt_client, yql_client, date):
    orgs_weights_table = os.path.join(config.ORGS_WEIGHTS_DIR, date)

    with NirvanaTransaction(yt_client) as transaction:
        yql_client.execute(
            query=templater.render_template(
                get_org_weights_query,
                vars={
                    'normalize_vector_query': utils.normalize_vector_query,
                    'dot_product_query': utils.dot_product_query,
                    'orgs_dssm_vectors_table': config.ORGS_DSSM_VECTORS_TABLE,
                    'orgs_info_table': config.ORGS_INFO_TABLE,
                    'regions_dssm_vectors_table': config.REGIONS_DSSM_VECTORS_TABLE,
                    'orgs_weights_table': orgs_weights_table,
                }
            ),
            transaction=str(transaction.transaction_id),
            title='YQL affinitive geo get organization weights',
        )

        yt_helpers.set_ttl(
            table=orgs_weights_table,
            ttl_timedelta=timedelta(days=config.DAYS_TO_STORE_DAILY_TABLES),
            yt_client=yt_client,
        )

        yql_client.execute(
            templater.render_template(
                get_org_weight_rank_correlation_query,
                vars={
                    'date': date,
                    'orgs_weights_cur_table': orgs_weights_table,
                    'orgs_weights_prev_table': os.path.join(config.ORGS_WEIGHTS_DIR, date_helpers.get_yesterday(date)),
                    'org_weight_rank_correlation_table': config.ORG_WEIGHT_RANK_CORRELATION_TABLE,
                },
            ),
            title='YQL affinitive geo org weight rank correlation',
            transaction=str(transaction.transaction_id),
        )
