#!/usr/bin/env python
# -*- coding: utf-8 -*-
import collections
import functools
import math
import os

import luigi
from yt.wrapper import aggregator, with_context

from crypta.profile.lib import (
    date_helpers,
    vector_helpers,
)
from crypta.profile.utils import utils
from crypta.profile.utils.config import config
from crypta.profile.utils.loggers import TimeTracker
from crypta.profile.utils.luigi_utils import (
    BaseYtTask,
    ExternalInput,
    OldNodesByNameCleaner,
    YtDailyRewritableTarget,
    YtTarget,
)


prepare_daily_data_query = """
$get_apps = ($app_list) -> {{
    $lower_apps = ListFlatMap($app_list, ($app_name) -> {{ RETURN String::ToLower($app_name); }});

    RETURN ListFilter(
        ListUniq($lower_apps),
        ($app_name) -> {{ RETURN LENGTH($app_name) <= 300; }}
    );
}};

INSERT INTO `{output_table}` WITH TRUNCATE
SELECT
    id,
    id_type,
    $get_apps(apps) AS apps,
    ua_profile,
    model,
    manufacturer,
    region_ids
FROM `{daily_input_table}`
WHERE id_type IN ('idfa', 'gaid') AND ListLength($get_apps(apps)) <= 300
ORDER BY id, id_type
"""

merge_app_metrica_data_query = """
$all_info = (
    SELECT
        id,
        id_type,
        apps,
        ua_profile,
        model,
        manufacturer,
        region_ids,
        TableName() AS `date`
    FROM RANGE(`{daily_app_metrica_directory}`, `{start_date}`, `{end_date}`)
);

$active = (
    SELECT id, id_type
    FROM `{daily_app_metrica_directory}/{end_date}`
);

$info_about_active = (
    SELECT *
    FROM $all_info AS all_info
    INNER JOIN $active AS active
    USING (id, id_type)
);

$sum_dict_create = ($item, $parent) -> {{
    RETURN AsDict(($item, 1));
}};

$sum_dicts_create = ($item, $parent) -> {{
    RETURN $item;
}};

$sum_dict = ($dict1, $dict2) -> {{
    -- Merge dicts with sum duplicated key values
    return SetUnion($dict1, $dict2, ($k, $a, $b) -> {{
        return ($a ?? 0) + ($b ?? 0);
    }});
}};

$sum_dict_add = ($state, $item, $parent) -> {{ -- $state
    RETURN $sum_dict($state, $sum_dict_create($item, $parent));
}};

$sum_dicts_add = ($state, $item, $parent) -> {{ -- $state
    RETURN $sum_dict($state, $item);
}};

$sum_dicts_merge = ($state_l, $state_r) -> {{ -- $state
    RETURN $sum_dict($state_l, $state_r);
}};

$aggregate_sum_dicts = AGGREGATION_FACTORY(
    "UDAF",
    $sum_dicts_create,
    $sum_dicts_add,
    $sum_dicts_merge
);

$key_with_max_value = ($dict_items) -> {{
    -- MODE aggregation function equivalent
    -- sort by value desc and take first key, because ListMax doesn't support custom comparator
    return ListSortDesc(
        ListFilter(
            $dict_items,
            ($pair) -> {{RETURN $pair.0 IS NOT NULL;}}
        ),
        ($pair) -> {{RETURN $pair.1;}}
    )[0].0;
}};

$get_main_geo = ($geo_fun, $regions) -> {{
    -- map region with $geo_fun and sum corresponding hits
    $mapped_region_hits = ListMap(
        DictItems(
            ToMultiDict(ListMap(
                DictItems(Unwrap($regions)),
                ($pair) -> {{
                    RETURN ($geo_fun(CAST($pair.0 AS Int32)), $pair.1);
                }}
            ))
    ),
    ($pair) -> {{
            RETURN ($pair.0, ListSum($pair.1));
    }});
    RETURN IF(
        $regions IS NULL,
        NULL,
        $key_with_max_value($mapped_region_hits)
    );
}};

$process_region = ($id) -> {{
    $geo = Geo::RoundRegionById($id, 'region');
    RETURN IF($geo.type == 5, $geo.id);
}};

INSERT INTO `{app_by_devid_daily_table}` WITH TRUNCATE
SELECT
    id,
    id_type,
    apps,
    ua_profile,
    model,
    manufacturer,
    Yson::Serialize(Yson::FromInt64Dict(region_ids)) as region_ids,
    $get_main_geo($process_region, region_ids) AS main_region_obl,
FROM (
    SELECT
        id,
        id_type,
        ListUniq(ListFlatMap(
            AGGREGATE_LIST(apps),
            ($app_name) -> {{ RETURN $app_name; }}
        )) AS apps,
        MAX_BY(ua_profile, `date`) AS ua_profile,
        MAX_BY(model, `date`) AS model,
        MAX_BY(manufacturer, `date`) AS manufacturer,
        AGGREGATE_BY(Yson::ConvertToInt64Dict(region_ids), $aggregate_sum_dicts) as region_ids
    FROM $info_about_active
    GROUP BY id, id_type
)
WHERE ListLength(apps) >= 3 AND ListLength(apps) <= 300
ORDER BY id, id_type
"""


class ParseDailyAppMetricaData(BaseYtTask):
    date = luigi.Parameter()
    priority = 100
    task_group = 'export_profiles'

    def requires(self):
        return ExternalInput(os.path.join(config.MOBILE_DEV_INFO.format(self.date)))

    def output(self):
        return YtTarget(os.path.join(config.APP_METRICA_DATA_DIRECTORY, 'daily', self.date))

    def run(self):
        with TimeTracker(monitoring_name=self.__class__.__name__), \
             self.yt.Transaction() as transaction:
            self.yql.query(
                prepare_daily_data_query.format(
                    daily_input_table=self.input().table,
                    output_table=self.output().table,
                ),
                transaction=transaction,
            )

            self.yt.set_attribute(
                self.output().table,
                'generate_date',
                self.date,
            )


def add_categories(record, app_to_category):
    categories_dict = collections.Counter()

    for app in record['apps']:
        if app in app_to_category:
            categories_dict.update(app_to_category[app])

    record['categories'] = categories_dict
    yield record


class GetDailyAppByDevid(BaseYtTask):
    date = luigi.Parameter()
    priority = 100
    task_group = 'export_profiles'

    def __init__(self, date):
        super(GetDailyAppByDevid, self).__init__(date)
        self.n_days = config.N_DAYS_TO_AGGREGATE_APP_DATA
        self.end_date = date_helpers.get_date_from_past(self.date, 2)
        self.date_range = date_helpers.generate_back_dates(self.end_date, self.n_days)
        self.start_date = min(self.date_range)

    def requires(self):
        return {
            'parsed_app_metrica_data': [ParseDailyAppMetricaData(date) for date in self.date_range],
            'categories_data': ExternalInput(config.MOBILE_APP_CATEGORIES),
            'cleaner': OldNodesByNameCleaner(
                date=self.end_date,
                folder=os.path.join(config.APP_METRICA_DATA_DIRECTORY, 'daily'),
                lifetime=self.n_days,
            ),
        }

    def output(self):
        return YtDailyRewritableTarget(
            config.APP_BY_DEVID_DAILY_TABLE,
            self.date,
        )

    def run(self):
        with TimeTracker(monitoring_name=self.__class__.__name__):
            with self.yt.Transaction() as transaction, self.yt.TempTable() as daily_table:
                self.yql.query(
                    merge_app_metrica_data_query.format(
                        daily_app_metrica_directory=os.path.join(config.APP_METRICA_DATA_DIRECTORY, 'daily'),
                        start_date=self.start_date,
                        end_date=self.end_date,
                        app_by_devid_daily_table=daily_table,
                    ),
                    transaction=transaction,
                )

                app_categories = dict()
                for row in self.yt.read_table(self.input()['categories_data'].table):
                    app_categories[row['bundleId']] = row['raw_categories']

                self.yt.create_empty_table(
                    self.output().table,
                    schema=utils.daily_app_metrica_schema,
                )

                self.yt.run_map(
                    functools.partial(add_categories, app_to_category=app_categories),
                    source_table=daily_table,
                    destination_table=self.output().table,
                    memory_limit=4*(1024**3),
                    spec={'title': 'Computing categories'},
                )

                self.yt.run_sort(
                    self.output().table,
                    sort_by=['id', 'id_type'],
                )

                self.yt.set_attribute(
                    self.output().table,
                    'generate_date',
                    self.date,
                )


def flatten_apps_mapper(row):
    for app in row['apps']:
        yield {
            'id': row['id'],
            'id_type': row['id_type'],
            'app': app,
        }


@with_context
def feature_making_reducer(key, rows, context):
    vector = None
    for row in rows:
        if context.table_index == 0:
            vector = vector_helpers.binary_to_numpy(row['vector'])
        elif context.table_index == 1 and vector is not None:
            yield {
                'id': row['id'],
                'id_type': row['id_type'],
                'vector': vector.tostring(),
            }


class GetDailyDevidVectors(BaseYtTask):
    date = luigi.Parameter()
    task_group = 'export_profiles'

    def requires(self):
        return {
            'daily_app_by_devid': GetDailyAppByDevid(self.date),
            'app2vec': ExternalInput(config.APP2VEC_VECTORS_TABLE),
        }

    def output(self):
        return YtDailyRewritableTarget(
            config.DAILY_DEVID2VEC,
            self.date,
        )

    def run(self):
        with TimeTracker(monitoring_name=self.__class__.__name__):
            with self.yt.Transaction(), \
                    self.yt.TempTable() as flattened_daily_apps_table, \
                    self.yt.TempTable() as joined_with_vectors_table:
                self.yt.create_empty_table(
                    self.output().table,
                    schema=utils.daily_devid_vector_schema,
                )

                self.yt.run_map(
                    flatten_apps_mapper,
                    self.input()['daily_app_by_devid'].table,
                    flattened_daily_apps_table,
                )

                self.yt.run_sort(flattened_daily_apps_table, sort_by='app')

                self.yt.run_join_reduce(
                    feature_making_reducer,
                    [self.yt.TablePath(self.input()['app2vec'].table, foreign=True),
                     flattened_daily_apps_table],
                    joined_with_vectors_table,
                    join_by='app',
                    spec={'title': 'Building devid2vec'},
                )

                self.yt.run_map_reduce(
                    None,
                    vector_helpers.sum_vectors_reducer,
                    joined_with_vectors_table,
                    self.output().table,
                    reduce_by=['id', 'id_type'],
                    spec={'title': 'Summarizing devid vectors'},
                    reduce_combiner=vector_helpers.sum_vectors_reducer,
                )

                self.yt.run_sort(
                    self.output().table,
                    sort_by=['id', 'id_type'],
                )

                self.yt.set_attribute(
                    self.output().table,
                    'generate_date',
                    self.date,
                )


@with_context
class UpdateMonthlyStorageReducer(object):
    def __init__(self, start_date, end_date):
        self.start_date = start_date
        self.end_date = end_date

    def __call__(self, key, rows, context):
        last_day_active_row = None
        storage_row = None

        for row in rows:
            if context.table_index == 0:
                last_day_active_row = row
            else:
                storage_row = row

        if last_day_active_row:
            updated_row = last_day_active_row
            if storage_row:
                updated_row['days_active'] = storage_row['days_active'] + [self.end_date]
            else:
                updated_row['days_active'] = [self.end_date]
        else:
            updated_row = storage_row

        updated_row['days_active'] = filter(lambda cur_date: cur_date >= self.start_date, updated_row['days_active'])
        if len(updated_row['days_active']) != 0:
            yield updated_row


class UpdateMonthlyDevidStorage(BaseYtTask):
    date = luigi.Parameter()
    data_source = luigi.Parameter()
    task_group = 'export_profiles'

    def requires(self):
        requires_by_data_source = {
            'vector': GetDailyDevidVectors(self.date),
            'app': GetDailyAppByDevid(self.date),
        }

        return requires_by_data_source[self.data_source]

    def output(self):
        output_by_data_source = {
            'vector': YtDailyRewritableTarget(
                config.MONTHLY_DEVID2VEC,
                self.date,
            ),
            'app': YtDailyRewritableTarget(
                config.APP_BY_DEVID_MONTHLY_TABLE,
                self.date,
            ),
        }

        return output_by_data_source[self.data_source]

    def run(self):
        schema_by_data_source = {
            'vector': utils.monthly_devid_vector_schema,
            'app': utils.monthly_app_metrica_schema,
        }

        with TimeTracker(monitoring_name=self.__class__.__name__):
            with self.yt.Transaction():
                if not self.yt.exists(self.output().table):
                    self.yt.create_empty_table(
                        self.output().table,
                        schema=schema_by_data_source[self.data_source],
                    )

                self.yt.sort_if_needed(
                    self.output().table,
                    sort_by=['id', 'id_type'],
                )

                end_date = date_helpers.get_date_from_past(self.date, 2)
                start_date = date_helpers.get_date_from_past(end_date, config.N_DAYS_TO_AGGREGATE_APP_DATA)

                self.yt.run_reduce(
                    UpdateMonthlyStorageReducer(
                        start_date=start_date,
                        end_date=end_date,
                    ),
                    [self.input().table, self.output().table],
                    self.output().table,
                    reduce_by=['id', 'id_type'],
                )

                self.yt.run_sort(
                    self.output().table,
                    sort_by=['id', 'id_type'],
                )

                self.yt.set_attribute(
                    self.output().table,
                    'generate_date',
                    self.date,
                )


flatten_query = """
INSERT INTO `{output_table}` WITH TRUNCATE
SELECT id, id_type, app
FROM (
    SELECT id, id_type, Yson::ConvertToStringList(apps) AS apps
    FROM `{input_table}`
)
FLATTEN BY apps AS app
ORDER BY app
"""


class GetMonthlyDevidByApp(BaseYtTask):
    date = luigi.Parameter()
    task_group = 'export_profiles'

    def requires(self):
        return UpdateMonthlyDevidStorage(self.date, 'app')

    def output(self):
        return YtDailyRewritableTarget(
            table=config.DEVID_BY_APP_MONTHLY_TABLE,
            date=self.date,
        )

    def run(self):
        with TimeTracker(monitoring_name=self.__class__.__name__):
            with self.yt.Transaction() as transaction:
                self.yql.query(
                    flatten_query.format(
                        input_table=self.input().table,
                        output_table=self.output().table,
                    ),
                    transaction=transaction,
                )

                self.yt.set_attribute(
                    self.output().table,
                    'generate_date',
                    self.date,
                )


@aggregator
def apps_idf_mapper(rows):
    counter = collections.Counter()
    for row in rows:
        for app in row['apps']:
            counter[app] += 1

    for app, app_cnt in counter.iteritems():
        yield {
            'app': app,
            'count': app_cnt,
        }


def apps_idf_reducer(key, rows, total_devids):
    total = sum(row['count'] for row in rows)
    yield {
        'app': key['app'],
        'count': total,
        'idf': math.log(1.0 * total_devids / total),
    }


class GetAppIdf(BaseYtTask):
    date = luigi.Parameter()
    task_group = 'export_profiles'

    def requires(self):
        return UpdateMonthlyDevidStorage(self.date, 'app')

    def output(self):
        return YtDailyRewritableTarget(
            table=config.APP_IDF_TABLE,
            date=self.date,
        )

    def run(self):
        with TimeTracker(monitoring_name=self.__class__.__name__):
            with self.yt.Transaction():
                self.yt.create_empty_table(
                    self.output().table,
                    schema={
                        'idf': 'double',
                        'count': 'uint64',
                        'app': 'string',
                    },
                )

                total_devids = self.yt.row_count(self.input().table)

                self.yt.run_map_reduce(
                    apps_idf_mapper,
                    functools.partial(apps_idf_reducer, total_devids=total_devids),
                    self.input().table,
                    self.output().table,
                    reduce_by='app',
                )

                self.yt.run_sort(
                    self.output().table,
                    sort_by='app',
                )

                self.yt.set_attribute(
                    self.output().table,
                    'generate_date',
                    self.date,
                )
