package ru.yandex.calendar.logic.beans;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;

import org.joda.time.Instant;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.PreparedStatementCallback;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.SingleColumnRowMapper;
import org.springframework.jdbc.core.SqlProvider;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.CollectionF;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.bolts.collection.Tuple2List;
import ru.yandex.bolts.function.Function;
import ru.yandex.calendar.frontend.web.cmd.run.CommandRunException;
import ru.yandex.calendar.util.db.CalendarArgPreparedStatementSetter;
import ru.yandex.calendar.util.db.CalendarJdbcDaoSupport;
import ru.yandex.commune.mapObject.MapField;
import ru.yandex.commune.mapObject.MapObject;
import ru.yandex.commune.mapObject.MapObjectDescription;
import ru.yandex.commune.mapObject.db.MapObjectQueries;
import ru.yandex.misc.db.q.SqlCondition;
import ru.yandex.misc.db.q.SqlLimits;
import ru.yandex.misc.db.q.SqlOrder;
import ru.yandex.misc.db.q.SqlQuery;
import ru.yandex.misc.db.q.SqlQueryUtils;
import ru.yandex.misc.lang.StringUtils;
import ru.yandex.misc.lang.Validate;
import ru.yandex.misc.spring.jdbc.PreparedStatementUtils;
import ru.yandex.misc.spring.jdbc.RowMapperResultSetExtractor2;
import ru.yandex.misc.spring.jdbc.SqlArgsProvider;

/**
 * @author Stepan Koltsov
 */
public class GenericBeanDao extends CalendarJdbcDaoSupport {

    public static final int STREAM_BATCH_SIZE = 1000;

    public <T extends Bean<TId>, TId> Option<T> findBean(BeanHelper<T, TId> helper, SqlCondition condition) {
        return findBean(helper, condition, SqlLimits.all());
    }

    public <T extends Bean<TId>, TId> Option<T> findBean(
            BeanHelper<T, TId> beanHelper, SqlCondition condition, SqlLimits limits)
    {
        String q = "SELECT * FROM " + beanHelper.getTableName() + " WHERE " + condition.sql() + " " + limits.toMysqlLimits();

        if (skipQuery(condition, q, condition.args())) return Option.empty();

        return getJdbcTemplate().queryForOption(q, beanHelper.beanRowMapper(), condition.args());
    }

    public <TId> void deleteBean(Bean<TId> bean) {
        deleteBeanById(bean.getHelper(), bean.getId());
    }

    public int deleteBeans(BeanHelper<? extends Bean<?>, ?> helper, SqlCondition condition) {
        String sql = "DELETE FROM " + helper.getTableName() + condition.whereSql();

        if (skipQuery(condition, sql, condition.args())) return 0;

        return getJdbcTemplate().update(sql, condition.args());
    }

    public <TId> void deleteBeanById(BeanHelper<? extends Bean<TId>, TId> helper, TId id) {
        int rows = deleteBeans(helper, helper.idEqCondition(id));

        if (rows == 0) throw new EmptyResultDataAccessException(1);
        if (rows > 1) throw new IncorrectResultSizeDataAccessException(1, rows);
    }

    public <T extends Bean<TId>, TId> T loadBean(BeanHelper<T, TId> beanHelper, SqlCondition c) {
        Option<T> bean = findBean(beanHelper, c, SqlLimits.all());
        if (!bean.isPresent()) {
            String msg =
                    "Failed to find " + beanHelper.getTableName() + " bean by condition: " +
                            SqlQueryUtils.populateQmsSafe(c.whereSql(), PreparedStatementUtils.flatten(c.args()));
            throw new CommandRunException(msg);
        }
        return bean.get();
    }

    public long countBeans(BeanHelper<?, ?> beanHelper, SqlCondition condition) {
        String q = "SELECT COUNT(1) FROM " + beanHelper.getTableName() + condition.whereSql();

        if (skipQuery(condition, q, condition.args())) return 0L;

        return getJdbcTemplate().queryForLong(q, condition.args());
    }

    public <T extends Bean<TId>, TId> Option<T> findBeanById(BeanHelper<T, TId> beanHelper, TId id) {
        return findBean(beanHelper, beanHelper.idEqCondition(id), SqlLimits.all());
    }

    public <T extends Bean<TId>, TId> T loadBeanById(BeanHelper<T, TId> helper, TId id) {
        return loadBean(helper, helper.idEqCondition(id));
    }

    public <T extends Bean<TId>, TId> ListF<T> loadBeansById(BeanHelper<T, TId> beanHelper, CollectionF<TId> ids) {
        return loadBeans(beanHelper, beanHelper.idInSetCondition(ids));
    }

    public <T extends Bean<TId>, TId, U> ListF<T> loadBeansByField(
            BeanHelper<T, TId> beanHelper, MapField<U> field, CollectionF<U> values)
    {
        return loadBeans(beanHelper, field.column().inSet(values));
    }

    public <T extends Bean<TId>, TId, U> ListF<T> loadBeansByField(
            BeanHelper<T, TId> beanHelper, MapField<U> field, U value)
    {
        return loadBeansByField(beanHelper, field, Cf.list(value));
    }

    /**
     * @deprecated
     */
    @Deprecated
    public <T extends Bean<TId>, TId> ListF<T> loadBeans(
            BeanHelper<T, TId> beanHelper, String tailSql, Object... tailParams)
    {
        String q = "SELECT * FROM " + beanHelper.getTableName() +
            (StringUtils.isNotEmpty(tailSql) ? " WHERE " + tailSql : "");
        return getJdbcTemplate().query(q, beanHelper.beanRowMapper(), tailParams);
    }

    public <T extends Bean<TId>, TId> ListF<T> loadBeans(
            BeanHelper<T, TId> beanHelper, SqlCondition condition, SqlLimits limits)
    {
        return loadBeans(beanHelper, condition, SqlOrder.unordered(), limits);
    }

    public <TId, T extends Bean<TId>> ListF<T> loadBeans(
            BeanHelper<T, TId> beanHelper, SqlCondition condition, SqlOrder order)
    {
        return loadBeans(beanHelper, condition, order, SqlLimits.all());
    }

    public <T extends Bean<TId>, TId> ListF<T> loadBeans(
            BeanHelper<T, TId> beanHelper, SqlCondition condition, SqlOrder order, SqlLimits limits)
    {
        String q = "SELECT * FROM " + beanHelper.getTableName()
                + condition.whereSql()
                + (!order.isUnordered() ? " " + order.toSql() : "")
                + (!limits.isAll() ? " " + limits.toMysqlLimits() : "");

        if (skipQuery(condition, q, condition.args())) return Cf.list();

        return getJdbcTemplate().query(q, beanHelper.beanRowMapper(), condition.args());
    }

    public <T extends Bean<TId>, TId> void loadBeans(
            BeanHelper<T, TId> beanHelper, SqlCondition condition, MapField<Instant> timeMarkerField,
            int rowsPerCallback, Function<? super ListF<T>, ? extends CollectionF<TId>> callback)
    {
        SqlCondition timeCondition = SqlCondition.trueCondition();
        ListF<Bean<TId>> beansToExclude = Cf.arrayList();

        for (;;) {
            SqlCondition cond = condition.and(timeCondition);
            String query = "SELECT * FROM " + beanHelper.getTableName() + cond.whereSql()
                    + " ORDER BY " + timeMarkerField.column().name()
                    + " LIMIT " + STREAM_BATCH_SIZE;

            ListF<T> beans = getJdbcTemplate().queryForList(query, beanHelper.beanClass().getClazz(), cond.args());

            beans.iterator().paginate(rowsPerCallback)
                    .map(bs -> bs.filter(callback.apply(bs).unique().containsF().compose(Bean::getId)))
                    .forEachRemaining(beansToExclude::addAll);

            if (beans.size() < STREAM_BATCH_SIZE) {
                return;
            }
            Instant nextTs = beans.last().getFieldValue(timeMarkerField);

            beansToExclude = Cf.toArrayList(
                    beansToExclude.filterNot(b -> b.getFieldValue(timeMarkerField).isBefore(nextTs)));

            timeCondition = timeMarkerField.ge(nextTs).and(BeanHelper.fieldsInSetCondition(
                    beanHelper.beanMapObjectDescription().getIdFields(),
                    beansToExclude.map(Bean::getId)).not());
        }
    }

    public <T extends Bean<TId>, TId> ListF<T> loadBeans(BeanHelper<T, TId> beanHelper, SqlCondition condition) {
        return loadBeans(beanHelper, condition, SqlLimits.all());
    }

    public <T extends Bean<TId>, TId, U> Option<T> findByField(BeanHelper<T, TId> helper, MapField<U> field, U value) {
        return findBean(helper, field.column().eq(value));
    }

    @SuppressWarnings("unchecked")
    public void updateBeanFields(Bean bean, MapField<?>... fields) {
        Bean temp = bean.getHelper().createBean();
        for (MapField<?> field : fields) {
            temp.setFieldValue((MapField<Object>) field, bean.getFieldValue(field));
        }
        temp.setId(bean.getId());
        updateBean(temp);
    }

    public void updateBean(Bean bean) {
        SqlQuery q = MapObjectQueries.updateSql(bean);
        getJdbcTemplate().updateRow(q.sql(), q.args());
    }

    public int updateBeans(Bean data, SqlCondition condition) {
        SqlQuery q = SqlQueryUtils.updateQuery(
                data.getHelper().getTableName(), data.getFieldValues().map1(f -> f.column().name()), condition);

        if (skipQuery(condition, q.sql(), q.args())) return 0;

        return getJdbcTemplate().update(q.sql(), q.args());
    }

    public void updateRow(Bean bean, SqlCondition condition) {
        int rows = updateBeans(bean, condition);
        if (rows == 0) {
            throw new EmptyResultDataAccessException(1);
        } else if (rows > 1) {
            throw new IncorrectResultSizeDataAccessException(1, rows);
        }
    }

    public <T extends Bean> int[] updateBeans(ListF<T> beans) {
        if (beans.isNotEmpty()) {
            ListF<MapField<?>> fields = beans.first().getSetFields().filter(f -> !f.isPartOfId());

            return batchUpdateBeansFieldsByIds(beans, fields);
        }
        return new int[0];
    }

    public long insertBeanGetGeneratedKey(Bean<Long> bean) {
        return insertBeansBatchGetGeneratedKeys(Cf.list(bean)).single();
    }

    public void insertBean(Bean bean) {
        SqlQuery q = MapObjectQueries.insertSql(bean);
        getJdbcTemplate().update(q.sql(), q.args());
    }

    public <T extends Bean> void insertBeans(ListF<T> beans) {
        if (beans.isNotEmpty()) {
            ListF<MapField<?>> fields = beans.first().getSetFields();
            Validate.forAll(beans, b -> b.getSetFields().equals(fields));

            batchInsertBeans(beans, fields);
        }
    }

    public <T extends Bean> void insertBeans(final ListF<T> beans, final ListF<MapField<?>> fields) {
        batchInsertBeans(beans, fields);
    }

    public <T extends Bean> void insertBeansIgnoreDuplicates(ListF<T> beans) {
        if (beans.isEmpty()) return;

        MapObjectDescription beanDescription = beans.first().getMapObjectDescription();
        ListF<MapField<?>> fields = beans.first().getSetFields();

        String q = SqlQueryUtils.insertQueryString(beanDescription.getTableName(), fields.map(f -> f.column().name()))
                + " ON CONFLICT DO NOTHING";

        executeBatch(beans, fields, beanDescription, q);
    }

    public <T extends Bean> void upsertBeans(ListF<T> beans) {
        if (beans.isEmpty()) return;

        MapObjectDescription beanDescription = beans.first().getMapObjectDescription();

        ListF<MapField<?>> fields = beans.first().getSetFields();
        Function<MapField<?>, String> getName = f -> f.column().name();

        ListF<String> idColumns = fields.filter(MapField::isPartOfId).map(getName);
        ListF<String> dataColumns = fields.filterNot(MapField::isPartOfId).map(getName);

        String q = SqlQueryUtils.insertQueryString(beanDescription.getTableName(), fields.map(getName))
                + " ON CONFLICT " + idColumns.mkString("(", ",", ")")
                + " DO UPDATE SET " + dataColumns.map(c -> c + " = excluded." + c).mkString(", ");

        executeBatch(beans, fields, beanDescription, q);
    }

    public String getInsertOnConflictIgnoreStatement(
            MapObjectDescription beanDescription,
            ListF<MapField<?>> fields,
            int rows) {
        StringBuilder sb = new StringBuilder(SqlQueryUtils.insertQueryString(
                beanDescription.getTableName(),
                fields.map(f -> f.column().name())));

        if (rows > 1) {
            String valuesPlaces = "(" + SqlQueryUtils.qms(fields.size()) + ")";
            sb.append(Cf.repeat(valuesPlaces, rows - 1).mkString(", ", ", ", ""));
        }

        sb.append(" ON CONFLICT DO NOTHING RETURNING ");
        sb.append(String.join(",", beanDescription.getFields().map(MapField::getName)));

        return sb.toString();
    }

    public <T extends Bean<Long>> ListF<Long> insertBeansBatchGetGeneratedKeys(ListF<T> beans) {
        if (beans.isEmpty()) return Cf.list();

        ListF<MapField<?>> fields = beans.first().getSetFields().filter(MapField.isPartOfIdF().notF());

        return insertBeansGetGeneratedKeys(beans, fields).get2();
    }

    public <T extends Bean<Long>> Tuple2List<T, Long> insertBeansGetGeneratedKeys(
            ListF<T> beans, ListF<MapField<?>> fields)
    {
        if (beans.isEmpty()) {
            return Cf.Tuple2List.cons();
        }

        MapObjectDescription beanDescription = beans.first().getMapObjectDescription();
        Tuple2<String, Function<MapObject, ListF<Object>>> t = MapObjectQueries.batchInsertSql(beanDescription, fields);

        String sql = t._1;
        ListF<ArgsHolder> listOfArgs = beans.map(t._2.andThen(ArgsHolder::new));

        class StatementCreator implements PreparedStatementCreator, SqlProvider {
            public PreparedStatement createPreparedStatement(Connection con) throws SQLException {
                return con.prepareStatement(sql, new String[] {"id"});
            }

            @Override
            public String getSql() {
                return sql + " RETURNING id";
            }
        }

        class StatementCallback implements PreparedStatementCallback<Tuple2List<T, Long>>, SqlArgsProvider {
            public Tuple2List<T, Long> doInPreparedStatement(PreparedStatement ps) throws SQLException {
                for (ArgsHolder args : listOfArgs) {
                    new CalendarArgPreparedStatementSetter(args.args).setValues(ps);
                    ps.addBatch();
                }
                ps.executeBatch();

                ResultSetExtractor<ListF<Long>> extractor = new RowMapperResultSetExtractor2<>(
                        new SingleColumnRowMapper<>(Long.class), listOfArgs.size());

                return beans.zip(extractor.extractData(ps.getGeneratedKeys()));
            }

            @Override
            public Option<ListF<Object>> sqlArgs() {
                return Option.of(listOfArgs.uncheckedCast());
            }
        }

        return getJdbcTemplate().execute(new StatementCreator(), new StatementCallback());
    }

    public <T extends Bean> void batchInsertBeans(
            final ListF<T> beans, final ListF<MapField<?>> fields)
    {
        if (beans.isEmpty()) return;

        MapObjectDescription beanDescription = beans.first().getMapObjectDescription();
        ListF<String> columnNames = fields.map(f -> f.column().name());

        String sql = "INSERT INTO "
                + beanDescription.getTableName() + columnNames.mkString(" (", ", ", ")")
                + " VALUES (" + SqlQueryUtils.qms(columnNames.size()) + ")";

        executeBatch(beans, fields, beanDescription, sql);
    }

    public <T extends Bean> int[] batchUpdateBeansFieldsByIds(final ListF<T> beans, final ListF<MapField<?>> fields) {
        if (beans.isEmpty() || fields.isEmpty()) return Cf.repeat(0, beans.size()).mapToIntArray(i -> i);

        MapObjectDescription beanDescription = beans.first().getMapObjectDescription();

        ListF<MapField<?>> idFields = beanDescription.getIdFields();
        Validate.allElementsOfType(beans, beans.first().getClass());

        Validate.forAll(beans, b -> idFields.forAll(b.isFieldSetF()));
        Validate.forAll(beans, b -> fields.forAll(b.isFieldSetF()));

        ListF<String> columnForUpdateNames = fields.map(f -> f.column().name());
        ListF<String> columnByNames = idFields.map(f -> f.column().name());

        String sql = "UPDATE " + beanDescription.getTableName()
                + " SET " + columnForUpdateNames.map(s -> s + " = ?").mkString(", ")
                + " WHERE " + columnByNames.map(s -> s + " = ?").mkString(" AND ");

        // MapObjectQueries.batchInsertSql() in executeBatch() suits well for batch update case too
        // but we need to append id fields to list of fields, cause WHERE clause goes after SET clause
        return executeBatch(beans, fields.plus(idFields), beanDescription, sql);
    }

    private <T extends Bean> int[] executeBatch(
            final ListF<T> beans, ListF<MapField<?>> fields,
            MapObjectDescription beanDescription, String sql)
    {
        final ListF<ListF<Object>> listOfArgs =
                beans.map(MapObjectQueries.batchInsertSql(beanDescription, fields).get2());

        return batchUpdate(sql, listOfArgs.uncheckedCast());
    }

    public int[] batchUpdate(String sql, ListF<ListF<?>> listOfArgs) {
        class BatchBeansPreparedStatementSetter implements BatchPreparedStatementSetter, SqlArgsProvider {
            @Override
            public void setValues(PreparedStatement ps, int index) throws SQLException {
                new CalendarArgPreparedStatementSetter(listOfArgs.get(index)).setValues(ps);
            }

            @Override
            public int getBatchSize() {
                return listOfArgs.size();
            }

            @Override
            public Option<ListF<Object>> sqlArgs() {
                return Option.of(listOfArgs.map(ArgsHolder::new));
            }
        }
        return getJdbcTemplate().batchUpdate(sql, new BatchBeansPreparedStatementSetter());
    }

    // TODO: actually this class is redundant - used only as workaround for
    // ru.yandex.misc.spring.jdbc.intercept.QueryInfoBrokenTest.queryInfoForBatch() (see it in recent iceberg version)
    // remove it when the broken test will be fixed
    private static class ArgsHolder {
        final ListF<?> args;

        public ArgsHolder(ListF<?> args) {
            this.args = args;
        }

        @Override
        public String toString() {
            return args.toString();
        }
    }

} //~
