import pandas as pd
import numpy as np
import time
import pickle
import warnings
import yt.yson as yson
warnings.filterwarnings('ignore')

from business_models import hahn
from business_models.util.dataframes import convert_dtypes_to_yql
from concurrent.futures import ProcessPoolExecutor

def _special_case(group):
    specific = sum(group[group > 0])
    total = sum(np.abs(group))
    return specific/total


def foo(function):
    if function == 'lambda x: sum(x)':
        return lambda x: sum(x)
    if function == 'lambda x: sum(x)/len(x)':
        return lambda x: sum(x)/len(x)
    if function == 'lambda x, y: x-y':
        return lambda x, y: x-y
    if function == 'lambda x, y: x/y':
        return lambda x, y: x/y
    if function == 'special_case':
        return _special_case


def bootstrap(args):
    test, control, func, func_2, ratio, n = args
    test, control = list(test), list(control)
    if len(test) < min_sample > len(control):
        return [-3000, -3000, -3000, -3000, -3000, -3000, -3000]
    if not test or not control:
        return [-1000, -1000, -1000, -1000, -1000, -1000, -1000]
    if not np.any(test) and not np.any(control):
        return [-2000, -2000, -2000, -2000, -2000, -2000, -2000]
    if sum(np.array(test) > 0) < 50 > sum(np.array(control) > 0):
        return[-5000, -5000, -5000, -5000, -5000, -5000, -5000]
    function = foo(func)
    function_2 = foo(func_2)
    size = max(len(test), len(control))
    t_list = []
    c_list = []
    np.random.seed(seed=42)
    for _ in range(n):
        if ratio:
            test_agg = np.random.choice(a=test, size=len(test), replace=True)
            control_agg = np.random.choice(a=control, size=len(control), replace=True)
        else:
            test_agg = np.random.choice(a=test, size=size, replace=True)
            control_agg = np.random.choice(a=control, size=size, replace=True)
        t_list.append(function(test_agg))
        c_list.append(function(control_agg))
    t_list = np.array(t_list)
    c_list = np.array(c_list)
    result = function_2(t_list, c_list)
    ci = [np.quantile(result, 0.025).round(5),
          np.quantile(result, 0.05).round(5),
          np.quantile(result, 0.075).round(5),
          np.quantile(result, 0.1).round(5),
          np.quantile(result, 0.5).round(5),
          np.quantile(result, 0.975).round(5),
          np.quantile(result, 1.0).round(5)]
    if func_2 == 'lambda x, y: x/y':
        res_2 = t_list - c_list
        ci_2 = [np.quantile(res_2, 0.025).round(5),
                  np.quantile(res_2, 0.05).round(5),
                  np.quantile(res_2, 0.075).round(5),
                  np.quantile(res_2, 0.1).round(5),
                  np.quantile(result, 0.5).round(5),
                  np.quantile(res_2, 0.975).round(5),
                  np.quantile(res_2, 1.0).round(5)]
        return np.array(ci + ci_2, dtype=np.float)
    return np.array(ci, dtype=np.float)


penetration_list = [
    'deliveries_with_postcard',
    'deliveries_with_thermobag',
    'deliveries_with_door_to_door',
    'postcard_pen',
    'thermobag_pen',
    'd2d_pen'
]

relation_list = [
    'express_couries_delivereies',
    'cargo_deliveries',
    'is_express_couries',
    'cargo_us',
    'd2d',
    'thermobag',
    'postcard',
    'd2d_us',
    'thermobag_us',
    'postcard_us'
]

ratio_list = ['is_express_couries', 'cargo_us']
tags = [f'{i}_penetration' for i in penetration_list] +\
       [f'{i}_relation' for i in relation_list] +\
       [f'{i}_ratio' for i in ratio_list]
boot_exp_amount = 10000
min_sample = 100


def foo_2(row):
    params = {}
    print(row['experiment_id'])
    for group in ('aa', 'ab'):
        for tag in penetration_list:
            test_tag = row[f'{group}_test_{tag}']
            control_tag = row[f'{group}_control_{tag}']
            params[f'{group}_{tag}_penetration'] = [
                np.array(test_tag if test_tag else [], dtype=np.int8),
                np.array(control_tag if control_tag else [],  dtype=np.int8),
                'special_case',
                'lambda x, y: x-y',
                True,
                boot_exp_amount
            ]

        for tag in relation_list:
            test_tag = row[f'{group}_test_{tag}']
            control_tag = row[f'{group}_control_{tag}']
            params[f'{group}_{tag}_relation'] = [
                np.array(test_tag if test_tag else [], dtype=np.int8),
                np.array(control_tag if control_tag else [], dtype=np.int8),
                'lambda x: sum(x)',
                'lambda x, y: x/y',
                False,
                boot_exp_amount,
            ]

        for tag in ratio_list:
            test_tag = row[f'{group}_test_{tag}']
            control_tag = row[f'{group}_control_{tag}']
            params[f'{group}_{tag}_ratio'] = [
                np.array(test_tag if test_tag else [], dtype=np.int8),
                np.array(control_tag if control_tag else [], dtype=np.int8),
                'lambda x: sum(x)/len(x)',
                'lambda x, y: x-y',
                True,
                boot_exp_amount
            ]

    results = []
    keys = list(params.keys())
    with ProcessPoolExecutor() as executor:
        for i in executor.map(bootstrap, params.values()):
            results.append(i)

    keys.append('experiment_id')
    keys.append('group_name')
    results.append(row['experiment_id'])
    results.append(row['group_name'])
    return dict(zip(keys, results))


if __name__ == '__main__':
    df = hahn.read(full_path='//home/taxi-delivery/analytics/dev/marketing/agliukov/crm_exp/for_python_7_final')
    res = df.apply(foo_2, axis=1)
    condition = [c for c in df.columns if 'ab_' not in c and 'aa_' not in c]
    new_df = res.apply(pd.Series).merge(df[condition], on=['experiment_id', 'group_name'])
    columns = [c for c in new_df.columns if 'ab_' in c or 'aa_'in c]
    types = convert_dtypes_to_yql(new_df)
    for column in columns:
        types[column] = 'List<Float>?'
    hahn.write(new_df,
           table_name='after_bootstrap_7',
           types=types,
           full_path='//home/taxi-delivery/analytics/dev/marketing/agliukov/crm_exp/after_bootstrap_7_final',
           if_exists='replace'
)
