#! /usr/bin/env python
#
# This script is used to generate schema definitions for the tables passed in.
# For each table, sequences, create table, indexes, and constraints will be output.
# Sequences are given the last value plus one by default.
# If a relation is not found, a SQL comment is emitted.
#
# Sample usage for generating sql for a zero downtime migration from a standard rails geneated
# schema
#   ./gen-postgresql -U myuser -H myhost -p myport -d mydb --alter-restart 1000000000 --bigint-id

def parse_args():
    "isolated argument parser."
    import argparse
    parser = argparse.ArgumentParser(description='Generate schema information about tables.',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('tables', metavar='table', nargs='+', help='Tables to generate schema.')
    parser.add_argument('--user', '-U', default='postgres', help='User to login')
    parser.add_argument('--host', '-H', default=None, help='Source database host')
    parser.add_argument('--port', '-p', default=None, help='Port to use')
    parser.add_argument('--database', '-d', help='Database to connect')
    parser.add_argument('--alter-restart', '-a', type=int, default=None,
                        help='Emit ALTER SEQUENCE statements starting at provided restart.')
    parser.add_argument('--bigint-id', '-b', action='store_true', default=False,
                        help='Make all integer id columns bigint.')
    parser.add_argument('--db-s3-glue', action='store_true', default=False,
                        help='Generate db-s3-glue compatible output.')
    return parser.parse_args()

def connect(args):
    "isolated connection opener. pass in the args return a connection."
    import psycopg2
    return psycopg2.connect(dbname=args.database, user=args.user, host=args.host, port=args.port)

COLUMNS_SQL="""
select
  table_schema,
  column_name,
  case when data_type='USER-DEFINED' then udt_name else data_type end,
  character_maximum_length,
  column_default,
  is_nullable
from information_schema.columns
where table_name = %(table)s
order by ordinal_position
"""

INDEXES_SQL="""
select
    i.relname as index_name,
    ix.indisunique as uniqe,
    ix.indisprimary as is_primary,
    array_to_string(array_agg(a.attname), ', ') as column_names
from pg_class t, pg_class i, pg_index ix, pg_attribute a
where
    t.oid = ix.indrelid
    and i.oid = ix.indexrelid
    and a.attrelid = t.oid
    and a.attnum = ANY(ix.indkey)
    and t.relkind = 'r'
    and t.relname = %(table)s
group by
    t.relname,
    i.relname,
    ix.indisprimary,
    ix.indisunique
order by
    t.relname,
    i.relname
"""

# XXX AGB: This can likely be simplified but my first attempt was incorrect.
MULTI_COLUMN_INDEX_SQL="""
select
    a.attname,
    a.attnum,
    ix.indkey
from pg_class t, pg_class i, pg_index ix, pg_attribute a
where
    t.oid = ix.indrelid
    and i.oid = ix.indexrelid
    and a.attrelid = t.oid
    and a.attnum = ANY(ix.indkey)
    and t.relkind = 'r'
    and t.relname = %(table)s
    and i.relname = %(index)s
"""

class Sequence(object):
    "An instance of this class represents a sequence in the database."
    def __init__(self, dbh, name):
        self.name = name
        dbc = dbh.cursor()
        dbc.execute("select last_value + 1 from {}".format(name)) # XXX we want to not escape here
        self.next_value = dbc.fetchone()[0]
        dbc.close()
    def create(self):
        "Return the statement to create the sequence."
        return "create sequence {} start with {};".format(self.name, self.next_value)
    def alter(self, restart):
        "Return and alter statement to reset to restart."
        return "alter sequence {} restart {};".format(self.name, restart)

class Column(object):
    "An instance of this class represents a single column in a table."
    def __init__(self, name, type, max, default, is_nullable):
        self.name = name
        self.type = type
        self.max = max
        self.default = default
        self.is_nullable = is_nullable
    def create(self, bigint_id):
        "Return statement fragment which defines the column."
        type = self.type
        if bigint_id is True and self.name == 'id' and self.type == 'integer':
            type = 'bigint'
        creates = ['  {} {}'.format(self.name, type)]
        if self.max is not None:
            creates.append('({}) '.format(self.max))
        else:
            creates.append(' ')
        if self.default is not None:
            creates.append('default {} '.format(self.default))
        if self.is_nullable is False:
            creates.append('not ')
        creates.append('null')
        return ''.join(creates)

class Index(object):
    "An instance of this class represents an index in a table."
    def __init__(self, name, is_unique, is_primary, column_names):
        self.name = name
        self.is_unique = is_unique
        self.is_primary = is_primary
        self.column_names = column_names
    def create(self, table):
        "Return the statement to create the index."
        creates = ['create']
        if self.is_unique:
            creates.append('unique')
        creates.append('index {} on {} using btree({});'.format(self.name, table, self.column_names))
        return ' '.join(creates)
    def alter(self, table):
        "Return a statement to add a constraint if necessary or return None."
        if self.is_primary:
            return "alter table {} add constraint {}_pkey primary key using index {};".format(table, table, self.name)
        elif self.is_unique:
            return "alter table {} add constraint {}_unique_constraint unique using index {};".format(table, table, self.name)
        else:
            return None

class Table(object):
    "An instance of this class represents a table in the database."
    def __init__(self, dbh, name):
        self.schema = None
        self.name = name
        self.columns = []
        dbc = dbh.cursor()
        dbc.execute(COLUMNS_SQL, {'table':name})
        for row in dbc.fetchall():
            self.schema = row[0]
            is_nullable = True
            if row[5] == "NO":
                is_nullable = False
            self.columns.append(Column(row[1], row[2], row[3], row[4], is_nullable))
        self.indexes = []
        dbc.execute(INDEXES_SQL, {'table':name})
        for row in dbc.fetchall():
            column_names = row[3].split()
            if len(column_names) > 1:
                column_names = Table._index_order(dbc, name, row[0], column_names)
            index = Index(row[0], row[1], row[2], ','.join(column_names))
            if index.is_primary:
                self.indexes.insert(0, index)
            else:
                self.indexes.append(index)
        dbc.close()
    @staticmethod
    def _index_order(dbc, table, index, column_names):
        "Helper method which figures out the order of a multi-column index."
        dbc.execute(MULTI_COLUMN_INDEX_SQL, {'table':table,'index':index})
        for row in dbc.fetchall():
            ordering = row[2].split()
            position = ordering.index(str(row[1]))
            column_names[position] = row[0]
        return column_names
    def create(self, bigint_id):
        "Return the create table statement."
        if self.schema is None:
            return '-- unable to find {}'.format(self.name)
        creates = ['create table {}.{} ('.format(self.schema, self.name)]
        creates.append(',\n'.join([column.create(bigint_id) for column in self.columns]))
        creates.append(");")
        indexes = []
        alters = []
        for index in self.indexes:
            indexes.append(index.create(self.name))
            alter = index.alter(self.name)
            if alter is not None:
                alters.append(alter)
        return '\n'.join(creates + indexes + alters)
    def db_s3_glue(self):
        mapping = {
            'integer': 'int',
            'bigint': 'bigint',
            'character varying': 'string',
            'boolean': 'boolean',
            'timestamp without time zone': 'timestamp',
            'text': 'string',
            'hstore': 'string',
            }
        db = ['    "{}" = <<EOF'.format(self.name), '{', '  "dpu_count": 1,', '  "version": 0,', '  "schema": [']
        cols = ['    {{"name": "{}", "type": "{}" }}'.format(column.name, mapping[column.type]) for column in self.columns]
        db.append(',\n'.join(cols))
        db.extend(['  ]', '}', 'EOF'])
        return '\n'.join(db)

def main():
    import re
    args = parse_args()
    dbh = connect(args)
    tables = [Table(dbh, table) for table in args.tables]
    matcher = re.compile("^nextval\('([a-zA-Z_]+)'::regclass\)$")
    sequence_names = set()
    for table in tables:
        for column in table.columns:
            if column.default is not None:
                match = matcher.match(column.default)
                if match is not None:
                    sequence_names.add(match.group(1))
    sequences = [Sequence(dbh, sequence_name) for sequence_name in sequence_names]
    dbh.close()

    # db_s3_glue is fairly specialized so check for that and exit if
    # that's what the caller wants.
    if args.db_s3_glue is True:
        for table in tables:
            print(table.db_s3_glue())
        return

    # Print out the requested sequences
    print("-- CREATE SEQUENCES")
    print("--")
    print("-- Note that the sequences are set to (last_value+1). If you are doing a")
    print("-- live migration you will need to adjust the `start with` clause")
    print("-- to a number to account for the rate of growth of the relation.")
    print("")
    for sequence in sequences:
        print(sequence.create())
        print("")
    if args.alter_restart is not None:
        print("-- ALTER SEQUENCES")
        print("-- You can use these to do a zero downtime update by doing a reset")
        print("-- to a value outside a range that will be possible to reach prior")
        print("-- to changing the write endpoint.")
        for sequence in sequences:
            print(sequence.alter(args.alter_restart))
            print("")

    # Print out the relations
    print("-- CREATE RELATIONS")
    print("--")
    print("-- Check the indexes on relations to ensure there are no duplicates.")
    print("-- For example an index on (user_id) and one on (user_id) are exact")
    print("-- duplicates and one can be removed. An index on (user_id, created_on)")
    print("-- and one on (user_id) is also a duplicate and you can safely")
    print("-- delete the index on (user_id).")
    print("")
    for table in tables:
        print(table.create(args.bigint_id))
        print("")

if __name__ == '__main__':
    main()
