#!/usr/bin/env python
# coding: utf-8

# In[1]:

from prophet import Prophet
from prophet.plot import add_changepoints_to_plot

import pandas as pd
import numpy as np

from business_models import greenplum
from sklearn.linear_model import LinearRegression

# In[2]:


df = greenplum("""

        select date as ds,
            tariff,
            country,
            tier,
            sum(fact) filter (where metric='deliveries') as deliveries,
            sum(fact) filter (where metric='gmv') as gmv,
            sum(fact) filter (where metric='net_inflow') as net_inflow
        from snb_delivery.data_growth_by_tier
        where
            date is not null
            and date < (select max(date) from snb_delivery.data_growth_by_tier where fact is not null)
            and fact is not null
        group by
            1, 2, 3, 4

""")

df['ds']= pd.to_datetime(df['ds'])


# In[3]:


df_pivot = df.pivot_table(index=['tariff','country','tier'], columns='ds', values='deliveries', fill_value=0).reset_index()[['tariff','country','tier']]
df_pivot


# In[4]:


COUNTRIES = {
    'Белоруссия': 'BY',
    'Великобритания': 'GB',
    'Грузия': 'GE',
    'Израиль': 'IL',
    'Латвия': 'LV',
    'Россия': 'RU',
    'Финляндия': 'FI',
    'Чили': 'CL'
}

FIRST_FORECAST_DAY = df['ds'].max() + pd.Timedelta("1 day")
DAYS_FORECAST = 150
THRESHOLD = 200
AVG_WINDOW_DAYS = 28
METRICS = ['deliveries','gmv','net_inflow']

forecast = pd.DataFrame()
prophet_models = []
prophet_forecasts = []

for metric in METRICS:

    for i, r in df_pivot.iterrows():
        tariff, country, tier = list(r)
    # tariff, country, tier = df_pivot.iloc[20,:]
        df_slice = df[(df['tariff']==tariff)&(df['country']==country)&(df['tier']==tier)][['ds',metric]].sort_values(by='ds').fillna(0)
        df_slice.rename(columns={metric: 'y'}, inplace=True)

        # среднее за N последних дней
        if df_slice.shape[0] > AVG_WINDOW_DAYS:
            avg_window = round(list(df_slice['y'].rolling(window=AVG_WINDOW_DAYS).mean())[-1])
        else:
            avg_window = round(df_slice['y'].fillna(0).mean())
        print("{}, {}. Last average: {:1f}".format(i, list(r), avg_window))

        # сначала prophet
        m = Prophet(changepoint_range=0.95)
        if country in COUNTRIES.keys():
            m.add_country_holidays(country_name=COUNTRIES[country])
        if df_slice.shape[0] > 400:
            m.add_seasonality(name='yearly', period=365, fourier_order=26)
        if df_slice.shape[0] > AVG_WINDOW_DAYS:
            m.fit(df_slice, verbose = False)
            future = m.make_future_dataframe(periods=DAYS_FORECAST)
            predict = m.predict(future)
            # prophet_models.append(m)
            # prophet_forecasts.append(predict)
            values_prophet = predict[['ds','yhat']].rename(columns={'yhat':'y'}).tail(DAYS_FORECAST).reset_index(drop=True)
        else:
            values_prophet = pd.DataFrame(
                {'ds': date_range,
                 'y': [avg_window]*DAYS_FORECAST  
                })
            prophet_models.append(0)
            prophet_forecasts.append(0)  

        # потом линейный тренд
        date_range = pd.date_range(FIRST_FORECAST_DAY, periods=DAYS_FORECAST, freq="D")
        fcst_range = np.array(date_range, dtype='float').reshape(-1,1)
        reg = LinearRegression().fit(np.array(df_slice['ds'][-AVG_WINDOW_DAYS:], dtype='float').reshape(-1,1), df_slice['y'][-AVG_WINDOW_DAYS:])
        predict = reg.predict(fcst_range)
        values_append = pd.DataFrame(
            {'ds': date_range,
             'y': predict
            })

        forecast_append = pd.DataFrame(
            {'metric': metric,
            'country': country,
            'tariff': tariff,
            'tier': tier,
            'date':values_append['ds'],
            'value_ML': values_prophet['y'],
            'value_trend': values_append['y'],
            'value_halftrend': (values_append['y']+[avg_window]*DAYS_FORECAST)/2,
             
#             'value_ML': np.clip(values_prophet['y'], a_min=0, a_max=None),
#             'value_trend': np.clip(values_append['y'], a_min=0, a_max=None),
#             'value_halftrend': np.clip((values_append['y']+[avg_window]*DAYS_FORECAST)/2, a_min=0, a_max=None),
             
            'value_avg': avg_window}
        )

        forecast = forecast.append(forecast_append)

# forecast.to_csv('forecast_tier.csv')


# # In[11]:


# forecast = pd.read_csv('forecast_tier.csv')
# forecast['date'] = pd.to_datetime(forecast['date'])
# forecast


# In[13]:


def positive_int(x):
    return np.clip([round(a) for a in x], a_min=0, a_max=None)

df_deliveries = forecast[forecast['metric']=='deliveries'].copy()
df_gmv = forecast[forecast['metric']=='gmv'].copy()
df_other = forecast[forecast['metric'].isin(['net_inflow'])].copy()

df_deliveries.loc[:,['value_ML','value_trend','value_halftrend']] = df_deliveries.loc[:,['value_ML','value_trend','value_halftrend']].apply(positive_int)
                    
df_gmv.loc[:,['value_ML','value_trend','value_halftrend']] = df_gmv.loc[:,['value_ML','value_trend','value_halftrend']].apply(positive_int)                


# In[14]:


df_output = pd.concat([df_deliveries, df_gmv, df_other], axis=0)
df_output.set_index(['country','tariff','tier','date','metric'], inplace=True)
df_output = df_output.unstack('metric')
df_output.columns = ['_'.join(col) for col in df_output.columns.values]
df_output.reset_index(inplace=True)


# In[16]:


greenplum('drop table if exists snb_delivery.data_growth_by_tier_ml')
greenplum.write_table('snb_delivery.data_growth_by_tier_ml', df_output, with_grant=True, operator='select', to='public')
