from flask import current_app as app
from flask_script import Command, Option

import datetime
import yt.wrapper as yt
import logging
import yenv
import numpy as np

from jafar import jafar_mongo
from jafar.utils.io import get_cluster
from jafar_yt import update_trend_data as yt_operations

from nile.api.v1 import filters

logger = logging.getLogger(__name__)

PATH_TREND_EVENT_COUNTS = '//home/advisor/trending/event_counts'


def format_table_path(date):
    return yt.ypath_join(PATH_TREND_EVENT_COUNTS, date.isoformat())


class TrendDataUpdaterCommand(Command):
    option_list = (
        Option('--force', '-f', dest='force', default=False, action='store_true'),
    )

    def run_yt(self):
        logger.info('Parsing events for trending data from Metrika logs')
        days_count = max(app.config['TRENDING_INTERVALS_IN_DAYS']) + 1
        all_dates = [datetime.date.today() - datetime.timedelta(days=i + 1) for i in xrange(days_count)]
        missed_dates = [date for date in all_dates if not yt.exists(format_table_path(date))]

        logger.info('Loading events for {} missed dates'.format(len(missed_dates)))
        for date in missed_dates:
            cluster = get_cluster()
            job = cluster.job()
            job.table(
                yt.ypath_join(app.config['YT_METRIKA_PATH_1_DAY'], date.isoformat())
            ).filter(
                filters.equals('APIKey', app.config['YT_LAUNCHER_API_KEY'])
            ).map(
                yt_operations.mapper
            ).groupby(
                'item'
            ).reduce(
                yt_operations.CountReducer(date.isoformat())
            ).put(
                format_table_path(date)
            )
            job.run()

        logger.info('Cleaning up old tables')
        all_tables_set = {date.isoformat() for date in all_dates}
        old_tables = set(yt.list(PATH_TREND_EVENT_COUNTS)) - all_tables_set
        for table in old_tables:
            table_path = yt.ypath_join(PATH_TREND_EVENT_COUNTS, table)
            logger.info('Removing table {}'.format(table_path))
            yt.remove(table_path)

    @staticmethod
    def get_trend_slope(counts):
        """
        calculates trend slope for normalized event counts
        :param counts: events counts by days
        :return:
        """
        counts = np.array(counts, dtype=np.float32)
        counts = counts / np.sum(counts)
        return np.polyfit(xrange(len(counts)), counts, deg=1)[0]

    def save_trending_stats(self, tables, count_stat):
        """
        save trending statistics from YT tables to mongo
        :param tables: YT table
        :param count_stat: name of count statistics ('installs' or 'counts')
        :return:
        """

        logger.info("Loading '{}' statistics from '{}' YT tables".format(count_stat, tables))
        cluster = get_cluster()
        with yt.TempTable() as temp_table:
            job = cluster.job()
            job.table(
                '{' + ','.join(tables) + '}'
            ).groupby(
                'item'
            ).reduce(
                yt_operations.TrendReducer(len(tables), count_stat)
            ).put(
                temp_table
            )
            job.run()

            items = []
            for row in yt.read_table(temp_table, format='json'):
                package_name = row['item']
                counts = row['counts']

                total_counts = sum(counts)
                assert total_counts > 0
                item = {'item': package_name}
                for days in app.config['TRENDING_INTERVALS_IN_DAYS']:
                    trend_field = app.config['TRENDING_MONGO_FIELD'].format(days)
                    item[trend_field] = self.get_trend_slope(counts[-days:])
                item['average_daily_counts'] = float(total_counts) / len(counts)  # average event counts a day
                items.append(item)

        logger.info("Loaded statistics and calculated trends for {} apps".format(len(items)))
        collection_name = app.config['TRENDING_MONGO_COLLECTION'].format(count_stat)
        collection_name_tmp = '{}_tmp'.format(collection_name)
        jafar_mongo.db[collection_name_tmp].drop()
        for days in app.config['TRENDING_INTERVALS_IN_DAYS']:
            trend_field = app.config['TRENDING_MONGO_FIELD'].format(days)
            jafar_mongo.db[collection_name_tmp].create_index([(trend_field, -1)], background=True)

        jafar_mongo.db[collection_name_tmp].insert_many(items)
        jafar_mongo.db[collection_name_tmp].rename(collection_name, dropTarget=True)

    def run_local(self):
        tables = sorted(yt.list(PATH_TREND_EVENT_COUNTS, absolute=True))
        tables = tables[:max(app.config['TRENDING_INTERVALS_IN_DAYS'])]
        for count_stat in app.config['TRENDING_COUNT_STATS']:
            self.save_trending_stats(tables, count_stat)

    def run(self, force):
        yt.update_config(app.config['YT_CONFIG'])
        if yenv.type != 'production' and not force:
            logger.warn(
                'Not running YT collection task in {} environment by default. '
                'Set --force flag if you really want it.'.format(yenv.type)
            )
        else:
            self.run_yt()

        self.run_local()
