import os

from crypta.lib.python.custom_ml import training_config
from crypta.lib.python.custom_ml.tools.utils import get_industry_dir_from_industry_name
from crypta.lib.python.yt import yt_helpers
from crypta.profile.lib.date_helpers import (
    get_date_from_past,
    get_today_date_string,
)


adding_new_sample_to_glued = """
INSERT INTO `{model_folder}/glued`
WITH TRUNCATE

SELECT *
FROM `{model_folder}/glued`
    UNION ALL
SELECT
    id,
    id_type,
    segment_name,
    retro_date,
    '{partner}' AS partner,
    '{login}' AS `login`,
    {add_to_training} AS use_for_training,
    {additional_fields}
FROM `{new_sample_by_puid_table}`;
"""

union_storage_and_new_samples_query_template = """
$union_samples = (
    SELECT
        id,
        id_type,
        segment_name,
        retro_date
    FROM `{storage_sample_path}`
UNION ALL
    SELECT
        id,
        id_type,
        segment_name,
        retro_date
    FROM `{sample_to_add_path}`
);
$union_samples = (
    SELECT
        id,
        id_type,
        MAX_BY(segment_name, retro_date) AS segment_name,
        MAX(retro_date) AS retro_date
    FROM $union_samples
    GROUP BY id, id_type
);
$union_samples_with_ranks = (
    SELECT
        id,
        id_type,
        segment_name,
        retro_date,
        ROW_NUMBER() OVER w AS row_rank
    FROM $union_samples
    WINDOW w AS (
        PARTITION BY segment_name
        ORDER BY retro_date DESC
    )
);
INSERT INTO `{output_table}`
WITH TRUNCATE
SELECT
    id,
    id_type,
    segment_name,
    retro_date
FROM $union_samples_with_ranks
WHERE row_rank <= {max_sample_size}
ORDER BY id, id_type;
"""


def merge_training_samples(
    yql_client,
    storage_sample_by_puid_path,
    new_sample_by_puid_path,
    combined_samples_path,
    transaction=None,
):
    """
    Combine new train sample and sample from storage
    """
    yql_client.execute(
        query=union_storage_and_new_samples_query_template.format(
            max_sample_size=training_config.MAX_SAMPLE_SIZE,
            storage_sample_path=storage_sample_by_puid_path,
            sample_to_add_path=new_sample_by_puid_path,
            output_table=combined_samples_path,
        ),
        title='YQL unite storage sample with new sample',
        transaction=str(transaction.transaction_id) if transaction is not None else transaction,
    )


def add_new_sample_for_existing_industry(
    yt_client,
    yql_client,
    industry_model_name,
    new_sample_by_puid_table,
    partner,
    login,
    logger,
    add_to_training=False,
    retrain_model=False,
    use_addition_date=False,
):
    """
    Function to add new sample to glued and to train sample if needed.
    """
    industry_yt_dir = get_industry_dir_from_industry_name(industry_model_name)
    industry_training_table = os.path.join(industry_yt_dir, 'sample_by_puid')
    logger.info('Industry folder path: {}'.format(industry_yt_dir))

    new_sample_schema = yt_helpers.get_yt_schema_dict_from_table(
        yt_client=yt_client,
        table=new_sample_by_puid_table,
    )
    for column in training_config.train_sample_columns:
        assert column in new_sample_schema, 'Column {} must be in the new sample table'.format(column)

    with yt_client.Transaction() as transaction:
        yql_client.execute(
            query=adding_new_sample_to_glued.format(
                model_folder=industry_yt_dir,
                new_sample_by_puid_table=new_sample_by_puid_table,
                partner=partner,
                login=login,
                add_to_training=add_to_training,
                additional_fields='CurrentUtcDatetime() AS `date`,' if use_addition_date else '',
            ),
            transaction=str(transaction.transaction_id),
        )

        if not add_to_training:
            logger.info('Training sample is not updated.')
            return

        initial_train_sample_size = yt_client.row_count(industry_training_table)
        merge_training_samples(
            yql_client=yql_client,
            storage_sample_by_puid_path=industry_training_table,
            new_sample_by_puid_path=new_sample_by_puid_table,
            combined_samples_path=industry_training_table,
            transaction=transaction,
        )
        new_train_sample_size = yt_client.row_count(os.path.join(industry_yt_dir, 'sample_by_puid'))
        logger.info('Initial train sample size: {}'.format(initial_train_sample_size))
        logger.info('New train_sample_size: {}'.format(new_train_sample_size))

        if retrain_model:
            sample_by_yuid_table = os.path.join(industry_yt_dir, 'sample_by_yuid')

            generate_date = yt_client.get_attribute(sample_by_yuid_table, 'model_training_date', default=get_today_date_string())
            generate_date = get_date_from_past(generate_date, days=7)
            yt_client.set_attribute(
                sample_by_yuid_table,
                'model_training_date',
                generate_date,
            )

            logger.info('generate_date for sample_by_yuid is updated.')
