import datetime

import yt.wrapper as yt

from sqlalchemy import create_engine


def row_mapper_exact(row):
    return row


class Dumper:
    def __init__(self, db_connection_string, yt_client, logger):
        self.__db_engine = create_engine(db_connection_string)
        self.__yt_client = yt_client
        self.__logger = logger

    def dump(
        self,
        query,
        yt_table_path,
        yt_table_schema,
        ttl,
        yt_link_path=None,
        row_mapper=row_mapper_exact,
        batch=None
    ):
        """
        Dumps data retrived from PostgreSQL DB by means of `query` to an YT-table
        with name `yt_table_path`.

        :param datetime.timedelta ttl: Table Time-To-Live.
        :param row_mapper: Query execution result (where rows are represented as
            dicts) is transformed by applying this function to each row.

        :param batch: If a table must be dumped by batches then this parameter
            MUST be a dict with keys "column", "size" and "start_value".

            The 'query' is extened by means of this parameter with the following
            statement:

                AND column > value ORDER BY column ASC LIMIT size

            where `value` is equal to the "start_value" when the query is
            executed for the first time and substituted by the last got value
            from the column "column".

            Therefore the query MUST NOT have `ORDER` and `LIMIT` statements and
            the last statement MUST be a `WHERE` statement. Moreover, the column
            "column" MUST present in the `SELECT` statement.
        """

        expiration_time = format(datetime.datetime.utcnow() + ttl)

        with yt.Transaction(client=self.__yt_client):
            self.__create_table(yt_table_path, yt_table_schema, expiration_time)
            self.__create_link(yt_table_path, yt_link_path, expiration_time)
            self.__dump(query, yt_table_path, row_mapper, batch)

    def __create_table(self, yt_table_path, yt_table_schema, expiration_time):
        self.__logger.info("Create YT table '%s'.", yt_table_path)
        self.__yt_client.create(
            "table",
            yt_table_path,
            ignore_existing=True,
            recursive=True,
            attributes={
                "schema": yt_table_schema,
                "expiration_time": expiration_time
            }
        )

    def __create_link(self, yt_table_path, yt_link_path, expiration_time):
        if yt_link_path:
            self.__logger.info("Create link '%s'.", yt_link_path)
            self.__yt_client.link(
                yt_table_path,
                yt_link_path,
                force=True,
                attributes={"expiration_time": expiration_time})

    def __dump(self, query, yt_table_path, row_mapper, batch):
        self.__logger.info("Dump data to '%s'.", yt_table_path)

        if batch is None:
            self.__yt_client.write_table(yt_table_path, map(row_mapper, self.__get_data(query)))
        else:
            for records in self.__get_data_by_batches(query, batch):
                self.__yt_client.write_table(yt.TablePath(yt_table_path, append=True), map(row_mapper, records))

        self.__logger.info("Data has been dumped successfully.")

    def __get_data(self, query):
        self.__logger.info("Get data from DB.")
        self.__logger.debug("Query:\n%s", query)

        rows = self.__db_engine.execute(query)
        self.__logger.info("Got %d rows.", rows.rowcount)

        return map(lambda row: dict(row.items()), rows)

    def __get_data_by_batches(self, query, batch):
        self.__logger.info("Get data from DB.")

        query += " AND %s > {} ORDER BY %s ASC LIMIT %d" % \
            (batch["column"], batch["column"], batch["size"])

        start_value = batch["start_value"]
        while True:
            self.__logger.debug("Query:\n%s", query)

            rows = self.__db_engine.execute(query.format(start_value)).fetchall()
            self.__logger.info("Got %d rows.", len(rows))

            if len(rows) == 0:
                return

            start_value = rows[-1][batch["column"]]

            yield map(lambda row: dict(row.items()), rows)
