#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os

import luigi

from crypta.profile.utils.config import config
from crypta.profile.utils.luigi_utils import (
    BaseYtTask,
    YtDailyRewritableTarget,
)
from crypta.profile.tasks.monitoring.validation_by_sources.validate_gender_age_profiles import (
    ValidateGenderAgeProfilesBySources,
)
from crypta.profile.tasks.monitoring.validation_by_sources.validate_income_profiles import (
    ValidateIncomeProfilesBySources,
)


compute_matrix_values_query = """
INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT
    {prediction_column} AS predicted_label,
    {validation_column} AS validation_label,
    COUNT(*) AS cnt
FROM `{input_table}`
WHERE {prediction_column} IS NOT NULL AND {validation_column} IS NOT NULL
GROUP BY {prediction_column}, {validation_column}
ORDER BY predicted_label, validation_label;
"""


class ComputeConfusionMatrix(BaseYtTask):
    date = luigi.Parameter()
    juggler_host = config.CRYPTA_ML_JUGGLER_HOST
    task_group = 'validate_socdem_profiles'
    output_dir = os.path.join(config.SOCDEM_VALIDATION_DIR, 'confusion_matrix')

    def requires(self):
        return {
            'gender_age_samples': ValidateGenderAgeProfilesBySources(self.date),
            'income_samples': ValidateIncomeProfilesBySources(self.date),
        }

    def output(self):
        outputs = {}
        for source in self.input()['gender_age_samples']:
            outputs[source] = YtDailyRewritableTarget(os.path.join(self.output_dir, source), self.date)
        for source in self.input()['income_samples']:
            outputs[source] = YtDailyRewritableTarget(os.path.join(self.output_dir, source), self.date)
        return outputs

    def run(self):
        with self.yt.Transaction() as transaction:
            for source in self.input()['gender_age_samples']:
                self.yql.query(
                    query_string=compute_matrix_values_query.format(
                        input_table=os.path.join(config.SOCDEM_VALIDATION_DIR, source, self.date, 'russia', 'sample'),
                        output_table=os.path.join(self.output_dir, source),
                        prediction_column='predicted_age',
                        validation_column='validation_age',
                    ),
                    transaction=transaction,
                )
            for source in self.input()['income_samples']:
                self.yql.query(
                    query_string=compute_matrix_values_query.format(
                        input_table=os.path.join(config.SOCDEM_VALIDATION_DIR, source, self.date, 'russia', 'sample'),
                        output_table=os.path.join(self.output_dir, source),
                        prediction_column='predicted_income',
                        validation_column='validation_income',
                    ),
                    transaction=transaction,
                )

        for output_task in self.output().values():
            self.yt.set_attribute(
                output_task.table,
                'generate_date',
                self.date,
            )
