import json
import random
from collections import defaultdict
from datetime import datetime, timedelta
from uuid import UUID

from nile.api.v1 import Record
from nile.api.v1 import with_hints, extended_schema, multischema
from qb2.api.v1 import typing as qt

from jafar_yt.utils.helpers import mapper_wrapper


def app_launches_mapper(records):
    """
    just grab app launch events from metrika
    """

    for record in records:
        try:
            user = str(UUID(record['DeviceID']))
            event_value = json.loads(record['EventValue'])
        except ValueError:
            continue

        dt = datetime.strptime(record['StartTime'], '%Y-%m-%d %H:%M:%S')
        dt += timedelta(0, int(record['StartTimeZone'] or 0))
        yield Record(user=user,
                     timestamp=int(record['StartTimestamp']),
                     date=record['StartDate'],
                     item=event_value.get('packageName'),
                     weekday=dt.weekday(),
                     hour=dt.hour)


class TargetMapper(object):
    def __init__(self, max_negative_count=float('inf')):
        self.max_negative_count = max_negative_count

    @staticmethod
    def get_record(basket, item, value, tmp, record):
        return Record(value=value,
                      personal=basket[item].get('personal') or 0,
                      weekly=(basket[item].get('weekly') or {}).get(str(record['weekday']), 0),
                      hourly=(basket[item].get('hourly') or {}).get(str(record['hour']), 0),
                      time_from_install=record['timestamp'] - basket[item]['install_time'],
                      **tmp)

    def __call__(self, records):
        for record in records:
            # iterate install apps to sample some of them as negative examples (launched=0)
            basket = record.get('basket', {})
            negatives = basket.keys()
            item = record['item']

            try:
                negatives.remove(item)
            except ValueError:  # skip items that are not in basket
                continue

            # TODO: maybe we should choose items with launches>0
            if 0 <= self.max_negative_count < len(negatives):
                negatives = random.sample(negatives, self.max_negative_count)

            common = {i: record[i] for i in ['user', 'item', 'timestamp']}

            yield self.get_record(basket, item, 1, common, record)
            for item in negatives:
                common['item'] = item
                yield self.get_record(basket, item, 0, common, record)


@mapper_wrapper
class SampleNegativesReducer(object):
    schema = extended_schema()

    def __init__(self, max_negative_count=float('inf')):
        self.max_negative_count = max_negative_count

    def __call__(self, groups):
        for key, records in groups:
            true_launch = None
            negatives = []
            for record in records:
                if record['value']:
                    true_launch = record
                else:
                    negatives.append(record)

            if not true_launch:
                continue

            if 0 <= self.max_negative_count < len(negatives):
                negatives = random.sample(negatives, self.max_negative_count)

            for record in [true_launch] + negatives:
                yield record


def basket_reducer(groups):
    for key, group in groups:
        items = [x.to_dict() for x in group.fields('item', 'install_time', 'app_preinstalled')]
        yield Record(key, basket=items)


def recent_reducer(groups):
    for key, records in groups:
        history = {}
        launches = 0
        for r in records:
            yield Record(r, recent=((launches - history[r['item']] - 1) if (r['item'] in history) else -1))
            if r['value'] > 0:
                history[r['item']] = launches
                launches += 1


def _get_launcher(basket):
    for item in basket:
        if item['item'] == 'com.yandex.launcher':
            return item


def launcher_preinstall(basket):
    launcher = _get_launcher(basket)
    return launcher and launcher['app_preinstalled']


def launcher_time_elapsed(basket, timestamp):
    launcher = _get_launcher(basket)
    return launcher and int(not launcher['app_preinstalled']) and (timestamp - launcher['install_time'])


def retrieve_app_features(basket):
    for idx, app in enumerate(sorted(basket, key=lambda x: -x['install_time'])):
        app['install_order_reversed'] = idx

    return basket


def install_time_elapsed(item, basket, timestamp):
    return timestamp - [app for app in basket if app['item'] == item][0]['install_time']


class DictReducer(object):
    """
    Reduces all values of given context to a dict.
    ({'user':1, 'item': 2},
    [{'weekly': 100, 'weekday': 2}, {'weekly': 10, 'weekday': 3}]) => {..., weekly: {'2': 100, '3': 10}}
    """
    def __init__(self, context, resulting_feature):
        self.context = context
        self.resulting_feature = resulting_feature

    def __call__(self, groups):
        for key, records in groups:
            feature = {str(r[self.context]): r[self.resulting_feature]  # convert str to avoid unhashable types
                       for r in records}
            yield Record(key, **{self.resulting_feature: feature})


@with_hints(output_schema=extended_schema(personal=qt.Integer,
                                          hourly=qt.Integer,
                                          weekly=qt.Integer,
                                          launch_time_elapsed=qt.Integer,
                                          recent=qt.Integer,
                                          from_launch_history_start_elapsed=qt.Integer,
                                          time_from_install=qt.Integer))
def features_reducer(groups):
    for result, records in groups:
        launch_history_start = None
        launches_count = 0
        history_dict = defaultdict(list)
        hourly_features = defaultdict(lambda: defaultdict(int))  # {"package.name": {"0": 1, "12": 2}}
        weekly_features = defaultdict(lambda: defaultdict(int))
        for record in records:
            if launch_history_start is None:
                launch_history_start = record['timestamp']

            features_dict = dict()
            history = history_dict[record['item']]

            # overall history events count
            features_dict['personal'] = len(history)
            features_dict['hourly'] = hourly_features[record['item']][str(record['hour'])]
            features_dict['weekly'] = weekly_features[record['item']][str(record['weekday'])]

            # seconds elapsed from last app launch
            features_dict['launch_time_elapsed'] = record['timestamp'] - history[-1][1] if history else -1

            # number of other apps launches since last launch of the app
            features_dict['recent'] = launches_count - history[-1][0] - 1 if history else -1

            # seconds elapsed since first launch event for the user
            features_dict['from_launch_history_start_elapsed'] = record['timestamp'] - launch_history_start

            features_dict['time_from_install'] = record['timestamp'] - record['install_time']

            yield Record(record, **features_dict)

            if record['value'] == 1:
                history.append((launches_count, record['timestamp']))
                launches_count += 1
                hourly_features[record['item']][str(record['hour'])] += 1
                weekly_features[record['item']][str(record['weekday'])] += 1


@mapper_wrapper
class GeneralReducer(object):
    def __init__(self, *keys):
        self.schema = dict(personal=qt.Integer,
                           hourly=qt.Dict[qt.String, qt.Integer],
                           weekly=qt.Dict[qt.String, qt.Integer],
                           users=qt.Integer)
        for key in keys:
            self.schema[key] = qt.String

    def __call__(self, groups):
        for key, records in groups:
            hourly_features = defaultdict(int)  # {"package.name": {"0": 1, "12": 2}}
            weekly_features = defaultdict(int)
            users = set()
            for record in records:
                hourly_features[str(record['hour'])] += 1
                weekly_features[str(record['weekday'])] += 1
                users.add(record['user'])

            personal = 0
            num_users = float(len(users)) * 0.1
            for k in hourly_features.keys():
                norm = int(hourly_features.pop(k)/num_users + 0.5)
                if norm:
                    hourly_features[k] = norm
                personal += hourly_features[k]
            for k in weekly_features.keys():
                norm = int(weekly_features.pop(k)/num_users + 0.5)
                if norm:
                    weekly_features[k] = norm

            yield Record(key,
                         personal=personal,
                         hourly=hourly_features,
                         weekly=weekly_features,
                         users=len(users))


@mapper_wrapper
class SplitMapper(object):
    def __init__(self, ratio):
        self.ratio = ratio
        self.schema = multischema(extended_schema(), extended_schema())

    def __call__(self, records, train, test):
        MAX_UUID_VALUE = UUID('f' * 32).int
        for record in records:
            if float(UUID(record['user']).int) / MAX_UUID_VALUE < self.ratio:
                train(record)
            else:
                test(record)
