# -*- coding: utf-8 -*-
import numpy as np
from scipy.stats import spearmanr
from iv import get_iv


def full_split(X, y, n_splits=20):
    IV = np.zeros(X.shape[1])
    for i in range(len(IV)):
        IV[i] = get_iv(y, X[:, i])
    return make_splits(X, IV, n_splits), IV


def full_split_by_baseline(X, y, baseline=0.7):
    IV = np.zeros(X.shape[1])
    for i in range(len(IV)):
        IV[i] = get_iv(y, X[:, i])
    return make_splits(X, IV, n_splits=20, by_baseline=True, baseline=0.7), IV


def make_a_split(IV, X, ids, step):
    cur_max_index = np.argmax(np.array(IV)[ids])
    print 'Current max IV feature: {}'.format(cur_max_index)
    corr_s = np.zeros(len(ids))
    for i in range(len(ids)):
        corr_s[i] = np.abs(spearmanr(X[:, ids[cur_max_index]], X[:, ids[i]])[0])
    print 'The closest feature: {}'.format(np.argsort(corr_s)[::-1][-2])
    top20_corr = np.argsort(corr_s)[::-1][:step]
    return list(np.array(ids)[top20_corr])


def make_splits(X, IV, n_splits, by_baseline=False, baseline=0.7):
    step = int(X.shape[1] / float(n_splits))
    ids = range(X.shape[1])
    splits = []
    while len(ids) >= 2 * step:
        print 'Current length of ids: {}'.format(len(ids))
        if by_baseline:
            step_ids = make_a_split_by_baseline(IV, X, ids, baseline)
        else:
            step_ids = make_a_split(IV, X, ids, step)
        splits.append(step_ids)
        ids = [i for i in ids if i not in step_ids]
    splits.append(ids)
    return splits


def make_a_split_by_baseline(IV, X, ids, baseline):
    cur_max_index = np.argmax(np.array(IV)[ids])
    print 'Current max IV feature: {}'.format(cur_max_index)
    corr_s = np.zeros(len(ids))
    for i in range(len(ids)):
        corr_s[i] = np.abs(spearmanr(X[:, ids[cur_max_index]], X[:, ids[i]])[0])
    print 'The closest feature: {}'.format(np.argsort(corr_s)[::-1][-2])
    above_baseline = corr_s > baseline
    return list(np.array(ids)[above_baseline])
