#!/usr/bin/env python
# -*- coding: UTF-8 -*-

import os
from collections import defaultdict

import yt.wrapper as yt
from nile.api.v1 import clusters, aggregators as na, extractors as ne, Record
from nile.files import LocalFile
from qb2.api.v1 import extractors as qe, filters as qf


def get_region(region, type=5, attr='name'):
    """
    Extract region id or name
    :param region: region object
    :param type: region type to extract (see https://nda.ya.ru/3UWtFU). Defaults to 5 — federation entity.
    :param attr: attribute to get (e.g., 'name', 'id')
    :return: attribute of region
    """
    if region is None:
        return None
    for r in region.path:
        if r.type == type:
            region = r
    return getattr(region, attr)

def dau_reducer(groups):
    for key, records in groups:
        regions_devices = {}
        for record in records:
            regions_devices[record['region_city_name']] = dict(
                region_city_id=record['region_city_id'],
                region_country_id=record['region_country_id'],
                region_country_name=record['region_country_name'],
                phone_ids=record['phone_ids'],
                devices=record['devices']
            )
        yield Record(fielddate=key['fielddate'], regions_devices=regions_devices)

def dau_merge_reducer(groups):
    for key, records in groups:
        regions_devices = defaultdict(dict)
        for record in records:
            fielddate = record['fielddate']
            for region_city_name, data in record['regions_devices'].items():
                # Update with new regions count number of devices
                phone_ids = list(set(regions_devices[region_city_name].get('phone_ids', [])) |
                                 set(data.get('phone_ids', [])))
                regions_devices[region_city_name].update(data)
                regions_devices[region_city_name]['phone_ids'] = phone_ids
                regions_devices[region_city_name]['devices'] = len(phone_ids)
        yield Record(fielddate=fielddate, regions_devices=regions_devices)

def remove_phone_ids_mapper(records):
    for record in records:
        for key in record['regions_devices']:
            record['regions_devices'][key].pop('phone_ids')
        yield record

if __name__ == '__main__':
    output_table = os.environ.get('YT_OUTPUT_TABLE', '//home/advisor/reports/phone/dau_aggregat_geo/dau_aggregat_geo')
    output_table_plot = output_table + '_plot'
    output_table_plot_compact = output_table_plot + '_compact'
    input_table = os.environ['YT_INPUT_TABLE']
    yt_token = os.environ['YT_TOKEN']
    common_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'common.py')
    cluster = clusters.yt.Hahn(pool='', token=yt_token).env(packages=[LocalFile(common_path)])
    yt_client = yt.YtClient(proxy='hahn', token=yt_token)

    root_dir = os.path.dirname(output_table)  # Take directory containing input table
    job = cluster.job().env(templates=dict(checkpoints_root=root_dir))

    phone_ids = job.table(
        '//home/advisor/phone/collection'
    ).filter(
        qf.defined('device_id')
    ).project(
        phone_id='_id',
        device_id=qe.dictitem('$uuid', 'device_id')
    ).checkpoint('phone_ids')

    records = job.table(
        input_table,
    ).join(
        phone_ids,
        by='device_id',
        type='left'
    ).project(
        'phone_id',
        'event_timestamp',
        fielddate='event_date',
        RegionID='geo_id'
    ).filter(
        qf.defined('phone_id')
    ).qb2(
        log='metrika-mobile-log',
        fields=['region',
                qe.log_field('phone_id'),
                qe.log_field('event_timestamp'),
                qe.log_field('fielddate')]
    ).project(
        ne.all(exclude=['region']),
        region_city_id=ne.custom(lambda r: get_region(r, 5, 'id'), 'region'),
        region_city_name=ne.custom(lambda r: get_region(r, 5, 'name'), 'region'),
        region_country_id=ne.custom(lambda r: get_region(r, 3, 'id'), 'region'),
        region_country_name=ne.custom(lambda r: get_region(r, 3, 'name'), 'region')
    ).groupby(
        'phone_id',
        'fielddate'
    ).top(
        1,
        'event_timestamp'
    ).groupby(
        'fielddate',
        'region_city_id'
    ).aggregate(
        devices=na.count(),
        phone_ids=na.distinct('phone_id'),
        region_city_name=na.any('region_city_name'),
        region_country_id=na.any('region_country_id'),
        region_country_name=na.any('region_country_name'),
    )

    records.put(output_table, append=True)

    dau_plot_new = records.groupby(
        'fielddate'
    ).reduce(
        dau_reducer
    )

    if yt_client.exists(output_table_plot):
        dau_plot = job.table(
            output_table_plot
        )

        dau_plot_new = job.concat(
            dau_plot,
            dau_plot_new
        ).groupby(
            'fielddate'
        ).reduce(
            dau_merge_reducer
        )

    dau_plot_new.put(
        output_table_plot
    ).map(
        remove_phone_ids_mapper
    ).put(
        output_table_plot_compact
    )

    job.run()
