package ru.yandex.chemodan.app.docviewer.copy.downloader;

import java.net.URI;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import lombok.EqualsAndHashCode;

import ru.yandex.chemodan.app.docviewer.copy.StorageResourceInfo;
import ru.yandex.chemodan.app.docviewer.copy.TempFileInfo;
import ru.yandex.chemodan.app.docviewer.states.MaxFileSizeChecker;
import ru.yandex.chemodan.app.docviewer.storages.FileLink;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;

public class SynchronizedFileDownloader<T extends FileToDownload> implements FileDownloader<T> {
    private static final Logger logger = LoggerFactory.getLogger(SynchronizedFileDownloader.class);
    private final FileDownloader<T> delegate;

    private final ConcurrentMap<FileLink, Mutex> downloadingFiles = new ConcurrentHashMap<>();

    public SynchronizedFileDownloader(FileDownloader<T> delegate) {
        this.delegate = delegate;
    }

    @Override
    public T buildFileToDownload(URI uri, StorageResourceInfo copierResponse) {
        return delegate.buildFileToDownload(uri, copierResponse);
    }

    @Override
    public TempFileInfo download(T fileToDownload, MaxFileSizeChecker sizeChecker) {
        FileLink fileLink = fileToDownload.toFileLink();
        if (!needSynchronization(fileToDownload, fileLink)) {
            return delegate.download(fileToDownload, sizeChecker);
        }

        return downloadWithSynchronization(fileToDownload, fileLink, sizeChecker);
    }

    private TempFileInfo downloadWithSynchronization(
            T fileToDownload, FileLink fileLink, MaxFileSizeChecker sizeChecker)
    {
        final Mutex mutex = downloadingFiles.compute(fileLink, (fLink, existing) -> {
            if (existing == null) {
                existing = new Mutex();
            }
            return Mutex.increment(existing);
        });

        if (mutex.counter > 1) {
            logger.info("Counter = {}, starting synchronization for {}", mutex.counter, fileLink.getSerializedPath());
        }

        try {
            synchronized (mutex.lock) {
                if (mutex.counter > 1) {
                    logger.info("Counter = {}, start downloading {}", mutex.counter, fileLink.getSerializedPath());
                }
                return delegate.download(fileToDownload, sizeChecker);
            }
        } finally {
            downloadingFiles.computeIfPresent(fileLink, (lnk, lck) -> {
                Mutex newValue = Mutex.decrement(lck);
                if (newValue.counter > 0) {
                    logger.info("Some threads waiting for {}, counter = {}. Mutex will not be deleted from map",
                            lnk.getSerializedPath(), newValue.counter);
                }

                return newValue.counter == 0 ? null : newValue;
            });
        }
    }

    private boolean needSynchronization(T fileToDownload, FileLink fileLink) {
        if (!(delegate instanceof CachingFileDownloader)) {
            // if there is no caching, then one-by-one downloading has no meaning
            return false;
        }
        CachingFileDownloader<T> cachingDownloader = (CachingFileDownloader<T>) delegate;

        // synchronization only for cacheable downloads without cached value
        return cachingDownloader.isNeedCaching(fileToDownload) && !cachingDownloader.isCached(fileLink);
    }

    @EqualsAndHashCode
    private static class Mutex {
        private final int counter;
        private final Object lock;

        Mutex() {
            this(new Object(), 0);
        }

        private Mutex(Object lock, int counter) {
            this.lock = lock;
            this.counter = counter;
        }

        static Mutex increment(Mutex other) {
            return new Mutex(other.lock, other.counter + 1);
        }

        static Mutex decrement(Mutex other) {
            return new Mutex(other.lock, other.counter - 1);
        }
    }
}
