# coding: utf-8

from contextlib import closing
from keyword import iskeyword
import inspect
import importlib

from psycopg2 import connect

from mail.pypg.pypg.types import DBEnum


def get_enum_values(dsn, type_name):
    with closing(connect(dsn)) as conn:
        cur = conn.cursor()
        cur.execute('SELECT enum_range(NULL::%s)::text[]' % type_name)
        return cur.fetchone()[0]


def make_py_name(name):
    name = name.replace('-', '_')
    if iskeyword(name):
        return name + '_'
    return name


def make_enum_lines(db_enum_values):
    for db_name in db_enum_values:
        yield "    %s = '%s'\n" % (make_py_name(db_name), db_name)


START_TAG = 'ENUM-DECLARATION-START'


def inject_enum_lines(sorce_lines, enum_lines):
    for line in sorce_lines:
        yield line
        if START_TAG in line:
            break
    else:
        raise RuntimeError("Can't find tag %r in source lines" % START_TAG)
    for line in enum_lines:
        yield line


def update_module(postge_dsn, filename, type_name):
    with closing(open(filename)) as fd:
        new_lines = list(
            inject_enum_lines(
                fd,
                make_enum_lines(
                    get_enum_values(
                        postge_dsn, type_name
                    )
                )
            )
        )
    with closing(open(filename, 'w')) as fd:
        fd.write(''.join(new_lines))


def locate_db_enums():
    db_enums = importlib.import_module('pymdb.types.db_enums')
    for _, obj in inspect.getmembers(db_enums):
        if inspect.isclass(obj) and issubclass(obj, DBEnum) and obj is not DBEnum:
            obj_module = importlib.import_module(obj.__module__)
            # crete py from pyC and pyO
            module_file = obj_module.__file__.rstrip('co')
            yield module_file, obj.name_in_db()


def main():
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument(
        'dsn',
        help='Postgre connection string'
    )
    args = parser.parse_args()

    for filename, type_name in locate_db_enums():
        update_module(
            args.dsn,
            filename,
            type_name)

if __name__ == '__main__':
    main()
