#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging

import numpy as np
import pandas as pd
from sklearn.cluster import KMeans

from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.lookalike.lib.python.utils import (
    fields,
    mobile_utils,
    utils,
    yt_schemas,
)
from crypta.lookalike.lib.python.utils.mobile_config import config as mobile_config

logger = logging.getLogger(__name__)


def calculate_centroid_distance(row):
    return 1.0 - np.dot(utils.normalize(row[fields.centroid_vector]), utils.normalize(row[fields.app_vector]))


def get_clusters_info(new_clustering_df, store_id, old_centroids_bundle_ids):
    clustering_idxs = new_clustering_df[fields.cluster_id].values
    clusters = np.unique(clustering_idxs)

    clusters_info = []
    new_centroids_vectors = []
    for cluster in clusters:
        cluster_idxs = clustering_idxs == cluster
        apps_ids = new_clustering_df[fields.app_id][cluster_idxs].values

        vectors = np.array(map(utils.normalize, new_clustering_df[fields.app_vector][cluster_idxs]))
        mean_vector = utils.normalize(vectors.mean(axis=0))
        distances = 1.0 - np.dot(vectors, mean_vector)
        centroid_ids = apps_ids[distances.argsort()[:mobile_config.CENTROIDS_TO_KEEP]]

        old_centroids_ratio = len(set(apps_ids) & set(old_centroids_bundle_ids[cluster])) \
            / float(mobile_config.CENTROIDS_TO_KEEP)

        centroid_vector = vectors[np.argmin(distances)]
        clusters_info.append({
            fields.cluster_id: cluster,
            fields.mean_vector: mean_vector,
            fields.centroid_vector: centroid_vector,
            fields.centroids_bundle_ids: centroid_ids,
            fields.store_id: store_id,
            fields.old_centroids_ratio: old_centroids_ratio,
        })
        new_centroids_vectors.append(centroid_vector)

    clusters_distances = list(map(
        lambda vector: 1.0 - np.dot(new_centroids_vectors, vector),
        new_centroids_vectors,
    ))

    new_clustering_df[fields.old_cluster_distance] = new_clustering_df.apply(
        lambda row: np.nan if (
            np.isnan(row[fields.old_cluster_id]) or int(row[fields.old_cluster_id]) >= len(clusters_distances)
        ) else clusters_distances[int(row[fields.old_cluster_id])][row[fields.cluster_id]],
        axis=1,
    ).values

    clusters_info_df = pd.DataFrame.from_records(clusters_info)
    clusters_info_df[fields.old_cluster_distance] = \
        new_clustering_df[new_clustering_df[fields.old_cluster_distance] != np.nan].groupby(
            [fields.cluster_id]
        )[[fields.old_cluster_distance]].mean()

    return clusters_info_df, new_clustering_df


def update(nv_params):
    yt_client = utils.get_yt_client(nv_params=nv_params)
    today = mobile_utils.get_date_from_nv_parameters(nv_params)
    metrics = {
        'date': today,
    }

    with NirvanaTransaction(yt_client):
        old_clustering_df = pd.DataFrame(yt_client.read_table(mobile_config.APPS_CLUSTERING))

        yt_helpers.create_empty_table(
            yt_client=yt_client,
            path=mobile_config.APPS_CLUSTERING,
            schema=yt_schemas.get_apps_clustering_schema(),
            additional_attributes={'optimize_for': 'scan'},
            force=True,
        )

        yt_helpers.create_empty_table(
            yt_client=yt_client,
            path=mobile_config.CLUSTERS_INFO,
            schema=yt_schemas.get_clusters_info_schema(),
            additional_attributes={'optimize_for': 'scan'},
            force=True,
        )

        app_dssm_vectors_df = pd.DataFrame(yt_client.read_table(mobile_config.APP_DSSM_VECTORS))
        app_dssm_vectors_df[fields.bundle_id] = app_dssm_vectors_df[fields.app_id]
        app_dssm_vectors_df[fields.store_id] = app_dssm_vectors_df[fields.id_type].apply(
            lambda id_type: fields.id_type_to_store_id[id_type],
        )

        app_dssm_vectors_df = app_dssm_vectors_df.merge(
            old_clustering_df[[fields.store_id, fields.bundle_id, fields.cluster_id]],
            how='left',
            on=[fields.store_id, fields.bundle_id],
        )
        app_dssm_vectors_df[fields.old_cluster_id] = app_dssm_vectors_df[fields.cluster_id]

        old_centroids_df = pd.DataFrame(yt_client.read_table(mobile_config.CLUSTER_CENTROIDS_VECTORS))

        for store_type in fields.id_type_to_store_id.values():
            store_df = app_dssm_vectors_df[app_dssm_vectors_df[fields.store_id] == store_type].reset_index()
            centroids_df = old_centroids_df[old_centroids_df[fields.store_id] == store_type]

            if not len(store_df):
                logger.info('Empty apps vectors table for {} store'.format(store_type))
                continue

            centroids_vectors = centroids_df[fields.centroid_vector].to_list()
            old_centroids_bundle_ids = centroids_df[fields.old_bundle_ids].to_list()

            if not len(centroids_vectors):
                logger.info('No stored centroids for {} store'.format(store_type))
                continue

            k_means = KMeans(
                n_clusters=len(centroids_vectors),
                init=np.array(centroids_vectors),
            ).fit(map(utils.normalize, store_df[fields.app_vector]))
            store_df[fields.cluster_id] = k_means.labels_

            store_clusters_info, store_df = get_clusters_info(store_df, store_type, old_centroids_bundle_ids)

            yt_client.write_table(
                yt_client.TablePath(mobile_config.CLUSTERS_INFO, append=True),
                store_clusters_info.to_dict('records'),
            )

            metrics.update({
                store_type + '_mean_' + fields.old_centroids_ratio:
                    np.mean(store_clusters_info[fields.old_centroids_ratio]),
                store_type + '_mean_' + fields.old_cluster_distance:
                    np.mean(store_clusters_info[fields.old_cluster_distance]),
            })

            store_df[fields.centroid_distance] = store_df.merge(
                store_clusters_info[[fields.cluster_id, fields.centroid_vector]],
                how='left',
                on=fields.cluster_id,
            ).apply(calculate_centroid_distance, axis=1).values

            yt_client.write_table(
                yt_client.TablePath(mobile_config.APPS_CLUSTERING, append=True),
                store_df[[
                    fields.bundle_id,
                    fields.cluster_id,
                    fields.MD5Hash,
                    fields.store_id,
                    fields.centroid_distance,
                    fields.old_cluster_distance,
                ]].to_dict('records'),
            )

            logger.info('Clusters are computed for {} apps'.format(store_type))

        yt_client.run_sort(
            mobile_config.APPS_CLUSTERING,
            sort_by=[
                fields.store_id,
                fields.cluster_id,
                fields.centroid_distance,
            ],
        )

        yt_helpers.write_stats_to_yt(
            yt_client=yt_client,
            table_path=mobile_config.DATALENS_MOBILE_LAL_METRICS_TABLE,
            data_to_write=metrics,
            schema={
                'date': 'string',
                'google_play_mean_old_centroids_ratio': 'double',
                'google_play_mean_old_cluster_distance': 'double',
                'itunes_mean_old_centroids_ratio': 'double',
                'itunes_mean_old_cluster_distance': 'double',
            },
            fielddate='date',
        )

    logger.info('Full apps clustering is performed')
