#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import tempfile

import numpy as np
import tensorflow as tf
from tensorflow.keras import (
    backend,
    callbacks,
    layers,
    models,
)
from tensorflow.python.framework import graph_io, graph_util

from crypta.profile.utils import utils
from crypta.profile.utils.config import config
from crypta.profile.utils.yql_utils import Yql
from crypta.profile.utils.yt_utils import get_yt_client


logger = logging.getLogger(__name__)


make_train_sample_query = """
PRAGMA File(
    'libcrypta_identifier_udf.so',
    'https://proxy.sandbox.yandex-team.ru/last/CRYPTA_IDENTIFIERS_UDF?attrs={{"released":"stable"}}&salt=1582281999'
);
PRAGMA Udf("libcrypta_identifier_udf.so");

$data = (
    SELECT
        gender,
        Identifiers::NormalizeEmail(id) AS normalized_email
    FROM `{input_table}`
    WHERE gender IS NOT NULL AND id_type == 'email' AND source != 'email_tf'
);

INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT
    normalized_email,
    SOME(gender) AS gender,
    RANDOM(normalized_email) AS shuffling_number
FROM $data
WHERE normalized_email IS NOT NULL
GROUP BY normalized_email
HAVING LISTLENGTH(AGGREGATE_LIST_DISTINCT(gender)) == 1
ORDER BY shuffling_number
LIMIT {sample_size};

"""


class EmailGenderTrainHelper:
    def __init__(self, embedding_size=32, max_seq_len=20, max_epochs=50, batch_size=500, possible_symbols=None):
        self.embedding_size = embedding_size
        self.max_seq_len = max_seq_len
        self.max_epochs = max_epochs
        self.batch_size = batch_size

        if possible_symbols is None:
            possible_symbols = "$048<\\`dhlptx#'+37?_cgkosw&*.26:>^bfjnrvz!%)-159aeimquy"
        self.mapping = {symbol: idx for idx, symbol in enumerate(possible_symbols)}

        self.yt = get_yt_client()
        self.yql = Yql(logger=logger, yt=self.yt)

    def build_model(self):
        net_input = layers.Input(shape=(self.max_seq_len,), name='login_input')

        embedding = layers.Embedding(len(self.mapping), self.embedding_size, input_length=self.max_seq_len)(net_input)

        trigram = layers.ZeroPadding1D(padding=1)(embedding)
        trigram = layers.Conv1D(32, kernel_size=3)(trigram)
        trigram = layers.Activation('relu')(trigram)
        trigram = layers.MaxPooling1D(pool_size=2)(trigram)

        middle = layers.ZeroPadding1D(padding=1)(trigram)
        middle = layers.Conv1D(64, kernel_size=3)(middle)
        middle = layers.Activation('relu')(middle)
        middle = layers.MaxPooling1D(pool_size=2)(middle)

        top = layers.ZeroPadding1D(padding=1)(middle)
        top = layers.Conv1D(128, kernel_size=3)(top)
        top = layers.Activation('relu')(top)
        top = layers.MaxPooling1D(pool_size=5)(top)

        flat = layers.Flatten()(top)

        net_output = layers.Dense(1, activation='sigmoid')(flat)

        return models.Model(inputs=net_input, outputs=net_output)

    def batch_generator(self, table_path):
        while True:
            batch_idx = 0
            batch_x = np.zeros((self.batch_size, self.max_seq_len), dtype=np.float32)
            batch_y = np.zeros(self.batch_size, dtype=np.float32)
            first_batch = True

            for row in self.yt.read_table(table_path, raw=False):
                if batch_idx == self.batch_size:
                    batch_idx = 0
                if batch_idx == 0 and not first_batch:
                    yield batch_x, batch_y
                    batch_x = np.zeros((self.batch_size, self.max_seq_len), dtype=np.float32)
                    batch_y = np.zeros(self.batch_size, dtype=np.float32)
                email = row['normalized_email'].split('@')
                if len(email) != 2:
                    continue
                email = email[0]
                if len(email) == 0:
                    continue
                idxs = np.asarray([self.mapping[x] for x in email if x in self.mapping],
                                  dtype=np.float32)[:self.max_seq_len]
                batch_x[batch_idx] = np.pad(idxs, (0, self.max_seq_len - len(idxs)), 'constant')
                batch_y[batch_idx] = int(row['gender'] == 'm')
                batch_idx += 1
                first_batch = False

    def get_train_validation_generators(self, dataset_path, train_ratio=0.9):
        row_count = self.yt.row_count(dataset_path)
        border = int(train_ratio * row_count)
        logger.info('Train rows number: {}\n Validation rows number: {}'.format(border, row_count - border))

        train_dataset_path = '{path}[#0:#{border}]'.format(path=dataset_path, border=border)
        logger.info('Train path: {}'.format(train_dataset_path))

        validation_dataset_path = '{path}[#{border}:]'.format(path=dataset_path, border=border)
        logger.info('Validation path: {}'.format(validation_dataset_path))

        return self.batch_generator(train_dataset_path), self.batch_generator(validation_dataset_path)

    def train_model(self, dataset_path):
        model = self.build_model()
        model.summary(print_fn=logger.info)

        model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

        train_generator, val_generator = self.get_train_validation_generators(dataset_path)
        early_stopping = callbacks.EarlyStopping(
            monitor='val_loss',
            min_delta=1e-3,
            patience=5,
            verbose=2,
        )
        lr_reducer = callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.2,
            patience=5,
            verbose=2,
        )
        with tempfile.NamedTemporaryFile() as checkpoint:
            checkpointer = callbacks.ModelCheckpoint(
                checkpoint.name,
                monitor='val_loss',
                verbose=2,
                save_best_only=True,
                save_weights_only=False,
            )
            callback_list = [early_stopping, checkpointer, lr_reducer]
            history = model.fit_generator(
                train_generator,
                epochs=self.max_epochs,
                verbose=2,
                steps_per_epoch=1e3,
                validation_data=val_generator,
                validation_steps=1e3,
                callbacks=callback_list,
            )

            final_model = models.load_model(checkpoint.name)

        return final_model, history

    @staticmethod
    def extract_const_tf_graph(keras_model):
        backend.set_learning_phase(0)
        tf.identity(keras_model.output, name='network_output')
        session = backend.get_session()
        graph = session.graph
        const_graph = graph_util.convert_variables_to_constants(session, graph.as_graph_def(), ['network_output'])
        logger.info('First node name (for yql query): {}'.format(const_graph.node[0].name))

        return const_graph

    def get_train_sample(self, dataset_path, train_sample_size=1e7):
        with self.yt.Transaction() as transaction:
            self.yql.query(
                query_string=make_train_sample_query.format(
                    input_table=config.SOCDEM_STORAGE_TABLE,
                    output_table=dataset_path,
                    sample_size=int(train_sample_size),
                ),
                transaction=transaction,
            )

    def nirvana_train_model(self, dataset_path, prepare_sample=False):
        if prepare_sample:
            self.get_train_sample(dataset_path)
        model, _ = self.train_model(dataset_path)
        const_graph = self.extract_const_tf_graph(model)
        with tempfile.NamedTemporaryFile() as graph_file:
            graph_io.write_graph(const_graph, '.', graph_file.name, as_text=False)
            utils.upload_to_sandbox(
                path=graph_file.name,
                resource_type='CRYPTA_EMAIL_GENDER_MODEL',
                description='tf model for gender by emails classification',
                owner='CRYPTA',
                attributes={'ttl': '365'},
                logger=logger,
            )
