#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from yt.wrapper import with_context

from crypta.lib.python import templater
from crypta.profile.lib import (
    date_helpers,
    vector_helpers,
)


dict_sum_yql = """
$dict_sum = AggregationFactory(
    "UDAF",
    ($x) -> ($x),
    ($x, $y) -> (SetUnion($x, $y, ($_, $v1, $v2) -> (($v1 ?? 0) + ($v2 ?? 0)))),
    ($x, $y) -> (SetUnion($x, $y, ($_, $v1, $v2) -> (($v1 ?? 0) + ($v2 ?? 0))))
);
"""

hits_merger_yql = """
$calculate_weight = ($date) -> {
    RETURN Math::Pow(2., -DateTime::ToDays(CAST('{{end_date}}' as date) - CAST($date as date)) / {{halflife_days}});
};
$mul_dict_values = ($dict, $multiplier) -> {
    RETURN ToDict(ListMap(DictItems($dict), ($item) -> {
        RETURN AsTuple($item.0, $item.1 * $multiplier); }));
};
{{dict_sum}}

$hits = (
    SELECT
        yandexuid,
        Yson::ConvertToUint64Dict(site_weights) AS raw_site_weights,
        $mul_dict_values(
            Yson::ConvertToUint64Dict(site_weights), $calculate_weight(TableName())) AS site_weights,
        TableName() AS `date`,
    FROM RANGE(
        `{{hits_dir}}`,
        `{{begin_date}}`,
        `{{end_date}}`
    )
    WHERE ListSum(DictPayloads(Yson::ConvertToUint64Dict(site_weights))) >= {{minimum_hits_per_day}}
);

INSERT INTO `{{merged_hits_table}}`
WITH TRUNCATE

SELECT
    yandexuid,
    Yson::Serialize(Yson::FromUint64Dict(AGGREGATE_BY(raw_site_weights, $dict_sum))) AS raw_site_weights,
    Yson::Serialize(Yson::FromUint64Dict(AGGREGATE_BY(site_weights, $dict_sum))) AS site_weights,
FROM $hits
GROUP BY yandexuid
HAVING MAX(`date`) >= '{{last_active_date}}'
    AND DictLength(AGGREGATE_BY(raw_site_weights, $dict_sum)) >= {{minimum_sites}}
    AND DictLength(AGGREGATE_BY(raw_site_weights, $dict_sum)) <= {{maximum_sites}}
ORDER BY yandexuid;
"""

get_cryptaid_merged_hits_yql = """
{{dict_sum}}

INSERT INTO `{{merged_hits_by_cryptaid_table}}`
WITH TRUNCATE

SELECT
    crypta_id AS crypta_id,
    Yson::Serialize(Yson::FromUint64Dict(AGGREGATE_BY(Yson::ConvertToUint64Dict(
        raw_site_weights), $dict_sum))) AS raw_site_weights,
    Yson::Serialize(Yson::FromDouble64Dict(AGGREGATE_BY(Yson::ConvertToDoubleDict(
        site_weights), $dict_sum))) AS site_weights,
FROM `{{merged_hits_by_yandexuid_table}}` AS merged_hits
INNER JOIN `{{yandexuid_cryptaid_table}}` AS yandexuid_cryptaid
ON merged_hits.yandexuid == yandexuid_cryptaid.yandexuid
GROUP BY yandexuid_cryptaid.crypta_id AS crypta_id
ORDER BY crypta_id;
"""

flatten_hits_query = """
INSERT INTO `{output_table}` WITH TRUNCATE
SELECT
    {id_type},
    site_weight.0 AS host,
    site_weight.1 AS weight,
FROM (
    SELECT
        {id_type},
        Yson::ConvertToDoubleDict(site_weights) AS site_weights,
    FROM `{input_table}`
)
FLATTEN DICT BY site_weights AS site_weight
ORDER BY host;
"""


calculate_idf_query = """
PRAGMA yt.DefaultMemoryLimit = '8G';

$total_ids_count = (
    SELECT COUNT(*)
    FROM `{merged_hits_by_id_table}`
);

INSERT INTO `{output_table}` WITH TRUNCATE
SELECT
    host,
    COUNT(*) AS `count`,
    Math::Log(CAST($total_ids_count AS Double) / COUNT(*)) AS idf,
    CAST(CAST($total_ids_count AS Double) / COUNT(*) AS Uint64) AS weight,
FROM `{flattened_hits_table}`
GROUP BY host
ORDER BY host;
"""


@with_context
def calculate_user_host_vector_reducer(key, rows, context):
    idf = None
    vector = None
    for row in rows:
        if context.table_index == 0:
            vector = vector_helpers.binary_to_numpy(row['vector'])
        elif context.table_index == 1:
            idf = row['idf']
        elif context.table_index == 2 and idf and \
                isinstance(vector, np.ndarray):
            weight = row['weight']
            result = vector * weight * idf
            yield {
                'yandexuid': row['yandexuid'],
                'vector': result.tostring(),
            }


def merge_hits_by_yandexuid(
    yt_client,
    yql_client,
    date,
    daily_hits_directory,
    merged_hits_by_yandexuid_table,
    number_of_days_to_merge,
    halflife,
    min_hosts,
    max_hosts,
    min_hits_per_day,
    last_active_date,
):
    with yt_client.Transaction():
        if yt_client.exists(merged_hits_by_yandexuid_table):
            yt_client.remove(merged_hits_by_yandexuid_table)
        begin_date = date_helpers.get_date_from_past(date, days=number_of_days_to_merge-1)
        yql_client.execute(
            templater.render_template(
                hits_merger_yql,
                vars={
                    'dict_sum': dict_sum_yql,
                    'halflife_days': halflife,
                    'minimum_hits_per_day': min_hits_per_day,
                    'minimum_sites': min_hosts,
                    'maximum_sites': max_hosts,
                    'last_active_date': last_active_date,
                    'hits_dir': daily_hits_directory,
                    'begin_date': begin_date,
                    'end_date': date,
                    'merged_hits_table': merged_hits_by_yandexuid_table,
                },
            ),
            title='YQL Merge hits by yandexuid',
        )

        yt_client.set_attribute(merged_hits_by_yandexuid_table, 'generate_date', date)
        yt_client.set_attribute(merged_hits_by_yandexuid_table, 'begin_date', begin_date)
        yt_client.set_attribute(merged_hits_by_yandexuid_table, 'end_date', date)
        yt_client.set_attribute(merged_hits_by_yandexuid_table, 'last_active_date', last_active_date)
        yt_client.set_attribute(merged_hits_by_yandexuid_table, 'halflife', halflife)
        yt_client.set_attribute(merged_hits_by_yandexuid_table, 'min_hits', min_hits_per_day)
        yt_client.set_attribute(merged_hits_by_yandexuid_table, 'min_hosts', min_hosts)
        yt_client.set_attribute(merged_hits_by_yandexuid_table, 'max_hosts', max_hosts)


def flatten_hits(yt_client, yql_client, date, merged_hits_by_yandexuid_table, flattened_hits_table):
    with yt_client.Transaction() as transaction:
        yql_client.execute(
            flatten_hits_query.format(
                input_table=merged_hits_by_yandexuid_table,
                output_table=flattened_hits_table,
                id_type='yandexuid',
            ),
            transaction=transaction.transaction_id,
        )

        yt_client.set_attribute(flattened_hits_table, 'generate_date', date)


def calculate_host_idf(
    yt_client,
    yql_client,
    date,
    merged_hits_by_yandexuid_table,
    flattened_hits_table,
    idf_table,
):
    with yt_client.Transaction() as transaction:
        yql_client.execute(
            calculate_idf_query.format(
                merged_hits_by_id_table=merged_hits_by_yandexuid_table,
                flattened_hits_table=flattened_hits_table,
                output_table=idf_table,
            ),
            transaction=transaction.transaction_id,
            erasure_codec=None,
            compression_codec=None,
        )

        yt_client.set_attribute(idf_table, 'generate_date', date)


def get_yandexuid_vectors(
    yt_client,
    date,
    site_vectors_table,
    flattened_hits_table,
    idf_table,
    yandexuid_vectors_table,
):
    with yt_client.Transaction():
        if yt_client.exists(yandexuid_vectors_table):
            yt_client.remove(yandexuid_vectors_table)

        yt_client.create(
            'table',
            yandexuid_vectors_table,
            recursive=True,
            attributes={
                'schema': [
                    {'name': 'yandexuid', 'type': 'uint64'},
                    {'name': 'vector', 'type': 'string'},
                ],
                'optimize_for': 'scan',
            },
        )

        yt_client.run_join_reduce(
            calculate_user_host_vector_reducer,
            [
                yt_client.TablePath(site_vectors_table, foreign=True),
                yt_client.TablePath(idf_table, foreign=True),
                flattened_hits_table,
            ],
            yandexuid_vectors_table,
            join_by='host',
        )

        yt_client.run_map_reduce(
            None,
            vector_helpers.sum_vectors_reducer,
            yandexuid_vectors_table,
            yandexuid_vectors_table,
            reduce_by='yandexuid',
            reduce_combiner=vector_helpers.sum_vectors_reducer,
        )

        yt_client.run_sort(yandexuid_vectors_table, sort_by='yandexuid')
        yt_client.set_attribute(yandexuid_vectors_table, 'generate_date', date)
