import pandas as pd
import click
import os
import yaml
from customer_service.forecasts.lib.utils import (
    load_actual,
    upload_df_to_yt,
    preprocess,
    get_yt_client
)
from customer_service.forecasts.lib.prophet import forecast


@click.command()
@click.option('--config', required=True, help='Path to config')
def main(config: str):
    """Forecast workload (emails/calls/chats) per every day and write to YT table.

    Result:
        * https://yt.yandex-team.ru/hahn/navigation?path=//home/sp/mmokhin/crew/forecast/workload_forecast_long
    """
    with open(config, 'r') as stream:
        config = yaml.safe_load(stream)

    # 1. Prepare actual data
    client = get_yt_client(config['yt_proxy'], os.getenv("YT_TOKEN"))

    actual_df = load_actual(config['input_params']['table_path'], client)
    actual_df = preprocess(actual_df, config['input_params'])

    # Stop forecasting if no data
    if (actual_df.shape[0] <= 5) | (actual_df["y"].sum() < 1):
        print("No raw data to forecast")
        return None

    # 2. Prepare forecast
    # Define unseen period for forecast
    forecast_params = config['forecast_params']
    unseen_future_ts_df = pd.DataFrame(
        {
            "ds": pd.date_range(
                actual_df["ds"].max() + pd.Timedelta(forecast_params["forecast_freq"]),
                periods=forecast_params["horizon"],
                freq=forecast_params["forecast_freq"],
            )
        }
    )

    # Calculate forecasts
    all_forecasts = [
        actual_df.groupby("group").apply(forecast, forecast_params, unseen_future_ts_df)
    ]

    forecast_df = (
        pd.concat([df for df in all_forecasts if not df.empty], axis=1)
        .reset_index()
        .drop(columns=["level_1"])
    )

    forecast_df["report_ts"] = pd.Timestamp.now().round("T")

    # 3. Write result to YT
    upload_df_to_yt(
        client,
        forecast_df.query("yhat > 0").loc[
            :,
            [
                "report_ts",
                "group",
                "model_id",
                "ds",
                "train_start",
                "train_end",
                "yhat",
            ],
        ],
        config['output_params'],
        is_append=False,
    )


if __name__ == '__main__':
    main()
