import sys
import os
import argparse
from sqlalchemy import inspect
from zenyatta.db.sql import PostgresSQL, MySQL


PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PARENT)))

parser = argparse.ArgumentParser(description="prune rds db to number of rows")

parser.add_argument('-f', '--host', required=True, dest='host',
                    action='store', help='db fqdn')

parser.add_argument('-u', '--user', required=True, dest='user',
                    action='store', help='user with access to db')

parser.add_argument('-p', '--password', required=True, dest='password',
                    action='store', help='user password')

parser.add_argument('-s', '--schema', required=True, dest='schema',
                    action='store', help='db schema')

parser.add_argument('-t', '--port', required=True, dest='port',
                    action='store', help='db port')

parser.add_argument('-r', '--rows', required=True, dest='offset', action='store',
                    help='how many rows to keep')

args = parser.parse_args()

conn_id = args.host.split('.')[0]

dbs = [PostgresSQL(conn_id), MySQL(conn_id)]

engine = None
gadget = None

for db in dbs:
    try:
        engine = db.create_sql_engine(login=args.user, host=args.host, port=args.port,
                                      password=args.password, schema=args.schema, stream_results=False)
        gadget = inspect(engine)
        break
    except TypeError as e:
        engine = None


for table in engine.table_names():
    type_items = gadget.get_columns(table)
    id_field = 'id' if any(item for item in type_items if item['name'] == 'id') else \
        next(item['name'] for item in type_items if 'id_' in item['name'] or '_id' in item['name'])

    with engine.connect() as db:
        row_count = db.execute("SELECT count(*) from " + table).scalar()
        if row_count > 1000:
            print(table + " has this many rows: " + str(row_count))
            db.execute("DELETE FROM {table} WHERE {table}.{id_field} IN "
                       "(SELECT {id_field} FROM {table} ORDER BY {id_field} ASC offset {args.offset})"
                       .format(**locals()))
            row_count = db.execute("SELECT count(*) from " + table).scalar()
            print(table + " and now has this many rows: " + str(row_count))
        else:
            print("skipping {table} because it has {row_count} rows".format(**locals()))
