package stampdb

import (
	"bufio"
	"bytes"
	"context"
	"fmt"
	"os"
	"path/filepath"
	"sync"
	"time"

	"github.com/klauspost/compress/zstd"
	"github.com/mailru/easyjson/jwriter"
	"go.uber.org/zap"
)

const inflightCap = 128 << 10

type Writer struct {
	outDir      string
	deletesTTL  time.Duration
	flushPeriod time.Duration
	maxAssoc    time.Duration
	maxDeletes  int
	mu          sync.Mutex
	log         *zap.Logger
	inFlight    map[string]Entry
	deletes     [][]byte
	writer      jwriter.Writer
	done        chan struct{}
	ctx         context.Context
	cancelFn    context.CancelFunc
}

func NewWriter(outDir string, opts ...WriterOption) (*Writer, error) {
	ctx, cancel := context.WithCancel(context.Background())
	db := &Writer{
		done:        make(chan struct{}),
		ctx:         ctx,
		cancelFn:    cancel,
		deletesTTL:  time.Hour,
		flushPeriod: time.Minute,
		outDir:      outDir,
		log:         zap.NewNop(),
		inFlight:    make(map[string]Entry, inflightCap),
		writer: jwriter.Writer{
			Flags: jwriter.NilSliceAsEmpty | jwriter.NilMapAsEmpty,
		},
	}

	for _, opt := range opts {
		opt(db)
	}

	db.maxDeletes = int(db.deletesTTL / db.flushPeriod)

	go db.loop()
	return db, nil
}

func (w *Writer) AddEntry(key string, info Entry) {
	w.mu.Lock()
	defer w.mu.Unlock()

	if w.isOldRecord(info.AssocTS) {
		w.log.Info("ignore key: too old",
			zap.String("key", key),
			zap.Any("info", info))
		return
	}

	if old, ok := w.inFlight[key]; ok {
		if old.AssocTS > info.AssocTS {
			w.log.Info("ignore key: we have newest record",
				zap.String("key", key),
				zap.Any("our", old),
				zap.Any("their", info))
			return
		}
	}

	w.inFlight[key] = info
}

func (w *Writer) DelEntry(ts int64, hostname string, key string) {
	w.mu.Lock()
	defer w.mu.Unlock()

	entry, ok := w.inFlight[key]
	if !ok {
		w.log.Info("missing key delete", zap.Int64("ts", ts), zap.String("key", key))
		return
	}

	if entry.Hostname != hostname {
		w.log.Info("delete from another radius server is ignored",
			zap.Int64("ts", ts),
			zap.String("key", key),
			zap.String("our_hostname", entry.Hostname),
			zap.String("del_hostname", hostname),
		)
		return
	}

	if err := w.store(ts, entry); err != nil {
		w.log.Error("unable to store", zap.String("key", key), zap.Error(err))
	}
	delete(w.inFlight, key)
}

func (w *Writer) Cleanup(epoch, ts int64, hostname string) {
	w.mu.Lock()
	defer w.mu.Unlock()

	for key, entry := range w.inFlight {
		if entry.Hostname != hostname {
			continue
		}

		if entry.Epoch >= epoch {
			continue
		}

		w.log.Info("remove stale record",
			zap.String("user", entry.Username),
			zap.String("key", key),
		)

		if err := w.store(ts, entry); err != nil {
			w.log.Error("unable to store", zap.String("key", key), zap.Error(err))
		}
		delete(w.inFlight, key)
	}
}

func (w *Writer) Restore() error {
	w.mu.Lock()
	defer w.mu.Unlock()

	w.inFlight = make(map[string]Entry, inflightCap)
	w.deletes = w.deletes[:0]
	dbFile, err := os.Open(filepath.Join(w.outDir, "current.json.zst"))
	if err != nil {
		if os.IsNotExist(err) {
			return nil
		}

		return fmt.Errorf("unable to open snapshot: %w", err)
	}
	defer func() { _ = dbFile.Close() }()

	zstdR, err := zstd.NewReader(dbFile)
	if err != nil {
		return fmt.Errorf("unable to create zstd reader: %w", err)
	}
	defer func() { zstdR.Close() }()

	epoch := time.Now().UnixMilli()
	scanner := bufio.NewScanner(zstdR)
	var deletes bytes.Buffer
	var deletesCount, inFlightCount int
	for scanner.Scan() {
		if bytes.Contains(scanner.Bytes(), []byte{'"', 't', 'o', '_', 't', 's', '"'}) {
			_, _ = deletes.Write(scanner.Bytes())
			_ = deletes.WriteByte('\n')
			deletesCount++
			continue
		}

		var entry ExportedEntry
		if err := entry.UnmarshalJSON(scanner.Bytes()); err != nil {
			return fmt.Errorf("unable to unmarshal line %q: %w", scanner.Text(), err)
		}

		var key string
		switch {
		case entry.IP != "":
			key = entry.IP
		case entry.MAC != "":
			key = entry.MAC
		default:
			w.log.Error("trying to restore unexpected entry", zap.Any("entry", entry))
			continue
		}

		inFlightCount++
		w.inFlight[key] = Entry{
			Epoch:    epoch,
			AssocTS:  entry.FromTS,
			IP:       entry.IP,
			MAC:      entry.MAC,
			Username: entry.Username,
			Hostname: entry.Host,
			Via:      entry.Via.Uint8(),
		}
	}

	if err := scanner.Err(); err != nil {
		w.inFlight = make(map[string]Entry, inflightCap)
		return fmt.Errorf("read failed: %w", err)
	}

	w.deletes = append(w.deletes, deletes.Bytes())
	w.log.Info("restored", zap.Int("deletes", deletesCount), zap.Int("in_flights", inFlightCount))
	return nil
}

func (w *Writer) Shutdown(ctx context.Context) error {
	w.cancelFn()

	select {
	case <-ctx.Done():
		return ctx.Err()
	case <-w.done:
		return nil
	}
}

func (w *Writer) store(ts int64, entry Entry) error {
	(ExportedEntry{
		FromTS:   entry.AssocTS,
		ToTS:     ts,
		IP:       entry.IP,
		MAC:      entry.MAC,
		Host:     entry.Hostname,
		Username: entry.Username,
		Via:      Via(entry.Via),
	}).MarshalEasyJSON(&w.writer)

	w.writer.RawByte('\n')
	return nil
}

func (w *Writer) loop() {
	defer close(w.done)

	for {
		toNextDump := time.Until(
			time.Now().Add(w.flushPeriod).Truncate(w.flushPeriod),
		)

		select {
		case <-w.ctx.Done():
			return
		case <-time.After(toNextDump):
			w.log.Info("starts DB dump")
			if dumpPath, err := w.dump(); err != nil {
				w.log.Warn("dump failed", zap.Error(err))
			} else {
				w.log.Info("DB dumped", zap.String("db_path", dumpPath))
			}
		}
	}
}

func (w *Writer) dump() (string, error) {
	w.mu.Lock()
	defer w.mu.Unlock()

	now := time.Now()
	dbPath := filepath.Join(w.outDir, dbName(now))
	tmpPath := dbPath + ".tmp"
	file, err := os.Create(tmpPath)
	if err != nil {
		return "", fmt.Errorf("unable to create temporary file: %w", err)
	}
	defer func() {
		_ = file.Close()
		_ = os.Remove(tmpPath)
	}()

	if err := w.lockedProcessDeletes(); err != nil {
		return "", fmt.Errorf("unable to process deletes: %w", err)
	}

	for key, entry := range w.inFlight {
		if err := w.store(0, entry); err != nil {
			return "", fmt.Errorf("unable to store key %q: %w", key, err)
		}
	}

	enc, err := zstd.NewWriter(file)
	if err != nil {
		return "", fmt.Errorf("unable to create zstd encoder: %w", err)
	}

	for _, d := range w.deletes {
		if _, err := enc.Write(d); err != nil {
			_ = enc.Close()
			return "", fmt.Errorf("unable to dump deletes: %w", err)
		}
	}

	if _, err := w.writer.DumpTo(enc); err != nil {
		_ = enc.Close()
		return "", fmt.Errorf("unable to dump db: %w", err)
	}

	if err := enc.Close(); err != nil {
		return "", fmt.Errorf("unable to close encoder: %w", err)
	}

	if err := os.Rename(tmpPath, dbPath); err != nil {
		return "", fmt.Errorf("unable to rename db: %w", err)
	}

	currentDBPath := filepath.Join(w.outDir, "current.json.zst")
	if err := os.Symlink(dbPath, currentDBPath+".tmp"); err != nil {
		return "", fmt.Errorf("unable to create symlink: %w", err)
	}

	return dbPath, os.Rename(currentDBPath+".tmp", currentDBPath)
}

func (w *Writer) lockedProcessDeletes() error {
	curDeletes, err := w.writer.BuildBytes()
	if err != nil {
		return err
	}

	w.deletes = append(w.deletes, curDeletes)
	if len(w.deletes) <= w.maxDeletes {
		return nil
	}

	w.deletes = w.deletes[1:]
	return nil
}

func (w *Writer) isOldRecord(ts int64) bool {
	return w.maxAssoc != 0 && time.Now().Sub(time.Unix(ts, 0)) >= w.maxAssoc
}
