# -*- coding: utf-8 -*-

from __future__ import print_function, absolute_import, division

import itertools
import logging

from nile.api.v1 import (
    Record,
    extractors as ne,
    aggregators as na
)


class SplitByTimeMapper(object):
    def __init__(self):
        super(SplitByTimeMapper, self).__init__()

    def __call__(self, records, all_age, hour_age, day_age):
        hour_ = 60 * 60
        day_ = 24 * 60 * 60
        for record in records:
            all_age(record)
            if self.__is_older_than(hour_, record['yandexuid'], record['unixtime']):
                hour_age(record)
            if self.__is_older_than(day_, record['yandexuid'], record['unixtime']):
                day_age(record)

    def __is_older_than(self, age, yandexuid, unixtime):
        try:
            timestamp = int((yandexuid or '')[-10:])
            return abs(unixtime - timestamp) > age
        except Exception:
            return True


class ImpressionsHitsRatioMapper(object):
    def __init__(self, config):
        super(ImpressionsHitsRatioMapper, self).__init__()
        self.config = config

    def __call__(self, records):
        for record in records:
            total_hits = sum(record['hits'].values())
            record['hits']['_total_/_total_'] = total_hits
            total_shows = sum(record['shows'].values())
            record['shows']['_total_'] = total_shows
            hits_group = self.__make_hits_groups(total_hits)
            for platform_service, distr_obj in itertools.product(
                sorted(record['hits']),
                sorted(record['shows'])
            ):
                platform, service = platform_service.split('/')
                hits = record['hits'][platform_service]
                shows = record['shows'][distr_obj]
                ratio = round(shows / float(hits), 2) if hits != 0 else 0.
                if ratio <= 1.:
                    for group in hits_group:
                        yield Record(
                            age_filter=record['cleanness'],
                            yandexuid=record['yandexuid'],
                            plat=platform,
                            distr_obj=distr_obj,
                            service=service,
                            ratio=ratio,
                            shows=shows,
                            hits=hits,
                            hits_filter=group
                        )

    def __make_hits_groups(self, total_hits):
        hits_group = ['raw']
        for threshold in self.config['hits_threshold']:
            if total_hits >= threshold:
                hits_group.append('{}hits+'.format(threshold))
        return hits_group


class DeterrenceStatsMapper(object):
    def __init__(self):
        super(DeterrenceStatsMapper, self).__init__()

    def __call__(self, records):
        for record in records:
            report = {
                key: record[key]
                for key in ['age_filter', 'hits_filter', 'distr_obj', 'plat', 'service']
            }
            report['shows_per_user'] = round(float(record['shows']) / record['users'], 2)
            report['mean_shows'] = min(round(float(record['shows']) / record['hits'], 2), 1)
            report.update({
                'q{}'.format(int(round(quantile, 2) * 100)): value
                for quantile, value in record['ratio_quantiles']
            })
            yield Record(**report)


class DeterrenceDistributionMetricAggregator(object):
    def __init__(self, config):
        super(DeterrenceDistributionMetricAggregator, self).__init__()
        self.config = config
        self.logger = logging.getLogger(__name__)

    def aggregate(self, extracts):
        targets = {}
        targets.update(self.__collect_impressions(extracts['atom_cube']))
        targets.update(self.__collect_hits(extracts['bs_watch_log']))
        targets.update(self.__collect_joined_hits_and_impressions(targets))
        targets.update(self.__collect_report(targets))
        return targets

    def __collect_impressions(self, atom_cube):
        imperssions = dict(
            zip(
                ['impressions_all', 'impressions_hour', 'impressions_day'],
                atom_cube.map(SplitByTimeMapper())
            )
        )
        for name in imperssions:
            imperssions[name] = imperssions[name].groupby(
                'yandexuid'
            ).aggregate(
                shows=na.histogram('distr_obj')
            ).project(
                'yandexuid',
                shows=ne.custom(lambda shows: dict(shows), 'shows')
            )
        return imperssions

    def __collect_hits(self, bs_watch_log):
        hits = dict(
            zip(
                ['hits_all', 'hits_hour', 'hits_day'],
                bs_watch_log.map(SplitByTimeMapper())
            )
        )
        for name in hits:
            hits[name] = hits[name].project(
                'yandexuid',
                platform_service=ne.custom(
                    lambda platform, service: '{}/{}'.format(platform,  service),
                    'platform', 'service'
                )
            ).groupby(
                'yandexuid'
            ).aggregate(
                hits=na.histogram('platform_service')
            ).project(
                'yandexuid',
                hits=ne.custom(lambda hits: dict(hits), 'hits')
            )
        return hits

    def __collect_joined_hits_and_impressions(self, targets):
        age_joined = {
            'hits_impressions_{}'.format(age): self.__join_hits_and_impressions_by_age(targets, age)
            for age in ['all', 'hour', 'day']
        }
        return {
            'hits_impressions_joined': age_joined['hits_impressions_all'].concat(
                age_joined['hits_impressions_hour'],
                age_joined['hits_impressions_day'],
            ).project(
                ne.all()  # We need this less project, so we can put this stream to cluster
            )
        }

    def __join_hits_and_impressions_by_age(self, targets, age):
        cleannes_type = {
            'all': 'raw',
            'hour': 'hour_clean',
            'day': 'day_clean'
        }
        return targets['hits_{}'.format(age)].project(
            ne.all(), cleanness=ne.const(cleannes_type[age])
        ).join(
            targets['impressions_{}'.format(age)],
            by='yandexuid',
            type='inner'
        )

    def __collect_report(self, targets):
        report = targets['hits_impressions_joined'].map(
            ImpressionsHitsRatioMapper(self.config),
            intensity='cpu'
        ).groupby(
            'age_filter', 'hits_filter', 'distr_obj', 'plat', 'service'
        ).aggregate(
            shows=na.sum('shows'),
            users=na.count_distinct('yandexuid'),
            hits=na.sum('hits'),
            ratio_quantiles=na.quantile(
                'ratio',
                quantiles=tuple(0.1 * i for i in range(1, 10))
            ),
            intensity='cpu'
        ).map(
            DeterrenceStatsMapper()
        )
        return {'report': report}
