import logging
from datetime import date, datetime

import yt.wrapper as yt
from flask import current_app as app
from nile.api.v1 import aggregators as na
from nile.api.v1 import extractors as ne, filters as nf
from qb2.api.v1 import filters as qf
from qb2.api.v1 import typing as qt

from jafar.datasets.base import BaseDatasetProcessor
from jafar.utils.io import get_cluster
from jafar_yt.update_datasets import AdvisorMongoInstallMapper
from jafar_yt.utils.helpers import day_before, place_launched, EventValueMapperSelective, unify_uuid
from jafar_yt.vanga_utils import SampleNegativesReducer, features_reducer, GeneralReducer

logger = logging.getLogger(__name__)

USERS_UPDATED_AT_DAYS = 30 * 6


class VangaDatasetProcessor(BaseDatasetProcessor):
    source = 'vanga'

    def __init__(self):
        super(VangaDatasetProcessor, self).__init__()
        self.yt_table_source = app.config['YT_VANGA_PATH']

    @staticmethod
    def prepare_basket(job, country, users_limit=None):
        def is_after_the_day(update_at):
            return date.fromtimestamp(update_at['$date']/1e3) > day_before(USERS_UPDATED_AT_DAYS)

        stream = job.table(
            app.config['YT_PATH_USERS_FULL']
        ).filter(  # sample possibly active users
            nf.custom(is_after_the_day, 'updated_at')
        )

        users_limit = users_limit or app.config['VANGA_USERS']

        if users_limit:
            stream = stream.take(users_limit)

        return stream.map(
            AdvisorMongoInstallMapper(country),
            intensity='cpu'
        ).project(
            'user',
            'item',
            'install_time'
        )

    @staticmethod
    def prepare_app_launches(job, date_from=None, date_to=None):
        # parse app launches from metrika to advisor YT
        day_from = date_from or day_before(app.config['VANGA_DAYS_INTERVAL'] +
                                           app.config['VANGA_SEED_DAYS']).isoformat()
        day_to = date_to or day_before(1).isoformat()

        return job.table(
            yt.ypath_join(app.config['YT_METRIKA_PATH_1_DAY'], '{%s..%s}' % (day_from, day_to))
        ).filter(
            nf.equals('EventName', 'app_launch')
        ).project(
            'DeviceID',
            'StartTime',
            'StartTimeZone',
            'StartTimestamp',
            'StartDate',
            'EventValue'
        ).map(
            EventValueMapperSelective(dict(
                packageName=qt.Optional[qt.Unicode],
                className=qt.Optional[qt.Unicode],
                action=qt.Optional[qt.Unicode],
                place=qt.Yson))
        ).filter(
            nf.custom(lambda x: (x or 'run') == 'run', 'action'),
            qf.defined('place', 'DeviceID', 'StartTime', 'StartTimestamp', 'packageName', 'StartDate')
        ).project(
            'packageName',
            'className',
            user=ne.custom(unify_uuid, 'DeviceID').add_hints(type=qt.String),
            timestamp=ne.custom(int, 'StartTimestamp').add_hints(type=qt.Integer),
            date='StartDate',
            item='packageName',
            weekday=ne.custom(lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M:%S').weekday(), 'StartTime')
                .add_hints(type=qt.Integer),
            hour=ne.custom(lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M:%S').hour, 'StartTime')
                .add_hints(type=qt.Integer),
            place=ne.custom(place_launched, 'place').add_hints(type=qt.SizedTuple[qt.String, qt.String])
        ).filter(
            qf.defined('user'),
            qf.not_(qf.equals('place', ('homescreens', 'dock')))
        )

    @staticmethod
    def sample_negatives(stream, basket_stream, negatives_count=None):
        """
        sample 'fake' app launches for apps that were installed but not launched
        """
        negatives_count = negatives_count or app.config['VANGA_IMPLICIT_NEGATIVES']

        if app.config['VANGA_USE_CLASSNAMES']:
            stream = stream.project(
                ne.all(['item', 'className']),
                real_item=ne.custom(u'{}/{}'.format, 'item', 'className').add_hints(type=qt.String)
            )
        else:
            stream = stream.project(
                ne.all('item'),
                real_item='item',
            )

        stream = stream.join(
            basket_stream,
            by='user',
            type='inner',
        ).filter(
            nf.custom(lambda x, y: x < y, 'install_time', 'timestamp')
        ).project(
            ne.all(),
            value=ne.custom(lambda x, y: int(x == y), 'real_item', 'item').add_hints(type=qt.Integer)
        )

        if negatives_count > 0:
            stream = stream.groupby(
                'user',
                'timestamp'
            ).reduce(
                SampleNegativesReducer(negatives_count)
            )

        return stream

    @staticmethod
    def enrich_with_features(dataset_stream):
        return dataset_stream.groupby(
            'user'
        ).sort(
            'timestamp'
        ).reduce(
            features_reducer,
            intensity='ultra_cpu'
        )

    @staticmethod
    def add_classnames_basket(train_launches, test_launches, basket):
        launches = train_launches.concat(
            test_launches
        )
        # select most common className for packageName
        top_class_name = launches.groupby(
            'item',
            'className'
        ).aggregate(
            value=na.count()
        ).groupby(
            'item'
        ).aggregate(
            className=na.last('className', 'value')
        ).checkpoint('top_class_name')

        # join to get basket of launched items
        launched_basket = basket.join(
            launches.unique('user', 'item', 'className').project('user', 'item', 'className'),
            by=['user', 'item'],
            type='inner',
            assume_unique=True,
            assume_defined=True
        ).checkpoint('launched_basket')

        # to get not launched items
        non_launched_basket = basket.join(
            launches,
            by=['user', 'item'],
            type='left_only',
            assume_unique_left=True,
            assume_defined=True
        ).join(
            top_class_name,
            by='item',
            type='inner',
            assume_small_right=True,
            assume_unique=True,
            assume_defined=True
        ).checkpoint('non_launched_basket')

        return launched_basket.concat(
            non_launched_basket
        ).project(
            'user', 'install_time',
            item=ne.custom(u'{}/{}'.format, 'item', 'className').add_hints(type=qt.String)
        )

    def update_interactions(self, country):
        job = get_cluster(backend='yql').job().env(templates=dict(checkpoints_root=self.yt_table_source))

        # 1. Prepare basket: table of (user, item) records
        basket = self.prepare_basket(
            job,
            country
        ).checkpoint('basket')

        # 2. Filter app_launches from appmetrika logs
        train_launches = self.prepare_app_launches(
            job,
            date_to=day_before(app.config['VANGA_DAYS_TEST'] + 1).isoformat()
        ).checkpoint('train_launches')

        test_launches = self.prepare_app_launches(
            job,
            date_from=day_before(app.config['VANGA_DAYS_TEST']).isoformat()
        ).checkpoint('test_launches')

        if app.config['VANGA_USE_CLASSNAMES']:
            train_launches = train_launches.filter(qf.defined('className'))
            test_launches = test_launches.filter(qf.defined('className'))
            basket = self.add_classnames_basket(train_launches, test_launches, basket)

        # 3. Sample negatives by joining launches and basket.
        train_dataset = self.sample_negatives(
            train_launches,
            basket
        ).project(
            ne.all(),
            is_test=ne.const(0)
        )

        # for the test we sample whole user basket
        test_dataset = self.sample_negatives(
            test_launches,
            basket,
            -1
        ).project(
            ne.all(),
            is_test=ne.const(1)
        )

        # 4. Adding continuous features
        interval = app.config['VANGA_DAYS_INTERVAL']
        common = train_dataset.concat(
            test_dataset
        )

        general = common.filter(
            nf.equals('value', 1)
        ).project(
            'user',
            'item',
            'hour',
            'weekday'
        ).groupby(
            'item'
        ).reduce(
            GeneralReducer('item')
        ).project(
            'item',
            total='personal',
            total_hourly='hourly',
            total_weekly='weekly',
        )

        common = common.call(
            self.enrich_with_features
        ).filter(  # cut off launches of first 'seed' days for more realistic 'recent' feature
            nf.custom(lambda x: date.fromtimestamp(x) > day_before(interval),
                      'timestamp')
        ).join(
            general,
            by='item',
            assume_unique_right=True,
            assume_small_right=True,
            assume_defined=True,
            allow_undefined_keys=False
        ).project(
            ne.all(),
            total_hourly=ne.custom(lambda x, y: x.get(str(y), 0), 'total_hourly', 'hour').add_hints(type=qt.Integer),
            total_weekly=ne.custom(lambda x, y: x.get(str(y), 0), 'total_weekly', 'weekday').add_hints(type=qt.Integer)
        )

        now = datetime.utcnow().isoformat()
        train_path = yt.ypath_join(self.yt_table_source, '%s_%s' % (now, 'train'))
        test_path = yt.ypath_join(self.yt_table_source, '%s_%s' % (now, 'test'))

        # train
        common.filter(
            nf.equals('is_test', 0)
        ).project(
            ne.all('is_test')
        ).put(
            train_path
        )

        # test
        common.filter(
            nf.equals('is_test', 1)
        ).project(
            ne.all('is_test')
        ).put(
            test_path
        )

        job.run()

        # link everything to central place
        yt.link(train_path, '%s_%s' % (self.yt_table_result, 'train'), force=True)
        yt.link(test_path, '%s_%s' % (self.yt_table_result, 'test'), force=True)
