import logging
from datetime import datetime
from math import log

import numpy as np
import yt.wrapper as yt
from flask import current_app as app
from flask_script import Command, Option
from nile.api.v1 import aggregators as na, files as nfi
from nile.api.v1 import extractors as ne, filters as nf
from nile.processing.record import Record
from qb2.api.v1 import filters as qf
from qb2.api.v1 import typing as qt

from jafar.utils.io import get_cluster
from jafar_yt.utils.helpers import day_before, place_launched, EventValueMapperSelective, unify_uuid, mapper_wrapper

logger = logging.getLogger(__name__)


class ProperMetricsStatsBase(Command):
    option_list = (
        Option(
            '--dataset-source', dest='dataset_source', required=False,
            default=None, type=str, action='store', help='path to yt table with records (user, item, score)'
        ),
        Option(
            '--baseline-mode', dest='baseline_mode', required=False,
            default='general', type=str, choices=['general', 'random'], action='store', help='mode for baseline: general or random'
        )
    )

    def run(self, dataset_source, baseline_mode):
        if dataset_source is not None:
            logger.info('Started testing table: %s', dataset_source)
            self.run_test(dataset_source)
        else:
            logger.info('Started counting %s baseline', baseline_mode)
            self.run_general(baseline_mode)

    @staticmethod
    def prepare_launches(job):
        day_from = day_before(app.config['PROPER_DAYS_INTERVAL']).isoformat()
        day_to = day_before(1).isoformat()
        data_path = yt.ypath_join(app.config['YT_METRIKA_PATH_1_DAY'], '{%s..%s}' % (day_from, day_to))

        return job.table(data_path).filter(
            nf.equals('APIKey', '37460'),
            nf.equals('EventName', 'app_launch'),
            qf.contains('Clids_Values', '2247990')
        ).map(
            EventValueMapperSelective(dict(packageName=qt.Optional[qt.Unicode],
                                           action=qt.Optional[qt.Unicode],
                                           place=qt.Yson))
        ).filter(
            nf.custom(lambda x: (x or 'run') == 'run', 'action'),
            qf.defined('place', 'DeviceID', 'packageName')
        ).project(
            user=ne.custom(unify_uuid, 'DeviceID').add_hints(type=qt.String),
            item='packageName',
            place=ne.custom(place_launched, 'place').add_hints(type=qt.SizedTuple[qt.String, qt.String])
        ).filter(
            qf.not_(qf.equals('place', ('homescreens', 'dock'))),
            qf.defined('user')
        )

    @staticmethod
    def prepare_general_list(stream, mode):
        stream = stream.groupby(
            'item'
        ).aggregate(
            count=na.count()
        )
        if mode == 'general':
            order = lambda x: -x
        else:
            order = lambda x: np.random.random_sample()
        return stream.project(
            'item', 'count',
            order=ne.custom(order, 'count').add_hints(type=qt.Float)
        ).sort(
            'order'
        )

    @staticmethod
    def prepare_user_popularities(stream):
        users_limit = app.config['PROPER_USERS']
        users_min_launches = app.config['PROPER_MIN_LAUNCHES']

        users = stream.project(
            'user'
          ).groupby(
            'user'
          ).aggregate(
            length=na.count()
          ).filter(
            nf.custom(lambda x: x > users_min_launches, 'length')
          ).take(users_limit)

        return stream.join(
            users,
            by='user',
            type='inner',
          ).groupby(
            'user', 'item'
          ).aggregate(
            count=na.count()
          )

    def run_general(self, baseline_mode):
        yt_table_source = app.config['YT_PROPER_PATH']
        job = get_cluster(backend='yql').job().env(templates=dict(checkpoints_root=yt_table_source))
        result_path = yt.ypath_join(yt_table_source, 'ndcg_%s' % datetime.utcnow().isoformat())

        # 1. Prepare logs: table of (user, item) records
        launches_counts = self.prepare_launches(
            job
        ).checkpoint('launches_counts')

        # 2. Make list of items ordered by popularity or with random order
        general = self.prepare_general_list(
            launches_counts, baseline_mode
        ).checkpoint(baseline_mode)

        # 3. Make list of (user, item, count)
        personal = self.prepare_user_popularities(
            launches_counts
        ).checkpoint('user_popularities')

        # 4. Count ndcg
        top_k = app.config['PROPER_TOP']
        personal.groupby(
            'user'
        ).reduce(
            MetricReducerGeneral(top_k),
            files=[nfi.StreamFile(general, 'general')]
        ).put(result_path)

        job.run()
        logger.info('results: %s', result_path)

    @staticmethod
    def join_user_popularities_test(launches_counts, scores):
        launches_counts = launches_counts.groupby(
            'user', 'item'
        ).aggregate(
            count=na.count()
        )

        launches_counts = launches_counts.groupby(
            'user'
        ).reduce(
            ReduceToList(dict(user=qt.String, items_list=qt.Yson), 'count')
        )

        scores = scores.groupby(
            'user'
        ).reduce(
            ReduceToList(dict(user=qt.String, items_list=qt.Yson), 'score')
        ).project(
            'user',
            items_list_test='items_list'
        )
        return launches_counts.join(
            scores,
            by=['user'],
            type='inner',
            assume_unique=True,
        )

    def run_test(self, dataset_source):
        yt_table_source = app.config['YT_PROPER_PATH']
        job = get_cluster(backend='yql').job().env(templates=dict(checkpoints_root=yt_table_source))
        result_path = yt.ypath_join(yt_table_source, 'ndcg_%s' % datetime.utcnow().isoformat())

        scores = job.table(dataset_source)

        # 1. Prepare logs: table of (user, item) records
        launches_counts = self.prepare_launches(
            job
        ).checkpoint('launches_counts')

        # 2. Join scores with logs
        stream = self.join_user_popularities_test(
            launches_counts, scores
        ).checkpoint('users')

        # 3. Count ndcg
        top_k = app.config['PROPER_TOP']
        stream.map(
            MetricMapperTest(top_k)
        ).put(result_path)

        job.run()
        logger.info('results: %s', result_path)


@mapper_wrapper
class MetricReducerGeneral(object):
    schema = dict(ndcg=qt.Float)

    def __init__(self, top_k):
        self.top_k = top_k

    def __call__(self, groups, **options):
        for key, records in groups:
            personal = [(record['item'], record['count']) for record in records]
            personal.sort(key=lambda x: x[1], reverse=True)
            personal_dict = dict(personal)

            general = options['file_streams']['general']
            limit = 700  # to prevent overflow in float
            g = lambda x: 2 ** x - 1 if x < limit else 2 ** limit - 1
            d = lambda i: float(1. / log(i + 2))
            final_list = []

            for record in general:
                item = record.item
                if item in personal_dict:
                    final_list.append(item)
                if len(final_list) >= self.top_k:
                    break

            dcg = sum([g(personal_dict[item]) * d(i) for i, item in enumerate(final_list)])
            max_dcg = sum([g(value) * d(i) for i, (item, value) in enumerate(personal[:self.top_k])])
            result = dcg * 1. / max_dcg
            yield Record(key, ndcg=result)


@mapper_wrapper
class MetricMapperTest(object):
    schema = dict(ndcg=qt.Float)

    def __init__(self, top_k):
        self.top_k = top_k

    def __call__(self, records):
        limit = 700  # to prevent overflow in float
        g = lambda x: 2 ** x - 1 if x < limit else 2 ** limit - 1
        d = lambda i: float(1. / log(i + 2))
        for record in records:
            counts = dict(record['items_list'])
            prediction_list = record['items_list_test'][:self.top_k]
            dcg = sum([g(counts.get(item, 0)) * d(i) for i, (item, score) in enumerate(prediction_list)])
            max_dcg = sum([g(count) * d(i) for i, (item, count) in enumerate(record['items_list'][:self.top_k])])
            yield Record(ndcg=1. * dcg / max_dcg)


@mapper_wrapper
class ReduceToList(object):
    def __init__(self, schema, field):
        self.schema = schema
        self.field = field

    def __call__(self, groups):
        for key, records in groups:
            items_list = [(record['item'], record[self.field]) for record in records]
            items_list.sort(key=lambda x: -x[1])
            yield Record(key, items_list=items_list)
