package ru.yandex.qe.dispenser.domain.dao.entity;

import java.io.IOException;
import java.sql.Timestamp;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.annotation.Nonnull;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableTable;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import com.google.common.collect.Table;
import com.google.common.collect.TreeMultimap;
import com.healthmarketscience.sqlbuilder.BinaryCondition;
import com.healthmarketscience.sqlbuilder.CustomSql;
import com.healthmarketscience.sqlbuilder.InCondition;
import com.healthmarketscience.sqlbuilder.NotCondition;
import com.healthmarketscience.sqlbuilder.SelectQuery;
import com.healthmarketscience.sqlbuilder.SqlObject;
import com.healthmarketscience.sqlbuilder.UnaryCondition;
import com.healthmarketscience.sqlbuilder.custom.postgresql.PgLimitClause;
import com.healthmarketscience.sqlbuilder.custom.postgresql.PgOffsetClause;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;

import ru.yandex.qe.dispenser.api.util.JsonSerializerBase;
import ru.yandex.qe.dispenser.domain.Entity;
import ru.yandex.qe.dispenser.domain.EntitySpec;
import ru.yandex.qe.dispenser.domain.Project;
import ru.yandex.qe.dispenser.domain.Segment;
import ru.yandex.qe.dispenser.domain.dao.SqlDaoBase;
import ru.yandex.qe.dispenser.domain.dao.SqlUtils;
import ru.yandex.qe.dispenser.domain.dao.quota.QuotaDao;
import ru.yandex.qe.dispenser.domain.dao.segment.SegmentUtils;
import ru.yandex.qe.dispenser.domain.hierarchy.Hierarchy;
import ru.yandex.qe.dispenser.domain.support.EntityOperation;
import ru.yandex.qe.dispenser.domain.support.EntityUsageDiff;
import ru.yandex.qe.dispenser.domain.util.CollectionUtils;
import ru.yandex.qe.dispenser.domain.util.FailMap;
import ru.yandex.qe.dispenser.domain.util.Page;
import ru.yandex.qe.dispenser.domain.util.StreamUtils;

import static ru.yandex.qe.dispenser.domain.dao.entity.SqlEntityDaoUtils.formatEntityTable;
import static ru.yandex.qe.dispenser.domain.dao.entity.SqlEntityDaoUtils.formatEntityUsageTable;
import static ru.yandex.qe.dispenser.domain.dao.entity.SqlEntityDaoUtils.toEntityTableName;
import static ru.yandex.qe.dispenser.domain.dao.entity.SqlEntityDaoUtils.toEntityUsageTableName;
import static ru.yandex.qe.dispenser.domain.util.CollectionUtils.ids;
import static ru.yandex.qe.dispenser.domain.util.ValidationUtils.validateEntityKey;


public class SqlEntityDao extends SqlDaoBase implements IntegratedEntityDao {
    private static final Logger LOG = LoggerFactory.getLogger(SqlEntityDao.class);

    private static final String SELECT_BY_KEY_QUERY_FORMAT = "SELECT * FROM %s WHERE key = :key";
    private static final String SELECT_BY_KEYS_QUERY = "SELECT * FROM %s WHERE key IN (:keys)";
    private static final String SELECT_BY_KEYS_FOR_UPDATE_QUERY = SELECT_BY_KEYS_QUERY + " ORDER BY key FOR UPDATE";

    private static final String CREATE_ENTITY_QUERY = "INSERT INTO %s (key, dimensions) VALUES (:key, :dimensions)";
    private static final String CREATE_EXPIRABLE_ENTITY_QUERY = "INSERT INTO %s (key, dimensions, expiration_time) VALUES (:key, :dimensions, :expirationTime)";

    private static final String REMOVE_QUERY = "DELETE FROM %s WHERE id IN (:ids)";

    private static final String UPDATE_ENTITY_QUERY = "UPDATE %s SET expiration_time = :expirationTime WHERE id = :id";
    private static final String UPDATE_USAGES_QUERY = "INSERT INTO %s (entity_id, project_id, usages) VALUES (:entityId, :projectId, :diff) ON CONFLICT (entity_id, project_id) DO UPDATE SET usages = %s.usages + :diff";
    private static final String SELECT_ENTITY_USAGES_QUERY = "SELECT * FROM %s WHERE entity_id IN (:entityIds)";
    private static final String SELECT_ENTITY_USAGES_FOR_UPDATE_QUERY = "SELECT * FROM %s WHERE entity_id IN (:entityIds) ORDER BY entity_id, project_id FOR UPDATE";
    private static final String CLEAN_USAGES_QUERY = "DELETE FROM %s WHERE entity_id IN (:entityIds) AND usages <= 0";

    private static final SqlObject ID_COLUMN = new CustomSql("id");
    private static final SqlObject ENTITY_ID_COLUMN = new CustomSql("entity_id");
    private static final SqlObject CREATION_TIME_COLUMN = new CustomSql("creation_time");
    private static final SqlObject EXPIRATION_TIME_COLUMN = new CustomSql("expiration_time");

    private static final SqlObject COUNT_COLUMN = new CustomSql("count(*) as c");

    @Autowired
    @Qualifier("quotaDao")
    private QuotaDao quotaDao;

    @NotNull
    @Override
    public QuotaDao getQuotaDao() {
        return quotaDao;
    }

    @Override
    @Transactional(propagation = Propagation.REQUIRED)
    public void doChanges(final @NotNull Collection<EntityOperation> operations) {
        IntegratedEntityDao.super.doChanges(operations);
    }

    @NotNull
    @Override
    @Transactional(propagation = Propagation.MANDATORY)
    public Table<Entity, Project, Integer> getUsages(@NotNull final Collection<Entity> entities) {
        if (entities.isEmpty()) {
            return ImmutableTable.of();
        }
        final Multimap<EntitySpec, Entity> spec2entities = CollectionUtils.toMultimap(entities, Entity::getSpec);
        final Table<Entity, Project, Integer> allUsages = HashBasedTable.create();
        spec2entities.asMap().forEach((spec, specEntities) -> {
            final Map<Long, Entity> id2entity = CollectionUtils.index(specEntities);
            jdbcTemplate.query(formatEntityUsageTable(SELECT_ENTITY_USAGES_QUERY, spec), ImmutableMap.of("entityIds", id2entity.keySet()), rs -> {
                final Entity entity = id2entity.get(rs.getLong("entity_id"));
                final Project project = Hierarchy.get().getProjectReader().read(rs.getLong("project_id"));
                final int usages = rs.getInt("usages");
                allUsages.put(entity, project, usages);
            });
        });
        return allUsages;
    }

    @Override
    @Transactional(propagation = Propagation.MANDATORY)
    public boolean changeUsages(@NotNull final List<EntityUsageDiff> usageDiffs) {
        if (usageDiffs.isEmpty()) {
            return false;
        }
        final Multimap<EntitySpec, EntityUsageDiff> spec2diffs = CollectionUtils.toMultimap(usageDiffs, ud -> ud.getEntity().getSpec());
        spec2diffs.asMap().forEach((spec, diffs) -> {
            final List<Map<String, ?>> columnParams = diffs.stream().map(this::toColumnParams).collect(Collectors.toList());
            final String tableName = toEntityUsageTableName(spec);
            final String sql = String.format(UPDATE_USAGES_QUERY, tableName, tableName);
            jdbcTemplate.batchUpdate(sql, columnParams);
        });
        return true;
    }

    @Override
    @Transactional(propagation = Propagation.MANDATORY)
    public boolean cleanIfNeeded(@NotNull final Collection<Entity> entities) {
        final Multimap<EntitySpec, Entity> spec2entities = CollectionUtils.toMultimap(entities, Entity::getSpec);
        spec2entities.asMap().forEach((spec, specEntities) -> {
            jdbcTemplate.update(formatEntityUsageTable(CLEAN_USAGES_QUERY, spec), ImmutableMap.of("entityIds", ids(specEntities)));
        });
        return true;
    }

    @NotNull
    @Override
    @Transactional(propagation = Propagation.REQUIRED)
    public Set<Entity> filter(@NotNull final Collection<EntitySpec> specs, @NotNull final EntityFilteringParams params) {
        if (specs.isEmpty()) {
            return Collections.emptySet();
        }
        return specs.stream()
                .map(spec -> filter(spec, params))
                .flatMap(Set::stream)
                .collect(Collectors.toSet());
    }

    @NotNull
    public Set<Entity> filter(@NotNull final EntitySpec spec, @NotNull final EntityFilteringParams filteringParams) {
        // Select only ids first and then whole rows with that ids (to improve pagination performance)
        // https://stackoverflow.com/a/6618428
        final SelectQuery resultQuery = getSelectQuery(spec, filteringParams);

        return jdbcTemplate.queryForSet(resultQuery.toString(), getEntityMapper(spec));
    }

    public SelectQuery getSelectQuery(final @NotNull EntitySpec spec, final @NotNull EntityFilteringParams filteringParams) {
        final SelectQuery idQuery = getIdQuery(spec, filteringParams);

        if (filteringParams.getOffset() != null) {
            idQuery.addCustomization(new PgOffsetClause(filteringParams.getOffset()));
        }
        if (filteringParams.getLimit() != null) {
            idQuery.addCustomization(new PgLimitClause(filteringParams.getLimit()));
        }

        return new SelectQuery()
                .addAllColumns()
                .addCustomFromTable(toEntityTableName(spec))
                .addCondition(new InCondition(ID_COLUMN).addObject(idQuery));
    }

    public SelectQuery getIdQuery(final @NotNull EntitySpec spec, final @NotNull EntityFilteringParams filteringParams) {
        final SelectQuery idQuery = new SelectQuery()
                .addCustomColumns(ID_COLUMN)
                .addCustomFromTable(toEntityTableName(spec));
        if (filteringParams.trashOnly()) {
            final SelectQuery entitiesWithUsages = new SelectQuery()
                    .addCustomColumns(ENTITY_ID_COLUMN)
                    .addCustomFromTable(toEntityUsageTableName(spec))
                    .addCondition(new BinaryCondition(BinaryCondition.Op.EQUAL_TO, ENTITY_ID_COLUMN, ID_COLUMN));
            idQuery.addCondition(new NotCondition(new UnaryCondition(UnaryCondition.Op.EXISTS, entitiesWithUsages)));
        }

        if (filteringParams.getCreatedFrom() != null) {
            idQuery.addCondition(BinaryCondition.greaterThanOrEq(CREATION_TIME_COLUMN, new Timestamp(filteringParams.getCreatedFrom())));
        }
        if (filteringParams.getCreatedTo() != null) {
            idQuery.addCondition(BinaryCondition.lessThanOrEq(CREATION_TIME_COLUMN, new Timestamp(filteringParams.getCreatedTo())));
        }
        if (filteringParams.getExpiredFrom() != null) {
            idQuery.addCondition(BinaryCondition.greaterThanOrEq(EXPIRATION_TIME_COLUMN, new Timestamp(filteringParams.getExpiredFrom())));
        }
        if (filteringParams.getExpiredTo() != null) {
            idQuery.addCondition(BinaryCondition.lessThanOrEq(EXPIRATION_TIME_COLUMN, new Timestamp(filteringParams.getExpiredTo())));
        }
        return idQuery;
    }

    @Override
    @NotNull
    public Page<Entity> filterPage(@NotNull final EntitySpec spec, @NotNull final EntityFilteringParams filteringParams) {
        final SelectQuery resultQuery = getSelectQuery(spec, filteringParams);

        final SelectQuery idQuery = getIdQuery(spec, filteringParams);

        final SelectQuery countQuery = new SelectQuery();

        countQuery.addCustomColumns(COUNT_COLUMN);
        countQuery.addCustomFromTable(toEntityTableName(spec));
        countQuery.addCondition(idQuery.getWhereClause());

        final Long count = jdbcTemplate.queryForObject(countQuery.toString(), Collections.emptyMap(), (a, b) -> a.getLong("c"));

        final Set<Entity> entities = jdbcTemplate.queryForSet(resultQuery.toString(), getEntityMapper(spec));
        return Page.of(entities, count);
    }

    @NotNull
    @Override
    @Transactional(propagation = Propagation.REQUIRED)
    public Entity read(@NotNull final Entity.Key key) throws EmptyResultDataAccessException {
        final EntitySpec spec = key.getSpec();
        final Map<String, ?> params = ImmutableMap.of("key", key.getPublicKey());
        try {
            return jdbcTemplate.queryForObject(formatEntityTable(SELECT_BY_KEY_QUERY_FORMAT, spec), params, getEntityMapper(spec));
        } catch (EmptyResultDataAccessException e) {
            throw new EmptyResultDataAccessException("No entity with key " + key, 1, e);
        }
    }

    @NotNull
    @Override
    @Transactional(propagation = Propagation.MANDATORY)
    public Map<Entity.Key, Entity> readAll(@NotNull final Collection<Entity.Key> keys) throws EmptyResultDataAccessException {
        return readAll(SELECT_BY_KEYS_QUERY, keys, true);
    }

    @NotNull
    @Override
    @Transactional(propagation = Propagation.MANDATORY)
    public Map<Entity.Key, Entity> readAllForUpdate(@NotNull final Collection<Entity.Key> keys) throws EmptyResultDataAccessException {
        return readAll(SELECT_BY_KEYS_FOR_UPDATE_QUERY, keys, true);
    }

    @NotNull
    @Override
    @Transactional(propagation = Propagation.MANDATORY)
    public Map<Entity.Key, Entity> readPresent(@NotNull final Collection<Entity.Key> keys) {
        return readAll(SELECT_BY_KEYS_QUERY, keys, false);
    }

    @NotNull
    @Override
    @Transactional(propagation = Propagation.MANDATORY)
    public Map<Entity.Key, Entity> readPresentForUpdate(@NotNull final Collection<Entity.Key> keys) {
        return readAll(SELECT_BY_KEYS_FOR_UPDATE_QUERY, keys, false);
    }

    @NotNull
    @Transactional(propagation = Propagation.MANDATORY)
    private Map<Entity.Key, Entity> readAll(@NotNull final String sql,
                                            @NotNull final Collection<Entity.Key> keys,
                                            final boolean checkAbsentKeys) throws EmptyResultDataAccessException {
        if (keys.isEmpty()) {
            return Collections.emptyMap();
        }

        final Set<Entity.Key> uniqKeys = new HashSet<>(keys);
        final Multimap<EntitySpec, Entity.Key> spec2keys = TreeMultimap.create();
        uniqKeys.forEach(k -> spec2keys.put(k.getSpec(), k));

        final Set<Entity> entities = new HashSet<>();
        spec2keys.asMap().forEach((spec, resourceKeys) -> {
            final Set<String> publicKeys = resourceKeys.stream().map(Entity.Key::getPublicKey).collect(Collectors.toSet());
            final Map<String, Object> params = ImmutableMap.of("keys", publicKeys);
            entities.addAll(jdbcTemplate.query(formatEntityTable(sql, spec), params, getEntityMapper(spec)));
        });

        final Map<Entity.Key, Entity> key2entity = CollectionUtils.toMap(entities, Entity::getKey);
        if (checkAbsentKeys) {
            final Set<Entity.Key> absentKeys = Sets.difference(uniqKeys, key2entity.keySet());
            if (!absentKeys.isEmpty()) {
                throw new EmptyResultDataAccessException("No entities with keys " + absentKeys, uniqKeys.size());
            }
        }
        return key2entity;
    }

    @NotNull
    @Override
    @Transactional(propagation = Propagation.MANDATORY)
    public Entity create(@NotNull final Entity entity) {
        LOG.debug("Creating entity with key '{}'", entity.getKey().getPublicKey());
        entity.setId(jdbcTemplate.insert(computeCreateEntityQuery(entity.getSpec()), toParams(entity)));
        return entity;
    }

    @NotNull
    @Override
    @Transactional(propagation = Propagation.REQUIRED)
    public Map<Entity.Key, Entity> createAll(@NotNull final Collection<Entity> entities) {
        if (entities.size() <= 2) {
            return StreamUtils.toMap(entities.stream().map(this::create), Entity::getKey);
        }
        final Multimap<EntitySpec, Entity> spec2entities = CollectionUtils.toMultimap(entities, Entity::getSpec);
        spec2entities.asMap().forEach((spec, specEntities) -> {
            final SqlParameterSource[] insertions = specEntities.stream()
                    .map(entity -> new MapSqlParameterSource(toParams(entity)))
                    .toArray(SqlParameterSource[]::new);
            jdbcTemplate.batchUpdate(computeCreateEntityQuery(spec), insertions);
        });
        return new FailMap<>();
    }

    @Nonnull
    private String computeCreateEntityQuery(@Nonnull final EntitySpec spec) {
        final String query = spec.isExpirable() ? CREATE_EXPIRABLE_ENTITY_QUERY : CREATE_ENTITY_QUERY;
        return formatEntityTable(query, spec);
    }

    @Override
    @Transactional(propagation = Propagation.REQUIRED)
    public @NotNull Entity read(@NotNull final Long id) throws EmptyResultDataAccessException {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean update(@NotNull final Entity entity) {
        final Map<String, Object> params = new HashMap<>();
        params.put("id", entity.getId());
        params.put("expirationTime", SqlUtils.toTimestamp(entity.getExpirationTime()));
        return jdbcTemplate.update(formatEntityTable(UPDATE_ENTITY_QUERY, entity.getSpec()), params) > 0;
    }

    @Override
    @Transactional(propagation = Propagation.REQUIRED)
    public boolean delete(final @NotNull Entity entity) {
        return deleteAll(Collections.singleton(entity));
    }

    @Override
    @Transactional(propagation = Propagation.REQUIRED)
    public boolean deleteAll(@NotNull final Collection<Entity> entities) {
        if (entities.isEmpty()) {
            return false;
        }
        final Multimap<EntitySpec, Entity> spec2entities = CollectionUtils.toMultimap(entities, e -> e.getKey().getSpec());
        spec2entities.asMap().forEach((spec, specEntities) -> {
            jdbcTemplate.update(formatEntityTable(REMOVE_QUERY, spec), ImmutableMap.of("ids", ids(entities)));
        });
        return true;
    }

    @Override
    @Transactional(propagation = Propagation.MANDATORY)
    public boolean clear() {
        return false;
    }

    @NotNull
    private RowMapper<Entity> getEntityMapper(@NotNull final EntitySpec spec) {
        return (rs, i) -> {
            final Dimension[] dimensions = SqlUtils.fromJsonb((PGobject) rs.getObject("dimensions"), Dimension[].class);
            final Table<Long, Set<String>, Long> resource2dimension = HashBasedTable.create();
            for (final Dimension dimension : dimensions) {
                resource2dimension.put(dimension.getResourceId(), dimension.getSegments(), dimension.getSize());
            }

            final Entity.Builder builder = Entity.builder(rs.getString("key"))
                    .id(rs.getLong("id"))
                    .spec(spec)
                    .creationTime(rs.getTimestamp("creation_time").getTime());
            if (spec.isExpirable()) {
                Optional.ofNullable(rs.getTimestamp("expiration_time"))
                        .map(Timestamp::getTime)
                        .ifPresent(builder::expirationTime);
            }
            Entity.ResourceKey.getKeysForEntitySpec(spec)
                    .forEach(r -> {
                        final Long size = resource2dimension.get(r.getResource().getId(), SegmentUtils.getNonAggregationSegmentKeys(r.getSegments()));
                        builder.dimension(r, size == null ? 0L : size);
                    });
            return builder.build();
        };
    }

    @NotNull
    @Deprecated
    private EntitySpec findMatchingSpec(@NotNull final Collection<EntitySpec> specs, @NotNull final Set<Long> resourceIds) {
        final Stream<EntitySpec> matchingSpecs = specs.stream().filter(spec -> ids(spec.getResources()).containsAll(resourceIds));
        return StreamUtils.requireSingle(matchingSpecs, "There are more than one matching entity spec for resources " + resourceIds + "!");
    }

    @NotNull
    private Map<String, Object> toParams(@NotNull final Entity entity) {
        validateEntityKey(entity.getKey().getPublicKey());
        final Dimension[] dimensions = Entity.ResourceKey.getKeysForEntitySpec(entity.getSpec())
                .map(r -> new Dimension(r.getResource().getId(), entity.getSize(r),
                        r.getSegments().stream().map(Segment::getPublicKey).collect(Collectors.toSet())))
                .toArray(Dimension[]::new);
        final Map<String, Object> params = new HashMap<>();
        params.put("key", entity.getKey().getPublicKey());
        params.put("dimensions", SqlUtils.toJsonb(dimensions));
        if (entity.getSpec().isExpirable()) {
            params.put("expirationTime", SqlUtils.toTimestamp(entity.getExpirationTime()));
        }
        return params;
    }

    @NotNull
    private Map<String, ?> toColumnParams(@NotNull final EntityUsageDiff usageDiff) {
        final Map<String, Object> params = new TreeMap<>();
        params.put("entityId", usageDiff.getEntity().getId());
        params.put("projectId", usageDiff.getProject().getId());
        params.put("diff", usageDiff.getUsages());
        return params;
    }

    @JsonSerialize(using = Serializer.class)
    private static final class Dimension {
        private final long resourceId;
        private final long size;
        private final Set<String> segmentKeys;

        private Dimension(@JsonProperty("resourceId") final long resourceId,
                          @JsonProperty("size") final long size,
                          @Nullable @JsonProperty("segments") final Set<String> segmentKeys) {
            this.resourceId = resourceId;
            this.size = size;
            this.segmentKeys = segmentKeys == null ? Collections.emptySet() : segmentKeys;
        }

        public long getResourceId() {
            return resourceId;
        }

        public long getSize() {
            return size;
        }

        public Set<String> getSegments() {
            return segmentKeys;
        }
    }

    static final class Serializer extends JsonSerializerBase<Dimension> {
        @Override
        public void serialize(@NotNull final Dimension dimension,
                              @NotNull final JsonGenerator jg,
                              @NotNull final SerializerProvider sp) throws IOException {
            jg.writeStartObject();
            jg.writeObjectField("resourceId", dimension.getResourceId());
            jg.writeObjectField("size", dimension.getSize());
            final Set<String> segments = dimension.getSegments();
            if (!segments.isEmpty()) {
                jg.writeObjectField("segments", segments);
            }
            jg.writeEndObject();
        }
    }
}
