from collections import Counter, defaultdict
from copy import deepcopy
from functools import partial
import os

from yt.wrapper import (
    create_table_switch,
    OperationsTracker,
    with_context,
)

from crypta.lib.python.custom_ml.tools.training_utils import normalize_probabilities
from crypta.lib.python.yt import yt_helpers
from crypta.profile.lib import date_helpers
from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.utils import utils
from crypta.profile.utils.config import config


corrected_socdem_base_schema = {
    'gender': 'any',
    'user_age_6s': 'any',
    'income_5_segments': 'any',
    'age_segments': 'any',
    'income_segments': 'any',
    'exact_socdem': 'any',
    'update_time': 'uint64',
}

collect_socdem_data_for_yandexuid_query = """
$socdem_data_for_yandexuid = (
    SELECT
        yandexuid_14days_raw_classification.yandexuid AS yandexuid,
        yandexuid_14days_raw_classification.update_time AS update_time,
        yandexuid_14days_raw_classification.update_time >= {last_update} AS is_active,
        yandexuid_14days_raw_classification.gender AS gender,
        yandexuid_14days_raw_classification.user_age_6s AS user_age_6s,
        yandexuid_14days_raw_classification.income_5_segments AS income_5_segments,

        yandexuid_socdem_storage.gender_scores AS gender_scores,
        yandexuid_socdem_storage.age_scores AS age_scores,
        yandexuid_socdem_storage.income_scores AS income_scores,

        profiles_for_14days.exact_socdem AS exact_socdem,

        yandexuid_crypta_id_matching.crypta_id AS crypta_id,

    FROM `{yandexuid_14days_raw_classification}` AS yandexuid_14days_raw_classification
    LEFT JOIN `{yandexuid_socdem_storage}` AS yandexuid_socdem_storage
    ON yandexuid_14days_raw_classification.yandexuid == yandexuid_socdem_storage.yandexuid
    LEFT JOIN `{profiles_for_14days}` AS profiles_for_14days
    ON yandexuid_14days_raw_classification.yandexuid == profiles_for_14days.yandexuid
    LEFT JOIN `{yandexuid_crypta_id_matching}` AS yandexuid_crypta_id_matching
    ON yandexuid_14days_raw_classification.yandexuid == yandexuid_crypta_id_matching.yandexuid
);

INSERT INTO `{yandexuid_socdem_with_crypta_id}`
WITH TRUNCATE

SELECT *
FROM $socdem_data_for_yandexuid
WHERE crypta_id IS NOT NULL
ORDER BY crypta_id;

INSERT INTO `{yandexuid_socdem_without_crypta_id}`
WITH TRUNCATE

SELECT *
FROM $socdem_data_for_yandexuid
WHERE crypta_id IS NULL;
"""

collect_socdem_data_for_devid_query = """
INSERT INTO `{devid_socdem_with_crypta_id}`
WITH TRUNCATE

SELECT
    devid_35days_raw_classification_table.id AS id,
    devid_35days_raw_classification_table.id_type AS id_type,
    update_time,
    gender,
    user_age_6s,
    income_5_segments,
    crypta_id
FROM `{devid_35days_raw_classification_table}` AS devid_35days_raw_classification_table
INNER JOIN `{devid_crypta_id_matching}` AS devid_crypta_id_matching
ON devid_35days_raw_classification_table.id == devid_crypta_id_matching.id AND
    devid_35days_raw_classification_table.id_type == devid_crypta_id_matching.id_type
ORDER BY crypta_id;
"""


def convert_user_age_6s_to_age_segments(user_age_6s):
    age_segments = defaultdict(float)

    for segment, probability in user_age_6s.items():
        if segment in ('45_54', '55_99'):
            segment = '45_99'
        age_segments[segment] += probability
    return age_segments


def convert_income_5_segments_to_income_segments(income_5_segments):
    return {
        'A': income_5_segments['A'],
        'B': income_5_segments['B1'] + income_5_segments['B2'],
        'C': income_5_segments['C1'] + income_5_segments['C2'],
    }


def calculate_yandexuid_profile(row, thresholds):
    """
    Calculate avg probabilities based on yandexuid classification and yandexuid socdem storage.
    Used in case when crypta_id socdem probabilities do not exist or are not reliable.
    """

    yandexuid_profile = {}
    for socdem_type, socdem_segments_name in zip(socdem_config.SOCDEM_TYPES,
                                                 socdem_config.SOCDEM_SEGMENT_TYPE_NAMES):
        yandexuid_profile[socdem_segments_name] = normalize_probabilities(
            Counter(row[socdem_segments_name]) +
            Counter(row['{}_scores'.format(socdem_type)])
        )

    yandexuid_profile['exact_socdem'] = utils.get_exact_socdem_dict(yandexuid_profile, thresholds)

    return yandexuid_profile


def compute_other_columns(corrected_profile, thresholds):
    corrected_profile['age_segments'] = convert_user_age_6s_to_age_segments(corrected_profile['user_age_6s'])
    corrected_profile['income_segments'] = convert_income_5_segments_to_income_segments(
        corrected_profile['income_5_segments'],
    )
    corrected_profile['exact_socdem'] = utils.get_exact_socdem_dict(corrected_profile, thresholds)

    return corrected_profile


def check_if_exact_socdem_changed(new_exact_socdem, old_exact_socdem):
    if old_exact_socdem is None:
        if len(new_exact_socdem) > 0:
            return True
        else:
            return False

    for socdem_name in socdem_config.EXACT_SOCDEM_FIELDS:
        if new_exact_socdem.get(socdem_name) != old_exact_socdem.get(socdem_name):
            return True

    return False


def correct_yandexuid_profiles_without_crypta_id(row, thresholds):
    """
    Try to correct yandexuid socdem by yandexuid_socdem_storage.
    """
    corrected_profile = {
        'yandexuid': row['yandexuid'],
        'update_time': row['update_time'],
    }
    yandexuid_profile = calculate_yandexuid_profile(row, thresholds)
    for socdem_segments_name, exact_socdem_field in zip(
        socdem_config.SOCDEM_SEGMENT_TYPE_NAMES,
        socdem_config.EXACT_SOCDEM_FIELDS,
    ):
        if yandexuid_profile['exact_socdem'].get(exact_socdem_field) is not None:
            corrected_profile[socdem_segments_name] = yandexuid_profile[socdem_segments_name]
        else:
            corrected_profile[socdem_segments_name] = row[socdem_segments_name]

    corrected_profile = compute_other_columns(corrected_profile, thresholds)

    if row['is_active']:
        yield create_table_switch(0)
        yield corrected_profile
    else:
        new_exact_socdem = corrected_profile['exact_socdem']
        old_exact_socdem = row['exact_socdem']

        if check_if_exact_socdem_changed(new_exact_socdem, old_exact_socdem):
            yield create_table_switch(1)
            yield corrected_profile


@with_context
class CalculateCryptaIdSocdem:
    def __init__(self, thresholds):
        self.thresholds = thresholds

        self.crypta_id_socdem_table_index = 0
        self.corrected_devid_socdem_table_index = 1
        self.corrected_active_yandexuid_socdem_table_index = 2
        self.corrected_not_active_yandexuid_socdem_table_index = 3

    def fill_corrected_profile(self, row, corrected_profile, crypta_id_profile, check_storage=False):
        for socdem_segments_name, exact_socdem_field in zip(
            socdem_config.SOCDEM_SEGMENT_TYPE_NAMES,
            socdem_config.EXACT_SOCDEM_FIELDS,
        ):
            corrected_profile[socdem_segments_name] = row[socdem_segments_name]

            # correct by crypta_id
            if exact_socdem_field in crypta_id_profile['exact_socdem']:
                corrected_profile[socdem_segments_name] = deepcopy(crypta_id_profile[socdem_segments_name])
            # correct by yandexuid_socdem_storage
            elif check_storage:
                yandexuid_profile = calculate_yandexuid_profile(row, self.thresholds)
                if exact_socdem_field in yandexuid_profile['exact_socdem']:
                    corrected_profile[socdem_segments_name] = deepcopy(yandexuid_profile[socdem_segments_name])

        corrected_profile = compute_other_columns(corrected_profile, self.thresholds)

        return corrected_profile

    def __call__(self, key, rows, context):
        has_devid_classification = False
        has_yandexuid_classification = False
        rows_to_correct = {'yandexuid': [], 'devid': []}

        scores = {}
        for socdem_type in socdem_config.SOCDEM_TYPES:
            scores[socdem_type] = Counter()

        for row in rows:
            # yandexuid or devid classification
            if context.table_index == 0 or context.table_index == 1:
                for socdem_type, socdem_segments_name in zip(socdem_config.SOCDEM_TYPES,
                                                             socdem_config.SOCDEM_SEGMENT_TYPE_NAMES):
                    scores[socdem_type] += Counter(row[socdem_segments_name])

                if context.table_index == 0:
                    has_yandexuid_classification = True
                    rows_to_correct['yandexuid'].append(deepcopy(row))
                elif context.table_index == 1:
                    has_devid_classification = True
                    rows_to_correct['devid'].append(deepcopy(row))

            # crypta_id storage
            elif context.table_index == 2:
                for socdem_type in socdem_config.SOCDEM_TYPES:
                    scores[socdem_type] += Counter(row['{}_scores'.format(socdem_type)])

        if not (has_devid_classification or has_yandexuid_classification):
            return

        crypta_id_profile = {
            'crypta_id': key['crypta_id'],
        }
        for socdem_type, socdem_segments_name in zip(socdem_config.SOCDEM_TYPES,
                                                     socdem_config.SOCDEM_SEGMENT_TYPE_NAMES):
            crypta_id_profile['{}_scores'.format(socdem_type)] = scores[socdem_type]
            crypta_id_profile[socdem_segments_name] = normalize_probabilities(scores[socdem_type])

        crypta_id_profile = compute_other_columns(crypta_id_profile, self.thresholds)

        yield create_table_switch(self.crypta_id_socdem_table_index)
        yield crypta_id_profile

        # correct devid profiles
        for row in rows_to_correct['devid']:
            devid_corrected_profile = {
                'id': row['id'],
                'id_type': row['id_type'],
                'update_time': row['update_time'],
            }
            devid_corrected_profile = self.fill_corrected_profile(
                row,
                devid_corrected_profile,
                crypta_id_profile,
            )

            yield create_table_switch(self.corrected_devid_socdem_table_index)
            yield devid_corrected_profile

        # correct yandexuid profiles
        for row in rows_to_correct['yandexuid']:
            yandexuid_corrected_profile = {
                'yandexuid': row['yandexuid'],
                'update_time': row['update_time'],
            }
            yandexuid_corrected_profile = self.fill_corrected_profile(
                row,
                yandexuid_corrected_profile,
                crypta_id_profile,
                check_storage=True,
            )

            if row['is_active']:
                yield create_table_switch(self.corrected_active_yandexuid_socdem_table_index)
                yield yandexuid_corrected_profile
            else:
                new_exact_socdem = yandexuid_corrected_profile['exact_socdem']
                old_exact_socdem = row['exact_socdem']

                if check_if_exact_socdem_changed(new_exact_socdem, old_exact_socdem):
                    yield create_table_switch(self.corrected_not_active_yandexuid_socdem_table_index)
                    yield yandexuid_corrected_profile


def setup_tables(
    yt_client,
    results_directory,
    date,
    corrected_active_yandexuid_without_crypta_id,
    corrected_not_active_yandexuid_without_crypta_id,
):
    output_tables_paths = {}
    for table, default_folder_path in [
        ('corrected_active_yandexuid', config.CORRECTED_YANDEXUID_PROFILES_YT_DIRECTORY),
        ('corrected_not_active_yandexuid', config.CORRECTED_NOT_ACTIVE_YANDEXUID_PROFILES_YT_DIRECTORY),
        ('crypta_id_profiles_table', config.CRYPTA_ID_PROFILES_YT_DIRECTORY),
        ('corrected_devid_profiles_table', config.CORRECTED_DEVID_PROFILES_YT_DIRECTORY),
    ]:
        if results_directory is not None:
            output_tables_paths[table] = os.path.join(results_directory, table)
        else:
            output_tables_paths[table] = os.path.join(default_folder_path, date)

    for table in (
        output_tables_paths['corrected_active_yandexuid'],
        output_tables_paths['corrected_not_active_yandexuid'],
        corrected_active_yandexuid_without_crypta_id,
        corrected_not_active_yandexuid_without_crypta_id,
    ):
        corrected_yandexuid_schema = corrected_socdem_base_schema.copy()
        corrected_yandexuid_schema.update({'yandexuid': 'uint64'})
        yt_helpers.create_empty_table(
            yt_client,
            table,
            schema=corrected_yandexuid_schema,
        )

    corrected_devid_schema = corrected_socdem_base_schema.copy()
    corrected_devid_schema.update({'id': 'string', 'id_type': 'string'})
    yt_helpers.create_empty_table(
        yt_client,
        output_tables_paths['corrected_devid_profiles_table'],
        schema=corrected_devid_schema,
    )

    crypta_id_profiles_schema = corrected_socdem_base_schema.copy()
    crypta_id_profiles_schema.update({
        'crypta_id': 'uint64',
        'gender_scores': 'any',
        'age_scores': 'any',
        'income_scores': 'any',
    })
    yt_helpers.create_empty_table(
        yt_client,
        output_tables_paths['crypta_id_profiles_table'],
        schema=crypta_id_profiles_schema,
    )

    return output_tables_paths


def run_voting(
    yt_client,
    yql_client,
    yandexuid_14days_raw_classification=config.MERGED_RAW_YANDEXUID_PROFILES,
    devid_35days_raw_classification_table=config.MERGED_RAW_DEVID_PROFILES,
    thresholds=None,
    date=None,
    results_directory=None,
):
    """
    Fuction to calculate crypta_id socdem, corrected yandexuid and devid socdem profiles.

    Parameters
    ----------
    yt_client
        crypta/lib/python/yt/client
    yql_client
        crypta/lib/python/yql/client
    yandexuid_14days_raw_classification: str, optional
        Path to the table with columns [yandexuid, update_time, gender, user_age_6s, income_5_segments],
        table consists of the raw probabilities for the classes obtained by applying the socdem models.
    devid_35days_raw_classification_table: str, optional
        Path to the table with columns [id, id_type, update_time, gender, user_age_6s, income_5_segments],
        table consists of raw the probabilities for the classes obtained by applying the mobile socdem models.
    thresholds: dict, optional
        Dictionary that has dictionary of thresholds for each socdem type.
    date: str, optional
        Date that will be used to select active users,
        if it is not defined, then generate_date of yandexuid_14days_raw_classification will be used.
    results_directory: str, optional
        Path that will be used to save output tables,
        if it is not defined paths from profile config will be used.
        Note: it must be defined if function is used outside regular process.

    Returns
    -------
    None
        Creates tables (the paths are defined by input parameters):
            - crypta_id socdem
            - corrected devid (with crypta_id) socdem
            - corrected active yandexuid socdem
            - corrected not active yandexuid socdem, if calculated socdem is different from already known one
    """

    with yt_client.Transaction() as transaction, \
            yt_client.TempTable() as yandexuid_socdem_with_crypta_id_table, \
            yt_client.TempTable() as yandexuid_socdem_without_crypta_id_table, \
            yt_client.TempTable() as devid_socdem_with_crypta_id_table, \
            yt_client.TempTable() as corrected_active_yandexuid_without_crypta_id_table, \
            yt_client.TempTable() as corrected_not_active_yandexuid_without_crypta_id_table:

        if date is None:
            date = yt_client.get_attribute(config.MERGED_RAW_YANDEXUID_PROFILES, 'generate_date')
        last_update = date_helpers.from_utc_date_string_to_noon_timestamp(date)

        yql_client.execute(
            query='{yandexuid_query}\n{devid_query}'.format(
                yandexuid_query=collect_socdem_data_for_yandexuid_query.format(
                    last_update=last_update,
                    yandexuid_14days_raw_classification=yandexuid_14days_raw_classification,
                    yandexuid_socdem_storage=config.YANDEXUID_SOCDEM_STORAGE_TABLE,
                    profiles_for_14days=config.YANDEXUID_EXPORT_PROFILES_14_DAYS_TABLE,
                    yandexuid_crypta_id_matching=config.YANDEXUID_CRYPTAID_TABLE,
                    yandexuid_socdem_with_crypta_id=yandexuid_socdem_with_crypta_id_table,
                    yandexuid_socdem_without_crypta_id=yandexuid_socdem_without_crypta_id_table,
                ),
                devid_query=collect_socdem_data_for_devid_query.format(
                    devid_35days_raw_classification_table=devid_35days_raw_classification_table,
                    devid_crypta_id_matching=config.DEVID_CRYPTAID_TABLE,
                    devid_socdem_with_crypta_id=devid_socdem_with_crypta_id_table,
                ),
            ),
            title='YQL collect socdem data for yandexuid and devid',
            transaction=str(transaction.transaction_id),
        )

        if thresholds is None:
            thresholds = utils.get_socdem_thresholds_from_api()

        output_tables_paths = setup_tables(
            yt_client,
            results_directory,
            date,
            corrected_active_yandexuid_without_crypta_id_table,
            corrected_not_active_yandexuid_without_crypta_id_table,
        )

        correct_by_crypta_id = yt_client.run_reduce(
            CalculateCryptaIdSocdem(thresholds),
            [
                yandexuid_socdem_with_crypta_id_table,
                devid_socdem_with_crypta_id_table,
                yt_client.TablePath(
                    config.CRYPTA_ID_SOCDEM_STORAGE_TABLE,
                    columns=('crypta_id', 'age_scores', 'gender_scores', 'income_scores'),
                )
            ],
            [
                output_tables_paths['crypta_id_profiles_table'],
                output_tables_paths['corrected_devid_profiles_table'],
                output_tables_paths['corrected_active_yandexuid'],
                output_tables_paths['corrected_not_active_yandexuid'],
            ],
            spec={
                'title': 'Calculate socdem for crypta_id, correct socdem for yandexuid and devid with crypta_id.',
            },
            reduce_by=['crypta_id'],
            sync=False,
        )

        correct_without_crypta_id = yt_client.run_map(
            partial(correct_yandexuid_profiles_without_crypta_id, thresholds=thresholds),
            yandexuid_socdem_without_crypta_id_table,
            [
                corrected_active_yandexuid_without_crypta_id_table,
                corrected_not_active_yandexuid_without_crypta_id_table,
            ],
            spec={
                'title': 'Correct socdem for yandexuid without crypta_id.',
            },
            sync=False,
        )

        with OperationsTracker() as tracker:
            tracker.add(correct_by_crypta_id)
            tracker.add(correct_without_crypta_id)

        sort_crypta_id_profiles = yt_client.run_sort(
            output_tables_paths['crypta_id_profiles_table'],
            sort_by='crypta_id',
            sync=False,
        )
        sort_devid_profiles = yt_client.run_sort(
            output_tables_paths['corrected_devid_profiles_table'],
            sort_by=['id', 'id_type'],
            sync=False,
        )
        collect_active_yandexuid_profiles = yt_client.run_sort(
            [
                corrected_active_yandexuid_without_crypta_id_table,
                output_tables_paths['corrected_active_yandexuid'],
            ],
            output_tables_paths['corrected_active_yandexuid'],
            sort_by='yandexuid',
            sync=False,
        )
        collect_not_active_yandexuid_profiles = yt_client.run_sort(
            [
                corrected_not_active_yandexuid_without_crypta_id_table,
                output_tables_paths['corrected_not_active_yandexuid'],
            ],
            output_tables_paths['corrected_not_active_yandexuid'],
            sort_by='yandexuid',
            sync=False,
        )

        with OperationsTracker() as tracker:
            tracker.add(sort_crypta_id_profiles)
            tracker.add(sort_devid_profiles)
            tracker.add(collect_active_yandexuid_profiles)
            tracker.add(collect_not_active_yandexuid_profiles)

        for table_path in output_tables_paths.values():
            yt_client.set_attribute(table_path, 'generate_date', date)
