#!/usr/bin/env python3

from typing import Dict
from collections import defaultdict
import json
import argparse
import time
import sys
import os

from settings import settings
from cache import Cache
from provider import Provider
from tractor import TractorApi
from db import CollectorsDatabase, IPADatabase
from blackbox import Blackbox
from decryptor import Decryptor
from logger import log


class ReloadTokensOp:
    def __init__(self, file, force=False):
        self.blackbox = Blackbox()
        self.collectors_db = CollectorsDatabase()
        self.tractor_api = TractorApi()
        self.ipa_db = IPADatabase()
        self.decryptor = Decryptor()
        self.force = force
        self.settings = settings()
        self.file = file

    def __call__(self):
        # TODO: process by chunk for slow start
        collectors = self.collectors_db.list_imap_oauth()
        log.info("loaded {} collectors from db".format(len(collectors)))
        cached, not_ready = self._load_cached(collectors)
        log.info("loaded from previous file {} active tokens".format(len(cached)))
        collectors = self._load_bb_info(not_ready)
        if self.settings.validate_domain:
            collectors = self._validate(collectors)
            log.info("collectors remaind after validation: {}".format(len(collectors)))
        orgid_collectors = self._group_by_orgid_provider(collectors)
        log.info("loaded unique org ids: {}".format(len(orgid_collectors)))
        orgid_to_secret = self._load_external_secrets(orgid_collectors)
        tokens = self._get_access_tokens(orgid_collectors, orgid_to_secret)
        tokens.update(cached)
        self._save(tokens)

    def _load_cached(self, collectors):
        if self.force:
            return {}, collectors
        logins = [c["login"] for c in collectors]
        cache = Cache(self.file, self.settings.min_token_ttl)
        cached, not_ready = cache.get(logins)
        not_ready = [c for c in collectors if c["login"] in not_ready]
        return cached, not_ready

    def _load_bb_info(self, collectors):
        new_collectors = []
        for collector in collectors:
            try:
                info = self.blackbox.userinfo(suid=collector["suid"])
                collector.update({"domain": info.get("domain"), "org_id": info.get("org_id")})
                new_collectors.append(collector)
            except Exception as e:
                log.error("failed to get bb info for login={}: {}".format(collector["login"], e))
        return new_collectors

    def _validate(self, collectors):
        validated = []
        for collector in collectors:
            login = collector.get("login")
            src_domain = _domain(login).lower()
            dst_domain = collector.get("domain").lower()
            if src_domain == dst_domain or _test_collector(login, dst_domain):
                validated.append(collector)
            else:
                log.warn(
                    "domain validity check failed for login={} dst_domain={}".format(
                        login, dst_domain
                    )
                )
        return validated

    def _group_by_orgid_provider(self, collectors):
        orgid_collectors = defaultdict(lambda: defaultdict(lambda: []))
        for collector in collectors:
            org_id = collector.get("org_id")
            provider = Provider.from_imap_server(collector.get("server"))
            orgid_collectors[org_id][provider].append(collector)
        return orgid_collectors

    def _load_external_secrets(self, orgid_collectors):
        decryptor = Decryptor()
        orgid_to_secret = defaultdict(lambda: {})
        for org_id, provider_collectors in orgid_collectors.items():
            if org_id is None:
                continue
            for provider, collectors in provider_collectors.items():
                domain = collectors[0].get("domain") if collectors else ""
                try:
                    secret: str = self.tractor_api.retrieve_secret(org_id, provider)
                    orgid_to_secret[org_id][provider] = json.loads(secret)
                except Exception as e:
                    log.error(
                        "failed to load secret from tractor api for domain={} orgid={}: {}, will try ipa db".format(
                            domain, org_id, e
                        )
                    )
                else:
                    log.info(
                        "loaded secret from tractor api for domain={} orgid={}, will skip ipa db".format(
                            domain, org_id
                        )
                    )
                    continue
                try:
                    secret = self.ipa_db.get_last_password(org_id, provider)
                    orgid_to_secret[org_id][provider] = decryptor.decrypt(secret)
                except Exception as e:
                    log.error(
                        "failed to load secret from ipa db for domain={} orgid={}: {}".format(
                            domain, org_id, e
                        )
                    )
                else:
                    log.info(
                        "loaded secret from ipa db for domain={} orgid={}".format(domain, org_id)
                    )
        return orgid_to_secret

    def _get_access_tokens(self, orgid_collectors, orgid_to_secret):
        tokens = {}
        for org_id, provider_collectors in orgid_collectors.items():
            for provider, collectors in provider_collectors.items():
                try:
                    service_credentials_cls = provider.service_credentials_cls()
                    service_credentials = service_credentials_cls(orgid_to_secret[org_id][provider])
                except Exception as e:
                    log.error("failed to get service credentials for orgid={}: {}".format(org_id, e))
                    continue
                tokens.update(self._delegate_tokens(service_credentials, collectors))
        return tokens

    def _delegate_tokens(self, service_credentials, collectors):
        tokens = {}
        for collector in collectors:
            try:
                login = collector["login"]
                token = service_credentials.get_token(login)
                tokens[login] = {"token": token, "created_at": _current_timestamp()}
            except Exception as e:
                log.error(
                    "failed get delegated token for login={} orgid={}: {}".format(
                        login, collector["org_id"], e
                    )
                )
        return tokens

    def _save(self, login_to_access_token):
        os.makedirs(os.path.dirname(self.file), exist_ok=True)
        with open(self.file, "w") as file:
            for login, token_ts in login_to_access_token.items():
                file.write("{} {} {}\n".format(login, token_ts["token"], token_ts["created_at"]))


def _domain(email):
    return email[email.rfind("@") + 1 :]


def _current_timestamp():
    return int(time.time())


def _test_collector(src_email, domain):
    test_src_emails = ("devops@kontrtest.onmicrosoft.com", "devops@xn--90aru.com")
    test_dst_domains = ("test005.gan4test.ru", "omenname.auto.connect-tk.tk")
    return domain in test_dst_domains and src_email in test_src_emails


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--file", help="access tokens file path", required=True)
    parser.add_argument("--force", help="force token update", action="store_true")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    op = ReloadTokensOp(args.file, args.force)
    op()
