import io
import os
import subprocess
import tarfile
import tempfile
from contextlib import closing

from mail.devpack.lib.helpers import get_ubuntu_codename
from mail.devpack.lib.db import if_server_not_start, if_server_start
import psycopg2
from library.python import resource

from . import helpers

PG_RESOURCES = {
    ('ubuntu', 'precise'): 'postgresql-precise.tgz',
    ('ubuntu', 'trusty'): 'postgresql-trusty.tgz',
    ('ubuntu', 'xenial'): 'postgresql-xenial.tgz',
    ('ubuntu', 'bionic'): 'postgresql-bionic.tgz',
    ('ubuntu', 'focal'): 'postgresql-focal.tgz',
}
PG_DEFAULT_RESOURCE = 'postgresql-trusty.tgz'
PG_VERSION = '10'
PGDATA_DIR = 'pgdata'
PG_SERVER_LOG = 'postgresql.log'
PG_WRAPPER_LOG = 'pg-wrapper.log'
PG_CONF = 'postgresql.conf'
PG_HBA_CONF = 'pg_hba.conf'
PG_RECOVERY_CONF = 'recovery.conf'


class Postgresql(object):
    def __init__(self, port, dbname, working_dir=None, role='master'):
        self.port = port
        self.dbname = dbname
        self.server_name = "postgres_%s_%d" % (dbname, port)
        self.root = helpers.create_root(self.server_name, working_dir, use_ram_disk_if_available=True)
        self.logger = helpers.create_logger(__name__, self, PG_WRAPPER_LOG)
        self.logger.info("pg root is %s", self.root)
        self._role = role

    @property
    def role(self):
        return self._role

    def _get_bin_dir(self):
        return os.path.join(self.root, 'usr', 'lib', 'postgresql', PG_VERSION, 'bin')

    def _get_bin(self, bin_file):
        return os.path.join(self._get_bin_dir(), bin_file)

    def extract_tar(self):
        import distro
        dist = distro.id()
        codename = distro.codename()
        if not codename:
            dist, codename = 'ubuntu', get_ubuntu_codename()

        self.logger.info("Linux distribution info: dist=%s codename=%s", dist, codename)
        pg_resource_name = PG_RESOURCES.get((dist.lower(), codename.lower())) or PG_DEFAULT_RESOURCE
        self.logger.info("start extracting %s to %s", pg_resource_name, self.root)
        pg_res = resource.find(pg_resource_name)
        if not pg_res:
            raise RuntimeError('No postgresql resource found: {}'.format(pg_resource_name))
        with io.BytesIO(pg_res) as tar_obj:
            tar = tarfile.open(fileobj=tar_obj)
            tar.extractall(path=self.root)
            self.logger.info("extracting done")

    def initdb(self):
        self.logger.info("starting initdb")
        cmd = [
            self._get_bin("pg_ctl"),
            '-D', PGDATA_DIR,
            'initdb',
            '-o', '"-E UTF8"',
            '-o', '"--auth-local=trust"']
        helpers.run_subprocess("initdb", self.logger, cmd, self._make_env(), self.root)

    def add_recovery_config(self, replicate_user, master_port, dbname):
        self.logger.info("put recoverty config to pgdata")
        recovery_config = resource.find('pg/recovery.conf').decode('utf-8').format(
            port=master_port,
            user=replicate_user,
            dbname=dbname)

        recovery_config_path = os.path.join(self.root, PGDATA_DIR, PG_RECOVERY_CONF)
        helpers.write2file(recovery_config, recovery_config_path)

    def make_backup(self, replicate_user, master_port, dst_dir=PGDATA_DIR):
        self.logger.info("starting backup")
        cmd = [
            self._get_bin("pg_basebackup"),
            '-h', 'localhost',
            '-U', replicate_user,
            '-D', dst_dir,
            '-p', str(master_port),
            '-X', 'stream']
        helpers.run_subprocess("backup", self.logger, cmd, self._make_env(), self.root)

    def dump(self, partition_tables=None, ignore_tables=None):
        if ignore_tables is None:
            ignore_tables = []
        if partition_tables:
            ignore_tables += [row[0] for row in self.query(
                '''
                SELECT partition_schemaname || '.' || partition_tablename
                  FROM unnest( %(partition_tables)s ) t, show_partitions(t) s
                ''',
                partition_tables=partition_tables
            )]
        # ['-T', 'my_table', '-T', 'another_table']
        ignore_tables_args = sum((['-T', table_name] for table_name in ignore_tables), [])
        self.logger.info("dump db")
        cmd = [self._get_bin("pg_dump"), '-s'] + ignore_tables_args + [self.dsn()]
        return helpers.run_subprocess("dump", self.logger, cmd, self._make_env(), self.root)

    def _make_env(self):
        lib_dir = os.path.join(self.root, 'usr', 'lib')
        lib_linux_dir = os.path.join(lib_dir, 'x86_64-linux-gnu')
        paths = lib_dir + ":" + lib_linux_dir
        env = os.environ.copy()
        if 'LD_LIBRARY_PATH' in env:
            env['LD_LIBRARY_PATH'] += ':' + paths
        else:
            env['LD_LIBRARY_PATH'] = paths
        env['LANG'] = 'en_US.UTF-8'
        env['LC_MESSAGES'] = 'en_US.UTF-8'
        return env

    def _check_run_server(self):
        path = os.path.join(self.root, PGDATA_DIR, 'postmaster.pid')
        return os.path.exists(path)

    @if_server_not_start
    def start(self):
        self.logger.info("starting postgresql server on port %d", self.port)
        config = resource.find('pg/pg.conf').decode('utf-8').format(
            port=self.port,
            socket_dir=tempfile.gettempdir()
        )
        self.logger.info("pg.conf:\n%s", config)

        config_path = os.path.join(self.root, PG_CONF)
        helpers.write2file(config, config_path)

        hba_config_path = os.path.join(self.root, PGDATA_DIR, PG_HBA_CONF)
        hba_config = resource.find('pg/pg_hba.conf').decode('utf-8')
        helpers.write2file(hba_config, hba_config_path)

        cmd = [
            self._get_bin("pg_ctl"),
            '-D', PGDATA_DIR,
            'start', '-w',
            '-l', PG_SERVER_LOG,
            '-o', '--config-file=%s' % PG_CONF,
            '-t', '600',
        ]
        helpers.run_subprocess("pg_ctl start", self.logger, cmd, self._make_env(), self.root)
        path = os.path.join(self.root, PGDATA_DIR, 'postmaster.pid')
        assert os.path.exists(path)

    @if_server_start
    def dropdb(self):
        self.logger.info("starting dropdb %s", self.dbname)
        cmd = [
            self._get_bin("dropdb"), self.dbname, '--if-exists',
            '-p', str(self.port),
            '-h', 'localhost',
        ]
        helpers.run_subprocess("dropdb", self.logger, cmd, self._make_env())
        self.logger.info("db %s dropped", self.dbname)

    def createdb(self):
        self.logger.info("starting createdb %s", self.dbname)
        cmd = [
            self._get_bin("createdb"), self.dbname,
            '-p', str(self.port),
            '-h', 'localhost']
        helpers.run_subprocess("createdb", self.logger, cmd, self._make_env())
        self.logger.info("db %s created", self.dbname)

    @if_server_start
    def stop(self):
        self.logger.info("stopping postgresql server on port %d", self.port)
        cmd = [
            self._get_bin("pg_ctl"),
            '-D', PGDATA_DIR,
            'stop', '-w',
            '-t', '600',
            ]
        helpers.run_subprocess("pg_ctl stop", self.logger, cmd, self._make_env(), self.root)
        #  TODO: wait till postgre is down

    def dsn(self):
        return "host=localhost dbname={dbname} port={port}".format(dbname=self.dbname, port=self.port)

    def _do_exec(self, query, vars=None, to_fetch=True):
        self.logger.info("will execute %s", query)
        try:
            with closing(psycopg2.connect(self.dsn())) as conn:
                conn.set_session(autocommit=True)
                with closing(conn.cursor()) as cur:
                    cur.execute(query, vars=vars or None)
                    return cur.fetchall() if to_fetch else None
        except:
            self.logger.error('cannot query, conninfo="%s" query=%s', self.dsn(), query)
            raise

    def query(self, query, **kwargs):
        return self._do_exec(query, vars=kwargs, to_fetch=True)

    def execute(self, query, **kwargs):
        return self._do_exec(query, vars=kwargs, to_fetch=False)

    def info(self):
        return {
            "dbname": self.dbname,
            "port": self.port,
            "dsn": self.dsn(),
            "root": self.root,
        }

    def communicate(self, **kwargs):
        def make_psql_command(dsn, **kwargs):
            cmdl = [self._get_bin('psql')]
            for k, v in kwargs.items():
                cmdl.append('--variable={k}={v}'.format(k=k, v=v))
            cmdl.append('"{}"'.format(dsn))
            return ' '.join(cmdl)

        popen = subprocess.Popen(
            make_psql_command(self.dsn(), **kwargs),
            shell=True,
            cwd=self.root,
            env=self._make_env()
        )
        while True:
            try:
                popen.communicate()
            except KeyboardInterrupt:
                # Mimic psql-behaviour to flush the prompt on CTRL+C rather than quit
                continue
            break
