import argparse

from library.python.vault_client.errors import ClientNoKeysInSSHAgent
from travel.rasp.bus.spark_api import model_serializer
from travel.rasp.bus.spark_api.client import ENVIRONMENT_TO_HOST, SparkClient, SparkConfig
from travel.library.python.rasp_vault.api import get_secret
from travel.library.python.rasp_vault.common import SecretNotFoundError


def create_spark_config(environment='production'):
    value = get_secret(f'spark-{environment}-trains')
    login, password = value['login'], value['password']

    return SparkConfig(
        host=ENVIRONMENT_TO_HOST[environment],
        login=login,
        password=password,
    )


def parse_args():
    parser = argparse.ArgumentParser(description='Spark Client')
    parser.add_argument(
        '--env',
        type=str,
        default='production',
        required=False,
    )
    parser.add_argument(
        '--format',
        type=str,
        default='csv',
        required=False,
    )
    subparsers = parser.add_subparsers(dest='cmd')
    subparsers.required = True
    subparsers.add_parser('regions', help='Show all regions')
    find_company_cmd = subparsers.add_parser('company', help='Find company by name')
    find_company_cmd.add_argument(
        '--name',
        type=str,
        required=True,
    )
    find_company_cmd.add_argument(
        '--region',
        type=str,
        action='append',
        default=None,
        required=False,
    )
    find_company_cmd.add_argument(
        '--okopf',
        type=str,
        default=None,
        required=False,
    )

    find_entrepreneur_cmd = subparsers.add_parser('ip', help='Find entrepreneur by fio')
    find_entrepreneur_cmd.add_argument(
        '--name',
        type=str,
        required=True,
    )
    find_entrepreneur_cmd.add_argument(
        '--region',
        type=str,
        action='append',
        default=None,
        required=False,
    )

    company_info_cmd = subparsers.add_parser('company-info', help='Find company info')
    for param in ['id', 'inn', 'ogrn']:
        company_info_cmd.add_argument(
            f'--{param}',
            type=str,
            required=False,
        )
    entrepreneur_okved_list_cmd = subparsers.add_parser('ip-info', help='Find entrepreneur info')
    for param in ['inn', 'ogrnip']:
        entrepreneur_okved_list_cmd.add_argument(
            f'--{param}',
            type=str,
            required=False,
        )

    return parser.parse_args()


def main():
    args = parse_args()
    try:
        config = create_spark_config(args.env)
    except ClientNoKeysInSSHAgent as e:
        if e.message == 'No keys in SSH Agent':
            print('You do not have a SSH Agent')
            return
        raise
    except SecretNotFoundError:
        print('You do not have access to spark api secrets. You can ask the access here:')
        print('Production:  https://yav.yandex-team.ru/secret/sec-01d540tg262va57y9s91t9zam4/explore/versions')
        print('Testing:     https://yav.yandex-team.ru/secret/sec-01d544h0pdj9a0mst1vbxpcr7x/explore/versions')
        return

    with SparkClient.create(config) as client:
        if args.cmd == 'regions':
            print(model_serializer.RegionSerializer.serialize(client.list_regions(), format=args.format))
        elif args.cmd == 'company':
            print(model_serializer.CompanySerializer.serialize(
                client.find_companies_by_name(
                    args.name, args.region, args.okopf
                ),
                format=args.format,
            ))
        elif args.cmd == 'ip':
            print(model_serializer.EntrepreneurSerializer.serialize(
                client.find_entrepreneurs_by_name(args.name, args.region),
                format=args.format,
            ))
        elif args.cmd == 'company-info':
            print(model_serializer.OkvedSerializer.serialize(
                client.get_company_report(args.id, args.inn, args.ogrn).okveds,
                format=args.format,
            ))
        elif args.cmd == 'ip-info':
            print(model_serializer.OkvedSerializer.serialize(
                client.get_entrepreneur_report(args.inn, args.ogrnip).okveds,
                format=args.format,
            ))
        else:
            print('Unsupported command')


if __name__ == '__main__':
    main()
