from datetime import datetime

from crm.agency_cabinet.common.consts.service import Services
from .base import UpdateDataLoader


class PredictDataLoader(UpdateDataLoader):
    def __init__(self, period_from, period_to, *args, **kwargs):
        self.period_from = period_from
        self.period_to = period_to

        if self.service == Services.direct.value:
            self.discount_name = 'Директ'
        elif self.service == Services.media.value:
            self.discount_name = 'Медийка'
        elif self.service == Services.video.value:
            self.discount_name = 'Видео'
        else:
            raise NotImplementedError

        super().__init__(*args, **kwargs)

    @staticmethod
    def _make_table_row(row):
        raise NotImplementedError()

    @staticmethod
    def _get_contract_id(yt_row) -> int:
        return yt_row['contract_id'] if 'contract_id' in yt_row else yt_row['orig_contract_id']

    @staticmethod
    def _get_period_from(row) -> datetime:
        return datetime.strptime(row['dt'], '%Y-%m-%d')

    def _filter_row(self, row) -> bool:
        contract_id = PredictDataLoader._get_contract_id(row)
        if not self._check_contract_id(contract_id):
            return False

        if row['discount_name'] != self.discount_name:
            return False

        period_from = PredictDataLoader._get_period_from(row)
        return self.period_from <= period_from < self.period_to

    def _read_table(self, **kwargs):
        table_path_without_last_slash = self.table_path[:-1]

        self._table = []
        for row in self.client.read_table(table_path_without_last_slash, **kwargs):
            if self._filter_row(row):
                self._table.append(self._make_table_row(row))
