import time
import logging
import os
from sandbox import sdk2
from sandbox.sdk2.helpers import subprocess
import sandbox.common.types.task as ctt
import sandbox.common.types.resource as ctr
from sandbox.common.errors import TaskFailure
from sandbox.projects.answers.resources import AnswersPostgresql


class AwaitPgDatabaseReady(sdk2.Task):
    """Await while Postgres database is not ready"""

    class Context(sdk2.Context):
        cmd = []

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 30*60
        owner = "MAIL"
        description = "Await while Postgres database is not ready"
        priority = ctt.Priority(ctt.Priority.Class.USER, ctt.Priority.Subclass.NORMAL)
        postgres = sdk2.parameters.Resource("Resource with PSQL", resource_type=AnswersPostgresql)
        poll_interval = sdk2.parameters.Integer("Database status poll interval (seconds)", default=10)
        timeout = sdk2.parameters.Integer("Wait timeout (minutes)", default=10)
        with sdk2.parameters.Group("Database parameters") as database_block:
            host = sdk2.parameters.String("Database host")
            port = sdk2.parameters.Integer("Database port", default=12000)
            name = sdk2.parameters.String("Database name")
            user = sdk2.parameters.String("Database user")
            password = sdk2.parameters.String("Vault item name containing password")

    def _is_db_started(self, pl):
        env = {
            'PGPASSWORD': sdk2.Vault.data(self.Parameters.password)
        }
        p = subprocess.Popen(
            self.Context.cmd,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=pl.stderr,
            env=env
        )
        return p.wait() == 0

    def on_create(self):
        self.Parameters.postgres = AnswersPostgresql.find(
            status=ctr.State.READY,
            attrs=dict(arch='linux')
        ).first().id

    def on_execute(self):
        psql_path = os.path.join(str(sdk2.ResourceData(sdk2.Resource[self.Parameters.postgres]).path), 'bin/psql')
        connection_string = "host={host} port={port} dbname={dbname} user={user}".format(
            host=self.Parameters.host,
            port=self.Parameters.port,
            dbname=self.Parameters.name,
            user=self.Parameters.user
        )
        self.Context.cmd = [psql_path, connection_string, '-c', 'select 1']

        timeout = self.Parameters.timeout * 60
        start_time = time.time()

        with sdk2.helpers.ProcessLog(self, logging.getLogger('psql')) as pl:
            while not self._is_db_started(pl):
                if time.time() - start_time >= timeout:
                    raise TaskFailure("Wait timeout exceeded")
                time.sleep(self.Parameters.poll_interval)
