package ru.yandex.direct.tracing.real;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.direct.tracing.TraceInterceptor;
import ru.yandex.direct.tracing.TraceProfile;
import ru.yandex.direct.tracing.data.TraceDataProfile;
import ru.yandex.direct.tracing.util.TraceClockProvider;
import ru.yandex.direct.tracing.util.TraceUtil;

/**
 * Support for profiling functions
 */
public class RealTraceProfiler {
    private static final double MIN_REST_SECONDS = 0.005;
    private static final Logger logger = LoggerFactory.getLogger(RealTraceProfiler.class);

    private final ConcurrentHashMap<Thread, ThreadState> threads = new ConcurrentHashMap<>();
    private final TraceClockProvider clock;
    private final TraceInterceptor traceInterceptor;
    private volatile String service;
    private volatile String method;

    private static final class Key {
        private final String func;
        private final String tags;

        Key(String func, String tags) {
            if (func == null || tags == null) {
                throw new NullPointerException();
            }
            this.func = func;
            this.tags = tags;
        }

        @Override
        public int hashCode() {
            return 31 * func.hashCode() + tags.hashCode();
        }

        @Override
        public boolean equals(Object other) {
            if (other instanceof Key) {
                Key otherKey = (Key) other;
                return func.equals(otherKey.func) && tags.equals(otherKey.tags);
            }
            return false;
        }
    }

    private static final class Entry {
        private long allEla;
        private long childrenEla;
        private long calls;
        private long objCount;
    }

    private final class Profile implements TraceProfile {
        private final Profile parent;
        private final Key key;
        private final long objCount;
        private final String service;
        private final String method;
        private long startTime;
        private long childrenEla;
        private List<Runnable> onCloseCommands;

        Profile(Profile parent, Key key, String service, String method, long objCount, boolean skipInterception) {
            logger.trace("profile started for {}/{}", key.func, key.tags);
            this.parent = parent;
            this.key = key;
            this.objCount = objCount;
            this.service = service;
            this.method = method;
            startTime = clock.nanoTime();
            if (traceInterceptor != null && !skipInterception) {
                traceInterceptor.checkInterception(this);
            }
        }

        @Override
        public void addCommandOnClose(Runnable cmd) {
            if (onCloseCommands == null) {
                onCloseCommands = new ArrayList<>();
            }
            onCloseCommands.add(cmd);
        }

        void capture(Entry entry, long currentTime, boolean partial) {
            if (currentTime > startTime) {
                long ela = currentTime - startTime;
                if (parent != null) {
                    parent.childrenEla += ela;
                }
                entry.allEla += ela;
                entry.childrenEla += childrenEla;
                startTime = currentTime;
                childrenEla = 0;
            }
            if (!partial) {
                entry.calls += 1;
                entry.objCount += objCount;
            }
        }

        @Override
        public String getService() {
            return service;
        }

        @Override
        public String getMethod() {
            return method;
        }

        @Override
        public String getFunc() {
            return key.func;
        }

        @Override
        public String getTags() {
            return key.tags;
        }

        @Override
        public long getObjCount() {
            return objCount;
        }

        @Override
        public void close() {
            if (onCloseCommands != null) {
                onCloseCommands.forEach(cmd -> {
                    try {
                        cmd.run();
                    } catch (Exception e) {
                        logger.trace("Exception occurred during executing on close command", e);
                    }
                });
            }
            ThreadState state = threads.get(Thread.currentThread());
            if (state == null || state.current != this) {
                throw new UnsupportedOperationException("Trying to close profile that is not active");
            }
            state.pop(clock.nanoTime());
            logger.trace("profile finished for {}/{}", key.func, key.tags);
        }
    }

    private static final class FlushState {
        private final Map<Key, TraceDataProfile> profiles = new HashMap<>();
        private double restSeconds = 0;
    }

    private final class ThreadState {
        private Profile current;
        private long restTime = 0;
        private long restStartTime;
        private boolean restActive;
        private final Map<Key, Entry> accumulated = new HashMap<>();

        void accumulate(Profile call, long currentTime, boolean partial) {
            Entry entry = accumulated.computeIfAbsent(call.key, k -> new Entry());
            call.capture(entry, currentTime, partial);
        }

        synchronized void activate() {
            restStartTime = clock.nanoTime();
            restActive = true;
        }

        synchronized void deactivate() {
            if (current == null) {
                long currentTime = clock.nanoTime();
                restTime += currentTime - restStartTime;
                restStartTime = currentTime;
            }
            restActive = false;
        }

        synchronized Profile push(Key key, long objCount, boolean skipInterception) {
            Profile call = new Profile(current, key, service, method, objCount, skipInterception);
            if (current == null && restActive) {
                restTime += call.startTime - restStartTime;
                restStartTime = call.startTime;
            }
            current = call;
            return call;
        }

        synchronized void pop(final long currentTime) {
            Profile call = current;
            current = call.parent;
            accumulate(call, currentTime, false);
            if (current == null && restActive) {
                restStartTime = currentTime;
            }
        }

        synchronized void flush(final FlushState flushState, final long currentTime) {
            Profile call = current;
            while (call != null) {
                accumulate(call, currentTime, true);
                call = call.parent;
            }
            for (Map.Entry<Key, Entry> kv : accumulated.entrySet()) {
                Key key = kv.getKey();
                TraceDataProfile profile =
                        flushState.profiles.computeIfAbsent(key, k -> new TraceDataProfile(k.func, k.tags));
                Entry entry = kv.getValue();
                profile.setAllEla(profile.getAllEla() + TraceUtil.secondsFromNanoseconds(entry.allEla));
                profile.setChildrenEla(profile.getChildrenEla() + TraceUtil.secondsFromNanoseconds(entry.childrenEla));
                profile.setCalls(profile.getCalls() + entry.calls);
                profile.setObjCount(profile.getObjCount() + entry.objCount);
            }
            accumulated.clear();
            if (current == null && restActive && currentTime > restStartTime) {
                flushState.restSeconds += TraceUtil.secondsFromNanoseconds(currentTime - restStartTime);
                restStartTime = currentTime;
            }
            flushState.restSeconds += TraceUtil.secondsFromNanoseconds(restTime);
            restTime = 0;
        }
    }

    public RealTraceProfiler() {
        this(TraceClockProvider.Default.instance());
    }

    public RealTraceProfiler(TraceClockProvider clock) {
        this(clock, null, null, null);
    }

    public RealTraceProfiler(TraceClockProvider clock, TraceInterceptor traceInterceptor, String service, String method) {
        this.clock = clock;
        this.traceInterceptor = traceInterceptor;
        this.service = service;
        this.method = method;
    }

    private ThreadState currentState() {
        return threads.computeIfAbsent(Thread.currentThread(), k -> new ThreadState());
    }

    public void activate() {
        currentState().activate();
    }

    public void deactivate() {
        currentState().deactivate();
    }

    public TraceProfile profile(String func, String tags, long objCount) {
        return profile(func, tags, objCount, false);
    }

    public TraceProfile profile(String func, String tags, long objCount, boolean skipInterception) {
        Key key = new Key(func, tags);
        ThreadState state = currentState();
        return state.push(key, objCount, skipInterception);
    }

    public List<TraceDataProfile> snapshot(final long currentTime) {
        FlushState flushState = new FlushState();
        threads.forEach((thread, state) -> state.flush(flushState, currentTime));
        List<TraceDataProfile> profiles = new ArrayList<>(flushState.profiles.values());
        if (flushState.restSeconds > MIN_REST_SECONDS) {
            profiles.add(new TraceDataProfile("rest", "", flushState.restSeconds, 0, 1, 0));
        }
        return profiles;
    }

    public List<TraceDataProfile> snapshot() {
        return snapshot(clock.nanoTime());
    }

    void setService(String service) {
        this.service = service;
    }

    void setMethod(String method) {
        this.method = method;
    }
}
