import base64
import threading
from contextlib import contextmanager
from functools import lru_cache

import time
from clickhouse_driver import Client
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session

from load.projects.cloud.cloud_helper import kms, metadata_compute as compute
from load.projects.cloud.cloud_helper.metadata_compute import SaToken
from load.projects.cloud.loadtesting.config import ENV_CONFIG, PG_DB_PASS_ATTR, CH_PASS_ATTR, DEFAULT_CH_DB_PASS, \
    DEFAULT_PG_DB_PASS

BASE = declarative_base()


@lru_cache()
def get_db_password(db):
    attrs = compute.get_instance_attributes() or {}
    if db == 'clickhouse':
        pass_attr = CH_PASS_ATTR
        default_pass = DEFAULT_CH_DB_PASS
    elif db == 'postgresql':
        pass_attr = PG_DB_PASS_ATTR
        default_pass = DEFAULT_PG_DB_PASS
    else:
        raise ValueError(f'Unknown database: {db}')

    if encrypt_db_pass := attrs.get(pass_attr):
        encrypt_db_pass = base64.b64decode(encrypt_db_pass)
    else:
        encrypt_db_pass = default_pass
    return kms.decrypt(encrypt_db_pass, SaToken.get(), ENV_CONFIG.KMS_AAD, ENV_CONFIG.KMS_KEY_ID)


class _ClickhouseClient(Client):
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass

    def __del__(self):
        self.disconnect()


def get_clickhouse_client():
    return _get_clickhouse_client(
        threading.get_ident(),  # клиент Кликхаусa не thread-safe
        int(time.time() / 600),
    )


@lru_cache()
def _get_clickhouse_client(thread_id, ttl_cache_dropper):
    return _ClickhouseClient(
        host=ENV_CONFIG.CLICKHOUSE_HOST,
        port=ENV_CONFIG.CLICKHOUSE_PORT,
        database=ENV_CONFIG.CLICKHOUSE_DBNAME,
        user=ENV_CONFIG.CLICKHOUSE_USER,
        password=get_db_password(db='clickhouse'),
        secure=True, ca_certs=ENV_CONFIG.DB_CERT_FILE,
        connect_timeout=1,
        settings={'joined_subquery_requires_alias': 0}
    )


def postgres_connect(user, password, server, db, port):
    db_connect_string = f'postgresql+psycopg2://{user}:{password}@{server}:{port}/{db}'
    args = {
        'sslmode': 'verify-full',
        'sslrootcert': ENV_CONFIG.DB_CERT_FILE,
        'target_session_attrs': 'read-write',
        "keepalives": 1,
        "keepalives_idle": 10,
        "keepalives_interval": 5,
        "keepalives_count": 5,
    }
    postgres_engine = create_engine(db_connect_string, connect_args=args, logging_name='postgres')
    return postgres_engine


@lru_cache()
def _session_maker():
    pg_db_pass = get_db_password(db='postgresql')
    engine = postgres_connect(
        user=ENV_CONFIG.POSTGRESQL_USER, password=pg_db_pass, port=ENV_CONFIG.POSTGRESQL_PORT,
        db=ENV_CONFIG.POSTGRESQL_DBNAME, server=ENV_CONFIG.POSTGRESQL_HOST)
    return sessionmaker(bind=engine, autocommit=True)


@contextmanager
def get_pg_session() -> Session:
    with _session_maker()() as session:
        session._model_changes = {}
        yield session
