import os
from yt.wrapper import ypath_join
from datacloud.config.token_names import AUDIENCE_TOKEN_NAME
from datacloud.config.robots import AUDIENCE_ROBOT_NAME
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.time import utils as time_utils
from datacloud.dev_utils.time import patterns
from datacloud.dev_utils.yt.take_part import Interval
from datacloud.dev_utils.yt import yt_utils
from datacloud.audience.lib import utils
from datacloud.audience.lib.score_info import Score
from datacloud.audience.lib.api import AudienceApi
from datacloud.audience.lib.segment_table import SegmentTable


logger = get_basic_logger(__name__)
DEFAULT_READY_TO_UPLOAD_FOLDER = '//projects/scoring/tmp/prod-tmp/audience'
# Fix to prevent too large segment
RECORDS_TO_TAKE = 15000000  # 16500000


def upload_audience(yt_client, audience, date_str):
    score = Score(audience['partner_id'], audience['score_name'])
    interval = Interval(audience['interval_start'], audience['interval_end'])

    if not yt_client.exists(DEFAULT_READY_TO_UPLOAD_FOLDER):
        yt_utils.create_folders([DEFAULT_READY_TO_UPLOAD_FOLDER], yt_client)

    audience_table = ypath_join(
        DEFAULT_READY_TO_UPLOAD_FOLDER,
        utils.build_audience_table_name(score, date_str, interval),
    )
    # prepare segment for upload
    utils.prepare_audience_table(yt_client, score, date_str, interval, audience_table)
    parts = utils.split_table(yt_client, audience_table, audience['n_segments'])

    assert len(parts) == audience['n_segments'], 'wrong number of table parts'
    result = upload_parts(yt_client, audience, parts)
    return result


def upload_parts(yt_client, audience, table_parts, force=False):
    token = os.environ.get(AUDIENCE_TOKEN_NAME)
    api = AudienceApi(yt_client, AUDIENCE_ROBOT_NAME, token)
    audience_name = audience['audience_name']
    segment_table = SegmentTable(yt_client)
    current_date = time_utils.now_str(patterns.FMT_DATE)
    is_all_ok = True
    for segment_number, table in enumerate(table_parts, start=1):
        # TODO: extract loop into function
        record = segment_table.get_segment(audience_name, segment_number)
        data = utils.download_segment(yt_client, table)[:RECORDS_TO_TAKE]
        if record:
            status, resp = update_segment(
                api, record['segment_id'], data, audience_name, segment_number)
            if status:
                record['update_date'] = current_date
                segment_table.insert_records([record])
            else:
                logger.warn('Failed to update segment: {}'.format(resp))
                logger.warn(resp.json())
        else:
            status, segment_id, resp = create_segment(api, data, audience_name, segment_number)
            if status:
                segment_table.add_segment(audience_name, segment_number, segment_id,
                                          current_date, current_date)
            else:
                logger.info('Failed to create segment: {}'.format(resp))
                logger.warn(resp.json())
        is_all_ok = is_all_ok and status
    return is_all_ok


def update_segment(api, segment_id, data, audience_name, segment_number):
    logger.info('Update segment {}/{}'.format(audience_name, segment_number))
    resp = api.replace_id_values(segment_id, data)
    if resp.status_code == 200:
        return True, resp
    else:
        return False, resp


def create_segment(api, data, audience_name, segment_number):
    logger.info('Create segment {}/{}'.format(audience_name, segment_number))
    resp = api.upload_id_values(data)
    segment_id = api.get_segment_id_from_response(resp)
    confirm_resp = api.confirm(segment_id, audience_name + '-' + str(segment_number))
    if confirm_resp.status_code == 200:
        return True, segment_id, resp
    else:
        return False, None, resp
