import numpy as np
import pandas as pd
from fbprophet import Prophet
from .utils import (
    delete_outliers,
    generate_forecast_model_id,
    seasonal_naive,
    wmape
)


# TODO: Refactor this spaghetti code
def forecast(actual_df: pd.DataFrame, forecast_params: dict, unseen_future_ts_df: pd.DataFrame):
    # 0.1 Define testing period (depend on last actual value for that ts)
    test_start_ts = (
        actual_df["ds"].max() - pd.Timedelta(forecast_params["forecast_freq"])
        * forecast_params["test_period"]
    )

    test_periods_range = pd.date_range(
        test_start_ts,
        periods=forecast_params["test_period"],
        freq=forecast_params["forecast_freq"],
    )

    test_ts_df = pd.DataFrame({"ds": test_periods_range})

    # 0.2 Keep only periods inside shift interval
    if forecast_params["is_set_shift_boundaries"]:
        shift_start = actual_df.query("y>0").ds.dt.hour.min()
        shift_end = actual_df.query("y>0").ds.dt.hour.max()
        actual_df = actual_df.loc[actual_df.ds.dt.hour.between(shift_start, shift_end)]

        test_ts_df = test_ts_df.loc[
            test_ts_df.ds.dt.hour.between(shift_start, shift_end)
        ]

        # Define forecasting period (specific for that ts)
        unseen_future_ts_df = unseen_future_ts_df.loc[
            unseen_future_ts_df.ds.dt.hour.between(shift_start, shift_end)
        ]

    # 0.3 Generate model_id
    model_id = generate_forecast_model_id(forecast_params)

    # 2. Delete outliers
    if forecast_params["is_delete_outliers"]:
        actual_df = delete_outliers(
            actual_df, forecast_params["horizon"], forecast_params["outlier_z_value"],
        )

    if (actual_df.shape[0] < 5) | (actual_df["y"].sum() == 0):
        print("Not enough data in that group")
        return None

    # 3. Calc limits for forecast level (cap) used if forecasting logistic growth
    forecast_cap = (
        actual_df.tail(forecast_params["horizon"])["y"].max() * forecast_params["forecast_cap"]
    )

    if forecast_cap == 0:
        print("forecast_cap = 0")
        return None

    forecast_floor = (
        actual_df.tail(forecast_params["horizon"])["y"].min()
        / forecast_params["forecast_cap"]
    )

    # 4. Transform
    if forecast_params["is_log_y"]:
        actual_df["y"] = np.log1p(actual_df["y"])
        forecast_cap = np.log1p(forecast_cap)
        forecast_floor = np.log1p(forecast_floor)

    # 5. Set limits for forecast level (cap) used if forecasting logistic growth
    actual_df["cap"] = forecast_cap
    actual_df["floor"] = forecast_floor
    unseen_future_ts_df["cap"] = forecast_cap
    unseen_future_ts_df["floor"] = forecast_floor
    test_ts_df["cap"] = forecast_cap
    test_ts_df["floor"] = forecast_floor

    # 6.0 Prepare train_df
    train_df = actual_df.loc[actual_df["ds"] < test_ts_df["ds"].min()]

    # 6.1 Create model for long train set
    try:
        model_long = Prophet(**forecast_params["prophet"]).fit(train_df)
    except (RuntimeError, ValueError):
        try:
            model_long = Prophet(growth="flat").fit(train_df, algorithm="Newton")
        except (RuntimeError, ValueError) as e:
            print(f"Issues with fitting Prophet for model_long: {e}")
            return None

    # 6.2 Create model for short train set
    short_model_train_start = train_df["ds"].max() - 1 * forecast_params["test_period"] * pd.Timedelta(forecast_params["forecast_freq"])

    try:
        model_short = Prophet(**forecast_params["prophet"]).fit(
            train_df.loc[train_df.ds >= short_model_train_start], algorithm="Newton"
        )
    except (RuntimeError, ValueError):
        try:
            model_short = Prophet(growth="flat").fit(
                train_df.loc[train_df.ds >= short_model_train_start], algorithm="Newton"
            )
        except (RuntimeError, ValueError) as e:
            print(f"Issues with fitting Prophet for model_short: {e}")
            model_short = model_long

    # 7.1 Forecast on test period for long model
    forecast_long = model_long.predict(test_ts_df)[["ds", "yhat"]]

    # 7.2 Forecast on test period for short model
    forecast_short = model_short.predict(test_ts_df)[["ds", "yhat"]]

    # 7.3 Naive forecast
    # Fill missing time slots with 0
    train_all_time_slots_df = (
        train_df.set_index("ds")
        .asfreq(freq=forecast_params["forecast_freq"], fill_value=0)
        .reset_index()
        .fillna(value=0)
    )

    naive_forecast = seasonal_naive(
        train_all_time_slots_df["y"],
        forecast_params["naive_season"],
        forecast_params["test_period"],
    )

    naive_test_df = pd.DataFrame({"ds": test_ts_df["ds"], "yhat_naive": naive_forecast})

    # 8. Combine forecasts and actual to compare them by wmape
    actual_all_time_slots_df = (
        actual_df.set_index("ds")
        .asfreq(freq=forecast_params["forecast_freq"], fill_value=0)
        .reset_index()
        .fillna(value=0)
    )

    compare_models_df = (
        forecast_long.merge(forecast_short, on="ds", suffixes=("_long", "_short"))
        .merge(naive_test_df, on="ds")
        .merge(actual_all_time_slots_df.loc[:, ["ds", "y"]], on="ds")
        .set_index(["ds"])
    )

    # 9. Undo transformation
    if forecast_params["is_log_y"]:
        compare_models_df = np.expm1(compare_models_df)
        forecast_cap = np.expm1(forecast_cap)
        forecast_floor = np.expm1(forecast_floor)
        actual_all_time_slots_df["y"] = np.expm1(actual_all_time_slots_df["y"])

    # 8. Calculate WMAPE
    wmape_long_model = wmape(compare_models_df.y, compare_models_df.yhat_long)
    wmape_short_model = wmape(compare_models_df.y, compare_models_df.yhat_short)
    wmape_naive_model = wmape(compare_models_df.y, compare_models_df.yhat_naive)

    best_wmape = min(wmape_long_model, wmape_short_model)

    # 9. Make forecast on unseen period with better model
    if wmape_naive_model * 1.1 <= best_wmape:
        # Naive forecast
        forecast_df = seasonal_naive(
            actual_all_time_slots_df["y"],
            forecast_params["naive_season"],
            forecast_params["horizon"],
        ).to_frame(name="yhat")

        # Add dates
        forecast_df["ds"] = pd.date_range(
            start=unseen_future_ts_df.ds.min(),
            periods=forecast_df.shape[0],
            freq=forecast_params["forecast_freq"],
        )

        # Add train start, end columns
        forecast_df["train_start"] = actual_df["ds"].min()
        forecast_df["train_end"] = actual_df["ds"].max()

        # Add model id
        forecast_df["model_id"] = model_id + "_naive"

    elif wmape_long_model <= best_wmape:
        # Forecast with long train period
        model = Prophet(**forecast_params["prophet"]).fit(actual_df)
        forecast_df = model.predict(unseen_future_ts_df)[["ds", "yhat"]]

        # Undo transformations
        if forecast_params["is_log_y"]:
            forecast_df["yhat"] = np.expm1(forecast_df["yhat"])

        # Add train start, end columns
        forecast_df["train_start"] = actual_df["ds"].min()
        forecast_df["train_end"] = actual_df["ds"].max()

        # Add model id
        forecast_df["model_id"] = model_id + "_long"

    else:
        # Forecast with short train period
        model = Prophet(**forecast_params["prophet"]).fit(
            actual_df.loc[actual_df.ds >= short_model_train_start]
        )
        forecast_df = model.predict(unseen_future_ts_df)[["ds", "yhat"]]

        # Undo transformations
        if forecast_params["is_log_y"]:
            forecast_df["yhat"] = np.expm1(forecast_df["yhat"])

        # Add train start, end columns
        forecast_df["train_start"] = short_model_train_start
        forecast_df["train_end"] = actual_df["ds"].max()

        # Add model id
        forecast_df["model_id"] = model_id + "_short"

    # 10. Replace negative values (if any) in forecast with zeros
    forecast_df["yhat"] = forecast_df["yhat"].clip(lower=0).round(1)

    # 11. Replace too high forecasts with forecast cap
    forecast_df["yhat"] = forecast_df["yhat"].clip(upper=forecast_cap).round(1)

    return forecast_df
