import datetime
import itertools
import logging
from operator import itemgetter
import random

import dateutil
import pytz
import tzlocal
import yt.wrapper as yt

import crypta.lib.python.bt.conf.conf as conf
from constants import (
    VOLATILE_ID,
    PERMANENT_ID,
    SEGMENT_TYPE,
    FILTER_CAPACITY,
    YUID, YANDEXUID,
    TIMESTAMP,
    ID_VALUE,
    OPTIONS,
    ENFORCE_DEVICE_AND_PLATFORM,
    ENFORCE_REGION,
    RELATED_GOALS,
    NUM_OUTPUT_BUCKETS,
    WEIGHT,
    MAX_COVERAGE,
    INCLUDE_INPUT,
    PRIORITY,
)
from crypta.audience.lib.tasks.audience import (
    _compute_stats_per_segment,
    _output_stats,
)
from crypta.audience.lib.tasks.audience.tables import (
    Attributes,
    LookalikeSegmentStateStorage,
)
from crypta.audience.lib.tasks.base import (
    YtTask,
)
from crypta.audience.lib.tasks.lookalike import (
    sql,
)
from crypta.lab.lib.tables import (
    UserDataStats,
)
from crypta.lab.proto.lookalike_pb2 import (
    TLookalikeOptions,
)
from crypta.lib.proto.user_data.user_data_stats_pb2 import (
    TSamplingOptions,
    TUserDataStatsOptions,
)
import crypta.lib.python.bt.workflow as workflow
from crypta.lib.python import time_utils
from crypta.lib.python.yt import schema_utils
from paths import (
    Segment,
    WithPaths,
)


logger = logging.getLogger(__name__)


def _output_failed_segment(task, segment_type, volatile_id, permanent_id, message):
    output_segment = task.paths.output_segments.segment(segment_type, str(volatile_id))
    logger.error("Segment %s (%s) is invalid: %s", volatile_id, permanent_id, message)
    output_segment.create(recursive=True)
    output_segment.set_status_failed(message)

    with LookalikeSegmentStateStorage(task._init_yt(), experiment_mode=True).batched_inserter as state_inserter:
        sql.write_segment(
            permanent_id,
            input=False,
            row_count=0,
            description='empty output',
            lookalike_type=segment_type,
            state_inserter=state_inserter,
        )


def batch_schema():
    return schema_utils.yt_schema_from_dict({
        YANDEXUID: "string",
        YUID: "string",
        WEIGHT: "double",
        VOLATILE_ID: "string",
        PERMANENT_ID: "int64",
        TIMESTAMP: "int64",
    })


class EnqueueAllSegments(workflow.IndependentTask):
    def run(self, **kwargs):
        for shard in range(conf.proto.Options.Lookalike.InputShards):
            yield EnqueueAllSegmentsByShard(shard=shard, total_shards=conf.proto.Options.Lookalike.InputShards, priority=self.priority)


class EnqueueAllSegmentsByShard(YtTask, workflow.IndependentTask, WithPaths):
    shard = workflow.Parameter(parse=int, default=0)
    total_shards = workflow.Parameter(parse=int, default=1)

    @property
    def max_batch_count(self):
        return int(conf.proto.Options.Lookalike.InputBatchCount)

    def batches(self, segments, rate=3.):
        max_batch_tables = conf.proto.Options.Lookalike.InputBatch
        max_batch_rows = conf.proto.Options.Lookalike.InputBatchRows
        sorted_segments = sorted(
            ((key, segment) for key, segment in segments if segment.get_permanent_id() % self.total_shards == self.shard),
            key=lambda ((segment_type, type_priority), segment): (
                -type_priority,
                -segment.priority,
                segment.modification_time,
                segment.size,
                (segment_type, segment)
            )
        )
        n_batches = 0
        for (segment_type, _), group in itertools.groupby(sorted_segments, itemgetter(0)):
            current_batch, current_size, batch_oldest_ts = [], 0, datetime.datetime.max

            for _, segment in group:
                batch_tables = len(current_batch)
                exceeds_rows_limit = (current_size + segment.size >= max_batch_rows)
                exceeds_tables_limit = (batch_tables >= max_batch_tables)

                batch_oldest_ts = min(
                    dateutil.parser.parse(segment.attributes.get(Segment.MODIFICATION_TIME), ignoretz=True),
                    batch_oldest_ts,
                )

                if (exceeds_tables_limit or exceeds_rows_limit) and batch_tables > 0:
                    yield self.paths.generate_new_batch_id(), current_batch
                    n_batches += 1
                    if n_batches > self.max_batch_count * rate:
                        return
                    current_batch, current_size, batch_oldest_ts = [], 0, datetime.datetime.max

                if segment.is_ready():
                    current_batch.append((segment_type, segment))
                    current_size += segment.size

            now = tzlocal.get_localzone().localize(datetime.datetime.now())
            if len(current_batch) and (now - pytz.utc.localize(batch_oldest_ts)) > datetime.timedelta(seconds=conf.proto.Options.Lookalike.InputIncompleteBatchMinAgeSec):
                yield self.paths.generate_new_batch_id(), current_batch

    @staticmethod
    def users_with_meta(segments_meta):
        @yt.with_context
        def mapper_(record, context):
            meta = segments_meta[context.table_index or 0]
            result = {
                YANDEXUID: record[ID_VALUE],
                YUID: record[ID_VALUE],
                WEIGHT: record.get(WEIGHT, 1.)
            }
            for key in [VOLATILE_ID, PERMANENT_ID, TIMESTAMP]:
                result[key] = meta[key]
            yield result

        return mapper_

    def get_segment_meta(self, segment_type, segment, now):
        segment_meta = {
            VOLATILE_ID: segment.id,
            PERMANENT_ID: segment.get_permanent_id(default=None),
            PRIORITY: segment.priority,
            SEGMENT_TYPE: segment_type,
            TIMESTAMP: now,
            OPTIONS: {
                ENFORCE_DEVICE_AND_PLATFORM:
                    segment.get_enforce_device_and_platform(),
                ENFORCE_REGION:
                    segment.get_enforce_region(),
                MAX_COVERAGE:
                    segment.get_max_coverage(),
                INCLUDE_INPUT: segment.get_input_including_mode()
            },
            RELATED_GOALS: segment.get_related_goals(),
            NUM_OUTPUT_BUCKETS: segment.get_num_output_buckets(),
            FILTER_CAPACITY: segment.size,
            Segment.SEGMENT_INFO: segment.get_segment_info(),
        }
        return segment_meta

    def create_batch(self, segments_with_types, batch_id):
        batch = self.paths.batched_inputs.batch(batch_id=batch_id)
        types, segments, segments_meta = zip(*segments_with_types)
        indexed_segments_meta = dict(enumerate(segments_meta))
        self.map(
            self.users_with_meta(indexed_segments_meta),
            segments,
            self.yt.TablePath(batch, schema=batch_schema()),
        )
        batch.set_segments_meta({meta[VOLATILE_ID]: meta for meta in segments_meta})
        priority = sum(meta.get(PRIORITY, 0) for meta in segments_meta)
        batch.set_attribute(PRIORITY, priority)
        return batch

    def run(self, **kwargs):
        now = time_utils.get_current_time()
        self.paths.batched_inputs.create(ignore_existing=True)
        batches = list(self.batches(self.paths.input_segments))
        random.shuffle(batches)

        failed_segments = []

        with LookalikeSegmentStateStorage(self._init_yt(), experiment_mode=True).batched_inserter as state_inserter:
            for batch_id, batch in itertools.islice(batches, 0, self.max_batch_count):
                segments = []
                for segment_type, segment in batch:
                    sql.write_segment(
                        segment.get_permanent_id(),
                        input=True,
                        row_count=segment.size,
                        lookalike_type=segment_type,
                        timestamp=now,
                        state_inserter=state_inserter,
                    )
                    segment_meta = self.get_segment_meta(segment_type, segment, now)
                    if self.yt.is_empty(segment):
                        failed_segments.append((segment_type, segment.id, segment.get_permanent_id(), "Empty segment"))
                        segment.remove()
                    elif segment_meta[NUM_OUTPUT_BUCKETS] == 0 or segment_meta[OPTIONS][MAX_COVERAGE] == 0:
                        failed_segments.append((segment_type, segment.id, segment.get_permanent_id(), "Empty output"))
                        segment.remove()
                    else:
                        segments.append((segment_type, segment, segment_meta))
                if not segments:
                    continue
                self.create_batch(segments, batch_id)
                for _, segment, _ in segments:
                    segment.remove()

        for args in failed_segments:
            _output_failed_segment(self, *args)


class PrepareSegmentsBatch(YtTask, workflow.IndependentTask, WithPaths):

    batch = workflow.Parameter()
    timeout = datetime.timedelta(hours=12)

    def complete(self):
        return not self.yt.exists(self.batch)

    @property
    def batch_path(self):
        return self.paths.batch(self.batch)

    @property
    def destination(self):
        return self.batch_path.waiting.users

    @property
    def destination_meta(self):
        return self.batch_path.waiting.meta

    def collect_datastats_options(self, segments_meta):
        max_filter_capacity = 64 * 2**20
        options = TUserDataStatsOptions(
            SamplingOptions=TSamplingOptions(
                SkipRate=float(conf.proto.Options.Lookalike.TestRate)
            )
        )
        for segment_id, meta in segments_meta.items():
            segment_options = options.Segments[segment_id]
            info = segment_options.Info.Info
            for key, value in meta.get(Segment.SEGMENT_INFO, {}).items():
                info[str(key)] = str(value)

            if not meta[OPTIONS][INCLUDE_INPUT]:
                filter_options = segment_options.FilterOptions
                filter_options.Capacity = min(meta[FILTER_CAPACITY], max_filter_capacity)
                filter_options.ErrorRate = round(float(conf.proto.Options.Lookalike.FilterErrorRate), 9)

        logger.info(options)
        return options

    def get_options(self, segment_id, segments_meta):
        buckets = conf.proto.Options.Lookalike.PrecisionBuckets
        bucket_size = conf.proto.Options.Lookalike.BucketSize

        _meta = segments_meta[segment_id]
        _options = _meta[OPTIONS]
        options = TLookalikeOptions()
        options.EnforceDeviceAndPlatform = _options[ENFORCE_DEVICE_AND_PLATFORM]
        options.EnforceRegion = _options[ENFORCE_REGION]
        options.IncludeInput = _options[INCLUDE_INPUT]
        options.LookalikeType = _meta[SEGMENT_TYPE]
        options.RelatedGoals.extend(_meta[RELATED_GOALS])
        options.Timestamp.Value = _meta[TIMESTAMP]

        counts = options.Counts
        counts.Input = _meta[FILTER_CAPACITY]
        if options.LookalikeType == 'audience':
            counts.Output = int(bucket_size * 2**(_meta.get(NUM_OUTPUT_BUCKETS, buckets)-1))
        else:
            counts.Output = _options[MAX_COVERAGE]

        counts.MaxCoverage = _options[MAX_COVERAGE]
        options.VolatileId = _meta[VOLATILE_ID]
        options.PermanentId = _meta[PERMANENT_ID]

        return options

    def get_all_serialized_options(self, segments_meta):
        return {segment_id: self.get_options(segment_id, segments_meta).SerializeToString()
                for segment_id in segments_meta}

    def _write_stats(self, stats):
        def serialize(group_id, proto):
            serialized = {UserDataStats.Fields.GROUP_ID: group_id}

            for name in [
                UserDataStats.Fields.AFFINITIES,
                UserDataStats.Fields.ATTRIBUTES,
                UserDataStats.Fields.COUNTS,
                UserDataStats.Fields.DISTRIBUTIONS,
                UserDataStats.Fields.FILTER,
                UserDataStats.Fields.IDENTIFIERS,
                UserDataStats.Fields.SEGMENT_INFO,
                UserDataStats.Fields.STRATUM,
            ]:
                serialized[name] = getattr(proto, name).SerializeToString() if proto.HasField(name) else None

            return serialized

        serialized_stats = (
            serialize(group_id, stat)
            for group_id, stat in stats.iteritems()
        )
        self.yt.write_table(self.destination_meta, serialized_stats, table_writer={"max_row_weight": 128 * 2**20})

    def _check_segments_matching(self, stats, segments_meta):
        expected_volatile_ids = {
            each[VOLATILE_ID] for each in segments_meta.values()
        }
        for segment_id in expected_volatile_ids:
            if segment_id not in stats or not stats[segment_id].Counts.WithData:
                _output_failed_segment(
                    self,
                    segments_meta[segment_id][SEGMENT_TYPE],
                    segment_id,
                    segments_meta[segment_id][PERMANENT_ID],
                    'No match',
                )
                stats.pop(segment_id, None)
                segments_meta.pop(segment_id, None)

    def run(self, **kwargs):
        self.yt.lock(self.batch, mode="exclusive")

        segments_meta = self.batch_path.input.segments_meta
        logger.info(segments_meta)
        self.yt.move(self.batch, self.destination, recursive=True, force=True)

        stats_options = self.collect_datastats_options(segments_meta)

        stats, scales = _compute_stats_per_segment(
            self,
            self.destination,
            group_by=VOLATILE_ID,
            id_value_column=YUID,
            user_data_stats_options=stats_options,
            experiment="by_crypta_id",
        )

        # TODO(CRYPTA-13107) check if this part is not repeated in prediction
        self._check_segments_matching(stats, segments_meta)

        self._write_stats(stats)
        self.destination_meta.set_options(self.get_all_serialized_options(segments_meta))

        segment_stats = {segment_id: _output_stats(_stats, scales[segment_id]) for segment_id, _stats in stats.items()}
        self.destination_meta.set_stats(segment_stats)
        self.destination_meta.set_stats_options(stats_options.SerializeToString())
        self.batch_path.waiting.set_priority(self.destination.get_attribute(PRIORITY))


class PrepareEnqueuedSegments(YtTask, workflow.IndependentTask, WithPaths):

    def run(self, **kwargs):
        for batch in self.paths.batched_inputs:
            if batch.get_attribute(Attributes.LOCK_COUNT):
                continue

            # TODO(CRYPTA-14711) find out why priority is None
            priority = batch.priority
            if priority is None:
                priority = 0

            yield PrepareSegmentsBatch(batch=batch, priority=priority)
