#!/usr/bin/env python
# -*- coding: utf-8 -*-

import datetime
from collections import defaultdict
from functools import partial

import luigi
import yt.yson as yson
from yt.wrapper import with_context

from crypta.profile.lib import date_helpers

from crypta.profile.utils import utils
from crypta.profile.utils.socdem import (
    get_age_segment_from_age,
    get_age_segment_from_year_of_birth,
    get_age_segment_from_birth_date,
    calculate_segment_scores,
    socdem_storage_schema,
)
from crypta.profile.utils.config import config
from crypta.profile.utils.loggers import TimeTracker
from crypta.profile.utils.luigi_utils import (
    YtDailyRewritableTarget,
    BaseYtTask,
    ExternalInput,
)
from crypta.profile.tasks.monitoring.validation_by_sources.make_samples import MakeSocdemValidationSamples
from crypta.profile.tasks.features.get_crypta_ids import GetCryptaIds


DEFAULT_WEIGHT = 1.0
# https://st.yandex-team.ru/CRYPTAUP-293#1500546848000
SOURCE_WEIGHTS = {
    'peoplesearch_ok': 2.0,
    'peoplesearch_vk': 2.0,
    'yamoney_cryptaup196': 2.0,

    'passport': 3.0,
    'germandb_cryptaup265': 3.0,
    'dit': 2.0,
    'socialdb': 2.0,

    'socialdb_ok': 2.0,
    'socialdb_vk': 2.0,
    'socialdb_fb': 2.0,
    'socialdb_gg': 2.0,
    'socialdb_mr': 2.0,
    'socialdb_mt': 2.0,

    'job_search_for_training': 2.0,
    'beeline_for_training': 2.0,
    'auto_ru_for_training': 2.0
}

HALFLIFE_DAYS_FOR_REGULARLY_UPDATED_SOURCES = 180


merged_socdem_storage_base_schema = {
    'gender_sources': 'any',
    'age_sources': 'any',
    'income_sources': 'any',

    'gender_scores': 'any',
    'age_scores': 'any',
    'income_scores': 'any',
}


merged_socdem_storage_yandexuid_schema = merged_socdem_storage_base_schema.copy()
merged_socdem_storage_yandexuid_schema['yandexuid'] = 'uint64'

merged_socdem_storage_crypta_id_schema = merged_socdem_storage_base_schema.copy()
merged_socdem_storage_crypta_id_schema['crypta_id'] = 'uint64'


def id_reducer(key, rows, date):
    age_sources = defaultdict(list)
    gender_sources = defaultdict(list)
    income_sources = defaultdict(list)

    for row in rows:
        source = row['source']
        source_weight = SOURCE_WEIGHTS.get(source, DEFAULT_WEIGHT)
        dt_since_upload = datetime.datetime.strptime(date, '%Y-%m-%d') - datetime.datetime.fromtimestamp(row['update_time'])
        days_since_upload = max(float(dt_since_upload.days), 0.0)

        record_info = {
            'source': row['source'],
            'update_time': row['update_time'],
            'weight': source_weight,
            'id': row['id'],
            'id_type': row['id_type'],
        }

        if 'gender' in row and row['gender']:
            gender_sources[row['gender']].append(record_info)

        if 'income_segment' in row and row['income_segment']:
            income_sources[row['income_segment']].append(record_info)

        # age info
        if 'birth_date' in row and row['birth_date']:
            age_segment = get_age_segment_from_birth_date(row['birth_date'])
            if age_segment is not None:
                age_record_info = record_info.copy()
                age_record_info['birth_date'] = row['birth_date']
                age_sources[age_segment].append(age_record_info)
        elif 'age' in row and row['age']:
            age_segment = get_age_segment_from_age(row['age'])
            if age_segment is not None:
                age_record_info = record_info.copy()
                age_record_info['age'] = row['age']
                age_record_info['weight'] = source_weight * 2 ** (-days_since_upload / HALFLIFE_DAYS_FOR_REGULARLY_UPDATED_SOURCES)
                age_sources[age_segment].append(age_record_info)
        elif 'year_of_birth' in row and row['year_of_birth']:
            age_segment = get_age_segment_from_year_of_birth(row['year_of_birth'])
            if age_segment is not None:
                age_record_info = record_info.copy()
                age_record_info['year_of_birth'] = row['year_of_birth']
                age_sources[age_segment].append(age_record_info)

    if gender_sources or age_sources or income_sources:
        result = dict(key)
        result.update({
            'gender_sources': gender_sources or None,
            'age_sources': age_sources or None,
            'income_sources': income_sources or None,
        })

        if gender_sources:
            result['gender_scores'] = calculate_segment_scores(gender_sources)

        if age_sources:
            result['age_scores'] = calculate_segment_scores(age_sources)

        if income_sources:
            result['income_scores'] = calculate_segment_scores(income_sources)

        yield result


@with_context
def join_with_yandexuid_reducer(key, rows, context):
    socdem_storage_rows = []

    if key['id_type'] == 'yandexuid':
        for row in rows:
            if utils.is_valid_uint64(key['id']):
                row['yandexuid'] = yson.YsonUint64(key['id'])
                yield row
    else:
        for row in rows:
            if context.table_index == 0:
                socdem_storage_rows.append(row)
            else:
                for socdem_row in socdem_storage_rows:
                    new_row = {
                        'yandexuid': row['yandexuid'],
                    }
                    new_row.update(socdem_row)
                    yield new_row


@with_context
def join_with_crypta_id_reducer(key, rows, context):
    socdem_storage_rows = []

    for row in rows:
        if context.table_index == 0:
            socdem_storage_rows.append(row)
        else:
            for socdem_row in socdem_storage_rows:
                new_row = {
                    'crypta_id': yson.YsonUint64(row['cryptaId'])
                }
                new_row.update(socdem_row)
                yield new_row


class MergeSocdemStorage(BaseYtTask):
    date = luigi.Parameter()
    task_group = 'export_profiles'

    def __init__(self, date):
        super(MergeSocdemStorage, self).__init__(date)

        self.input_tables = []
        for table in self.yt.search(config.SOCDEM_STORAGE_YT_DIR, node_type=['table']):
            self.input_tables.append(table)

    def requires(self):
        if config.environment == 'production':
            return {
                'input_tables': [ExternalInput(table) for table in self.input_tables],
                'validation_tables': MakeSocdemValidationSamples(date_helpers.get_yesterday(self.date)),
            }
        else:
            return [ExternalInput(table) for table in self.input_tables]

    def output(self):
        return YtDailyRewritableTarget(
            table=config.SOCDEM_STORAGE_TABLE,
            date=self.date,
        )

    def run(self):
        with TimeTracker(monitoring_name=self.__class__.__name__):
            with self.yt.Transaction():
                self.yt.create_empty_table(
                    self.output().table,
                    schema=socdem_storage_schema,
                )

                self.yt.run_map(
                    utils.do_nothing,
                    self.input_tables,
                    self.output().table,
                )

                self.yt.run_sort(
                    self.output().table,
                    sort_by=['id', 'id_type'],
                )

                self.yt.set_attribute(
                    self.output().table,
                    'generate_date',
                    self.date,
                )


class JoinSocdemStorage(BaseYtTask):
    date = luigi.Parameter()
    task_group = 'export_profiles'

    def requires(self):
        return {
            'merged_socdem_storage': MergeSocdemStorage(self.date),
            'indevice_matching': GetCryptaIds(self.date),
        }

    def output(self):
        return {
            'yandexuid': YtDailyRewritableTarget(
                table=config.YANDEXUID_SOCDEM_STORAGE_TABLE,
                date=self.date,
            ),
            'crypta_id': YtDailyRewritableTarget(
                table=config.CRYPTA_ID_SOCDEM_STORAGE_TABLE,
                date=self.date,
            ),
        }

    def run(self):
        with TimeTracker(monitoring_name=self.__class__.__name__):
            with self.yt.Transaction(), \
                    self.yt.TempTable() as joined_with_yandexuid, \
                    self.yt.TempTable() as joined_with_crypta_id:
                self.yt.create_empty_table(
                    self.output()['yandexuid'].table,
                    schema=merged_socdem_storage_yandexuid_schema,
                )

                self.yt.run_reduce(
                    join_with_yandexuid_reducer,
                    [self.input()['merged_socdem_storage'].table, config.INDEVICE_YANDEXUID],
                    joined_with_yandexuid,
                    reduce_by=['id', 'id_type'],
                )

                self.yt.run_map_reduce(
                    None,
                    partial(id_reducer, date=self.date),
                    joined_with_yandexuid,
                    self.output()['yandexuid'].table,
                    reduce_by='yandexuid',
                )

                self.yt.run_sort(
                    self.output()['yandexuid'].table,
                    sort_by='yandexuid',
                )

                self.yt.set_attribute(
                    self.output()['yandexuid'].table,
                    'generate_date',
                    self.date,
                )

                self.yt.create_empty_table(
                    self.output()['crypta_id'].table,
                    schema=merged_socdem_storage_crypta_id_schema,
                )

                self.yt.run_reduce(
                    join_with_crypta_id_reducer,
                    [self.input()['merged_socdem_storage'].table, config.VERTICES_NO_MULTI_PROFILE],
                    joined_with_crypta_id,
                    reduce_by=['id', 'id_type'],
                )

                self.yt.run_map_reduce(
                    None,
                    partial(id_reducer, date=self.date),
                    joined_with_crypta_id,
                    self.output()['crypta_id'].table,
                    reduce_by='crypta_id',
                )

                self.yt.run_sort(
                    self.output()['crypta_id'].table,
                    sort_by='crypta_id',
                )

                self.yt.set_attribute(
                    self.output()['crypta_id'].table,
                    'generate_date',
                    self.date,
                )
