#!/usr/bin/env python
# -*- coding: utf-8 -*-
import luigi
from yt.wrapper import (
    create_table_switch,
    with_context,
)
from crypta.profile.lib import (
    date_helpers,
    vector_helpers,
)
from crypta.profile.user_vectors import lib as user_vectors
from crypta.profile.utils import utils
from crypta.profile.utils.config import config
from crypta.profile.utils.loggers import send_to_graphite
from crypta.profile.utils.luigi_utils import BaseYtTask, ExternalInput, YtDailyRewritableTarget

from crypta.profile.tasks.features.calculate_host_idf import CalculateHostIdf
from crypta.profile.tasks.features.flatten_hits import FlattenHitsBySite
from crypta.profile.tasks.features.get_app_metrica_data import UpdateMonthlyDevidStorage


VECTORS_WITH_YUID_TABLE_IDX = 0
ALL_VECTORS_TABLE_IDX = 1


def sum_vectors_mapper(row):
    result = sum(map(lambda vector: vector_helpers.binary_to_numpy(vector), row['vectors']))

    output = {
        'crypta_id': row['crypta_id'],
        'days_active': row['days_active'],
        'vector': result.tostring(),
    }

    yield create_table_switch(ALL_VECTORS_TABLE_IDX)
    yield output

    if row['yuids_num'] > 0:
        yield create_table_switch(VECTORS_WITH_YUID_TABLE_IDX)
        yield output


class GetDailyYandexuidVectors(BaseYtTask):
    date = luigi.Parameter()
    priority = 100
    task_group = 'export_profiles'

    def requires(self):
        return {
            'host_vectors': ExternalInput(config.SITE2VEC_VECTORS_TABLE),
            'metrics_flattened_hits': FlattenHitsBySite(self.date, data_source='metrics', id_type='yandexuid'),
            'metrics_host_idf': CalculateHostIdf(self.date, data_source='metrics', id_type='yandexuid'),
        }

    def output(self):
        return YtDailyRewritableTarget(config.DAILY_YANDEXUID2VEC, self.date)

    def run(self):
        with self.yt.Transaction():
            self.yt.create_empty_table(
                self.output().table,
                schema={
                    'yandexuid': 'uint64',
                    'vector': 'string',
                },
            )

            self.yt.run_join_reduce(
                user_vectors.calculate_user_host_vector_reducer,
                [
                    self.yt.TablePath(self.input()['host_vectors'].table, foreign=True),
                    self.yt.TablePath(self.input()['metrics_host_idf'].table, foreign=True),
                    self.input()['metrics_flattened_hits'].table,
                ],
                self.output().table,
                join_by='host',
                spec={'title': 'Building yandexuid2vec'},
            )

            self.yt.run_map_reduce(
                None,
                vector_helpers.sum_vectors_reducer,
                self.output().table,
                self.output().table,
                reduce_by='yandexuid',
                reduce_combiner=vector_helpers.sum_vectors_reducer,
                spec={'title': 'Summarizing vectors'},
            )

            self.yt.run_sort(self.output().table, sort_by='yandexuid')
            send_to_graphite(
                'yandexuid2vec',
                self.yt.row_count(self.output().table),
            )

            self.yt.set_attribute(self.output().table, 'generate_date', self.date)


@with_context
class MonthlyVectorsReducer(object):
    def __init__(self, date, number_of_days_to_keep_vectors):
        self.date = date
        self.last_date_to_keep = date_helpers.get_date_from_past(self.date, number_of_days_to_keep_vectors - 1)

    def remove_old_dates(self, dates):
        for index, current_date in enumerate(dates):
            if current_date >= self.last_date_to_keep:
                return dates[index:]
        return []

    def __call__(self, key, records, context):
        daily_record = None
        monthly_record = None
        for record in records:
            if context.table_index == 0:
                daily_record = record
            elif context.table_index == 1:
                monthly_record = record

        if monthly_record:
            monthly_record['days_active'] = self.remove_old_dates(monthly_record['days_active'])

            if daily_record:
                daily_record['days_active'] = monthly_record['days_active']
                # do not append same date multiple times
                if not daily_record['days_active'] or daily_record['days_active'][-1] < self.date:
                    daily_record['days_active'].append(self.date)
                yield daily_record

            elif monthly_record['days_active']:
                yield monthly_record

        elif daily_record and not monthly_record:
            daily_record['days_active'] = [self.date]
            yield daily_record


class GetMonthlyYandexuidVectors(BaseYtTask):
    date = luigi.Parameter()
    priority = 100
    task_group = 'export_profiles'

    def requires(self):
        return GetDailyYandexuidVectors(self.date)

    def output(self):
        return YtDailyRewritableTarget(config.MONTHLY_YANDEXUID2VEC, self.date)

    def run(self):
        with self.yt.Transaction():
            if not self.yt.exists(self.output().table):
                self.yt.create_empty_table(
                    self.output().table,
                    schema=utils.monthly_yandexuid_vector_schema,
                )

            source_tables = [self.input().table, self.output().table]
            for table in source_tables:
                self.yt.sort_if_needed(table, sort_by='yandexuid')

            self.yt.run_reduce(
                MonthlyVectorsReducer(
                    date_helpers.get_yesterday(self.date),
                    config.STANDARD_AGGREGATION_PERIOD,
                ),
                source_tables,
                self.output().table,
                reduce_by='yandexuid',
            )

            self.yt.run_sort(self.output().table, sort_by='yandexuid')
            self.yt.set_attribute(self.output().table, 'generate_date', self.date)


join_vectors_with_crypta_id_query = """
$vectors = (
    SELECT
        CAST(yandexuid AS String) AS id,
        'yandexuid' AS id_type,
        vector,
        Yson::ConvertToStringList(days_active) AS days_active,
    FROM `{yandexuid_vectors_table}`
UNION ALL
    SELECT
        id,
        id_type,
        vector,
        Yson::ConvertToStringList(days_active) AS days_active,
    FROM `{devid_vectors_table}`
);

$grouped_by_crypta_id = (
    SELECT
        SUM(CAST(vectors.id_type == 'yandexuid' AS Uint64)) AS yuids_num,
        AGGREGATE_LIST(vectors.vector) AS vectors,
        ListUniq(ListFlatMap(AGGREGATE_LIST(vectors.days_active), ($x) -> ($x))) AS days_active,
        CAST(vertices_no_multi_profile.cryptaId AS Uint64) AS crypta_id,
    FROM $vectors AS vectors
    INNER JOIN `{vertices_no_multi_profile_table}` AS vertices_no_multi_profile
    USING (id, id_type)
    GROUP BY vertices_no_multi_profile.cryptaId
);

INSERT INTO `{intermediate_table}` WITH TRUNCATE
SELECT
    crypta_id,
    days_active,
    vectors,
    yuids_num,
FROM $grouped_by_crypta_id;
"""


class GetCryptaIdVectors(BaseYtTask):
    date = luigi.Parameter()
    priority = 90
    task_group = 'export_profiles'

    def requires(self):
        return {
            'yandexuid_vectors': GetMonthlyYandexuidVectors(self.date),
            'devid_vectors': UpdateMonthlyDevidStorage(self.date, 'vector'),
        }

    def output(self):
        return {
            'vectors_with_yuid': YtDailyRewritableTarget(config.MONTHLY_CRYPTAID2VEC, date=self.date),
            'all_vectors': YtDailyRewritableTarget(config.ALL_MONTHLY_CRYPTAID2VEC, date=self.date),
        }

    def run(self):
        with self.yt.Transaction() as transaction, \
                self.yt.TempTable() as intermediate_table:
            self.yql.query(
                join_vectors_with_crypta_id_query.format(
                    yandexuid_vectors_table=self.input()['yandexuid_vectors'].table,
                    devid_vectors_table=self.input()['devid_vectors'].table,
                    vertices_no_multi_profile_table=config.VERTICES_NO_MULTI_PROFILE,
                    intermediate_table=intermediate_table,
                ),
                transaction=transaction,
            )

            output_tables = [self.output()['vectors_with_yuid'].table, self.output()['all_vectors'].table]

            for table in output_tables:
                self.yt.create_empty_table(
                    table,
                    schema={
                        'crypta_id': 'uint64',
                        'days_active': 'any',
                        'vector': 'string',
                    },
                )

            self.yt.run_map(sum_vectors_mapper, intermediate_table, output_tables)

            for table in output_tables:
                self.yt.run_sort(table, sort_by='crypta_id')
                self.yt.set_attribute(table, 'generate_date', self.date)
