package ru.yandex.ljinx;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;

import com.googlecode.concurrentlinkedhashmap.ConcurrentLinkedHashMap;
import com.googlecode.concurrentlinkedhashmap.EntryWeigher;
import org.apache.http.FormattedHeader;
import org.apache.http.Header;
import org.apache.http.HttpEntity;
import org.apache.http.concurrent.FutureCallback;

import ru.yandex.http.util.nio.BasicAsyncResponseProducerGenerator;
import ru.yandex.http.util.request.function.RequestFunctionValue;
import ru.yandex.http.util.server.ImmutableBaseServerConfig;
import ru.yandex.stater.StatsConsumer;
import ru.yandex.util.string.StringUtils;
import ru.yandex.util.timesource.TimeSource;

public class MemoryCacheStorage extends AbstractCacheStorage {
    private static final int MEMBER_WEIGHT = 8;
    private static final int OBJECT_WEIGHT = 48;
    private static final int BASE_WEIGHT =
        //BasicAsyncResponseProducerGenerator
        OBJECT_WEIGHT + MEMBER_WEIGHT + MEMBER_WEIGHT + MEMBER_WEIGHT
        //BasicAsyncResponseProducerGenerator.StatusLine
        + OBJECT_WEIGHT + MEMBER_WEIGHT + MEMBER_WEIGHT + MEMBER_WEIGHT;

    private final ConcurrentLinkedHashMap<
        RequestFunctionValue, CacheEntry> cache;
    private final long defaultTTL;

    public MemoryCacheStorage(
        final String storageName,
        final ImmutableBaseServerConfig serverConfig,
        final MemoryCacheStorageConfig config)
    {
        super(storageName, serverConfig);
        cache = new ConcurrentLinkedHashMap.Builder
            <RequestFunctionValue, CacheEntry>()
            .concurrencyLevel(Runtime.getRuntime().availableProcessors())
            .maximumWeightedCapacity(config.capacity())
            .weigher(new Weigher())
            .build();
        defaultTTL = config.defaultTTL();
    }

    public long defaultTTL() {
        return defaultTTL;
    }

    @Override
    public void close() throws IOException {
        cache.setCapacity(0);
    }

    public CacheResponse get(
        final ProxyPassSession session,
        final RequestFunctionValue key)
    {
        final CacheResponse value;
        final CacheEntry entry = cache.get(key);
        if (entry != null) {
            long now = TimeSource.INSTANCE.currentTimeMillis();
            if (entry.expireTimeStamp() < now) {
                cache.remove(key, entry);
                session.logger().info(
                    "Cache entry for key <"
                    + key + "> is expired at "
                    + entry.expireTimeStamp() + ", invalidating");
                value = null;
            } else {
                value =
                    new CacheResponse(
                        entry.value(),
                        CacheResponse.CacheType.MEMORY,
                        entry.expireTimeStamp());
            }
        } else {
            value = null;
        }
        return value;
    }

    @Override
    public void get(
        final ProxyPassSession session,
        final FutureCallback<CacheResponse> callback)
    {
        callback.completed(get(session, session.cacheKey()));
    }

    protected void put(
        final RequestFunctionValue key,
        final BasicAsyncResponseProducerGenerator value,
        final long expireTimeStamp)
    {
        cache.put(key, new CacheEntry(value, expireTimeStamp));
    }

    @Override
    public void put(
        final ProxyPassSession session,
        final BasicAsyncResponseProducerGenerator value,
        final FutureCallback<Void> callback)
    {
        Long ttl = session.ttl();
        if (ttl == null) {
            ttl = defaultTTL;
        }
        put(
            session.cacheKey(),
            value,
            TimeSource.INSTANCE.currentTimeMillis() + ttl);
        callback.completed(null);
    }

    @Override
    public void remove(
        final ProxyPassSession session,
        final FutureCallback<Void> callback)
    {
        cache.remove(session.cacheKey());
        callback.completed(null);
    }

    @Override
    public Map<String, Object> status(final boolean verbose) {
        Map<String, Object> status = new LinkedHashMap<>();
        status.put("capacity", cache.capacity());
        status.put("element-count", cache.size());
        status.put("weighted-size", cache.weightedSize());
        return status;
    }

    @Override
    public <E extends Exception> void stats(
        final StatsConsumer<? extends E> statsConsumer)
        throws E
    {
        super.stats(statsConsumer);
        statsConsumer.stat(
            StringUtils.concat(storageName, "-cache-capacity_ammv"),
            cache.capacity());
        statsConsumer.stat(
            StringUtils.concat(storageName, "-cache-element-count_ammv"),
            cache.size());
        statsConsumer.stat(
            StringUtils.concat(storageName, "-cache-weighted-size_ammv"),
            cache.weightedSize());
    }

    private static class Weigher
        implements EntryWeigher<RequestFunctionValue, CacheEntry>
    {
        @Override
        public int weightOf(
            final RequestFunctionValue key,
            final CacheEntry value)
        {
            int size = OBJECT_WEIGHT + key.weight();
            if (value != null) {
                size += weight(value.value());
            }
            return size;
        }

        public int weight(final BasicAsyncResponseProducerGenerator value) {
            int weight = BASE_WEIGHT;
            if (value.headers() != null) {
                for (Header header : value.headers()) {
                    weight += weight(header);
                }
            }
            if (value.entityGenerator() != null) {
                HttpEntity entity = value.entityGenerator().get();
                weight += OBJECT_WEIGHT << 1;
                weight += (int) entity.getContentLength();
                weight += weight(entity.getContentType());
                weight += weight(entity.getContentEncoding());
            }
            return weight;
        }

        public int weight(final Header header) {
            int weight = BASE_WEIGHT;
            if (header == null) {
                return weight;
            }
            weight += OBJECT_WEIGHT + (header.getName().length() << 1);
            if (header instanceof FormattedHeader) {
                weight += OBJECT_WEIGHT + OBJECT_WEIGHT
                    + (((FormattedHeader) header).getBuffer().capacity() << 1);
            } else {
                weight += OBJECT_WEIGHT + (header.getValue().length() << 1);
            }
            return weight;
        }
    }

    public static class CacheEntry {
        private final BasicAsyncResponseProducerGenerator value;
        private final long expireTimeStamp;

        CacheEntry(
            final BasicAsyncResponseProducerGenerator value,
            final long expireTimeStamp)
        {
            this.value = value;
            this.expireTimeStamp = expireTimeStamp;
        }

        public BasicAsyncResponseProducerGenerator value() {
            return value;
        }

        public long expireTimeStamp() {
            return expireTimeStamp;
        }
    }
}
