#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging

from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import schema_utils
from crypta.lookalike.lib.python.utils import (
    fields,
    utils,
)
from crypta.lookalike.lib.python.utils.config import (
    config,
    environment,
)


logger = logging.getLogger(__name__)

make_negatives_query_template = """
INSERT INTO `{negatives_table}`
WITH TRUNCATE

SELECT
    user_ranks.yandexuid AS yandexuid,
    segments_ranks.GroupID AS GroupID
FROM `{ranked_user_data_yandexuids_table}` AS user_ranks
INNER JOIN `{segments_ranks_for_negatives_flattened_table}` AS segments_ranks
USING(row_rank);
"""


def flatten_segments_ranks_mapper(row):
    for row_rank in range(row[fields.min_rank], row[fields.max_rank]):
        yield {
            fields.group_id: row[fields.group_id],
            fields.row_rank: row_rank,
        }


def make(nv_params):
    yt_client = utils.get_yt_client(nv_params=nv_params)
    yql_client = utils.get_yql_client(nv_params=nv_params)

    offset = 0
    segments_ranks = []

    for row in yt_client.read_table(config.TRAIN_SEGMENTS_WITH_COUNTS_TABLE):
        ids_cnt = int(row['ids_cnt'])
        segments_ranks.append({
            fields.group_id: row[fields.group_id],
            fields.min_rank: offset,
            fields.max_rank: offset + ids_cnt,
        })
        offset += ids_cnt

    segment_ranks_schema = schema_utils.yt_schema_from_dict({
        fields.group_id: 'string',
        fields.min_rank: 'uint64',
        fields.max_rank: 'uint64',
    })
    ranks_flattened_schema = schema_utils.yt_schema_from_dict({
        fields.group_id: 'string',
        fields.row_rank: 'uint64',
    })

    with NirvanaTransaction(yt_client) as transaction, \
        yt_client.TempTable(attributes={'schema': segment_ranks_schema}) as segments_ranks_table, \
            yt_client.TempTable(attributes={'schema': ranks_flattened_schema}) as segments_ranks_flattened_table:

        yt_client.write_table(segments_ranks_table, segments_ranks)

        yt_client.run_map(
            flatten_segments_ranks_mapper,
            segments_ranks_table,
            segments_ranks_flattened_table,
            spec={
                'title': 'Flatten segments ranks for LaL negatives',
                'max_data_size_per_job': 128 * 1024,
            },
        )

        yql_client.execute(
            query=make_negatives_query_template.format(
                ranked_user_data_yandexuids_table=config.RANKED_USER_DATA_YANDEXUIDS_TABLE,
                segments_ranks_for_negatives_flattened_table=segments_ranks_flattened_table,
                negatives_table=config.NEGATIVES_TABLE,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL LaL make negatives',
        )

    if environment.environment == 'production':
        user_data_cnt = yt_client.row_count(config.RANKED_USER_DATA_YANDEXUIDS_TABLE)
        negatives_cnt = yt_client.row_count(config.NEGATIVES_TABLE)
        assert negatives_cnt <= user_data_cnt, 'Negatives should be less or equal size as UserData'
    logger.info('Successfully made negatives')
