import json
from datetime import timedelta
from itertools import islice
from typing import Callable, Dict

from tractor.disk.db import Database
from tractor_disk.mds import MDSClient
from tractor.disk.models import TaskType
from tractor.models import Task, TaskWorkerStatus
from tractor_disk.workers.sync_stat import Stat
from tractor.util.common import exception_info
from tractor_disk.workers.env import Env
from tractor_disk.workers.list_worker import parse_listing_args
from tractor_disk.disk_error import exception_to_category


def acquire_task(env: Env) -> Task:
    settings: Settings = env["settings"]
    db: Database = env["db"]
    with db.make_connection() as conn:
        with conn.cursor() as cur:
            return db.acquire_task(
                type=TaskType.SYNC,
                expiry_timeout=timedelta(seconds=settings.tasking.expiry_timeout_in_seconds),
                worker_id=env["worker_id"],
                cur=cur,
            )


def trycatch(method):
    def wrap(task: Task, env: Env):
        try:
            method(task=task, env=env)
        except Exception as e:
            try:
                db: Database = env["db"]
                env["logger"].exception(message="sync task failed", exception=e)
                with db.make_connection() as conn:
                    with conn.cursor() as cur:
                        db.fail_sync_task(error=str(e), task_id=task.task_id, cur=cur)
            except Exception as another_e:
                env["logger"].exception(message="cannot fail sync task", exception=another_e)

    return wrap


@trycatch
def run_task(task: Task, env: Env):
    logger = env["logger"]
    logger.info("sync started")

    db: Database = env["db"]
    if task.canceled:
        with db.make_connection() as conn:
            with conn.cursor() as cur:
                db.set_error_for_cancelled_task(task_id=task.task_id, cur=cur)
                logger.info("task is cancelled")
                return

    if task.worker_status != TaskWorkerStatus.PENDING:
        logger.info("worker status is incorrect: {} != PENDING".format(task.worker_status))
        return

    mds: MDSClient = env["mds"]
    sync_op: Callable[[Dict], bool] = env["sync_op"]

    chunk = mds.download(filepath=task.input["stid"])

    files_to_sync, _ = parse_listing_args(chunk)

    worker_num, workers_count = task.input["worker_num"], task.input["workers_count"]
    log_progress_every = env.get("log_progress_every", 100)
    stat = Stat(total_count=len(files_to_sync) // workers_count)

    for row in islice(files_to_sync, worker_num, None, workers_count):
        exception, err_category = None, None
        try:
            synced = sync_op(row)
            if synced:
                stat.synced()
                logger.info("synced", **row)
            else:
                stat.exists()
        except Exception as e:
            exception, err_category = exception_to_category(e)

        if exception:
            stat.error(category=err_category)
            logger.error(
                "sync error", err_category=err_category, **row, **exception_info(exception)
            )
        processed = stat.get_processed()
        if processed >= stat.get_total() or (processed % log_progress_every) == 0:
            logger.info("sync status", **stat.get())

    # TODO: maybe save failed rows to mds
    with db.make_connection() as conn:
        with conn.cursor() as cur:
            out = json.dumps(stat.get())
            db.finish_sync_task(output=out, task_id=task.task_id, cur=cur)

    logger.info("sync finished")
