from typing import Any, Optional, Dict, Callable
from functools import wraps
from os import getenv, uname, getpid
import json
from tractor.logger import get_logger
from tractor.versioned_keys import load as _load_versioned_keys
from tractor.disk.db import Database
from tractor.disk.models import TaskType
from tractor.models import ExternalProvider, Task
from tractor.exception import YandexUserNotFound
from tractor.error import YANDEX_USER_NOT_FOUND
from tractor.settings import Settings
from tractor.crypto.fernet import Fernet
from tractor_disk.disk_pair import DiskPair
from tractor_disk.mds import create_mds_client_from_settings
from tractor_disk.settings import settings
from tractor_disk.source_drive import (
    SourceDrive,
    get_create_drive_op_with_explicit_secret,
    get_file_by_source,
    get_path_mapping_op_by_source,
)
from tractor_disk.sync.sync_file import SyncFileOp
from tractor_disk.yandex_disk import create_yandex_disk


def _source_drive_from_provider(provider: ExternalProvider) -> SourceDrive:
    if provider == ExternalProvider.GOOGLE:
        return SourceDrive.GOOGLE
    if provider == ExternalProvider.MICROSOFT:
        return SourceDrive.MICROSOFT
    return None


def _retrieve_secret(
    settings: Settings, db: Database, org_id, provider: ExternalProvider
) -> object:
    with db.make_connection() as conn:
        with conn.cursor() as cur:
            encrypted_secret: bytes = db.get_external_secret(org_id, provider, cur)
    fernet: Fernet = Fernet(_load_versioned_keys(settings))
    secret: str = fernet.decrypt_text(encrypted_secret)
    return json.loads(secret)


Env = Dict[str, Any]


def trycatch(method: Callable[..., Env]):
    @wraps(method)
    def wrap(task: Task, env: Env) -> Optional[Env]:
        try:
            return method(task=task, env=env)
        except Exception as e:
            try:
                db: Database = env["db"]
                env["logger"].exception(
                    message="make task env failed",
                    exception=e,
                    task_type=task.type,
                    task_id=task.task_id,
                    org_id=task.org_id,
                )
                with db.make_connection() as conn:
                    with conn.cursor() as cur:
                        msg = _error_message_from_exception(e)
                        if task.type == TaskType.LIST:
                            env["db"].fail_listing_task(error=msg, task_id=task.task_id, cur=cur)
                        elif task.type == TaskType.SYNC:
                            env["db"].fail_sync_task(error=msg, task_id=task.task_id, cur=cur)
                        else:
                            raise ValueError("invalid task type: {}".format(task.type))
            except Exception as another_e:
                env["logger"].exception(
                    message="cannot fail task",
                    exception=another_e,
                    task_type=task.type,
                    task_id=task.task_id,
                    org_id=task.org_id,
                )

    return wrap


def _error_message_from_exception(e):
    if isinstance(e, YandexUserNotFound):
        return YANDEX_USER_NOT_FOUND
    return str(e)


def make_env() -> Env:
    res = {}
    res["worker_id"] = getenv("DEPLOY_POD_TRANSIENT_FQDN", uname().nodename) + "/" + str(getpid())
    res["settings"] = settings()
    res["db"] = Database(settings().tractor_disk_db)
    res["mds"] = create_mds_client_from_settings()
    res["logger"] = get_logger()
    return res


def common_task_env(task: Task, env: Env) -> Env:
    res = {}
    res["db"] = env["db"]
    res["mds"] = env["mds"]
    res["logger"] = get_logger(task_id=task.task_id, org_id=task.org_id)

    inp = task.input
    for key in ("user", "provider"):
        if key not in inp:
            raise KeyError("{} key is required".format(key))

    user = inp["user"]
    user["org_id"] = task.org_id
    user["domain"] = task.domain

    for key in ("uid", "login", "email"):
        if key not in user:
            raise KeyError("user.{} key is required".format(key))

    if user["uid"] is None:
        raise YandexUserNotFound()

    provider = ExternalProvider(inp["provider"])
    source = _source_drive_from_provider(provider)
    secret: object = _retrieve_secret(env["settings"], res["db"], task.org_id, provider)
    create_src_disk = get_create_drive_op_with_explicit_secret(secret, source)
    res["disk_pair"] = DiskPair(
        user=user, create_src_disk=create_src_disk, create_dst_disk=create_yandex_disk
    )

    return res


@trycatch
def make_sync_task_env(task: Task, env: Env) -> Env:
    res = common_task_env(task, env)
    res["sync_op"] = SyncFileOp(res["disk_pair"])
    return res


@trycatch
def make_list_task_env(task: Task, env: Env) -> Env:
    res = common_task_env(task, env)

    provider = ExternalProvider(task.input["provider"])
    source = _source_drive_from_provider(provider)

    res["file_cls"] = get_file_by_source(source)

    path_mapping_op_cls = get_path_mapping_op_by_source(source)
    res["mapping"] = lambda files: path_mapping_op_cls(files)

    return res
