import datetime
import logging
from itertools import imap
from uuid import UUID

import numpy as np
from flask import current_app as app

from jafar import advisor_mongo, request_cache, clickhouse
from jafar.datasets import get_dataset_processor
from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock
from jafar.utils import date_to_datetime
from jafar.utils.structarrays import DataFrame
from jafar_yt.profile_dataframe_converter import UserProfileConverter
from jafar_yt.usage_stats import STAT_FIELDS

logger = logging.getLogger(__name__)

EVENT_NAME_BY_COUNTER = {
    'rec_view_count': 'rec_view',
    'app_install_count': 'App_install',
    'app_launch_count': 'app_launch'
}


def validate_users(users):
    """
    In case of user id not being a valid UUID, skip it.
    """
    result = []
    for user in users:
        try:
            result.append(UUID(user))
        except TypeError:
            continue
    return result


class OfflineReadDataBlock(SingleContextBlock):
    """
    This block reads user data (installs and user features)
    from the chosen dataset.
    """

    def __init__(self, output_data, user_features=None, item_features=None, sample=None, sampling_frame=None):
        """
        :param output_data: frames to read
        :param user_features: specific columns to read from FRAME_KEY_USER_FEATURES frame
        :param item_features: specific columns to read from FRAME_KEY_ITEM_FEATURES frame
        """
        super(OfflineReadDataBlock, self).__init__(
            input_data=None, output_data=output_data, destroyed_data=None
        )
        available_frame_keys = ids.data_frame_keys.keys()
        assert output_data, 'output_data must be non-empty'
        assert set(output_data).issubset(available_frame_keys), \
            'Only the following output frames are supported: {}'.format(available_frame_keys)
        self.output_data = output_data
        self.user_features = user_features
        self.item_features = item_features
        self.sample = sample
        self.sampling_frame = sampling_frame

    def apply(self, context, train):
        for frame_key in self.output_data:
            if frame_key not in context.data:
                context.data[frame_key] = self.load_data(context, frame_key)

        if self.sampling_frame and self.sample and (0 < self.sample < 1.):
            frame = context.data[self.sampling_frame]
            users = np.unique(frame['user'])
            sampled_users = np.random.choice(users, int(len(users) * self.sample), replace=False)
            context.data[self.sampling_frame] = frame[np.in1d(frame['user'], sampled_users)]
            logger.debug(
                "Sampled %s%% users from %s: %s rows left (out of %s)",
                self.sample * 100, self.sampling_frame, len(context.data[self.sampling_frame]), len(frame)
            )

        return context

    def load_data(self, context, frame_key):
        dataset = ids.data_frame_keys[frame_key]
        dataset_processor = get_dataset_processor(dataset)
        logger.info('Reading {} for {}'.format(frame_key, context.country))
        if frame_key in ids.user_feature_keys:
            df = dataset_processor.get_user_features(country=context.country, features=self.user_features)
        elif frame_key in ids.item_feature_keys:
            df = dataset_processor.get_item_features(country=context.country, features=self.item_features)
        else:
            df = dataset_processor.get_data(country=context.country)
        logger.info('Read {} {} for {}'.format(len(df), frame_key, context.country))
        return df


class OnlineReadDataBlock(SingleContextBlock):
    """
    This blocks reads user data from advisor's mongo profile collection.
    Hence, output_frame is fixed to be FRAME_KEY_ADVISOR_MONGO_DATA only.

    Contrary to its offline sibling, this block also needs an input frame containing
    user ids.
    """

    def __init__(self, installs_frame=ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS,
                 user_features_frame=ids.FRAME_KEY_USER_FEATURES,
                 removed_items_frame=ids.FRAME_KEY_REMOVED_ITEMS,
                 disliked_items_frame=ids.FRAME_KEY_DISLIKED_ITEMS,
                 user_apps_only=False):
        """
        :param user_apps_only: do not put system apps to FRAME_KEY_ADVISOR_MONGO_INSTALLS frame
        """
        assert installs_frame in (ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS,)
        self.input_frame = ids.FRAME_KEY_TARGET
        self.output_data = [
            installs_frame,
            user_features_frame
        ]
        self.user_apps_only = user_apps_only
        self.installs_frame = installs_frame
        self.removed_items_frame = removed_items_frame
        self.disliked_items_frame = disliked_items_frame
        self.user_features_frame = user_features_frame
        super(OnlineReadDataBlock, self).__init__(
            input_data=[self.input_frame], output_data=self.output_data, destroyed_data=None
        )

    @staticmethod
    def get_recommender_country(user_features):
        """
        Selects country based on user's LBS info.

        * If user's country in COUNTRIES - use user's country
        * Try to choose country using RECOMMENDER_COUNTRY_MAP
        * Use DEFAULT_COUNTRY otherwise

        :param user_features: Frame with user country
        :return: country of the recommender model
        """
        default = app.config['DEFAULT_COUNTRY']
        if len(user_features) == 0:
            return default
        user_country = user_features['lbs_country'][0]
        if user_country in app.config['COUNTRIES']:
            return user_country
        return app.config['RECOMMENDER_COUNTRY_MAP'].get(user_country, default)

    def apply(self, context, train):
        if all(frame in context.data for frame in self.output_data):
            return context

        query_frame = context.data[self.input_frame]
        users = query_frame['user'].unique()
        users = validate_users(users)
        if len(users) != 1:
            raise Exception('OnlineReadDataBlock can manage only with 1 user')
        user = users[0]

        @request_cache.memoize()
        def get_mongo_users_data(user_id, user_apps_only):
            logger.debug("Loading user data from mongodb")
            user_profile = advisor_mongo.db.profile.find_one({'_id': user_id})

            installs = DataFrame.empty(UserProfileConverter.get_installs_dtype())
            user_features = DataFrame.empty(UserProfileConverter.get_user_features_dtype())
            removed_items = DataFrame.empty(UserProfileConverter.get_removed_items_dtype())
            disliked_items = DataFrame.empty(UserProfileConverter.get_disliked_items_dtype())

            if user_profile:
                converter = UserProfileConverter(user_profile)
                try:
                    installs = converter.get_installs(user_apps_only)
                    user_features = converter.get_user_features()
                    removed_items = converter.get_removed_items()
                    disliked_items = converter.get_disliked_items()
                except UserProfileConverter.IncompleteProfile:
                    logger.warning('User with DeviceID=%s has incomplete profile', user_id)
            else:
                logger.warning('User with DeviceID=%s not found', user_id)
            return installs, user_features, removed_items, disliked_items

        installs, user_features, removed_items, disliked_items = get_mongo_users_data(user, self.user_apps_only)

        context.data[self.installs_frame] = installs
        context.data[self.user_features_frame] = user_features
        context.data[self.removed_items_frame] = removed_items
        context.data[self.disliked_items_frame] = disliked_items
        context.country = self.get_recommender_country(user_features)
        return context


class DatarameFilteringBlock(SingleContextBlock):
    """
    Allows adjustment of different data sources
    Example: intersects item names, and leaves only items appearing in both sources
    """

    def __init__(self, input_data1_key, input_data2_key, field,
                 filter_data1=True, filter_data2=True):
        super(DatarameFilteringBlock, self).__init__(
            input_data=(input_data1_key, input_data2_key),
            output_data=(input_data1_key, input_data2_key),
            destroyed_data=None
        )
        assert filter_data1 or filter_data2, 'At least one frame should be filtered'
        self.input_data_key1 = input_data1_key
        self.input_data_key2 = input_data2_key
        self.field = field
        self.filter_data1 = filter_data1
        self.filter_data2 = filter_data2

    def apply(self, context, train):
        if self.input_data_key1 in context.data and self.input_data_key2 in context.data:
            data1 = context.data[self.input_data_key1]
            data2 = context.data[self.input_data_key2]

            common_values = np.intersect1d(data1[self.field], data2[self.field])

            if self.filter_data1:
                data1_common_indices = data1[self.field].is_in(common_values)
                context.data[self.input_data_key1] = data1[data1_common_indices]

            if self.filter_data2:
                data2_common_indices = data2[self.field].is_in(common_values)
                context.data[self.input_data_key2] = data2[data2_common_indices]

            logger.info('Field %s was intersected between %s and %s, rows left %d and %d',
                        self.field, self.input_data_key1, self.input_data_key2, len(context.data[self.input_data_key1]),
                        len(context.data[self.input_data_key2]))
        else:
            logger.debug(
                'Keys %s are not in context, skipping DatarameFilteringBlock',
                [key for key in (self.input_data_key1, self.input_data_key2)
                 if key not in context.data]
            )

        return context


class OnlineUsageStatsData(SingleContextBlock):
    def __init__(self, counters, output_frame=ids.FRAME_KEY_USAGE_STATS, dates=None, prefix=None):

        assert len(counters) > 0, 'Please specify at least one counter'
        for counter in counters:
            assert counter in STAT_FIELDS, 'counter %s not in usage_stats fields' % counter
        self.counters = counters
        self.output_frame = output_frame
        self.dates = self.check_dates_format(dates) if dates is not None else None

        self.prefix = prefix
        super(OnlineUsageStatsData, self).__init__(
            input_data=[ids.FRAME_KEY_TARGET],
            output_data=[output_frame],
        )

    @staticmethod
    def check_dates_format(dates):
        """
        ensure that dates list stored as datetime since we need it for mongo queries
        :param dates:
        :return:
        """

        result = []
        for date in dates:
            if isinstance(date, datetime.datetime):
                result.append(datetime.datetime(date.year, date.month, date.day))
            elif isinstance(date, datetime.date):
                result.append(date_to_datetime(date))
            else:
                raise ValueError('Dates must be in date or datetime format')
        return result

    def apply(self, context, train):
        target_frame = context.data[ids.FRAME_KEY_TARGET]
        users = list(np.unique(target_frame['user']))
        users = validate_users(users)

        @request_cache.memoize()
        def get_users_usage_stats_from_clickhouse(users, dates, counters, prefix):
            logger.debug('Loading usage stats data from clickhouse')

            counters_query = ', '.join("sum({0}) as {0}".format(counter) for counter in counters)
            query = """
                SELECT
                    item,
                    user,
                    {counters}
                FROM {db}.{usage_counters}
                WHERE user IN %(users)s
                GROUP BY user, item
            """.format(db=app.config['CLICKHOUSE_DATABASE'],
                       counters=counters_query,
                       usage_counters=app.config['CLICKHOUSE_USAGE_COUNTERS_TABLE'])

            usage_stats = clickhouse.execute(query, {'users': tuple(users)})
            # filter rows with all counters = 0
            usage_stats = filter(lambda row: any(row[2:]), usage_stats)

            result = list()
            # user column UUID -> str
            for row in imap(list, usage_stats):
                row[1] = str(row[1])
                result.append(tuple(row))

            dtype = [('item', np.object), ('user', np.object)]
            for counter in counters:
                column = counter if prefix is None else '{}_{}'.format(prefix, counter)
                dtype.append((column, np.int32))
            return DataFrame.from_structarray(np.array(result, dtype=dtype))

        usage_stats = get_users_usage_stats_from_clickhouse(users, self.dates, self.counters, self.prefix)
        logger.debug('Found %d stat records for %d users', len(usage_stats), len(users))
        context.data[self.output_frame] = usage_stats
        return context
