#!/usr/bin/env python
# coding: utf8
import datetime as dt
import os.path
import re

import yt.wrapper as yt

from nile.api.v1 import aggregators, clusters, extractors, Record
from qb2.api.v1 import filters

from mpfs.config import settings
from mpfs.core.mrstat.stat_utils import set_yt_proxy
from mpfs.core.user_activity_info.dao import UserActivityInfoDAO
from mpfs.core.user_activity_info.utils import ErrorInfoContainer
from mpfs.core.promo_codes.logic.come_back_users_promo import provide_discount_if_needed

DATE_FORMAT = '%Y-%m-%d'
SINGLE_DATE_RE = re.compile(r'^{?(?P<date>\d{4}-\d{2}-\d{2})\}?$')
DATES_RANGE_RE = re.compile(r'^{?(?P<first_date>\d{4}-\d{2}-\d{2})\.\.(?P<last_date>\d{4}-\d{2}-\d{2})\}?$')

YT_TOKEN = settings.mrstat['yt_token']
YT_SOURCES_BASE_PATH = '//statbox/heavy-report/disk-audience-actions-v2'
YT_RESULTS_PATH = '//home/mpfs-stat/user_activity_info'
YT_LOG_TABLE = os.path.join(YT_RESULTS_PATH, 'processed_dates')
YT_AGGREGATED_LOG_TABLE = os.path.join(YT_RESULTS_PATH, 'processed_dates_aggregated')

AGGREGATION_DAYS_OFFSET = settings.user_activity_info['aggregation_days_offset']


class YTActivitiInfoSource(object):
    def __init__(self, base_path, as_platform_type, filter_platform=None):
        self._base_path = base_path
        self._as_platform_type = as_platform_type
        self._filter_platform = filter_platform

    @property
    def filter_platform(self):
        if self._filter_platform:
            return self._filter_platform.lower()

    def get_path(self, first_date, last_date):
        if first_date and last_date:
            table_name = '{%s..%s}' % (first_date, last_date)
        elif last_date:
            table_name = '{%s}' % last_date
        else:
            raise ValueError('At least last_date should be defined')
        return os.path.join(self._base_path, table_name)

    @property
    def platform_type(self):
        return self._as_platform_type

    def get_aggregated_table(self, job, first_date, last_date):
        table = job.table(self.get_path(first_date, last_date))
        # берем записи с is_active = True и валидными puid
        table = table.filter(
            filters.equals('is_active', True),
            filters.match('puid', r'^\d+$'),
        )

        if self.filter_platform:
            table = table.filter(filters.defined('platform'))
        # оставляем только нужные колонки
        table = table.project(
            'puid',
            'date',
            platform=extractors.custom(lambda p: p.lower(), 'platform'),
        )
        if self.filter_platform:
            table = table.filter(filters.equals('platform', self.filter_platform))
        aggregations = {
            'first_activity': aggregators.min('date'),
            'last_activity': aggregators.max('date'),
        }
        table = table.groupby('puid').aggregate(**aggregations)
        table = table.project(
            'first_activity',
            'last_activity',
            uid='puid',
            platform_type=extractors.const(self.platform_type),
        )
        return table


def get_first_unprocessed_date(cluster):
    job = cluster.job()
    table = job.table(YT_LOG_TABLE)
    table = table.aggregate(max_last_date=aggregators.max('last_date'))
    table.put(YT_AGGREGATED_LOG_TABLE)
    job.run()
    info = next(yt.read_table(str(job.remote_path(YT_AGGREGATED_LOG_TABLE)), format='json'), None)
    if info is not None:
        return date_from_str(info['max_last_date']) + dt.timedelta(days=1)


def get_current_date(days_offset=0):
    current_date = dt.date.today()
    if days_offset:
        return current_date + dt.timedelta(days=days_offset)
    else:
        return current_date


def date_from_str(date_str):
    return dt.datetime.strptime(date_str, DATE_FORMAT).date()


def is_single_date(raw_date):
    return bool(SINGLE_DATE_RE.match(raw_date))


def is_dates_range(raw_date):
    return bool(DATES_RANGE_RE.match(raw_date))


def parse_dates(raw_date):
    if is_single_date(raw_date):
        date = date_from_str(SINGLE_DATE_RE.match(raw_date).groupdict()['date'])
        return date, date

    if is_dates_range(raw_date):
        match = DATES_RANGE_RE.match(raw_date)
        match_dict = match.groupdict()
        raw_first_date, raw_last_date = match_dict['first_date'], match_dict['last_date']
        return date_from_str(raw_first_date), date_from_str(raw_last_date)

    raise ValueError('Date %s has incorrect format' % raw_date)


def write_yt_log(cluster, first_date, last_date, sharpei_errors_count, missing_uids_count):
    cluster.write(
        YT_LOG_TABLE,
        [Record(
            first_date=str(first_date),
            last_date=str(last_date),
            sharpei_errors_count=sharpei_errors_count,
            missing_uids_count=missing_uids_count,
            timestamp=str(dt.datetime.now())
        )],
        append=True
    )


def load(date, table, dry_run, provide_discounts=False):
    set_yt_proxy()
    cluster = clusters.Hahn(YT_TOKEN).env(templates={
        'tmp': '$tmp_root/$job_uuid',
    })

    if table:
        if not date:
            raise ValueError('Dates must be specified for loading from specified table for stat purposes')
        result_path = table
        first_date, last_date = parse_dates(date)
    else:
        if date:
            first_date, last_date = parse_dates(date)
        else:
            last_date = get_current_date(days_offset=AGGREGATION_DAYS_OFFSET)
            first_date = get_first_unprocessed_date(cluster)
            if first_date is None:
                first_date = last_date
            elif first_date > last_date:
                raise RuntimeError('First unprocessed date %s is greater than %s' % (first_date, last_date))
        result_path = os.path.join(YT_RESULTS_PATH,  '%s - %s' % (first_date, last_date))
        print 'Start YT aggregation, target: %s' % result_path

        job = cluster.job()

        """
        Веб Десктоп+Тач - таблица web-disk
        Windows - таблица desktop & platform=windows
        Mac - таблица desktop & platform=mac
        Android - таблица mobile & platform=android
        iOs - таблица mobile & platform=iOS
        ПП - сейчас попадает в web-disk
        """
        sources = (
            YTActivitiInfoSource(os.path.join(YT_SOURCES_BASE_PATH, 'desktop'), 'windows', filter_platform='windows'),
            YTActivitiInfoSource(os.path.join(YT_SOURCES_BASE_PATH, 'desktop'), 'mac', filter_platform='mac'),
            YTActivitiInfoSource(os.path.join(YT_SOURCES_BASE_PATH, 'mobile'), 'android', filter_platform='android'),
            YTActivitiInfoSource(os.path.join(YT_SOURCES_BASE_PATH, 'mobile'), 'ios', filter_platform='ios'),
            YTActivitiInfoSource(os.path.join(YT_SOURCES_BASE_PATH, 'web-disk'), 'web'),
        )

        table = job.concat(*(s.get_aggregated_table(job, first_date, last_date) for s in sources))

        table.put(result_path)
        job.run()

    iterator = yt.read_table(result_path, format='json')

    sharpei_errors_count = missing_uids_count = 0
    error_info_container = ErrorInfoContainer()
    if not dry_run:
        print 'Start loading to DB from %s' % result_path
        activity_info_update_results = UserActivityInfoDAO().bulk_update_activity_dates_and_fetch_closest_activity_dates(
            iterator, error_info_container)
        for doc in activity_info_update_results:
            if not provide_discounts:
                continue
            provide_discount_if_needed(doc['uid'], doc['activity_before_update'], doc['activity_after_update'])
        sharpei_errors_count = error_info_container.sharpei_errors_count
        missing_uids_count = error_info_container.missing_uids_count
        print 'write stat to YT'
        write_yt_log(cluster, first_date, last_date, sharpei_errors_count, missing_uids_count)
    return sharpei_errors_count, missing_uids_count
