import json
from datetime import timedelta
from contextlib import closing
from typing import Dict, Callable, List, Tuple

from tractor.disk.db import Database
from tractor.error import NOT_ENOUGH_QUOTA, EXTERNAL_USER_NOT_FOUND
from tractor_disk.disk_error import ExternalUserNotFound
from tractor.logger import DeployLogger
from tractor_disk.mds import MDSClient
from tractor.disk.models import TaskType
from tractor.models import Task, TaskWorkerStatus
from tractor_disk.common import NULL_STR
from tractor_disk.settings import Settings
from tractor_disk.workers.env import Env


def acquire_task(env: Env) -> Task:
    settings: Settings = env["settings"]
    db: Database = env["db"]
    with closing(db.make_connection()) as conn, conn, conn.cursor() as cur:
        return db.acquire_task(
            type=TaskType.LIST,
            expiry_timeout=timedelta(seconds=settings.tasking.expiry_timeout_in_seconds),
            worker_id=env["worker_id"],
            cur=cur,
        )


def trycatch(method):
    def wrap(task: Task, env: Env):
        try:
            method(task=task, env=env)
        except Exception as e:
            try:
                db: Database = env["db"]
                env["logger"].exception(message="list task failed", exception=e)
                msg = _error_message_from_exception(e)
                with closing(db.make_connection()) as conn, conn, conn.cursor() as cur:
                    db.fail_listing_task(error=msg, task_id=task.task_id, cur=cur)
            except Exception as another_e:
                env["logger"].exception(message="cannot fail list task", exception=another_e)

    return wrap


def _error_message_from_exception(e):
    if isinstance(e, ExternalUserNotFound):
        return EXTERNAL_USER_NOT_FOUND
    return str(e)


def dump_listing_args(files, multiple_parent_files, skipped_files) -> bytes:
    return json.dumps(
        {
            "files": files,
            "user_multiple_parent_files": multiple_parent_files,
            "skipped_files": skipped_files,
        }
    ).encode()


def parse_listing_args(args: str) -> Tuple[Dict, Dict]:
    parsed = json.loads(args)
    return parsed["files"], parsed["user_multiple_parent_files"]


def get_quota(ya_disk) -> int:
    ya_disk_resp = ya_disk.quote()
    return ya_disk_resp["total_space"]


def make_file_info(file) -> Dict:
    return {
        "id": file.id(),
        "size": file.size(),
        "path": file.path(),
        "mimeType": file.mime_type(),
        "ownedByMe": file.owned_by_me(),
    }


@trycatch
def run_task(task: Task, env: Env):
    db: Database = env["db"]
    logger: DeployLogger = env["logger"]
    if task.canceled:
        with closing(db.make_connection()) as conn, conn, conn.cursor() as cur:
            db.set_error_for_cancelled_task(task_id=task.task_id, cur=cur)
            logger.info(message="listing task canceled")
            return

    if task.worker_status != TaskWorkerStatus.PENDING:
        return

    src_disk = env["disk_pair"].src_disk
    dst_disk = env["disk_pair"].dst_disk
    mds: MDSClient = env["mds"]
    mapping_op: Callable[[List], Callable] = env["mapping"]
    file_cls = env["file_cls"]

    files = src_disk.get_files()
    for file in files:
        file["root_folder_id"] = src_disk.root_folder_id()

    logger.info(message="listed files", count=format(len(files)))

    op = mapping_op(files=files)
    singleparent, multiparent = op()

    ret = []
    skipped = []
    src_files_sz: int = 0
    for file in files:
        file["path"] = singleparent[file["id"]]
        new_file = file_cls(file)
        if not new_file.owned_by_me():
            continue
        if src_disk.is_downloadable(file):
            target_list = ret
            if new_file.size() != NULL_STR:
                src_files_sz += int(new_file.size())
        else:
            target_list = skipped
        target_list.append(make_file_info(new_file))

    data = dump_listing_args(files=ret, multiple_parent_files=multiparent, skipped_files=skipped)
    stid = mds.upload(filename="listing", data=data)
    logger.info(message="listing uploaded to mds", stid=stid)
    quota = get_quota(dst_disk)

    if src_files_sz > quota:
        try:
            logger.error(message=NOT_ENOUGH_QUOTA, src_files_sz=src_files_sz)
            with closing(db.make_connection()) as conn, conn, conn.cursor() as cur:
                db.fail_listing_task(
                    error=NOT_ENOUGH_QUOTA,
                    task_id=task.task_id,
                    cur=cur,
                    stid=stid,
                    files_count=len(ret),
                    files_size=src_files_sz,
                    quota=quota,
                )
        except Exception as another_e:
            logger.exception(message="cannot fail list task", exception=another_e)
        return

    with closing(db.make_connection()) as conn, conn, conn.cursor() as cur:
        output = json.dumps(
            {"stid": stid, "files_count": len(ret), "files_size": src_files_sz, "quota": quota}
        )
        db.finish_task(output=output, task_id=task.task_id, cur=cur)
