import asyncio

from logbroker.public.api.protos.config_manager_pb2 import (
    SingleModifyRequest,
    CreateRemoteMirrorRuleRequest,
    RemoveRemoteMirrorRuleRequest,
)

from library.python.vault_client.instances import Production as VaultClient

from saas.library.python.deploy_manager_api import SaasService
from saas.library.python.logbroker import SaaSConfigurationManager, ConfigurationClient, LogbrokerTopic
from saas.library.python.token_store import PersistentTokenStore


def get_logbroker_token() -> str:
    yav_secret_id = 'sec-01dmzfn3qhxes8y4yyfar0xyhs'  # saas-robot
    yav_token = PersistentTokenStore.get_token_from_store_env_or_file('yav')
    yav_client = VaultClient(authorization=f'OAuth {yav_token}', decode_files=True)

    version = yav_client.get_version(yav_secret_id)
    return version['value']['logbroker_token']


async def main():
    logbroker_token = get_logbroker_token()
    token_start = logbroker_token[:4]
    token_end = logbroker_token[-4:]

    print('Loaded "logbroker_token" from yav')

    client = ConfigurationClient()

    manager = SaaSConfigurationManager()
    ns_names = await manager.get_namespace_names()

    print('Starting reading configs...')

    invalid_rules = []
    broken_topic_paths = []

    for ns_name in ns_names:
        ns = manager.get_namespace(ns_name)

        print(f'Loading namespace: {ns.name}')

        async for service in ns.get_services():
            config = await service.get_config()
            if not config.logbroker_mirror:
                continue

            slots_by_interval = await asyncio.get_event_loop().run_in_executor(
                None,
                lambda: SaasService(service.ctype, service.name).slots_by_interval
            )
            if not slots_by_interval:
                print(f'Unable to get shard & slot info for service {service.name} {service.ctype}, skipping...')
                continue

            for interval in slots_by_interval:
                shard = interval['id']

                topic_path = f'{config.mirror_topics_path}/shard-{shard}'
                topic = LogbrokerTopic(client, topic_path)
                result = await topic.describe()

                is_valid = True
                for rule in result.remote_mirror_rules:
                    token = rule.properties.credentials.oauth_token
                    ts = token[:4]
                    te = token[-4:]

                    if ts != token_start or te != token_end:
                        invalid_rules.append(rule)
                        is_valid = False

                if not is_valid:
                    broken_topic_paths.append(topic_path)

    print('Broken topic paths:\n', '\n'.join(broken_topic_paths) or '-')

    if not broken_topic_paths:
        print('No fixes needed')
        return

    ans = input(
        'To start fixing mirror rules for topics above, type "yes"\n\n'
        'WARNING!\n'
        'You should not stop the process, otherwise mirror rules can be partly removed\n'
        'If it happened, you have to MANUALLY check that everything is correct and restore missing rules if needed\n'
        'To get info about mirror rules for a topic, you can use:\n'
        'ya tool logbroker -s <cluster> schema describe <topic>\n\n'
        'Continue? — '
    )
    if ans != 'yes':
        print('Cancelled')
        return

    print('Trying to fix invalid mirror rules...')

    for rule in invalid_rules:
        await client.execute_modify_commands([
            SingleModifyRequest(
                remove_remote_mirror_rule=RemoveRemoteMirrorRuleRequest(
                    remote_mirror_rule=rule.remote_mirror_rule,
                )
            )
        ])

        props = rule.properties
        props.credentials.oauth_token = logbroker_token

        await client.execute_modify_commands([
            SingleModifyRequest(
                create_remote_mirror_rule=CreateRemoteMirrorRuleRequest(
                    remote_mirror_rule=rule.remote_mirror_rule,
                    properties=props,
                )
            )
        ])

    print('We did it!')


if __name__ == '__main__':
    asyncio.run(main())
