# coding=utf-8
import csv
import logging
import os

import wiki

PAGE_TEMPLATE = u'''
====Колонки====

{{{{grid page="{}" width="100%" }}}}

====Индексы====

{{{{grid page="{}" width="100%" }}}}

====Внешние ключи====

{{{{grid page="{}" width="100%" }}}}
'''


class ColumnDefinition(object):
    def __init__(self, column_name, is_nullable, data_type, character_maximum_length, comment):
        self.column_name = column_name
        self.is_nullable = is_nullable
        self.data_type = data_type
        self.character_maximum_length = character_maximum_length
        self.comment = comment
        self.default = u''

    def __str__(self):
        return 'ColumnDefinition{ column_name: %s, is_nullable: %s, data_type: %s, character_maximum_length: %s}' % (
            self.column_name,
            self.is_nullable,
            self.data_type,
            self.character_maximum_length
        )


class IndexDefinition(object):
    def __init__(self, table_name, index_name, index_type, index_cols):
        self.table_name = table_name
        self.index_name = index_name
        self.index_type = index_type
        self.index_cols = index_cols

    def __str__(self):
        return 'IndexDefinition{ table_name: %s, index_name: %s, index_type: %s, index_cols: %s }' % (
            self.table_name,
            self.index_name,
            self.index_type,
            self.index_cols
        )


class ForeignKeyDefinition(object):
    def __init__(self, table_name, constraint_name, column_name, foreign_table_name, foreign_column_name):
        self.table_name = table_name
        self.constraint_name = constraint_name
        self.column_name = column_name
        self.foreign_table_name = foreign_table_name
        self.foreign_column_name = foreign_column_name

    def __str__(self):
        return 'ForeignKeyDefinition{{ table_name: {}, constraint_name: {}, column_name: {}, foreign_table_name: {}, foreign_column_name: {}}}'.format(
            self.table_name,
            self.constraint_name,
            self.column_name,
            self.foreign_table_name,
            self.foreign_column_name
        )


class GenerateDatabaseDoc(object):
    def __init__(self,
                 host,
                 port,
                 dbname,
                 user,
                 password,
                 root_path,
                 wiki_api,
                 grid_template,
                 indexes_root_path,
                 indexes_grid_template,
                 page_path,
                 foreign_key_path,
                 foreign_key_grid_template,
                 do_wiki_gen):
        self.host = host
        self.port = port
        self.dbname = dbname
        self.user = user
        self.password = password
        self.root_path = root_path
        self.wiki_api = wiki_api
        self.grid_template = grid_template
        self.indexes_root_path = indexes_root_path
        self.indexes_grid_template = indexes_grid_template
        self.page_path = page_path
        self.foreign_key_path = foreign_key_path
        self.foreign_key_grid_template = foreign_key_grid_template
        self.do_wiki_gen = do_wiki_gen

    def get_entity_path(self, entity):
        return self.root_path + '/' + entity

    def get_indexes_path(self, table_name):
        return self.indexes_root_path + '/' + table_name

    def get_entity_page_path(self, table_name):
        return self.page_path + '/' + table_name

    def get_foreign_key_path(self, table_name):
        return self.foreign_key_path + '/' + table_name

    def __connect(self):
        import psycopg2

        logging.info('Establishing connection to the database...')
        conn = psycopg2.connect(host=self.host, port=self.port, dbname=self.dbname, user=self.user,
                                password=self.password, sslmode='require')
        logging.info('Connection to the database has been established')
        return conn

    def generate(self):
        logging.info('Generating')
        all_columns = {}
        with (self.__connect()) as conn:
            with (conn.cursor()) as cur:
                tables = self.__get_tables(cur)
                for table in tables:
                    logging.info('Getting columns for table: %s', table)
                    all_columns[table] = self.__get_columns(table, cur)

        for table in tables:
            logging.info('Processing table: %s', table)
            columns = all_columns[table]
            file_name, file_path = self.__write_csv(table, columns)

            if self.do_wiki_gen:
                self.wiki_api.create_wiki(self.get_entity_path(table), self.grid_template)
                self.wiki_api.rename(table, self.get_entity_path(table), self.grid_template)
                cache_key = self.wiki_api.upload_csv(file_name, file_path)
                self.wiki_api.import_csv(self.get_entity_path(table), cache_key, {
                    'icolumn_0_to': 100,
                    'icolumn_1_to': 101,
                    'icolumn_2_to': 102,
                    'icolumn_3_to': 103,
                    'icolumn_4_to': 104
                })

    def generate_indexes(self):
        logging.info('Generating indexes')
        with (self.__connect()) as conn:
            with (conn.cursor()) as cur:
                indexes = self.__get_indexes(cur)

        index_map = {}
        for index in indexes:
            logging.info('Processing index: %s', index.index_name)
            if index.table_name in index_map:
                index_map[index.table_name].append(index)
            else:
                index_list = [index]
                index_map[index.table_name] = index_list

        for table_name, indexes in index_map.iteritems():
            logging.info('Processing table: %s', table_name)
            file_name, file_path = self.__write_index_csv(table_name, indexes)

            if self.do_wiki_gen:
                self.wiki_api.create_wiki(self.get_indexes_path(table_name), self.indexes_grid_template)
                self.wiki_api.rename(table_name, self.get_indexes_path(table_name), self.indexes_grid_template)
                cache_key = self.wiki_api.upload_csv(file_name, file_path)
                self.wiki_api.import_csv(self.get_indexes_path(table_name), cache_key, {
                    'icolumn_0_to': 100,
                    'icolumn_1_to': 101,
                    'icolumn_2_to': 102,
                    'icolumn_3_to': 103
                })

    def generate_pages(self):
        logging.info('Generating pages')
        with (self.__connect()) as conn:
            with (conn.cursor()) as cur:
                tables = self.__get_tables(cur)

        for table in tables:
            logging.info('Processing table: %s', table)
            body = PAGE_TEMPLATE.format(self.get_entity_path(table), self.get_indexes_path(table),
                                        self.get_foreign_key_path(table))
            if self.do_wiki_gen:
                self.wiki_api.create_page(self.get_entity_page_path(table), table, body, False)

    def generate_foreign_keys(self):
        with (self.__connect()) as conn:
            with (conn.cursor()) as cur:
                foreign_keys = self.__get_foreign_keys(cur)

        logging.info("Loaded %d foreign_keys", len(foreign_keys))
        foreign_keys_map = {}
        for fk in foreign_keys:
            if fk.table_name not in foreign_keys_map:
                index_list = []
                foreign_keys_map[fk.table_name] = index_list
            else:
                index_list = foreign_keys_map[fk.table_name]
            index_list.append(fk)

        for table_name, foreign_keys in foreign_keys_map.iteritems():
            logging.info('Processing table: %s', table_name)
            file_name, file_path = self.__write_foreign_key_csv(table_name, foreign_keys)
            foreign_key_path = self.get_foreign_key_path(table_name)

            if self.do_wiki_gen:
                self.wiki_api.create_wiki(foreign_key_path, self.foreign_key_grid_template)
                self.wiki_api.rename(table_name, foreign_key_path, self.foreign_key_grid_template)
                cache_key = self.wiki_api.upload_csv(file_name, file_path)
                self.wiki_api.import_csv(foreign_key_path, cache_key, {
                    'icolumn_0_to': 100,
                    'icolumn_1_to': 101,
                    'icolumn_2_to': 102,
                    'icolumn_3_to': 103
                })

    @staticmethod
    def __get_tables(cur):
        cur.execute(
            '''
            SELECT table_name
                from information_schema.tables
                where table_schema = %s
            ''',
            ('public',)
        )
        return [tr[0] for tr in cur.fetchall()]

    @staticmethod
    def __get_indexes(cur):
        cur.execute(
            '''
            SELECT idx.indrelid::regclass::varchar,
                i.relname as indname,
                am.amname as indam,
                array_to_string(ARRAY(
                    SELECT pg_get_indexdef(idx.indexrelid, k + 1, true)
                    FROM generate_subscripts(idx.indkey, 1) as k
                    ORDER BY k
                ), ',') as indkey_names
            FROM   pg_index as idx
                    JOIN   pg_class as i ON     i.oid = idx.indexrelid
                    JOIN   pg_am as am ON     i.relam = am.oid
                    join pg_namespace ns on i.relnamespace = ns.oid
            where ns.nspname = %s
            order by idx.indrelid::regclass::varchar, indname
            ''',
            ('public',)
        )

        indexes = []
        for tr in cur.fetchall():
            table_name = tr[0]
            index_name = tr[1]
            index_type = tr[2]
            index_columns = tr[3]

            indexes.append(IndexDefinition(table_name, index_name, index_type, index_columns))
        return indexes

    @staticmethod
    def __get_foreign_keys(cur):
        cur.execute(
            '''
            SELECT
                tc.table_name,
                tc.constraint_name,
                kcu.column_name,
                ccu.table_name AS foreign_table_name,
                ccu.column_name AS foreign_column_name
            FROM information_schema.table_constraints AS tc
                JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
                JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema
                where tc.table_schema = %s and tc.constraint_type = %s
                order by tc.table_name, tc.constraint_name
            ''',
            ('public', 'FOREIGN KEY',)
        )

        return [ForeignKeyDefinition(*tr) for tr in cur.fetchall()]

    @staticmethod
    def __write_csv(table, columns, directory='csv'):
        if not os.path.exists(directory):
            os.makedirs(directory)

        file_name = table + '.csv'
        file_path = os.path.join(directory, file_name)
        with open(file_path, 'w') as f:
            c = csv.writer(f, delimiter=',')
            for column in columns:
                c.writerow([
                    column.column_name,
                    column.data_type.encode('utf-8'),
                    'true' if not column.is_nullable else '',
                    column.comment.encode('utf-8') if column.comment is not None else '',
                    column.default.encode('utf-8')
                ])

        return file_name, file_path

    @staticmethod
    def __write_index_csv(table, indexes, directory='indexes_csv'):
        if not os.path.exists(directory):
            os.makedirs(directory)

        file_name = table + '.csv'
        file_path = os.path.join(directory, file_name)
        with open(file_path, 'w') as f:
            c = csv.writer(f, delimiter=',')
            for index in indexes:
                c.writerow([
                    index.table_name,
                    index.index_name,
                    index.index_type,
                    index.index_cols
                ])

        return file_name, file_path

    @staticmethod
    def __write_foreign_key_csv(table, foreign_keys, directory='foreign_keys_csv'):
        if not os.path.exists(directory):
            os.makedirs(directory)

        file_name = table + '.csv'
        file_path = os.path.join(directory, file_name)
        with open(file_path, 'w') as f:
            c = csv.writer(f, delimiter=',')
            for foreign_key in foreign_keys:
                c.writerow([
                    foreign_key.column_name,
                    foreign_key.constraint_name,
                    foreign_key.foreign_table_name,
                    foreign_key.foreign_column_name
                ])

        return file_name, file_path

    @staticmethod
    def __get_columns(table, cur):
        cur.execute(
            '''
            select cols.column_name, cols.is_nullable, cols.data_type, cols.character_maximum_length, pg_catalog.col_description(c.oid, cols.ordinal_position::int)
                from information_schema.columns cols
                    join pg_catalog.pg_class c on c.oid = cols.table_name::regclass::oid and c.relname = cols.table_name
                where cols.table_schema = %s and cols.table_name = %s
                order by ordinal_position
            ''',
            ('public', table,)
        )

        columns = []
        for tr in cur.fetchall():
            column_name, is_nullable, data_type, character_maximum_length, comment = tr

            if character_maximum_length is not None:
                data_type = unicode(data_type) + u'(' + unicode(character_maximum_length) + u')'
            else:
                data_type = unicode(data_type)

            is_nullable = True if is_nullable == 'YES' else False

            comment = unicode(comment, 'utf-8') if comment is not None else None

            logging.info(u'%s %s %s %s %s', column_name, is_nullable, data_type, character_maximum_length, comment)

            columns.append(ColumnDefinition(column_name, is_nullable, data_type, character_maximum_length, comment))

        return columns


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)

    oauth_token = os.environ['WIKI_OAUTH_TOKEN']
    database_password = os.environ['DATABASE_PASSWORD']

    wiki_api = wiki.WikiApi('https://wiki-api.yandex-team.ru/_api', oauth_token)

    generateDatabaseDoc = GenerateDatabaseDoc(
        'market-checkouter-test01f.db.yandex.net,market-checkouter-test01i.db.yandex.net,market-checkouter-test01h.db.yandex.net',
        '6432,6432,6432',
        'market_checkouter_test',
        'market_checkouter',
        database_password,
        '/users/timursha/database-auto',
        wiki_api,
        '/users/timursha/autogen/table/',
        '/users/timursha/indexes-auto',
        '/users/timursha/autogen/index',
        '/users/timursha/database-page-auto',
        '/users/timursha/foreign-keys-auto',
        '/users/timursha/autogen/foreign_key',
        True
    )
    generateDatabaseDoc.generate()
    generateDatabaseDoc.generate_indexes()
    generateDatabaseDoc.generate_pages()
    generateDatabaseDoc.generate_foreign_keys()
