import json
from typing import Any, Dict, Optional

from airflow.models.connection import Connection
from twitch_airflow_components.operators.postgres import PostgresOperator

from conductor.operators.operator_iface import IConfiguredOperator, TaskWrapper


class RedshiftOutput:
    _s3_url: Optional[str]

    def __init__(self, s3_url: Optional[str]):
        self._s3_url = s3_url

    @property
    def s3_url(self) -> str:
        if self._s3_url is None:
            raise AttributeError(
                "Tried to access output s3_url on a RedshiftOperator instance that did not have an unload."
            )
        return self._s3_url


class ConfiguredRedshiftOperator(IConfiguredOperator):
    def generate_tasks(
        self,
        task_id: str,
        query: str,
        unload: bool = False,
        unload_prefix: str = "",
        parallel_unload: bool = True,
        unload_format: str = "PARQUET",
        overwrite: bool = True,
    ) -> TaskWrapper[PostgresOperator, RedshiftOutput]:
        config = self.project_resources.env.redshift
        if config is None:
            raise ValueError(
                "Cannot use RedshiftOperator when no Redshift config is set."
            )
        extra: Dict[str, Any] = {"iam": True, "aws_conn_id": None, "redshift": True}
        if config.cluster_identifier is not None:
            extra["cluster-identifier"] = config.cluster_identifier
        connection = Connection(
            host=config.host,
            login=config.db_user,
            schema=config.db_name,
            extra=json.dumps(extra),
        )
        if config.port is not None:
            connection.port = int(config.port)
        if unload:
            output_s3_url = self.dag_resources.s3_url_for_path([task_id])
            if len(unload_prefix) > 0:
                output_s3_url = self.dag_resources.s3_url_for_path(
                    [task_id, unload_prefix]
                )
            query = (
                "UNLOAD($$ {query} $$)\n"
                "TO '{output_s3_url}/'\n"
                "FORMAT {format}\n"
                "MAXFILESIZE 50 MB\n"
                "IAM_ROLE '{unload_role}'\n"
                "PARALLEL {parallel}\n"
                "{overwrite}"
            ).format(
                query=query,
                output_s3_url=output_s3_url,
                format=unload_format,
                unload_role=self.project_resources.unload_role_chain(),
                parallel="ON" if parallel_unload else "OFF",
                overwrite="ALLOWOVERWRITE" if overwrite else "",
            )
            return TaskWrapper[PostgresOperator, RedshiftOutput](
                PostgresOperator(
                    task_id=task_id, sql=query, connection=connection, dag=self.dag
                ),
                RedshiftOutput(output_s3_url),
            )
        return TaskWrapper[PostgresOperator, RedshiftOutput](
            PostgresOperator(
                task_id=task_id, sql=query, connection=connection, dag=self.dag
            ),
            RedshiftOutput(None),
        )
