# -*- coding: utf-8 -*-

import os

import luigi

from crypta.profile.lib import date_helpers
from crypta.profile.utils.config import config
from crypta.profile.utils.loggers import TimeTracker
from crypta.profile.utils.luigi_utils import ExternalInput, BaseYtTask, YtDailyRewritableTarget
from crypta.profile.utils.socdem import socdem_storage_schema


MALE_THRESHOLD = 0.92
FEMALE_THRESHOLD = 0.88

tf_model_application_query = """
PRAGMA yt.DataSizePerJob = "128M";
PRAGMA file(
    'email_gender.pb',
    'https://proxy.sandbox.yandex-team.ru/last/CRYPTA_EMAIL_GENDER_MODEL'
);

$session = TensorFlow::InitSession(FilePath("email_gender.pb"));

$char_table = AsDict(
    AsTuple('a', 48.0f),  AsTuple('b', 34.0f),  AsTuple('c', 20.0f),
    AsTuple('d', 7.0f),   AsTuple('e', 49.0f),  AsTuple('f', 35.0f),
    AsTuple('g', 21.0f),  AsTuple('h', 8.0f),   AsTuple('i', 50.0f),
    AsTuple('j', 36.0f),  AsTuple('k', 22.0f),  AsTuple('l', 9.0f),
    AsTuple('m', 51.0f),  AsTuple('n', 37.0f),  AsTuple('o', 23.0f),
    AsTuple('p', 10.0f),  AsTuple('q', 52.0f),  AsTuple('r', 38.0f),
    AsTuple('s', 24.0f),  AsTuple('t', 11.0f),  AsTuple('u', 53.0f),
    AsTuple('v', 39.0f),  AsTuple('w', 25.0f),  AsTuple('x', 12.0f),
    AsTuple('y', 54.0f),  AsTuple('z', 40.0f),  AsTuple('!', 41.0f),
    AsTuple('#', 13.0f),  AsTuple('$', 0.0f),   AsTuple('%', 42.0f),
    AsTuple('&', 26.0f),  AsTuple('\\'', 14.0f), AsTuple(')', 43.0f),
    AsTuple('*', 27.0f),  AsTuple('+', 15.0f),  AsTuple('-', 44.0f),
    AsTuple('.', 28.0f),  AsTuple('0', 1.0f),   AsTuple('1', 45.0f),
    AsTuple('2', 29.0f),  AsTuple('3', 16.0f),  AsTuple('4', 2.0f),
    AsTuple('5', 46.0f),  AsTuple('6', 30.0f),  AsTuple('7', 17.0f),
    AsTuple('8', 3.0f),   AsTuple('9', 47.0f),  AsTuple(':', 31.0f),
    AsTuple('<', 4.0f),   AsTuple('>', 32.0f),  AsTuple('?', 18.0f),
    AsTuple('\\\\', 5.0f), AsTuple('^', 33.0f),  AsTuple('_', 19.0f),
    AsTuple('`', 6.0f)
);

$chars = ToSet(AsList(
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j',
    'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
    'u', 'v', 'w', 'x', 'y', 'z',
    '!', '#', '$', '%', '^', '&', '*', '(', ')', '<',
    '>', '`', '-', '+', '_', '\\'', '?', '.', '\\\\'
));

$char_to_float = ($char) -> {{
    RETURN IF($char IN $chars, $char_table[$char], 0.0f);
}};

$splitter = Re2::FindAndConsume(@@(.)@@);

$login_to_float_vector = ($email) -> {{
    $cut = Find($email, "@");
    $cut = IF($cut == -1 OR $cut > 20, 20, CAST($cut AS Uint32) ?? 0u);
    $email = SUBSTRING($email, 0, $cut);
    $email_chars = $splitter($email);
    RETURN ListFlatMap($email_chars, $char_to_float);
}};

$input = (
    SELECT
        $login_to_float_vector(id ?? "") AS login_input_1,
        TableRow() AS PassThrough
    FROM `{input_table}`
);

$processed = (PROCESS $input
USING TensorFlow::RunBatch(
    $session,
    TableRows(),
    Struct<network_output:List<Float>>,
    ListCreate(String),
    128,
    20
));

INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT
    Yson::Serialize(Yson::FromDoubleDict(AsDict(
        AsTuple('m', network_output[0] ?? 0.5f),
        AsTuple('f', 1. - network_output[0] ?? 0.5f),
    ))) AS gender,
    'email_tf' AS source,
    PassThrough.id AS id,
    PassThrough.id_type AS id_type,
    CAST({update_time} AS Uint64) AS update_time
FROM $processed
ORDER BY id, id_type;

"""


class GetEmailTFGenderPredictions(BaseYtTask):
    date = luigi.Parameter()
    juggler_host = config.CRYPTA_ML_JUGGLER_HOST
    task_group = 'import_socdem_data'

    def requires(self):
        return ExternalInput(config.EMAILS_TO_CLASSIFY)

    def output(self):
        return YtDailyRewritableTarget(
            os.path.join(
                config.SOCDEM_RAW_STORAGE_YT_DIR,
                'email',
                'email_tf',
            ),
            self.date,
        )

    def run(self):
        self.yt.config['spec_defaults']['pool'] = config.SEGMENTS_POOL

        with TimeTracker(monitoring_name=self.__class__.__name__):
            update_time = date_helpers.from_utc_date_string_to_timestamp(self.date)

            self.yql.query(
                query_string=tf_model_application_query.format(
                    input_table=self.input().table,
                    output_table=self.output().table,
                    update_time=update_time,
                )
            )

            self.yt.set_attribute(
                self.output().table,
                'generate_date',
                self.date,
            )


class EmailTfGender(BaseYtTask):
    date = luigi.Parameter()
    juggler_host = config.CRYPTA_ML_JUGGLER_HOST
    task_group = 'import_socdem_data'

    def requires(self):
        return GetEmailTFGenderPredictions(self.date)

    def output(self):
        return YtDailyRewritableTarget(
            os.path.join(
                config.SOCDEM_STORAGE_YT_DIR,
                'email',
                'email_tf',
            ),
            self.date,
        )

    def run(self):
        self.yt.config['spec_defaults']['pool'] = config.SEGMENTS_POOL

        def gender_mapper(row):
            gender_prob = row['gender']
            if gender_prob['m'] >= MALE_THRESHOLD:
                row['gender'] = 'm'
                yield row
            elif gender_prob['f'] >= FEMALE_THRESHOLD:
                row['gender'] = 'f'
                yield row

        with self.yt.Transaction():
            self.yt.create_empty_table(
                self.output().table,
                schema=socdem_storage_schema,
            )

            self.yt.run_map(
                gender_mapper,
                self.input().table,
                self.output().table,
            )

            self.yt.run_sort(
                self.output().table,
                sort_by='id',
            )

            self.yt.set_attribute(
                self.output().table,
                'generate_date',
                self.date,
            )
