# -*- coding: utf-8 -*-

from itertools import ifilter, imap

from django.conf import settings
from django.db import models, connection, transaction
from django.db.backends.util import truncate_name
from django.db.models.fields.related import ForeignKey
from django.core.management.base import BaseCommand


qn = connection.ops.quote_name


def rollback_on_exception(func):
    @transaction.commit_manually
    def wrapper(self, *args, **kwargs):
        try:
            func(self, *args, **kwargs)
        except:
            transaction.rollback()
            raise
        else:
            transaction.commit()

    return wrapper


class Command(BaseCommand):

    @rollback_on_exception
    def handle(self, *args, **kwargs):
        c = connection.cursor()
        c.execute('SET FOREIGN_KEY_CHECKS = 0')
        for app in applications():
            migrate_app(app, c)
        c.execute('SET FOREIGN_KEY_CHECKS = 1')


def applications():
    def get_appname(package_name):
        return package_name.split('.')[-1]

    installed_apps = getattr(settings, 'INSTALLED_APPS', [])
    apps = imap(get_appname, installed_apps)
    return ifilter(lambda a: a != 'django_intranet_stuff', apps)


def migrate_app(app_name, c):
    for model in models_of(app_name):
        migrate_model(model, c)


def migrate_model(model, c):
    fields = filter(relates_dis, model._meta._fields())
    for field in fields:
        for obj in model.objects.all():
            if _detect_new_fk(obj.__class__._meta.db_table,
                              field.column,
                              field.rel.to._meta.db_table,
                              field.rel.field_name,
                              c):
                continue
            migrate_field(obj, field, c)
            try:
                obj.save()
            except Exception:
                print 'Removing object', obj.__class__, obj.id, 'because field', field.column, 'cannot be null'
                obj.delete()

        drop_old_fk(model, field, c)
        create_new_fk(model, field, c)


sql_fk =  'SELECT tc.constraint_name FROM' \
          ' information_schema.table_constraints tc,' \
          ' information_schema.key_column_usage kcu' \
          ' WHERE tc.table_name=kcu.table_name' \
          ' AND tc.table_schema=kcu.table_schema' \
          ' AND tc.constraint_name=kcu.constraint_name' \
          " AND tc.constraint_type='FOREIGN KEY'" \
          ' AND tc.table_schema=%s' \
          ' AND tc.table_name=%s' \
          ' AND kcu.column_name=%s'


def drop_old_fk(model, field, c):
    schema = getattr(settings, 'DATABASE_NAME')
    table_name = model._meta.db_table
    field_name = field.column

    r = c.execute(sql_fk, (schema, table_name, field_name))
    if r:
        fk_name = c.fetchone()[0]
        print 'Deleting old fk %s...' % fk_name
        c.execute('ALTER TABLE %s DROP FOREIGN KEY %s' % (qn(table_name), qn(fk_name)))


def create_new_fk(model, field, c):
    """
    FK name generation has been taken from django.db.backends.creation.BaseDatabaseCreation.sql_for_pending_references
    """
    table_name = model._meta.db_table
    field_name = field.column
    ext_table_name = field.rel.to._meta.db_table
    ext_field_name = field.rel.field_name

    if not _detect_new_fk(table_name, field_name, ext_table_name, ext_field_name, c):
        fk_name = _gen_fk_name(table_name, field_name, ext_table_name, ext_field_name)
        print 'Creating new fk %s...' % fk_name
        sql = 'ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)' % (
            qn(table_name),
            qn(truncate_name(fk_name, connection.ops.max_name_length())),
            qn(field_name), qn(ext_table_name), qn(ext_field_name),
        )
        c.execute(sql)


def _detect_new_fk(table_name, field_name, ext_table_name, ext_field_name, c):
    """
    Query has been taken from south.db.mysql.DatabaseOperations.delete_column
    """
    fk_name = _gen_fk_name(table_name, field_name, ext_table_name, ext_field_name)
    return c.execute(sql_fk + ' AND tc.constraint_name=%s', (
        getattr(settings, 'DATABASE_NAME'),
        table_name,
        field_name,
        fk_name
    ))


def _gen_fk_name(table_name, field_name, ext_table_name, ext_field_name):
    """
    Alg has been taken from django.db.backends.creation.BaseDatabaseCreation
    """
    return '%s_refs_%s_%x' % (field_name, ext_field_name, abs(hash((table_name, ext_table_name)) % 4294967296L))


def migrate_field(obj, field, c):
    users_staff_id = getattr(obj, field.column)
    if users_staff_id:
        try:
            if field.rel.to.__name__ == 'Staff':
                login = get_login_by_users_staff_id(users_staff_id, c)
                print 'Login:', login
                intranet_id = field.rel.to.objects.get(login=login).id
            else:
                print 'Column name:', field.column
                print 'users_staff_id:', users_staff_id
                from_staff_id = get_from_staff_id(users_staff_id, field, c)
                print 'from_staff_id:', from_staff_id
                intranet_id = (field.rel.to.objects
                               .get(from_staff_id=from_staff_id).id)
        except field.rel.to.DoesNotExist:
            intranet_id = None

        if intranet_id != users_staff_id:
            setattr(obj, field.column, intranet_id)


def get_from_staff_id(users_staff_id, field, c):
    template = 'SELECT from_staff_id FROM {table} WHERE id={id:d}'
    table = 'users_' + '_'.join(field.rel.to._meta.db_table.split('_')[1:])
    query = template.format(id=users_staff_id,
                            table=table)
    c.execute(query)
    from_staff_id, = c.fetchone()
    return from_staff_id


def get_login_by_users_staff_id(users_staff_id, c):
    template = 'SELECT login FROM users_staff WHERE id={id:d}'
    query = template.format(id=users_staff_id)
    c.execute(query)
    login, = c.fetchone()
    return login


def relates_dis(field):
    return isinstance(field, ForeignKey)\
            and field.rel.to in models_of('django_intranet_stuff')


def models_of(app_name):
    try:
        models_module = __import__(app_name + '.models', {}, {}, ['models'])
        for m in dir(models_module):
            model = getattr(models_module, m, None)
            if (model # Пипец проверочка...
                and getattr(model, '_meta', None)
                and issubclass(model, models.Model)
                and str(model).find(app_name + '.models')>0):
                    yield model
    except ImportError:
        pass
