import logging
import os
import shutil

from mail.devpack.lib.components.base import AbstractComponent
from mail.devpack.lib.components.postgres import apply_resources_sql
from mail.devpack.lib.pg import Postgresql
from mail.devpack.lib.db import DbComponentMixin
from mail.devpack.lib.state import read_state

log = logging.getLogger(__name__)


class PostgresShard(object):
    def __init__(self, shard_info, dbname, users, ddl_prefixes, root, replicate_user='replica'):
        self._dbname = dbname
        self._users = users
        self._ddl_prefixes = ddl_prefixes
        self._dbs = []
        self._shard_id = shard_info['id']
        self._shard_name = shard_info['name']
        self._replicate_user = replicate_user
        for db in shard_info['dbs']:
            self._dbs.append(Postgresql(db['port'], self._dbname, root, db['type']))

    @property
    def master(self):
        m = list(filter(lambda db: db.role == 'master', self._dbs))
        assert len(m) == 1, 'Postgres shard must have one master'
        return m[0]

    @property
    def shard_id(self):
        return self._shard_id

    @property
    def shard_name(self):
        return self._shard_name

    def dropdb(self):
        self.master.dropdb()

    def createdb(self):
        self.master.createdb()
        self._create_replica_user()
        for db in self._dbs:
            if db.role == 'replica':
                db.make_backup(self._replicate_user, self.master.port)
                db.add_recovery_config(self._replicate_user, self.master.port, self._dbname)
                db.start()

    def init_root(self):
        for db in self._dbs:
            db.extract_tar()
        self.master.initdb()

    def start(self):
        self.master.start()

    def stop(self):
        for db in self._dbs:
            db.stop()

    def execute(self, query, **kwargs):
        return self.master.execute(query, **kwargs)

    def query(self, query, **kwargs):
        return self.master.query(query, **kwargs)

    def purge(self):
        for db in self._dbs:
            root = db.root
            if os.path.exists(root):
                log.info("%s %s: deleting data ...", self._dbname, db.role)
                shutil.rmtree(root, ignore_errors=True)

    def _create_replica_user(self):
        self.execute('CREATE USER %s REPLICATION LOGIN' % self._replicate_user)

    def _create_users(self):
        self.execute('create user root superuser createdb inherit login;')
        for user in self._users:
            self.execute('drop user if exists %s' % user)
            self.execute('create user %s' % user)

    def _apply_migrations(self):
        for ddl_prefix in self._ddl_prefixes:
            apply_resources_sql(self.master, ddl_prefix)


class PostgresCluster(AbstractComponent, DbComponentMixin):
    DEPS = []

    @classmethod
    def gen_config(cls, port_generator, config=None):
        raise Exception("Please override this method for your component")

    def __init__(self, config, dbname, users, ddl_prefixes):
        self.__config = config
        self.__state = read_state(config, self.NAME)
        self.__shards = []  # type: list[PostgresShard]
        self._dbname = dbname  # type: str
        for shard_info in config[self.NAME]['shards']:
            self.__shards.append(PostgresShard(shard_info, self._dbname, users, ddl_prefixes, config.root))
        self.users = users

    @property
    def config(self):
        return self.__config

    @property
    def NAME(self):
        return self._dbname

    @property
    def state(self):
        return self.__state

    @property
    def shard(self):
        return self.__shards[0]

    @property
    def shards(self):
        return self.__shards

    def shard_by_id(self, shard_id):
        return next(sh for sh in self.__shards if sh.shard_id == shard_id)

    def shard_by_name(self, shard_name):
        return next(sh for sh in self.__shards if sh.shard_name == shard_name)

    def port(self):
        return self.shard.master.port

    def query(self, query, **kwargs):
        return self.shard.master.query(query, **kwargs)

    def execute(self, query, **kwargs):
        return self.shard.master.execute(query, **kwargs)

    def init_root(self):
        for shard in self.__shards:
            shard.init_root()

    def communicate(self):
        return self.shard.master.communicate()

    def start(self):
        for shard in self.__shards:
            shard.start()

    def dsn(self):
        return self.shard.master.dsn()

    def stop(self):
        for shard in self.__shards:
            shard.stop()

    def purge(self):
        for shard in self.__shards:
            shard.purge()

    def is_multiroot(self):
        return True

    def prepare_data(self):
        for shard in self.__shards:
            shard.dropdb()
            shard.createdb()
            shard._create_users()
            shard._apply_migrations()
