#! /usr/bin/python
"""
Copies MySQL databases between several locations
Has configurable settings to copy just several tables/databases
"""

import logging
import subprocess
import argparse
import yaconfig
import MySQLdb


class MySQL:
    """ The MySQL wrapper """
    def __init__(self, host, user, passwd):
        self.log = logging.getLogger('MySQL')

        self.log.debug('Connecting to %s', host)
        self.conn = MySQLdb.connect(host, user, passwd)

    def run(self, query):
        """ Runs the query """
        cur = self.conn.cursor()
        cur.execute(query)
        return cur.fetchall()

    def databases(self):
        """ Return the database list """
        return [x[0] for x in self.run("SHOW DATABASES")]

    def tables(self, dbname):
        """ Return table list in db """
        return [x[0] for x in self.run("SHOW TABLES IN `%s`" % dbname)]


class Copier:
    """ The copier class """

    copy_command = \
        "mysqldump --default-character-set=binary "\
        + "--skip-lock-tables --no-create-db " \
        + "-u '{source_user}' " \
        + "-p'{source_password}' " \
        + "-h '{source_host}' " \
        + "'{dbname}' " \
        + "'{table}' " \
        + "| sed 's/ROW_FORMAT=COMPACT//' " \
        + "| mysql " \
        + "-u '{dest_user}' " \
        + "-p'{dest_password}' " \
        + "-h '{dest_host}' " \
        + "'{dbname}'"

    def __init__(self):
        parser = argparse.ArgumentParser("mysql-dbcopier")
        parser.add_argument('config', help='The config file to use')
        args = parser.parse_args()

        self.conf = yaconfig.load_config(args.config)
        self.log = logging.getLogger('Copier')

    def determine_sources(self):
        """ Compare config with real data and generate the to-be-copied list """
        src = MySQL(self.conf.source.host,
                    self.conf.source.user,
                    self.conf.source.passwd)

        # the result set
        sources = {}

        # process includes
        if self.conf.include == '*':
            for dbname in src.databases():
                sources[dbname] = src.tables(dbname)

        else:
            for dbname in self.conf.include:
                if self.conf.include[dbname] == '*':
                    sources[dbname] = src.tables(dbname)
                else:
                    sources[dbname] = self.conf.include[dbname]

        # prodess excludes
        for dbname in self.conf.exclude:
            sources[dbname] = list(set(sources[dbname])
                                   - set(self.conf.exclude[dbname]))

        return sources

    def copy_table(self, dbname, table):
        """ Copies one table from source to destination """
        subprocess.check_call(self.copy_command.format(
            source_user=self.conf.source.user,
            source_password=self.conf.source.passwd,
            source_host=self.conf.source.host,
            dest_user=self.conf.destination.user,
            dest_password=self.conf.destination.passwd,
            dest_host=self.conf.destination.host,
            dbname=dbname,
            table=table,
        ), shell=True)

    def copy_db(self, dbname, tables):
        """ Copies one database from source to destination """
        tables_count = len(tables)
        i = 0

        for table in sorted(tables):
            i += 1
            self.log.info('Copying %s: %s (%s/%s)', dbname, table, i, tables_count)

            retries = self.conf.retries
            while True:
                try:
                    self.copy_table(dbname, table)
                    break
                except KeyboardInterrupt:
                    raise
                except:                            # pylint: disable=bare-except
                    if retries > 0:
                        retries -= 1
                        self.log.warning('Error running command, retrying')
                    else:
                        self.log.fatal("Could not copy %s with %s tries",
                                       table,
                                       self.conf.retries)
                        raise

    def run(self):
        """ Run the copier logic """

        # determine possible sources
        sources = self.determine_sources()
        for dbname, tables in sources.items():
            self.copy_db(dbname, tables)

        if self.conf.post_action:
            subprocess.check_call(self.conf.post_action, shell=True)


if __name__ == '__main__':
    Copier().run()
