import logging

import sqlalchemy as sa

from .. import db

from .log import Logger

from maps.wikimap.mapspro.libs.python import geolocks


def get_detached_task(cls, id):
    with db.get_write_session(db.CORE_DB) as session:
        task = session.query(cls).get(id)
        sa.orm.make_transient(task)
        return task


class StageGuard:
    def __init__(self, task, pgpool):
        self.task_id = task.id
        self.db_logger = Logger(task.id)
        self.logger = logging.getLogger('StageGuard')

        if hasattr(task, 'lock_ids'):
            self.locks = [geolocks.GeoLock.load(pgpool, lock_id)
                          for lock_id in task.lock_ids]
        elif task.lock_id != 0:
            self.locks = [geolocks.GeoLock.load(pgpool, task.lock_id)]
        else:
            self.locks = []
            self.db_logger.warning('using dummy lock')

        self.pgpool = pgpool

        self.unlock_on_success = True
        self.unlock_on_error = True
        self._stage = None

    def __enter__(self):
        return self

    def next_stage(self, stage):
        if self._stage is not None:
            self.db_logger.info(self._stage + ' done')
        self._stage = stage

    def __exit__(self, exc_type, exc_val, exc_tb):
        stage_str = self._stage or 'task'
        if exc_val is None:
            self.db_logger.info(stage_str + ' done')
            if self.unlock_on_success:
                geolocks.unlock_all(self.pgpool, self.locks)
                self.db_logger.info('unlocked')
            return

        # exception occured
        log_msg = stage_str + ' failed'
        self.db_logger.error(log_msg)

        if self.unlock_on_error:
            self._try_unlock()

    def _try_unlock(self):
        try:
            self.logger.warn('task id: %s: trying to unlock lock ids: %s',
                             self.task_id, ','.join(str(l.id) for l in self.locks))
            geolocks.unlock_all(self.pgpool, self.locks)
        except RuntimeError:
            self.logger.exception('task id: %s: failed to unlock', self.task_id)
