#!/usr/bin/env python3

import argparse
import json
import os
import re
import psycopg2
import sys
from base64 import b64decode

from tractor_disk.source_drive import SourceDrive
from tractor.crypto import Fernet, VersionedKeyStorage

from tractor.logger import get_logger


CONNINFO = {
    "testing": "dbname='ipa_db_test' user='ipa' host='sas-3i7yq4idm26t7cns.db.yandex.net' port='6432'",
    "production": """
        dbname='ipa_db'
        user='ipa'
        host='man-aylbm9aix6k44t7m.db.yandex.net,sas-si83i3fji95kk23d.db.yandex.net,vla-lpmhheb3szhk6e9s.db.yandex.net'
        port='6432'
    """,
}
DEFAULT_FILE_PATH = "/app/data/.google_secret.json"
DEFAULT_KEY_VERSIONS_PATH = "/app/data/.ipa_key_versions.json"


class IPADatabase:
    def __init__(self, conninfo):
        self.conninfo = conninfo

    def get_last_password(self, org_id, src_login):
        params = self._get_last_import_params(org_id, src_login)
        if params is None:
            raise RuntimeError("No secrets in db for org_id={}".format(org_id))
        user_info = params[0]["users"][0]
        return user_info["password"]

    def _get_last_import_params(self, org_id, src_login):
        with psycopg2.connect(self.conninfo) as conn:
            with conn.cursor() as cur:
                cur.execute(
                    """
                SELECT params
                FROM tasks
                WHERE entity_id = %s AND task_type = 'init_import' AND params#>>'{users,0,src_login}' = %s
                ORDER BY CREATED DESC
                LIMIT 1
                """,
                    (org_id, src_login),
                )
                return cur.fetchone()


class Decryptor:
    def __init__(self, key_versions):
        key_version_storage = VersionedKeyStorage(key_versions)
        self.fernet = Fernet(key_version_storage)

    def decrypt(self, password):
        secret_base64 = self.fernet.decrypt(b64decode(password)).decode("utf8")
        return json.loads(b64decode(secret_base64))


def main():
    args = parse_args()
    env = args.env if args.env else env_from_stage()
    org_id = args.org_id if args.org_id else org_id_from_deploy_unit()
    conninfo = CONNINFO[env]
    key_versions = json.loads(open(args.keys).read().strip())
    fake_login = args.source.fake_src_login()
    get_logger().info(
        "load secret",
        org_id=org_id,
        params={"env": env, "fake_login": fake_login, "output_path": args.path},
    )
    db = IPADatabase(conninfo)
    password = db.get_last_password(org_id, fake_login)
    decryptor = Decryptor(key_versions)
    secret = decryptor.decrypt(password)
    save(secret, args.path)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", help="environment; extracted from DEPLOY_STAGE_ID if not defined")
    parser.add_argument(
        "--org-id",
        dest="org_id",
        help="organization id; extracted from DEPLOY_UNIT_ID if not defined",
    )
    parser.add_argument("--keys", help="key versions filepath", default=DEFAULT_KEY_VERSIONS_PATH)
    parser.add_argument("--path", help="output file path", default=DEFAULT_FILE_PATH)
    parser.add_argument(
        "--source",
        help="source drive",
        choices=list(SourceDrive),
        type=SourceDrive.from_string,
        required=True,
    )
    args = parser.parse_args()
    return args


def org_id_from_deploy_unit():
    deploy_unit = os.environ["DEPLOY_UNIT_ID"]
    matched = re.match("^org-(?P<org_id>\d+)$", deploy_unit)
    if not matched:
        raise RuntimeError("Failed to parse deploy unit id: {}".format(deploy_unit))
    return matched.group("org_id")


def save(secret, path):
    with open(path, "w") as file:
        json.dump(secret, file)


def env_from_stage():
    stage = os.environ["DEPLOY_STAGE_ID"]
    matched = re.match("^mail_tractor_disk_(?P<env>production|testing)$", stage)
    if not matched:
        raise RuntimeError("Failed to parse deploy stage id: {}".format(stage))
    return matched.group("env")


if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        get_logger().error("load secret error", reason=str(e))
