import datetime
import logging
import os
from time import time

import numpy as np
from yt.wrapper import (
    OperationsTracker,
)

import crypta.lib.python.bt.conf.conf as conf
from crypta.lib.python.yt import yt_helpers
from crypta.audience.lib.tasks.audience.audience import (
    _prepare_stats,
    _prepare_goals_users,
    _foreign,
    _collect_output_stats,
    _get_list_batches_with_priorities,
)
from crypta.audience.lib.tasks.audience.sql import (
    write_segment_on_output,
    write_segment_properties,
    write_segment_goal_relation,
)
from crypta.audience.lib.tasks.audience.tables import (
    Attributes,
    Input,
    InputBatch,
    Output,
    OutputBatch,
    StorageOutput,
    StatsStorage,
    RegularSegmentStateStorage,
    SegmentsGoalsRelationStorage,
    SegmentPropertiesStorage,
    GeoBatch,
)
from crypta.audience.lib.tasks.base import (
    YtTask,
)
from crypta.audience.lib.communality import (
    _cosine_quantile_score,
    _raw_trace,
)
from crypta.audience.lib.tasks.math import (
    _similarity,
)
from crypta.audience.lib.tasks.native_operations import (
    TMergeUserDataStats,
    TPruneUserData,
    TCountMapper,
    TCountReducer,
    TTransformCountsToStats,
    TMergeStatsWithCounts,
    TMatchWithUserData,
    TSplitUserDataToSegmentStats,
    TEmptyMapper,
    TExtractGeoStorageOutput,
)
from crypta.audience.proto.audience_geo_pb2 import (
    TProjectionMat,
    TPrunedOptions,
    TExportGeoToStorageState,
)
from crypta.lab.lib.specs import (
    YtSpecs,
    _spec,
)
from crypta.lab.lib.tables import (
    UserData,
    UserDataStats,
)
from crypta.lib.proto.user_data.math_pb2 import (
    TVectorType,
)
from crypta.lib.proto.user_data import (
    user_data_pb2,
    user_data_stats_pb2,
)
from crypta.lib.python.bt.workflow import (
    IndependentTask,
    Parameter,
)
from crypta.lib.python.bt.workflow.targets.table import (
    HasAttribute,
)


logger = logging.getLogger(__name__)


class CreatePrunedUserData(YtTask, IndependentTask):

    @property
    def userdata(self):
        return conf.paths.lab.data.userdata

    @property
    def pruned_userdata(self):
        return conf.paths.audience.pruned_userdata

    @property
    def pruned_userdata_stats(self):
        return conf.paths.audience.pruned_userdata_stats

    def targets(self):
        yield HasAttribute(
            self.yt, self.pruned_userdata, UserData.Attributes.LAST_UPDATE_DATE, self.get_attr_value()
        )

    def get_attr_value(self):
        return self.yt.get_attribute(
            self.userdata,
            UserData.Attributes.LAST_UPDATE_DATE,
            ""
        )

    def get_random_projection_mat(self, m, n, seed):
        np.random.seed(seed)
        # constants from here https://users.soe.ucsc.edu/~optas/papers/jl.pdf
        mat = (3. / m) ** (1. / 2) * np.random.choice([-1., 0., 1.], n*m,
                                                      [1. / 6, 2. / 3, 1. / 6])
        return mat.reshape((m, n))

    def get_projection_mat(self, m, n, seed):
        mat = TProjectionMat()
        rows = mat.Mat

        projection_mat = self.get_random_projection_mat(m, n, seed=seed)
        logger.info("Projection mat with %s components, %s seed: %s", m, seed, projection_mat)
        _rows = []
        for row in projection_mat:
            _row = TVectorType()
            _row.Data.extend(row)
            _rows.append(_row)
        rows.Rows.extend(_rows)
        return mat

    def run(self, **kwargs):
        userdata = self.userdata
        pruned_userdata = self.pruned_userdata
        pruned_userdata_stats = self.pruned_userdata_stats

        options = TPrunedOptions()
        strata_segments = options.UsedStrataSegments
        strata_segments.Segment.extend(
            Output.CRYPTA_INTERESTS + Output.CRYPTA_SEGMENTS)
        options.ProjectionMat.CopyFrom(self.get_projection_mat(8, 512, seed=42))
        state = options.SerializeToString()
        self.yt.remove(pruned_userdata, force=True)
        self.native_map(
            TPruneUserData,
            userdata,
            pruned_userdata,
            spec=_spec(
                YtSpecs.MAPPER_BIG_ROWS,
                YtSpecs.MAPPER_HIGH_MEMORY_USAGE,
            ),
            state=state
        )
        self.sort(pruned_userdata, pruned_userdata, sort_by=GeoBatch.Fields.USER_ID)
        _compute_stats_per_segment(self, pruned_userdata,
                                   stats_path=pruned_userdata_stats)

        yt_helpers.set_yql_proto_field(pruned_userdata, 'UserData', user_data_pb2.TUserData, self.yt)
        yt_helpers.set_yql_proto_fields(pruned_userdata_stats, user_data_stats_pb2.TUserDataStats, self.yt)

        date_attr = self.get_attr_value()
        self.yt.set_attribute(pruned_userdata, UserData.Attributes.LAST_UPDATE_DATE, date_attr)
        self.yt.set_attribute(pruned_userdata_stats, UserData.Attributes.LAST_UPDATE_DATE, date_attr)


def _compute_stats_per_goal(task, goals, stats_path=None):
    with task.yt.TempTable(prefix='goals_users_') as users_temporary,\
            task.yt.TempTable(prefix='stats_per_goal_') as stats_temporary:
        if stats_path:
            stats_temporary = stats_path
        if not _prepare_goals_users(task, goals, users_temporary, string_format=False):
            return {}

        _compute_stats_per_segment(
            task, users_temporary, stats_path=stats_temporary, compute_counts=False
        )
        return {
            str(record[UserDataStats.Fields.GROUP_ID]): _prepare_stats(record)
            for record in UserDataStats.records(task.yt.read_table(stats_temporary))
        }


def compute_segment_counts(task, users, counts_path, sync=True):
    mapper = TCountMapper
    if str(task.yt.get_attribute(users, "sorted_by", [""])[0]) != GeoBatch.Fields.SEGMENT_ID:
        mapper = TEmptyMapper
    operation_id = task.native_map_reduce_with_combiner(
        mapper_name=mapper,
        reducer_name=TCountReducer,
        combiner_name=TCountReducer,
        source=users + "{{{}}}".format(GeoBatch.Fields.SEGMENT_ID),
        destination=counts_path,
        reduce_by=[GeoBatch.Fields.SEGMENT_ID],
        sync=sync
    )
    return operation_id


def join_stats_with_counts(task, stats_path, counts_path, batch_meta=None):
    with task.yt.TempTable(prefix='stats_counts_') as stats_counts:
        source = [counts_path]
        if batch_meta:
            source += [batch_meta+"{{{}}}".format(GeoBatch.Fields.SEGMENT_ID)]

        task.native_map(
            mapper_name=TTransformCountsToStats,
            source=counts_path,
            destination=stats_counts
        )

        task.native_map_reduce(
            mapper_name=TEmptyMapper,
            reducer_name=TMergeStatsWithCounts,
            source=[stats_path, stats_counts],
            destination=stats_path,
            reduce_by=[UserDataStats.Fields.GROUP_ID],
        )


def _compute_stats_per_segment(task, users, stats_path, compute_counts=True, batch_meta=None):
    userdata_path = conf.paths.audience.pruned_userdata

    with task.yt.TempTable(prefix='sorted_users_') as sorted_users, \
            task.yt.TempTable(prefix='matched_users_') as matched_users, \
            task.yt.TempTable(prefix='counts_') as counts:

        if compute_counts:
            counts_tracker = OperationsTracker()
            counts_tracker.add_by_id(
                compute_segment_counts(task, users, counts, sync=False))

        task.sort(
            source=users,
            destination=sorted_users,
            sort_by=GeoBatch.Fields.USER_ID,
        )

        source = [userdata_path, _foreign(sorted_users)]
        if task.yt.row_count(userdata_path) > task.yt.row_count(sorted_users):
            source = [_foreign(userdata_path), sorted_users]

        task.native_join_reduce(
            TMatchWithUserData,
            source=source,
            destination=matched_users,
            join_by=GeoBatch.Fields.USER_ID,
            spec=_spec(
                YtSpecs.JOIN_REDUCE_HEAVY_JOBS,
                YtSpecs.REDUCER_BIG_ROWS
            ),
        )

        options = user_data_stats_pb2.TUserDataStatsOptions(
            Flags=user_data_stats_pb2.TUserDataStatsOptions.TFlags(
                IgnoreAffinities=True,
                IgnoreUnusedStrataSegments=True
            )
        )
        strata_segments = options.UsedStrataSegments
        strata_segments.Segment.extend(
            Output.CRYPTA_INTERESTS + Output.CRYPTA_SEGMENTS)
        mapper_state = options.SerializeToString()
        task.native_map_reduce_with_combiner(
            mapper_name=TSplitUserDataToSegmentStats,
            combiner_name=TMergeUserDataStats,
            reducer_name=TMergeUserDataStats,
            source=matched_users,
            destination=stats_path,
            reduce_by=UserData.Fields.GROUP_ID,
            spec=_spec(
                YtSpecs.SMALL_DATA_SIZE_PER_SORT_JOB,
                YtSpecs.VERY_SMALL_DATA_SIZE_PER_MAP_JOB,
                YtSpecs.REDUCER_HIGH_MEMORY_USAGE,
                YtSpecs.MAPPER_HIGH_MEMORY_USAGE,
                YtSpecs.REDUCER_BIG_ROWS,
                YtSpecs.MAPPER_BIG_ROWS
            ),
            mapper_state=mapper_state
        )

        if compute_counts:
            counts_tracker.wait_all()
            join_stats_with_counts(task, stats_path, counts, batch_meta=batch_meta)


class ProcessBigBatch(YtTask, IndependentTask):

    batch = Parameter()

    def get_segments_info_mapper(self):
        def mapper(record):
            yield record["meta"].get(Input.Attributes.CRYPTA_SEGMENT_INFO, {})
        return mapper

    def send_geo_segments_to_storage(self, users_path, meta_path, sync=True):
        storage_state = TExportGeoToStorageState()
        storage_state.Timestamp = int(time())
        output_segments = []
        for record in self.yt.read_table(meta_path):
            segment_id = int(record[GeoBatch.Fields.SEGMENT_ID])
            meta = record['meta']
            segment_info = meta.get(Input.Attributes.CRYPTA_SEGMENT_INFO, {})
            is_geo = (segment_info.get('segment_type') == 'geo')
            is_geo_condition = (
                segment_info.get('geo_segment_type') == 'condition')
            has_non_direct_retargeting = any(
                segment_info.get(type) == "1" for type in (
                    'adfox_retargeting', 'banana_retargeting',
                    'display_retargeting', 'geoadv_retargeting', 'zen_retargeting'))
            if is_geo and (is_geo_condition or has_non_direct_retargeting):
                output_segments.append(segment_id)
        storage_state.OutputSegments.SegmentID.extend(output_segments)
        storage_output = os.path.join(conf.paths.storage.queue, os.path.basename(self.batch))
        storage_meta_output = os.path.join(
            conf.paths.storage.queue_segments_info,
            os.path.basename(self.batch))
        self.yt.create('table', storage_output, recursive=True,
                       attributes={Attributes.SCHEMA: StorageOutput.SCHEMA})

        operation_id = self.native_map(
            TExtractGeoStorageOutput,
            users_path,
            storage_output,
            state=storage_state.SerializeToString(),
            sync=sync
        )

        self.yt.create('table', storage_meta_output, recursive=True)
        self.map(self.get_segments_info_mapper(), meta_path, storage_meta_output)

        return operation_id

    def prepare_output_meta_table(self):
        output_batch = self.ypath_join(conf.paths.audience.output_batch,
                                       os.path.basename(self.batch))
        output_meta = OutputBatch.meta(output_batch)
        self.yt.create('map_node', output_batch, ignore_existing=True,
                       recursive=True)
        if self.yt.exists(output_meta):
            self.yt.remove(output_meta, recursive=True)

        return output_meta

    def compute_stats_per_goal(self, meta_path, stats_path=None):
        all_related_goals = set()
        list_goals_by_segment = dict()
        for record in self.yt.read_table(meta_path):
            meta = record['meta']
            related_goals = list(meta.get(Input.Attributes.CRYPTA_RELATED_GOALS, []))
            all_related_goals.update(related_goals)
            if related_goals:
                list_goals_by_segment[record[GeoBatch.Fields.SEGMENT_ID]] = related_goals

        logger.info('Goals count %s', len(all_related_goals))
        stats_per_goal = _compute_stats_per_goal(self, all_related_goals, stats_path=stats_path)

        return list_goals_by_segment, stats_per_goal

    def collect_stats(self, stats_path, goals):
        list_goals_by_segment, stats_per_goal = goals
        storage_client = self._init_yt()
        with SegmentPropertiesStorage(storage_client, experiment_mode=True).batched_inserter as properties_inserter, \
             SegmentsGoalsRelationStorage(storage_client, experiment_mode=True).batched_inserter as relation_inserter, \
             RegularSegmentStateStorage(storage_client, experiment_mode=True).batched_inserter as state_inserter, \
             StatsStorage(storage_client, experiment_mode=True).batched_inserter as stats_storage:

            ts = int(time())
            output_ts = datetime.datetime.now()

            userdata_stats_path = conf.paths.audience.pruned_userdata_stats
            global_stats = _prepare_stats(next(self.yt.read_table(userdata_stats_path)))
            communality_quantiles = properties_inserter.storage.get_communality_quantiles()

            segment_type = 'audience_big'

            for record in self.yt.read_table(stats_path):
                this_segment_stats = _prepare_stats(record)
                group_id = record.get(UserDataStats.Fields.GROUP_ID, "")
                if not group_id:
                    logger.info("Segment without GroupID")
                    continue
                segment_id = int(group_id)
                row_count = this_segment_stats.Counts.UniqIdValue

                output_stats_record = {}
                if not this_segment_stats.Counts.WithData:
                    logger.warning('Segment %s has no stats', segment_id)
                    output_stats = {Output.Attributes.ERROR: "Failed to compute stats"}
                    output_stats_record = OutputBatch.record(int(segment_id), output_stats)
                else:
                    trace = _raw_trace(this_segment_stats)
                    audience_id = segment_id

                    write_segment_properties(
                        segment_id=segment_id,
                        communality=trace,
                        segment_type=segment_type,
                        inserter=properties_inserter,
                    )

                    raw_similarities = {
                        str(goal_id): _similarity(this_segment_stats, stats_per_goal.get(goal_id, global_stats))
                        for goal_id in list_goals_by_segment.get(segment_id, [])
                    }
                    for goal_id, similarity in raw_similarities.iteritems():
                        write_segment_goal_relation(
                            segment_id=audience_id,
                            goal_id=goal_id,
                            cosine=similarity,
                            inserter=relation_inserter
                        )
                    similarity_scores = {
                        goal: _cosine_quantile_score(similarity)
                        for (goal, similarity) in raw_similarities.iteritems()
                    }

                    output_stats = _collect_output_stats(
                        audience_id,
                        segment_type,
                        this_segment_stats,
                        1.0,
                        global_stats,
                        similarity_scores,
                        communality_quantiles,
                        trace=trace
                    )
                    output_stats_record = OutputBatch.record(segment_id, output_stats)

                stats_storage.insert_row(
                    segment_id,
                    segment_type,
                    this_segment_stats.SerializeToString(),
                    ts
                )

                write_segment_on_output(
                    segment_id=segment_id,
                    row_count=row_count,
                    time=output_ts,
                    state_inserter=state_inserter,
                    segment_type=segment_type
                )

                yield output_stats_record

    def write_stats(self, stats_path, stats_per_goal):
        self.yt.write_table(
            self.prepare_output_meta_table(),
            self.collect_stats(stats_path, stats_per_goal)
        )

    def run(self, **kwargs):
        self.yt.lock(self.batch, mode="exclusive")
        if not self.yt.exists(self.batch):
            logger.warning('Batch %s disappeared', self.batch)
            return
        if self.yt.get_attribute(self.batch, InputBatch.Attributes.IN_PROCESSING, False):
            logger.warning('Batch %s is computed', self.batch)
            return
        batch = self.batch

        batch_meta, batch_users = InputBatch.meta_users(batch)

        send_to_storage_op_tracker = OperationsTracker()
        op = self.send_geo_segments_to_storage(batch_users, batch_meta, sync=False)
        send_to_storage_op_tracker.add_by_id(op)

        with self.yt.TempTable(prefix='stats_') as stats_path:
            _compute_stats_per_segment(
                self,
                batch_users, stats_path, batch_meta=batch_meta
            )
            goals = self.compute_stats_per_goal(batch_meta)
            self.write_stats(stats_path, goals)

        send_to_storage_op_tracker.wait_all()

        self.yt.set_attribute(batch, InputBatch.Attributes.IN_PROCESSING, True)


class EnqueueBigBatches(YtTask, IndependentTask):

    def run(self, **kwargs):
        batches_path = conf.paths.audience.input_batch

        if not self.yt.exists(batches_path):
            logger.warning('Path %s does not exist', batches_path)
            return

        skip_filter = lambda batch: self.yt.get_attribute(batch, InputBatch.Attributes.IN_PROCESSING, False)
        batches = _get_list_batches_with_priorities(self, batches_path, dir_only=True, skip_filter=skip_filter, add_ts_priority_to_out=True)

        for priority, ts_priority, batch in batches:
            yield ProcessBigBatch(batch=batch, priority=priority)
