# coding: utf-8
"""Interface to husky SQL primitives for huskydb"""
import json
import logging
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import Optional

import aiopg
from psycopg2 import IntegrityError
from psycopg2.errorcodes import UNIQUE_VIOLATION

from mail.husky.husky.types import UserData, Status, Errors
from mail.pypg.pypg.common import describe_cursor
from mail.pypg.pypg.query_conf import load_from_package
from ora2pg.tools.unicode_helpers import safe_unicode

log = logging.getLogger(__name__)
DEFAULT_HUSKY_CLUSTER = 'husky-worker'


def retry_as_failed(func):
    async def decorated(*args, **kwargs):
        try:
            log.info('First try')
            return await func(*args, mark_failed=False, **kwargs)
        except IntegrityError as e:
            log.info('Caught expection')
            if e.pgcode == UNIQUE_VIOLATION:
                log.warning('Task is duplicated, marking as failed')
                log.info('Second try')
                return await func(*args, mark_failed=True, **kwargs)
            raise

    return decorated


class HuskydbTransfer(object):
    queries = load_from_package(__package__, __file__)

    def __init__(self, app_args, pg: aiopg.Pool, shard_id: int,  husky_cluster: str):
        # TODO :: Move to common settings
        self.blackbox = app_args.blackbox
        self.tvm = app_args.tvm
        self.bb_tvm_id = app_args.bb_tvm_id
        self.mailhost = app_args.mailhost
        self.shard_id = shard_id
        self.husky_cluster = husky_cluster
        self.max_tries = app_args.max_tries
        self.tries_delays = dict(enumerate(
            app_args.fail_retry_wait_progression or
            [timedelta(minutes=minutes) for minutes in [1, 5, 15, 40, 60]],
            1
        ))
        self.pg_pool = pg
        log.info('Initialized husky_transfer shard_id=%d', self.shard_id)

    @asynccontextmanager
    async def _huskydb_query(self, query, **query_args):
        with (await self.pg_pool.cursor()) as cur:
            await cur.execute(query.query, query_args)
            yield cur

    async def _exec_huskydb_query(self, query, **query_args):
        async with self._huskydb_query(query, **query_args):
            pass

    @retry_as_failed
    async def return_in_progress_tasks(self, mark_failed=False):
        """
        Find users with status "in_progress" and treat them as if
        they failed to be transferred
        """
        async with self._huskydb_query(
            self.queries.process_in_progress,
            shard_id=self.shard_id,
            husky_cluster=self.husky_cluster,
            default_cluster=DEFAULT_HUSKY_CLUSTER,
            max_tries=json.dumps(self.max_tries),
            status=Status.Pending if not mark_failed else Status.Error,
            error=None if not mark_failed else Errors.WrongArgs.name,
        ) as cur:
            # TODO :: Move to DictCursor
            returned_tasks = [row[0] for row in await cur.fetchall() if row]
            log.info(
                'Found tasks in "in_progress" status, returned them back: %r',
                returned_tasks
            )

    async def acquire_tasks(self, usercount):
        """
        Get some users from db queue, sorted by "last_update" and
        "priority" fields. Update their status as "in_transfer"
        """

        if usercount <= 0:
            return []

        async with self._huskydb_query(
            self.queries.acquire_tasks,
            shard_id=self.shard_id,
            husky_cluster=self.husky_cluster,
            default_cluster=DEFAULT_HUSKY_CLUSTER,
            usercount=usercount,
        ) as cur:
            col_names = describe_cursor(cur)
            return [UserData(**dict(zip(col_names, row))) for row in await cur.fetchall()]

    async def _mark_task_as_completed(self, transfer_id, task_output):
        async with self._huskydb_query(
            self.queries.complete_task,
            transfer_id=transfer_id,
            task_output=json.dumps(task_output) if task_output is not None else None,
        ) as cur:
            return await cur.fetchone()

    async def _mark_task_as_failed(self, transfer_id, error, error_message, task_output):
        async with self._huskydb_query(
            self.queries.fail_task,
            transfer_id=transfer_id,
            error_type=error.name,
            error_message=safe_unicode(error_message, 'error message'),
            task_output=json.dumps(task_output) if task_output is not None else None,
        ) as cur:
            return await cur.fetchone()

    async def _mark_task_as_to_be_retried(self, transfer_id, delay, error_message, task_output):
        async with self._huskydb_query(
            self.queries.retry_task,
            transfer_id=transfer_id,
            delay=delay,
            error_message=safe_unicode(error_message, 'error message'),
            task_output=json.dumps(task_output) if task_output is not None else None,
        ) as cur:
            return await cur.fetchone()

    async def _get_task_info(self, transfer_id):
        async with self._huskydb_query(
            self.queries.get_task_info,
            transfer_id=transfer_id,
        ) as cur:
            return await cur.fetchone()

    async def update_user_status(self, result):
        """Set task status to "complete" if it was ok,
           set task status to "error" if tries counter was exceeded or error is non-retryable,
           else set task status to "pending" for retrying it late.
           Save try error_message, if any."""

        if result.error == Errors.NoError:
            actual_status, tries = await self._mark_task_as_completed(
                transfer_id=result.transfer_id,
                task_output=result.task_output,
            )
        else:
            if result.error.is_retriable:
                task, tries = await self._get_task_info(transfer_id=result.transfer_id)
            if result.error.is_retriable and tries + 1 < self.max_tries[task]:
                delay = self.tries_delays.get(tries + 1, timedelta(seconds=0))
                actual_status, tries = await self._mark_task_as_to_be_retried(
                    transfer_id=result.transfer_id,
                    delay=delay,
                    error_message=result.error_message,
                    task_output=result.task_output,
                )
            else:
                actual_status, tries = await self._mark_task_as_failed(
                    transfer_id=result.transfer_id,
                    error=result.error,
                    error_message=result.error_message,
                    task_output=result.task_output,
                )
        log.info('Actual status: %s ; tries: %d', actual_status, tries)

    async def get_cluster_id(self, shard_id):
        async with self._huskydb_query(
            self.queries.get_cluster_id_by_shard_id,
            shard_id=shard_id,
        ) as cur:
            res = await cur.fetchone()
            if res is None:
                raise ValueError("shard_id not found")
        return res[0]

    async def get_max_workers_by_shard_id(self, shard_id: int) -> Optional[int]:
        async with self._huskydb_query(
            self.queries.get_max_workers_by_shard_id,
            shard_id=shard_id,
        ) as cur:
            res = await cur.fetchone()
            if res is None:
                return res
            else:
                return res[0]
