from airflow.models import DagRun, TaskInstance

from airflow.settings import Session

from time import sleep

import unittest


"""code that forces a wait for database
"""
session = Session()
attempts = 30
attempt = 0
SLEEP_TIME = 60

while attempt < attempts:
    attempt += 1
    print("attempting to check database: {attempt}/{attempts}".format(**locals()))
    try:
        runs = session.query(DagRun).all()
        run = next(run for run in runs if 'success' in run.state)
        if run:
            print("running tests")
            break
        else:
            # check here for failed tasks and raise an error if true
            pass
    except:
        sleep(SLEEP_TIME)
        continue

if attempt >= attempts:
    raise ValueError("never got a successful DagRun from the database")


class DagCompletionTests(unittest.TestCase):

    def test_dag_finished(self):
        session = Session()
        runs = session.query(DagRun).all()
        assert len(runs) > 0
        run = next(run for run in runs if 'success' in run.state)
        assert run is not None

    def test_tasks(self):
        session = Session()
        runs = session.query(DagRun).all()
        run = next(run for run in runs if 'success' in run.state)

        tasks = session.query(TaskInstance).filter(TaskInstance.dag_id == run.dag_id) \
            .filter(TaskInstance.execution_date == run.execution_date).all()
        assert len(tasks) > 0

        # ensures all tasks have 'success' as their task.state
        assert not any('success' not in task.state for task in tasks)
