# -*- coding: utf-8 -*-
import os
from functools import partial

import luigi

from crypta.profile.lib import vector_helpers

from crypta.profile.utils.config import config
from crypta.profile.utils.luigi_utils import BaseYtTask, YtDailyRewritableTarget, ExternalInput, ExternalInputDate

NEW_TELEPHONY_USERS = '//home/telephony/mbi/active_user_yt_for_crypta_new_telephony'
SPRAV_USERS = '//home/telephony/mbi/active_user_sprav_for_crypta_v2'

prepare_vectors_query = """
$normalized_sprav_emails = (
    SELECT permalink, Identifiers::HashMd5Email(email) AS id, 'email_md5' AS id_type
    FROM (
        SELECT permalink, email
        FROM `{sprav_users_table}`
        FLATTEN LIST BY email
    )
    WHERE Identifiers::IsValidEmail(email)
);

$normalized_sprav_phones = (
    SELECT permalink, Identifiers::HashMd5Phone(phone) AS id, 'phone_md5' AS id_type
    FROM (
        SELECT permalink, phone
        FROM `{sprav_users_table}`
        FLATTEN LIST BY phone
    )
    WHERE Identifiers::IsValidPhone(phone)
);

$normalized_telephony_emails = (
    SELECT Identifiers::HashMd5Email(email) AS id, 'email_md5' AS id_type
    FROM `{telephony_users_table}`
    WHERE Identifiers::IsValidEmail(email)
);

$normalized_telephony_phones = (
    SELECT Identifiers::HashMd5Phone(phone) AS id, 'phone_md5' AS id_type
    FROM `{telephony_users_table}`
    WHERE Identifiers::IsValidPhone(phone)
);

$active_sprav_users = (
    SELECT crypta_id, permalink
    FROM (
        SELECT t1.permalink AS permalink, CAST(t2.cryptaId AS Uint64) AS crypta_id
        FROM (
            SELECT *
            FROM $normalized_sprav_emails
            UNION ALL
            SELECT *
            FROM $normalized_sprav_phones
        ) AS t1
        INNER JOIN `{vertices_no_multi_profile_table}` AS t2
        USING (id, id_type)
    )
    GROUP BY crypta_id, permalink
);

$active_telephony_users = (
    SELECT DISTINCT crypta_id
    FROM (
        SELECT CAST(t2.cryptaId AS Uint64) AS crypta_id
        FROM (
            SELECT *
            FROM $normalized_telephony_emails
            UNION ALL
            SELECT *
            FROM $normalized_telephony_phones
        ) AS t1
        INNER JOIN `{vertices_no_multi_profile_table}` AS t2
        USING (id, id_type)
    )
);

INSERT INTO `{orgs_to_score_table}` WITH TRUNCATE
SELECT users.crypta_id AS crypta_id, vectors.vector AS vector, users.permalink AS permalink
FROM $active_sprav_users AS users
INNER JOIN `{crypta_id_vectors_table}` AS vectors
USING (crypta_id);

INSERT INTO `{sample_table}` WITH TRUNCATE
SELECT users.crypta_id AS crypta_id, vectors.vector AS vector
FROM $active_telephony_users AS users
INNER JOIN `{crypta_id_vectors_table}` AS vectors
USING (crypta_id)
"""


def calculate_score_mapper(row, sample_vector):
    import numpy as np

    other_vector = vector_helpers.vector_row_to_features(row)
    yield {
        'permalink': row['permalink'],
        'similarity': float(np.dot(other_vector, sample_vector)),
    }


def max_reducer(key, rows):
    max_score = max(row['similarity'] for row in rows)
    result_row = dict(key)
    result_row['similarity'] = max_score
    yield result_row


class ScoreUsersForTelephony(BaseYtTask):
    date = luigi.Parameter()
    task_group = 'coded_segments'

    def requires(self):
        return {
            'telephony_users': ExternalInput(NEW_TELEPHONY_USERS),
            'sprav_users': ExternalInput(SPRAV_USERS),
            'crypta_id_vectors': ExternalInputDate(config.MONTHLY_CRYPTAID2VEC, date=self.date),
        }

    def output(self):
        return YtDailyRewritableTarget(
            os.path.join(config.PROFILES_EXPORT_YT_DIRECTORY, 'telephony/user_scores'),
            self.date,
        )

    def run(self):
        with self.yt.Transaction() as transaction, \
                self.yt.TempTable() as orgs_to_score_table, \
                self.yt.TempTable() as sample_table:
            self.yql.query(
                query_string=prepare_vectors_query.format(
                    sprav_users_table=self.input()['sprav_users'].table,
                    telephony_users_table=self.input()['telephony_users'].table,
                    vertices_no_multi_profile_table=config.VERTICES_NO_MULTI_PROFILE,
                    crypta_id_vectors_table=config.MONTHLY_CRYPTAID2VEC,
                    orgs_to_score_table=orgs_to_score_table,
                    sample_table=sample_table,
                ),
                transaction=transaction,
                udf_resource_dict={'libcrypta_identifier_udf.so': config.CRYPTA_IDENTIFIERS_UDF_RESOURCE},
            )

            sample_vectors = []
            for row in self.yt.read_table(sample_table):
                sample_vectors.append(row['vector'])

            sample_vector = vector_helpers.binary_to_numpy(sample_vectors[0])

            for vector in sample_vectors[1:]:
                sample_vector += vector_helpers.binary_to_numpy(vector)

            normalized_sample_vector = vector_helpers.normalize(sample_vector)

            self.yt.create_empty_table(
                self.output().table,
                schema={
                    'permalink': 'int64',
                    'similarity': 'double',
                },
            )

            self.yt.run_map_reduce(
                partial(calculate_score_mapper, sample_vector=normalized_sample_vector),
                max_reducer,
                orgs_to_score_table,
                self.output().table,
                reduce_by='permalink',
            )

            self.yt.set_attribute(
                self.output().table,
                'generate_date',
                self.date,
            )

            self.yt.set_attribute(
                self.output().table,
                'generate_date',
                self.date,
            )
