# -*- coding: utf-8 -*-
from collections import Counter
import datetime
import os

from nile.api.v1 import (
    aggregators as na,
    Record,
)
from passport.backend.profile import get_cluster
from passport.backend.profile.aggregators import (
    histogram_by_list,
    histogram_with_top,
)
from passport.backend.profile.utils.helpers import (
    cut_host,
    date_to_integer_unixtime,
    merge_records,
    to_date_str,
)
from passport.backend.profile.utils.yt import (
    check_table_attribute_exist,
    get_yt,
    set_table_attribute,
)
from qb2.api.v1 import (
    extractors as se,
    filters as sf,
)
from retrying import retry
import yenv
from yt.wrapper.errors import YtIncorrectResponse


IP_COUNT = 50
GEO_COUNT = 15
OS_COUNT = 15
BROWSER_COUNT = 15
YANDEXUID_COUNT = 50
RETPATH_REFERER_COUNT = 20
AM_VERSION_COUNT = 15
CLOUD_TOKEN_COUNT = 15
DEVICE_ID_COUNT = 15

DAILY_JOB_STATUS_ATTRIBUTE = 'profile_daily_job_finished'


class UserProfileJob(object):
    def __init__(self, config, date):
        self.config = config
        self.yt = get_yt(config=config)
        self.date = date

        self.job = get_cluster(config).env(
            default_memory_limit=1024,  # в мегабайтах
            parallel_operations_limit=int(config['yt']['parallel_operations_limit']),
            yt_spec_defaults=config['yt'].get('spec', {}),
        ).job(name='create-userprofiles-' + to_date_str(date))
        self.tmp_tables = []

    def _calculate_passport_auth_frequencies_table(self, date_start, date_end, suffix):
        auths = self.job.table(
            os.path.join(
                self.config['yt']['passport_dataset_dir'],
                '{%s..%s}' % (date_start, date_end),
            ),
        ).join(
            self.job.table(
                os.path.join(self.config['yt']['passport_glogouts_dataset_dir'], to_date_str(self.date)),
            ),
            type='left',
            by='uid',
            assume_unique_right=True,
        )

        aggregators = {
            'succ_auth_count_%s' % suffix: na.count(),
            'captcha_passed_%s' % suffix: na.sum('captcha_passed'),

            'ip_freq_%s' % suffix: histogram_with_top('ip', IP_COUNT),
            'country_freq_%s' % suffix: histogram_with_top('country_id', GEO_COUNT),
            'city_freq_%s' % suffix: histogram_with_top('city_id', GEO_COUNT),
            'as_list_freq_%s' % suffix: histogram_by_list('as_list'),

            'day_part_freq_%s' % suffix: na.histogram('day_part'),
            'weekday_freq_%s' % suffix: na.histogram('weekday'),

            'os_family_freq_%s' % suffix: histogram_with_top('os_family', OS_COUNT),
            'os_name_freq_%s' % suffix: histogram_with_top('os_name', OS_COUNT),

            'browser_freq_%s' % suffix: histogram_with_top('browser', BROWSER_COUNT),
            'browser_name_freq_%s' % suffix: histogram_with_top('browser_name', BROWSER_COUNT),
            'browser_os_freq_%s' % suffix: histogram_with_top('browser_os', BROWSER_COUNT),
            'is_mobile_freq_%s' % suffix: na.histogram('is_mobile'),

            'yandexuid_freq_%s' % suffix: histogram_with_top('yandexuid', YANDEXUID_COUNT),
            'retpath_host_2_freq_%s' % suffix: histogram_with_top('retpath_host_2', RETPATH_REFERER_COUNT),
            'retpath_host_3_freq_%s' % suffix: histogram_with_top('retpath_host_3', RETPATH_REFERER_COUNT),
            'referer_host_2_freq_%s' % suffix: histogram_with_top('referer_host_2', RETPATH_REFERER_COUNT),
            'referer_host_3_freq_%s' % suffix: histogram_with_top('referer_host_3', RETPATH_REFERER_COUNT),
        }
        return auths.qb2(
            log='passport-log',
            fields=[
                se.log_fields(
                    'ip',
                    'as_list',
                    'browser_name',
                    'browser_version',
                    'os_family',
                    'os_name',
                    'yandexuid',
                    'retpath_host',
                    'referer_host',
                ).allow_override(),
                se.integer_log_fields(
                    'uid',
                    'captcha_passed',
                    'country_id',
                    'city_id',
                    'day_part',
                    'weekday',
                    'is_mobile',
                ).allow_override(),
                se.custom(
                    'browser',
                    lambda browser_name, browser_version: str(browser_name) + " " + str(browser_version),
                    'browser_name',
                    'browser_version',
                ),
                se.custom(
                    'browser_os',
                    lambda browser, os_name: str(browser) + " - " + str(os_name),
                    'browser',
                    'os_name',
                ),
                se.integer_log_fields('unixtime', 'glogout_unixtime').hide(),

                se.custom('retpath_host_2', lambda host: cut_host(host, 2), 'retpath_host'),
                se.custom('retpath_host_3', lambda host: cut_host(host, 3), 'retpath_host'),

                se.custom('referer_host_2', lambda host: cut_host(host, 2), 'referer_host'),
                se.custom('referer_host_3', lambda host: cut_host(host, 3), 'referer_host'),
            ],
            filters=[sf.custom(lambda unixtime, glogout_unixtime: unixtime >= glogout_unixtime if glogout_unixtime else True)],
            intensity='cpu',
        ).groupby('uid').aggregate(**aggregators)

    def _calculate_blackbox_ses_update_frequencies_table(self, date_start, date_end, suffix):
        events = self.job.table(
            os.path.join(
                self.config['yt']['blackbox_dataset_dir'],
                '{%s..%s}' % (date_start, date_end),
            ),
        ).join(
            self.job.table(
                os.path.join(self.config['yt']['passport_glogouts_dataset_dir'], to_date_str(self.date)),
            ),
            type='left',
            by='uid',
            assume_unique_right=True,
        )

        aggregators = {
            'su_ip_freq_%s' % suffix: histogram_with_top('ip', IP_COUNT),
            'su_country_freq_%s' % suffix: histogram_with_top('country_id', GEO_COUNT),
            'su_city_freq_%s' % suffix: histogram_with_top('city_id', GEO_COUNT),
            'su_as_list_freq_%s' % suffix: histogram_by_list('as_list'),

            'su_day_part_freq_%s' % suffix: na.histogram('day_part'),
            'su_weekday_freq_%s' % suffix: na.histogram('weekday'),

            'su_os_family_freq_%s' % suffix: histogram_with_top('os_family', OS_COUNT),
            'su_os_name_freq_%s' % suffix: histogram_with_top('os_name', OS_COUNT),

            'su_browser_freq_%s' % suffix: histogram_with_top('browser', BROWSER_COUNT),
            'su_browser_name_freq_%s' % suffix: histogram_with_top('browser_name', BROWSER_COUNT),
            'su_browser_os_freq_%s' % suffix: histogram_with_top('browser_os', BROWSER_COUNT),
            'su_is_mobile_freq_%s' % suffix: na.histogram('is_mobile'),

            'su_referer_host_2_freq_%s' % suffix: histogram_with_top('referer_host_2', RETPATH_REFERER_COUNT),
            'su_referer_host_3_freq_%s' % suffix: histogram_with_top('referer_host_3', RETPATH_REFERER_COUNT),
        }

        return events.qb2(
            # hack
            log='passport-log',
            fields=[
                # TODO: map(lambda f: f.rename(prefix + f.name), fields)
                se.integer_log_field('uid'),
                # ip
                se.log_field('ip').allow_override(),
                se.log_field('as_list'),
                # browser
                se.log_field('browser_name'),
                se.log_field('browser_version'),
                # os
                se.log_field('os_family'),
                se.log_field('os_name'),
                se.integer_log_field('is_mobile'),
                # referer/retpath
                se.log_field('retpath_host'),
                se.log_field('referer_host'),
                # geo
                se.integer_log_field('country_id'),
                se.integer_log_field('city_id'),
                # day/week
                se.integer_log_field('day_part'),
                se.integer_log_field('weekday'),

                # custom
                se.custom(
                    'browser',
                    lambda browser_name, browser_version: str(browser_name) + " " + str(browser_version),
                    'browser_name',
                    'browser_version',
                ),
                se.custom(
                    'browser_os',
                    lambda browser, os_name: str(browser) + " - " + str(os_name),
                    'browser',
                    'os_name',
                ),

                se.custom('retpath_host_2', lambda host: cut_host(host, 2), 'retpath_host'),
                se.custom('retpath_host_3', lambda host: cut_host(host, 3), 'retpath_host'),

                se.custom('referer_host_2', lambda host: cut_host(host, 2), 'referer_host'),
                se.custom('referer_host_3', lambda host: cut_host(host, 3), 'referer_host'),
                se.integer_log_fields('unixtime', 'glogout_unixtime').hide(),
            ],
            filters=[sf.custom(lambda unixtime, glogout_unixtime: unixtime >= glogout_unixtime if glogout_unixtime else True)],
            intensity='cpu',
        ).groupby('uid').aggregate(**aggregators)

    def _calculate_oauth_issue_token_frequencies_table(self, date_start, date_end, suffix):
        events = self.job.table(
            os.path.join(
                self.config['yt']['oauth_dataset_dir'],
                '{%s..%s}' % (date_start, date_end),
            ),
        ).join(
            self.job.table(
                os.path.join(self.config['yt']['passport_glogouts_dataset_dir'], to_date_str(self.date)),
            ),
            type='left',
            by='uid',
            assume_unique_right=True,
        )

        aggregators = {
            'it_ip_freq_%s' % suffix: histogram_with_top('user_ip', IP_COUNT),
            'it_country_freq_%s' % suffix: histogram_with_top('country_id', GEO_COUNT),
            'it_city_freq_%s' % suffix: histogram_with_top('city_id', GEO_COUNT),
            'it_as_list_freq_%s' % suffix: histogram_by_list('as_list'),

            'it_day_part_freq_%s' % suffix: na.histogram('day_part'),
            'it_weekday_freq_%s' % suffix: na.histogram('weekday'),

            'it_device_id_freq_%s' % suffix: histogram_with_top('device_id', DEVICE_ID_COUNT),
            'it_am_version_freq_%s' % suffix: histogram_with_top('am_version_truncated', AM_VERSION_COUNT),
            'it_cloud_token_freq_%s' % suffix: histogram_with_top('cloud_token', CLOUD_TOKEN_COUNT),
        }

        return events.qb2(
            log='generic-tskv-log',
            fields=[
                # TODO: map(lambda f: f.rename(prefix + f.name), fields)
                se.integer_log_field('uid'),
                # ip
                se.log_field('user_ip').allow_override(),
                se.log_field('as_list'),
                # geo
                se.integer_log_field('country_id'),
                se.integer_log_field('city_id'),
                # day/week
                se.integer_log_field('day_part'),
                se.integer_log_field('weekday'),

                # device
                se.log_field('device_id'),
                se.log_field('am_version_truncated'),
                se.log_field('cloud_token'),
                se.integer_log_fields('unixtime', 'glogout_unixtime').hide(),
            ],
            filters=[sf.custom(lambda unixtime, glogout_unixtime: unixtime >= glogout_unixtime if glogout_unixtime else True)],
            intensity='cpu',
        ).groupby('uid').aggregate(**aggregators)

    def timedelta_factors(self, timedelta, suffix, job):
        date_end_str = to_date_str(self.date)
        date_start_str = to_date_str(self.date + timedelta)
        self.tmp_tables.append(
            job(
                date_start=date_start_str,
                date_end=date_end_str,
                suffix=suffix,
            ),
        )
        return self

    def timedelta_factors_passport(self, timedelta, suffix):
        return self.timedelta_factors(timedelta, suffix, self._calculate_passport_auth_frequencies_table)

    def timedelta_factors_blackbox(self, timedelta, suffix):
        return self.timedelta_factors(timedelta, suffix, self._calculate_blackbox_ses_update_frequencies_table)

    def timedelta_factors_oauth(self, timedelta, suffix):
        return self.timedelta_factors(timedelta, suffix, self._calculate_oauth_issue_token_frequencies_table)

    @staticmethod
    def _uid_reducer(groups):
        for _uid, records in groups:
            yield merge_records(records)

    def _join_tables(self):
        if not self.tmp_tables:
            return

        union_table = self.job.concat(*self.tmp_tables)
        result_table = union_table.groupby('uid').reduce(self._uid_reducer)
        return result_table

    @staticmethod
    def _append_column_mapper(key, value):
        """
        Маппер добавляет столбец key со значением value ко всем строкам.
        """
        def mapper(records):
            for record in records:
                fields = record.to_dict()
                fields[key] = value
                yield Record(**fields)
        return mapper

    @staticmethod
    def _aggregate_monthly_reduce(months, target_date_unixtime):
        """
        Редюсер строит агрегаты из строк столбцов *_freq_1m за последние месяцы, указанные в months.
        Если указано [2,3,6], то *_freq_1m сложатся в правильном порядке и дополнительно
        появятся столбцы *_freq_2m, *_freq_3m, *_freq_6m.

        Для правильного порядка строки в группе должны быть уже отсортированы хронологически по убыванию.
        """
        def reducer(groups):

            def get_glogout_month_ago(glogout_unixtime):
                if glogout_unixtime:
                    # перевести его в номер месяца для этого ключа (uid-а), с округлением в меньшую сторону.
                    return (target_date_unixtime - int(glogout_unixtime or 0)) / (30 * 24 * 60 * 60)
                return max(months) + 1

            for key, records in groups:
                glogout_month_ago = None
                records_by_months = {}
                for record in records:
                    glogout_month_ago = glogout_month_ago or get_glogout_month_ago(record.get('glogout_unixtime'))
                    # если глогаут был сделан в этом месяце, то растаскиваем его значения в ключи месяцами больше
                    # если в каком-то месяце ранее - округляется до 1 месяца ниже
                    if record.__month__ <= max([glogout_month_ago, 1]):
                        fields = {
                            field: Counter(dict(freq))
                            for field, freq in record.to_dict().iteritems() if field.endswith('_freq_1m')
                        }
                        records_by_months[record.__month__] = fields

                aggregate = {}
                # итерируемся по месяцам от текущего месяца назад в прошлое
                for month in xrange(1, max(months) + 1):
                    current_month = records_by_months.get(month, {})
                    for field, freq in current_month.iteritems():
                        if field in aggregate:
                            aggregate[field].update(freq)
                        else:
                            aggregate[field] = freq

                    if month in months:
                        final_aggregate = {k.replace('_1m', '_%dm' % month): v.items() for k, v in aggregate.iteritems()}
                        if final_aggregate:
                            yield Record(key, **final_aggregate)
        return reducer

    def aggregate_monthly(self, current_profile_table, months):
        dates = []
        # Нам необходимо забрать из базы все таблицы, которые требуются для расчёта самого "прошлого" месяца,
        # из запрошенных.
        for month in xrange(1, max(months)):
            dates.append(self.date - datetime.timedelta(days=30) * month)
        dates_str = [to_date_str(date) for date in dates]
        old_profile_table_paths = [os.path.join(self.config['yt']['profile_dir'], date) for date in dates_str]
        old_profile_tables = []
        # Добавляем к текущему, свежему расчёту столбец "__month__"=1,
        # чтобы в _aggregate_monthly_reduce текущий расчёт шёл в самом начале.
        current_profile_table = current_profile_table.map(self._append_column_mapper('__month__', 1))
        # для остальных таблиц проставляем __month__ с правильным значением, упорядоченным
        # по возрастанию в прошлое (2 месяца назад, 3 месяца назад, 4 месяца назад и т.п.)
        for month, table_path in zip(xrange(2, max(months) + 1), old_profile_table_paths):
            # если нет профиля за какой-то предыдущий день/месяц - пропускаем его.
            if self.yt.exists(table_path):
                table = self.job.table(table_path).map(self._append_column_mapper('__month__', month))
                old_profile_tables.append(table)
        n_months_aggregate = self.job.concat(
            current_profile_table,
            *old_profile_tables
        ).join(
            self.job.table(
                os.path.join(self.config['yt']['passport_glogouts_dataset_dir'], to_date_str(self.date)),
            ),
            type='left',
            by='uid',
            assume_unique_right=True,
        ).groupby(
            'uid',
        ).sort(
            '__month__',
        ).reduce(
            self._aggregate_monthly_reduce(months, date_to_integer_unixtime(self.date)),
        )
        return n_months_aggregate

    def run(self):
        result_table = self._join_tables()
        # Добавляем агрегаты разных распределений за 3 и 6 месяцев
        month_aggregates = self.aggregate_monthly(result_table, self.config['monthly_aggregate'])
        self.tmp_tables = [result_table, month_aggregates]
        if yenv.type == 'production':
            self._join_tables().project(
                se.integer_log_field('uid'),
                se.all(exclude=['uid']),
            ).join(
                self.job.table(
                    os.path.join(self.config['yt']['cards_dir'], to_date_str(self.date)),
                ).project(
                    se.integer_log_field('uid'),
                    se.all(exclude=['uid']),
                ),
                type='left',
                by='uid',
                assume_unique=True,
            ).sort(
                'uid',
            ).put(
                os.path.join(self.config['yt']['profile_dir'], to_date_str(self.date)),
            )
        else:
            # аналитическая тулза не создает таблицы /cards/uids_has_cards/ вне прода.
            self._join_tables().project(
                se.integer_log_field('uid'),
                se.all(exclude=['uid']),
            ).sort(
                'uid',
            ).put(
                os.path.join(self.config['yt']['profile_dir'], to_date_str(self.date)),
            )
        return self.job.run()


@retry(stop_max_attempt_number=3, wait_fixed=5000, retry_on_exception=(YtIncorrectResponse,))
def create_userprofiles(config, target_date):
    profile_dataset_path = os.path.join(config['yt']['profile_dir'], to_date_str(target_date))
    if check_table_attribute_exist(config, profile_dataset_path, DAILY_JOB_STATUS_ATTRIBUTE):
        return

    (
        UserProfileJob(
            config=config,
            date=target_date,
        )
        .timedelta_factors_passport(datetime.timedelta(days=0), '1d')
        .timedelta_factors_passport(datetime.timedelta(days=-7), '1w')
        .timedelta_factors_passport(datetime.timedelta(days=-30), '1m')
        .timedelta_factors_blackbox(datetime.timedelta(days=0), '1d')
        .timedelta_factors_blackbox(datetime.timedelta(days=-7), '1w')
        .timedelta_factors_blackbox(datetime.timedelta(days=-30), '1m')
        .timedelta_factors_oauth(datetime.timedelta(days=0), '1d')
        .timedelta_factors_oauth(datetime.timedelta(days=-7), '1w')
        .timedelta_factors_oauth(datetime.timedelta(days=-30), '1m')
        .run()
    )

    set_table_attribute(config, profile_dataset_path, DAILY_JOB_STATUS_ATTRIBUTE, True)
