#!/usr/bin/env python
# -*- coding: utf-8 -*-
from collections import Counter
import os

import luigi
from yt.wrapper import with_context

from crypta.profile.lib import date_helpers

from crypta.profile.utils.config import config
from crypta.profile.utils.luigi_utils import BaseYtTask, ExternalInput, YtTableAttributeTarget
from crypta.profile.utils.socdem import socdem_storage_schema
from crypta.lib.python.custom_ml.tools.training_utils import get_item_with_max_value

MALE_THRESHOLD = 0.9
FEMALE_THRESHOLD = 0.9
AGE_THRESHOLD = 0.6
AGE_RATIO_THRESHOLD = 0.7

age_segment_names = {
    'bio_child': '0_17',
    'bio_under35': '18_34',
    'bio_over35': '35_99',
}


@with_context
class VoiceTablesMerger(object):
    def __init__(self, table_indexes_to_timestamps):
        self.table_indexes_to_timestamps = table_indexes_to_timestamps

    def __call__(self, key, rows, context):
        gender_counter = Counter()
        age_counter = Counter()
        timestamps = []
        for row in rows:
            timestamps.append(self.table_indexes_to_timestamps[context.table_index])
            if row['bio_male'] >= MALE_THRESHOLD:
                gender_counter['m'] += 1
            elif row['bio_female'] >= FEMALE_THRESHOLD:
                gender_counter['f'] += 1

            if 'bio_child' in row:
                max_age_segment, max_age_segment_probability = None, 0
                for field_name, segment_name in age_segment_names.iteritems():
                    if row[field_name] > max_age_segment_probability:
                        max_age_segment = segment_name
                        max_age_segment_probability = row[field_name]

                if max_age_segment_probability > AGE_THRESHOLD:
                    age_counter[max_age_segment] += 1

        output_row = {}
        if gender_counter['m'] > (gender_counter['f'] * 2):
            output_row['gender'] = 'm'
        elif gender_counter['f'] > (gender_counter['m'] * 2):
            output_row['gender'] = 'f'

        if age_counter:
            max_age_segment, max_age_segment_count = get_item_with_max_value(age_counter)
            total_count = sum(age_counter.values())
            if float(max_age_segment_count) / total_count > AGE_RATIO_THRESHOLD:
                output_row['age_segment'] = max_age_segment

        if output_row and key['uuid']:
            output_row['update_time'] = max(timestamps)
            output_row['id'] = key['uuid']
            output_row['source'] = 'voice'
            output_row['id_type'] = 'uuid'
            yield output_row


class GetVoiceSocdem(BaseYtTask):
    date = luigi.Parameter()
    juggler_host = config.CRYPTA_ML_JUGGLER_HOST
    task_group = 'import_socdem_data'

    def __init__(self, date):
        super(GetVoiceSocdem, self).__init__(date)
        self.last_processed_date = sorted(self.yt.list(config.VOICE_SOCDEM_SOURCE))[-1]

    def requires(self):
        inputs = []
        for date in date_helpers.generate_back_dates(self.last_processed_date, 14):
            if self.yt.exists(os.path.join(config.VOICE_SOCDEM_SOURCE, date)):
                inputs.append(ExternalInput(os.path.join(config.VOICE_SOCDEM_SOURCE, date)))
        return inputs

    def output(self):
        return YtTableAttributeTarget(
            os.path.join(config.SOCDEM_STORAGE_YT_DIR, 'uuid', 'voice'),
            'last_processed_date',
            self.last_processed_date,
        )

    def run(self):
        input_tables = [task_input.table for task_input in self.input()]
        table_indexes_to_timestamps = {}
        for index, input_table in enumerate(input_tables):
            utc_timestamp = date_helpers.from_date_string_to_timestamp(os.path.basename(input_table))
            table_indexes_to_timestamps[index] = utc_timestamp

        with self.yt.Transaction():
            self.yt.create_empty_table(
                self.output().table,
                schema=socdem_storage_schema,
            )

            self.yt.run_reduce(
                VoiceTablesMerger(table_indexes_to_timestamps=table_indexes_to_timestamps),
                input_tables,
                self.output().table,
                reduce_by='uuid',
            )

            self.yt.run_sort(
                self.output().table,
                sort_by='id',
            )

            self.yt.set_attribute(self.output().table, 'generate_date', self.date)
            self.yt.set_attribute(self.output().table, 'last_processed_date', self.last_processed_date)
