import pickle

from datetime import datetime

import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve

from catboost import CatBoostClassifier, Pool

import yt.wrapper as yt

import feature_preparation

RANDOM_STATE = 33
MODEL_NAME = 'catboost_v0'
MODEL_VERSION = 0
QUALITY_PATH = f'//home/taxi-delivery/analytics/production/smb_model/{MODEL_NAME}_quality'

def save_obj(obj, name):
    with open('projects/projects/smb_model/catboost_v0/' + name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
          
def get_dt():
    return datetime.utcnow().strftime('%Y-%m-%d')

def get_dttm():
    return datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')

def save_to_yt(path, table, schema):
    yt.config["proxy"]["url"] = "hahn.yt.yandex.net"
    
    records = table.to_dict('records')
    if not yt.exists(path):
        yt.create('table', path, force=True, attributes={'schema': schema, 
                                                          'optimize_for': 'scan'})
    yt.write_table(
        yt.TablePath(path, append=True),
        records
    )

        
def main():
    FEATURES_W_LABELS_PATH = (
        '//home/taxi-delivery/analytics/dev/griganton/LOGDATA-617/smb_features_year_new_2020-08-02_2021-07-31'
    )
    
    df = feature_preparation.get_data(FEATURES_W_LABELS_PATH)
    
    features_list = feature_preparation.FEATURES_LIST
    
    X = df[features_list]
    y = df['is_smb']
    categorical_features_indices = np.where(X.dtypes != np.float)[0]
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                                        test_size=0.25, 
                                                        random_state=RANDOM_STATE)
    
    catboost = CatBoostClassifier(
        iterations=1000,
        depth=7,
        l2_leaf_reg = 500.0,
        random_strength = 0.9,
        eval_metric = 'AUC',
        random_seed=RANDOM_STATE,
        use_best_model=True
    )

    catboost.fit(
        X_train, y_train,
        cat_features=categorical_features_indices,
        eval_set=(X_test, y_test),
        verbose=False
    )
    print(f'{get_dttm()} - Model is fitted: ' + str(catboost.is_fitted()))
    
        
    y_score = catboost.predict_proba(X_test)
    print(f'{get_dttm()} - Score calculated')
    
    save_obj(catboost, MODEL_NAME)
    print(f'{get_dttm()} - Model saved with name: {MODEL_NAME}')
    
    precision, recall, threshold = precision_recall_curve(y_test, y_score[:, 1])
    quality = pd.DataFrame({'precision': precision[:-1], 
                            'recall': recall[:-1], 
                            'threshold': threshold})

    schema = [
            {'name': 'precision', 'type': 'float', 'required': True},
            {'name': 'recall', 'type': 'float', 'required': True},
            {'name': 'threshold', 'type': 'float', 'required': True}
        ]
    
    save_to_yt(QUALITY_PATH, quality, schema)
    print(f'{get_dttm()} - Quality saved')

    
if __name__ == '__main__':
    main()
