from load.projects.cloud.loadtesting.db.tables import AgentVersionTable, AgentVersionStatus
from typing import Optional


class AgentVersionQueries:
    def __init__(self, session):
        self._session = session

    def get(self, image_id: str) -> Optional[AgentVersionTable]:
        return self._session.query(AgentVersionTable).filter(AgentVersionTable.image_id == image_id).first()

    def get_target(self) -> AgentVersionTable:
        return self._session.query(AgentVersionTable).filter(
            AgentVersionTable.status == AgentVersionStatus.TARGET.value
        ).first()

    def switch_target(self, new_target):
        for old_target in self._session.query(AgentVersionTable).filter(
                AgentVersionTable.status == AgentVersionStatus.TARGET.value
        ).all():
            old_target.status = AgentVersionStatus.ACTUAL.value
        new_target.status = AgentVersionStatus.TARGET.value
        return new_target

    def add(self, version: AgentVersionTable):
        self._session.add(version)

    def change_status(self, version: AgentVersionTable, new_status: AgentVersionStatus):
        version.status = new_status.value

    def deprecate_up_to(self, version_id):
        border_version = self.get(version_id)
        assert border_version is not None, "Version not found"
        assert border_version.revision, "Version with revision required."

        current_target = self.get_target()
        assert current_target.revision > border_version.revision, 'Target version can not be deprecated.'

        deprecated = []
        for version in self._session.query(AgentVersionTable).filter(
                AgentVersionTable.revision <= border_version.revision,
                AgentVersionTable.status == AgentVersionStatus.ACTUAL.value,
        ).all():
            version.status = AgentVersionStatus.DEPRECATED.value
            deprecated.append(version)
        return deprecated

    def outdate_up_to(self, version_id):
        border_version = self.get(version_id)
        assert border_version is not None, "Version not found."
        assert border_version.revision, "Version with revision required."

        current_target = self.get_target()
        assert current_target.revision > border_version.revision, 'Target version can not be outdated.'

        outdated = []
        for version in self._session.query(AgentVersionTable).filter(
                AgentVersionTable.revision <= border_version.revision,
                AgentVersionTable.status.in_([AgentVersionStatus.ACTUAL.value, AgentVersionStatus.DEPRECATED.value]),
        ).all():
            version.status = AgentVersionStatus.OUTDATED.value
            outdated.append(version)
        return outdated
