package ru.yandex.direct.mysql.schema;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.regex.Pattern;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonGetter;
import com.fasterxml.jackson.annotation.JsonProperty;

import ru.yandex.direct.mysql.MySQLUtils;

public class DatabaseSchema {
    private static RoutineType[] routineTypes = new RoutineType[]{RoutineType.FUNCTION, RoutineType.PROCEDURE};

    private String name;
    private String createSql;
    private final List<TableSchema> tables;
    private final List<RoutineSchema> routines;

    @JsonCreator
    public DatabaseSchema(@JsonProperty("name") String name, @JsonProperty("create_sql") String createSql,
                          @JsonProperty("tables") List<TableSchema> tables, @JsonProperty("routines") List<RoutineSchema> routines) {
        this.name = Objects.requireNonNull(name);
        this.createSql = Objects.requireNonNull(createSql);
        this.tables = Objects.requireNonNull(tables);
        this.routines = Objects.requireNonNull(routines);
    }

    public static RoutineType[] getRoutineTypes() {
        return routineTypes;
    }

    @JsonGetter("name")
    public String getName() {
        return name;
    }

    @JsonGetter("create_sql")
    public String getCreateSql() {
        return createSql;
    }

    @JsonGetter("tables")
    public List<TableSchema> getTables() {
        return tables;
    }

    @JsonGetter("routines")
    public List<RoutineSchema> getRoutines() {
        return routines;
    }

    public void rename(String name) {
        if (!Objects.equals(name, this.name)) {
            this.name = name;
            String createSql = Pattern.compile(
                    "^(CREATE\\s+DATABASE\\s+(/\\*.*?\\*/\\s*)*)[^\\s]+\\s+",
                    Pattern.CASE_INSENSITIVE
            )
                    .matcher(this.createSql)
                    .replaceFirst("$1" + MySQLUtils.quoteName(name) + " ");
            if (Objects.equals(createSql, this.createSql)) {
                throw new IllegalStateException("Failed to patch createSql: " + createSql);
            }
            this.createSql = createSql;
        }
    }

    public static DatabaseSchema dump(Connection conn, String databaseName) throws SQLException {
        try (MySQLUtils.CatalogGuard ignored = new MySQLUtils.CatalogGuard(conn)) {
            return dumpInternal(conn, databaseName);
        }
    }

    private static DatabaseSchema dumpInternal(Connection conn, String databaseName) throws SQLException {
        conn.setCatalog(databaseName);
        String createSql = null;
        try (PreparedStatement stmt = conn
                .prepareStatement("SHOW CREATE DATABASE IF NOT EXISTS " + MySQLUtils.quoteName(databaseName))) {
            try (ResultSet rs = stmt.executeQuery()) {
                if (rs.next()) {
                    createSql = rs.getString(2);
                }
            }
        }
        if (createSql == null) {
            throw new IllegalStateException("Cannot get sql for CREATE DATABASE statement");
        }
        List<String> tableNames = new ArrayList<>();
        try (PreparedStatement stmt = conn.prepareStatement("SHOW TABLES")) {
            try (ResultSet rs = stmt.executeQuery()) {
                while (rs.next()) {
                    tableNames.add(rs.getString(1));
                }
            }
        }
        List<TableSchema> tables = new ArrayList<>();
        for (String tableName : tableNames) {
            tables.add(TableSchema.dump(conn, tableName));
        }
        List<RoutineSchema> routines = new ArrayList<>();
        for (RoutineType routineType : routineTypes) {
            List<String> routineNames = new ArrayList<>();
            try (PreparedStatement stmt = conn.prepareStatement("SHOW " + routineType + " STATUS WHERE Db = ?")) {
                stmt.setString(1, databaseName);
                try (ResultSet rs = stmt.executeQuery()) {
                    while (rs.next()) {
                        routineNames.add(rs.getString(2));
                    }
                }
            }
            for (String routineName : routineNames) {
                routines.add(RoutineSchema.dump(conn, routineName, routineType));
            }
        }
        return new DatabaseSchema(databaseName, createSql, tables, routines);
    }

    public void restore(Connection conn) throws SQLException {
        MySQLUtils.executeUpdate(conn, createSql);
        try (MySQLUtils.CatalogGuard ignored = new MySQLUtils.CatalogGuard(conn)) {
            conn.setCatalog(name);
            for (TableSchema table : tables) {
                table.restore(conn);
            }
            for (RoutineSchema routine : routines) {
                routine.restore(conn);
            }
        }
    }

    public void restoreViews(Connection conn) throws SQLException {
        try (MySQLUtils.CatalogGuard ignored = new MySQLUtils.CatalogGuard(conn)) {
            conn.setCatalog(name);
            for (TableSchema table : tables) {
                table.restoreViews(conn);
            }
        }
    }

    public boolean schemaEquals(DatabaseSchema that) {
        if (!name.equals(that.name)) {
            return false;
        }
        if (!createSql.equals(that.createSql)) {
            return false;
        }
        if (tables.size() != that.tables.size()) {
            return false;
        }
        for (int i = 0; i < tables.size(); i++) {
            if (!tables.get(i).schemaEquals(that.tables.get(i))) {
                return false;
            }
        }
        return routines.equals(that.routines);

    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof DatabaseSchema)) {
            return false;
        }

        DatabaseSchema that = (DatabaseSchema) o;

        if (!name.equals(that.name)) {
            return false;
        }
        if (!createSql.equals(that.createSql)) {
            return false;
        }
        if (!tables.equals(that.tables)) {
            return false;
        }
        return routines.equals(that.routines);

    }

    @Override
    public int hashCode() {
        int result = name.hashCode();
        result = 31 * result + createSql.hashCode();
        result = 31 * result + tables.hashCode();
        result = 31 * result + routines.hashCode();
        return result;
    }
}
