# -*- coding: utf-8 -*-
# isort:skip_file
import os
import sys

from sqlalchemy.dialects import mysql
from sqlalchemy.schema import (
    CreateIndex,
    CreateTable,
)

os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'passport.backend.oauth.api.settings.default_settings')

from passport.backend.oauth.core.db.eav.schemas import (
    central_metadata,
    shard_metadata,
)

# Схемы проектов подключаем после импорта passport.backend.oauth.core.db.eav.schemas
import passport.backend.oauth.core.db.schemas  # noqa
import passport.backend.oauth.tvm_api.tvm_api.db.schemas  # noqa


REPLACEMENTS = (
    ('id INTEGER NOT NULL AUTO_INCREMENT', 'id BIGINT NOT NULL AUTO_INCREMENT'),
)

IGNORE_REPLACEMENTS_FOR = ('auto_id_client', 'auto_id_tvm_client', 'auto_id_tvm_secret_key')

USER = 'oauth'

CREATE_TABLE_POSTFIX = ' ENGINE=InnoDB DEFAULT CHARSET=latin1'


def create_db(db_name):
    yield 'DROP DATABASE IF EXISTS %s' % db_name
    yield 'CREATE DATABASE %s CHARACTER SET latin1' % db_name
    yield 'USE %s' % db_name


def query_to_str(query):
    return str(
        query.compile(dialect=mysql.dialect()),
    ).replace(' \n', '\n').strip()


def metadata_to_sql(metadata, db_name):
    for entry in create_db(db_name):
        yield entry

    for _, table in sorted(metadata.tables.items()):
        entry = query_to_str(CreateTable(table))
        if table.name not in IGNORE_REPLACEMENTS_FOR:
            for string, replacement in REPLACEMENTS:
                entry = entry.replace(string, replacement)
        yield entry + CREATE_TABLE_POSTFIX
        for index in sorted(table.indexes, key=lambda ind: ind.name):
            yield query_to_str(CreateIndex(index))


def schemas_to_sql(work_dir):
    with open(os.path.join(work_dir, 'schema_central.sql'), 'w') as f:
        for item in metadata_to_sql(central_metadata, 'oauthdbcentral'):
            f.write('%s;\n\n' % item)

    with open(os.path.join(work_dir, 'schema_shard.sql'), 'w') as f:
        for item in metadata_to_sql(shard_metadata, 'oauthdbshard1'):
            f.write('%s;\n\n' % item)


def main():
    if len(sys.argv) < 2:
        work_dir = '../expected/'
    else:
        work_dir = sys.argv[1]

    schemas_to_sql(work_dir)
