import logging

from cached_property import cached_property
from crypta.lab.lib.common import WithApi
from crypta.lab.lib.native_operations import (
    TLookalikeJoiner,
    TLookalikeMapper,
    TLookalikeReducer,
    TPredictMapper,
    TPredictReducer,
)
from crypta.lab.lib.specs import YtSpecs
from crypta.lab.lib.tables import (
    SampleStatsStorage,
    UserDataStats,
)
from crypta.lab.proto.lookalike_pb2 import (
    TLookalikeMapping,
    TLookalikeReducing,
    TLookalikeOutputView,
)
from crypta.lab.proto.sample_pb2 import TSampleStats
from crypta.lab.proto.view_pb2 import (
    ESampleViewState,
    TSampleView,
)
from crypta.lib.proto.user_data.user_data_stats_pb2 import TUserDataStats
import crypta.lib.python.bt.conf.conf as conf
from crypta.lib.python.bt.tasks import YtTask
from crypta.lib.python.bt.workflow import (
    IndependentTask,
    Parameter,
)
from crypta.lib.python.bt.workflow.targets.table import HasAttribute
from crypta.lib.python.swagger import _to_proto
from crypta.lib.python.yt import schema_utils
from crypta.lookalike.lib.python.utils import utils as lal_utils
from crypta.lookalike.proto.user_embedding_pb2 import TUserEmbedding
from crypta.lookalike.services.user_dssm_applier.py.native_operations import TApplyUserDssmMapper
from crypta.siberia.bin.common.yt_describer.proto.grouped_id_pb2 import TGroupedId
from crypta.siberia.bin.common.yt_describer.proto.group_stats_pb2 import TGroupStats
from crypta.siberia.bin.common.yt_describer.proto.yt_describer_config_pb2 import TYtDescriberConfig
from crypta.siberia.bin.common.yt_describer.py import describe
from paths import (
    Profile,
    WithPaths,
)

logger = logging.getLogger(__name__)


class Lookalike(YtTask, IndependentTask, WithPaths, WithApi):
    sample_id = Parameter()
    src_view = Parameter()
    dst_view = Parameter()

    SCORE = 'score'
    SEGMENT_ID = 'segment_id'
    STATS = 'Stats'
    YUID = 'yuid'

    @cached_property
    def dst_schema(self):
        return schema_utils.get_schema_from_proto(TLookalikeOutputView)

    @cached_property
    def segment_stats_schema(self):
        return schema_utils.get_schema_from_proto(TSampleStats)

    @cached_property
    def describe_input_schema(self):
        return schema_utils.get_schema_from_proto(TGroupedId)

    @cached_property
    def sample_stats_storage(self):
        storage = SampleStatsStorage(self._init_yt())
        storage.prepare_table()
        return storage

    @property
    def sort_output_by(self):
        return ['GroupID', 'MinusRegionSize', 'Region', 'MinusDeviceProbability', 'Device', 'MinusScore']

    def update_sample_view_state(self, id, view_id, state):
        return self.api.lab.updateSampleViewState(id=id, view_id=view_id, state=state).result()

    def get_view(self, view_id):
        try:
            view = self.api.lab.getSampleView(
                id=self.sample_id,
                view_id=view_id,
            ).result()
            assert view
            return _to_proto(TSampleView, view)
        except Exception as e:
            if 'status_code' in dir(e) and e.status_code == 404:
                return None
            raise e

    def prepare_stats(self, _stats):
        proto_stats = TUserDataStats()

        affinities = proto_stats.Affinities
        affinities.ParseFromString(_stats.get(UserDataStats.Fields.AFFINITIES) or "")
        attrbutes = proto_stats.Attributes
        attrbutes.ParseFromString(_stats.get(UserDataStats.Fields.ATTRIBUTES) or "")

        identifiers = proto_stats.Identifiers
        identifiers.ParseFromString(_stats.get(UserDataStats.Fields.IDENTIFIERS) or "")

        stratum = proto_stats.Stratum
        stratum.ParseFromString(_stats.get(UserDataStats.Fields.STRATUM) or "")

        distributions = proto_stats.Distributions
        distributions.ParseFromString(_stats.get(UserDataStats.Fields.DISTRIBUTIONS) or "")

        counts = proto_stats.Counts
        counts.ParseFromString(_stats.get(UserDataStats.Fields.COUNTS) or "")

        segmentInfo = proto_stats.SegmentInfo
        segmentInfo.ParseFromString(_stats.get(UserDataStats.Fields.SEGMENT_INFO) or "")

        proto_stats.GroupID = _stats.get(UserDataStats.Fields.GROUP_ID) or ""

        return proto_stats

    def src_for_describe_mapper(self, row):
        yield {
            'IdValue': row[self.YUID],
            'IdType': 'yandexuid',
            UserDataStats.Fields.GROUP_ID: self.group_id,
        }

    @property
    def group_id(self):
        return 'lookalikeGroup'

    def run(self, **kwargs):
        src_view = self.get_view(self.src_view)
        dst_view = self.get_view(self.dst_view)

        assert src_view.State == ESampleViewState.Value('READY'), 'Source view is not ready yet'

        src_path = src_view.Path
        dst_path = dst_view.Path
        output_count = dst_view.Options.Lookalike.Counts.Output
        use_dates = dst_view.Options.Lookalike.UseDates
        current_user_data = conf.paths.lab.data.userdata
        global_user_data_stats = conf.paths.lab.data.userdata_stats

        if use_dates:
            global_user_data_stats = conf.paths.lab.data.crypta_id.userdata_stats

        logger.info('Lookalike %s to %s for sample %s, use dates: %s',
                    src_view, dst_view, self.sample_id, use_dates)

        logger.info(self.update_sample_view_state(id=self.sample_id, view_id=self.dst_view, state='PROCESSING'))

        self.yt.create('table', dst_path, attributes={'schema': self.dst_schema}, force=True)

        segment_stats_attr = {'schema': self.segment_stats_schema, 'optimize_for': 'scan'}
        describe_src_attr = {'schema': self.describe_input_schema, 'optimize_for': 'scan'}
        embeddings_attrs = {'schema': schema_utils.get_schema_from_proto(TUserEmbedding), 'optimize_for': 'scan'}

        with_filters_schema = [{'name': name, 'type': 'any'} for name in self.sort_output_by+[UserDataStats.Fields.FILTER]]

        with self.yt.TempTable(prefix='srcForDescribe', attributes=describe_src_attr) as src_for_describe, \
             self.yt.TempTable(prefix='errorTable') as error_table, \
             self.yt.TempTable(prefix='userEmbeddingsTable', attributes=embeddings_attrs) as user_embeddings_table, \
             self.yt.TempTable(prefix='segmentStats', attributes=segment_stats_attr) as segment_stats, \
             self.yt.TempTable(prefix='withScore') as withScore, \
             self.yt.TempTable(prefix='withFilters', attributes={'schema': with_filters_schema}) as withFilters, \
             self.yt.TempTable(prefix='withResultAndOthers') as withResultAndOthers:

            if use_dates:
                rows = list(self.sample_stats_storage.select_rows('*', 'SampleID=\'{}\''.format(self.sample_id)))

                assert len(rows) > 0, 'Past description is not ready yet'

                userdata_stats = rows[0][self.STATS]
                logger.info('Past description for sample is retrieved')
            else:
                self.map(self.src_for_describe_mapper, src_path, src_for_describe,)

                description_config = TYtDescriberConfig(
                    CryptaIdUserDataTable=conf.paths.user_data_stats_by_cryptaid,
                    TmpDir='//tmp',
                    InputTable=src_for_describe,
                    OutputTable=segment_stats,
                )
                describe(self.yt, self, description_config)

                assert self.yt.get_attribute(segment_stats, 'row_count') == 1

                userdata_stats = self.yt.read_table(self.yt.TablePath(segment_stats, columns=[self.STATS])).next()
                userdata_stats = userdata_stats[self.STATS]

            mapping = TLookalikeMapping()
            group_stats = TGroupStats()
            parsed_userdata_stats = group_stats.Stats
            parsed_userdata_stats.ParseFromString(userdata_stats)

            segment_meta = mapping.Segments[self.group_id]
            segment_meta.UserDataStats.CopyFrom(parsed_userdata_stats)
            options = segment_meta.Options
            counts = options.Counts
            counts.Input = self.yt.get_attribute(src_path, "row_count")
            counts.Output = int(output_count)

            global_stats = self.prepare_stats(self.yt.read_table(global_user_data_stats).next())
            mapping.GlobalUserDataStats.CopyFrom(global_stats)
            mapping.MaxFilterErrorRate = float(1e-3)

            logger.info(mapping)

            state = mapping.SerializeToString()

            if use_dates:
                min_used_dates = self.yt.get_attribute(src_path, 'min_used_dates')
                dssm_files = lal_utils.get_old_dssm_model(self.yt, last_date=min(min_used_dates))

                self.native_map(
                    TApplyUserDssmMapper,
                    current_user_data,
                    [user_embeddings_table, error_table],
                    files=dssm_files
                )
            else:
                user_embeddings_table, dssm_files = lal_utils.get_last_version_of_dssm_entities(self.yt)

            self.native_map(
                TPredictMapper,
                user_embeddings_table,
                withScore,
                state,
                spec=YtSpecs.MAPPER_HIGH_MEMORY_USAGE,
                files=dssm_files,
            )

            self.sort(withScore, withScore, sort_by=self.sort_output_by)

            filter_data = parsed_userdata_stats.Filter
            self.yt.write_table(
                withFilters,
                [{
                    UserDataStats.Fields.GROUP_ID: self.group_id,
                    UserDataStats.Fields.FILTER: filter_data.SerializeToString(),
                }],
            )
            self.sort(withFilters, withFilters, sort_by=self.sort_output_by, spec=YtSpecs.ALLOW_BIG_ROWS)

            reducing = TLookalikeReducing()
            reducing.MaxFilterErrorRate = float(1e-3)
            options = reducing.Segments[self.group_id]
            options.Counts.Output = int(output_count)

            logging.info(reducing)

            reducing_state = reducing.SerializeToString()
            spec = dict(YtSpecs.ALLOW_BIG_ROWS, **YtSpecs.HIGH_MEMORY_USAGE)

            self.native_reduce(
                TPredictReducer,
                [withFilters, withScore],
                withResultAndOthers,
                reduce_by=UserDataStats.Fields.GROUP_ID,
                sort_by=self.sort_output_by,
                state=reducing_state,
                spec=spec,
            )

            self.sort(
                withResultAndOthers + "{{{},{}}}".format(self.YUID, self.SCORE),
                dst_path,
                sort_by=self.YUID
            )
        logger.info(self.update_sample_view_state(id=self.sample_id, view_id=self.dst_view, state='READY'))


class ComputeLookalikeDaily(YtTask, IndependentTask, WithPaths):

    def run(self, **kwargs):
        for path in self.paths.input:
            yield ComputeLookalike(segments=str(path))


class ComputeLookalike(YtTask, IndependentTask, WithPaths):

    segments = Parameter()

    @property
    def input_segments(self):
        return self.paths.get_input_segments(self.segments)

    def targets(self):
        yield HasAttribute(
            self.yt, self.destination,
            Profile.Attributes.DATE, self.paths.profiles.last_processed_date
        )

    @cached_property
    def batches(self):
        batches = list(self.input_segments.create_batches(self.paths))
        return batches

    @cached_property
    def destination(self):
        path = self.input_segments.output_path(self.paths)
        if self.yt.exists(path):
            path = self.yt.get("{}/@path".format(path))
        return self.paths.table(path)

    @property
    def linked_path(self):
        return self.paths.get_linked_path(self.destination)

    def run(self, **kwargs):
        sources = [batch.output(self.paths) for batch in self.batches]
        destination = self.destination

        schema = [
            {"name": name, "type": _type} for (name, _type) in
            [("probabilistic_segments", "any"),
             ("marketing_segments", "any"),
             ("id", "string"),
             ("id_type", "string")]
        ]
        with self.yt.TempTable(attributes=dict(schema=schema)) as tmp:
            self.native_reduce(TLookalikeJoiner, sources, tmp,
                               reduce_by=Profile.Fields.YANDEXUID)
            self.sort(tmp, tmp, sort_by=["id_type", "id"])

            self.merge(tmp, destination, spec=dict(schema_inference_mode='from_input'))
        self.input_segments.batches_dir(self.paths).remove(force=True)
        self.input_segments.done_dir(self.paths).remove(force=True)
        self.destination.set_attribute(
            Profile.Attributes.DATE,
            self.paths.profiles.last_processed_date
        )

        if self.linked_path:
            if self.yt.exists(self.linked_path) and str(self.yt.get_attribute(self.linked_path, "path")) != str(self.destination):
                self.yt.remove(self.linked_path)
            self.yt.link(self.destination, self.linked_path, recursive=True, ignore_existing=True)

    def requires(self):
        for batch in self.batches:
            yield ComputeLookalikeBatch(batch=str(batch))


class ComputeLookalikeBatch(YtTask, WithPaths):

    batch = Parameter()

    @property
    def batched_segments(self):
        return self.paths.get_batched_segments(self.batch)

    @property
    def destination(self):
        return self.batched_segments.output(self.paths)

    def complete(self):
        return self.yt.exists(self.destination)

    def run(self, **kwargs):
        batch = self.batched_segments
        segments, state = batch.get_reducing_state()

        self.native_reduce(
            TLookalikeReducer,
            batch,
            self.destination,
            state=state,
            reduce_by=batch.SEGMENT_ID
        )
        self.sort(
            self.destination,
            self.destination,
            sort_by=Profile.Fields.YANDEXUID
        )
        batch.remove()

    def requires(self):
        yield ComputeTopScores(batch=self.batch)


class ComputeTopScores(YtTask, IndependentTask, WithPaths):

    batch = Parameter()

    def targets(self):
        yield self.paths.table(self.batch).be.exists()
        yield self.paths.table(self.batch).be.not_empty()

    def run(self, **kwargs):
        batch = self.paths.get_batched_segments(self.batch)
        path = batch.get_source_path()
        vectors = self.paths.profiles
        state = path.get_mapping_state(
            self.yt.get_attribute(vectors, 'generate_date'),
            self.paths.daily_profiles.size
        )
        self.native_map(TLookalikeMapper, vectors, batch, state)
        self.sort(
            source=batch,
            destination=batch,
            sort_by=[batch.SEGMENT_ID, batch.MINUS_SCORE]
        )
