import enum
import logging

import yt.yson as yson

import crypta.lib.python.bt.conf.conf as conf
from crypta.lab.proto.sample_pb2 import TSampleStats
from crypta.lib.proto.user_data.attribute_names_pb2 import TAttributeNames
from crypta.lib.proto.user_data.user_data_pb2 import (
    TIdentifiers,
    TUserData,
    TVectors,
)
from crypta.lib.proto.user_data.user_data_stats_pb2 import (
    TSegmentInfo,
    TStrata,
    TUserDataStats,
)
from crypta.lib.python.bt.yt import AbstractDynamicStorage
from crypta.lib.python.native_yt.proto import create_schema


logger = logging.getLogger(__name__)


class SampleStatsStorage(AbstractDynamicStorage):
    class Fields(enum.Enum):
        HASH = 'Hash'
        SAMPLE_ID = 'SampleID'
        GROUP_ID = 'GroupID'
        STATS = 'Stats'

    @property
    def path(self):
        return conf.paths.lab.sample_stats

    @property
    def schema(self):
        proto_schema = create_schema(TSampleStats, strong=True, dynamic=True)
        hash_column = [dict(name='Hash', type='uint64', expression='farm_hash(SampleID)', sort_order='ascending')]
        return yson.to_yson_type(
            hash_column + proto_schema,
            attributes=dict(unique_keys=True),
        )


class Types(object):
    STRING = 'string'
    INT64 = 'int64'
    UINT64 = 'uint64'
    DOUBLE = 'double'
    ANY = 'any'


class RecordBase(dict):
    def _get_proto(self, proto_cls, key):
        proto = proto_cls()
        string_value = self.get(key, "")
        if not string_value:
            return proto
        proto.ParseFromString(string_value)
        return proto


class UserData(object):
    class Fields(object):
        YUID = 'yuid'
        CRYPTA_ID = 'CryptaID'
        GROUP_ID = 'GroupID'
        IDENTIFIERS = 'Identifiers'
        STRATA = 'Strata'
        ATTRIBUTES = 'Attributes'
        VECTORS = 'Vectors'
        WITHOUT_DATA = 'WithoutData'

    class Attributes(object):
        GLOBAL_STATS = 'global_stats'
        DATA_ID = 'data_id'
        LAST_UPDATE_DATE = '_last_update_date'
        LAST_UPDATE_TIMESTAMP = TAttributeNames().LastUpdateTimestamp

    class Record(RecordBase):

        def get_attributes(self):
            return self._get_proto(TUserData.TAttributes, UserData.Fields.ATTRIBUTES)

        def get_identifiers(self):
            return self._get_proto(TIdentifiers, UserData.Fields.IDENTIFIERS)

        def get_identifier(self, key_identifier, default=None):
            identifiers = self.get_identifiers()
            return identifiers.identifiers.get(key_identifier, default)

        def get_strata(self):
            return self._get_proto(TStrata, UserData.Fields.STRATA)

        def get_vectors(self):
            return self._get_proto(TVectors, UserData.Fields.VECTORS)

        def get_group_id(self):
            return self[UserData.Fields.GROUP_ID] or ""

        def get_yuid(self):
            return self[UserData.Fields.YUID] or ""

        def get_crypta_id(self):
            return self[UserData.Fields.CRYPTA_ID] or ""

        def has_data(self):
            return not bool(self.get(UserData.Fields.WITHOUT_DATA))

        @staticmethod
        def convert_to_proto_identifiers(identifiers, not_unique=False):
            pb_identifiers = TIdentifiers()
            for key, value in identifiers.items():
                if value is None:
                    value = UserData.StringNone
                pb_identifiers.Identifiers[key] = str(value)
            pb_identifiers.NotUnique = not_unique
            return pb_identifiers

    @staticmethod
    def shift_enum_value(value):
        value = int(value) - 1
        if value == -1:
            value = None
        return value

    @staticmethod
    def records(data):
        for record in data:
            yield UserData.Record(record)

    StringNone = "none"
    SCHEMA = create_schema(TUserData)


class UserDataStats(object):
    Proto = TUserDataStats

    class Fields(object):
        ATTRIBUTES = 'Attributes'
        GROUP_ID = 'GroupID'
        IDENTIFIERS = 'Identifiers'
        STRATUM = 'Stratum'
        COUNTS = 'Counts'
        FILTER = 'Filter'
        DISTRIBUTIONS = 'Distributions'
        AFFINITIES = 'Affinities'
        SEGMENT_INFO = 'SegmentInfo'

    class Record(RecordBase):
        def get_attributes(self):
            return self._get_proto(TUserDataStats.TAttributesStats, UserDataStats.Fields.ATTRIBUTES)

        def get_identifiers(self):
            return self._get_proto(TIdentifiers, UserDataStats.Fields.IDENTIFIERS)

        def get_identifier(self, key_identifier, default=None):
            identifiers = self.get_identifiers()
            return identifiers.Identifiers.get(key_identifier, default)

        def get_stratum(self):
            return self._get_proto(TUserDataStats.TStratumStats, UserDataStats.Fields.STRATUM)

        def get_distributions(self):
            return self._get_proto(TUserDataStats.TDistributions, UserDataStats.Fields.DISTRIBUTIONS)

        def get_counts(self):
            return self._get_proto(TUserDataStats.TCounts, UserDataStats.Fields.COUNTS)

        def get_filter(self):
            return self._get_proto(TUserDataStats.TFilter, UserDataStats.Fields.FILTER)

        def get_affinities(self):
            return self._get_proto(TUserDataStats.TAffinitiveStats, UserDataStats.Fields.AFFINITIES)

        def get_segment_info(self):
            return self._get_proto(TSegmentInfo, UserDataStats.Fields.SEGMENT_INFO)

        def get_group_id(self):
            return self.get(UserDataStats.Fields.GROUP_ID, "") or ""

    @staticmethod
    def records(data):
        for record in data:
            yield UserDataStats.Record(record)

    SCHEMA = create_schema(TUserDataStats)
