package diskqueue

import (
	"bufio"
	"bytes"
	"context"
	"encoding/binary"
	"fmt"
	"io"
	"math/rand"
	"os"
	"path"
	"sync/atomic"
	"time"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/csp-report/internal/logbroker/message"
)

const (
	defaultMaxMsgSize      = 10 << 10
	defaultMaxBytesPerFile = 50 << (10 * 2)
)

// This version of DiskQueue is heavy based on https://github.com/nsqio/go-diskqueue

type DiskQueue struct {
	// internal files and buffers
	readFile    *os.File
	writeFile   *os.File
	reader      *bufio.Reader
	msgReadBuf  bytes.Buffer
	msgWriteBuf bytes.Buffer
	writeBuf    bytes.Buffer

	// run-time state (also persisted to disk)
	readPos      int64
	writePos     int64
	readFileNum  int64
	writeFileNum int64
	depth        int64

	// instantiation time metadata
	dataPath        string
	maxBytesPerFile int64 // currently this cannot change once created
	inflight        int64
	seqNo           uint64
	maxMsgSize      uint32
	syncEvery       uint32        // number of writes per fsync
	syncTimeout     time.Duration // duration of time per fsync
	needSync        bool

	// keeps track of the position where we have read
	// (but not yet sent over readChan)
	nextReadPos     int64
	nextReadFileNum int64

	// exposed via ReadChan()
	readChan chan *message.Message

	// internal channels
	writeChan    chan *message.Message
	exitChan     chan int
	exitSyncChan chan int

	logger log.Structured
}

func MustNew(dataPath string, opts ...Option) *DiskQueue {
	d, err := New(dataPath, opts...)
	if err != nil {
		panic(err.Error())
	}
	return d
}

func New(dataPath string, opts ...Option) (*DiskQueue, error) {
	if !isDirExists(dataPath) {
		return nil, fmt.Errorf("dataPath must be a valid directory: %s", dataPath)
	}

	q := &DiskQueue{
		dataPath:        dataPath,
		maxBytesPerFile: defaultMaxBytesPerFile,
		maxMsgSize:      defaultMaxMsgSize,
		readChan:        make(chan *message.Message),
		writeChan:       make(chan *message.Message),
		exitChan:        make(chan int),
		exitSyncChan:    make(chan int),
		syncEvery:       1000000,
		syncTimeout:     500 * time.Millisecond,
		logger:          &nop.Logger{},
	}

	for _, opt := range opts {
		opt(q)
	}

	err := q.retrieveMetaData()
	if err != nil && !os.IsNotExist(err) {
		return nil, fmt.Errorf("failed to retrieveMetaData: %w", err)
	}

	go q.ioLoop()
	return q, nil
}

func (q *DiskQueue) SetSeqNo(seqNo uint64) {
	q.seqNo = seqNo
}

// Depth returns the depth of the queue
func (q *DiskQueue) Depth() int64 {
	return atomic.LoadInt64(&q.depth)
}

// Size returns the size of the queue
func (q *DiskQueue) Size() int64 {
	return atomic.LoadInt64(&q.inflight)
}

// ReadChan returns the receive-only []byte channel for reading data
func (q *DiskQueue) ReadChan() <-chan *message.Message {
	return q.readChan
}

// Put writes a []byte to the queue
func (q *DiskQueue) Put(ctx context.Context, msg *message.Message) error {
	select {
	case <-q.exitChan:
		return ErrClosed
	case q.writeChan <- msg:
		return nil
	case <-ctx.Done():
		return ErrTimeout
	}
}

// Close cleans up the queue and persists metadata
func (q *DiskQueue) Close() error {
	close(q.exitChan)
	// ensure that ioLoop has exited
	<-q.exitSyncChan

	close(q.readChan)
	if q.readFile != nil {
		silentClose(q.readFile)
		q.readFile = nil
	}

	if q.writeFile != nil {
		silentClose(q.writeFile)
		q.writeFile = nil
	}

	return q.sync()
}

func (q *DiskQueue) deleteAllFiles() error {
	err := q.skipToNextRWFile()
	if err != nil {
		return err
	}

	fn := q.metaDataFileName()
	err = os.Remove(fn)
	if err != nil && !os.IsNotExist(err) {
		q.logger.Error("failed to remove metadata file", log.String("file_name", fn), log.Error(err))
		return err
	}

	return nil
}

func (q *DiskQueue) skipToNextRWFile() error {
	if q.readFile != nil {
		silentClose(q.readFile)
		q.readFile = nil
	}

	if q.writeFile != nil {
		silentClose(q.writeFile)
		q.writeFile = nil
	}

	var err error
	for i := q.readFileNum; i <= q.writeFileNum; i++ {
		fn := q.fileName(i)
		innerErr := os.Remove(fn)
		if innerErr != nil && !os.IsNotExist(innerErr) {
			q.logger.Error("failed to remove data file", log.String("file_name", fn), log.Error(innerErr))
			err = innerErr
		}
	}

	q.writeFileNum++
	q.writePos = 0
	q.readFileNum = q.writeFileNum
	q.readPos = 0
	q.nextReadFileNum = q.writeFileNum
	q.nextReadPos = 0
	atomic.StoreInt64(&q.depth, 0)

	return err
}

// readOne performs a low level filesystem read for a single []byte
// while advancing read positions and rolling files, if necessary
func (q *DiskQueue) readOne() (*message.Message, error) {
	var err error
	var msgSize uint32

	if q.readFile == nil {
		curFileName := q.fileName(q.readFileNum)
		q.readFile, err = os.OpenFile(curFileName, os.O_RDONLY, 0600)
		if err != nil {
			return nil, err
		}

		if q.readPos > 0 {
			_, err = q.readFile.Seek(q.readPos, 0)
			if err != nil {
				silentClose(q.readFile)
				q.readFile = nil
				return nil, err
			}
		}

		q.reader = bufio.NewReader(q.readFile)
	}

	err = binary.Read(q.reader, binary.BigEndian, &msgSize)
	if err != nil {
		silentClose(q.readFile)
		q.readFile = nil
		return nil, err
	}

	if q.maxMsgSize != 0 && msgSize > q.maxMsgSize {
		// this file is corrupt and we have no reasonable guarantee on
		// where a new message should begin
		silentClose(q.readFile)
		q.readFile = nil
		return nil, fmt.Errorf("invalid message read size (%d)", msgSize)
	}

	q.msgReadBuf.Reset()
	_, err = io.CopyN(&q.msgReadBuf, q.reader, int64(msgSize))
	if err != nil {
		silentClose(q.readFile)
		q.readFile = nil
		return nil, err
	}

	totalBytes := int64(4 + msgSize)

	// we only advance next* because we have not yet sent this to consumers
	// (where readFileNum, readPos will actually be advanced)
	q.nextReadPos = q.readPos + totalBytes
	q.nextReadFileNum = q.readFileNum

	// TODO: each data file should embed the maxBytesPerFile
	// as the first 8 bytes (at creation time) ensuring that
	// the value can change without affecting runtime
	if q.nextReadPos > q.maxBytesPerFile {
		if q.readFile != nil {
			silentClose(q.readFile)
			q.readFile = nil
		}

		q.nextReadFileNum++
		q.nextReadPos = 0
	}

	return message.Unmarshal(&q.msgReadBuf)
}

// writeOne performs a low level filesystem write for a single []byte
// while advancing write positions and rolling files, if necessary
func (q *DiskQueue) writeOne(msg *message.Message) error {
	var err error

	if q.writeFile == nil {
		curFileName := q.fileName(q.writeFileNum)
		q.writeFile, err = os.OpenFile(curFileName, os.O_RDWR|os.O_CREATE, 0600)
		if err != nil {
			return err
		}

		if q.writePos > 0 {
			_, err = q.writeFile.Seek(q.writePos, 0)
			if err != nil {
				silentClose(q.writeFile)
				q.writeFile = nil
				return err
			}
		}
	}

	q.msgWriteBuf.Reset()
	if err = message.Marshal(msg, &q.msgWriteBuf); err != nil {
		return fmt.Errorf("invalid message: failed to marshal: %w", err)
	}

	dataLen := uint32(q.msgWriteBuf.Len())
	if q.maxMsgSize != 0 && dataLen > q.maxMsgSize {
		return fmt.Errorf("invalid message write size (%d) maxMsgSize=%d", len(msg.Content), q.maxMsgSize)
	}

	q.writeBuf.Reset()
	err = binary.Write(&q.writeBuf, binary.BigEndian, dataLen)
	if err != nil {
		return err
	}

	_, err = q.writeBuf.Write(q.msgWriteBuf.Bytes())
	if err != nil {
		return err
	}

	// only write to the file once
	_, err = q.writeFile.Write(q.writeBuf.Bytes())
	if err != nil {
		silentClose(q.writeFile)
		q.writeFile = nil
		return err
	}

	totalBytes := int64(4 + dataLen)
	q.writePos += totalBytes
	atomic.AddInt64(&q.depth, 1)
	atomic.AddInt64(&q.inflight, 1)

	if q.writePos > q.maxBytesPerFile {
		q.writeFileNum++
		q.writePos = 0

		// sync every time we start writing to a new file
		err = q.sync()
		if err != nil {
			q.logger.Error("sync failed", log.Error(err))
		}

		if q.writeFile != nil {
			silentClose(q.writeFile)
			q.writeFile = nil
		}
	}

	return err
}

// sync fsyncs the current writeFile and persists metadata
func (q *DiskQueue) sync() error {
	//if q.writeFile != nil {
	//	err := q.writeFile.Sync()
	//	if err != nil {
	//		silentClose(q.writeFile)
	//		q.writeFile = nil
	//		return err
	//	}
	//}

	err := q.persistMetaData()
	if err != nil {
		return err
	}

	q.needSync = false
	return nil
}

func (q *DiskQueue) retrieveMetaData() error {
	var f *os.File
	var err error

	fileName := q.metaDataFileName()
	f, err = os.OpenFile(fileName, os.O_RDONLY, 0600)
	if err != nil {
		return err
	}
	defer silentClose(f)

	var depth int64
	_, err = fmt.Fscanf(f, "%d\n%d,%d\n%d,%d\n",
		&depth,
		&q.readFileNum, &q.readPos,
		&q.writeFileNum, &q.writePos)
	if err != nil {
		return err
	}
	atomic.StoreInt64(&q.depth, depth)
	q.nextReadFileNum = q.readFileNum
	q.nextReadPos = q.readPos

	return nil
}

// persistMetaData atomically writes state to the filesystem
func (q *DiskQueue) persistMetaData() error {
	var f *os.File
	var err error

	fileName := q.metaDataFileName()
	tmpFileName := fmt.Sprintf("%s.%d.tmp", fileName, rand.Int())

	// write to tmp file
	f, err = os.OpenFile(tmpFileName, os.O_RDWR|os.O_CREATE, 0600)
	if err != nil {
		return err
	}

	_, err = fmt.Fprintf(f, "%d\n%d,%d\n%d,%d\n",
		atomic.LoadInt64(&q.depth),
		q.readFileNum, q.readPos,
		q.writeFileNum, q.writePos)
	if err != nil {
		silentClose(f)
		return err
	}
	_ = f.Sync()
	silentClose(f)

	// atomically rename
	return os.Rename(tmpFileName, fileName)
}

func (q *DiskQueue) metaDataFileName() string {
	return path.Join(q.dataPath, "disk_queue.meta")
}

func (q *DiskQueue) fileName(fileNum int64) string {
	return path.Join(q.dataPath, fmt.Sprintf("disk_queue.%06d.dat", fileNum))
}

func (q *DiskQueue) checkTailCorruption(depth int64) {
	if q.readFileNum < q.writeFileNum || q.readPos < q.writePos {
		return
	}

	// we've reached the end of the diskqueue
	// if depth isn't 0 something went wrong
	if depth != 0 {
		if depth < 0 {
			q.logger.Error(
				"negative depth at tail, possible metadata corruption",
				log.Int64("depth", depth),
			)
		} else if depth > 0 {
			q.logger.Error(
				"positive depth at tail, possible data loss",
				log.Int64("depth", depth),
			)
		}
		// force set depth 0
		atomic.StoreInt64(&q.depth, 0)
		q.needSync = true
	}

	if q.readFileNum != q.writeFileNum || q.readPos != q.writePos {
		if q.readFileNum > q.writeFileNum {
			q.logger.Error(
				"data corruption: readFileNum > writeFileNum",
				log.Int64("read_file_num", q.readFileNum),
				log.Int64("write_file_num", q.writeFileNum),
			)
		}

		if q.readPos > q.writePos {
			q.logger.Error(
				"data corruption: readPos > writePos",
				log.Int64("read_pos", q.readPos),
				log.Int64("write_pos", q.writePos),
			)
		}

		_ = q.skipToNextRWFile()
		q.needSync = true
	}
}

func (q *DiskQueue) moveForward() {
	oldReadFileNum := q.readFileNum
	q.readFileNum = q.nextReadFileNum
	q.readPos = q.nextReadPos
	atomic.AddInt64(&q.inflight, -1)
	// TODO(buglloc): we going to underground after restoration
	depth := atomic.AddInt64(&q.depth, -1)

	// see if we need to clean up the old file
	if oldReadFileNum != q.nextReadFileNum {
		// sync every time we start reading from a new file
		q.needSync = true

		fn := q.fileName(oldReadFileNum)
		err := os.Remove(fn)
		if err != nil {
			q.logger.Error("failed to remove data file", log.String("file_name", fn), log.Error(err))
		}
	}

	q.checkTailCorruption(depth)
}

func (q *DiskQueue) handleReadError() {
	// jump to the next read file and rename the current (bad) file
	if q.readFileNum == q.writeFileNum {
		// if you can't properly read from the current write file it's safe to
		// assume that something is fucked and we should skip the current file too
		if q.writeFile != nil {
			silentClose(q.writeFile)
			q.writeFile = nil
		}
		q.writeFileNum++
		q.writePos = 0
	}

	badFn := q.fileName(q.readFileNum)
	badRenameFn := badFn + ".bad"

	q.logger.Error(
		"broken file, jump to next file",
		log.String("file_name", badFn),
		log.String("saved_file_name", badRenameFn),
	)
	err := os.Rename(badFn, badRenameFn)
	if err != nil {
		q.logger.Error(
			"failed to rename bad data file",
			log.String("file_name", badFn),
			log.String("saved_file_name", badRenameFn),
			log.Error(err),
		)
	}

	q.readFileNum++
	q.readPos = 0
	q.nextReadFileNum = q.readFileNum
	q.nextReadPos = 0

	// significant state change, schedule a sync on the next iteration
	q.needSync = true
}

// ioLoop provides the backend for exposing a go channel (via ReadChan())
// in support of multiple concurrent queue consumers
//
// it works by looping and branching based on whether or not the queue has data
// to read and blocking until data is either read or written over the appropriate
// go channels
//
// conveniently this also means that we're asynchronously reading from the filesystem
func (q *DiskQueue) ioLoop() {
	var (
		msgRead    *message.Message
		err        error
		count      uint32
		r          chan *message.Message
		syncTicker = time.NewTicker(q.syncTimeout)
	)

loop:
	for {
		// dont sync all the time :)
		if count == q.syncEvery {
			q.needSync = true
		}

		if q.needSync {
			err = q.sync()
			if err != nil {
				q.logger.Error("failed to sync data", log.Error(err))
			}
			count = 0
		}

		if (q.readFileNum < q.writeFileNum) || (q.readPos < q.writePos) {
			if q.nextReadPos == q.readPos {
				msgRead, err = q.readOne()
				if err != nil {
					q.logger.Error(
						"failed to read data",
						log.Int64("read_pos", q.readPos),
						log.String("file_name", q.fileName(q.readFileNum)),
						log.Error(err),
					)
					q.handleReadError()
					continue
				}
			}
			r = q.readChan
		} else {
			r = nil
		}

		select {
		// the Go channel spec dictates that nil channel operations (read or write)
		// in a select are skipped, we set r to q.readChan only when there is data to read
		case r <- msgRead:
			count++
			// moveForward sets needSync flag if a file is removed
			q.moveForward()
		case msgWrite := <-q.writeChan:
			count++
			q.seqNo++
			// TODO(buglloc): why this code here?
			if err = q.writeOne(msgWrite.WithSeqNo(q.seqNo)); err != nil {
				q.logger.Error("failed to save message to diskQ", log.Error(err))
			}
			message.ReleaseMsg(msgWrite)
		case <-syncTicker.C:
			// pass
		case <-q.exitChan:
			break loop
		}
	}

	syncTicker.Stop()
	q.exitSyncChan <- 1
}

func silentClose(file io.Closer) {
	_ = file.Close()
}
