# coding: utf-8

import logging
import os
import random
import subprocess
import time as os_time

from django.db import connection, DatabaseError, transaction, connections
from django.utils.encoding import smart_unicode

from common.data_api.file_wrapper.config import get_wrapper_creator
from common.data_api.file_wrapper.registry import FileType
from common.db.switcher import get_connection_by_role
from common.utils.dump import load_dump_to_database, purge_database

log = logging.getLogger(__name__)


def copy_table(cursor, src_table, dst_table, commit=False):

    cursor.execute('CREATE TABLE %s LIKE %s' % (dst_table, src_table))

    cursor.execute("INSERT INTO %s SELECT * FROM %s" % (dst_table, src_table))

    if commit:
        cursor.execute("COMMIT")


def analyze_all_tables(db_alias):
    connection = connections[db_alias]
    cursor = connection.cursor()
    cursor.execute('SHOW TABLES')
    for (table,) in cursor.fetchall():
        cursor.execute('ANALYZE TABLE {}'.format(connection.ops.quote_name(table)))


class MysqlFileWriter(object):
    """
    Для загрузки нужно использовать следущюий SQL
    LOAD DATA LOCAL INFILE 'file_name'
    INTO TABLE tbl_name
    CHARACTER SET utf8
    FIELDS ENCLOSED BY '"'
    LINES TERMINATED BY '\\n'
    [(col_name_or_user_var,...)]
    [SET col_name = expr,...]

    В конце файла не должно быть пустой строки
    """

    def __init__(self, stream, fields):
        self.stream = stream
        self.fields = fields
        self.has_rows = False

    def writedict(self, rowdict):
        row = [rowdict[name] for name in self.fields]

        self.writerow(row)

    def writerow(self, row):
        if len(row) != len(self.fields):
            raise ValueError('Lenght of row must be %s' % len(self.fields))

        output = '\t'.join(LoadInFileHelper.transform_value(value) for value in row)
        if self.has_rows:
            self.stream.write('\n')
        else:
            self.has_rows = True

        self.stream.write(output)

        # FIXME: возможно можно это и убрать, скорее всего перестраховка
        self.stream.flush()

    def close(self):
        self.stream.close()


class MysqlFileReader(object):
    def __init__(self, stream, fields):
        self.stream = stream
        self.fields = fields

    def __iter__(self):
        return self

    def next(self):
        rowdict = self.readdict()
        if rowdict is None:
            raise StopIteration()

        return rowdict

    def readdict(self):
        row = self.readrow()
        if row is None:
            return None

        return dict(zip(self.fields, row))

    def readrow(self):
        line = self.stream.readline().strip()
        if not line:
            return None

        row = [LoadInFileHelper.simple_restore_value(v) for v in line.split('\t')]

        if len(row) != len(self.fields):
            raise ValueError('Length of row is not equal to length of fields')

        return row


class LoadInFileHelper(object):
    @classmethod
    def quote(cls, value):
        return '"' + cls.escape(value) + '"'

    replace_map = (
        ('\\', r'\\'),
        ('\x00', r'\0'),
        ('\b', r'\b'),
        ('\n', r'\n'),
        ('\r', r'\r'),
        ('\t', r'\t'),
        ('\x1a', r'\Z'),  # \Z ASCII 26 (Control+Z)
        ('"', r'\"')
    )

    reverse_char_map_dict = {
        to[1:]: from_ for from_, to in replace_map
    }

    @classmethod
    def escape(cls, value):
        out_value = value

        for target, replacement in cls.replace_map:
            out_value = out_value.replace(target, replacement)

        return out_value

    @classmethod
    def transform_value(cls, value):
        if value is None:
            return 'NULL'

        if isinstance(value, unicode):
            value = value

        elif isinstance(value, str):
            value = unicode(value, encoding='utf8')

        elif isinstance(value, bool):
            value = unicode(int(value))

        elif isinstance(value, int):
            value = unicode(value)

        else:
            value = unicode(value)

        value = value.encode('utf8')

        return cls.quote(value)

    @classmethod
    def simple_restore_value(cls, value):
        if value == 'NULL':
            return None

        if not (value.startswith('"') and value.endswith('"')):
            raise ValueError('Value must be doublequoted')

        return cls.unescape(value[1:-1]).decode('utf-8')

    @classmethod
    def unescape(cls, value):
        if '\\' not in value:
            return value

        parts = []
        index = 0
        while index < len(value):
            back_slash_index = value.find('\\', index)

            if back_slash_index == -1:
                parts.append(value[index:])
                break

            if back_slash_index > index:
                parts.append(value[index:back_slash_index])

            try:
                next_char = value[back_slash_index + 1]
            except IndexError:
                raise ValueError("Can't unescape %r" % value)

            out_char = cls.reverse_char_map_dict.get(next_char, next_char)
            parts.append(out_char)
            index = back_slash_index + 2

        return ''.join(parts)


class ExecuteAndCheckMixIn(object):
    @classmethod
    def check_message(cls, connection, cursor, ignore_warning_filter=None, ignore_skipped=False):
        """
        Проверяем статус сообщение от mysql

        Records: 0  Deleted: 0  Skipped: 0  Warnings: 0
        Rows matched: 10000  Changed: 10000  Warnings: 0
        """

        mysql_message = connection.connection.info()

        if not mysql_message:
            log.debug(u'Пустое сообщение от mysql')
            return

        log.debug(mysql_message)

        status = {}
        parts = mysql_message.strip().split(':')

        name = parts[0].strip().lower()
        for part in parts[1:]:
            part = part.strip()

            if not part.isdigit():
                value, next_name = part.split(' ', 1)
                value = int(value.strip())

                status[name] = value
                name = next_name.strip().lower()
            else:
                status[name] = int(part)

        if status.get('skipped', 0) > 0:
            message = u'При загрузке некоторые записи были пропущены'
            if ignore_skipped:
                log.warning(message)
            else:
                log.error(message)
                raise DatabaseError('Some rows were skipped')

        if status.get('warnings', 0) > 0:
            log.warning(u'При загрузке поймали warnings')
            cursor.execute('show warnings')
            bad_warnings = []
            for row in cursor.fetchall():
                warning_text = u'\t'.join(smart_unicode(c) for c in row)
                if ignore_warning_filter and ignore_warning_filter(warning_text):
                    log.warning(warning_text)
                else:
                    log.error(warning_text)
                    bad_warnings.append(warning_text)

            if bad_warnings:
                raise DatabaseError('Some bad warnings were gotten:\n{}'.format(
                    '\n'.join(bad_warnings)
                ))

    @classmethod
    def execute_and_check(cls, connection, cursor, sql, params=None, ignore_warning_filter=None, ignore_skipped=False):
        mysqlconn = connection.connection
        if params:
            if len(params) == 1:
                log.debug(sql, smart_unicode(mysqlconn.literal(params[0])))
            else:
                log.debug(sql, *map(lambda x: smart_unicode(mysqlconn.literal(x)), params))

            cursor.execute(sql, params)
        else:
            log.debug(sql)
            cursor.execute(sql)

        cls.check_message(connection, cursor, ignore_warning_filter, ignore_skipped)


class MysqlModelUpdater(ExecuteAndCheckMixIn):
    def __init__(self, model_class, working_dir, fields=None, load_on_context_exit=True,
                 tmp_attrs=None):
        self.model_class = model_class
        self.load_on_context_exit = load_on_context_exit
        self.opts = self.model_class._meta
        self.working_dir = working_dir
        self.tmp_attrs = tmp_attrs or []

        if fields is None:
            fields = list(self.model_class._meta.local_fields)
        else:
            fields = [self.opts.get_field(fname) for fname in fields]

        if self.opts.pk not in fields:
            fields.insert(0, self.opts.pk)

        self.row_length = len(fields) + len(self.tmp_attrs)

        self.fields = fields

        self.column_names = [f.column for f in fields]
        self.all_column_names = self.column_names + list(self.tmp_attrs)

    def get_filepath(self, postfix='original'):
        filepath = os.path.join(
            self.working_dir,
            '{}'.format(self.__class__.__name__.lower()),
            str(os_time.time()),
            str(random.random()),
            '{}_{}.mysql'.format(self.model_class.__name__.lower(), postfix)
        )

        os.makedirs(os.path.dirname(filepath))

        return filepath

    def init(self):
        self.filepath = self.get_filepath()

        self.mysql_writer_stream = open(self.filepath, 'wb')
        self.mysql_writer = MysqlFileWriter(self.mysql_writer_stream, fields=self.all_column_names)

    def add(self, obj):
        rowdict = {
            f.column: getattr(obj, f.attname)
            for f in self.fields
        }

        if self.tmp_attrs:
            for attr in self.tmp_attrs:
                rowdict[attr] = getattr(obj, attr)

        self.mysql_writer.writedict(rowdict)

    def add_row(self, **kwargs):
        self.mysql_writer.writedict(**kwargs)

    def add_dict(self, objdict):
        self.mysql_writer.writedict(objdict)

    def __enter__(self):
        self.init()

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is None:
            self.finalize()

            if self.load_on_context_exit:
                self.load()

    def process(self):
        """
        Шаблонная реализация, ничего не делает.
        """
        if self.tmp_attrs:
            processed_filepath = self.get_filepath('processed')
            with open(processed_filepath, 'wb') as processed_file, open(self.filepath) as f:
                writer = MysqlFileWriter(processed_file, self.column_names)
                reader = MysqlFileReader(f, self.all_column_names)

                for rowdict in reader:
                    writer.writedict(rowdict)

            self.filepath = processed_filepath

    def finalize(self):
        self.mysql_writer_stream.close()

    @transaction.atomic
    def load(self):
        cursor = connection.cursor()

        self.process()

        qn = connection.ops.quote_name

        table_name = '_'.join([
            self.opts.db_table,
            'tmp',
            str(os_time.time()).replace('.', '_').replace(',', '_'),
            str(random.randint(1, 1000))
        ])

        table_output = []
        for f in self.fields:
            col_type = f.db_type(connection=connection)

            field_output = [qn(f.column), col_type]

            if not f.null:
                field_output.append('NOT NULL')
            if f.primary_key:
                if f.get_internal_type() == 'AutoField':
                    # Без автоинкремента
                    field_output = [qn(f.column), 'int NOT NULL PRIMARY KEY']

                else:
                    raise NotImplementedError('Non integer primary key field is not supported')

            table_output.append(' '.join(field_output))

        full_statement = ['CREATE TEMPORARY TABLE' + ' ' + qn(table_name) + ' (']
        for i, line in enumerate(table_output):  # Combine and add commas.
            full_statement.append(
                '    %s%s' % (line, i < len(table_output) - 1 and ',' or '')
            )

        full_statement.append(');')

        create_stmt = u'\n'.join(full_statement)

        self.execute_and_check(connection, cursor, create_stmt)

        load_stmt = '''
LOAD DATA LOCAL INFILE %s
    INTO TABLE {table_name}
    CHARACTER SET utf8
    FIELDS ENCLOSED BY '"'
    LINES TERMINATED BY '\n'
    ({columns})
'''
        load_stmt = load_stmt.format(
            table_name=table_name,
            columns=u' ,'.join(map(qn, self.column_names))
        )

        self.execute_and_check(connection, cursor, load_stmt, [self.filepath])

        update_stmnt = '''
UPDATE {base_table} bt JOIN {table_name} t ON t.{pk_name} = bt.{pk_name} SET {set_list}
'''

        update_stmnt = update_stmnt.format(
            table_name=qn(table_name),
            base_table=qn(self.opts.db_table),
            pk_name=qn(self.opts.pk.column),
            set_list=', '.join('bt.{name} = t.{name}'.format(name=qn(f.column))
                               for f in self.fields if f != self.opts.pk)
        )

        self.execute_and_check(connection, cursor, update_stmnt)


class MysqlModelLoader(ExecuteAndCheckMixIn):
    def __init__(self, model_class, working_dir, load_on_context_exit=True, on_duplicate=None,
                 ignore_warning_filter=None, ignore_skipped=False):
        self.model_class = model_class
        self.load_on_context_exit = load_on_context_exit
        self.opts = self.model_class._meta
        self.working_dir = working_dir
        self.on_duplicate = on_duplicate
        self.ignore_warning_filter = ignore_warning_filter
        self.ignore_skipped = ignore_skipped

        fields = list(self.model_class._meta.local_fields)

        self.fields = fields

        self.column_names = [f.column for f in fields]

    def init(self):
        self.filepath = os.path.join(
            self.working_dir,
            '{}'.format(self.__class__.__name__.lower()),
            str(os_time.time()),
            str(random.random()),
            '{}.mysql'.format(self.model_class.__name__.lower())
        )

        os.makedirs(os.path.dirname(self.filepath))

        self.mysql_writer_stream = open(self.filepath, 'wb')
        self.mysql_writer = MysqlFileWriter(self.mysql_writer_stream, fields=self.column_names)

    def add(self, obj):
        self.mysql_writer.writedict({
            f.column: f.pre_save(obj, add=True)
            for f in self.fields
        })

    def add_row(self, **kwargs):
        self.mysql_writer.writedict(**kwargs)

    def add_dict(self, objdict):
        self.mysql_writer.writedict(objdict)

    def __enter__(self):
        self.init()

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is None:
            self.finalize()

            if self.load_on_context_exit:
                self.load()

    def finalize(self):
        self.mysql_writer_stream.close()

    @transaction.atomic
    def load(self):
        cursor = connection.cursor()

        qn = connection.ops.quote_name

        table_name = self.opts.db_table
        if self.on_duplicate == 'replace':
            options = 'REPLACE'
        elif self.on_duplicate == 'ignore':
            options = 'IGNORE'
        else:
            options = ''

        load_stmt = '''
LOAD DATA LOCAL INFILE %s
    {options}
    INTO TABLE {table_name}
    CHARACTER SET utf8
    FIELDS ENCLOSED BY '"'
    LINES TERMINATED BY '\n'
    ({columns})
'''
        load_stmt = load_stmt.format(
            table_name=table_name,
            columns=u' ,'.join(map(qn, self.column_names)),
            options=options
        )

        self.execute_and_check(connection, cursor, load_stmt, [self.filepath],
                               ignore_warning_filter=self.ignore_warning_filter,
                               ignore_skipped=self.ignore_skipped)


def kill_all_connections_to_db(db_name):
    """

mysql> SHOW PROCESSLIST;
+--------+------------+-----------+-----------+---------+------+-------+------------------+
| Id     | User       | Host      | db        | Command | Time | State | Info             |
+--------+------------+-----------+-----------+---------+------+-------+------------------+
| 145572 | schuprakov | localhost | schup_dev | Query   |    0 | NULL  | show processlist |
+--------+------------+-----------+-----------+---------+------+-------+------------------+

    :param db_name: Database Name

    """
    cursor = connection.cursor()
    cursor.execute('SELECT CONNECTION_ID()')
    current_connection_id = cursor.fetchone()[0]
    cursor.execute('SHOW PROCESSLIST')
    connection_ids = []
    for row in cursor.fetchall():
        if row[3] == db_name and row[0] != current_connection_id:
            connection_ids.append(row[0])

    for cid in connection_ids:
        try:
            cursor.execute('KILL %s', [cid])
        except DatabaseError:
            pass

    for c in connections.all():
        c.close()


def get_mysql_conn_string(conn_params, options=""):
    conn_string_parts = []
    for param, cmd_arg in [('host', '--host'), ('port', '--port'), ('user', '--user'), ('passwd', '--password')]:
        if param in conn_params:
            conn_string_parts.append("{}={}".format(cmd_arg, conn_params[param]))

    if options:
        conn_string_parts.append(options)

    conn_string_parts.append(conn_params['db'])

    return ' '.join(conn_string_parts)


def can_set_gtid_purged_off():
    with open('/dev/null', 'w') as dev_null:
        try:
            subprocess.check_call(
                'mysqldump --help | grep set-gtid-purged', shell=True, stdout=dev_null, stderr=dev_null
            )

            return True
        except subprocess.CalledProcessError:
            log.info(u'Flag --set-gtid-purged is not supported')
            return False


def dump_db_by_alias(role, output_file, schema_only=False):
    conn_params = get_connection_by_role(role).conn_params

    options = "--extended-insert --create-options --max-allowed-packet=128M --net-buffer-length=8M"
    if can_set_gtid_purged_off():
        options += " --set-gtid-purged=OFF"

    if schema_only:
        conn_str = get_mysql_conn_string(conn_params, options + ' --no-data --ignore-table={}.{}'.format(conn_params['db'], 'django_migrations'))
        dump_calls = "{ mysqldump " + conn_str + ";\n"

        conn_str = get_mysql_conn_string(conn_params, options)
        dump_calls += "mysqldump " + conn_str + " {table}; }}".format(table='django_migrations')
    else:
        conn_str = get_mysql_conn_string(conn_params, options)
        dump_calls = "mysqldump " + conn_str

    try:
        subprocess.check_output(
            "(set -o pipefail; {dump_calls} | gzip -c > {output_file})".format(
                output_file=output_file, dump_calls=dump_calls),
            shell=True,
            stderr=subprocess.STDOUT,
            executable='/bin/bash'
        )
    except subprocess.CalledProcessError as ex:
        # перехватываем исключение, т.к. в нем строка подключения к MySQL кототорая не должна попасть в лог
        raise Exception(u'mysqldump returned non-zero exit status {}. {}'.format(ex.returncode, ex.output))


def load_dump(dump_filename):
    db_name = connection.get_db_name()

    log.info(u'Closing all connections to dst database')
    kill_all_connections_to_db(db_name)

    log.info(u'Dropping tables from %s', db_name)
    purge_database(db_name)

    log.info(u'Loading %s to %s', dump_filename, db_name)

    file_wrapper = get_wrapper_creator(FileType.MYSQL_DUMP).get_file_wrapper(dump_filename)
    with file_wrapper.open() as dump_file:
        load_dump_to_database(dump_file, db_name)

    # После манипуляций с удалением и переливкой соединение может оставаться в состоянии
    # 'No database selected', поэтому явно переоткрываем
    connection.close()
    connection.ensure_connection()
