import json, datetime

import tqdm.notebook
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression, Lasso, Ridge


MIN_WINDOW_PERIODS = 1


def load_experiment_json(path):
    with open(path, 'r+') as f:
        raw_data = json.load(f)
        experiment_meta = {
            k: set(v) for k, v in raw_data.items() if isinstance(v, list)
        }
        raw_data = {
            k: v for k, v in raw_data.items() if not isinstance(v, list)
        }
        experiment_meta.update(raw_data)
        return experiment_meta


def get_additional_experiment_meta(
        experiment_meta, days_before_treatment, days_after_treatment,
):
    experiment_meta['ab_analysis_start_dt'] = str(
        (
            datetime.datetime.strptime(
                experiment_meta['ab_exp_start_dt'], '%Y-%m-%d',
            )
            - datetime.timedelta(days=days_before_treatment)
        ).date(),
    )
    experiment_meta['ab_analysis_end_dt'] = str(
        (
            datetime.datetime.strptime(
                experiment_meta['ab_exp_start_dt'], '%Y-%m-%d',
            )
            + datetime.timedelta(days=days_after_treatment - 1)
        ).date(),
    )

    ab_periods = (
        datetime.datetime.strptime(
            experiment_meta['ab_analysis_end_dt'], '%Y-%m-%d',
        )
        - datetime.datetime.strptime(
            experiment_meta['ab_exp_start_dt'], '%Y-%m-%d',
        )
    ).days + 1

    experiment_meta['aa_analysis_start_dt'] = str(
        (
            datetime.datetime.strptime(
                experiment_meta['ab_analysis_start_dt'], '%Y-%m-%d',
            )
            - datetime.timedelta(days=ab_periods)
        ).date(),
    )
    experiment_meta['aa_analysis_end_dt'] = str(
        (
            datetime.datetime.strptime(
                experiment_meta['ab_analysis_end_dt'], '%Y-%m-%d',
            )
            - datetime.timedelta(days=ab_periods)
        ).date(),
    )
    experiment_meta['aa_exp_start_dt'] = str(
        (
            datetime.datetime.strptime(
                experiment_meta['ab_exp_start_dt'], '%Y-%m-%d',
            )
            - datetime.timedelta(days=ab_periods)
        ).date(),
    )

    return experiment_meta


def get_experiment_meta(
        path, days_before_treatment=21, days_after_treatment=21,
):
    experiment_meta = load_experiment_json(path)
    experiment_meta = get_additional_experiment_meta(
        experiment_meta, days_before_treatment, days_after_treatment,
    )
    return experiment_meta


class SyntheticDiDDataProcessor(object):
    def __init__(self, outcome_var, time_var, object_var):
        self.outcome_var = outcome_var
        self.time_var = time_var
        self.object_var = object_var

    def inpute_missing_values(self, df, fillna_dict):
        assert set(df.columns.values.tolist()) == set(
            list(fillna_dict.keys()) + [self.time_var, self.object_var],
        ), 'Columns in fillna_dict must match columns in  df.'

        td_list = []
        for obj, data_obj in df.groupby(self.object_var):
            time_diff = set(df[self.time_var].unique()).difference(
                set(data_obj[self.time_var].unique()),
            )
            if time_diff:
                td_df_list = []
                for td in time_diff:
                    td_dict = {self.object_var: obj, self.time_var: td}
                    td_dict.update(fillna_dict)
                    td_df_list.append(td_dict)
                td_df = pd.DataFrame(td_df_list)
                td_list.append(td_df.copy())
        time_diff_df = pd.concat(td_list)

        return pd.concat([df, time_diff_df]).sort_values(
            [self.object_var, self.time_var],
        )

    def truncate_df(self, df, experiment_meta, mode='ab'):
        assert mode in {
            'aa',
            'ab',
        }, '"mode" argument must be either "ab" or "aa".'

        return (
            df[
                (
                    df[self.time_var]
                    >= experiment_meta['{}_analysis_start_dt'.format(mode)]
                )
                & (
                    df[self.time_var]
                    <= experiment_meta['{}_analysis_end_dt'.format(mode)]
                )
            ]
            .groupby([self.object_var, self.time_var])
            .sum()
            .reset_index()
        )

    def process_data(self, df, experiment_meta, fillna_dict):
        df = self.inpute_missing_values(df, fillna_dict)
        df_ab = self.truncate_df(df, experiment_meta, mode='ab')
        df_aa = self.truncate_df(df, experiment_meta, mode='aa')
        return df_ab, df_aa


class SyntheticDiDEstimator:
    def __init__(
            self,
            df,
            treatment_start_date,
            outcome_var,
            time_var,
            object_var,
            treated_objects_set,
            skip_control_objects_set,
    ):
        self.df = df
        self.treatment_start_date = treatment_start_date
        self.outcome_var = outcome_var
        self.time_var = time_var
        self.object_var = object_var
        self.treated_objects_set = treated_objects_set
        self.skip_control_objects_set = skip_control_objects_set

    def process_data(self):
        data = self.df[
            (~self.df[self.object_var].isin(self.treated_objects_set))
            & (~self.df[self.object_var].isin(self.skip_control_objects_set))
        ][[self.object_var, self.time_var, self.outcome_var]]
        self.time_list = sorted(data[self.time_var].unique())
        self.object_list = sorted(data[self.object_var].unique())

        # Stack data.
        data_list = []
        for obj, data_obj in data.groupby(self.object_var):
            data_list.append(data_obj[self.outcome_var].values)
        data = np.stack(data_list, axis=1)

        return data

    def get_predictions(self, data, model_type='OLS'):
        assert model_type in {
            'OLS',
            'Ridge',
            'Lasso',
        }, '"model_type" argument must be one of ["OLS", "Ridge", "Lasso"].'

        y_treatment = (
            self.df[self.df[self.object_var].isin(self.treated_objects_set)]
            .groupby(self.time_var)
            .agg({self.outcome_var: 'mean'})
            .values
        )

        self.pre_exp_periods = (
            np.array(self.time_list) < self.treatment_start_date
        ).sum()

        if model_type == 'OLS':
            self.lr = LinearRegression(fit_intercept=True)
        elif model_type == 'Ridge':
            self.lr = Ridge(alpha=100.0, fit_intercept=True)
        elif model_type == 'Lasso':
            self.lr = Lasso(alpha=100.0, fit_intercept=True)

        self.lr.fit(
            data[: self.pre_exp_periods], y_treatment[: self.pre_exp_periods],
        )

        pre_treatment_y_pred = self.lr.predict(data[: self.pre_exp_periods])
        post_treatment_y_pred = self.lr.predict(data[self.pre_exp_periods :])

        pre_treatment_y_true = y_treatment[: self.pre_exp_periods]
        post_treatment_y_true = y_treatment[self.pre_exp_periods :]

        return (
            pre_treatment_y_pred,
            post_treatment_y_pred,
            pre_treatment_y_true,
            post_treatment_y_true,
        )

    def calculate_diff_in_diff(
            self,
            pre_treatment_y_pred,
            post_treatment_y_pred,
            pre_treatment_y_true,
            post_treatment_y_true,
    ):
        diff_pre = (pre_treatment_y_true - pre_treatment_y_pred).sum() / len(
            pre_treatment_y_true,
        )

        post_treatment_preds_add = post_treatment_y_pred + diff_pre

        diff_post_add = (
            post_treatment_y_true - post_treatment_preds_add
        ).sum()
        diff_post_add_pp = diff_post_add / post_treatment_y_true.sum()

        return {'effect_abs': diff_post_add, 'effect_rel': diff_post_add_pp}
    
    def calculate_growth_rate(
            self,
            pre_treatment_y_pred,
            post_treatment_y_pred,
            pre_treatment_y_true,
            post_treatment_y_true,
    ):
        wow_exp = {}
        wow_con = {}
        for days_interval in [7, 14, 21]:
            wow_exp[days_interval] = round(
                (np.sum(post_treatment_y_true[:days_interval]) - np.sum(pre_treatment_y_true[-days_interval:]))
                / np.sum(pre_treatment_y_true[-days_interval:]),
                3
            )
            wow_con[days_interval] = round(
                (np.sum(post_treatment_y_pred[:days_interval]) - np.sum(pre_treatment_y_pred[-days_interval:]))
                / np.sum(pre_treatment_y_pred[-days_interval:]),
                3
            )
        return wow_exp, wow_con

    def plot_ab(
            self,
            pre_treatment_y_pred,
            post_treatment_y_pred,
            pre_treatment_y_true,
            post_treatment_y_true,
            rolling_window=7,
            fig_save_path=None,
    ):
        fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(18, 9))
        title_str = 'experiment: {}'.format(
            ', '.join(self.treated_objects_set),
        )
#         if self.skip_control_objects_set:
#             title_str += ' except {}'.format(
#                 ', '.join(self.skip_control_objects_set),
#             )

        ax.set_title(title_str, fontsize=15)
        ax.set_ylabel(self.outcome_var, labelpad=10, fontsize=15)

        x = np.array(self.time_list)
        y1 = (
            pd.Series(
                np.concatenate(
                    [pre_treatment_y_true, post_treatment_y_true],
                ).flatten(),
            )
            .rolling(window=rolling_window, min_periods=MIN_WINDOW_PERIODS)
            .mean()
        )
        y2 = (
            pd.Series(
                np.concatenate(
                    [pre_treatment_y_pred, post_treatment_y_pred],
                ).flatten(),
            )
            .rolling(window=rolling_window, min_periods=MIN_WINDOW_PERIODS)
            .mean()
        )

        ax.xaxis.set_major_locator(plt.MaxNLocator(10))

        plt.plot(x, y1, c='r', label='experiment')
        plt.plot(x, y2, c='b', label='control')
        plt.axvline(
            x=self.treatment_start_date,
            label='experiment start date',
            linestyle='--',
            color='dimgray',
        )
        plt.axvline(
            x=self.time_list[self.pre_exp_periods - 1],
            label='experiment start date - 1 day',
            linestyle='--',
            color='darkgray',
        )
        plt.legend()
        plt.grid()
        ax.set_ylim(ymin=0)
        if fig_save_path:
            plt.savefig(fig_save_path)


def run_test(
        df,
        treatment_start_date,
        outcome_var,
        time_var,
        object_var,
        treated_objects_set,
        skip_control_objects_set,
        plot_flg=True,
        fig_save_path=None,
):
    sdid = SyntheticDiDEstimator(
        df=df,
        treatment_start_date=treatment_start_date,
        outcome_var=outcome_var,
        time_var=time_var,
        object_var=object_var,
        treated_objects_set=treated_objects_set,
        skip_control_objects_set=skip_control_objects_set,
    )
    data = sdid.process_data()
    predictions = sdid.get_predictions(data, model_type='OLS')
    effect = sdid.calculate_diff_in_diff(*predictions)
    growth_rate = sdid.calculate_growth_rate(*predictions)

    if plot_flg:
        sdid.plot_ab(
            *predictions, rolling_window=7, fig_save_path=fig_save_path,
        )
    return effect, growth_rate


def get_p_value(
        df,
        effect,
        growth_rate,
        experiment_meta,
        treatment_start_date,
        outcome_var,
        time_var,
        object_var,
        n_iter=100,
):
    effect_rel_list = []
    growth_rate_list = []
    treated_objects_set = (
        experiment_meta['exp_cities'] - experiment_meta['exclude_cities']
    )
    skip_control_objects_set = experiment_meta['exclude_cities']
    for seed in tqdm.notebook.tqdm(range(n_iter)):
        np.random.seed(seed)
        placebo_test_objects_set = np.random.choice(
            list(
                set(df[object_var].unique())
                - treated_objects_set
                - skip_control_objects_set,
            ),
            len(treated_objects_set),
            replace=False,
        )

        effect_iter, growth_rate_iter = run_test(
            df,
            treatment_start_date=treatment_start_date,
            outcome_var=outcome_var,
            time_var=time_var,
            object_var=object_var,
            treated_objects_set=placebo_test_objects_set,
            skip_control_objects_set=treated_objects_set.union(
                skip_control_objects_set,
            ),
            plot_flg=False,
            fig_save_path=None,
        )
        growth_rate_iter_diff = growth_rate_iter[0][21] - growth_rate_iter[1][21]
        growth_rate_list.append(growth_rate_iter_diff)
        effect_rel_list.append(effect_iter['effect_rel'])

    p_value_rel = (
        np.abs(np.array(effect_rel_list)) > abs(effect)
    ).sum() / len(effect_rel_list)
    p_value_gr = (
        np.abs(np.array(growth_rate_list)) > abs(growth_rate)
    ).sum() / len(growth_rate_list)
    return p_value_rel, p_value_gr, effect_rel_list, growth_rate_list
