from crypta.lab.proto.lookalike_pb2 import (
    Mapping,
    Reducing
)
from crypta.lib.python.bt.paths import (
    Directory,
    Table,
    WithID,
    BasePaths,
)
import crypta.lib.python.bt.commons.dates as dates
import crypta.lib.python.bt.conf.conf as conf

import os.path
import time

import logging
logger = logging.getLogger(__name__)


class WithPaths(object):
    @property
    def paths(self):
        return Paths(yt=self.yt)


class Profile(Table):
    class Fields(object):
        YANDEXUID = 'yandexuid'

    class Attributes(object):
        DATE = 'generate_date'

    @property
    def last_processed_date(self):
        return self.get_attribute(self.Attributes.DATE)


class Input(Table, WithID):
    class Fields(object):
        OUTPUT_SIZE = 'yandexuid_count'
        DATA_TYPE = 'data_type'
        SEGMENT = 'segment'
        VECTOR = 'vector'

    class Record(dict):
        @property
        def segment_id(self):
            return "{}:{}".format(self[Input.Fields.DATA_TYPE], self[Input.Fields.SEGMENT])

        @property
        def output_size(self):
            return self[Input.Fields.OUTPUT_SIZE]

        @property
        def vector(self):
            return self[Input.Fields.VECTOR]

    @staticmethod
    def record(record):
        return Input.Record(record)

    def read(self):
        for record in super(Input, self).read():
            yield self.record(record)

    def get_mapping_state(self, date, profile_size):
        mapping = Mapping()
        for meta in self.read():
            segment_id = meta.segment_id
            segment_meta = mapping.segments[segment_id]
            segment_meta.vector = meta.vector
            parameters = segment_meta.parameters
            parameters.probability = float(meta.output_size) / profile_size
        if isinstance(date, str):
            date = time.mktime(dates.parse_day(date).timetuple())
        mapping.OldestTimestamp = int(date - 24*60*60*int(conf.proto.Options.Days))
        return mapping.SerializeToString()

    def create_batches(self, paths):
        destination_dir = self.batches_dir(paths)
        batch_size = int(conf.proto.Options.LookalikeBatchSize)
        size = self.size
        count = (size + batch_size) / batch_size
        for i in range(count):
            start_position = i*batch_size
            end_position = min(size, (i+1)*batch_size)
            if start_position == size:
                break
            path = "{}[#{}:#{}]".format(self, start_position, end_position)
            batch_name = "{}-{}".format(start_position, end_position)
            batch = Batch(self.yt, destination_dir.child_table(batch_name))
            batch.create()
            batch.set_source_path(path)
            yield batch

    def batches_dir(self, paths):
        return paths.batches_segments.child_directory(self.id)

    def done_dir(self, paths):
        return paths.done_segments.child_directory(self.id)

    def output_path(self, paths):
        return paths.output.child_directory(self.id)


class Batch(Table):
    SEGMENT_ID = 'volatile_id'
    MINUS_SCORE = '_score'

    class Attributes(object):
        SOURCE_PATH = 'source_path'

    def set_source_path(self, source_path):
        self.set_attribute(self.Attributes.SOURCE_PATH, source_path)

    def get_source_path(self):
        return Input(self.yt, self.get_attribute(self.Attributes.SOURCE_PATH))

    def get_reducing_state(self):
        segments = {}
        for meta in self.get_source_path().read():
            segments[meta.segment_id] = meta.output_size
        reducing = Reducing()
        segment_ids = []
        for i, (segment_id, count) in enumerate(segments.items()):
            segment = reducing.segments[segment_id]
            counts = segment.counts
            counts.output = count
            segment.table_index = i
            segment_ids.append(segment_id)
        return segment_ids, reducing.SerializeToString()

    def output(self, paths):
        _dir, basename = os.path.split(str(self))
        path = Batch(self.yt, os.path.join(paths.done_segments, os.path.basename(_dir), basename))
        return path

    def prepare_output(self, paths):
        path = self.output(paths)
        path.create()
        path.set_source_path(self.get_source_path())
        return path

    def get_output_separate_paths(self, paths, segments):
        logger.info(segments)
        _dir, basename = os.path.split(str(self))
        root = os.path.join(paths.output, os.path.basename(_dir))
        logger.info(root)
        segment_paths = [os.path.join(root, segment) for segment in segments]
        return segment_paths


class InputTables(Directory):
    def __iter__(self):
        for segments in self.list(absolute=True, max_size=None):
            yield Input(segments)


class Paths(BasePaths):
    @property
    def profiles(self):
        return Profile(self.yt, conf.paths.vectors.monthly)

    @property
    def daily_profiles(self):
        return Profile(self.yt, conf.paths.vectors.daily)

    @property
    def input(self):
        return InputTables(self.yt, conf.paths.lab.lookalike.input)

    @property
    def output(self):
        return Directory(self.yt, conf.paths.lab.lookalike.output)

    @property
    def batches_segments(self):
        return Directory(self.yt, conf.paths.lab.lookalike.segments.batches)

    @property
    def done_segments(self):
        return Directory(self.yt, conf.paths.lab.lookalike.segments.done)

    def get_input_segments(self, path):
        return Input(self.yt, path)

    def get_batched_segments(self, path):
        return Batch(self.yt, path)

    def get_linked_path(self, destination):
        if os.path.basename(destination) == 'segment_vectors_for_audience':
            return conf.paths.lab.lookalike.link_segments_by_audience
        return
