import _jsonnet
import datetime
import json
import random
import string
from collections import namedtuple

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder

from jafar.utils.structarrays import DataFrame

Mapping = namedtuple('Mapping', ['user_map', 'item_map'])

hexdigits = frozenset(string.hexdigits)


def get_random_string(n=16):
    return ''.join(random.choice(string.letters + string.digits) for _ in range(n))


def get_shape(data):
    num_users = len(list(data.user.unique()))
    num_items = len(list(data.item.unique()))
    return num_users, num_items


def get_mapping(data):
    users = list(sorted(data.user.unique()))
    items = list(sorted(data.item.unique()))
    user_map = dict(zip(users, range(len(users))))
    item_map = dict(zip(items, range(len(items))))
    return Mapping(user_map=user_map, item_map=item_map)


def bincount_relative(arr, length):
    assert np.max(arr) < length, 'Required length should be greater than max(arr)+1'
    counts = np.bincount(arr, minlength=length).astype(np.float32)
    counts /= counts.sum()
    return counts


def get_index_pairs(arr):
    diffs = np.ediff1d(arr, to_begin=1, to_end=1)
    indices = diffs.nonzero()[0]
    return np.c_[indices[:-1], indices[1:]]


def get_index_pairs_for_strings(arr):
    arr = LabelEncoder().fit_transform(arr)
    return get_index_pairs(arr)


def get_index_pairs_with_keys(array):
    values, ids = np.unique(array, return_index=True)
    idx = np.argsort(ids)
    values = values[idx]
    ids = np.concatenate((ids[idx], [array.shape[0]]))
    return zip(values, np.c_[ids[:-1], ids[1:]])


def check_frame(frame, required_columns):
    missing_columns = set(required_columns).difference(set(frame.columns))
    if missing_columns:
        raise ValueError("Dataframe is missing the following required columns: {}".format(missing_columns))


def sort_by_users(X, return_index=False):
    # sorting X by users
    if isinstance(X, pd.DataFrame):
        idx = X['user'].argsort()
        X = X.iloc[idx]
    else:
        idx = np.argsort(X['user'])
        X = X[idx]

    # split by user-continuous chunks
    index_pairs = get_index_pairs(X['user'])

    if not return_index:
        return X, index_pairs
    else:
        return X, index_pairs, idx


def add_implicit_negatives(X, n_items, factor=1, max_attempts=100):
    X_sorted, index_pairs = sort_by_users(X)
    n_samples = X_sorted.shape[0] * factor

    # make a user-by-items dictionary
    values = np.vstack([X_sorted['user'], X_sorted['item']]).T
    data_dict = {}
    for i, j in index_pairs:
        data_dict[values[i, 0]] = set(values[i: j, 1])

    sampled_users = np.repeat(values[:, 0], factor)
    left_out_users = np.arange(n_samples)
    sampled_negative_items = np.zeros(n_samples)

    # weight items according to popularity
    weights = bincount_relative(np.int32(values[:, 1]), length=n_items)

    # sample negative items by sampling all items at once
    # and then retrying for items that are in `data_dict`
    attempts = 0
    while len(left_out_users) > 0 and attempts < max_attempts:
        failed = []
        neg_items = np.random.choice(n_items, p=weights, size=len(left_out_users))
        for item_idx, user_idx in enumerate(left_out_users):
            user = sampled_users[user_idx]
            item = neg_items[item_idx]
            if item in data_dict[user]:
                failed.append(user_idx)
            else:
                sampled_negative_items[user_idx] = item
        if len(left_out_users) == len(failed):
            attempts += 1
        left_out_users = failed

    if len(left_out_users) != 0:
        raise ValueError(
            "Couldn't sample enough implicit negatives: {} users left out".format(
                len(left_out_users)
            )
        )

    # uncomment to check that sampling works correctly
    # check = dict(zip(zip(X_sorted[:, 0], X_sorted[:, 1]), X_sorted[:, 2]))
    # for i in negatives:
    #     if (i[0], i[1]) in check:
    #         print 'wrong sample', (i[0], i[1]), check[(i[0], i[1])]
    #         raise Exception('wrong sample')

    if isinstance(X, pd.DataFrame):
        negatives = pd.DataFrame(dict(
            user=sampled_users,
            item=sampled_negative_items,
            value=np.zeros(n_samples),
        ))
        return pd.concat([X_sorted, negatives])
    else:
        negatives = DataFrame.from_dict(dict(
            user=sampled_users,
            item=np.array(sampled_negative_items, dtype=np.int32),
            value=np.zeros(n_samples, dtype=np.float32),
        ))

        return DataFrame.concatenate([
            DataFrame.from_structarray(X_sorted[['user', 'item', 'value']]),
            negatives[['user', 'item', 'value']]
        ])


def get_all_subclasses(cls):
    for children in cls.__subclasses__():
        yield children
        for grandchildren in get_all_subclasses(children):
            yield grandchildren


def date_to_datetime(date, max_time=False):
    if max_time:
        return datetime.datetime.combine(date, datetime.datetime.max.time())
    return datetime.datetime.combine(date, datetime.datetime.min.time())


def date_range(start_date, end_date):
    for n_days in xrange((end_date - start_date).days):
        yield start_date + datetime.timedelta(days=n_days)


def memory_usage():
    """Memory usage of the current process in megabytes. https://stackoverflow.com/a/898406 """
    status = None
    result = {'peak': 0, 'rss': 0}
    try:
        # This will only work on systems with a /proc file system
        # (like Linux).
        status = open('/proc/self/status')
        for line in status:
            parts = line.split()
            if not parts:
                continue
            key = parts[0][2:-1].lower()
            if key in result:
                result[key] = int(parts[1]) / 1024
    finally:
        if status is not None:
            status.close()
    return result


def jsonp_to_dict(jsonp):
    js_object = jsonp[2:-2]  # Stripping JSONP padding
    return json.loads(_jsonnet.evaluate_snippet('snippet', js_object.encode('utf-8')))


def fields_extractor(*fields):
    def extractor(doc):
        return {k: doc[k] for k in fields}

    return extractor


def is_hex(s):
    return all(symbol in hexdigits for symbol in s)


def safe_get(data_container, *keys):
    """
    Safely accesses the fields of classes like mongoengine.Document, mongoengine.EmbeddedDocument, etc.

    The basic idea that there are some nested objects:
    main_container {
      obj1: {
        obj11: "value11",
        obj12: "value12",
        ...
      },
      obj2: {
        obj21: "value21",
        obj22: "value22",
        ...
      },
      ...
    }

    this method allows to safely extract the value of obj12 for example. If on some key None is occurred -
        None is returned
    :return: the value obtained from the sequence of keys if exists, otherwise - None
    """
    try:
        result = data_container
        for key in keys:
            if result[key] is None:
                if getattr(result, key, None) is None:
                    return None
                else:
                    result = getattr(result, key)
            else:
                result = result[key]
        return result
    except:
        return None


def to_ascii(unicode_str):
    """
    Convert unicode string to python ascii string

    All non-ascii characters will be ignored(lost)

    For example unicode string contains only ascii characters:
      <ASCII CHAR 1><ASCII CHAR 2><ASCII CHAR 3>
    For such string, this method will return the same string:
      <ASCII CHAR 1><ASCII CHAR 2><ASCII CHAR 3>

    If unicode string contains non ascii characters:
      <ASCII CHAR 1><ASCII CHAR 2><NON - ASCII CHAR 3>
    Then only ascii characters will appear in the result:
      <ASCII CHAR 1><ASCII CHAR 2>

    :param unicode_str:
    :return: corresponding ascii string
    """
    if unicode_str is None:
        return None
    return unicode_str.encode('ascii', 'ignore')
