# coding: utf-8

import itertools

from analytics.plotter_lib.plotter import Plot, require
from analytics.plotter_lib.utils import get_dts_delta
from nile.api.v1 import (
    with_hints,
    extended_schema,
    extractors as ne,
    aggregators as na,
    filters as nf,
    Record
)

PERIODS = [7, 14, 30]


def convert_ui(ui):
    if ui == 'yandexApp':
        return 'mobile'
    elif ui == 'tablet':
        return 'desktop'
    elif not ui:
        return 'undefined'
    return ui


@with_hints(output_schema=extended_schema(period=int, users=int, returned=int))
def retention_reducer(groups):
    for key, recs in groups:
        last_fielddate = None
        last_page_name = None
        last_ui = None

        for rec in recs:
            if last_fielddate and last_fielddate != rec.fielddate:
                dt_delta = get_dts_delta(rec.fielddate, last_fielddate)
                for period in PERIODS:
                    for comb in itertools.product(
                        (last_ui, '_total_'),
                        (last_page_name, '_total_')
                    ):
                        ui, page_name = comb
                        if dt_delta.days <= period:
                            yield Record(
                                fielddate=last_fielddate,
                                ui=ui,
                                page_name=page_name,
                                period=period,
                                users=1,
                                returned=1
                            )
                        else:
                            yield Record(
                                fielddate=last_fielddate,
                                ui=ui,
                                page_name=page_name,
                                period=period,
                                users=1,
                                returned=0
                            )
            last_fielddate = rec.fielddate
            last_page_name = rec.page_name
            last_ui = rec.ui

        for period in PERIODS:
            for comb in itertools.product(
                (last_ui, '_total_'),
                (last_page_name, '_total_')
            ):
                ui, page_name = comb
                yield Record(
                    fielddate=last_fielddate,
                    ui=ui,
                    page_name=page_name,
                    period=period,
                    users=1,
                    returned=0
                )


class EntryPointsRetention(Plot):
    @require('CollectionsRedirLog.full_with_additional_days')
    def entry_points_retention_publish(self, streams):
        return streams['CollectionsRedirLog.full_with_additional_days'] \
            .filter(nf.equals('path', 'start.session')) \
            .project(
                'fielddate',
                'yandexuid',
                page_name=ne.custom(lambda x: x if x else 'undefined', 'page_name').with_type(str),
                ui=ne.custom(convert_ui, 'ui').with_type(str),
            ) \
            .groupby('yandexuid') \
            .sort('fielddate') \
            .reduce(
                retention_reducer
            ) \
            .groupby('fielddate', 'ui', 'page_name', 'period') \
            .aggregate(
                users=na.sum('users'),
                returned=na.sum('returned')
            ) \
            .project(
                ne.all(),
                retention=ne.custom(lambda x, y: float(x) / y, 'returned', 'users').with_type(float)
            ) \
            .publish(self.get_statface_report('Collections/Metrics/Retention/EntryPointsRetention'), allow_change_job=True)
