import aiopg
import tenacity
import logging
from contextlib import asynccontextmanager
from dataclasses import dataclass

from mail.shiva.stages.api.settings.sharpei import SharpeiSettings

log = logging.getLogger(__name__)


@dataclass
class TaskParams:
    sharpei: SharpeiSettings = None
    shard_id: int = None
    task_name: str = None
    db_user: str = 'cron'
    job_no: int = 0
    jobs_count: int = 1


class HuskydbEngine(object):
    def __init__(self, pg: aiopg.Pool):
        self.pg_pool = pg

    @asynccontextmanager
    async def connection(self):
        async with self.pg_pool.acquire() as conn:
            try:
                yield conn
            except Exception as exc:
                log.error('Exception during request to huskydb: %s', exc)
                conn.close()
                raise exc


async def get_task(huskydb: HuskydbEngine, params: TaskParams):
    async with huskydb.connection() as conn:
        async with conn.cursor() as cur:
            await cur.execute(
                '''
                SELECT worker, to_char(started, 'YYYY-MM-DD HH:MI:SS')
                  FROM shiva.shard_running_tasks
                 WHERE shard_id = %(shard_id)s
                   AND task = %(task)s
                   AND job_no = %(job_no)s
                ''',
                dict(
                    shard_id=params.shard_id,
                    task=params.task_name,
                    job_no=params.job_no
                )
            )
            return await cur.fetchone()


async def plan_task(huskydb: HuskydbEngine, worker_name: str, params: TaskParams):
    async with huskydb.connection() as conn:
        async with conn.cursor() as cur:
            try:
                await cur.execute('BEGIN')
                await cur.execute(
                    '''
                    INSERT INTO shiva.shard_running_tasks (shard_id, task, job_no, worker)
                    VALUES (%(shard_id)s, %(task)s, %(job_no)s, %(worker)s)
                    ON CONFLICT DO NOTHING
                    RETURNING worker, to_char(started, 'YYYY-MM-DD HH:MI:SS')
                    ''',
                    dict(
                        shard_id=params.shard_id,
                        task=params.task_name,
                        job_no=params.job_no,
                        worker=worker_name
                    )
                )
                res = await cur.fetchone()
                await cur.execute('COMMIT')
                return res
            except Exception as exc:
                await cur.execute('ROLLBACK')
                log.error('got exception during plan_task: %s', exc)
                raise exc


@tenacity.retry(wait=tenacity.wait_exponential(multiplier=0.5, min=1, max=300) + tenacity.wait_random(min=0, max=5))
async def finish_task(huskydb: HuskydbEngine, params: TaskParams, status: str = 'complete', notice: str = None):
    async with huskydb.connection() as conn:
        async with conn.cursor() as cur:
            await cur.execute(
                '''
                WITH deleted_tasks AS (
                    DELETE FROM shiva.shard_running_tasks
                     WHERE shard_id = %(shard_id)s
                       AND task = %(task)s
                       AND job_no = %(job_no)s
                    RETURNING shard_id, task, job_no, started, worker
                )
                INSERT INTO shiva.shard_tasks_info
                    (shard_id, task, job_no, started, worker, status, notice)
                SELECT shard_id, task, job_no, started, worker, %(status)s, %(notice)s
                  FROM deleted_tasks
                ''',
                dict(
                    shard_id=params.shard_id,
                    task=params.task_name,
                    job_no=params.job_no,
                    status=status,
                    notice=notice
                )
            )


async def clean_tasks(huskydb: HuskydbEngine, worker_name: str):
    async with huskydb.connection() as conn:
        async with conn.cursor() as cur:
            await cur.execute(
                '''
                DELETE FROM shiva.shard_running_tasks
                 WHERE worker = %(worker)s
                ''',
                dict(
                    worker=worker_name
                )
            )


async def get_running_task_stats(huskydb: HuskydbEngine, worker_name: str):
    async with huskydb.connection() as conn:
        async with conn.cursor() as cur:
            await cur.execute(
                '''
                SELECT task, count(*)
                  FROM shiva.shard_running_tasks
                 WHERE worker = %(worker)s
                 GROUP BY task;
                ''',
                dict(
                    worker=worker_name
                )
            )
            return await cur.fetchall()
