#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import datetime
from functools import partial

from yt.wrapper import with_context

from crypta.profile.lib import date_helpers

from crypta.profile.utils.config import config
from crypta.profile.utils.luigi_utils import ExternalInput
from crypta.profile.utils.segment_utils.builders import RegularSegmentBuilder
from crypta.profile.utils.segment_utils.processors import DayProcessor, LogProcessor


MIN_HIT_RATE = 1.0
MAX_HIT_RATE = 300.0
MIN_ACTIVE_DAYS = 3


day_processor_query = """
INSERT INTO `{output_table}` WITH TRUNCATE
SELECT yandexuid, `date`, SUM(`count`) AS n_hits
FROM (
    SELECT
        yandexuid,
        `date`,
        site_weight.1 AS `count`
    FROM (
        SELECT yandexuid, TableName() AS `date`, Yson::ConvertToUint64Dict(site_weights) AS site_weights
        FROM `{input_table}`
        WHERE yandexuid != 0
    )
    FLATTEN DICT BY site_weights AS site_weight
)
GROUP BY yandexuid, `date`
"""


class ProcessedMetricsHitsForDigitalViewers(DayProcessor):
    def requires(self):
        return ExternalInput(
            os.path.join(
                config.METRICS_HITS_DIRECTORY,
                self.date,
            ),
        )

    def process_day(self, inputs, output_path):
        self.yql.query(
            day_processor_query.format(
                input_table=inputs.table,
                output_table=output_path,
            ),
            transaction=self.transaction,
        )


def get_date_of_yuid_creation(yandexuid):
    yandexuid_str = str(yandexuid)
    if len(yandexuid_str) < 10:
        return None
    ts = yandexuid_str[-10:]
    return date_helpers.to_date_string(datetime.datetime.fromtimestamp(int(ts)))


def activity_reducer(key, rows, today_date):
    creation_date = get_date_of_yuid_creation(key['yandexuid'])
    total_hits = 0
    active_dates = set()

    for row in rows:
        total_hits += row['n_hits']
        active_dates.add(row['date'])

    if len(active_dates) >= MIN_ACTIVE_DAYS and \
            creation_date is not None and creation_date <= today_date:
        days_since_creation = date_helpers.from_date_string_to_datetime(today_date, date_helpers.DATE_FORMAT) - \
            date_helpers.from_date_string_to_datetime(creation_date, date_helpers.DATE_FORMAT) + \
            datetime.timedelta(days=1)
        days_since_creation = min(35, days_since_creation.days)
        hit_rate = round(float(total_hits) / days_since_creation, 1)

        if MIN_HIT_RATE <= hit_rate <= MAX_HIT_RATE:
            yield {
                'yandexuid': key['yandexuid'],
                'hit_rate': hit_rate,
            }


@with_context
def filter_mobile_reducer(key, rows, context):
    is_phone = False
    for row in rows:
        if context.table_index == 0 and row['ua_profile']:
            device_type = row['ua_profile'].split('|')[1]
            if device_type == 'phone':
                is_phone = True

        elif context.table_index == 1 and is_phone:
            yield row


class DigitalViewers(RegularSegmentBuilder):
    name_segment_dict = {
        'heavy_digital_viewers': (547, 1049),
        'light_digital_viewers': (547, 1048),
        'phone_heavy_digital_viewers': (547, 1047),
        'phone_light_digital_viewers': (547, 1046),
    }

    number_of_days = 35

    def requires(self):
        return {
            'ProcessedMetricsHits': LogProcessor(
                ProcessedMetricsHitsForDigitalViewers,
                self.date,
                self.number_of_days,
            ),
        }

    def build_segment(self, inputs, output_path):

        with self.yt.TempTable(prefix='activity_table_') as activity_table, \
                self.yt.TempTable(prefix='digital_viewers_table_') as digital_viewers_table, \
                self.yt.TempTable(prefix='phone_digital_viewers_table_') as phone_digital_viewers_table, \
                self.yt.TempTable(prefix='phone_activity_table_') as phone_activity_table:

            # Count hit rate of yandexuid
            self.yt.run_map_reduce(
                None,
                partial(activity_reducer, today_date=self.date),
                self.input()['ProcessedMetricsHits'].table,
                activity_table,
                reduce_by='yandexuid',
            )

            # Build digital viewers segment
            ldv_threshold, hdv_threshold = self._find_thresholds(activity_table)
            self._add_segment_names(
                activity_table,
                digital_viewers_table,
                ldv_threshold,
                hdv_threshold,
                ldv_segment_name='light_digital_viewers',
                hdv_segment_name='heavy_digital_viewers',
            )

            # Build phone digital viewers segment
            self._filter_phone_yandexuids(
                source_table=activity_table,
                destination_table=phone_activity_table,
            )

            ldv_threshold, hdv_threshold = self._find_thresholds(phone_activity_table)
            self._add_segment_names(
                phone_activity_table,
                phone_digital_viewers_table,
                ldv_threshold,
                hdv_threshold,
                ldv_segment_name='phone_light_digital_viewers',
                hdv_segment_name='phone_heavy_digital_viewers',
            )

            self.yt.run_merge(
                source_table=[
                    digital_viewers_table,
                    phone_digital_viewers_table,
                ],
                destination_table=output_path,
            )

    def _filter_phone_yandexuids(self, source_table, destination_table):
        self.yt.run_sort(
            source_table,
            sort_by='yandexuid',
        )

        self.yt.run_reduce(
            filter_mobile_reducer,
            source_table=[config.YUID_WITH_ALL_BY_YANDEXUID_TABLE, source_table],
            destination_table=destination_table,
            reduce_by='yandexuid',
        )

    def _find_thresholds(self, activity_table):
        with self.yt.TempTable() as hist_table:
            self.yt.unique_count(
                source_table=activity_table,
                destination_table=hist_table,
                unique_by=['hit_rate'],
            )

            self.yt.run_sort(
                hist_table,
                sort_by='hit_rate',
            )

            total_row_count = self.yt.row_count(activity_table)
            ldv_count = total_row_count / 3
            hdv_count = 2 * ldv_count

            ldv_threshold, hdv_threshold = 0, 0

            prev_count = 0
            cur_count = 0
            for row in self.yt.read_table(hist_table, raw=False):
                cur_count += row['count']
                if prev_count < ldv_count <= cur_count:
                    ldv_threshold = row['hit_rate']
                elif prev_count < hdv_count <= cur_count:
                    hdv_threshold = row['hit_rate']
                prev_count = cur_count

            return ldv_threshold, hdv_threshold

    def _add_segment_names(self, activity_table, output_table,
                           ldv_threshold, hdv_threshold,
                           ldv_segment_name, hdv_segment_name):

        def add_segment_names_mapper(row):
            if row['hit_rate'] <= ldv_threshold:
                yield {
                    'id': str(row['yandexuid']),
                    'id_type': 'yandexuid',
                    'segment_name': ldv_segment_name,
                }
            elif row['hit_rate'] >= hdv_threshold:
                yield {
                    'id': str(row['yandexuid']),
                    'id_type': 'yandexuid',
                    'segment_name': hdv_segment_name,
                }

        self.yt.run_map(
            add_segment_names_mapper,
            source_table=activity_table,
            destination_table=output_table,
        )
