import os
import logging

import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist

from crypta.lib.python import templater
from crypta.profile.lib import date_helpers
from crypta.profile.utils.config import config

logger = logging.getLogger(__name__)


prepare_template = """
$segment2vec_with_index_and_clusterid_table = (
    SELECT
        segment2vec_table.*,
        clustering_table.{{clustering_fields.cluster_id}} AS cluster_id,
        -1 + ROW_NUMBER() OVER w AS index,
    FROM `{{segment2vec_table}}` AS segment2vec_table
    INNER JOIN `{{clustering_table}}` AS clustering_table
    ON segment2vec_table.{{segment2vec_fields.name}} == clustering_table.{{clustering_fields.name}}
    {% if id_type %}
    WHERE
        segment2vec_table.{{segment2vec_fields.id_type}} == '{{id_type}}' AND
        clustering_table.{{clustering_fields.id_type}} == '{{id_type}}'
    {% endif %}
    WINDOW w AS (PARTITION BY clustering_table.cluster_id)
);

INSERT INTO `{{segment2vec_with_index_and_clusterid_table}}`
WITH TRUNCATE

SELECT *
FROM $segment2vec_with_index_and_clusterid_table;

INSERT INTO `{{clusterid2size_table}}`
WITH TRUNCATE

SELECT
    cluster_id,
    count(index) AS size,
FROM $segment2vec_with_index_and_clusterid_table
GROUP BY cluster_id
"""


update_segment2vec_table_template = """
$python_scripts = @@
import random

import numpy as np


def cosine_distance(vector1, vector2):
    vector1, vector2 = np.array(list(vector1)), np.array(list(vector2))
    return 1 - (np.dot(vector1, vector2)) / (np.linalg.norm(vector1, ord=2) * np.linalg.norm(vector2, ord=2))


def get_random_index(number):
    return np.random.choice(number)
@@;

$cosine_distance = Python3::cosine_distance(Callable<(Yson?, Yson?)->Double?>, $python_scripts);
$get_random_index = Python3::get_random_index(Callable<(Uint64?)->Uint64?>, $python_scripts);


$clustering_distance_table = (
    SELECT
        new_clustering_table.{{clustering_fields.name}} AS segment_name,
        new_clustering_table.{{clustering_fields.cluster_id}} AS cluster_id,
        $cosine_distance(
            old_centroids_table.{{clustering_fields.vector}},
            new_centroids_table.{{clustering_fields.vector}},
        ) AS distance,
    FROM `{{old_clustering_table}}` AS old_clustering_table
    INNER JOIN `{{new_clustering_table}}` AS new_clustering_table
    ON old_clustering_table.{{clustering_fields.name}} == new_clustering_table.{{clustering_fields.name}}
    INNER JOIN `{{centroids_table}}` AS old_centroids_table
    ON old_clustering_table.{{clustering_fields.cluster_id}} == old_centroids_table.{{clustering_fields.cluster_id}}
    INNER JOIN `{{centroids_table}}` AS new_centroids_table
    ON new_clustering_table.{{clustering_fields.cluster_id}} == new_centroids_table.{{clustering_fields.cluster_id}}
    {% if id_type %}
    WHERE
        old_clustering_table.{{clustering_fields.id_type}} == '{{id_type}}' AND
        new_clustering_table.{{clustering_fields.id_type}} == '{{id_type}}' AND
        old_centroids_table.{{clustering_fields.id_type}} == '{{id_type}}' AND
        new_centroids_table.{{clustering_fields.id_type}} == '{{id_type}}'
    {% endif %}
);

$distance_threshold = (
    SELECT
        PERCENTILE(distance, {{percentile}}),
    FROM $clustering_distance_table
);

$segment2vec_table = (
    SELECT *
    FROM `{{segment2vec_table}}`
    {% if id_type %}
    WHERE `{{segment2vec_fields.id_type}}` == '{{id_type}}'
    {% endif %}
);

$segments_to_update_table = (
    SELECT
        clustering_distance_table.segment_name AS segment_name,
        clustering_distance_table.cluster_id AS cluster_id
    FROM $clustering_distance_table AS clustering_distance_table
    LEFT JOIN $segment2vec_table AS segment2vec_table
    ON clustering_distance_table.segment_name == segment2vec_table.{{segment2vec_fields.name}}
    WHERE clustering_distance_table.distance >= $distance_threshold OR
        segment2vec_table.{{segment2vec_fields.name}} IS NULL
);

$clusterid2size_table = (
    SELECT
        cluster_id AS cluster_id,
        cluster_id AS new_cluster_id,
        size,
    FROM `{{clusterid2size_table}}`
    {% for cluster_id, nearest_cluster_id in missed_clusterid_to_nearest_clusterid.items() %}
UNION ALL
    SELECT
        {{cluster_id}} AS cluster_id,
        {{nearest_cluster_id}} AS new_cluster_id,
        size,
    FROM `{{clusterid2size_table}}`
    WHERE cluster_id == {{nearest_cluster_id}}
    {% endfor %}
);

$segments_to_update_table_with_index = (
    SELECT
        segments_to_update_table.segment_name AS segment_name,
        clusterid2size_table.new_cluster_id AS cluster_id,
        $get_random_index(clusterid2size_table.size) AS index,
    FROM $segments_to_update_table AS segments_to_update_table
    INNER JOIN $clusterid2size_table AS clusterid2size_table
    USING (cluster_id)
);

$updated_segments_table = (
    SELECT
        segments_to_update_table_with_index.segment_name AS {{segment2vec_fields.name}},
        {% if id_type %}
        '{{id_type}}' AS {{segment2vec_fields.id_type}},
        {% endif %}
        segment2vec_with_index_and_clusterid_table.vector AS {{segment2vec_fields.vector}},
    FROM $segments_to_update_table_with_index AS segments_to_update_table_with_index
    INNER JOIN `{{segment2vec_with_index_and_clusterid_table}}` AS segment2vec_with_index_and_clusterid_table
    USING (cluster_id, index)
);

INSERT INTO `{{segment2vec_table}}`
WITH TRUNCATE

SELECT *
FROM (
    SELECT
        segment2vec_table.{{segment2vec_fields.name}} AS {{segment2vec_fields.name}},
        {% if id_type %}
        segment2vec_table.{{segment2vec_fields.id_type}} AS {{segment2vec_fields.id_type}},
        {% endif %}
        segment2vec_table.{{segment2vec_fields.vector}} AS {{segment2vec_fields.vector}},
    FROM `{{segment2vec_table}}` AS segment2vec_table
    LEFT ONLY JOIN $updated_segments_table AS updated_segments_table
    USING ({{segment2vec_fields.name}} {% if id_type %}, {{segment2vec_fields.id_type}} {% endif %})

UNION ALL
    SELECT *
    FROM $updated_segments_table
)
ORDER BY {{segment2vec_fields.name}} {% if id_type %}, {{segment2vec_fields.id_type}} {% endif %};
"""


update_site2vec_app2vec_table_query = """
$site2vec_app2vec_table = (
        SELECT
            host AS host_app,
            Null AS id_type,
            vector AS vector,
        FROM `{site2vec_table}`
    UNION ALL
        SELECT
            app AS host_app,
            id_type AS id_type,
            vector AS vector,
        FROM `{app2vec_table}`
);

INSERT INTO `{site2vec_app2vec_table}`
WITH TRUNCATE

SELECT *
FROM $site2vec_app2vec_table
ORDER BY host_app
"""


def update_site2vec_app2vec_table(yt_client, transaction, yql_client):
    yql_client.execute(
        query=update_site2vec_app2vec_table_query.format(
            site2vec_table=config.SITE2VEC_VECTORS_TABLE,
            app2vec_table=config.APP2VEC_VECTORS_TABLE,
            site2vec_app2vec_table=config.SITE2VEC_APP2VEC_VECTORS_TABLE,
        ),
        transaction=str(transaction.transaction_id),
        title='YQL update_site2vec_app2vec_table',
    )

    yt_client.copy(
        config.SITE2VEC_APP2VEC_VECTORS_TABLE,
        os.path.join(config.SITE2VEC_APP2VEC_VECTORS_FOLDER, date_helpers.get_today_date_string()),
        recursive=True,
        force=True,
    )


def update_vectors(
        yt_client,
        transaction,
        yql_client,
        segment2vec_table,
        clustering_config,
        clustering_fields,
        segment2vec_fields,
        id_types,
):

    with yt_client.TempTable() as segment2vec_with_index_and_clusterid_table, \
            yt_client.TempTable() as clusterid2size_table:

        centoids_last_update_planned_date = yt_client.get_attribute(
            path=clustering_config.CENTROIDS_TABLE,
            attribute='last_update_planned_date',
            default='1970-01-01',
        )
        segment2vec_generate_date = yt_client.get_attribute(
            path=segment2vec_table,
            attribute='generate_date',
            default='1970-01-01',
        )
        if segment2vec_generate_date >= centoids_last_update_planned_date:
            logger.info('segment clusters are not changed, clusters update: {clusters_update}, vectors update: {segment2vec_generate_date}'.format(
                clusters_update=centoids_last_update_planned_date,
                segment2vec_generate_date=segment2vec_generate_date,
            ))
            return

        monthly_clustering_tables = yt_client.list(clustering_config.MONTHLY_CLUSTERING_DIR, absolute=True, sort=True)
        if len(monthly_clustering_tables) < 2:
            logger.info('No clustering tables to recalculate vectors')
            return

        old_clustering_table, new_clustering_table = monthly_clustering_tables[-2:]

        for id_type in [None] if id_types is None else id_types:
            yql_client.execute(
                query=templater.render_template(
                    prepare_template,
                    vars={
                        'segment2vec_table': segment2vec_table,
                        'clustering_table': new_clustering_table,
                        'segment2vec_with_index_and_clusterid_table': segment2vec_with_index_and_clusterid_table,
                        'clusterid2size_table': clusterid2size_table,
                        'clustering_fields': clustering_fields,
                        'segment2vec_fields': segment2vec_fields,
                        'id_type': id_type,
                    },
                ),
                transaction=str(transaction.transaction_id),
                title='YQL prepare',
            )

            centroids_df = pd.DataFrame(yt_client.read_table(clustering_config.CENTROIDS_TABLE))
            if id_type is not None:
                centroids_df = centroids_df[centroids_df[clustering_fields.id_type] == id_type]
            new_clusterid_df = pd.DataFrame(yt_client.read_table(yt_client.TablePath(clusterid2size_table, columns=['cluster_id'])))

            centroids_cluster_ids = centroids_df[centroids_df[clustering_fields.simlink].isna()][clustering_fields.cluster_id].to_numpy()
            is_missing_cluster = ~np.in1d(centroids_cluster_ids, new_clusterid_df['cluster_id'].tolist())

            centroids_vectors = centroids_df[clustering_fields.vector].tolist()
            centroids_neighbors = np.apply_along_axis(
                func1d=lambda row: row[np.in1d(row, centroids_cluster_ids[~is_missing_cluster])][0],
                axis=1,
                arr=np.argsort(cdist(centroids_vectors, centroids_vectors, 'cosine'), axis=1),
            )

            missed_clusterid_to_nearest_clusterid = dict(zip(
                centroids_cluster_ids[is_missing_cluster],
                centroids_neighbors[centroids_cluster_ids[is_missing_cluster]],
            ))

            yql_client.execute(
                query=templater.render_template(
                    update_segment2vec_table_template,
                    vars={
                        'old_clustering_table': old_clustering_table,
                        'new_clustering_table': new_clustering_table,
                        'centroids_table': clustering_config.CENTROIDS_TABLE,
                        'segment2vec_table': segment2vec_table,
                        'clusterid2size_table': clusterid2size_table,
                        'segment2vec_with_index_and_clusterid_table': segment2vec_with_index_and_clusterid_table,
                        'percentile': config.PERCENTILE_OF_VECTORS_UPDATE,
                        'missed_clusterid_to_nearest_clusterid': missed_clusterid_to_nearest_clusterid,
                        'clustering_fields': clustering_fields,
                        'segment2vec_fields': segment2vec_fields,
                        'id_type': id_type,
                    },
                ),
                transaction=str(transaction.transaction_id),
                title='YQL update_segment2vec_table',
            )

        yt_client.set_attribute(
            path=segment2vec_table,
            attribute='generate_date',
            value=date_helpers.get_today_date_string(),
        )

        update_site2vec_app2vec_table(
            yt_client=yt_client,
            transaction=transaction,
            yql_client=yql_client,
        )
