import pickle

import pandas as pd

from datetime import datetime

import yt.wrapper as yt

import feature_preparation

MODEL_NAME = 'catboost_v0'
MODEL_VERSION = 0

FEATURES_PATH = '//home/taxi-delivery/analytics/production/smb_model/users_features'
RESULTS_PATH = '//home/taxi-delivery/analytics/production/smb_model/users_probas'
        
def get_dt():
    return datetime.utcnow().strftime('%Y-%m-%d')

def get_dttm():
    return datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')

def load_obj(name):
    with open('projects/projects/smb_model/catboost_v0/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)
    
def save_to_yt(path, table, schema):
    yt.config["proxy"]["url"] = "hahn.yt.yandex.net"
    
    records = table.to_dict('records')
    if not yt.exists(path):
        yt.create('table', path, force=True, attributes={'schema': schema, 
                                                          'optimize_for': 'scan'})
    yt.write_table(
        yt.TablePath(path, append=True),
        records
    )
    
def main():
    model = load_obj(MODEL_NAME)
    print(f'{get_dttm()} - Model loaded')
    
    df = feature_preparation.get_data(FEATURES_PATH)
    
    print(f'{get_dttm()} - Data loaded')
    
    features_list = feature_preparation.FEATURES_LIST
    
    X = df[features_list]
    print(f'{get_dttm()} - Features ready')
    user_phone_pd_id = df['user_phone_pd_id']
    
    predicted_probas = model.predict_proba(X)
    prediction_dttm = get_dt()
    print(f'{get_dttm()} - Prediction made')
    
    classes_mapping = {}
    for class_ in model.classes_:
        if class_ == 0:
            classes_mapping[class_] = 'not_smb_proba'
        else:
            classes_mapping[class_] = 'smb_proba'
            
    prediction = pd.DataFrame(predicted_probas, 
                          columns=[
                              classes_mapping[model.classes_[0]],
                              classes_mapping[model.classes_[1]]
                          ])
    
    new_df = pd.concat([prediction, user_phone_pd_id], axis=1)
    new_df['utc_prediction_dt'] = prediction_dttm
    new_df['model_name'] = MODEL_NAME
    new_df['model_version'] = MODEL_VERSION
    
    schema = [
        {'name': 'model_name', 'type': 'string', 'required': True},
        {'name': 'model_version', 'type': 'int32', 'required': True},
        {'name': 'not_smb_proba', 'type': 'float', 'required': True},
        {'name': 'smb_proba', 'type': 'float', 'required': True},
        {'name': 'user_phone_pd_id', 'type': 'string', 'required': True},
        {'name': 'utc_prediction_dt', 'type': 'string', 'required': True}
    ]
    
    save_to_yt(RESULTS_PATH, new_df, schema)
    print(f'{get_dttm()} - Prediction saved')
    
if __name__ == '__main__':
    main()
