#!/usr/bin/env python
# -*- coding: utf-8 -*-

from functools import partial
import json
import logging

import numpy as np

from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.lookalike.lib.python.utils import (
    fields,
    mobile_utils,
)
from crypta.lookalike.lib.python.utils.mobile_config import config as mobile_config
from crypta.lookalike.lib.python.utils.utils import get_feature_name

from yt.wrapper import create_table_switch

logger = logging.getLogger(__name__)

TRAIN_TABLE_INDEX = 0
VALIDATION_TABLE_INDEX = 1


def get_sampling_ratios_for_apps(yt, path, max_rows_count):
    rows_raw_count = 0
    for row in yt.read_table(path):
        rows_raw_count += row[fields.devids_cnt]

    sampling_dict = {}
    rows_count = 0
    sampling_ratio = max_rows_count / rows_raw_count

    for idx, row in enumerate(yt.read_table(path)):
        key = get_feature_name(row[fields.app_id], row[fields.id_type])

        if sampling_ratio * row[fields.devids_cnt] >= mobile_config.MIN_ROWS_FOR_APP:
            sampling_dict[key] = sampling_ratio
        else:
            sampling_dict[key] = min(mobile_config.MIN_ROWS_FOR_APP / row[fields.devids_cnt], 1.)
        rows_count += row[fields.devids_cnt] * sampling_dict[key]

    logger.info('Raw number of rows is {}'.format(rows_raw_count))
    logger.info('Sampled number of rows is {}'.format(rows_count))

    return sampling_dict


def make_user_features(key, rows, train_apps_sampling, feature_to_idx, category_to_vector):
    apps_features = np.zeros((1, len(feature_to_idx)))
    apps_app2vec = np.zeros((1, mobile_config.EMBEDDING_FEATURES_SIZE))
    user_apps = set()
    target_apps = [('no_target_app', 0)]
    categories = [None]

    app_idx = 1
    for row in rows:
        app2vec = mobile_utils.get_app2vec(row, category_to_vector)
        if app2vec is None:
            continue

        apps_app2vec = np.append(apps_app2vec, app2vec.reshape(1, -1), axis=0)
        apps_features = np.append(
            apps_features,
            np.array(row[fields.app_features_from_stores]).reshape(1, -1),
            axis=0,
        )
        user_apps.add(row[fields.app_id])
        categories.append(row[fields.category])

        key_name = get_feature_name(row[fields.app_id], key[fields.id_type])
        if key_name in train_apps_sampling:
            target_apps.append((row[fields.app_id], app_idx))

        app_idx += 1

    sum_features = np.sum(apps_features, axis=0)
    sum_app2vec = np.sum(apps_app2vec, axis=0)

    for app, app_idx in target_apps:
        if app_idx == 0:
            number_of_apps = max(len(user_apps), 1)
            yield create_table_switch(VALIDATION_TABLE_INDEX)
        else:
            key_name = get_feature_name(app, key[fields.id_type])
            if train_apps_sampling[key_name] < 1:
                if np.random.random_sample() > train_apps_sampling[key_name]:
                    continue
            number_of_apps = max(len(user_apps) - 1, 1)
            user_apps.remove(app)

            yield create_table_switch(TRAIN_TABLE_INDEX)

        user_apps_features, user_apps_vector_features = mobile_utils.aggregate_features(
            feature_to_idx, app_idx, apps_features, sum_features, apps_app2vec, sum_app2vec, number_of_apps,
        )

        yield {
            fields.device_id: key[fields.device_id],
            fields.id_type: key[fields.id_type],
            fields.target_app: app,
            fields.category: categories[app_idx],
            fields.installed_apps: ' '.join(map(str, list(user_apps))),
            fields.user_apps_features_from_stores: ','.join(map(str, user_apps_features)),
            fields.user_apps_vector_features: ','.join(map(str, user_apps_vector_features)),
        }

        if app_idx != 0:
            user_apps.add(app)


def get(nv_params, output):
    yt_client = mobile_utils.get_yt_client(nv_params=nv_params)
    yql_client = mobile_utils.get_yql_client(nv_params=nv_params)

    with NirvanaTransaction(yt_client) as transaction, \
            yt_client.TempTable() as user_app_table:
        yql_client.execute(
            query=mobile_utils.user_app_query.format(
                devid_by_app_table=mobile_config.DEVID_BY_APP_WITH_CRYPTA_ID,
                app2vec=mobile_config.APP2VEC_TABLE,
                app_features_table=mobile_config.APPS_FEATURES_FROM_STORES,
                merged_stores=mobile_config.MERGED_STORES,
                output_table=user_app_table,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL get (user, app, app_features) table',
        )

        sampling_dict = get_sampling_ratios_for_apps(
            yt_client,
            mobile_config.TRAIN_APPS_TABLE,
            mobile_config.MAX_TRAIN_ROWS,
        )
        logger.info('Sampling ratios are calculated.')

        category_to_vector = mobile_utils.get_category_to_vector_dict(
            yt_client, mobile_config.CATEGORY2VEC_TABLE, fields.category, fields.vector,
        )

        feature_to_id = mobile_utils.get_features_dict(yt_client, mobile_config.CATEGORICAL_FEATURES)

        for table_path in (mobile_config.USERS_TRAIN_FEATURES_MOBILE,
                           mobile_config.USERS_VALIDATION_FEATURES_MOBILE):
            yt_helpers.create_empty_table(
                yt_client=yt_client,
                path=table_path,
                schema={
                    fields.device_id: 'string',
                    fields.id_type: 'string',
                    fields.target_app: 'string',
                    fields.category: 'string',
                    fields.installed_apps: 'string',
                    fields.user_apps_features_from_stores: 'string',
                    fields.user_apps_vector_features: 'string',
                },
                force=True,
            )

        yt_client.run_reduce(
            partial(
                make_user_features,
                train_apps_sampling=sampling_dict,
                feature_to_idx=feature_to_id,
                category_to_vector=category_to_vector,
            ),
            user_app_table,
            [
                mobile_config.USERS_TRAIN_FEATURES_MOBILE,
                mobile_config.USERS_VALIDATION_FEATURES_MOBILE,
            ],
            reduce_by=['id', 'id_type'],
            spec={
                'mapper': {'memory_limit': int(1e9)},
                'reducer': {'memory_limit': int(1e9)},
            },
        )

        yt_helpers.write_stats_to_yt(
            yt_client=yt_client,
            table_path=mobile_config.DATALENS_MOBILE_LAL_COUNTERS_TABLE,
            data_to_write={
                'counter_name': 'train_rows_cnt',
                'count': yt_client.row_count(mobile_config.USERS_TRAIN_FEATURES_MOBILE),
            },
            schema={
                'fielddate': 'string',
                'counter_name': 'string',
                'count': 'uint64',
            },
            date=mobile_utils.get_date_from_nv_parameters(nv_params=nv_params),
        )

        with open(output, 'w') as output_file:
            json.dump({'ParamsSize': len(feature_to_id) + mobile_config.ADDITIONAL_USER_FEAURES_CNT}, output_file)

        logger.info('Successfully dumped features size')
