import datetime
import json
import logging
import os
from time import time
import uuid

import sql
from yt.wrapper import (
    yson,
    create_table_switch,
)

from cached_property import cached_property
import crypta.lib.python.bt.conf.conf as conf
from crypta.audience.lib.tasks.lookalike.paths import WithPaths
from crypta.audience.lib.tasks.base import (
    YtTask,
)
from crypta.audience.lib.tasks.audience import (
    _compute_stats_per_goal,
    _compute_stats_per_segment,
    _prepare_stats,
)
from crypta.audience.lib.tasks.audience.sql import (
    write_segment_goal_relation,
    write_segment_properties,
)
from crypta.audience.lib.tasks.audience.tables import (
    Attributes,
    InputBatch,
    LookalikeSegmentStateStorage,
    LookalikeStatsStorage,
    Output,
    SegmentPropertiesStorage,
    SegmentsGoalsRelationStorage,
    StorageOutput,
    StatsStorage,
)
from crypta.audience.lib.communality import (
    _cosine_quantile_score,
    _raw_trace,
)
from crypta.audience.lib.tasks.lookalike.constants import (
    ALLOW_BIG_ROWS,
    LOOKALIKE_AUDIENCE_TYPE,
    MAPPER_HIGH_MEMORY_USAGE,
    REDUCER_HIGH_MEMORY_USAGE,
    VOLATILE_ID,
)
from crypta.audience.lib.tasks.lookalike.interaction import (
    _output_failed_segment
)
from crypta.audience.lib.tasks.math import (
    _similarity,
)
from crypta.audience.lib.tasks.native_operations import (
    TExtractStorageOutput,
    TKeepUniqueStorageOutput,
)
from crypta.audience.proto.states_pb2 import (
    TExportToStorageState,
)
from crypta.lab.lib.native_operations import (
    TPredictMapper,
    TPredictReducer,
)
from crypta.lab.lib.tables import (
    UserDataStats,
)
from crypta.lab.proto.lookalike_pb2 import (
    TLookalikeMapping,
    TLookalikeReducing,
    TLookalikeOptions,
    TLookalikeOutput,
)
from crypta.lib.proto.user_data.user_data_stats_pb2 import (
    TUserDataStatsOptions,
)
from crypta.lib.python.bt.workflow import (
    IndependentTask,
    Parameter,
)
from crypta.lookalike.lib.python.utils import utils as lal_utils
from crypta.lib.python.yt import schema_utils
from crypta.profile.lib import date_helpers

logger = logging.getLogger(__name__)


class SplitMapper(object):

    def __init__(self, indexed_segments, ts):
        self.indexed_segments = indexed_segments
        self.ts = ts

    def __call__(self, record):
        valid = False
        try:
            yson.YsonUint64(record[Output.Fields.YUID])
            valid = True
        except ValueError:
            pass

        if valid:
            result = {}
            result.update(record)
            volatile_id = result.pop(VOLATILE_ID, None)
            result.pop(InputBatch.Fields.EXTERNAL_ID, None)
            result[Output.Fields.ID_VALUE] = result[Output.Fields.YUID]
            result[Output.Fields.SEND] = 0
            yield create_table_switch(self.indexed_segments[volatile_id])
            yield result


def _exclude_covariance(userdata_stats):
    userdata_stats.Distributions.Main.Covariance.ParseFromString("")


def _fill_uniq_id_value(stats_per_segment):
    # id_value == yuid in this case so we hack it
    for _, stats in stats_per_segment.iteritems():
        stats.Counts.UniqIdValue = stats.Counts.Total


def extract_segments_options(options):
    meta = {}
    for segment_id, bin_options in options.items():
        options = TLookalikeOptions()
        options.ParseFromString(bin_options)
        meta[segment_id] = options
    return meta


class PredictSegments(YtTask, IndependentTask, WithPaths):

    batch = Parameter()
    timeout = datetime.timedelta(hours=12)

    @property
    def batch_path(self):
        return self.paths.batch(self.batch)

    @property
    def done_path(self):
        return self.batch_path.done

    @property
    def segments_meta_table(self):
        return self.batch_path.waiting.meta

    def extract_segments_meta(self):
        segments_meta = {}
        for segment_meta in self.segments_meta_table.read(
                columns=[]):
            segments_meta[segment_meta[VOLATILE_ID]] = segment_meta
        return segments_meta

    @cached_property
    def destinations(self):
        result = {}
        for (segment_id, options) in self.segments_options.iteritems():
            destination = self.paths.output_segments.segment(
                segment_type=options.LookalikeType,
                segment_id=segment_id
            )
            result[segment_id] = destination
        return result

    @cached_property
    def stats_storage(self):
        stats_storage = StatsStorage(self._init_yt())
        stats_storage.prepare_table()
        return stats_storage

    @cached_property
    def indexed_segments(self):
        return {segment_id: i for (i, (segment_id, _))
                in enumerate(self.destinations.items())}

    def move_replace(self, source, destination):
        if self.yt.exists(destination):
            self.yt.remove(destination, recursive=True, force=True)
        self.yt.move(source, destination, recursive=True)

    def get_mapping_state(self):
        mapping = TLookalikeMapping()
        for _userdata_stats in self.segments_meta_table.read(
                columns=[
                    UserDataStats.Fields.GROUP_ID,
                    UserDataStats.Fields.DISTRIBUTIONS,
                    UserDataStats.Fields.ATTRIBUTES,
                    UserDataStats.Fields.COUNTS,
                    UserDataStats.Fields.AFFINITIES,
                    UserDataStats.Fields.STRATUM,
                ]):
            userdata_stats = _prepare_stats(_userdata_stats)
            segment_id = str(userdata_stats.GroupID)
            segment_meta = mapping.Segments[segment_id]
            segment_meta.UserDataStats.CopyFrom(userdata_stats)
            _exclude_covariance(segment_meta.UserDataStats)
            segment_meta.Options.CopyFrom(self.segments_options[segment_id])

            if segment_meta.Options.EnforceRegion and has_only_unknown_region(segment_meta.UserDataStats):
                logger.info("Change EnforceRegion to False for segment '%s' because there is only unknown region", segment_id)
                segment_meta.Options.EnforceRegion = False

        global_stats = _prepare_stats(self.paths.userdata_stats.read().next())
        _exclude_covariance(global_stats)
        mapping.GlobalUserDataStats.CopyFrom(global_stats)
        mapping.MaxFilterErrorRate = \
            float(conf.proto.Options.Lookalike.FilterErrorRate)

        return mapping

    def run_native_predict_map(self, scored):
        state = self.get_mapping_state().SerializeToString()
        user_embeddings_table, dssm_files = lal_utils.get_last_version_of_dssm_entities(self.yt)

        self.native_map(
            TPredictMapper,
            user_embeddings_table,
            scored,
            state,
            spec=MAPPER_HIGH_MEMORY_USAGE,
            files=dssm_files,
        )

    def get_reducing_state(self):
        bin_stats_options = self.segments_meta_table.get_stats_options()
        stats_options = TUserDataStatsOptions()
        if bin_stats_options:
            stats_options.ParseFromString(bin_stats_options)

        reducing = TLookalikeReducing()
        reducing.SamplingOptions.CopyFrom(stats_options.SamplingOptions)
        if float(reducing.SamplingOptions.SkipRate) < 1e-7:
            reducing.SamplingOptions.SkipRate = 1.
        for segment_id, options in self.segments_options.items():
            reducing.Segments[segment_id].CopyFrom(options)
        reducing.MaxFilterErrorRate = \
            float(conf.proto.Options.Lookalike.FilterErrorRate)
        return reducing

    @cached_property
    def segments_options(self):
        return extract_segments_options(self.segments_meta_table.get_options())

    def run_native_predict_reduce(self, source, destination, sort_by, lookalike_stats_storage=None):
        reducing_state = self.get_reducing_state().SerializeToString()

        spec = dict(ALLOW_BIG_ROWS, **REDUCER_HIGH_MEMORY_USAGE)
        operation_id = self.native_reduce(
            TPredictReducer,
            source,
            self.yt.TablePath(destination, schema=schema_utils.get_schema_from_proto(TLookalikeOutput)),
            reduce_by=UserDataStats.Fields.GROUP_ID,
            sort_by=sort_by,
            state=reducing_state,
            spec=spec
        )
        op = self.yt.get_operation(operation_id)
        stats = op.get('progress').get('job_statistics').get('custom')
        logger.info("Filtered metrics %s", json.dumps(stats))
        if lookalike_stats_storage is not None:
            lookalike_stats_storage.insert_row(self.batch_path.id, stats)

    def get_segment_type(self, segment_id):
        return 'lookalike'

    def extract_audience_scored(self, scored, audience_scored):
        lookalike_types = {segment_id: options.LookalikeType
                           for segment_id, options in self.segments_options.items()}

        def mapper(record):
            if lookalike_types[record[VOLATILE_ID]] == LOOKALIKE_AUDIENCE_TYPE:
                yield record

        self.map(mapper, scored, audience_scored)

    def write_to_storage(self, scored, ts):
        with self.yt.TempTable() as audience_scored:
            self.extract_audience_scored(scored, audience_scored)
            storage_output = os.path.join(conf.paths.storage.queue,
                                          str(uuid.uuid4()))
            self.yt.create('table', storage_output, recursive=True, force=True,
                           attributes={Attributes.SCHEMA: StorageOutput.SCHEMA})
            storage_state = TExportToStorageState()
            storage_state.Timestamp = ts
            self.native_map_reduce(
                mapper_name=TExtractStorageOutput,
                reducer_name=TKeepUniqueStorageOutput,
                source=audience_scored,
                destination=storage_output,
                reduce_by=StorageOutput.REDUCE_BY,
                mapper_state=storage_state.SerializeToString(),
            )

    def get_stats(self, stats_per_segment, lookalike_type):
        return (StatsStorage.record(options.PermanentId, lookalike_type, stats_per_segment[segment_id].SerializeToString())
                for (segment_id, options) in self.segments_options.iteritems())

    def get_the_date_before_userdata_last_update(self):
        return date_helpers.get_yesterday(self.yt.get_attribute(conf.paths.lab.data.userdata, '_last_update_date'))

    def run(self, **kwargs):
        self.yt.lock(self.batch, mode="exclusive")

        storage_client = self._init_yt()

        ts = int(time())

        userdata_stats_path = conf.paths.lab.data.crypta_id.userdata_stats

        split_mapper = SplitMapper(
            indexed_segments=self.indexed_segments,
            ts=ts,
        )

        global_stats = _prepare_stats(
            next(self.yt.read_table(userdata_stats_path)))
        assert global_stats

        with self.yt.TempTable(prefix='scored_') as scored, self.yt.TempTable(prefix='filters_') as filters:

            destinations = self.destinations.values()
            for output_table in destinations:
                self.yt.remove(output_table, force=True)
                self.yt.create(
                    'table',
                    output_table,
                    recursive=True,
                    attributes={
                        Attributes.SCHEMA: Output.SCHEMA,
                        Attributes.OPTIMIZE_FOR: 'scan',
                    }
                )

            """Apparently map + sort + reduce is faster
            as we get more jobs for map."""
            self.run_native_predict_map(scored)
            sort_by = [
                'GroupID',
                'MinusRegionSize',
                'Region',
                'MinusDeviceProbability',
                'Device',
                'MinusScore',
            ]

            self.sort(
                scored,
                scored,
                sort_by=sort_by
            )

            attrs = sort_by+[UserDataStats.Fields.FILTER]
            schema = [{'name': name, 'type': 'any'} for name in attrs]
            self.yt.remove(filters)
            self.yt.create('table', filters, recursive=True, force=True,
                           attributes={Attributes.SCHEMA: schema})
            self.sort(
                self.segments_meta_table+"{{{},{}}}"
                    .format(UserDataStats.Fields.GROUP_ID, UserDataStats.Fields.FILTER),
                filters,
                sort_by=sort_by,
                spec=ALLOW_BIG_ROWS
            )
            self.run_native_predict_reduce(
                [filters, scored],
                scored,
                sort_by,
                LookalikeStatsStorage(storage_client).prepare_table()
            )

            all_related_goals = [
                goal
                for _, options in self.segments_options.items()
                for goal in options.RelatedGoals
            ]

            stats_per_segment, scales_per_segment = _compute_stats_per_segment(
                self, scored, group_by=VOLATILE_ID, id_value_column=Output.Fields.YUID,
            )
            stats_per_goal = _compute_stats_per_goal(self, all_related_goals)

            logger.debug('Stats per segment %s', stats_per_segment)

            self.write_to_storage(scored, ts)

            self.sort(scored, scored, sort_by=VOLATILE_ID)
            self.map(
                split_mapper, scored, destinations
            )
        with SegmentsGoalsRelationStorage(storage_client, experiment_mode=True).batched_inserter as relation_inserter, \
             LookalikeSegmentStateStorage(storage_client, experiment_mode=True).batched_inserter as state_inserter,\
             SegmentPropertiesStorage(storage_client, experiment_mode=True).batched_inserter as properties_inserter:
            communality_quantiles = properties_inserter.storage.get_communality_quantiles()
            for (segment_id, options) in self.segments_options.iteritems():
                destination = self.destinations[segment_id]
                permanent_id = options.PermanentId
                segment_type = self.get_segment_type(segment_id)
                lookalike_type = options.LookalikeType
                sql.write_segment(
                    permanent_id,
                    input=False,
                    row_count=destination.size,
                    lookalike_type=lookalike_type,
                    state_inserter=state_inserter,
                )

                if segment_id not in stats_per_segment:
                    _output_failed_segment(
                        self,
                        segment_type,
                        segment_id,
                        permanent_id,
                        'Empty output'
                    )
                    continue

                this_segment_stats = stats_per_segment[segment_id]
                related_goals = [goal for goal in options.RelatedGoals]
                raw_similarities = {
                    str(goal): _similarity(this_segment_stats,
                                           stats_per_goal.get(goal, global_stats))
                    for goal in related_goals
                }
                for goal_id, similarity in raw_similarities.iteritems():
                    write_segment_goal_relation(segment_id=permanent_id,
                                                goal_id=goal_id, cosine=similarity, inserter=relation_inserter)
                similarity_scores = {goal: _cosine_quantile_score(similarity) for
                                     (goal, similarity) in
                                     raw_similarities.iteritems()}

                destination.set_extended_stuff(
                    local_stats=this_segment_stats,
                    scale_factor=scales_per_segment.get(segment_id, 1.),
                    global_stats=global_stats,
                    similarity_scores=similarity_scores,
                    segment_type=segment_type,
                    quantiles=communality_quantiles
                )
                destination.set_status_ok()
                logger.info("Output %s", self.yt.get(destination, attributes=[
                    Output.Attributes.STATUS,
                    Output.Attributes.CRYPTA_RELATED_GOALS_SIMILARITY,
                    Output.Attributes.INTERESTS_AFFINITY,
                    Output.Attributes.SEGMENTS_AFFINITY,
                    Output.Attributes.COMMUNALITY,
                    Output.Attributes.COVARIANCE_LOG_DET,
                    Output.Attributes.OVERALL_STATS,
                ]))

                this_segment_stats = stats_per_segment[segment_id]
                write_segment_properties(
                    segment_id=permanent_id,
                    communality=_raw_trace(this_segment_stats),
                    segment_type=segment_type,
                    inserter=properties_inserter
                )

            input_stats_per_segment = {
                segment[UserDataStats.Fields.GROUP_ID]: _prepare_stats(segment) for segment in
                self.yt.read_table(self.segments_meta_table)
            }
            try:
                self.stats_storage.insert_rows(self.get_stats(stats_per_segment, 'lookalike_output'))
                self.stats_storage.insert_rows(self.get_stats(input_stats_per_segment, 'lookalike_input'))
            except Exception as e:
                # TODO: remove me
                logger.warning('StatsStorage is failed: %s', e.message)

            self.yt.remove(self.batch, recursive=True, force=True)


class PredictWaitingSegments(YtTask, IndependentTask, WithPaths):

    @property
    def waiting_dir(self):
        return self.paths.waiting_segments

    def run(self, **kwargs):
        for batch in self.waiting_dir:
            if batch.get_attribute(Attributes.LOCK_COUNT):
                continue

            yield PredictSegments(batch=batch, priority=batch.priority)


def has_only_unknown_region(user_data_stats):
    for item in user_data_stats.Attributes.Region:
        if item.Region:
            return False

    return True
