package ru.yandex.mail.cerberus.asyncdb.internal;

import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import lombok.SneakyThrows;
import lombok.Value;
import lombok.val;
import one.util.streamex.StreamEx;
import org.jdbi.v3.core.generic.GenericTypes;
import org.jdbi.v3.core.mapper.Nested;
import org.jdbi.v3.core.mapper.reflect.ColumnName;
import org.jdbi.v3.sqlobject.customizer.SqlStatementCustomizer;
import org.jdbi.v3.sqlobject.customizer.SqlStatementCustomizerFactory;
import ru.yandex.mail.cerberus.asyncdb.Alias;
import ru.yandex.mail.cerberus.asyncdb.annotations.ConfigureCrudRepository;
import ru.yandex.mail.cerberus.asyncdb.annotations.Id;
import ru.yandex.mail.cerberus.asyncdb.annotations.Serial;

import javax.annotation.Nullable;
import java.beans.BeanInfo;
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.lang.annotation.Annotation;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;

import static java.util.function.Predicate.not;
import static ru.yandex.mail.cerberus.asyncdb.internal.RepositoryInfoUtils.findEntityType;

public class CrudRepositoryStatementCustomizer implements SqlStatementCustomizerFactory {
    private static final Map<Class<?>, SqlStatementCustomizer> CUSTOMIZER_CACHE = new ConcurrentHashMap<>();

    @Value
    private static class EntityProperty {
        String name;
        String columnName;
        boolean id;
        boolean serial;
    }

    @SneakyThrows
    private static BeanInfo getBeanInfo(Class<?> type) {
        return Introspector.getBeanInfo(type, Object.class);
    }

    private static boolean isIdProperty(PropertyDescriptor descriptor) {
        return descriptor.getReadMethod().isAnnotationPresent(Id.class);
    }

    private static boolean isNestedProperty(PropertyDescriptor descriptor) {
        return descriptor.getReadMethod().isAnnotationPresent(Nested.class);
    }

    private static boolean isSerialProperty(PropertyDescriptor descriptor) {
        return descriptor.getReadMethod().isAnnotationPresent(Serial.class);
    }

    private static String resolveColumnName(PropertyDescriptor descriptor) {
        val annotation = descriptor.getReadMethod().getAnnotation(ColumnName.class);
        if (annotation == null) {
            val snakeCaseConverter = (PropertyNamingStrategy.SnakeCaseStrategy) PropertyNamingStrategy.SNAKE_CASE;
            return snakeCaseConverter.translate(descriptor.getName());
        } else {
            return annotation.value();
        }
    }

    private static boolean isOptional(PropertyDescriptor propertyDescriptor) {
        return propertyDescriptor.getReadMethod().isAnnotationPresent(Nullable.class);
    }

    private static List<EntityProperty> resolveEntityProperties(BeanInfo beanInfo) {
        return StreamEx.of(beanInfo.getPropertyDescriptors())
            .flatMap(descriptor -> {
                val name = descriptor.getName();
                val columnName = resolveColumnName(descriptor);
                val isId = isIdProperty(descriptor);
                val isNested = isNestedProperty(descriptor);
                val isSerial = isSerialProperty(descriptor);

                if (isNested && isId) {
                    throw new IllegalStateException("@Nested @Id properties not supported");
                }

                if (isNested && isSerial) {
                    throw new IllegalStateException("@Nested @Serial properties not supported");
                }

                if (isNested) {
                    val nestedPrefix = descriptor.getReadMethod().getDeclaredAnnotation(Nested.class).value();
                    val columnPrefix = nestedPrefix.isEmpty() ? "" : nestedPrefix + '_';
                    val namePrefix = isOptional(descriptor) ? name + '?' : name;
                    val nestedPropertyType = descriptor.getPropertyType();
                    return StreamEx.of(resolveEntityProperties(getBeanInfo(nestedPropertyType)))
                        .map(nested -> {
                            return new EntityProperty(
                                namePrefix + '.' + nested.getName(),
                                columnPrefix + nested.getColumnName(),
                                false,
                                nested.isSerial()
                            );
                        });
                } else {
                    return Stream.of(new EntityProperty(name, columnName, isId, isSerial));
                }
            })
            .toImmutableList();
    }

    private static List<EntityProperty> findIdProperties(List<EntityProperty> properties) {
        return StreamEx.of(properties)
            .filter(EntityProperty::isId)
            .toImmutableList();
    }

    private static String generateInsertionQuery(List<EntityProperty> properties, String tableName) {
        val columnNames = StreamEx.of(properties)
            .map(EntityProperty::getColumnName)
            .joining(", ");
        val bindingNames = StreamEx.of(properties)
            .map(property -> {
                val bindName = ':' + property.getName();
                if (property.isSerial()) {
                    val sequenceName = tableName + '_' + property.getColumnName() + "_seq";
                    return "COALESCE(" + bindName + ", nextval('" + sequenceName + "'))";
                } else {
                    return bindName;
                }
            })
            .joining(", ");
        return "INSERT INTO " + tableName + " (" + columnNames + ") VALUES (" + bindingNames + ")";
    }

    private static String generateUpdateQuery(List<EntityProperty> properties, String tableName, List<EntityProperty> idProperties,
                                              String columnsList) {
        val bindingsStream = StreamEx.of(properties)
            .filter(not(EntityProperty::isId))
            .map(property -> ':' + property.getName());
        val setExpressions = StreamEx.of(properties)
            .filter(not(EntityProperty::isId))
            .map(EntityProperty::getColumnName)
            .zipWith(bindingsStream)
            .map(pair -> pair.getKey() + " = " + pair.getValue())
            .joining(",\n");
        val resultColumns = StreamEx.of(properties)
            .map(property -> "oldRecords." + property.getColumnName())
            .joining(", ");

        val idComparison = StreamEx.of(idProperties)
            .map(property -> property.getColumnName() + " = :" + property.getName())
            .joining(" AND ");
        val idColumnComparison = StreamEx.of(idProperties)
            .map(EntityProperty::getColumnName)
            .map(name -> "(newRecords." + name + " = oldRecords." + name + ')')
            .joining(" AND ");

        return "UPDATE " + tableName + " newRecords\n"
             + "  SET " + setExpressions + '\n'
             + "FROM (SELECT " + columnsList + " FROM " + tableName + " WHERE (" + idComparison + ") FOR UPDATE) oldRecords\n"
             + "WHERE " + idColumnComparison + '\n'
             + "RETURNING " + resultColumns;
    }

    private static SqlStatementCustomizer createCustomizer(Annotation annotation, Class<?> sqlObjectType) {
        val entityType = GenericTypes.getErasedType(findEntityType(sqlObjectType));
        val entityBeanInfo = getBeanInfo(entityType);
        val entityProperties = resolveEntityProperties(entityBeanInfo);

        val configureAnnotation = (ConfigureCrudRepository) annotation;
        val tableName = configureAnnotation.table();
        val idProperties = findIdProperties(entityProperties);
        val columnsList = StreamEx.of(entityProperties)
            .map(EntityProperty::getColumnName)
            .joining(", ");

        val insertQuery = generateInsertionQuery(entityProperties, tableName);
        val updateQuery = generateUpdateQuery(entityProperties, tableName, idProperties, columnsList);

        val idDefine = StreamEx.of(idProperties)
            .map(EntityProperty::getColumnName)
            .sorted()
            .joining(", ");
        val idBindingsDefine = idProperties.size() == 1
            ? ":id"
            : StreamEx.of(idProperties)
                .map(EntityProperty::getColumnName)
                .sorted()
                .mapFirst(name -> ':' + name)
                .joining(",");

        return statement -> {
            statement.define(Alias.TABLE, tableName);
            statement.define(Alias.ID, idDefine);
            statement.define(Alias.ID_BINDINGS, idBindingsDefine);
            statement.define(Alias.ENTITY_COLUMN_LIST, columnsList);
            statement.define(Alias.INSERTION_QUERY, insertQuery);
            statement.define(Alias.UPDATING_QUERY, updateQuery);
        };
    }

    @Override
    public SqlStatementCustomizer createForType(Annotation annotation, Class<?> sqlObjectType) {
        return CUSTOMIZER_CACHE.computeIfAbsent(sqlObjectType, type -> createCustomizer(annotation, sqlObjectType));
    }
}
