import os
from itertools import chain
import random

from django.db import transaction


class StaffHostManager:
    def __init__(self, hosts: str, common_dc: str, common_dc_weight: float):
        current_dc = os.getenv('DEPLOY_NODE_DC').upper()
        self._db_mapping = {}

        for host in hosts.split(','):
            self._db_mapping.setdefault(self._get_dc_by_host(host), []).append(host)

        if current_dc in self._db_mapping:
            priority_hosts, other_hosts = self.prioritize_db_hosts(current_dc)
        else:
            priority_hosts, other_hosts = self.prioritize_db_hosts(common_dc)
            if random.random() > common_dc_weight:
                priority_hosts, other_hosts = other_hosts, priority_hosts

        self._sorted_hosts = self._shuffle(priority_hosts) + self._shuffle(other_hosts)

    def prioritize_db_hosts(self, target_dc):
        priority_hosts = self._db_mapping[target_dc]
        other_hosts = list(chain.from_iterable(hosts for dc, hosts in self._db_mapping.items() if dc != target_dc))
        return priority_hosts, other_hosts

    @staticmethod
    def _shuffle(same_dc_hosts):
        return random.sample(same_dc_hosts, len(same_dc_hosts))

    @staticmethod
    def _get_dc_by_host(host: str) -> str:
        return host[:3].upper()

    @property
    def host_string(self):
        return ','.join(self._sorted_hosts)


def atomic(using=None, savepoint=False):
    return transaction.atomic(using, savepoint)
