import re
import numpy as np
from yt.wrapper import ypath_join
import datacloud.config.yt as yt_path_config
from datacloud.features.phone_range import extractor as phone_range_extractor
from datacloud.dev_utils.data import data_utils as du
from datacloud.dev_utils.yt import yt_utils
from datacloud.dev_utils.time.patterns import RE_DAILY_LOG_FORMAT


DSSM_NAME = 'DSSM'
CLUSTER_NAME = 'CLUSTER'
COUNT_VECTORIZER_NAME = 'COUNT_VECTORIZER'
GEO_NAME = 'GEO'
LOCATIONS_NAME = 'LOCATIONS_NAME'
NORMED_S2V_NAME = 'NORMED_S2V'
CONTACT_ACTION_NAME = 'CONTACT_ACTIONS'
TIME_HIST_NAME = 'TIME_HIST'
PHONE_RANGE_NAME = 'PHONE_RANGE'

DSSM_COUNT = 400
CLUSTER_COUNT = 767  # base 763 + 4 from raiff
COUNT_VECTORIZER_COUNT = 1152
GEO_COUNT = 18
LOCATIONS_COUNT = 173  # base 63, 20 + 30 by country, 10 + 50 by region
NORMED_S2V_COUNT = 512
CONTACT_ACTION_COUNT = 50
TIME_HIST_COUNT = 25
PHONE_RANGE_COUNT = 3


class Feature(object):
    def __init__(self, feature_name, count, ext_id_key='cid', default=None, yt_table=None):
        """
        yt_table used only in input_pipeline
        """
        self.feature_name = feature_name
        self.ext_id_key = ext_id_key
        self.count = count
        self.default = default
        self.yt_table = yt_table

    def extract(self, rec):
        return du.array_fromstring(rec['features'])

    def has_default(self):
        return self.default is not None

    def fill_with_default(self):
        assert self.default is not None
        return np.array([self.default for it in range(self.count)])

    def get_table_for_date(self, date_str):
        return ypath_join(self.get_weekly_folder(), date_str)

    @classmethod
    def get_weekly_folder(cls):
        raise NotImplementedError()

    @classmethod
    def get_last_weekly_table(cls):
        raise NotImplementedError()

    @classmethod
    def get_ready_dates(cls, yt_client):
        dates = []
        for date in yt_client.list(cls.get_weekly_folder()):
            if re.match(RE_DAILY_LOG_FORMAT, date):
                dates.append(date)
        return dates

    def __str__(self):
        return type(self).__name__

    def __repr__(self):
        return self.__str__()


class DSSMFeature(Feature):
    def __init__(self, ext_id_key='cid', default=None, yt_table=None):
        super(DSSMFeature, self).__init__(DSSM_NAME, DSSM_COUNT, ext_id_key, default, yt_table)

    @classmethod
    def get_weekly_folder(cls):
        return ypath_join(yt_path_config.AGGREGATES_FOLDER, 'dssm/weekly')

    @classmethod
    def get_last_weekly_table(cls):
        return yt_utils.get_last_table(cls.get_weekly_folder())


class ClusterFeature(Feature):
    def __init__(self, ext_id_key='cid', default=None, yt_table=None):
        super(ClusterFeature, self).__init__(
            CLUSTER_NAME, CLUSTER_COUNT, ext_id_key, default, yt_table)

    @classmethod
    def get_weekly_folder(cls):
        return ypath_join(yt_path_config.AGGREGATES_FOLDER, 'cluster/user2clust')

    @classmethod
    def get_last_weekly_table(cls):
        return yt_utils.get_last_table(cls.get_weekly_folder())


class NormedS2VFeature(Feature):
    def __init__(self, ext_id_key='cid', default=None, yt_table=None):
        super(NormedS2VFeature, self).__init__(
            NORMED_S2V_NAME, NORMED_S2V_COUNT, ext_id_key, default, yt_table)

    @classmethod
    def get_weekly_folder(cls):
        return ypath_join(yt_path_config.AGGREGATES_FOLDER, 'normed_s2v/weekly')

    @classmethod
    def get_last_weekly_table(cls):
        return yt_utils.get_last_table(cls.get_weekly_folder())


class ContactActionFeature(Feature):
    def __init__(self, ext_id_key='cid', default=None, yt_table=None):
        super(ContactActionFeature, self).__init__(
            CONTACT_ACTION_NAME, CONTACT_ACTION_COUNT, ext_id_key, default, yt_table)

    def extract(self, rec):
        return np.array(rec['features'], dtype='int64')

    @classmethod
    def get_weekly_folder(cls):
        return ypath_join(yt_path_config.AGGREGATES_FOLDER, 'contact_actions')

    @classmethod
    def get_last_weekly_table(cls):
        return yt_utils.get_last_table(cls.get_weekly_folder())


class TimeHistFeature(Feature):
    def __init__(self, ext_id_key='cid', default=None, yt_table=None):
        super(TimeHistFeature, self).__init__(
            TIME_HIST_NAME, TIME_HIST_COUNT, ext_id_key, default, yt_table)

    @classmethod
    def get_weekly_folder(cls):
        return ypath_join(yt_path_config.AGGREGATES_FOLDER, 'time_hist')

    @classmethod
    def get_last_weekly_table(cls):
        return ypath_join(yt_utils.get_last_table(cls.get_weekly_folder()), 'features')

    def get_table_for_date(self, date_str):
        return ypath_join(self.get_weekly_folder(), date_str, 'features')


class PhoneRangeFeature(Feature):
    def __init__(self, ext_id_key='cid', default=None, yt_table=None):
        super(PhoneRangeFeature, self).__init__(
            PHONE_RANGE_NAME, PHONE_RANGE_COUNT, ext_id_key, default, yt_table)

    @classmethod
    def get_weekly_folder(cls):
        return ypath_join(yt_path_config.AGGREGATES_FOLDER, 'phone_range')

    @classmethod
    def get_last_weekly_table(cls):
        return ypath_join(yt_utils.get_last_table(cls.get_weekly_folder()), 'features')

    def get_table_for_date(self, date_str):
        return ypath_join(self.get_weekly_folder(), date_str, 'features')

    def extract(self, rec):
        return phone_range_extractor.extract_phone_range_features(rec)


class GeoFeature(Feature):
    """Placeholder"""
    def __init__(self, ext_id_key='cid', default=None, yt_table=None):
        super(GeoFeature, self).__init__(
            GEO_NAME, GEO_COUNT, ext_id_key, default, yt_table)


class LocationsFeature(Feature):
    def __init__(self, ext_id_key='cid', default=None, yt_table=None):
        super(LocationsFeature, self).__init__(LOCATIONS_NAME, LOCATIONS_COUNT, ext_id_key, default, yt_table)

    @classmethod
    def get_weekly_folder(cls):
        return ypath_join(yt_path_config.AGGREGATES_FOLDER, 'locations/weekly')

    @classmethod
    def get_last_weekly_table(cls):
        return yt_utils.get_last_table(cls.get_weekly_folder())
