import collections
import datetime
import json
import multiprocessing
import os

import luigi

from crypta.lib.python.juggler.juggler_helpers import report_event_to_juggler
from crypta.profile.lib import (
    bb_helpers,
    date_helpers,
)
from crypta.profile.utils import (
    api,
    loggers,
    luigi_utils,
    utils,
    yt_utils,
)
from crypta.profile.utils.config import (
    config,
    secrets,
)


SAMPLE_CHUNK_COUNT = 100
SAMPLE_CHUNK_SIZE = 100
NUMBER_OF_PROCESSES = 10


def convert_to_str_list(src):
    return [str(x) for x in src]


def convert_to_str_dict(src):
    return {str(k): v for k, v in src.iteritems()}


def get_not_exported_segments():
    exports = api.get_api().lab.getNotExportedToBigB().result()
    keyword_name_segments = collections.defaultdict(set)

    for export in exports:
        keyword_name = bb_helpers.get_keyword_name(export.keywordId, export.segmentId)
        keyword_name_segments[keyword_name].add(str(export.segmentId))

    return keyword_name_segments


def is_equal(field_1, field_2, not_exported_ids_set=None):
    not_exported_ids = set(not_exported_ids_set or [])

    # field1 can't be None
    if field_2 is None:
        return False, 'Field2 is None: {} {}'.format(field_1, field_2)

    if type(field_1) != type(field_2):
        return False, 'Types do not match: {} {}'.format(type(field_1), type(field_2))

    if isinstance(field_1, float):
        if abs(field_1 - field_2) < 0.01:
            return True, ''
        else:
            return False, 'Diff value: {} != {}'.format(field_1, field_2)
    elif isinstance(field_1, list):
        field_1_str = convert_to_str_list(field_1)
        field_2_str = convert_to_str_list(field_2)

        list_1_values = set(field_1_str) - not_exported_ids
        list_2_values = set(field_2_str)

        if list_1_values == list_2_values:
            return True, ''
        else:
            return False, 'Diff values in lists: only in 1 = {} only in 2 ={}'.format(
                list_1_values - list_2_values,
                list_2_values - list_1_values,
            )
    elif isinstance(field_1, dict):
        field_1_exported_ids_str = {str(k): v for k, v in field_1.iteritems() if k not in not_exported_ids}
        field_2_str = convert_to_str_dict(field_2)

        field_1_keys = set(field_1_exported_ids_str.keys())
        field_2_keys = set(field_2_str.keys())

        if field_1_keys != field_2_keys:
            return False, 'Diff keys: only in 1 = {} only in 2 ={}'.format(field_1_keys - field_2_keys, field_2_keys - field_1_keys)
        else:
            for key, value in field_1_exported_ids_str.iteritems():
                if isinstance(value, float) and abs(value - field_2_str[key]) >= 0.01:
                    return False, 'Diff {} value: {} != {}'.format(key, value, field_2_str[key])
                elif isinstance(value, int) and value != field_2_str[key]:
                    return False, 'Diff {} value: {} != {}'.format(key, value, field_2_str[key])
                elif isinstance(value, dict):
                    if value.keys() != field_2_str[key].keys():
                        return False, 'Diff inner keys {} value: {} != {}'.format(key, value, field_2_str[key])
                    else:
                        for inner_key, inner_value in value.iteritems():
                            if isinstance(inner_value, float) and abs(inner_value - field_2_str[key][inner_key]) >= 0.01:
                                return False, 'Diff inner values {} value: {} != {}'.format(inner_key, inner_value, field_2_str[key][inner_key])
        return True, ''
    else:
        equal = field_1 == field_2
        if equal:
            return True, ''
        else:
            return False, 'Diff value: {} != {}'.format(field_1, field_2)


bb_exported_fields = set(utils.bb_keyword_id_to_field_name.values()) - {
    'audience_segments',
    'shortterm_interests',
}

keywords_to_check_update_time = {
    175, 198, 216,
    217, 220, 281,
    544, 545, 546,
    547, 548, 549,
    595, 601,

    877, 878, 879, 880,
    885, 886, 887, 888,
}

keyword_id_to_exact_socdem_field_name = {
    174: 'gender',
    176: 'income_segment',
    543: 'age_segment',
    614: 'income_5_segment',
}

field_name_to_exact_socdem_field_name_dict = {
    'gender': 'gender',
    'user_age_6s': 'age_segment',
    'income_segments': 'income_segment',
    'income_5_segments': 'income_5_segment',
}


def check_update_times(update_times, exported_update_time, exact_socdem):
    error_messages = []
    wrong_time_message_template = 'Wrong update time for keyword_id {}: exported={}, actual={}'

    for keyword_id, times in update_times.iteritems():
        if keyword_id in keywords_to_check_update_time or \
                (keyword_id in keyword_id_to_exact_socdem_field_name
                 and exact_socdem.get(keyword_id_to_exact_socdem_field_name[keyword_id])):
            if len(times) > 1 or abs(times.keys()[0] - exported_update_time) > 5:
                error_messages.append(wrong_time_message_template.format(
                    keyword_id,
                    exported_update_time,
                    times,
                ))

    return error_messages


def check_value_correctness(yandexuid, field_name, field_value, bb_field_value, not_exported_segments=None):
    error_messages = []

    equal, reason = is_equal(field_value, bb_field_value, not_exported_segments)
    if not equal:
        error_messages.append(reason)
        error_messages.append('{} {} exported: {} actual: {}. Not exported segments: {}'.format(
            yandexuid,
            field_name,
            json.dumps(field_value, sort_keys=True),
            json.dumps(bb_field_value, sort_keys=True),
            not_exported_segments,
        ))

    return error_messages


class Worker(object):
    def __init__(self, logger, not_exported_segments_by_keyword_name):
        self.yabs_client = bb_helpers.BigbClient(
            tmv_id=config.CRYPTA_PROFILE_TVM_ID,
            tvm_secret=secrets.get_secrets().get_secret('CRYPTA_PROFILE_TVM_SECRET'),
            logger=logger,
            n_retries=10,
        )
        self.logger = logger
        self.not_exported_segments_by_keyword_name = not_exported_segments_by_keyword_name

    def __call__(self, yt_table_ranges, yandexuids_without_difference_counter, yandexuids_with_difference_counter):
        self.yt = yt_utils.get_yt_client()

        for yt_table_range in yt_table_ranges:
            for row in self.yt.read_table(yt_table_range):
                error_messages = []

                bb_profile, update_times = self.yabs_client.get_parsed_bb_profile(row['yandexuid'], 'yandexuid', glue=False)

                if not bb_profile:
                    self.logger.info("Profile not found for %s", row["yandexuid"])
                    continue

                exported_exact_socdem = row['exact_socdem'] or {}

                error_messages.extend(
                    check_update_times(
                        update_times,
                        row['update_time'],
                        exported_exact_socdem,
                    )
                )

                for field_name, field_value in row.iteritems():
                    if field_name not in bb_exported_fields or field_value is None:
                        continue

                    if field_name == 'exact_socdem':
                        error_messages.extend(
                            check_value_correctness(
                                row['yandexuid'],
                                'offline_exact_socdem',
                                field_value,
                                bb_profile.get('offline_exact_socdem'),
                            )
                        )

                        raw_bb_exact_socdem = bb_profile.get('exact_socdem') or {}
                        filtered_bb_exact_socdem = {}
                        for bb_exact_socdem_key, bb_exact_socdem_value in raw_bb_exact_socdem.iteritems():
                            if bb_exact_socdem_key in field_value:
                                filtered_bb_exact_socdem[bb_exact_socdem_key] = bb_exact_socdem_value

                        error_messages.extend(
                            check_value_correctness(
                                row['yandexuid'],
                                'exact_socdem',
                                field_value,
                                filtered_bb_exact_socdem,
                            )
                        )
                    elif field_name == 'lal_internal':
                        error_messages.extend(
                            check_value_correctness(
                                row['yandexuid'],
                                field_name,
                                field_value,
                                dict(bb_profile.get('trainable_segments', {}), **bb_profile.get('lal_internal', {})),
                            )
                        )

                    elif field_name in ('gender', 'user_age_6s', 'income_segments', 'income_5_segments'):
                        error_messages.extend(
                            check_value_correctness(
                                row['yandexuid'],
                                'offline_{}'.format(field_name),
                                field_value,
                                bb_profile.get('offline_{}'.format(field_name)),
                            )
                        )

                        # check old keywords if socdem is offline
                        if exported_exact_socdem.get(field_name_to_exact_socdem_field_name_dict[field_name]):
                            error_messages.extend(
                                check_value_correctness(
                                    row['yandexuid'],
                                    field_name,
                                    field_value,
                                    bb_profile.get(field_name),
                                )
                            )
                    else:
                        error_messages.extend(
                            check_value_correctness(
                                row['yandexuid'],
                                field_name,
                                field_value,
                                bb_profile.get(field_name),
                                self.not_exported_segments_by_keyword_name.get(field_name),
                            )
                        )

                if error_messages:
                    with multiprocessing.Lock():
                        self.logger.error('yandexuid = {yandexuid}'.format(yandexuid=row['yandexuid']))
                        yandexuids_with_difference_counter.value += 1
                        for error_message in error_messages:
                            self.logger.error(error_message)
                else:
                    with multiprocessing.Lock():
                        yandexuids_without_difference_counter.value += 1


class BbConsistencyMonitoring(luigi_utils.BaseYtTask):
    date = luigi.Parameter()
    task_group = 'consistency_monitoring'

    def requires(self):
        return luigi_utils.ExternalInput(config.YANDEXUID_EXPORT_PROFILES_14_DAYS_TABLE)

    def output(self):
        return luigi_utils.YtNodeAttributeTarget(
            path=os.path.join(
                config.YANDEXUID_DAILY_EXPORT_DIRECTORY,
                date_helpers.get_yesterday(self.date),
            ),
            attribute_name=self.__class__.__name__,
            attribute_value=True,
        )

    def prepare_table_ranges(self):
        return utils.get_random_table_ranges_from_yt(
            self.yt,
            self.input().table,
            sample_chunk_count=SAMPLE_CHUNK_COUNT,
            sample_chunk_size=SAMPLE_CHUNK_SIZE,
        )

    def run(self):
        random_table_ranges = self.prepare_table_ranges()
        jobs = []

        tasks = utils.partition(random_table_ranges, number_of_chunks=NUMBER_OF_PROCESSES)
        yandexuids_without_difference_counter = multiprocessing.Value('i', 0)
        yandexuids_with_difference_counter = multiprocessing.Value('i', 0)
        not_exported_segments_by_keyword_name = get_not_exported_segments()

        self.logger.info("Not exported segments = {}".format(not_exported_segments_by_keyword_name))

        for process_index in range(NUMBER_OF_PROCESSES):
            process = multiprocessing.Process(
                target=Worker(self.logger, not_exported_segments_by_keyword_name),
                args=(
                    tasks[process_index],
                    yandexuids_without_difference_counter,
                    yandexuids_with_difference_counter,
                ),
            )
            jobs.append(process)
            process.start()

        for job in jobs:
            job.join()

        bad_yandexuid_count = yandexuids_with_difference_counter.value
        good_yandexuid_count = yandexuids_without_difference_counter.value

        bb_inconsistent_keywords_percentage = (bad_yandexuid_count * 100.0) / (good_yandexuid_count + bad_yandexuid_count)
        loggers.send_to_graphite(
            name='bb_inconsistent_keywords_percentage',
            value=bb_inconsistent_keywords_percentage,
        )

        self.logger.info(
            'BB yandexuids: consistent = {}, inconsistent = {}, bad% = {}'.format(
                good_yandexuid_count,
                bad_yandexuid_count,
                bb_inconsistent_keywords_percentage,
            ),
        )
        self.yt.set_attribute(self.output().path, self.__class__.__name__, True)

        percent_threshold = 10
        if bb_inconsistent_keywords_percentage > percent_threshold:
            message = 'Percent of inconsistent BB entries={bad} is above the threshold={threshold}'.format(
                bad=bb_inconsistent_keywords_percentage,
                threshold=percent_threshold,
            )

            report_event_to_juggler(
                status='WARN',
                service='offline_classification_bb_monitoring',
                host=config.CRYPTA_PROFILE_JUGGLER_HOST,
                description=message,
                logger=loggers.get_stderr_logger(),
            )
        else:
            report_event_to_juggler(
                status='OK',
                service='offline_classification_bb_monitoring',
                host=config.CRYPTA_PROFILE_JUGGLER_HOST,
                logger=loggers.get_stderr_logger(),
            )


if __name__ == '__main__':
    luigi.run(
        [
            '--scheduler-url', config.LUIGI_SCHEDULER_URL,
            '--date', str(datetime.date.today())
        ],
        main_task_cls=BbConsistencyMonitoring,
    )
