import yt.wrapper as yt_wrapper
from yt.wrapper import ypath_join, ypath_split
from sklearn.metrics import roc_auc_score

from datacloud.dev_utils.yt import yt_utils
from datacloud.dev_utils.solomon import solomon_utils
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.id_value.id_value_lib import normalize_phone, normalize_email, count_md5
from datacloud.dev_utils.time.utils import now_str
from datacloud.dev_utils.crypta import crypta_snapshot
from datacloud.model_applyer.tables.models_config_table import ApiModelsConfigTable
from datacloud.stability.batch_monitoring_log_table import BatchMonitoringLogTable


logger = get_basic_logger(__name__)

PURE_CRYPTA_TYPE = BatchMonitoringLogTable.PURE_CRYPTA_TYPE
BLENDED_CRYPTA_TYPE = BatchMonitoringLogTable.BLENDED_CRYPTA_TYPE

BATCH_MONITORING_FOLDER = '//home/x-products/production/datacloud/stability/batch_monitoring'
TEST_SAMPLES_FOLDER = ypath_join(BATCH_MONITORING_FOLDER, 'monitoring_samples')
BY_CID_SAMPLES_FOLDER = ypath_join(BATCH_MONITORING_FOLDER, 'by_cid_samples')
BY_CID_BLENDED_SAMPLES_FOLDER = ypath_join(BATCH_MONITORING_FOLDER, 'by_cid_samples_blended')
ID_VALUE_BLENDS_PATH = '//home/x-products/production/crypta_v2/id_value_blends'

CRYPTA_DB_LAST = '//home/x-products/production/crypta_v2/crypta_db_last'
HASH_PHONE_TO_CID = ypath_join(CRYPTA_DB_LAST, 'phone_id_value_to_cid')
HASH_EMAIL_TO_CID = ypath_join(CRYPTA_DB_LAST, 'email_id_value_to_cid')

TAG = 'BATCH MONITORING'
TMP_FOLDER = '//tmp'


def map_apply_md5_hash_email_phone(rec):
    if rec.get('phone') is not None:
        rec['phone'] = normalize_phone(str(rec['phone']))
        rec['hash_phone'] = count_md5(rec['phone'])
    else:
        rec['hash_phone'] = None

    if rec.get('email') is not None:
        rec['email'] = normalize_email(rec['email'])
        if rec['email'] is not None:
            rec['hash_email'] = count_md5(rec['email'])
        else:
            rec['hash_email'] = None
    else:
        rec['hash_email'] = None
    yield rec


@yt_wrapper.with_context
class ReduceTrainByCid:
    def __init__(self, num_crypta_tables):
        self.num_crypta_tables = num_crypta_tables

    def __call__(self, key, recs, context):
        cids = []
        for rec in recs:
            if context.table_index < self.num_crypta_tables - 1:
                cids.append(rec['cid'])
            elif len(cids) > 0:
                for cid in cids:
                    yield {
                        'external_id': rec['external_id'],
                        'id_value': key['id_value'],
                        'cid': cid,
                        'target': rec['target']
                    }
            else:
                yield {
                    'external_id': rec['external_id'],
                    'id_value': key['id_value'],
                    'cid': None,
                    'target': rec['target']
                }


@yt_wrapper.with_context
def join_scores(key, recs, context):
    scores = []
    for rec in recs:
        if context.table_index == 0:
            scores.append(rec['score'])
        elif len(scores) > 0:
            for score in scores:
                yield dict(
                    score=score,
                    **rec
                )
        else:
            yield rec


def external_id2score(key, recs):
    scores = []
    for rec in recs:
        if rec.get('score') is not None:
            scores.append(rec['score'])

    rec2yield = {
        'external_id': key['external_id'],
        'target': rec['target']
    }
    if len(scores) > 0:
        rec2yield['score'] = max(scores)
    yield rec2yield


def calc_auc_hit(table, yt_client):
    scores = []
    targets = []
    hit = 0
    no_hit = 0
    for rec in yt_client.read_table(table):
        if str(rec['target']) not in ['nan', '-1.0', None, '-1']:
            if rec.get('score') is not None:
                scores.append(rec['score'])
                targets.append(rec['target'])
                hit += 1
            else:
                no_hit += 1
    if len(set(targets)) > 1:
        AUC = roc_auc_score(targets, scores)
    else:
        AUC = 0.0
    return AUC, (1.0 * hit / (hit + no_hit))


def make_sensors_batch_monitoring(partner_id, score_name, batch_name, timestamp, value, value_type='AUC'):
    """ Makes sensor batch monitoring
    """
    return {
        'labels': {
            'partner_id': partner_id,
            'score_name': score_name,
            'batch_name': batch_name,
            'type': value_type,
        },
        'ts': timestamp,
        'value': value
    }


def load_batch_results_to_solomon(row):
    sensors = [
        make_sensors_batch_monitoring(row['partner_id'], row['score_name'], row['batch_name'],
                                      solomon_utils.str2ts(row['score_date']), row['AUC'],
                                      'AUC'),
        make_sensors_batch_monitoring(row['partner_id'], row['score_name'], row['batch_name'],
                                      solomon_utils.str2ts(row['score_date']), row['hit'],
                                      'hit')
    ]
    solomon_utils.post_sensors_to_solomon('datacloud', 'score', 'batch-monitoring',
                                          sensors, wait=0)


def get_calculated_by_cid_table(yt_client, batch_name, crypta_type, crypta_id_tables):
    by_cid_path = ypath_join(BY_CID_SAMPLES_FOLDER, batch_name)
    batch_table = ypath_join(TEST_SAMPLES_FOLDER, batch_name)

    # if yt_client.exists(by_cid_path):
    #     modify_time_by_cid = yt_client.get_attribute(by_cid_path, 'modification_time')
    #     crypta_modify_time = min(
    #         yt_client.get_attribute(t, 'modification_time') for t in crypta_id_tables
    #     )

    #     if modify_time_by_cid > crypta_modify_time:
    #         logger.info('By cid table already calculated. No need to recalculate')
    #         return by_cid_path

    reduce_input_tables = []
    for crypta_table in crypta_id_tables:
        reduce_input_tables.append(
            yt_client.TablePath(crypta_table, attributes={'foreign': True})
        )
    reduce_input_tables.append(
        yt_client.TablePath(batch_table, attributes={'primary': True})
    )
    yt_client.run_reduce(
        ReduceTrainByCid(len(reduce_input_tables)),
        reduce_input_tables,
        by_cid_path,
        reduce_by='id_value',
        join_by='id_value',
        spec={'title': '[{}] replace id_value by cid'.format(TAG)}
    )
    yt_client.run_sort(
        by_cid_path,
        sort_by='cid',
        spec={'title': '[{}] sort by cid'.format(TAG)}
    )

    return by_cid_path


def calc_auc_hit_workflow(yt_client, batch_name, score_table, partner_id, score_name,
                          score_date, score_type, crypta_type, ext2score=None):
    assert crypta_type == PURE_CRYPTA_TYPE, 'Only PURE_CRYPTA_TYPE is supported'
    snapshot = crypta_snapshot.get_snapshot(yt_client, score_date)
    assert snapshot, 'Crypta snapshot for {} doesnt exist :/'.format(score_date)

    crypta_id_tables = [snapshot.phone_id_value_to_cid, snapshot.email_id_value_to_cid]
    with yt_client.Transaction(), yt_client.TempTable(TMP_FOLDER) as temp_table:
        logger.info('Step 1: Replace id_value by cid')
        by_cid_path = get_calculated_by_cid_table(yt_client, batch_name, crypta_type, crypta_id_tables)
        logger.info('Step 1: Done')

        logger.info('Step 2: Join score by cid')
        yt_client.run_reduce(
            join_scores,
            [
                yt_client.TablePath(score_table, attributes={'columns': ['cid', 'score'], 'foreign': True}),
                yt_client.TablePath(by_cid_path, attributes={'primary': True})
            ],
            temp_table,
            reduce_by='cid',
            join_by='cid',
            spec={'title': '[{}] join score by cid'.format(TAG)}
        )
        yt_client.run_sort(
            temp_table,
            sort_by='external_id',
            spec={'title': '[{}] sort by external_id'.format(TAG)}
        )
        logger.info('Step 2: Done')

        logger.info('Step 3: Make score by external_id')
        temp_external_id2score = ext2score
        if ext2score is None:
            temp_external_id2score = temp_table

        yt_client.run_reduce(
            external_id2score,
            temp_table,
            temp_external_id2score,
            reduce_by='external_id',
            spec={'title': '[{}] make final external_id to score table'.format(TAG)}
        )
        logger.info('Step 3: Done')

        logger.info('Step 4: Calc AUC and hit')
        auc, hit = calc_auc_hit(temp_external_id2score, yt_client)
        logger.info('Step 4: Done')

    logger.info('AUC: {}, HIT: {}'.format(auc, hit))
    bm_log_table = BatchMonitoringLogTable()
    row_out = {
        'test_batch': ypath_join(TEST_SAMPLES_FOLDER, batch_name),
        'score_table': score_table,
        'hit': hit,
        'AUC': auc,
        'timestamp': now_str('%Y-%m-%dT%H:%M:%SZ'),
        'partner_id': partner_id,
        'score_name': score_name,
        'score_date': score_date,
        'score_type': score_type,
        'crypta_type': crypta_type,
        'batch_name': batch_name
    }
    bm_log_table.add_record(**row_out)
    if solomon_utils.is_date(row_out['score_date']):
        load_batch_results_to_solomon(row_out)
    return auc, hit


def run_calc_auc_hit_workflow(yt_client, score_table, partner_id, score_name, score_date,
                              score_type):
    config_rec = ApiModelsConfigTable().get_model(partner_id, score_name)
    assert config_rec is not None, 'Not found config for {partner_id}, {score_name}'.format(
        partner_id=partner_id,
        score_name=score_name
    )
    batch_names = [m['batch_name'] for m in config_rec['additional'].get('batch_monitoring', [])]

    for batch_name in batch_names:
        logger.info('Calculating batch monitoring for %s', batch_name)
        calc_auc_hit_workflow(
            yt_client=yt_client, batch_name=batch_name, score_table=score_table,
            partner_id=partner_id, score_name=score_name, score_date=score_date,
            score_type=score_type, crypta_type=PURE_CRYPTA_TYPE)

        if config_rec['additional'].get(ApiModelsConfigTable.CRYPTA_BLEND_ON_KEY, False):
            logger.info('Calculating batch for blended crypta')
            calc_auc_hit_workflow(
                yt_client=yt_client, batch_name=batch_name, score_table=score_table,
                partner_id=partner_id, score_name=score_name, score_date=score_date,
                score_type=score_type, crypta_type=BLENDED_CRYPTA_TYPE)
