# -*- coding: utf-8 -*-
from sklearn.model_selection import StratifiedKFold, KFold
from datacloud.ml_utils.common.constants import RANDOM_SEED


def get_stratified_k_fold(random_state=RANDOM_SEED, shuffle=True, **kwargs):
    return StratifiedKFold(
        random_state=random_state,
        shuffle=shuffle,
        **kwargs
    )


def get_k_fold(random_state=RANDOM_SEED, shuffle=True, **kwargs):
    return KFold(
        random_state=random_state,
        shuffle=shuffle,
        **kwargs
    )
