from collections import (
    Counter,
    defaultdict,
)
import json
import logging
import os
import sys

import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans

from crypta.lib.python import templater
from crypta.lib.python.logging.logging_helpers import get_function_logger
from crypta.lib.python.yql import client as yql_helpers
from crypta.lib.python.yt import yt_helpers
from crypta.lookalike.lib.python.utils import utils as lal_utils
from crypta.lookalike.lib.python.utils.config import config as lal_config
from crypta.lookalike.proto import yt_node_names_pb2
from crypta.profile.lib import date_helpers
from crypta.siberia.bin.common.yt_describer.proto.yt_describer_config_pb2 import TYtDescriberConfig
from crypta.siberia.bin.common.yt_describer.py import describe
from crypta.siberia.bin.custom_audience.lib.python.clustering import (
    metrics,
    yt_schemas,
)


logger = logging.getLogger(__name__)
function_logger = get_function_logger(logger)


get_flatten_centroids_in_new_space_template = """
$flatten_centroids_table = (
    SELECT *
    FROM `{{clusters_centroids_table}}`
    FLATTEN LIST BY {{neighbors}}
);

INSERT INTO `{{flatten_centroids_in_new_space_table}}`
WITH TRUNCATE

SELECT
    {% for column in old_space_columns %}
    flatten_centroids_table.{{column}} AS {{column}},
    {% endfor %}
    {% for column in new_space_columns %}
    segments_vectors_with_info_table.{{column}} AS {{column}},
    {% endfor %}
FROM $flatten_centroids_table AS flatten_centroids_table
INNER JOIN `{{segments_vectors_with_info_table}}` AS segments_vectors_with_info_table
ON flatten_centroids_table.{{neighbors}} == segments_vectors_with_info_table.{{one_neighbor_field}}
{% if id_type %}
    AND flatten_centroids_table.{{id_type}} == segments_vectors_with_info_table.{{id_type}}
{% endif %}
"""


update_centroids_users_count_template = """
$clusterid2userid_count_table = (
    SELECT
        {{fields.cluster_id}},
        COUNT(id) AS {{fields.users_count}},
    FROM `{{userid_clusterid_table}}`
    GROUP BY {{fields.cluster_id}}
);

INSERT INTO `{{centroids_table}}`
WITH TRUNCATE

SELECT
{% for column in centroids_columns %}
{% if column != users_count_column %}
    centroids_table.{{column}} AS {{column}},
{% endif %}
{% endfor %}
    COALESCE(
        centroids_table.{{fields.users_count}},
        clusterid2userid_count_table.{{fields.users_count}},
    ) AS {{fields.users_count}},
FROM `{{centroids_table}}` AS centroids_table
LEFT JOIN $clusterid2userid_count_table AS clusterid2userid_count_table
USING ({{fields.cluster_id}})
ORDER BY {{fields.cluster_id}}
"""


def get_yt_client(config):
    return yt_helpers.get_yt_client(
        yt_proxy=config.YT_PROXY,
        yt_pool=config.YT_POOL,
        yt_token=os.environ['YT_TOKEN']
    )


def get_yql_client(config):
    return yql_helpers.create_yql_client(
        yt_proxy=config.YT_PROXY,
        pool=config.YT_POOL,
        token=os.environ["YQL_TOKEN"],
    )


@function_logger
def get_last_lal_model(yt_client, output_dir):
    _, result_files = lal_utils.get_last_version_of_dssm_entities(yt_client)
    yt_client.copy(
        result_files[0],
        os.path.join(output_dir, 'dssm_model.applier'),
        recursive=True,
        force=True,
    )
    yt_client.copy(
        result_files[1],
        os.path.join(output_dir, 'segments_dict.json'),
        recursive=True,
        force=True,
    )


@function_logger
def get_segments_stats(yt_client, transaction, segments_to_describe_table, segments_stats_table, tmp_dir='//tmp'):
    segments_description_config = TYtDescriberConfig(
        CryptaIdUserDataTable=lal_config.FOR_DESCRIPTION_BY_CRYPTAID_TABLE,
        TmpDir=tmp_dir,
        InputTable=segments_to_describe_table,
        OutputTable=segments_stats_table,
    )
    yt_helpers.create_empty_table(
        yt_client=yt_client,
        path=segments_stats_table,
        schema=yt_schemas.get_describe_schema(),
        additional_attributes={'optimize_for': 'scan'},
        force=True,
    )
    describe(yt_client, transaction, segments_description_config)


@function_logger
def get_segments_features(yt_client, segments_stats_table, segments_features_table, lal_model_segment_dict_file):
    features_mapping = json.loads(
        next(yt_client.read_file(lal_model_segment_dict_file))
    )
    features_mapping = {str.encode(str(key)): value for key, value in features_mapping.items()}

    yt_helpers.create_empty_table(
        yt_client=yt_client,
        path=segments_features_table,
        schema=yt_schemas.get_dssm_features_schema(),
        additional_attributes={'optimize_for': 'scan'},
        force=True,
    )
    yt_client.run_map(
        lal_utils.MakeDssmSegmentFeaturesMapper(features_mapping=features_mapping),
        segments_stats_table,
        segments_features_table,
    )


@function_logger
def get_segments_vectors(yt_client, transaction, yql_client, segments_features_table, segments_vectors_table, lal_model_applier_file):
    yt_node_names = yt_node_names_pb2.TYtNodeNames()
    model_path = yt_client.get_attribute(
        lal_model_applier_file,
        yt_node_names.DssmSandboxLinkAttr,
    )

    yql_client.execute(
        query=templater.render_template(
            lal_utils.get_segments_embeddings_query_template,
            vars={
                'model_path': model_path,
                'segments_dssm_features_table': segments_features_table,
                'segments_dssm_vectors_table': segments_vectors_table,
            }
        ),
        transaction=str(transaction.transaction_id),
        title='YQL get_segments_vectors',
    )


def get_segments_with_info_table_path(stages_dir):
    return os.path.join(stages_dir, 'segments_with_info')


def get_segments_stats_table_path(stages_dir):
    return os.path.join(stages_dir, 'segments_stats')


def get_segments_vectors_with_info_table_path(stages_dir):
    return os.path.join(stages_dir, 'segments_vectors_with_info')


def get_flatten_centroids_in_new_space_table_path(stages_dir):
    return os.path.join(stages_dir, 'flatten_centroids_in_new_space')


def update_stages_tables(
        yt_client,
        yql_client,
        prepare_for_describe,
        get_segments_vectors_with_info,
        stages_dir,
        config,
        only_new_segments=False,
        update_model=True
):
    today = date_helpers.get_today_date_string()

    segments_stats_table = get_segments_stats_table_path(stages_dir)
    segments_with_info_table = get_segments_with_info_table_path(stages_dir)
    segments_stats_generate_date = yt_client.get_attribute(
        path=segments_stats_table,
        attribute='generate_date',
        default='1970-01-01',
    )
    if today > segments_stats_generate_date:
        with yt_client.Transaction() as transaction, \
                yt_client.TempTable() as segments_to_describe_table:
            prepare_for_describe(
                transaction=transaction,
                yql_client=yql_client,
                segments_to_describe_table=segments_to_describe_table,
                segments_with_info_table=segments_with_info_table,
                only_new_segments=only_new_segments,
            )
            get_segments_stats(
                yt_client=yt_client,
                transaction=transaction,
                segments_to_describe_table=segments_to_describe_table,
                segments_stats_table=segments_stats_table,
            )
            yt_client.set_attribute(
                path=segments_stats_table,
                attribute='generate_date',
                value=today,
            )

    segments_vectors_with_info_table = get_segments_vectors_with_info_table_path(stages_dir)
    segments_vectors_with_info_generate_date = yt_client.get_attribute(
        path=segments_vectors_with_info_table,
        attribute='generate_date',
        default='1970-01-01',
    )
    if today > segments_vectors_with_info_generate_date:
        with yt_client.Transaction() as transaction, \
                yt_client.TempTable() as segments_features_table, \
                yt_client.TempTable() as segments_vectors_table:
            if update_model:
                lal_utils.copy_last_lal_model(
                    yt_client=yt_client,
                    output_dir=config.DSSM_LAL_MODEL_DIR,
                )
            get_segments_features(
                yt_client=yt_client,
                segments_stats_table=segments_stats_table,
                segments_features_table=segments_features_table,
                lal_model_segment_dict_file=config.DSSM_LAL_MODEL_SEGMENT_DICT_FILE,
            )
            get_segments_vectors(
                yt_client=yt_client,
                transaction=transaction,
                yql_client=yql_client,
                segments_features_table=segments_features_table,
                lal_model_applier_file=config.DSSM_LAL_MODEL_APPLIER_FILE,
                segments_vectors_table=segments_vectors_table,
            )
            get_segments_vectors_with_info(
                transaction=transaction,
                yql_client=yql_client,
                segments_with_info_table=segments_with_info_table,
                segments_vectors_table=segments_vectors_table,
                segments_vectors_with_info_table=segments_vectors_with_info_table,
            )
            yt_client.set_attribute(
                path=segments_vectors_with_info_table,
                attribute='generate_date',
                value=today,
            )


@function_logger
def read_centroids_table(yt_client, centroids_table, fields, groupby):
    centroids_df = pd.DataFrame(yt_client.read_table(centroids_table))
    centroids_df = centroids_df.groupby(groupby, sort=True).first().reset_index()
    simlink_index = ~centroids_df[fields.simlink].isna()
    simlink_values = list(map(int, centroids_df[simlink_index][fields.simlink].tolist()))
    centroids_df[fields.simlink] = None
    centroids_df.loc[simlink_index, fields.simlink] = simlink_values
    return centroids_df


@function_logger
def get_new_centroids_df(old_centroids_df, segments_df, kmeans_centroids, config, fields):
    argsort_nearest_segments = np.argsort(
        cdist(
            kmeans_centroids,
            np.array(segments_df[fields.vector].tolist()),
            'cosine',
        )
    )
    centroids_df = segments_df.loc[argsort_nearest_segments[:, 0]]
    centroids_df.reset_index(drop=True, inplace=True)

    all_neighbors = np.array(
        segments_df[fields.name].tolist()
    )[argsort_nearest_segments][:, :config.NEIGHBOURS_NUMBER].tolist()
    centroids_df[fields.neighbors] = all_neighbors

    centroids_columns = np.array(centroids_df.columns.tolist())
    centroids_df.drop(
        columns=centroids_columns[~np.in1d(centroids_columns, list(fields.centroids_fields.values()))],
        inplace=True,
    )
    centroids_df[fields.cluster_id] = old_centroids_df.loc[old_centroids_df[fields.simlink].isna(), fields.cluster_id].tolist()
    centroids_df[fields.users_count] = old_centroids_df.loc[old_centroids_df[fields.simlink].isna(), fields.users_count].tolist()
    return centroids_df


def getsizeof(obj, counted_ids=None):
    size = sys.getsizeof(obj)

    if counted_ids is None:
        counted_ids = set()
    if id(obj) in counted_ids:
        return 0
    counted_ids.add(id(obj))

    if isinstance(obj, dict):
        size += sum([getsizeof(key, counted_ids) for key in obj.keys()])
        size += sum([getsizeof(value, counted_ids) for value in obj.values()])
    elif hasattr(obj, '__dict__'):
        size += getsizeof(obj.__dict__, counted_ids)
    elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
        size += sum([getsizeof(el, counted_ids) for el in obj])

    return size


class GetIndexOfNearestVectorMapper():
    def __init__(self, vectors, fields, id_type=None):
        self.vectors = vectors
        self.fields = fields
        self.id_type = id_type

    def __call__(self, row):
        if self.id_type is None or row[self.fields.id_type] == self.id_type:
            neares_vector_index = np.argmin(
                cdist(
                    np.array([row[self.fields.vector]]),
                    np.array(self.vectors),
                    'cosine',
                ),
                axis=1,
            )[0]
            yield {
                self.fields.id: row[self.fields.id],
                'nearest_vector_index': int(neares_vector_index),
            }


@function_logger
def get_index_of_nearest_vector(yt_client, segments_vectors_table, segments_df, vectors, fields, id_type=None):
    with yt_client.TempTable() as segmentid_nearest_vector_index_table:
        yt_client.run_map(
            GetIndexOfNearestVectorMapper(vectors, fields, id_type),
            segments_vectors_table,
            segmentid_nearest_vector_index_table,
            spec={
                'data_size_per_job': 2 * 1024 * 1024,
                'mapper': {
                    'memory_limit': 2 * 1024 * 1024 * 1024 + getsizeof(vectors),
                    'memory_reserve_factor': 1,
                }
            },
        )
        segmentid_nearest_vector_index_df = pd.DataFrame(yt_client.read_table(segmentid_nearest_vector_index_table))
        nearest_vectors_indexes = np.array(
            segments_df.merge(segmentid_nearest_vector_index_df, on=fields.id)['nearest_vector_index'].tolist()
        )
    return nearest_vectors_indexes


@function_logger
def join_small_clusters(centroids_vectors, labels, min_cluster_size):
    map_of_moves = dict()
    cluster_ids = np.arange(len(centroids_vectors))
    moved_cluster_ids = list()

    nearest_centroids = np.argsort(
        cdist(
            centroids_vectors,
            centroids_vectors,
            'cosine',
        ),
        axis=1,
    )

    clusterid2size = defaultdict(int, Counter(labels))
    for cluster_id in cluster_ids:
        if clusterid2size[cluster_id] >= min_cluster_size:
            continue

        moved_cluster_ids.append(cluster_id)
        new_cluster_id = nearest_centroids[cluster_id][~np.in1d(nearest_centroids[cluster_id], moved_cluster_ids)][0]

        labels[labels == cluster_id] = new_cluster_id

        clusterid2size[new_cluster_id] += clusterid2size[cluster_id]
        map_of_moves[cluster_id] = new_cluster_id

    return labels, map_of_moves


@function_logger
def join_clusters_with_monsters_elements(centroids_df, clustering_df, required_share, required_distance, fields):
    def get_required_users_count(max_users_count, sum_users_count, required_share):
        return (max_users_count - required_share * sum_users_count) / required_share

    def get_users_count_info(clustering_df, cluster_id):
        sum_users_count = max_users_count = largest_share = 0
        users_count = clustering_df.loc[clustering_df[fields.cluster_id] == cluster_id, fields.users_count]
        sum_users_count = users_count.sum()
        if sum_users_count > 0:
            max_users_count = users_count.max()
            largest_share = max_users_count / sum_users_count
        return max_users_count, sum_users_count, largest_share

    clustering_max_users_count_df = clustering_df[[fields.cluster_id, fields.users_count]].groupby(by=[fields.cluster_id]).max()
    clustering_max_users_count_df.rename(columns={fields.users_count: 'max_users_count'}, inplace=True)
    centroids_df = centroids_df.merge(
        clustering_max_users_count_df,
        how='left',
        on=fields.cluster_id,
    )

    clustering_users_count_df = clustering_df[[fields.cluster_id, fields.users_count]].groupby(by=[fields.cluster_id]).sum()
    centroids_df = centroids_df.drop(columns=[fields.users_count]).merge(
        clustering_users_count_df,
        how='left',
        on=fields.cluster_id,
    )
    centroids_df['largest_share'] = centroids_df['max_users_count'] / centroids_df[fields.users_count]

    centroids_df.loc[centroids_df.max_users_count.isna(), ['max_users_count', fields.users_count, 'largest_share']] = 0

    centroids_distances = cdist(
        centroids_df[fields.vector].tolist(),
        centroids_df[fields.vector].tolist(),
        'cosine',
    )
    nearest_centroids = np.argsort(
        centroids_distances,
        axis=1,
    )

    cluster_ids = np.arange(centroids_df.shape[0])
    clusterid2index = dict(zip(centroids_df[fields.cluster_id], cluster_ids))
    index2clusterid = dict(zip(clusterid2index.values(), clusterid2index.keys()))

    users_counts = dict(zip(centroids_df[fields.cluster_id], centroids_df[fields.users_count]))
    max_users_counts = dict(zip(centroids_df[fields.cluster_id], centroids_df.max_users_count))
    largest_shares = dict(zip(centroids_df[fields.cluster_id], centroids_df.largest_share))

    is_used_centroids = np.array(~centroids_df[fields.simlink].isna())

    map_of_moves = dict()
    for cluster_id in centroids_df.loc[centroids_df.simlink.isna(), fields.cluster_id].tolist():
        while ~is_used_centroids[clusterid2index[cluster_id]] and largest_shares[cluster_id] > required_share:

            required_users_count = get_required_users_count(
                max_users_counts[cluster_id],
                users_counts[cluster_id],
                required_share,
            )

            is_used_centroids[clusterid2index[cluster_id]] = True
            cluster_id_to_join = index2clusterid[
                nearest_centroids[clusterid2index[cluster_id]][
                    np.in1d(
                        nearest_centroids[clusterid2index[cluster_id]],
                        cluster_ids[~is_used_centroids],
                    )
                ][0]
            ]
            is_used_centroids[clusterid2index[cluster_id]] = False

            if centroids_distances[clusterid2index[cluster_id]][clusterid2index[cluster_id_to_join]] > required_distance:
                break

            if centroids_df.loc[centroids_df[fields.cluster_id] == cluster_id_to_join, fields.users_count].values[0] <= required_users_count:
                clustering_df.loc[clustering_df[fields.cluster_id] == cluster_id_to_join, fields.cluster_id] = cluster_id
            else:
                cluster_to_join_df = clustering_df[clustering_df[fields.cluster_id] == cluster_id_to_join].copy()
                cluster_to_join_df['distance'] = cdist(
                    centroids_df.loc[centroids_df[fields.cluster_id] == cluster_id, fields.vector].tolist(),
                    clustering_df.loc[clustering_df[fields.cluster_id] == cluster_id_to_join, fields.vector].tolist(),
                    'cosine',
                )[0]
                cluster_to_join_df.sort_values(by=['distance'], inplace=True)

                sites_to_join_number = np.sum(
                    np.cumsum(cluster_to_join_df[fields.users_count].to_numpy()) < required_users_count
                ) + 1

                clustering_df.loc[cluster_to_join_df.iloc[:sites_to_join_number].index, fields.cluster_id] = cluster_id

            max_users_counts[cluster_id_to_join], users_counts[cluster_id_to_join], largest_shares[cluster_id_to_join] = \
                get_users_count_info(clustering_df, cluster_id_to_join)

            if largest_shares[cluster_id_to_join] > required_share:
                clustering_df.loc[clustering_df[fields.cluster_id] == cluster_id_to_join, fields.cluster_id] = cluster_id

            if clustering_df[clustering_df[fields.cluster_id] == cluster_id_to_join].shape[0] == 0:
                max_users_counts[cluster_id_to_join] = users_counts[cluster_id_to_join] = largest_shares[cluster_id_to_join] = 0
                is_used_centroids[clusterid2index[cluster_id_to_join]] = True
                map_of_moves[clusterid2index[cluster_id_to_join]] = cluster_id

            max_users_counts[cluster_id], users_counts[cluster_id], largest_shares[cluster_id] = get_users_count_info(clustering_df, cluster_id)

    return clustering_df, map_of_moves


@function_logger
def get_clusterid_and_simlink(yt_client, segments_vectors_table, new_centroids_df, new_segments_df, min_cluster_size, required_monster_share, upper_bound_distance, fields, id_type=None):
    nearest_vectors_indexes = get_index_of_nearest_vector(
        yt_client=yt_client,
        segments_vectors_table=segments_vectors_table,
        segments_df=new_segments_df,
        vectors=new_centroids_df[fields.vector].tolist(),
        fields=fields,
        id_type=id_type,
    )
    cluster_id_indexes, map_of_moves = join_small_clusters(
        centroids_vectors=np.array(new_centroids_df[fields.vector].tolist()),
        labels=nearest_vectors_indexes,
        min_cluster_size=min_cluster_size,
    )

    new_centroids_df[fields.simlink] = None
    new_centroids_df.loc[map_of_moves.keys(), fields.simlink] = new_centroids_df.loc[map_of_moves.values(), fields.cluster_id].tolist()
    new_segments_df[fields.cluster_id] = new_centroids_df.loc[cluster_id_indexes, fields.cluster_id].tolist()

    new_segments_df, map_of_moves = join_clusters_with_monsters_elements(
        centroids_df=new_centroids_df,
        clustering_df=new_segments_df,
        required_share=required_monster_share,
        required_distance=upper_bound_distance,
        fields=fields,
    )

    new_centroids_df.loc[map_of_moves.keys(), fields.simlink] = list(map_of_moves.values())

    return new_centroids_df, new_segments_df


@function_logger
def restore_missing_centroids(yt_client, old_centroids_table, new_centroids_df, fields, id_type=None):
    old_centroids_df = read_centroids_table(
        yt_client=yt_client,
        centroids_table=old_centroids_table,
        fields=fields,
        groupby=[fields.cluster_id] if id_type is None else [fields.id_type, fields.cluster_id]
    )
    if id_type is not None:
        old_centroids_df = old_centroids_df[old_centroids_df[fields.id_type] == id_type].copy()
        old_centroids_df.reset_index(drop=True)
    missed_centroids = old_centroids_df.loc[
        ~np.in1d(
            np.unique(old_centroids_df[fields.cluster_id]),
            np.unique(new_centroids_df[fields.cluster_id])
        )
    ]
    return pd.concat([new_centroids_df, missed_centroids]).sort_values(by=fields.cluster_id).reset_index(drop=True)


@function_logger
def update_simlinks(centroids_df, fields):
    centroids_with_simlink_df = centroids_df[~centroids_df[fields.simlink].isna()]
    simlink_mapper = dict(zip(
        centroids_with_simlink_df[fields.cluster_id].values,
        centroids_with_simlink_df[fields.simlink].values
    ))
    new_simlink_mapper = dict()
    for cluster_id, simlink in simlink_mapper.items():
        while simlink in simlink_mapper:
            simlink = simlink_mapper[simlink]
        new_simlink_mapper[cluster_id] = simlink

    columns_to_change = [column for column in centroids_df.columns.values.tolist() if column not in [fields.cluster_id, fields.simlink]]
    clusterid2index = dict(zip(centroids_df.cluster_id.tolist(), np.arange(centroids_df.shape[0])))

    for cluster_id, simlink in new_simlink_mapper.items():
        centroids_df.loc[clusterid2index[cluster_id], columns_to_change] = centroids_df.loc[clusterid2index[simlink], columns_to_change]
        centroids_df.loc[clusterid2index[cluster_id], fields.simlink] = simlink
    return centroids_df


def get_last_update_planned_date_by_date(date=date_helpers.get_today_date_string(), day_of_update_clusters=15):
    date = date_helpers.from_date_string_to_datetime(date, '%Y-%m-%d')
    last_update_planned_date = date.replace(day=day_of_update_clusters)
    if last_update_planned_date > date:
        last_update_planned_date= last_update_planned_date.replace(month=last_update_planned_date.month-1)
    return date_helpers.to_date_string(last_update_planned_date)


def copy_table_with_limit(yt_client, input_table, output_dir, limit):
    yt_client.copy(
        input_table,
        os.path.join(output_dir, date_helpers.get_today_date_string()),
        recursive=True,
        force=True,
    )
    dir_tables = yt_client.list(output_dir, absolute=True)
    if len(dir_tables) > limit:
        yt_client.remove(min(dir_tables))


def write_df_to_table_by_batches(yt_client, table_path, df, batch_size=500000):
    for i in range(int(df.shape[0] / batch_size + (1 if df.shape[0] % batch_size > 0 else 0))):
        start_pos, end_pos = i * batch_size, min(df.shape[0], (i + 1) * batch_size)
        cur_new_sites_df = df.iloc[start_pos:end_pos]
        yt_client.write_table(
            yt_client.TablePath(
                table_path,
                append=True,
            ),
            cur_new_sites_df.to_dict('records'),
        )


def fit_clustering(
    yt_client,
    old_centroids_df,
    new_segments_df,
    stages_dir,
    config,
    fields,
    id_type,
):
    top_new_segments_df = new_segments_df.iloc[-config.NUMBER_OF_TOP_ELEMENTS_FOR_UPDATE_CLUSTERS:]
    top_new_segments_df.reset_index(drop=True, inplace=True)

    kmeans = KMeans(
        n_clusters=old_centroids_df[old_centroids_df[fields.simlink].isna()].shape[0],
        init=np.array(old_centroids_df[old_centroids_df[fields.simlink].isna()][fields.vector].tolist()),
        n_init=1,
    ).fit(top_new_segments_df[fields.vector].tolist())

    new_centroids_df = get_new_centroids_df(
        old_centroids_df=old_centroids_df,
        segments_df=top_new_segments_df,
        kmeans_centroids=kmeans.cluster_centers_,
        config=config,
        fields=fields,
    )

    new_centroids_df, new_segments_df = get_clusterid_and_simlink(
        yt_client=yt_client,
        segments_vectors_table=get_segments_vectors_with_info_table_path(stages_dir),
        new_centroids_df=new_centroids_df,
        new_segments_df=new_segments_df,
        min_cluster_size=config.MIN_CLUSTER_SIZE,
        required_monster_share=config.ALLOWED_LARGEST_SEGMENT_SHARE_IN_CLUSTER,
        upper_bound_distance=config.UPPER_BOUND_DISTANCE_FOR_JOIN_MONSTERS,
        fields=fields,
        id_type=id_type,
    )

    new_centroids_df = new_centroids_df.set_index(old_centroids_df[old_centroids_df[fields.simlink].isna()].index)
    old_centroids_df.loc[old_centroids_df[fields.simlink].isna()] = new_centroids_df
    new_centroids_df = old_centroids_df

    new_and_missing_centroids_df = restore_missing_centroids(
        yt_client=yt_client,
        old_centroids_table=config.CENTROIDS_TABLE,
        new_centroids_df=new_centroids_df,
        fields=fields,
        id_type=id_type,
    )

    new_and_missing_centroids_df = update_simlinks(
        centroids_df=new_and_missing_centroids_df,
        fields=fields,
    )

    return new_and_missing_centroids_df, new_segments_df


def get_flatten_centroids_and_fit_clustering(
    yt_client,
    yql_client,
    stages_dir,
    config,
    yt_schemas,
    fields,
    one_neighbor_field,
    id_types=None,
):
    with yt_client.Transaction() as transaction:
        today = date_helpers.get_today_date_string()

        segments_vectors_with_info_table = get_segments_vectors_with_info_table_path(stages_dir)
        flatten_centroids_in_new_space_table = get_flatten_centroids_in_new_space_table_path(stages_dir)
        old_space_columns = [fields.cluster_id, fields.simlink, fields.neighbors, fields.users_count]
        new_space_columns = [column for column in fields.centroids_fields.values() if column not in old_space_columns]
        yql_client.execute(
            query=templater.render_template(
                get_flatten_centroids_in_new_space_template,
                vars={
                    'clusters_centroids_table': config.CENTROIDS_TABLE,
                    'segments_vectors_with_info_table': segments_vectors_with_info_table,
                    'old_space_columns': old_space_columns,
                    'new_space_columns': new_space_columns,
                    'neighbors': fields.neighbors,
                    'one_neighbor_field': one_neighbor_field,
                    'id_type': fields.id_type if id_types is not None else None,
                    'flatten_centroids_in_new_space_table': flatten_centroids_in_new_space_table,
                },
            ),
            transaction=str(transaction.transaction_id),
            title='YQL get_flatten_centroids_in_new_space_query',
        )

        old_centroids_df = read_centroids_table(
            yt_client=yt_client,
            centroids_table=flatten_centroids_in_new_space_table,
            fields=fields,
            groupby=[fields.cluster_id] if id_types is None else [fields.id_type, fields.cluster_id],
        )
        new_segments_df = pd.DataFrame(yt_client.read_table(segments_vectors_with_info_table))

        if id_types is None:
            new_centroids_df, new_segments_df = fit_clustering(
                yt_client=yt_client,
                old_centroids_df=old_centroids_df,
                new_segments_df=new_segments_df,
                stages_dir=stages_dir,
                config=config,
                fields=fields,
                id_type=None,
            )
        else:
            new_centroids_store_dfs, new_segments_store_dfs = list(), list()
            for id_type in id_types:
                new_centroids_store_df, new_segments_store_df = fit_clustering(
                    yt_client=yt_client,
                    old_centroids_df=old_centroids_df[old_centroids_df[fields.id_type] == id_type].reset_index(drop=True).copy(),
                    new_segments_df=new_segments_df[new_segments_df[fields.id_type] == id_type].reset_index(drop=True).copy(),
                    stages_dir=stages_dir,
                    config=config,
                    fields=fields,
                    id_type=id_type,
                )
                new_centroids_store_dfs.append(new_centroids_store_df)
                new_segments_store_dfs.append(new_segments_store_df)
            new_centroids_df = pd.concat(new_centroids_store_dfs)
            new_segments_df = pd.concat(new_segments_store_dfs)

        last_update_planned_date = get_last_update_planned_date_by_date(today, config.DAY_OF_UPDATE_CLUSTERS)
        yt_helpers.create_empty_table(
            yt_client=yt_client,
            path=config.CENTROIDS_TABLE,
            schema=yt_schemas.get_centroids_schema(),
            additional_attributes={
                'generate_date': today,
                'last_update_planned_date': last_update_planned_date,
                'optimize_for': 'scan',
            },
            force=True,
        )
        yt_helpers.create_empty_table(
            yt_client=yt_client,
            path=config.CLUSTERING_TABLE,
            schema=yt_schemas.get_clustering_schema(),
            additional_attributes={
                'generate_date': today,
                'optimize_for': 'scan',
            },
            force=True,
        )

        for df, output_table in [
            (new_centroids_df, config.CENTROIDS_TABLE),
            (new_segments_df, config.CLUSTERING_TABLE),
        ]:
            write_df_to_table_by_batches(
                yt_client=yt_client,
                table_path=output_table,
                df=df,
            )

        for output_dir, limit in [
            (config.MONTHLY_CLUSTERING_DIR, config.MONTHLY_CLUSTERING_TO_KEEP),
            (config.DAILY_CLUSTERING_DIR, config.DAILY_CLUSTERING_TO_KEEP),
        ]:
            copy_table_with_limit(
                yt_client=yt_client,
                input_table=config.CLUSTERING_TABLE,
                output_dir=output_dir,
                limit=limit,
            )

        metrics.update(
            yt_client=yt_client,
            transaction=transaction,
            yql_client=yql_client,
            centroids_table=config.CENTROIDS_TABLE,
            monthly_clustering_dir=config.MONTHLY_CLUSTERING_DIR,
            metrics_table=config.DATALENS_METRICS_TABLE,
            service_name=config.SERVICE_NAME,
            clustering_fields=fields,
            id_types=id_types,
            lower_bounds=config.METRICS_LOWER_BOUNDS,
            upper_bounds=config.METRICS_UPPER_BOUNDS,
        )


def apply_clustering(
    yt_client,
    stages_dir,
    config,
    fields,
    id_types=None,
):
    with yt_client.Transaction():
        centroids_df = pd.DataFrame(yt_client.read_table(config.CENTROIDS_TABLE))
        new_segments_df = pd.DataFrame(yt_client.read_table(get_segments_vectors_with_info_table_path(stages_dir)))

        if id_types is None:
            nearest_vectors_indexes = get_index_of_nearest_vector(
                yt_client=yt_client,
                segments_vectors_table=get_segments_vectors_with_info_table_path(stages_dir),
                segments_df=new_segments_df,
                vectors=centroids_df[centroids_df[fields.simlink].isna()][fields.vector].tolist(),
                fields=fields,
            )
            new_segments_df[fields.cluster_id] = centroids_df.loc[nearest_vectors_indexes, fields.cluster_id].tolist()
        else:
            new_segments_store_dfs = list()
            for id_type in id_types:
                new_segments_store_df = new_segments_df[new_segments_df[fields.id_type] == id_type].reset_index(drop=True).copy()
                if new_segments_store_df.shape[0] == 0:
                    continue
                new_segments_store_dfs.append(new_segments_store_df)
                centroids_store_df = centroids_df[(centroids_df[fields.id_type] == id_type) & (centroids_df[fields.simlink].isna())].reset_index(drop=True).copy()
                nearest_vectors_indexes = get_index_of_nearest_vector(
                    yt_client=yt_client,
                    segments_vectors_table=get_segments_vectors_with_info_table_path(stages_dir),
                    segments_df=new_segments_store_df,
                    vectors=centroids_store_df[fields.vector].tolist(),
                    fields=fields,
                    id_type=id_type,
                )
                new_segments_store_df[fields.cluster_id] = centroids_store_df.loc[nearest_vectors_indexes, fields.cluster_id].tolist()
            new_segments_df = pd.concat(new_segments_store_dfs)

        write_df_to_table_by_batches(
            yt_client=yt_client,
            table_path=config.CLUSTERING_TABLE,
            df=new_segments_df,
            batch_size=500000,
        )
        yt_client.set_attribute(
            config.CLUSTERING_TABLE,
            'generate_date',
            date_helpers.get_today_date_string(),
        )
        copy_table_with_limit(
            yt_client=yt_client,
            input_table=config.CLUSTERING_TABLE,
            output_dir=config.DAILY_CLUSTERING_DIR,
            limit=config.DAILY_CLUSTERING_TO_KEEP,
        )


def update_centroids_users_count(transaction, yql_client, userid_clusterid_table, config, fields):
    yql_client.execute(
        query=templater.render_template(
            update_centroids_users_count_template,
            vars={
                'userid_clusterid_table': userid_clusterid_table,
                'centroids_table': config.CENTROIDS_TABLE,
                'centroids_columns': list(fields.centroids_fields.values()),
                'users_count_column': fields.users_count,
                'fields': fields,
            },
        ),
        transaction=str(transaction.transaction_id),
        title='YQL update_centroids_users_count',
    )
