#!/usr/bin/env python
# coding: utf-8

# In[1]:


from prophet import Prophet
from prophet.plot import add_changepoints_to_plot

import pandas as pd
import numpy as np

from business_models import greenplum

from sklearn.linear_model import LinearRegression

import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
logging.getLogger('prophet').setLevel(logging.ERROR)


# In[3]:


df = greenplum("""

        select
            date as ds,
            tariff,
            country,
            city_group,
            city,
            business,
            sum(deliveries) as deliveries,
            sum(gmv) as gmv,
            sum(gmv_cards) as gmv_cards,
            sum(commissions) as commissions,
            sum(subsidies) as subsidies,
            sum(decoupling) as decoupling,
            sum(net_inflow) as net_inflow
        from snb_delivery.dash_finance_estimate_source
        where date is not null
        group by 1,2,3,4,5,6

""")

df['ds']= pd.to_datetime(df['ds'])


# In[4]:


PARAMETERS = ['tariff','country','city_group','city','business']


# In[5]:


df_pivot = df.pivot_table(index=PARAMETERS, columns='ds', values='deliveries', fill_value=0).reset_index()[PARAMETERS]

# In[6]:


COUNTRIES = {
    'Белоруссия': 'BY',
    'Великобритания': 'GB',
    'Грузия': 'GE',
    'Израиль': 'IL',
    'Латвия': 'LV',
    'Россия': 'RU',
    'Финляндия': 'FI',
    'Чили': 'CL'
}

FIRST_FORECAST_DAY = df['ds'].max() + pd.Timedelta("1 day")
DAYS_FORECAST = 40
THRESHOLD = 200
AVG_WINDOW_DAYS = 28
METRICS = ['deliveries','gmv','gmv_cards','commissions','subsidies','decoupling']


# In[7]:


forecast = pd.DataFrame()

for metric in METRICS:

    for i, r in df_pivot.iterrows():
        tariff, country, city_group, city, business = list(r)
    # tariff, country, tier = df_pivot.iloc[20,:]
        df_slice = df[(df['tariff']==tariff)&(df['country']==country)&(df['city_group']==city_group)&(df['city']==city)&(df['business']==business)][['ds',metric]].sort_values(by='ds').fillna(0)
        df_slice.rename(columns={metric: 'y'}, inplace=True)

        # среднее за N последних дней
        if df_slice.shape[0] > AVG_WINDOW_DAYS:
            avg_window = round(list(df_slice['y'].rolling(window=AVG_WINDOW_DAYS).mean())[-1])
        else:
            avg_window = round(df_slice['y'].fillna(0).mean())
        logger.info("{}, {}, {}. Last average: {:1f}".format(metric, i, list(r), avg_window))

        # сначала prophet
        m = Prophet(changepoint_range=0.95)
        if country in COUNTRIES.keys():
            m.add_country_holidays(country_name=COUNTRIES[country])
        if df_slice.shape[0] > 400:
            m.add_seasonality(name='yearly', period=365, fourier_order=26)
        if df_slice.shape[0] > AVG_WINDOW_DAYS:
            m.fit(df_slice, verbose = False)
            future = m.make_future_dataframe(periods=DAYS_FORECAST)
            predict = m.predict(future)
            values_prophet = predict[['ds','yhat']].rename(columns={'yhat':'y'}).tail(DAYS_FORECAST).reset_index(drop=True)
        else:
            values_prophet = pd.DataFrame(
                {'ds': date_range,
                 'y': [avg_window]*DAYS_FORECAST  
                })
 

        # потом линейный тренд
        date_range = pd.date_range(FIRST_FORECAST_DAY, periods=DAYS_FORECAST, freq="D")
        fcst_range = np.array(date_range, dtype='float').reshape(-1,1)
        reg = LinearRegression().fit(np.array(df_slice['ds'][-AVG_WINDOW_DAYS:], dtype='float').reshape(-1,1), df_slice['y'][-AVG_WINDOW_DAYS:])
        predict = reg.predict(fcst_range)
        values_append = pd.DataFrame(
            {'ds': date_range,
             'y': predict
            })

        forecast_append = pd.DataFrame(
            {'metric': metric,
            'tariff': tariff,
            'country': country,
            'city_group': city_group,
            'city': city,
            'business': business,
            'date':values_append['ds'],
            'ML': values_prophet['y'],
            'trend': values_append['y'],
            'halftrend': (values_append['y']+[avg_window]*DAYS_FORECAST)/2,
             
#             'value_ML': np.clip(values_prophet['y'], a_min=0, a_max=None),
#             'value_trend': np.clip(values_append['y'], a_min=0, a_max=None),
#             'value_halftrend': np.clip((values_append['y']+[avg_window]*DAYS_FORECAST)/2, a_min=0, a_max=None),
             
            'avg': avg_window}
        )

        forecast = forecast.append(forecast_append)
        del(df_slice)
        del(m)
        del(r)
        
forecast.to_csv('forecast.csv')


# In[6]:


forecast = pd.read_csv('forecast.csv', index_col=False).drop(columns='Unnamed: 0')


# In[7]:


forecast['date'] = pd.to_datetime(forecast['date'])


# In[8]:


def positive_int(x):
    return np.clip([round(a) for a in x], a_min=0, a_max=None)

df_deliveries = forecast[forecast['metric']=='deliveries'].copy()
df_gmv = forecast[forecast['metric'].isin(['gmv','gmv_cards'])].copy()
df_other = forecast[forecast['metric'].isin(['commissions','subsidies','decoupling'])].copy()

df_deliveries.loc[:,['ML','trend','halftrend','avg']] = df_deliveries.loc[:,['ML','trend','halftrend','avg']].apply(positive_int)
                    
df_gmv.loc[:,['ML','trend','halftrend','avg']] = df_gmv.loc[:,['ML','trend','halftrend','avg']].apply(positive_int)                


# In[9]:


df_output = pd.concat([df_deliveries, df_gmv, df_other], axis=0)
df_output.set_index(PARAMETERS + ['date','metric'], inplace=True)
df_output = df_output.unstack('metric')
df_output.columns = ['_'.join(col) for col in df_output.columns.values]

df_output.reset_index(inplace=True)


# In[10]:


def mix_(a):
    return str(a.split('_', maxsplit=1)[1]) + '_' + str(a.split('_',maxsplit=1)[0])
# 'trend_gmv' -> 'gmv_trend'
df_output.columns = [mix_(col) if col.endswith(tuple(METRICS)) else col for col in df_output.columns.values]


# In[12]:


cols = df_output.columns.tolist()
cols = ['date'] + cols[:5] + cols[6:]
df_output = df_output[cols]


# In[13]:


greenplum('drop table if exists snb_delivery.data_ml_estimate')
greenplum.write_table('snb_delivery.data_ml_estimate', df_output, with_grant=True, operator='select', to='public')
