import crypta.adhoc.cdp.config_pb2 as config
from crypta.lib.python.getoptpb import ParsePbOptions
from crypta.lib.python.yql.client import create_yql_client
from crypta.lib.python.logging import logging_helpers

import library.python.resource as rs
from yt.wrapper import YtClient, TablePath
from fbprophet import Prophet
import pandas as pd

import logging
import datetime


logger = logging.getLogger(__name__)


def parse_day(str_):
    return datetime.datetime.strptime(str(str_), '%Y-%m-%d').date()


def day_back(date, count):
    return date - datetime.timedelta(days=count)


def quote(str_):
    return '"%s"' % (str_)


def fit_predict(series, periods):
    model = Prophet(
        yearly_seasonality=False,
        weekly_seasonality=True,
        daily_seasonality=False,
        changepoint_range=0.9,
        changepoint_prior_scale=0.3,
        interval_width=0.75,
    )
    model.fit(series, iter=5000)
    future = model.make_future_dataframe(periods=periods)
    forecast = model.predict(future)
    df = forecast[["ds", "trend", "trend_lower", "trend_upper", "yhat", "yhat_lower", "yhat_upper"]]
    df = df.assign(
        dt=df.ds.dt.strftime("%Y-%m-%d"),
    ).drop(["ds"], axis=1)
    return df


def predict_and_store(yt, path, destination, value_column='Value', lookahead=180):
    series = pd.DataFrame(list(yt.read_table(path)))
    all_predictions = []

    for each_placeid in pd.unique(series['PlaceID']):
        for each_type in pd.unique(series['Type']):
            subseries = series[(series['Type'] == each_type) & (series['PlaceID'] == each_placeid)]
            subseries = subseries[['Day', value_column]].rename(columns={'Day': 'ds', value_column: 'y'})
            predictions = fit_predict(subseries, lookahead)
            predictions.clip(lower=0, inplace=True)
            predictions['Type'] = each_type
            predictions['PlaceID'] = each_placeid
            predictions = predictions.rename(columns={
                'dt': 'Day',
            })
            all_predictions.append(predictions)

    all_predictions = pd.concat(all_predictions)
    schema = [
        {"name": "Type", "type": "string"},
        {"name": "PlaceID", "type": "int64"},
        {"name": "Day", "type": "string"},
        {"name": "trend", "type": "double"},
        {"name": "trend_lower", "type": "double"},
        {"name": "trend_upper", "type": "double"},
        {"name": "yhat", "type": "double"},
        {"name": "yhat_lower", "type": "double"},
        {"name": "yhat_upper", "type": "double"},
    ]
    yt.remove(destination, force=True)
    yt.create('table', destination, attributes={"schema": schema})
    yt.write_table(destination, all_predictions.to_dict('records'))


def get_target_days(args):
    origin = parse_day(args.Day) if args.Day else day_back(datetime.date.today(), 1)
    return set([str(day_back(origin, i)) for i in range(args.DaysBack)])


def get_computed_days(yt, tables, column='Day'):
    days = set()
    for table in tables:
        if yt.exists(table):
            days |= set([row[column] for row in yt.read_table(table)])
    return days


def join_values_and_forecast(
    yt, cdp_segments_usage, cdp_segments_usage_forecast, cdp_segments_usage_with_forecast
):
    cdp_segments_usage_df = pd.DataFrame(list(yt.read_table(cdp_segments_usage)))
    cdp_segments_usage_forecast_df = pd.DataFrame(list(yt.read_table(cdp_segments_usage_forecast)))
    merge_keys = ['Day', 'PlaceID', 'Type']
    merged_df = pd.merge(cdp_segments_usage_forecast_df, cdp_segments_usage_df, on=merge_keys, how='left')

    schema = [
        {"name": "Day", "type": "string"},
        {"name": "PlaceID", "type": "int64"},
        {"name": "TotalCostRub", "type": "double"},
        {"name": "Type", "type": "string"},
        {"name": "trend", "type": "double"},
        {"name": "trend_lower", "type": "double"},
        {"name": "trend_upper", "type": "double"},
        {"name": "yhat", "type": "double"},
        {"name": "yhat_lower", "type": "double"},
        {"name": "yhat_upper", "type": "double"},
    ]
    yt.remove(cdp_segments_usage_with_forecast, force=True)
    yt.create('table', cdp_segments_usage_with_forecast, attributes={"schema": schema})
    yt.write_table(cdp_segments_usage_with_forecast, merged_df.T.to_dict().values())


def main():
    logging_helpers.configure_stderr_logger(logging.getLogger())
    args = ParsePbOptions(config.TConfig)
    root = TablePath(args.Paths.OutputDirectory)

    cdp_segments_usage = root.join('CdpSegmentsUsage')
    cdp_segments_usage_forecast = cdp_segments_usage + 'Forecast'
    cdp_segments_usage_with_forecast = cdp_segments_usage + 'WithForecast'

    yql = create_yql_client(args.Yt.Proxy, args.Yt.Token, tmp_folder=args.Paths.TmpDirectory)
    yt = YtClient(proxy=args.Yt.Proxy, token=args.Yt.Token)

    target_days = get_target_days(args)
    computed_days = get_computed_days(yt, [cdp_segments_usage])
    effective_days = target_days - computed_days
    logger.info('Going to compute metrics for [%s]', effective_days)

    if effective_days:
        query = rs.find('/query/metrics.yql').decode('utf-8').format(
            days=','.join(map(quote, effective_days)),
            cdp_segments_usage=cdp_segments_usage,
        )
        yql.execute(query, syntax_version=1)

    predict_and_store(yt, cdp_segments_usage, cdp_segments_usage_forecast, value_column='TotalCostRub')
    join_values_and_forecast(
        yt, cdp_segments_usage, cdp_segments_usage_forecast, cdp_segments_usage_with_forecast,
    )
