package ru.yandex.chemodan.app.dataapi.core.mdssnapshot;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;

import org.joda.time.Duration;
import org.joda.time.Instant;
import org.springframework.transaction.TransactionStatus;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.SetF;
import ru.yandex.bolts.function.Function0;
import ru.yandex.chemodan.app.dataapi.api.data.field.DataField;
import ru.yandex.chemodan.app.dataapi.api.data.filter.RecordsFilter;
import ru.yandex.chemodan.app.dataapi.api.data.protobuf.ProtobufDataUtils;
import ru.yandex.chemodan.app.dataapi.api.data.record.DataRecord;
import ru.yandex.chemodan.app.dataapi.api.data.record.DataRecordId;
import ru.yandex.chemodan.app.dataapi.api.data.snapshot.Snapshot;
import ru.yandex.chemodan.app.dataapi.api.db.Database;
import ru.yandex.chemodan.app.dataapi.api.db.handle.DatabaseHandle;
import ru.yandex.chemodan.app.dataapi.api.db.ref.SpecialDatabases;
import ru.yandex.chemodan.app.dataapi.api.user.DataApiUserId;
import ru.yandex.chemodan.app.dataapi.core.dao.support.ShardedTransactionManager;
import ru.yandex.chemodan.app.dataapi.utils.elliptics.EllipticsHelper;
import ru.yandex.commune.db.shard2.Shard2;
import ru.yandex.commune.dynproperties.DynamicProperty;
import ru.yandex.inside.elliptics.EllipticsFileNotFoundException;
import ru.yandex.inside.elliptics.EllipticsUploadState;
import ru.yandex.misc.ExceptionUtils;
import ru.yandex.misc.db.masterSlave.MasterSlaveContextHolder;
import ru.yandex.misc.db.masterSlave.MasterSlavePolicy;
import ru.yandex.misc.env.EnvironmentType;
import ru.yandex.misc.lang.StringUtils;
import ru.yandex.misc.lang.Validate;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.random.Random2;

/**
 * @author Denis Bakharev
 */
public class MdsSnapshotReferenceManager {
    private static final String PARTITIONED_SNAPSHOT_PREFIX = "partitioned:";

    public final DynamicProperty<Boolean> partitioningEnabled =
            new DynamicProperty<>("mds-snaphot-partitioning-enabled",
                    EnvironmentType.getActive() != EnvironmentType.PRODUCTION);

    private final Logger logger = LoggerFactory.getLogger(getClass());
    private final Duration snapshotReferenceDeletionInterval;
    private final EllipticsHelper ellipticsHelper;
    private final ShardedTransactionManager transactionManager;
    private final MdsSnapshotReferenceJdbcDao mdsSnapshotReferenceJdbcDao;
    private final ExecutorService deleteExecutorService;

    private final int partitionSize;

    public MdsSnapshotReferenceManager(
            MdsSnapshotReferenceJdbcDao mdsSnapshotReferenceJdbcDao,
            Duration snapshotReferenceDeletionInterval,
            EllipticsHelper ellipticsHelper,
            ShardedTransactionManager transactionManager,
            int partitionSize)
    {
        this.mdsSnapshotReferenceJdbcDao = mdsSnapshotReferenceJdbcDao;
        this.snapshotReferenceDeletionInterval = snapshotReferenceDeletionInterval;
        this.ellipticsHelper = ellipticsHelper;
        this.transactionManager = transactionManager;
        this.partitionSize = partitionSize;
        this.deleteExecutorService = Executors.newSingleThreadExecutor();
    }

    public void deleteOldSnapshotReferences() {
        Instant timeThreshold = getTimeThresholdForSnapshotReference(Option.empty());
        ListF<MdsSnapshotReference> oldSnapshotReferences =
                mdsSnapshotReferenceJdbcDao.findAllWithCreationTimeLessThan(timeThreshold);
        oldSnapshotReferences.forEach(this::deleteReference);
    }

    public void deleteOldSnapshotReferencesByShard(Shard2 shard) {
        Instant timeThreshold = getTimeThresholdForSnapshotReference(Option.empty());
        ListF<MdsSnapshotReference> oldSnapshotReferences =
                mdsSnapshotReferenceJdbcDao.findAllWithCreationTimeLessThan(timeThreshold, shard);
        oldSnapshotReferences.forEach(this::deleteReference);
    }

    public void deleteOldSnapshotReferencesByShardParallel(Shard2 shard, ExecutorService executor, int semaphoreCount,
                                                           Option<Instant> now) {
        Semaphore semaphore = new Semaphore(semaphoreCount);
        Instant timeThreshold = getTimeThresholdForSnapshotReference(now);
        ListF<MdsSnapshotReference> oldSnapshotReferences =
                mdsSnapshotReferenceJdbcDao.findAllWithCreationTimeLessThan(timeThreshold, shard);
        oldSnapshotReferences.forEach(reference -> deleteOldSnapshotReferenceByParallelExecutor(reference, semaphore, executor));
    }

    public void deleteOldSnapshotReferenceByParallelExecutor(MdsSnapshotReference reference, Semaphore semaphore,
                                                             ExecutorService executor) {
        try {
            semaphore.acquire();
        } catch (InterruptedException e) {
            logger.debug(e);
            return;
        }
        executor.submit(() -> {
            try {
                deleteReference(reference);
            } finally {
                semaphore.release();
            }
        });
    }

    public void createSnapshotReferenceOrUpdateLastRequestTime(Database database) {
        Validate.isTrue(SpecialDatabases.isMdsSnapshotable(database.dbRef()),
                "Not configured to store mds snapshot for database: ", database.dbRef());

        int updatedRowsCount = mdsSnapshotReferenceJdbcDao.updateLastRequestTime(Instant.now(), database);
        if (updatedRowsCount == 0) {
            MdsSnapshotReference snapshotReference = new MdsSnapshotReference(
                    database.handleValue(), database.rev, Instant.now(), Option.empty(), database.uid);
            mdsSnapshotReferenceJdbcDao.insert(snapshotReference);
        }
    }

    public Instant getTimeThresholdForSnapshotReference(Option<Instant> now) {
        return now.getOrElse(Instant::now).minus(snapshotReferenceDeletionInterval);
    }

    public void deleteReference(MdsSnapshotReference reference) {
        if (!isOldSnapshotReference(reference)) {
            logger.info("Skipping the reference {}", reference);
            return;
        }
        MasterSlaveContextHolder.withPolicy(MasterSlavePolicy.RW_M, () -> {
            TransactionStatus transaction = transactionManager.getTransaction(reference.uid);
            try {

                logger.info("Deleting snapshot reference from db: {}", reference);
                mdsSnapshotReferenceJdbcDao.delete(reference);

                logger.info("Deleting binary data from MDS for snapshot reference: {}", reference);
                if (reference.mdsKey.isPresent()) {
                    deleteStoredSnapshot(reference.mdsKey.get());
                }

                transactionManager.commit(reference.uid, transaction);
                logger.info("All data for snapshot reference {} were deleted successfully", reference);
            } catch (Throwable e) {
                ExceptionUtils.throwIfUnrecoverable(e);
                transactionManager.rollback(reference.uid, transaction);

                //данные о старых MdsSnapshotReference возможно будут удалены при следующем вызове, так что просто
                //пишем в лог. Невозможность удалить один MdsSnapshotReference не должно блокировать удаление остальных
                logger.error("Something went wrong when we try delete data for snapshot reference: {}", reference, e);
            }
        });
    }

    private boolean isOldSnapshotReference(MdsSnapshotReference mdsSnapshotReference) {
        return mdsSnapshotReference.lastRequestTime.isBefore(getTimeThresholdForSnapshotReference(Option.empty()));
    }

    public Option<Snapshot> getRevisionSnapshotFromMds(
            DataApiUserId uid, long rev, DatabaseHandle handle, RecordsFilter filter)
    {
        Option<MdsSnapshotReference> mdsSnapshotReferenceO = mdsSnapshotReferenceJdbcDao.find(uid, handle.handle, rev);
        Option<String> keyO = mdsSnapshotReferenceO.flatMapO(ref -> ref.mdsKey);

        try {
            return keyO.map(key -> getSnapshotFromMds(key, uid, handle, filter));
        } catch (EllipticsFileNotFoundException e) {
            // no file in MDS - no snapshot
            // delete reference from database and return empty
            deleteExecutorService.submit(() -> MasterSlaveContextHolder.withPolicy(MasterSlavePolicy.RW_M,
                    () -> mdsSnapshotReferenceO.forEach(mdsSnapshotReferenceJdbcDao::delete)));
            return Option.empty();
        }
    }

    void deleteStoredSnapshot(String key) {
        try {
            if (key.startsWith(PARTITIONED_SNAPSHOT_PREFIX)) {
                key = StringUtils.removeStart(key, PARTITIONED_SNAPSHOT_PREFIX);
                    Snapshot indexSnapshot = ProtobufDataUtils.deserializeSnapshot(ellipticsHelper.download(key),
                            DataApiUserId.parse("1"), new DatabaseHandle("hack", "hack", "hack"));

                    for (String mdsKey : indexSnapshot.getRecord(getIndexRecordId(indexSnapshot.database)).data().data.keys()) {
                        ellipticsHelper.delete(mdsKey);
                    }
                    ellipticsHelper.delete(key);
            } else {
                ellipticsHelper.delete(key);
            }
        } catch (EllipticsFileNotFoundException e) {
            logger.warn("Try to clean but file not found. key={} : {}", key, e);
        }
    }

    private Snapshot getSnapshotFromMds(String key, DataApiUserId uid, DatabaseHandle handle, RecordsFilter filter) {
        Snapshot rawSnapshot;

        if (key.startsWith(PARTITIONED_SNAPSHOT_PREFIX)) {
            key = StringUtils.removeStart(key, PARTITIONED_SNAPSHOT_PREFIX);
            Snapshot indexSnapshot =
                    ProtobufDataUtils.deserializeSnapshot(ellipticsHelper.download(key), uid, handle);

            MapF<String, ListF<String>> collectiondsByMdsKey =
                    indexSnapshot.getRecord(getIndexRecordId(indexSnapshot.database)).data().data
                            .mapValues(DataField::listValue)
                            .mapValues(l -> l.map(DataField::stringValue));

            SetF<String> mdsKeysToDownload = Cf.hashSet();

            if (filter.getCollectionIdCond().isWithIds()) {
                MapF<String, String> mdsKeysByCollection = Cf.hashMap();
                for (String mdsKey : collectiondsByMdsKey.keys()) {
                    for (String collectionId : collectiondsByMdsKey.getTs(mdsKey)) {
                        mdsKeysByCollection.put(collectionId, mdsKey);
                    }
                }

                for (String collectionId : filter.getCollectionIdCond().idsO.get()) {
                    mdsKeysToDownload.addAll(mdsKeysByCollection.getO(collectionId));
                }

            } else {
                mdsKeysToDownload = collectiondsByMdsKey.keySet();
            }

            ListF<DataRecord> records = Cf.arrayList();

            for (String mdsKey : mdsKeysToDownload) {
                records.addAll(ProtobufDataUtils.deserializeSnapshot(ellipticsHelper.download(mdsKey), uid, handle)
                        .records());
            }

            rawSnapshot = new Snapshot(indexSnapshot.database, records);
        } else {
            rawSnapshot = ProtobufDataUtils.deserializeSnapshot(ellipticsHelper.download(key), uid, handle);
        }
        return MdsSnapshotProcessor.applyFilteringAndLimits(rawSnapshot, filter);
    }

    public void saveSnapshotToMdsIfNecessary(DataApiUserId uid, Database db, Function0<Snapshot> getSnapshotFromDbF) {
        if (!SpecialDatabases.isMdsSnapshotable(db.dbRef())) {
            return;
        }

        Option<MdsSnapshotReference> snapshotRefO = mdsSnapshotReferenceJdbcDao.find(uid, db.handleValue(), db.rev);
        if (!snapshotRefO.isPresent()) {
            return;
        }

        MdsSnapshotReference snapshotReference = snapshotRefO.get();
        if (snapshotReference.mdsKey.isPresent() || isOldSnapshotReference(snapshotReference)) {
            return;
        }

        Snapshot snapshot = getSnapshotFromDbF.apply();

        MapF<String, ListF<DataRecord>> recordsByCollection =
                snapshot.records.records().groupBy(DataRecord::getCollectionId);

        ListF<DataRecord> recordsPart = Cf.arrayList();
        ListF<String> currentCollections = Cf.arrayList();
        MapF<String, ListF<String>> mdsKeys = Cf.hashMap();

        try {
            for (String collectionId : recordsByCollection.keys().sorted()) {
                if (partitioningEnabled.get()
                        && recordsPart.isNotEmpty()
                        && recordsPart.size() + recordsByCollection.getTs(collectionId).size() > partitionSize)
                {
                    mdsKeys.put(saveSnapshotToMds(db, new Snapshot(db, recordsPart)), currentCollections);
                    recordsPart = Cf.arrayList();
                    currentCollections = Cf.arrayList();
                }
                recordsPart.addAll(recordsByCollection.getTs(collectionId));
                currentCollections.add(collectionId);
            }
            if (mdsKeys.isEmpty() || recordsPart.isNotEmpty()) {
                mdsKeys.put(saveSnapshotToMds(db, new Snapshot(db, recordsPart)), currentCollections);
            }

            if (mdsKeys.size() == 1) {
                mdsSnapshotReferenceJdbcDao.updateMdsKey(mdsKeys.keys().first(), db);
            } else {
                MapF<String, DataField> data =
                        mdsKeys.mapValues(l -> DataField.list(l.map(DataField::string)));
                DataRecord indexRecord = new DataRecord(db.uid, getIndexRecordId(db), db.rev, data);

                byte[] indexData = ProtobufDataUtils.serialize(new Snapshot(db, Cf.list(indexRecord)));

                String indexMdsKey = PARTITIONED_SNAPSHOT_PREFIX
                        + ellipticsHelper.upload(genMdsFilename(db), indexData).getKey();

                mdsSnapshotReferenceJdbcDao.updateMdsKey(indexMdsKey, db);
            }
        } catch (RuntimeException e) {
            for (String mdsKey : mdsKeys.keys()) {
                try {
                    ellipticsHelper.delete(mdsKey);
                } catch (Throwable e2) {
                    ExceptionUtils.throwIfUnrecoverable(e2);
                    logger.warn("Failed to delete garbage snapshot: {}", e2);
                }
            }

            throw e;
        }
    }

    private DataRecordId getIndexRecordId(Database db) {
        return new DataRecordId(db.dbHandle, "index", "index_record");
    }

    private String saveSnapshotToMds(Database db, Snapshot snapshot) {
        //сериализуем и записываем снапшот в MDS
        byte[] snapshotData = ProtobufDataUtils.serialize(snapshot);
        String filename = genMdsFilename(db);
        EllipticsUploadState uploadState = ellipticsHelper.upload(filename, snapshotData);
        return uploadState.getKey();
    }

    private String genMdsFilename(Database db) {
        return StringUtils.format(
                "handle_{}.rev_{}.{}.shapshot", db.handleValue(), db.rev, Random2.R.nextAlnum(10));
    }
}
