import crypta.lib.python.bt.conf.conf as conf
import datetime
import os.path
import uuid

from yt.common import date_string_to_timestamp

from crypta.audience.lib.tasks.lookalike.constants import (
    SEGMENT_TYPE, DEFAULT_SEGMENT_TYPE,
    PRIORITY
)
from crypta.lib.python.bt.paths import (
    WithYt, Directory, Table, File, YtObject, WithID,
)

from crypta.audience.lib.tasks.audience import (
    _output_stats,
)
from crypta.audience.lib.affinity.affinity import (
    _compute_affinities,
)
from crypta.audience.lib.communality import (
    _compute_communality,
    _raw_trace,
)
from crypta.audience.lib.tasks.audience.tables import (
    Output,
    Attributes,
)
from crypta.lab.lib.tables import (
    UserDataStats,
)

import logging

import yt.yson as yson


logger = logging.getLogger(__name__)


def get_time_priority(batch):
    creation_time = batch.get_attribute(Attributes.CREATION_TIME, use_cached_attr=True)
    if creation_time is None:
        return 0

    timestamp = date_string_to_timestamp(creation_time)
    return -int(timestamp)


def get_batch_priority(batch):
    priority = batch.get_attribute(PRIORITY, use_cached_attr=True)
    if priority is None or isinstance(priority, yson.yson_types.YsonEntity):
        return None
    priority = int(priority)

    return priority if priority > 0 else get_time_priority(batch)


class WithPaths(object):
    @property
    def paths(self):
        return Paths(yt=self.yt)


class Batch(Directory, WithID):
    @property
    def meta(self):
        return Meta(self.yt, os.path.join(self, 'Meta'))

    @property
    def users(self):
        return Table(self.yt, os.path.join(self, 'Users'))

    @property
    def priority(self):
        return self.get_priority()

    def get_priority(self):
        return get_batch_priority(self)

    def set_priority(self, priority):
        self.set_attribute(PRIORITY, priority)


class TotalBatch(Directory, WithID, WithPaths):
    @property
    def input(self):
        return InputBatch(
            self.yt,
            os.path.join(self.paths.batched_inputs, self.id)
        )

    @property
    def waiting(self):
        return Batch(
            self.yt,
            os.path.join(self.paths.waiting_segments, self.id)
        )

    @property
    def done(self):
        return Batch(
            self.yt,
            os.path.join(self.paths.done_segments, self.id)
        )


class Segment(Table, WithID, WithPaths):
    STATUS = 'crypta_status'
    ERROR = 'crypta_error'
    LOCKS = 'locks'
    SEGMENT_ID = 'segment_id'

    RELATED_GOALS = 'crypta_related_goals'
    INTERESTS_AFFINITY = Output.Attributes.INTERESTS_AFFINITY
    SEGMENTS_AFFINITY = Output.Attributes.SEGMENTS_AFFINITY
    RELATED_GOALS_SIMILARITY = Output.Attributes.CRYPTA_RELATED_GOALS_SIMILARITY
    COMMUNALITY = Output.Attributes.COMMUNALITY
    OVERALL_STATS = Output.Attributes.OVERALL_STATS
    SEGMENT_INFO = 'crypta_segment_info'

    STATUS_NEW = Output.Statuses.NEW
    STATUS_DONE = Output.Statuses.DONE
    STATUS_FAILED = Output.Statuses.FAILED

    MAINTAIN_DEVICE_DISTRIBUTION = 'crypta_maintain_device_distribution'
    MAINTAIN_GEO_DISTRIBUTION = 'crypta_maintain_geo_distribution'
    NUM_OUTPUT_BUCKETS = 'crypta_lookalike_precision'
    MAX_COVERAGE = '_max_coverage'
    PRIORITY = 'segment_priority'
    INCLUDE_INPUT = '_include_input'

    @staticmethod
    def attribute_keys():
        return [
            Segment.MAINTAIN_DEVICE_DISTRIBUTION,
            Segment.MAINTAIN_GEO_DISTRIBUTION,
            Segment.NUM_OUTPUT_BUCKETS,
            Segment.MAX_COVERAGE,
            Segment.RELATED_GOALS,
            Segment.SEGMENT_ID,
            Segment.ROW_COUNT,
            Segment.LOCKS,
            SEGMENT_TYPE,
            Segment.INCLUDE_INPUT,
            Segment.SEGMENT_INFO,
        ]

    @property
    def priority(self):
        return self.get_segment_priority()

    def get_attribute(self, attr_name, use_cached_attr=True, **kwargs):
        return super(Segment, self)\
            .get_attribute(attr_name, use_cached_attr=use_cached_attr,
                           replace_none_to_default=True, **kwargs)

    def is_ready(self):
        is_status_new = self.get_status(default=None) == self.STATUS_NEW
        is_not_locked = not self.get_locks()
        has_segment_id = self.get_permanent_id(default=None) is not None
        return (is_status_new and is_not_locked and has_segment_id)

    def set_status_ok(self):
        self.set_attribute(self.STATUS, self.STATUS_DONE)
        self.set_attribute(self.ERROR, None)

    def set_status_failed(self, error_message=None):
        self.set_attribute(self.STATUS, self.STATUS_FAILED)
        self.set_attribute(self.OVERALL_STATS, _output_stats(UserDataStats.Proto()))
        if error_message:
            self.set_attribute(self.ERROR, error_message)

    def set_extended_stuff(self, local_stats, scale_factor, global_stats, similarity_scores, segment_type, quantiles):
        self.set_attribute(self.COMMUNALITY, _compute_communality(local_stats, segment_type, quantiles))
        self.set_attribute(Output.Attributes.COVARIANCE_TRACE, _raw_trace(local_stats))
        self.set_attribute(self.INTERESTS_AFFINITY, _compute_affinities(Output.CRYPTA_INTERESTS, local_stats, global_stats))
        self.set_attribute(self.SEGMENTS_AFFINITY, _compute_affinities(Output.CRYPTA_SEGMENTS, local_stats, global_stats))
        self.set_attribute(self.RELATED_GOALS_SIMILARITY, similarity_scores)
        self.set_attribute(self.OVERALL_STATS, _output_stats(local_stats, scale_factor))

    def get_status(self, **kwargs):
        return self.get_attribute(self.STATUS, **kwargs)

    def set_segment_type(self, segment_type):
        self.set_attribute(SEGMENT_TYPE, segment_type)

    def get_segment_type(self, default=DEFAULT_SEGMENT_TYPE, **kwargs):
        return self.get_attribute(SEGMENT_TYPE, default=default,
                                  replace_none_to_default=True, **kwargs)

    def get_segment_priority(self, default='0', **kwargs):
        return float(self.get_attribute(self.PRIORITY, default=default, **kwargs))

    def get_permanent_id(self, **kwargs):
        return self.get_attribute(self.SEGMENT_ID, **kwargs)

    def get_segment_info(self, **kwargs):
        info = self.get_attribute(self.SEGMENT_INFO, **kwargs)
        if info is None:
            info = {}
        return info

    def get_locks(self, **kwargs):
        return self.get_attribute(self.LOCKS, **kwargs)

    def get_output_segment(self):
        return self.paths.output_segments.segment(self.get_segment_type(),
                                                  self.id)

    def get_input_segment(self):
        return self.paths.input_segments.segment(self.get_segment_type(),
                                                 self.id)

    def get_enforce_device_and_platform(self):
        option = self.get_attribute(self.MAINTAIN_DEVICE_DISTRIBUTION,
                                       default=True)
        if isinstance(option, str):
            return option == 'true'
        return bool(option)

    def get_enforce_region(self):
        option = self.get_attribute(self.MAINTAIN_GEO_DISTRIBUTION,
                                       default=True)
        if isinstance(option, str):
            return option == 'true'
        return bool(option)

    def get_related_goals(self):
        return self.get_attribute(self.RELATED_GOALS, default=[])

    def get_num_output_buckets(self):
        max_n_buckets = conf.proto.Options.Lookalike.PrecisionBuckets
        n_buckets = self.get_attribute(
            self.NUM_OUTPUT_BUCKETS,
            default=max_n_buckets
        )
        return max(0, min(int(n_buckets), max_n_buckets))

    def get_max_coverage(self):
        max_coverage = self.get_attribute(
            self.MAX_COVERAGE,
            default=conf.proto.Options.Lookalike.MaxCoverage,
        )
        return max_coverage

    def get_input_including_mode(self):
        return self.get_attribute(
            self.INCLUDE_INPUT,
            default=False
        )


class WithSegments(YtObject):
    SEGMENTS = 'Segments'

    @property
    def segments(self):
        return self.get_segments()

    def set_segments(self, segments, **kwargs):
        self.set_attribute(self.SEGMENTS, segments, **kwargs)

    def get_segments(self, **kwargs):
        return self.get_attribute(self.SEGMENTS, **kwargs) or ()


class Sample(Table, WithSegments):
    pass


class Model(File, WithSegments):
    METRICS = 'Metrics'

    def set_metrics(self, segments, **kwargs):
        self.set_attribute(self.METRICS, segments, **kwargs)

    def get_metrics(self, **kwargs):
        return self.get_attribute(self.METRICS, **kwargs) or {}


class Meta(Table):
    OPTIONS = 'options'
    STATS = 'crypta_stats'
    STATS_OPTIONS = 'stats_options'

    def get_options(self, **kwargs):
        return self.get_attribute(Meta.OPTIONS, **kwargs)

    def set_options(self, value):
        return self.set_attribute(Meta.OPTIONS, value)

    def set_stats(self, value):
        return self.set_attribute(Meta.STATS, value)

    def get_stats(self, **kwargs):
        return self.get_attribute(Meta.STATS, **kwargs)

    def set_stats_options(self, value):
        return self.set_attribute(Meta.STATS_OPTIONS, value)

    def get_stats_options(self, default="", **kwargs):
        return self.get_attribute(Meta.STATS_OPTIONS, default=default, **kwargs)


class InputBatch(Table, WithID):
    SEGMENTS_META = 'segments_meta'
    BATCH_SIZE = 'batch_size'

    @property
    def priority(self):
        return get_batch_priority(self)

    @property
    def segments_meta(self):
        return self.get_attribute(self.SEGMENTS_META)

    @property
    def batch_size(self):
        batch_size = self.get_attribute(self.BATCH_SIZE)
        if batch_size is None:
            batch_size = len(self.segments_meta)
        return int(batch_size)

    def set_segments_meta(self, segments_meta):
        self.set_attribute(self.BATCH_SIZE, len(segments_meta))
        self.set_attribute(self.SEGMENTS_META, segments_meta)


class InputBatches(Directory):
    def __iter__(self):
        if not self.exists():
            return
        for batch in self.list(absolute=True, attributes=[InputBatch.BATCH_SIZE]):
            yield InputBatch(self.yt, batch)

    def batch(self, batch_id):
        return InputBatch(self.yt, os.path.join(self, batch_id))


class InputFolder(Directory):
    PRIORITY = 'priority'

    @staticmethod
    def get_default_priority(lookalike_type):
        return 0

    @property
    def priority(self):
        lookalike_type = os.path.basename(self)
        return self.get_attribute(self.PRIORITY, self.get_default_priority(lookalike_type))


class IOSegments(Directory):
    def __iter__(self):
        attributes = list(set([
            Segment.ROW_COUNT,
            Segment.PRIORITY,
            Segment.MODIFICATION_TIME,
            Segment.SEGMENT_ID,
            Segment.LOCKS,
            Segment.STATUS
        ] + Segment.attribute_keys()))
        for _dir in self.list(absolute=True, attributes=[InputFolder.PRIORITY]):
            _dir = InputFolder(self.yt, _dir)
            priority = _dir.priority

            for _segment in self.yt.list(_dir, max_size=None, absolute=True,
                                         attributes=attributes):
                _segment.attributes = {key: _segment.attributes.get(key)
                                       for key in attributes}
                yield (os.path.basename(_dir), priority), Segment(self.yt, _segment)

    @property
    def size(self):
        size = 0
        for _dir in self.yt.search(root=self, node_type=["map_node"], follow_links=True):
            size += Directory(self.yt, _dir).count
        return size

    def segment(self, segment_type, segment_id):
        return Segment(self.yt, os.path.join(self, segment_type, segment_id))


class Batches(Directory):
    def __iter__(self):
        if not self.exists():
            return
        attributes = [Attributes.CREATION_TIME, Attributes.LOCK_COUNT, PRIORITY]
        for batch in self.list(absolute=True, attributes=attributes):
            batch.attributes = {key: batch.attributes.get(key)
                                   for key in attributes}
            yield Batch(self.yt, batch)


class NewTable(Table):
    DATA_ID = 'data_id'

    def init_data_id(self):
        self.set_attribute(self.DATA_ID, str(uuid.uuid4()))

    @property
    def data_id(self):
        return self.get_attribute(self.DATA_ID, default=None)


class Paths(WithYt):

    @property
    def waiting_segments(self):
        return Batches(self.yt, conf.paths.lookalike.segments.waiting)

    @property
    def done_segments(self):
        return Batches(self.yt, conf.paths.lookalike.segments.done)

    @property
    def archive(self):
        return Batch(self.yt, conf.paths.lookalike.segments.archive)

    @property
    def batched_inputs(self):
        return InputBatches(self.yt, conf.paths.lookalike.segments.batches)

    @property
    def input_segments(self):
        return IOSegments(self.yt, conf.paths.lookalike.input)

    @property
    def output_segments(self):
        return IOSegments(self.yt, conf.paths.lookalike.output)

    @property
    def userdata(self):
        return Table(self.yt, conf.paths.lab.data.userdata)

    @property
    def userdata_stats(self):
        return Table(self.yt, conf.paths.lab.data.crypta_id.userdata_stats)

    @property
    def audience_by_id_value(self):
        return Table(self.yt, conf.paths.audience.matching.by_id_value)

    def batch(self, batch):
        return TotalBatch(self.yt, batch)

    def new_model(self):
        name = str(datetime.datetime.now().isoformat())
        self.models.create()
        return Model(self.yt, self.models.child_file(name))

    @staticmethod
    def generate_new_batch_id():
        return 'batch-{}'.format(uuid.uuid4())

    def table(self, path):
        return Table(self.yt, path)

    def directory(self, path):
        return Directory(self.yt, path)


__all__ = ['WithPaths']
