import contextlib
import os

import cachetools
import ydb

from crypta.lib.python import (
    templater,
    time_utils,
)
from crypta.ltp.viewer.lib.ydb import schema
from crypta.ltp.viewer.lib.structs import status
from crypta.ltp.viewer.lib.structs.filter import Filter
from crypta.ltp.viewer.lib.structs.page import Page


class Client:
    def __init__(self, endpoint, database, auth_token):
        driver_config = ydb.DriverConfig(
            endpoint,
            database,
            auth_token=auth_token,
        )
        self.driver = ydb.Driver(driver_config)
        self.driver.wait(timeout=30)

        self.session_pool = ydb.SessionPool(self.driver)
        self.root = database

        self.size_cache = cachetools.TTLCache(maxsize=1000, ttl=300)

        self.create_tables()

    def path(self, *path):
        return os.path.join(self.root, *path)

    def create_tables(self):
        with self.get_session() as session:
            session.create_table(self.path("id-to-history-id"), schema.ID_TO_HISTORY_ID_SCHEMA)
            session.create_table(self.path("user-queries"), schema.USER_QUERIES_SCHEMA)

    def prepare_query(self, session, query_file, vars=None):
        query = templater.render_resource(os.path.join("/queries", query_file), strict=True, vars=vars)
        return session.prepare("""
            PRAGMA TablePathPrefix("{}");
            {}
        """.format(self.root, query))

    def prepare_filtered_history_query(self, session, query_file, history_id):
        filter_template = templater.render_resource("/queries/filter_log.sql", strict=True, vars={
            "history_id": history_id
        })
        return self.prepare_query(session, query_file, {"filter_template": filter_template})

    @contextlib.contextmanager
    def get_session(self):
        with self.session_pool.checkout() as session:
            yield session

    @contextlib.contextmanager
    def get_transaction(self, session):
        tx = session.transaction(ydb.SerializableReadWrite())
        tx.begin()
        try:
            yield tx
            tx.commit()
        finally:
            tx.rollback()

    def add_graph(self, ids, history_id, tasks, schedule_limit):
        timestamp = int(time_utils.get_current_time())

        with self.get_session() as session:
            session.create_table(self.path(history_id, "log"), schema.LOG_SCHEMA)
            session.create_table(self.path(history_id, "progress"), schema.PROGRESS_SCHEMA)

            with self.get_transaction(session) as tx:
                tx.execute(
                    self.prepare_query(session, "add_history_id.sql"),
                    {
                        "$ids": ids,
                        "$history_id": history_id,
                        "$timestamp": timestamp,
                    },
                )

                tx.execute(
                    self.prepare_query(session, "init_progress.sql", {"history_id": history_id}),
                    {
                        "$tasks": tasks,
                        "$timestamp": timestamp,
                    },
                )

            with self.get_transaction(session) as tx:
                return self.schedule_tasks(session, tx, history_id, schedule_limit)

    def insert_chunk(self, date, log, records, id):
        with self.get_session() as session:
            with self.get_transaction(session) as tx:
                history_id = self.get_history_id(session, tx, id)

                tx.execute(
                    self.prepare_query(session, "add_records_to_log.sql", {"history_id": history_id}),
                    {
                        "$records": records,
                        "$log": log,
                    } | id.get_params(),
                )

                self.update_progress(session, tx, date, log, id, history_id, status.COMPLETED)

            with self.get_transaction(session) as tx:
                return self.schedule_tasks(session, tx, history_id, limit=1)

    def update_progress(self, session, tx, date, log, id, history_id, status):
        tx.execute(
            self.prepare_query(session, "update_progress.sql", {"history_id": history_id}),
            {
                "$date": date,
                "$log": log,
                "$complete_time": int(time_utils.get_current_time()),
                "$status": status,
            } | id.get_params(),
        )

    def get_history_id(self, session, tx, id):
        result_sets = tx.execute(self.prepare_query(session, "get_history_id.sql"), id.get_params())
        if result_sets[0].rows:
            return result_sets[0].rows[0].history_id
        else:
            return None

    def get_history(self, id, page, filter=None):
        filter = filter or Filter()
        with self.get_session() as session, self.get_transaction(session) as tx:
            history_id = self.get_history_id(session, tx, id)

            rows = []
            subpage = Page(offset=page.offset, limit=min(page.limit, 1000))
            while subpage.offset < page.limit + page.offset:
                new_rows = tx.execute(
                    self.prepare_filtered_history_query(session, "get_history.sql", history_id),
                    filter.get_params() | subpage.get_params()
                )[0].rows
                if not new_rows:
                    break
                rows += new_rows
                subpage = Page(
                    offset=subpage.offset + subpage.limit,
                    limit=min(subpage.limit, page.limit - (subpage.offset - page.offset)),
                )

            total = self.get_history_size(session, tx, history_id, filter)

        return rows, total

    def get_history_size(self, session, tx, history_id, filter):
        key = history_id, filter

        if key not in self.size_cache:
            query = self.prepare_filtered_history_query(session, "get_history_size.sql", history_id)

            result_sets = tx.execute(
                query,
                filter.get_params(),
            )
            self.size_cache[key] = result_sets[0].rows[0].total

        return self.size_cache[key]

    def save_query(self, owner, id_, from_date, to_date):
        timestamp = int(time_utils.get_current_time())

        with self.get_session() as session, self.get_transaction(session) as tx:
            tx.execute(
                self.prepare_query(session, "save_user_query.sql"),
                {"$owner": owner, "$timestamp": timestamp, "$from_date": from_date, "$to_date": to_date} | id_.get_params(),
            )

    def get_user_queries(self, owner):
        with self.get_session() as session, self.get_transaction(session) as tx:
            result_sets = tx.execute(
                self.prepare_query(session, "get_user_queries.sql"),
                {"$owner": owner},
            )

        return result_sets[0].rows

    def get_progress(self, id_, from_date=None, to_date=None):
        with self.get_session() as session, self.get_transaction(session) as tx:
            history_id = self.get_history_id(session, tx, id_)
            if not history_id:
                return None

            rows = []
            page = Page(offset=0, limit=1000)
            while True:
                new_rows = tx.execute(
                    self.prepare_query(session, "get_progress.sql", {"history_id": history_id}),
                    {"$from_date": from_date, "$to_date": to_date} | page.get_params(),
                )[0].rows
                if not new_rows:
                    break
                rows += new_rows
                page = Page(offset=page.offset + page.limit, limit=page.limit)

        return rows

    def fail_chunk(self, date, log, id):
        with self.get_session() as session, self.get_transaction(session) as tx:
            history_id = self.get_history_id(session, tx, id)
            self.update_progress(session, tx, date, log, id, history_id, status.FAILED)

    def schedule_tasks(self, session, tx, history_id, limit):
        return tx.execute(
            self.prepare_query(session, "get_next_task.sql", {"history_id": history_id}),
            {"$limit": limit},
        )[0].rows

    def get_expired_history_ids(self, ttl):
        with self.get_session() as session, self.get_transaction(session) as tx:
            result_sets = tx.execute(
                self.prepare_query(session, "get_expired_history_ids.sql"),
                {"$deadline": int(time_utils.get_current_time() - ttl.total_seconds())},
            )

        return [row.history_id for row in result_sets[0].rows]

    def expire_queries(self, ttl):
        with self.get_session() as session, self.get_transaction(session) as tx:
            tx.execute(
                self.prepare_query(session, "expire_queries.sql"),
                {"$deadline": int(time_utils.get_current_time() - ttl.total_seconds())},
            )

    def drop_history(self, history_id):
        with self.get_session() as session, self.get_transaction(session) as tx:
            session.drop_table(self.path(history_id, "log"))
            session.drop_table(self.path(history_id, "progress"))
            self.driver.scheme_client.remove_directory(self.path(history_id))

            tx.execute(
                self.prepare_query(session, "drop_history_id.sql"),
                {"$history_id": history_id},
            )
