import logging
import time
from contextlib import contextmanager
from threading import Lock
from typing import ContextManager

import travel.avia.subscriptions.app.settings.app as app_settings
import sqlalchemy.orm.session
from sqlalchemy import create_engine
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from travel.library.python.avia_mdb_replica_info.avia_mdb_replica_info import MdbAPI, POSTGRES_API_BASE_URL

DSN = "postgresql://{user}:{password}@{host}:{port}/{database}"
WRITEABLE_DB_HOST_TEMPLATE = 'c-{cluster}.rw.db.yandex.net'

log = logging.getLogger(__name__)


class Database:
    def __init__(self, settings: app_settings.AppConfig = None, DSN: str = None, echo=True):
        if DSN:
            self.DSN = DSN
        else:
            self.DSN = self.format_dsn(settings)
        self._lock = Lock()
        self._echo = echo
        self._engine = None
        self._Session = None
        self.recreate_engine()

    @property
    def engine(self) -> Engine:
        try:
            self._lock.acquire()
            return self._engine
        finally:
            self._lock.release()

    @property
    def Session(self) -> sessionmaker:
        try:
            self._lock.acquire()
            return self._Session
        finally:
            self._lock.release()

    def recreate_engine(self):
        log.info('(Re)Creating engine')
        new_engine = create_engine(
            self.DSN,
            pool_pre_ping=True,
            echo=self._echo,
            # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#psycopg2-connect-arguments
            server_side_cursors=False,
            connect_args={
                # https://www.postgresql.org/docs/10/libpq-connect.html
                # If this parameter is set to read-write, only a connection in which read-write transactions
                # are accepted by default is considered acceptable. The query SHOW transaction_read_only will be sent
                # upon any successful connection; if it returns on, the connection will be closed.
                # If multiple hosts were specified in the connection string, any remaining servers will be tried
                # just as if the connection attempt had failed. The default value of this parameter, any,
                # regards all connections as acceptable.
                'target_session_attrs': 'read-write',
            },
        )
        new_session = sessionmaker(new_engine)
        try:
            self._lock.acquire()
            self._engine = new_engine
            self._Session = new_session
        finally:
            self._lock.release()

    @classmethod
    def format_dsn(cls, settings: app_settings.AppConfig) -> str:
        db_host = cls.resolve_db_host(settings) or cls.get_writeable_db_host(settings.db.cluster)
        return DSN.format(
            user=settings.db.user,
            password=settings.db.password,
            host=db_host,
            port=settings.db.port,
            database=settings.db.database,
        )

    @staticmethod
    def resolve_db_host(settings: app_settings.AppConfig) -> str:
        mdb_api = MdbAPI(
            api_base_url=POSTGRES_API_BASE_URL,
            oauth_token=settings.mdb_token,
        )
        try:
            cluster_info = mdb_api.get_cluster_info(cluster_id=settings.db.cluster)
        except Exception:
            return ''
        else:
            return ','.join([instance.hostname for instance in cluster_info.instances])

    @staticmethod
    def get_writeable_db_host(cluster: str) -> str:
        return WRITEABLE_DB_HOST_TEMPLATE.format(cluster=cluster)

    @contextmanager
    def session(self, *args, **kwargs) -> ContextManager[sqlalchemy.orm.session.Session]:
        """
        Context manager that yields session that is proven to be a working one.
        """
        sess: sqlalchemy.orm.session.Session = None
        try:
            log.info('session: getting checked session %s %s', args, kwargs)
            sess = self.session_and_check(*args, **kwargs)
            log.info('session: yielding session')
            yield sess
            log.info('session: committing session')
            sess.commit()
        except OperationalError:
            log.info('session: operational error in session, recreating engine')
            self.recreate_engine()
            if sess is not None:
                sess.rollback()
            raise
        except Exception:
            log.info('session: rolling back session')
            if sess is not None:
                sess.rollback()
            raise

    @contextmanager
    def connection(self) -> ContextManager[Connection]:
        """
        Context manager that yields connection
        """
        conn = None
        try:
            log.info('connection: getting connection')
            conn = self._connect()
            log.info('connection: yielding connection')
            yield conn
        finally:
            if conn:
                log.info('connection: closing connection')
                conn.close()

    def session_and_check(self, *args, **kwargs) -> sqlalchemy.orm.session.Session:
        """
        Tries to get a session until it's proven to be a working one
        """
        max_retries = 3
        for i in range(max_retries):
            try:
                log.info('session_and_check: creating session')
                sess: sqlalchemy.orm.session.Session = self.Session(*args, **kwargs)
                log.info('session_and_check: checking session with bare select')
                sess.execute('SELECT 1')
                return sess
            except OperationalError:
                log.info('session_and_check: operational error, recreating engine')
                self.recreate_engine()
                time.sleep(.1)

        log.info('session_and_check: last try, returning connection that might fail')
        return self.Session(*args, **kwargs)

    def _connect(self) -> Connection:
        """
        Tries to get a connection until it's proven to be a working one
        """
        max_retries = 3
        for i in range(max_retries):
            try:
                log.info('_connect: getting engine connection')
                return self.engine.connect()
            except OperationalError:
                log.info('_connect: connection failed, recreating engine')
                self.recreate_engine()
                time.sleep(.1)

        log.info('_connect: last try, getting connection or failing')
        return self.engine.connect()

    def ping(self):
        with self.connection() as conn:
            t1 = time.time()
            conn.execute("SELECT 1").close()
            t2 = time.time()
        return t2 - t1
