import logging
import six

from yt import yson

from crypta.lab.proto.lookalike_pb2 import TLookalikeMapping
from crypta.lib.proto.user_data.user_data_stats_pb2 import TUserDataStats
from crypta.lib.python import (
    templater,
    native_yt,
)
from crypta.lib.python.native_operations import lab
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.lookalike.lib.python.utils import utils as lal_utils
from crypta.lookalike.lib.python.utils.config import config
from crypta.lookalike.services.custom_lookalike.lib import sample
from crypta.siberia.bin.common.yt_describer.proto import group_stats_pb2


logger = logging.getLogger(__name__)

filter_top_query = '''
INSERT INTO `{{ lookalike_table }}` WITH TRUNCATE
SELECT
    GroupID,
    Score,
    Yandexuid,
FROM (
    SELECT
        input.*,
        ROW_NUMBER() OVER w AS rank_in_group,
        sizes.Size AS output_size
    FROM `{{ lookalike_table }}` AS input
    LEFT JOIN ANY `{{ sizes_table }}` AS sizes
    ON input.GroupID == sizes.GroupID
    WINDOW w AS (
        PARTITION BY input.GroupID
        ORDER BY input.Score DESC
    )
)
WHERE output_size IS NULL
    OR rank_in_group <= output_size;
'''


def prepare_stats(_stats):
    proto_stats = TUserDataStats()

    affinities = proto_stats.Affinities
    affinities.ParseFromString(yson.get_bytes(_stats.get('Affinities')) or six.ensure_binary(''))

    attrbutes = proto_stats.Attributes
    attrbutes.ParseFromString(yson.get_bytes(_stats.get('Attributes')) or six.ensure_binary(''))

    identifiers = proto_stats.Identifiers
    identifiers.ParseFromString(yson.get_bytes(_stats.get('Identifiers')) or six.ensure_binary(''))

    stratum = proto_stats.Stratum
    stratum.ParseFromString(yson.get_bytes(_stats.get('Stratum')) or six.ensure_binary(''))

    distributions = proto_stats.Distributions
    distributions.ParseFromString(yson.get_bytes(_stats.get('Distributions')) or six.ensure_binary(''))

    counts = proto_stats.Counts
    counts.ParseFromString(yson.get_bytes(_stats.get('Counts')) or six.ensure_binary(''))

    segment_info = proto_stats.SegmentInfo
    segment_info.ParseFromString(yson.get_bytes(_stats.get('SegmentInfo')) or six.ensure_binary(''))

    proto_stats.GroupID = yson.get_bytes(_stats.get('GroupID') or six.ensure_binary(''))

    return proto_stats


def get_mapping_with_global_stats(yt_client):
    global_stats = prepare_stats(six.next(yt_client.read_table(config.USER_DATA_STATS_TABLE)))

    mapping = TLookalikeMapping()
    mapping.GlobalUserDataStats.CopyFrom(global_stats)
    mapping.MaxFilterErrorRate = 1e-3

    return mapping


def add_segment_stats(mapping, group_id, stats, output_size):
    group_stats = group_stats_pb2.TGroupStats()
    parsed_userdata_stats = group_stats.Stats
    parsed_userdata_stats.ParseFromString(yson.get_bytes(stats))

    segment_meta = mapping.Segments[group_id]
    segment_meta.UserDataStats.CopyFrom(parsed_userdata_stats)

    options = segment_meta.Options
    sizes = options.Counts
    sizes.Input = 0
    sizes.Output = output_size


def transform_predictor_output(row):
    yield {
        'GroupID': row['GroupID'],
        'Score': -row['MinusScore'],
        'Yandexuid': row['Yandexuid'],
    }


def get_scores_by_description(
    yt_client,
    description_table,
    lookalike_table,
    model_dir,
    transaction,
):
    logger.info('Getting lookalike by description')

    if model_dir is not None:
        user_embeddings_table, dssm_files = lal_utils.get_dssm_entities_from_dir(model_dir)
    else:
        user_embeddings_table, dssm_files = lal_utils.get_last_version_of_dssm_entities(yt_client)

    user_embeddings_table_size = yt_client.row_count(user_embeddings_table)

    mapping = get_mapping_with_global_stats(yt_client)

    for row in yt_client.read_table(description_table):
        add_segment_stats(
            mapping=mapping,
            group_id=row[lal_utils.fields.group_id],
            stats=row['Stats'],
            output_size=user_embeddings_table_size,
        )

    state = mapping.SerializeToString()

    with yt_client.TempTable() as prediction_output:
        native_yt.run_native_map(
            mapper_name=lab.TPredictMapper,
            source=user_embeddings_table,
            destination=prediction_output,
            state=state,
            mapper_files=dssm_files,
            token=yt_client.config['token'],
            pool=yt_client.config['pool'],
            proxy=yt_client.config['proxy']['url'],
            title='Custom lookalike: TPredictMapper',
            transaction=transaction.transaction_id,
            spec={'mapper': {'memory_limit': 1024 * 1024 * 1024 * 8}},
        )

        yt_helpers.create_empty_table(
            yt_client=yt_client,
            path=lookalike_table,
            schema={
                'GroupID': 'string',
                'Score': 'double',
                'Yandexuid': 'string',
            },
        )

        yt_client.run_map(
            transform_predictor_output,
            prediction_output,
            lookalike_table,
        )


def filter_top(yql_client, lookalike_table, sizes_table, transaction):
    yql_client.execute(
        query=templater.render_template(
            filter_top_query,
            vars={
                'lookalike_table': lookalike_table,
                'sizes_table': sizes_table,
            },
        ),
        transaction=str(transaction.transaction_id),
        title='YQL Custom lookalike: Filter top',
    )


def get_by_sample(
    yt_client,
    yql_client,
    sample_table,
    lookalike_table,
    sizes_table=None,
    model_dir=None,
):
    with NirvanaTransaction(yt_client) as transaction, \
            yt_client.TempTable() as description_table:
        sample.describe(yt_client, sample_table, description_table, transaction)

        get_scores_by_description(yt_client, description_table, lookalike_table, model_dir, transaction)

        if sizes_table is not None:
            filter_top(yql_client, lookalike_table, sizes_table, transaction)
