from intranet.yandex_directory.src import settings
from intranet.yandex_directory.src.yandex_directory.logbroker.consumer import AbstractConsumer
from intranet.yandex_directory.src.yandex_directory.directory_logging.logger import error_log, default_log
from intranet.yandex_directory.src.yandex_directory.common.db import (
    get_main_connection,
    get_meta_connection,
)
from intranet.yandex_directory.src.yandex_directory.core.tasks import SyncUserData
from intranet.yandex_directory.src.yandex_directory.sso.tasks import SyncSsoOrgTask
from intranet.yandex_directory.src.yandex_directory.core.models import UserMetaModel
from intranet.yandex_directory.src.yandex_directory.core.models.domain import DomainModel
from intranet.yandex_directory.src.yandex_directory.core.models.organization import OrganizationMetaModel
from intranet.yandex_directory.src.yandex_directory.core.task_queue.exceptions import DuplicatedTask
from intranet.yandex_directory.src.yandex_directory.core.utils import (
    get_user_data_from_blackbox_by_uid,
    only_attrs,
    is_domain_uid,
    is_outer_uid
)
from intranet.yandex_directory.src.yandex_directory.core.utils.organization import is_sso_turned_on


class PassportConsumer(AbstractConsumer):
    """
    Class for working with passport logbroker data
    """
    def __init__(self,
                 lb_topic=settings.PASSPORT_LOGBROKER_TOPIC,
                 lb_consumer=settings.PASSPORT_LOGBROKER_CONSUMER_NAME,
                 codec='gzip'):
        super(PassportConsumer, self).__init__(lb_topic, lb_consumer, codec)

    @staticmethod
    def _check_user(event):
        """
        Check domain or portal user
        """
        result = {
            "type": "unknown",
            "uid": None,
        }
        if event.get(b'uid') and event[b'uid'].isdigit():
            result["uid"] = int(event[b'uid'])
            if is_domain_uid(result["uid"]):
                result["type"] = "domain"
            elif is_outer_uid(result["uid"]):
                result["type"] = "portal"
            default_log.info(event)
        return result

    @staticmethod
    def _get_users_portal_org_ids(uids: list):
        """
        Get portal users org ids and information about shards
        """
        with get_meta_connection() as meta:
            orgs_by_shards = {}
            org_ids_full = UserMetaModel(meta).find(filter_data={"id": uids}, fields=["org_id"])
            org_ids = {}
            for x in org_ids_full:
                if x['id'] not in org_ids:
                    org_ids[x['id']] = x['org_id']

            org_ids = [{'id': k, 'org_id': v} for k, v in org_ids.items()]
            default_log.info(f"There is found org_ids {org_ids}")
            if org_ids:
                orgs_by_shards = OrganizationMetaModel(meta).get_orgs_by_shards(*only_attrs(org_ids, "org_id"))
            return org_ids, orgs_by_shards

    @staticmethod
    def _get_users_domain_org_ids(uids: list):
        """
        Get domain users org ids with SSO enabled and information about shards
        """
        uids_domain = {}
        shard_org_id = {}
        with get_meta_connection() as meta:
            uids_org_id = {u['id']: u['org_id'] for u in UserMetaModel(meta).find(filter_data={"id": uids}, fields=["org_id", "id"])}
            for uid in uids:
                if uid not in uids_org_id.keys():
                    user_data = get_user_data_from_blackbox_by_uid(uid)
                    domain = ''
                    if user_data and user_data.get('login'):
                        chunks = user_data["login"].split("@")
                        if len(chunks) == 2:
                            domain = chunks[1]
                            uids_domain[domain] = uid
                    if not domain:
                        error_log.error(f"Cannot get domain for user with uid {uid}")
                if uids_domain.keys():
                    domain_org_ids = DomainModel(None).find_all(
                        filter_data={"name": list(uids_domain.keys()), "owned": True}, fields=["org_id"]
                    )
                    for domain_info in domain_org_ids:
                        org_id = domain_info["org_id"]
                        domain = domain_info["name"]
                        uids_org_id[org_id] = uids_domain.get(domain)
            # Check uids for sso_enabled
            if uids_org_id.keys():
                for shard, org_ids in OrganizationMetaModel(meta).get_orgs_by_shards(*list(uids_org_id.keys())).items():
                    shard_org_id[shard] = org_ids
                    with get_main_connection(shard=shard) as main:
                        excluded_org_ids = []
                        for org_id in uids_org_id.keys():
                            if not is_sso_turned_on(main, org_id):
                                default_log.info(f"SSO is not enabled for org_id: {org_id}")
                                excluded_org_ids.append(org_id)

                        for org_id in excluded_org_ids:
                            uids_org_id.pop(org_id)

        return uids_org_id, shard_org_id

    def _process_domain_users(self, uids: list):
        """
        Sync domain users method
        """
        default_log.info(f"There is domain uids to check {uids}")
        uids_org_id, shard_org_id = self._get_users_domain_org_ids(uids)
        for shard, org_ids in shard_org_id.items():
            with get_main_connection(for_write=True, shard=shard) as main_connection:
                for org_id in org_ids:
                    if org_id in uids_org_id.keys():
                        try:
                            uid = uids_org_id[org_id]
                            default_log.info(f"Going to sync SSO user {uid} with org_id {org_id}")
                            SyncSsoOrgTask(main_connection).delay(org_id=org_id)
                        except DuplicatedTask:
                            pass

    def _process_portal_users(self, uids: list):
        """
        Sync portal users method
        """
        default_log.info(f"There is portal uids to check {uids}")
        org_ids_data, orgs_by_shard = self._get_users_portal_org_ids(uids)
        for shard, org_ids in orgs_by_shard.items():
            with get_main_connection(for_write=True, shard=shard) as main_connection:
                for uid in set(only_attrs(org_ids_data, "id")):
                    try:
                        default_log.info(f"Going to sync user {uid}")
                        SyncUserData(main_connection).delay(user_id=uid)
                    except DuplicatedTask:
                        pass

    def _handle_rows(self, message):
        """
        Main method to handle Logbroker messages
        Extract users information to update, check for portal or domain user type and schedule to update
        """
        domain_uids_to_check = []
        portal_uids_to_check = []
        for line in message.split(b'\n'):
            # Convert each line tskv->dict format
            result = self._check_user(
                self._tskv_to_dict(line)
            )
            if result["type"] == "domain":
                domain_uids_to_check.append(result["uid"])
            elif result["type"] == "portal":
                portal_uids_to_check.append(result["uid"])
        if domain_uids_to_check:
            self._process_domain_users(list(set(domain_uids_to_check)))
        if portal_uids_to_check:
            self._process_portal_users(list(set(portal_uids_to_check)))
