#!/usr/bin/env python
import subprocess
import signal
import re
import os
import sys
import logging
import ConfigParser
import boto3
import time
from argparse import ArgumentParser
from datetime import datetime, timedelta


class MdsS3:
    def __init__(self, url, bucket_name, access_key=None, secret_key=None):
        self.url = url
        self.bucket_name = bucket_name
        self.client = boto3.client('s3', endpoint_url=url, aws_access_key_id=access_key,
                                   aws_secret_access_key=secret_key,
                                   use_ssl=False, verify=False)

    def list(self, prefix=''):
        return self.client.list_objects(Bucket=self.bucket_name, Prefix=prefix)

    def download(self, key, f_name):
        self.client.download_file(self.bucket_name, key, f_name)


class CorruptedDataException(Exception):
    pass


class BackupNotFound(Exception):
    pass


def get_run_parameters():
    parser = ArgumentParser(description='Mysqldump restore tool')

    parser.add_argument('--url', dest='url', help="s3 server url", required=True)
    parser.add_argument('--bucket-name', dest='bucket_name', help="s3 bucket name", required=True)
    parser.add_argument('--last', dest='last', action='store_true', help='Take any yesterday backup')
    parser.add_argument('--name', dest='name', help="full name of mds file like '%%Y%%m%%d....tar.xz'")
    parser.add_argument('--backup-date', dest='backup_date',
                        help="backup date for check in format like '%%Y%%m%%d'", required=True)
    parser.add_argument('--backup-number', dest='backup_number',
                        help="number of the backup", required=True)
    parser.add_argument('--keep-files', dest='keep_files', action='store_true',
                        help='Don\'t delete backup files after restore', default=False)

    parser.add_argument('--tmp_dir', dest='tmp_dir',
                        help='tmp dir for backup download and apply log', default='/tmp')
    parser.add_argument('--cnf-file', dest='cnf_file', help='location of mysql config', default='/etc/mysql/my.cnf')
    return parser.parse_args()


def get_backup_name(args):
    if not args.name and not args.last and not (args.backup_date and args.backup_number):
        raise BackupNotFound

    if args.last:
        yesterday = datetime.today() - timedelta(days=1)
        date = yesterday.strftime('%Y%m%d')
        mds = MdsS3(url=args.url, bucket_name=args.bucket_name)
        ls_res = mds.list('mysql-mdb/{}'.format(date))
        return ls_res['Contents'][0]['Key']
    elif args.backup_number and args.backup_date:
        date = args.backup_date
        mds = MdsS3(url=args.url, bucket_name=args.bucket_name)
        ls_res = mds.list('mysql-mdb/{}'.format(date))
        backup_number = int(args.backup_number)
        if backup_number <= len(ls_res['Contents']):
            return ls_res['Contents'][backup_number - 1]['Key']

        raise BackupNotFound
    else:
        return 'mysql-mdb/{}'.format(args.name)


def load_table_sql(dump_path, table_name, sql_path):
    subprocess.call(
        "zcat {} | sed -n '/Table structure for table `{}`/,/UNLOCK TABLES/p; /ALTER TABLE `{}` ENABLE KEYS/q' > {}".format(
            dump_path,
            table_name,
            table_name,
            sql_path
        ),
        shell=True,
        preexec_fn=lambda: signal.signal(signal.SIGPIPE, signal.SIG_DFL)
    )
    subprocess.call("echo ' UNLOCK TABLES;' >> {}".format(sql_path), shell=True)


def main():
    logging.getLogger().setLevel(logging.DEBUG)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger = logging.getLogger()
    logger.addHandler(handler)
    logging.getLogger('boto3').setLevel(logging.CRITICAL)
    logging.getLogger('botocore').setLevel(logging.CRITICAL)
    logging.getLogger('s3transfer').setLevel(logging.CRITICAL)
    logging.getLogger('urllib3').setLevel(logging.CRITICAL)
    logging.info(' Starting mysqldump restore...')
    args = get_run_parameters()
    tmp_name = tmp_dir = None
    try:
        name = get_backup_name(args)
        logging.info('downloading backup {}'.format(name))
        tmp_add_dir = datetime.now().strftime("%Y%m%d%H%M%S")
        tmp_dir_base = args.tmp_dir
        tmp_dir = os.path.join(tmp_dir_base, tmp_add_dir)
        if not os.path.isdir(tmp_dir):
            os.makedirs(tmp_dir)

        tmp_name = '{}/{}'.format(tmp_dir, re.sub('^mysql-mdb/', '', name))
        mds = MdsS3(url=args.url, bucket_name=args.bucket_name)
        mds.download(name, tmp_name)

        config = ConfigParser.ConfigParser(allow_no_value=True)
        config.read(args.cnf_file)
        datadir = config.get('mysqld', 'datadir')
        if not os.path.isdir(datadir):
            os.makedirs(datadir)

        subprocess.call('service mysql stop', shell=True)
        subprocess.call('chown -R mysql:mysql {}/'.format(datadir), shell=True)
        if not os.path.isdir('/var/run/mysqld'):
            os.makedirs('/var/run/mysqld')

        subprocess.call('chown mysql:mysql -R /var/run/mysqld', shell=True)
        subprocess.call('chmod 755 /var/run/mysqld -R', shell=True)

        logging.info('starting mysql...')
        subprocess.call('service mysql start', shell=True)
        # wait mysql to start
        time.sleep(40)

        logging.info('inserting timezones')
        subprocess.call('mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql mysql', shell=True)

        logging.info('getting data from dump')
        banners_sql_path = '{}/{}'.format(tmp_dir, 'banners.sql')
        campaigns_sql_path = '{}/{}'.format(tmp_dir, 'campaigns.sql')
        zone_template_sql_path = '{}/{}'.format(tmp_dir, 'zone_template.sql')

        logging.info('dumping banners')
        load_table_sql(tmp_name, 'banner', banners_sql_path)

        logging.info('dumping campaigns')
        load_table_sql(tmp_name, 'campaign', campaigns_sql_path)

        logging.info('dumping zone_templates')
        load_table_sql(tmp_name, 'zone_template', zone_template_sql_path)

        subprocess.call('mysql -e "drop database if exists adfox; create database adfox;"', shell=True)

        logging.info('inserting banners')
        subprocess.call('cat {} | mysql adfox'.format(banners_sql_path), shell=True)

        logging.info('inserting campaigns')
        subprocess.call('cat {} | mysql adfox'.format(campaigns_sql_path), shell=True)

        logging.info('inserting zone_templates')
        subprocess.call('cat {} | mysql adfox'.format(zone_template_sql_path), shell=True)

        expected_min_date = datetime.strptime(args.backup_date, '%Y%m%d').strftime('%Y-%m-%d')
        logging.info('Expected min last update: {}'.format(expected_min_date))

        logging.info('validating banners')
        last_update = subprocess.check_output(
            'mysql -B --skip-column-names -e "select max(last_update) from banner;" adfox',
            shell=True
        )

        logging.info('Banners last update: {}'.format(last_update))
        if last_update < expected_min_date:
            raise CorruptedDataException

        logging.info('validating campaigns')
        last_update = subprocess.check_output(
            'mysql -B --skip-column-names -e "select max(last_update) from campaign;" adfox',
            shell=True
        )

        logging.info('Campaigns last update: {}'.format(last_update))
        if last_update < expected_min_date:
            raise CorruptedDataException

        logging.info('mysqldump restore finished')
    except Exception:
        raise
    finally:
        if not args.keep_files:
            if tmp_name is not None and os.path.exists(tmp_name):
                subprocess.call('rm {}'.format(tmp_name), shell=True)
            if tmp_dir is not None and os.path.exists(tmp_dir):
                subprocess.call('rm -rf {}'.format(tmp_dir), shell=True)


if __name__ == '__main__':
    main()
