from abc import ABC, abstractmethod

import numpy as np

from ..io_wrappers.io_wrapper import IOWrapper


class DataPreprocessor(ABC):

    def __init__(
        self, params: dict = None, io_wrapper: IOWrapper = None
    ) -> None:
        self.params=params

        self.io_wrapper = io_wrapper

        self.train_features, self.train_targets = self.extract_data("train")

        if self.params["has_valid"] is True :
            self.valid_features, self.valid_targets = self.extract_data("test")

        self.params["num_classes"] = len(self.str_to_ind)

    @abstractmethod
    def extract_data(self, data_type: str) :
        pass

    def get_train_generator(self) :
        return (np.stack(self.train_features, axis=0), self.train_targets)

    def get_valid_generator(self) :
        return (np.stack(self.valid_features, axis=0), self.valid_targets)
