from yt.wrapper import ypath_join

from datacloud.dev_utils.yt.yt_utils import get_yt_client, create_folders
from datacloud.dev_utils.time.utils import assert_date_str
from datacloud.dev_utils.time.utils import now_str
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.transfer.yt_to_ydb.ydb_constants import YDB_LOCK_NAME
from datacloud.dev_utils.status_db.db import lock_manager, LockError
from datacloud.dev_utils.status_db.task import Task, Status
from datacloud.dev_utils.crypta import crypta_snapshot
from datacloud.dev_utils.status_db.db import StatusDB
from datacloud.model_applyer.tables.models_config_table import ApiModelsConfigTable
from datacloud.model_applyer.lib.model_config import ModelConfig
from datacloud.model_applyer.utils.score_transfer import (
    HashedIdReducer, transfer_score_to_ydb
)
from datacloud.model_applyer.utils.helpers import PURE_SCORE_TYPE, ALL_SCORE_TYPES
from datacloud.stability.score_acceptance.assert_stability import (
    StabilityNotReadyException, assert_psi_score, assert_psi_features,
    assert_batch_monitoring, DEFAULT_PSI_THRESH
)
from datacloud.stability.score_acceptance.score_acc_log_table import ScoreAcceptanceLogTable
from datacloud.stability.score_acceptance.score_acc_to_solomon import post_results

logger = get_basic_logger(__name__)

COMMAND_SEPRARATOR = '#'
TMP_FOLDER = '//projects/scoring/tmp/prod-tmp/models'
FEATURES_FOR_STAB = 'features-for-stabiity-computation'
CHECK_PASSED_MESSAGE = 'OK'


def check_to_yt_and_solomon(yt_client, partner_id, score_name, score_date, check_passed, message):
    log_table = ScoreAcceptanceLogTable(yt_client=yt_client)
    log_table.add_record(partner_id, score_name, score_date, check_passed, message)
    post_results(partner_id, score_name, score_date, check_passed, message)


def assert_stability(task, PSI_threshold=DEFAULT_PSI_THRESH):
    partner_id = task.data['partner_id']
    score_name = task.data['score_name']
    date_str = task.data['date_str']
    assert_date_str(date_str)

    logger.info('Started stability check in %s %s %s', partner_id, score_name, date_str)

    yt_client = get_yt_client()

    current_time = now_str()
    done_result = task.make_done()
    transfer_key = COMMAND_SEPRARATOR.join([partner_id, score_name, date_str])
    transfer_ready_result = Task(
        'transfer_score', transfer_key, Status.READY,
        {'date_str': date_str, 'partner_id': partner_id, 'score_name': score_name},
        current_time, current_time)

    am_config_table = ApiModelsConfigTable()
    model = am_config_table.get_model(partner_id=partner_id, score_name=score_name)
    score_type = model['additional'].get('score_type', PURE_SCORE_TYPE)
    if score_type not in ALL_SCORE_TYPES:
        raise ValueError('Unknown score type {}'.format(score_type))
    if score_type != PURE_SCORE_TYPE:
        raise ValueError('Score blending is not supported yet!')

    try:
        assert_psi_score(yt_client, score_name, date_str, score_type, PSI_threshold)
        assert_psi_features(yt_client, score_name, date_str, PSI_threshold)
        assert_batch_monitoring(yt_client, score_name, date_str, score_type, partner_id)

    except StabilityNotReadyException as snre:
        logger.error('One of stabilities isn\'t ready yet!')
        logger.error(snre)
        return []

    except AssertionError as ae:
        logger.error('One of stability checks failed!')
        logger.error(ae)
        check_to_yt_and_solomon(yt_client, partner_id, score_name, date_str, False, str(ae))
        return [done_result]

    else:
        logger.info('All checks passed for %s %s %s', partner_id, score_name, date_str)
        check_to_yt_and_solomon(yt_client, partner_id, score_name, date_str, True,
                                CHECK_PASSED_MESSAGE)
        return [done_result, transfer_ready_result]

    return []


def check_newer_task_exists(task):
    status_db = StatusDB(table_path='//home/x-products/production/new-status-db')
    for other_task in status_db.get_tasks_with_status(task.program, Status.READY):
        if other_task.data['partner_id'] != task.data['partner_id']:
            continue
        if other_task.data['score_name'] != task.data['score_name']:
            continue
        if other_task.data['date_str'] > task.data['date_str']:
            return True
    return False


def transfer_score(task):
    partner_id = task.data['partner_id']
    score_name = task.data['score_name']
    date_str = task.data['date_str']
    assert_date_str(date_str)

    if check_newer_task_exists(task):
        logger.warn('Task {} is too old, new score exists. Skipped.'.format(task))
        return [task.make_done(new_status=Status.SKIPPED)]

    logger.info('Started transfer of %s %s %s', partner_id, score_name, date_str)
    config_rec = ApiModelsConfigTable().get_model_or_raise(partner_id, score_name)

    if task.data.get('force', False):
        logger.info('Transfer is forced. No transfer_on check in config!')
    elif not config_rec['additional'].get('transfer_on', False):
        logger.info('Transfer isn\'t turned on. Aborting')
        return [task.make_done()]

    yt_client = get_yt_client()

    model_folder = ypath_join(TMP_FOLDER, partner_id, score_name)
    create_folders(model_folder, yt_client=yt_client)
    model_by_hashed_id = yt_client.TablePath(
        ypath_join(model_folder, date_str),
        attributes={
            'schema': [
                {'name': 'hashed_id', 'type': 'uint64'},
                {'name': 'score', 'type': 'double'},
            ],
            'optimize_for': 'scan',
            'compression_codec': 'brotli_3'
        }
    )
    model_config = ModelConfig.from_json(config_rec)

    cid_to_all_table = crypta_snapshot.get_snapshot(yt_client, date_str).cid_to_all
    score_path = ypath_join(model_config.score_dir, date_str)

    if not (yt_client.exists(model_by_hashed_id) and
            yt_client.row_count(model_by_hashed_id)):
        hash_cid = config_rec['additional'].get('cookie_sync_on', False)
        with yt_client.Transaction():
            yt_client.run_reduce(
                HashedIdReducer(hash_cid),
                [
                    score_path,
                    cid_to_all_table
                ],
                model_by_hashed_id,
                reduce_by=['cid'],
                spec={'title': '[SCORE TRANSFER] Prepare {} {}'.format(
                    score_name,
                    date_str
                )}
            )
    else:
        logger.info('Model by hashed_id already calculated, launching transfer...')

    try:
        with lock_manager(YDB_LOCK_NAME):
            transfer_score_to_ydb(yt_client=yt_client, model_by_hashed_id=model_by_hashed_id,
                                  partner_id=partner_id, score_name=score_name,
                                  date_str=date_str)
            stability_link_path = ypath_join(model_config.score_dir, FEATURES_FOR_STAB)
            if yt_client.exists(stability_link_path):
                yt_client.link(score_path, stability_link_path, force=True)

            current_time = now_str()
            mini_batch_key = COMMAND_SEPRARATOR.join([partner_id, score_name, date_str])
            return [
                task.make_done(),
                Task(
                    'mini_batch_monitoring', mini_batch_key, Status.READY,
                    {'date_str': date_str, 'partner_id': partner_id, 'score_name': score_name},
                    current_time, current_time)
            ]
    except LockError:
        logger.error('Transfer is locked!')
    return []
