import collections
import datetime
import time

from sepelib.core.constants import MINUTE_SECONDS
from walle.host_shard import HostShard, HostShardType
from walle.util.misc import get_expert_shards_count, get_existing_tiers
from walle.util.cron import CronJobKnownError

DMC_SHARD_PROCESSING_BORDER = 15 * MINUTE_SECONDS


class ShardProcessedError(CronJobKnownError):
    """Not all shards were processed"""


def _get_outdated_shard_info(shard):
    processing_time = datetime.datetime.fromtimestamp(shard["processing_time"]).strftime("%Y.%m.%d %H:%M:%S")
    return f"Shard(id={shard['shard_id']}, processed={processing_time})"


def _create_error_message(shard_type, tier, missed_shards_ids, outdated_shards):
    error = f"DMC {str(shard_type)} tier={tier} problems:\n"
    if missed_shards_ids:
        error += f"  some shards are missing: {', '.join(str(s) for s in list(missed_shards_ids)[:3])}\n"
    if outdated_shards:
        outdated_shards_info = (_get_outdated_shard_info(s) for s in list(missed_shards_ids)[:3])
        error += f"  some shards are outdated: {', '.join(outdated_shards_info)}\n"
    return error


def _start_dmc_shards_processing():
    db_shards = collections.defaultdict(lambda: collections.defaultdict(list))
    for db_shard in HostShard.objects():
        db_shards[db_shard["tier"]][db_shard["type"]].append(db_shard)

    errors = []
    for tier in get_existing_tiers():
        str_tier = str(tier)
        for shard_type in HostShardType:
            missed_shard_ids = set(range(get_expert_shards_count(tier)))
            outdated_shards = []
            for db_shard in db_shards[str_tier][shard_type]:
                if time.time() - db_shard["processing_time"] > DMC_SHARD_PROCESSING_BORDER:
                    outdated_shards.append(db_shard)
                missed_shard_ids.discard(db_shard["shard_id"])
            if missed_shard_ids or outdated_shards:
                errors.append(_create_error_message(shard_type, tier, missed_shard_ids, outdated_shards))
    if errors:
        raise ShardProcessedError("\n".join(errors))
