#!/usr/bin/python -tt
# -*- coding: utf-8 -*-

import base64
from os.path import join
from collections import defaultdict

from ads.bsyeti.libs.log_protos import crypta_profile_pb2
from library.python.protobuf.json import proto2json
import luigi
import yt.yson as yson
import yt.wrapper as yt

from crypta.profile.lib import date_helpers

from crypta.profile.utils.api import segments
from crypta.profile.utils.config import config
from crypta.profile.utils.loggers import TimeTracker
from crypta.profile.utils.luigi_utils import YtTarget, BaseYtTask, ExternalInput, BaseTimestampYtTask, OldNodesByNameCleaner

from crypta.profile.runners.export_profiles.lib.export.add_vectors_to_daily_export import AddVectorsToDailyExport
from crypta.profile.runners.export_profiles.lib.export.get_daily_export_and_process_bb_storage import GetDailyExportAndProcessBbStorage


field_name_to_bb_keyword_id = {
    'age_segments': (175, 'weighted_uint_values'),
    'top_common_site_ids': (198, 'uint_values'),
    'heuristic_segments': (216, 'pair_values'),
    'probabilistic_segments': (217, 'weighted_pair_values'),
    'interests_composite': (217, 'weighted_pair_values'),
    'yandex_loyalty': (220, 'uint_values'),
    'marketing_segments': (281, 'weighted_uint_values'),
    'lal_common': (544, 'weighted_uint_values'),
    'lal_private': (545, 'weighted_uint_values'),
    'lal_internal': (546, 'weighted_uint_values'),
    'heuristic_common': (547, 'uint_values'),
    'heuristic_private': (548, 'uint_values'),
    'heuristic_internal': (549, 'uint_values'),
    'exact_socdem': (569, 'pair_values'),
    'affinitive_site_ids': (595, 'pair_values'),
    'packed_vector': (596, 'base64_value'),
    'longterm_interests': (601, 'uint_values'),
    'shortterm_interests': (602, 'uint_values'),
    'gender': (877, 'weighted_uint_values'),
    'user_age_6s': (878, 'weighted_uint_values'),
    'income_segments': (879, 'weighted_uint_values'),
    'income_5_segments': (880, 'weighted_uint_values'),
    'exact_gender': (885, 'uint_values'),
    'exact_age_segment': (886, 'uint_values'),
    'exact_income_3_segment': (887, 'uint_values'),
    'exact_income_5_segment': (888, 'uint_values'),
}


named_values = {
    877: {'m': 0, 'f': 1},
    878: {'0_17': 0, '18_24': 1, '25_34': 2, '35_44': 3, '45_54': 4, '55_99': 5},
    879: {'A': 0, 'B': 1, 'C': 2},
    880: {'A': 0, 'B1': 1, 'B2': 2, 'C1': 3, 'C2': 4},

    175: {'0_17': 0, '18_24': 1, '25_34': 2, '35_44': 3, '45_99': 4},

    885: {'m': 0, 'f': 1},
    886: {'0_17': 0, '18_24': 1, '25_34': 2, '35_44': 3, '45_54': 4, '55_99': 5},
    887: {'A': 0, 'B': 1, 'C': 2},
    888: {'A': 0, 'B1': 1, 'B2': 2, 'C1': 3, 'C2': 4},
}

exact_socdem_to_keyword_name_dict = {
    'gender': 'exact_gender',
    'age_segment': 'exact_age_segment',
    'income_segment': 'exact_income_3_segment',
    'income_5_segment': 'exact_income_5_segment',
}

SANDBOX_RESOURCE = 'https://proxy.sandbox.yandex-team.ru/last/BIGB_AB_EXPERIMENTS_PRODUCTION_CONFIG'
LOGBROKER_ITEMS = {885, 886, 887, 888, 601}


def is_exported_segment(keyword_id, segment_id, not_exported_segments):
    # CRYPTA-2248 filter not exported segments for keywords 544, 545, 546, 547, 548, 549, 557, 601, 602

    return int(segment_id) not in not_exported_segments[keyword_id]


def add_item(proto, value_type, keyword_id, field_value, segments_not_for_export=None):
    item = proto.items.add(keyword_id=keyword_id)

    not_exported_segments = segments_not_for_export or defaultdict(set)

    if value_type == 'weighted_uint_values':
        if isinstance(field_value, dict):
            field_value = list(field_value.items())
        if isinstance(field_value, list):
            for segment_id, segment_value in field_value:
                if keyword_id in named_values:
                    segment_id = named_values[keyword_id][segment_id]

                if is_exported_segment(keyword_id, segment_id, not_exported_segments):
                    weighted_uint_values = item.weighted_uint_values.add()
                    weighted_uint_values.first = int(segment_id)
                    weighted_uint_values.weight = int(segment_value * 1000000)
        else:
            residual_value = 1.0 - field_value

            weighted_uint_values = item.weighted_uint_values.add()
            weighted_uint_values.first = 0
            weighted_uint_values.weight = int(residual_value * 1000000)

            weighted_uint_values = item.weighted_uint_values.add()
            weighted_uint_values.first = 1
            weighted_uint_values.weight = int(field_value * 1000000)

    elif value_type == 'pair_values':
        if keyword_id == 595:
            for site_id, affinity in field_value.iteritems():
                pair = item.pair_values.add()
                pair.first = int(site_id)
                pair.second = int(affinity * 1000000)
        else:
            for segment_id, segment_value in field_value.iteritems():
                if is_exported_segment(keyword_id, segment_id, not_exported_segments):
                    pair = item.pair_values.add()
                    pair.first = int(segment_id)
                    pair.second = int(segment_value)

    elif value_type == 'uint_values':
        if field_value is None:
            item.DeleteFlag = True
        else:
            if isinstance(field_value, float):
                field_value = [int(field_value * 1000000)]
            elif isinstance(field_value, (int, yson.YsonUint64, yson.YsonInt64)):

                if is_exported_segment(keyword_id, field_value, not_exported_segments):
                    field_value = [field_value]

            if keyword_id in not_exported_segments:
                field_value = [x for x in field_value if x not in not_exported_segments[keyword_id]]

            if len(field_value) > 0:
                item.uint_values[:] = field_value

    elif value_type == 'weighted_pair_values':
        for supersegment_id, segment_value in field_value.iteritems():
            for subsegment_id, subsegment_value in segment_value.iteritems():
                value = item.weighted_pair_values.add()
                value.first = int(supersegment_id)
                value.second = int(subsegment_id)
                value.weight = int(subsegment_value * 1000000)

    elif value_type == 'base64_value':
        item.base64_value = base64.b64encode(field_value)

    item_fields = [descriptor.name for descriptor, _ in item.ListFields()]
    if len(item_fields) == 1 and item_fields[0] == 'keyword_id':
        item.DeleteFlag = True


def is_export_value_valid(bb_export_entry):
    return len(bb_export_entry.items) > 0


class GenerateTablesForExport(object):
    def __init__(self, timestamp, segments_not_for_export, trainable_segments_ids, trainable_segments_priority, output_to_logbroker):
        self.timestamp = timestamp
        self.not_exported_segments = segments_not_for_export
        self.trainable_segments_ids = trainable_segments_ids
        self.trainable_segments_priority = trainable_segments_priority
        self.output_to_logbroker = output_to_logbroker

    def start(self):
        self.converter = proto2json.Proto2JsonConverter(crypta_profile_pb2.TCryptaLog)

    def __call__(self, record):
        timestamp = self.timestamp or record['update_time']
        user_id = record['yandexuid'] if 'yandexuid' in record else record['crypta_id']
        bb_export_entry = crypta_profile_pb2.TCryptaLog(yandex_id=user_id, icookie=record.get('icookie'), timestamp=timestamp)
        output_records = []

        keyword_217 = {}
        for field_name, field_value in record.iteritems():
            if field_name not in field_name_to_bb_keyword_id:
                continue

            keyword_id, value_type = field_name_to_bb_keyword_id[field_name]

            if record.get('fields_to_delete') and field_name in record['fields_to_delete']:
                if keyword_id == 569:
                    for exact_socdem_field_name, keyword_name in exact_socdem_to_keyword_name_dict.iteritems():
                        keyword_id, value_type = field_name_to_bb_keyword_id[keyword_name]
                        bb_export_entry.items.add(
                            keyword_id=keyword_id,
                            DeleteFlag=True,
                        )
                else:
                    bb_export_entry.items.add(
                        keyword_id=keyword_id,
                        DeleteFlag=True,
                    )

            elif field_value is not None:
                if keyword_id == 217:
                    keyword_217.update(field_value)
                elif keyword_id == 602:
                    if isinstance(field_value, list):
                        for shortterm_interest_id in field_value:
                            if is_exported_segment(keyword_id, shortterm_interest_id, self.not_exported_segments):
                                bb_export_entry.items.add(
                                    keyword_id=602,
                                    uint_values=[shortterm_interest_id],
                                )
                    elif isinstance(field_value, dict):
                        for shortterm_interest_id, shortterm_interest_timestamp in field_value.iteritems():
                            if is_exported_segment(keyword_id, shortterm_interest_id, self.not_exported_segments):
                                output_record = crypta_profile_pb2.TCryptaLog(
                                    yandex_id=user_id,
                                    timestamp=shortterm_interest_timestamp,
                                    items=[crypta_profile_pb2.TCryptaLog.TItem(keyword_id=602, uint_values=[int(shortterm_interest_id)])],
                                )
                                output_records.append(output_record)

                elif keyword_id == 569:
                    for exact_socdem_field_name, keyword_name in exact_socdem_to_keyword_name_dict.iteritems():
                        field_value = record['exact_socdem'].get(exact_socdem_field_name)
                        keyword_id, value_type = field_name_to_bb_keyword_id[keyword_name]
                        field_value = named_values[keyword_id].get(field_value)
                        add_item(bb_export_entry, value_type, keyword_id, field_value)

                elif keyword_id == 546:
                    add_item(bb_export_entry, value_type, keyword_id, field_value, self.not_exported_segments)
                    filtered_segments = {}
                    for segment_id, probability in field_value.items():
                        if segment_id in self.trainable_segments_ids:
                            filtered_segments[segment_id] = probability
                    if len(filtered_segments) > 0:
                        filtered_segments = list(sorted(
                            filtered_segments.items(),
                            key=lambda segment_pair: self.trainable_segments_priority.get(segment_pair[0], 0),
                            reverse=True,
                        ))
                        add_item(bb_export_entry, value_type, 1084, filtered_segments, self.not_exported_segments)

                else:
                    add_item(bb_export_entry, value_type, keyword_id, field_value, self.not_exported_segments)

        if keyword_217:
            add_item(bb_export_entry, 'weighted_pair_values', 217, keyword_217)

        if is_export_value_valid(bb_export_entry):
            output_records.append(bb_export_entry)

        return self.output(output_records)

    def output(self, output_records):
        for record in output_records:
            yield yt.create_table_switch(0)
            yield {'value': record.SerializeToString()}

            if self.output_to_logbroker:
                items = [item for item in record.items if item.keyword_id in LOGBROKER_ITEMS]
                del record.items[:]
                record.items.extend(items)

                if is_export_value_valid(record):
                    yield yt.create_table_switch(1)
                    yield {'value': self.converter.convert(record)}


def create_empty_table(yt, path):
    yt.create_empty_table(
        path,
        schema={
            'value': 'string',
        },
    )


def get_trainable_segments_priorities(yt_client):
    trainable_segments_priorities = {}
    for priority, row in enumerate(sorted(
        yt_client.read_table(config.TRAINABLE_SEGMENTS_PRIORITIES),
        key=lambda row: (row['priority'], row['custom_priority'], row['partners_count']),
    )):
        trainable_segments_priorities[row['segment_id']] = priority
    return trainable_segments_priorities


def create_export(yt, source_table, collector_table, logbroker_table=None, timestamp=None):
    """
        Separate function for emergency_bb_update.py
    """
    output_to_logbroker = logbroker_table is not None

    with yt.Transaction():
        create_empty_table(yt, collector_table)
        outputs = [collector_table]

        if output_to_logbroker:
            create_empty_table(yt, logbroker_table)
            outputs.append(logbroker_table)

        not_exported_segments = segments.get_not_exported_segments()
        trainable_segments_ids = segments.get_trainable_segments()
        trainable_segments_priorities = get_trainable_segments_priorities(yt)

        yt.run_map(
            GenerateTablesForExport(
                timestamp=timestamp,
                segments_not_for_export=not_exported_segments,
                trainable_segments_ids=trainable_segments_ids,
                trainable_segments_priority=trainable_segments_priorities,
                output_to_logbroker=output_to_logbroker,
            ),
            source_table=source_table,
            destination_table=outputs,
            spec={
                'mapper': {
                    'memory_limit': 4 * (1 << 10 << 10 << 10),
                },
            },
        )


class GetExportTables(BaseYtTask):
    date = luigi.Parameter()
    priority = 100
    task_group = 'export_profiles'

    def requires(self):
        return {
            'Vectors': AddVectorsToDailyExport(self.date),
            'Cleaner': OldNodesByNameCleaner(
                self.date,
                folder=config.YANDEXUID_LOGBROKER_EXPORT_YT_DIRECTORY,
                lifetime=config.NUMBER_OF_INTERMEDIATE_PROFILES_TABLES_TO_KEEP,
            ),
        }

    def run(self):
        self.logger.info('Preparing export table to upload to BigB through LogBroker')
        with TimeTracker(self.__class__.__name__):
            create_export(
                self.yt,
                source_table=self.input()['Vectors'].table,
                logbroker_table=self.output().table,
                collector_table=self.collector_path,
            )

    @property
    def collector_path(self):
        return join(config.PROFILES_COLLECTOR_EXPORT_YT_DIRECTORY, 'yandexuid_{}'.format(self.date))

    def output(self):
        return YtTarget(join(config.YANDEXUID_LOGBROKER_EXPORT_YT_DIRECTORY, self.date))


class GetCryptaIdExportTables(BaseYtTask):
    date = luigi.Parameter()
    priority = 90
    task_group = 'export_profiles'

    def requires(self):
        return {
            'DailyExport': GetDailyExportAndProcessBbStorage(self.date, 'crypta_id'),
            'Cleaner': OldNodesByNameCleaner(
                self.date,
                folder=config.CRYPTA_ID_LOGBROKER_EXPORT_YT_DIRECTORY,
                lifetime=config.NUMBER_OF_INTERMEDIATE_PROFILES_TABLES_TO_KEEP,
            ),
        }

    @property
    def collector_path(self):
        return join(config.PROFILES_COLLECTOR_EXPORT_YT_DIRECTORY, 'crypta_id_{}'.format(self.date))

    def output(self):
        return YtTarget(join(config.CRYPTA_ID_LOGBROKER_EXPORT_YT_DIRECTORY, str(self.date)))

    def run(self):
        self.logger.info('Preparing crypta_id export table to upload to BigB through LogBroker')
        with TimeTracker(self.__class__.__name__):
            create_export(
                self.yt,
                source_table=self.input()['DailyExport']['daily_export'].table,
                logbroker_table=self.output().table,
                collector_table=self.collector_path,
                timestamp=date_helpers.from_utc_date_string_to_noon_timestamp(self.date),
            )


class GetShorttermInterestsExportTables(BaseTimestampYtTask):
    timestamp = luigi.Parameter()
    task_group = 'shortterm_interests_upload'

    def requires(self):
        return ExternalInput(join(config.SHORTTERM_INTERESTS_INPUT_YT_DIRECTORY, self.timestamp))

    def run(self):
        with TimeTracker(self.__class__.__name__):
            with self.yt.Transaction():
                create_export(
                    self.yt,
                    source_table=self.input().table,
                    collector_table=self.output().table,
                    timestamp=int(self.timestamp),
                )
                self.yt.remove(self.input().table)

    def output(self):
        return YtTarget(join(config.SHORTTERM_INTERESTS_COLLECTOR_EXPORT_YT_DIRECTORY, str(self.timestamp)))
