import io
import pytest
import unittest
from unittest import mock

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from moto import mock_s3

from twitch_airflow_components.operators.sagemaker_manifest_file import SageMakerManifestFileOperator


class TestGenerateManifestOperator(unittest.TestCase):
    def setUp(self):
        self.source_bucket = "test_sourcebucket"
        self.dest_bucket = "test_destbucket"
        self.source_prefix = "some/prefix/"
        self.source_prefix_without_slash = "some/prefix"
        self.source_prefix_not_in_keys = "random"
        self.manifest_file_key = "path/to/manifest_file.json"
        self.source_keys = ["some/prefix/path/to/data1.csv", "some/prefix/another/path/to/data2.csv"]

    @mock_s3
    @mock.patch.object(
        S3Hook, "get_connection", return_value=Connection(schema="test_bucket")
    )
    def test_execute(self, mock_get_connection):
        conn = S3Hook()
        conn.create_bucket(
            bucket_name=self.source_bucket,
            region_name="us-west-2",
        )
        conn.create_bucket(
            bucket_name=self.dest_bucket,
            region_name="us-west-2",
        )

        conn.load_file_obj(
            bucket_name=self.source_bucket,
            key="some/prefix/data.csv",
            file_obj=io.BytesIO(b"input"),
            replace=True,
        )

        op = SageMakerManifestFileOperator(
            task_id="generate_manifest_file",
            source_bucket_name=self.source_bucket,
            dest_bucket_name=self.dest_bucket,
            source_prefix=self.source_prefix,
            manifest_file_key=self.manifest_file_key,
        )

        op.execute(None)

        manifest_file_in_dest_bucket = conn.read_key(
            bucket_name=self.dest_bucket, key=self.manifest_file_key
        )

        assert manifest_file_in_dest_bucket is not None

        expected_file_content = (
            '[{"prefix": "s3://test_sourcebucket/some/prefix/"}, "data.csv"]'
        )

        assert manifest_file_in_dest_bucket == expected_file_content

    @mock_s3
    @mock.patch.object(
        S3Hook, "get_connection", return_value=Connection(schema="test_bucket")
    )
    def test_execute_with_source_keys(self, mock_get_connection):
        conn = S3Hook()
        conn.create_bucket(
            bucket_name=self.source_bucket,
            region_name="us-west-2",
        )
        conn.create_bucket(
            bucket_name=self.dest_bucket,
            region_name="us-west-2",
        )

        conn.load_file_obj(
            bucket_name=self.source_bucket,
            key="some/prefix/data.csv",
            file_obj=io.BytesIO(b"input"),
            replace=True,
        )

        op = SageMakerManifestFileOperator(
            task_id="generate_manifest_file",
            source_bucket_name=self.source_bucket,
            dest_bucket_name=self.dest_bucket,
            source_prefix=self.source_prefix,
            manifest_file_key=self.manifest_file_key,
            source_keys=self.source_keys,
        )

        op.execute(None)

        manifest_file_in_dest_bucket = conn.read_key(
            bucket_name=self.dest_bucket, key=self.manifest_file_key
        )

        assert manifest_file_in_dest_bucket is not None

        expected_file_content = (
            '[{"prefix": "s3://test_sourcebucket/some/prefix/"}, "path/to/data1.csv", "another/path/to/data2.csv"]'
        )

        assert manifest_file_in_dest_bucket == expected_file_content

    @mock_s3
    @mock.patch.object(
        S3Hook, "get_connection", return_value=Connection(schema="test_bucket")
    )
    def test_execute_prefix_without_slash(self, mock_get_connection):
        conn = S3Hook()
        conn.create_bucket(
            bucket_name=self.source_bucket,
            region_name="us-west-2",
        )
        conn.create_bucket(
            bucket_name=self.dest_bucket,
            region_name="us-west-2",
        )

        conn.load_file_obj(
            bucket_name=self.source_bucket,
            key="some/prefix/data.csv",
            file_obj=io.BytesIO(b"input"),
            replace=True,
        )

        op = SageMakerManifestFileOperator(
            task_id="generate_manifest_file",
            source_bucket_name=self.source_bucket,
            dest_bucket_name=self.dest_bucket,
            source_prefix=self.source_prefix_without_slash,
            manifest_file_key=self.manifest_file_key,
        )

        op.execute(None)

        manifest_file_in_dest_bucket = conn.read_key(
            bucket_name=self.dest_bucket, key=self.manifest_file_key
        )

        assert manifest_file_in_dest_bucket is not None

        expected_file_content = '[{"prefix": "s3://test_sourcebucket/some/prefix"}, "/data.csv"]'

        assert manifest_file_in_dest_bucket == expected_file_content

    @mock_s3
    @mock.patch.object(
        S3Hook, "get_connection", return_value=Connection(schema="test_bucket")
    )
    def test_execute_with_source_keys_invalid_prefix(self, mock_get_connection):
        conn = S3Hook()
        conn.create_bucket(
            bucket_name=self.source_bucket,
            region_name="us-west-2",
        )
        conn.create_bucket(
            bucket_name=self.dest_bucket,
            region_name="us-west-2",
        )

        conn.load_file_obj(
            bucket_name=self.source_bucket,
            key="some/prefix/data.csv",
            file_obj=io.BytesIO(b"input"),
            replace=True,
        )

        op = SageMakerManifestFileOperator(
            task_id="generate_manifest_file",
            source_bucket_name=self.source_bucket,
            dest_bucket_name=self.dest_bucket,
            source_prefix=self.source_prefix_not_in_keys,
            manifest_file_key=self.manifest_file_key,
            source_keys=self.source_keys
        )

        with pytest.raises(AirflowException):
            op.execute(None)
