import asyncio
from abc import ABC, abstractmethod
from asyncio import Future
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import Manager, Queue
from typing import Any, Dict, Iterable, List, Optional, Union

from yql.api.v1.client import YqlClient
from yql.client.parameter_value_builder import YqlParameterValueBuilder
from yt.wrapper import TablePath, YtClient
from yt.wrapper.default_config import retries_config


class BaseYtTask(ABC):
    ITER_SIZE: int = 10000
    QUEUE_SLEEP_TIME: float = 0.3
    TIMEOUT: int = 60

    def __init__(self, cluster: str, yt_token: str):
        self._cluster = cluster
        self._yt_token = yt_token

    def __await__(self):
        return self().__await__()

    @abstractmethod
    async def __call__(self, *args, **kwargs):
        raise NotImplementedError

    def make_yt_client(self) -> YtClient:
        return YtClient(
            proxy=self._cluster,
            token=self._yt_token,
            config={
                "proxy": {
                    "retries": retries_config(
                        count=6,
                        enable=True,
                        total_timeout=None,
                        backoff={
                            "policy": "exponential",
                            "exponential_policy": {
                                "start_timeout": 30000,  # ms
                                "base": 2,
                                "max_timeout": self.TIMEOUT * 1000,  # ms
                                "decay_factor_bound": 0.3,
                            },
                        },
                    ),
                    "connect_timeout": self.TIMEOUT * 1000,  # ms
                    "request_timeout": self.TIMEOUT * 1000,  # ms
                    "heavy_request_timeout": self.TIMEOUT * 1000,  # ms
                }
            },
        )


class BaseYtExportTask(BaseYtTask):
    TABLE_SCHEMA: List[Dict[str, Union[str, bool]]] = []

    ITER_SIZE: int = 10000
    QUERY_FULL_SLEEP_TIME: float = 0.3
    CHUNKED_ITERATOR_METHOD_NAME: str = ""
    RECREATE_TABLE_EVERY_RUN: bool = True
    SORTING_PARAMS: Optional[List[str]] = None

    TIMEOUT: int = 60

    def __init__(self, cluster: str, token: str, table: str, data_producer: object):
        super().__init__(cluster, token)
        self._table = table
        self._data_iterator = getattr(data_producer, self.CHUNKED_ITERATOR_METHOD_NAME)

    def __reduce__(self):
        reduced = super().__reduce__()

        state_dict = reduced[2].copy()
        # This property is not picklable, but not needed in subprocess
        del state_dict["_data_iterator"]

        return reduced[0], reduced[1], state_dict

    async def __call__(self, *args, **kwargs):
        with ProcessPoolExecutor(max_workers=1) as executor:
            with Manager() as manager:
                queue = manager.Queue(maxsize=1)
                fut = asyncio.get_event_loop().run_in_executor(
                    executor, self.write_to_yt, queue
                )

                data_iter_exc = None
                try:
                    async for records in self._data_iterator(self.ITER_SIZE):
                        await self._put_in_queue(queue, records, fut)

                    # End of data signal
                    await self._put_in_queue(queue, None, fut)
                except Exception as exc:
                    data_iter_exc = exc
                finally:
                    if data_iter_exc:
                        fut.cancel()
                        raise data_iter_exc

                    if not fut.done():
                        await fut

                    if not fut.cancelled() and fut.exception():
                        raise fut.exception()

    @classmethod
    async def _put_in_queue(cls, queue: Queue, data: Any, fut: Future):
        while queue.full():
            if fut.done():
                exc = fut.exception()
                raise exc if exc else Exception("Subprocess finished unexpectedly")
            await asyncio.sleep(cls.QUERY_FULL_SLEEP_TIME)

        queue.put(data)

    def write_to_yt(self, queue: Queue):
        yt_client = self.make_yt_client()
        table_path = TablePath(self._table, append=True)

        with yt_client.Transaction(timeout=self.TIMEOUT * 1000):
            table_exists = yt_client.exists(table_path)
            if table_exists and self.RECREATE_TABLE_EVERY_RUN:
                yt_client.remove(table_path)
            if self.RECREATE_TABLE_EVERY_RUN or not table_exists:
                yt_client.create(
                    "table", table_path, attributes={"schema": self.TABLE_SCHEMA}
                )

            while True:
                records = queue.get(block=True, timeout=self.TIMEOUT)
                if records is None:
                    if self.SORTING_PARAMS:
                        yt_client.run_sort(
                            source_table=self._table, sort_by=self.SORTING_PARAMS
                        )
                    break
                yt_client.write_table(table_path, records)


class BaseYtImportDirTask(BaseYtTask):
    CHUNKED_WRITER_METHOD_NAME: str = ""
    PROCESSED_ATTR = "_geosmb_processed"

    def __init__(self, cluster: str, yt_token: str, yt_dir: str, data_consumer: object):
        super().__init__(cluster, yt_token)
        self._yt_dir = yt_dir
        self._data_generator_writer = getattr(
            data_consumer, self.CHUNKED_WRITER_METHOD_NAME
        )

    def __reduce__(self):
        reduced = super().__reduce__()

        state_dict = reduced[2].copy()
        # This property is not picklable, but not needed in subprocess
        del state_dict["_data_generator_writer"]

        return reduced[0], reduced[1], state_dict

    async def __call__(self):
        yt_client = self.make_yt_client()
        with ProcessPoolExecutor(max_workers=1) as executor:
            with Manager() as manager:
                for table in yt_client.list(
                    self._yt_dir,
                    attributes=["type", self.PROCESSED_ATTR],
                    absolute=True,
                    sort=True,
                ):
                    if table.attributes["type"] != "table" or table.attributes.get(
                        self.PROCESSED_ATTR, False
                    ):
                        continue

                    queue = manager.Queue(maxsize=1)
                    fut = asyncio.get_event_loop().run_in_executor(
                        executor, self.read_from_yt, queue, yt_client, str(table)
                    )

                    queue_iterator = QueueIterator(
                        queue=queue, fut=fut, queue_sleep_time=self.QUEUE_SLEEP_TIME
                    )
                    data_generator_writer = DataGeneratorWriter(
                        data_generator_writer=self._data_generator_writer,
                        queue_iterator=queue_iterator,
                        fut=fut,
                    )

                    await data_generator_writer()

                    yt_client.set_attribute(str(table), self.PROCESSED_ATTR, True)

    def read_from_yt(self, queue: Queue, yt_client: YtClient, table_name: str):
        data_chunk = []

        read_table_iter = self.make_read_table_iter(
            yt_client=yt_client, table_name=table_name
        )
        for row in read_table_iter:
            decoded_row = self._yt_row_decode(row)
            if decoded_row:
                data_chunk.append(decoded_row)

            if len(data_chunk) == self.ITER_SIZE:
                queue.put(data_chunk, block=True, timeout=self.TIMEOUT)
                data_chunk = []

        if data_chunk:
            queue.put(data_chunk, block=True, timeout=self.TIMEOUT)

        # End of data signal
        queue.put(None, block=True, timeout=self.TIMEOUT)
        # Make sure signal received, meaning last chunk was successfully processed
        queue.join()

    def make_read_table_iter(cls, yt_client: YtClient, table_name: str) -> Iterable:
        return yt_client.read_table(table_name)

    @classmethod
    def _yt_row_decode(cls, row: Any) -> Any:
        return row


class YqlOperationError(Exception):
    pass


class BaseYtImportDirWithYqlTask(BaseYtImportDirTask):
    YQL: str = ""

    def __init__(
        self,
        cluster: str,
        yt_token: str,
        yql_token: str,
        yt_dir: str,
        data_consumer: object,
    ):
        super().__init__(
            cluster=cluster,
            yt_token=yt_token,
            yt_dir=yt_dir,
            data_consumer=data_consumer,
        )
        self._yql_token = yql_token

    def make_read_table_iter(
        self, yt_client: YtClient, table_name: str
    ) -> Iterable[tuple]:
        with YqlClient(token=self._yql_token) as yql_client:
            request = yql_client.query(
                self.YQL.format(cluster_name=self._cluster, table_name=table_name),
                syntax_version=1,
            )

            request.run()
            results = request.get_results()

            if results.status == "ERROR":
                raise YqlOperationError(results.text)

            return results.table.get_iterator()


class BaseExecuteYqlTask:
    YQL: str = ""

    def __init__(self, yql_token: str, yql_format_kwargs: Optional[dict] = None):
        self._yql_token = yql_token
        self.yql_format_kwargs = yql_format_kwargs if yql_format_kwargs else dict()

    def __await__(self):
        return self().__await__()

    async def __call__(self) -> None:
        with ProcessPoolExecutor(max_workers=1) as executor:
            fut = asyncio.get_event_loop().run_in_executor(executor, self.execute_yql)

            await fut

    def execute_yql(self) -> None:
        with YqlClient(token=self._yql_token) as yql_client:
            request = yql_client.query(
                self.YQL.format(**self.yql_format_kwargs), syntax_version=1
            )

            request.run()
            results = request.get_results()

            if results.status == "ERROR":
                raise YqlOperationError(results.text)


class BaseImportWithYqlTask:
    YQL: str = ""
    CHUNKED_WRITER_METHOD_NAME: str = ""
    ITER_SIZE: int = 10000
    QUEUE_SLEEP_TIME: float = 0.3
    TIMEOUT: int = 60

    def __init__(
        self,
        yql_token: str,
        data_consumer: object,
        yql_format_kwargs: Optional[dict] = None,
    ):
        self._data_generator_writer = getattr(
            data_consumer, self.CHUNKED_WRITER_METHOD_NAME
        )
        self._yql_token = yql_token
        self._yql_format_kwargs = yql_format_kwargs if yql_format_kwargs else dict()

    def __await__(self):
        return self().__await__()

    def __reduce__(self):
        reduced = super().__reduce__()

        state_dict = reduced[2].copy()
        # This property is not picklable, but not needed in subprocess
        del state_dict["_data_generator_writer"]

        return reduced[0], reduced[1], state_dict

    async def __call__(self):
        with ProcessPoolExecutor(max_workers=1) as executor:
            with Manager() as manager:
                queue = manager.Queue(maxsize=1)
                fut = asyncio.get_event_loop().run_in_executor(
                    executor,
                    self.read_from_yt,
                    queue,
                    await self.fetch_yql_query_params(),
                )
                queue_iterator = QueueIterator(
                    queue=queue, fut=fut, queue_sleep_time=self.QUEUE_SLEEP_TIME
                )
                data_generator_writer = DataGeneratorWriter(
                    data_generator_writer=self._data_generator_writer,
                    queue_iterator=queue_iterator,
                    fut=fut,
                )

                await data_generator_writer()

    def read_from_yt(self, queue: Queue, yql_query_params: dict):
        data_chunk = []

        yql_results = self.execute_yql(yql_query_params)
        for row in yql_results:
            decoded_row = self._yt_row_decode(row)
            if decoded_row:
                data_chunk.append(decoded_row)

            if len(data_chunk) == self.ITER_SIZE:
                queue.put(data_chunk, block=True, timeout=self.TIMEOUT)
                data_chunk = []

        if data_chunk:
            queue.put(data_chunk, block=True, timeout=self.TIMEOUT)

        # End of data signal
        queue.put(None, block=True, timeout=self.TIMEOUT)
        # Make sure signal received, meaning last chunk was successfully processed
        queue.join()

    def execute_yql(self, yql_query_params: dict) -> Iterable[tuple]:
        with YqlClient(token=self._yql_token) as yql_client:
            request = yql_client.query(
                self.YQL.format(**self._yql_format_kwargs), syntax_version=1
            )

            request.run(
                parameters=YqlParameterValueBuilder.build_json_map(yql_query_params)
            )
            results = request.get_results()

            if results.status == "ERROR":
                raise YqlOperationError(results.text)

            return results.table.get_iterator()

    async def fetch_yql_query_params(self) -> dict:
        return dict()

    @classmethod
    def _yt_row_decode(cls, row: tuple) -> Any:
        return row


class BaseReplicateYtTableTask(BaseYtTask):
    def __init__(
        self,
        src_cluster: str,
        target_cluster: str,
        token: str,
        src_table: str,
        target_table: str,
    ):
        super().__init__(target_cluster, token)
        self._src_cluster = src_cluster
        self._src_table = src_table
        self._target_table = target_table

    async def __call__(self) -> None:
        with ProcessPoolExecutor(max_workers=1) as executor:
            asyncio.get_event_loop().run_in_executor(executor, self.replicate)

    def replicate(self) -> None:
        yt_client = self.make_yt_client()
        yt_client.run_remote_copy(
            source_table=self._src_table,
            destination_table=self._target_table,
            cluster_name=self._src_cluster,
        )


class QueueIterator:
    def __init__(self, *, queue: Queue, fut: Future, queue_sleep_time: float):
        self.queue = queue
        self.fut = fut
        self.queue_sleep_time = queue_sleep_time

    async def __call__(
        self,
    ) -> Any:
        while True:
            while self.queue.empty():
                if self.fut.done():
                    exc = self.fut.exception()
                    raise exc if exc else Exception("Subprocess finished unexpectedly")
                await asyncio.sleep(self.queue_sleep_time)

            records = self.queue.get(block=False)
            if records is None:
                self.queue.task_done()
                break

            yield records

            self.queue.task_done()


class DataGeneratorWriter:
    def __init__(
        self,
        *,
        data_generator_writer: callable,
        queue_iterator: QueueIterator,
        fut: Future,
    ):
        self._data_generator_writer = data_generator_writer
        self._queue_iterator = queue_iterator
        self._fut = fut

    async def __call__(self):
        try:
            await self._data_generator_writer(self._queue_iterator())
        except Exception:
            self._fut.cancel()
            raise
        finally:
            if not self._fut.done():
                await self._fut

            if not self._fut.cancelled() and self._fut.exception():
                raise self._fut.exception()
