package s3

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"io"
	"log"
	"runtime/pprof"
	"strconv"
	"strings"
	"time"

	awsS3 "github.com/aws/aws-sdk-go/service/s3"

	"a.yandex-team.ru/security/osquery/osquery-sender/metrics"
)

const (
	// Use one timeout for all requests.
	requestTimeout = time.Second * 30

	// Minimal part size according to AWS docs.
	defaultPartSize = 5 * 1024 * 1024
)

// S3Manager is an alternative to s3manager from aws-sdk-go with different interface: Download is sequential
// and does not require a WriterAt, while Upload is push-based and does not require a Reader. S3Manager also
// has additional retries which ignore the error.
type S3Manager struct {
	client S3Client
	config S3ManagerConfig

	downloadPool *JobPool
	uploadPool   *JobPool
	getInfoPool  *JobPool
}

type S3ManagerConfig struct {
	EnableDebug        bool
	NumRetries         int
	NumDownloadWorkers int
	NumUploadWorkers   int
	NumGetInfoWorkers  int
	MinUploadPartSize  int
}

type AbortWriteCloser interface {
	io.WriteCloser
	Abort()
}

func NewS3Manager(client S3Client, config *S3ManagerConfig) *S3Manager {
	numDownloadWorkers := config.NumDownloadWorkers
	if numDownloadWorkers == 0 {
		numDownloadWorkers = 1
	}
	numUploadWorkers := config.NumUploadWorkers
	if numUploadWorkers == 0 {
		numUploadWorkers = 1
	}
	numGetInfoWorkers := config.NumGetInfoWorkers
	if numGetInfoWorkers == 0 {
		numGetInfoWorkers = 1
	}
	log.Printf("starting %d download workers %d upload workers, %d get info workers for s3\n",
		numDownloadWorkers, numUploadWorkers, numGetInfoWorkers)

	return &S3Manager{
		client: client,
		config: *config,
		// No buffering to reduce memory consumption: serialization waits until at least one upload worker is
		// available.
		downloadPool: NewThreadPool(numDownloadWorkers, 0, "s3-download-pool"),
		uploadPool:   NewThreadPool(numUploadWorkers, 0, "s3-upload-pool"),
		getInfoPool:  NewThreadPool(numGetInfoWorkers, 0, "s3-get-info-pool"),
	}
}

func (m *S3Manager) Download(bucket string, key string) io.Reader {
	ret := &s3Reader{
		client:       m.client,
		bucket:       &bucket,
		key:          &key,
		enableDebug:  m.config.EnableDebug,
		numRetries:   m.config.NumRetries,
		bufSize:      m.config.MinUploadPartSize,
		downloadPool: m.downloadPool,

		buf:    &bytes.Buffer{},
		offset: 0,
		// The totalParts will be replaced after the first background read finishes.
		totalSize: -1,
	}
	ret.startBackgroundRead()
	return ret
}

func (m *S3Manager) Upload(bucket string, key string, metadata map[string]*string) AbortWriteCloser {
	partSize := m.config.MinUploadPartSize
	if partSize == 0 {
		partSize = defaultPartSize
	}
	ret := &s3Writer{
		client:      m.client,
		bucket:      &bucket,
		key:         &key,
		enableDebug: m.config.EnableDebug,
		numRetries:  m.config.NumRetries,
		partSize:    partSize,
		uploadPool:  m.uploadPool,

		buf:      &bytes.Buffer{},
		metadata: metadata,
	}
	ret.buf.Grow(partSize + partSize/4)
	return ret
}

func (m *S3Manager) GetInfos(bucket string, keys []string) ([]*awsS3.HeadObjectOutput, error) {
	var futures []*Future
	for _, key := range keys {
		key := key
		labels := pprof.Labels("name", "s3-get-infos", "bucket", bucket, "key", key)
		future := m.getInfoPool.SubmitWithLabels(labels, func() (interface{}, error) {
			return m.getInfo(bucket, key)
		})
		futures = append(futures, future)
	}

	var ret []*awsS3.HeadObjectOutput
	for i, future := range futures {
		info, err := future.Get()
		if err != nil {
			if m.ObjectExists(bucket, keys[i]) {
				return nil, fmt.Errorf("getting object info for %s/%s failed: %v", bucket, keys[i], errToString(err))
			}
			ret = append(ret, nil)
			continue
		}
		ret = append(ret, info.(*awsS3.HeadObjectOutput))
	}
	return ret, nil
}

func (m *S3Manager) ObjectExists(bucket string, key string) bool {
	// Synchronously check that the object exists (this method is used in GetInfos(), so it cannot be reimplemented via
	// a call to GetInfos()).
	for i := 0; i < m.config.NumRetries; i++ {
		ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
		defer cancel()
		_, err := m.client.HeadObjectWithContext(ctx, &awsS3.HeadObjectInput{
			Bucket: &bucket,
			Key:    &key,
		})
		if err != nil {
			if i != m.config.NumRetries-1 {
				log.Printf("ERROR: retrying get info for %s/%s, retry #%d, error: %v\n", bucket, key, i+1, errToString(err))
				continue
			} else {
				return false
			}
		}
		return true
	}
	return false
}

// ListObjects list all objects. ListObjectsV2Pages applies the same context to all subcalls, so we have to roll
// our own pagination.
func (m *S3Manager) ListObjects(bucket string, prefix *string, pageFn func([]*awsS3.Object) error) error {
	var lastToken *string
	for {
		var output *awsS3.ListObjectsV2Output
		for i := 0; i < m.config.NumRetries; i++ {
			ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
			var err error
			output, err = m.client.ListObjectsV2WithContext(ctx, &awsS3.ListObjectsV2Input{
				Bucket:            &bucket,
				ContinuationToken: lastToken,
				Prefix:            prefix,
			})
			cancel()
			if err != nil {
				if i != m.config.NumRetries-1 {
					log.Printf("ERROR: retrying list objects in %s, retry #%d, error: %v\n", bucket, i+1, errToString(err))
					continue
				} else {
					return err
				}
			}
			break
		}

		err := pageFn(output.Contents)
		if err != nil {
			return err
		}

		lastToken = output.NextContinuationToken
		if lastToken == nil {
			break
		}
	}
	return nil
}

// DeleteObjects batches keys in groups.
func (m *S3Manager) DeleteObjects(bucket string, keys []string) error {
	const batchObjects = 100
	for i := 0; i < len(keys); i += batchObjects {
		end := i + batchObjects
		if end > len(keys) {
			end = len(keys)
		}

		err := m.deleteObjectsBatch(bucket, keys[i:end])
		if err != nil {
			return err
		}
	}
	return nil
}

func (m *S3Manager) CopyObject(bucket string, fromKey string, toKey string) error {
	copySource := bucket + "/" + fromKey
	for i := 0; i < m.config.NumRetries; i++ {
		ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
		defer cancel()
		_, err := m.client.CopyObjectWithContext(ctx, &awsS3.CopyObjectInput{
			Bucket:     &bucket,
			CopySource: &copySource,
			Key:        &toKey,
		})
		if err != nil {
			if i != m.config.NumRetries-1 {
				log.Printf("ERROR: retrying copy for %s/%s -> %s/%s, retry #%d, error: %v\n", bucket, fromKey,
					bucket, toKey, i+1, errToString(err))
				continue
			} else {
				return err
			}
		}
		return nil
	}
	return errors.New("S3Manager.CopyObject: unreachable")
}

func (m *S3Manager) Stop() {
	m.uploadPool.Close()
	m.getInfoPool.Close()
}

func (m *S3Manager) getInfo(bucket string, key string) (*awsS3.HeadObjectOutput, error) {
	for i := 0; i < m.config.NumRetries; i++ {
		ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
		defer cancel()
		output, err := m.client.HeadObjectWithContext(ctx, &awsS3.HeadObjectInput{
			Bucket: &bucket,
			Key:    &key,
		})
		if err != nil {
			if i != m.config.NumRetries-1 {
				log.Printf("ERROR: retrying get info for %s/%s, retry #%d, error: %v\n", bucket, key, i+1, errToString(err))
				continue
			} else {
				return nil, err
			}
		}
		return output, nil
	}
	return nil, errors.New("S3Manager.getInfo: unreachable")
}

func (m *S3Manager) deleteObjectsBatch(bucket string, keys []string) error {
	var ids []*awsS3.ObjectIdentifier
	for _, key := range keys {
		key := key
		ids = append(ids, &awsS3.ObjectIdentifier{
			Key: &key,
		})
	}

	for i := 0; i < m.config.NumRetries; i++ {
		startTime := time.Now()
		ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
		defer cancel()
		_, err := m.client.DeleteObjectsWithContext(ctx, &awsS3.DeleteObjectsInput{
			Bucket: &bucket,
			Delete: &awsS3.Delete{
				Objects: ids,
			},
		})
		if err != nil {
			if i != m.config.NumRetries-1 {
				log.Printf("ERROR: retrying delete objects in %s, retry #%d, error: %v\n", bucket, i+1, errToString(err))
				continue
			} else {
				return fmt.Errorf("failed to delete %d objects (%s, ...): %v", len(ids), *ids[0], errToString(err))
			}
		}

		if m.config.EnableDebug {
			log.Printf("deleted %d objects in %v\n", len(ids), time.Since(startTime))
		}
		return nil
	}
	return errors.New("S3Manager.deleteObjectsBatch: unreachable")
}

type s3Reader struct {
	client       S3Client
	bucket       *string
	key          *string
	enableDebug  bool
	numRetries   int
	bufSize      int
	downloadPool *JobPool

	buf            *bytes.Buffer
	offset         int64
	totalSize      int64
	backgroundRead *Future
}

func (r *s3Reader) Read(p []byte) (int, error) {
	// s3Reader works as follows: a background read job is submitted after creating s3Reader. When the job finishes,
	// its results are appended to the buffer and the new one is submitted, unless the current part is the last one.
	if r.buf.Len() == 0 {
		// The buffer is empty and there are no running background reads which could've filled it.
		if r.backgroundRead == nil {
			// Free the memory upon EOF.
			r.buf = nil
			return 0, io.EOF
		}
		err := r.refillFromBackgroundRead()
		if err != nil {
			return 0, err
		}
	}
	return r.buf.Read(p)
}

type readResult struct {
	data      []byte
	totalSize int64
}

func (r *s3Reader) startBackgroundRead() {
	if r.backgroundRead != nil {
		panic("starting background read when another one is in progress")
	}
	startTime := time.Now()

	offset := r.offset
	end := offset + int64(r.bufSize)
	rangeStr := fmt.Sprintf("bytes=%d-%d", offset, end-1)
	r.offset = end
	if r.enableDebug {
		log.Printf("start reading range %s of %s/%s\n", rangeStr, *r.bucket, *r.key)
	}
	labels := pprof.Labels("name", "s3-get-object", "bucket", *r.bucket, "key", *r.key, "range", rangeStr)
	r.backgroundRead = r.downloadPool.SubmitWithLabels(labels, func() (interface{}, error) {
		for i := 0; i < r.numRetries; i++ {
			ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
			defer cancel()
			output, err := r.client.GetObjectWithContext(ctx, &awsS3.GetObjectInput{
				Bucket: r.bucket,
				Key:    r.key,
				Range:  &rangeStr,
			})
			if err != nil {
				if i != r.numRetries-1 {
					log.Printf("ERROR: retrying downloading %s/%s, retry #%d, error (%T): %v\n", *r.bucket, *r.key, i+1, err, errToString(err))
					continue
				} else {
					return nil, fmt.Errorf("failed to get: %v", errToString(err))
				}
			}
			defer func() {
				_ = output.Body.Close()
			}()

			data := make([]byte, int(*output.ContentLength))
			n, err := io.ReadFull(output.Body, data)
			if err != nil {
				if i != r.numRetries-1 {
					log.Printf("ERROR: retrying reading %s/%s, retry #%d, read %d, error (%T): %v\n", *r.bucket, *r.key, i+1, n, err, errToString(err))
					continue
				} else {
					return nil, fmt.Errorf("failed to read: %v", errToString(err))
				}
			}
			totalSize := int64(-1)
			if output.ContentRange != nil {
				totalSize = parseContentRange(*output.ContentRange)
			}
			if r.enableDebug {
				log.Printf("read chunk (size %d, total size %d) for range %s of %s/%s in %v\n",
					len(data), totalSize, rangeStr, *r.bucket, *r.key, time.Since(startTime))
			}
			return readResult{
				data:      data,
				totalSize: totalSize,
			}, nil
		}
		return nil, errors.New("s3Reader.startBackgroundRead: unreachable")
	})
}

func parseContentRange(s string) int64 {
	// Unfortunately, there is not enough documentation on Content-Range header, but you can look at setTotalBytes in
	// https://github.com/aws/aws-sdk-go/blob/ee7b112de35c14c3dfe525bcade2718347483dd2/service/s3/s3manager/download.go
	v := strings.Split(s, "/")
	last := v[len(v)-1]
	if last != "*" {
		ret, err := strconv.ParseInt(last, 10, 64)
		if err == nil {
			return ret
		}
	}
	return int64(-1)
}

func (r *s3Reader) refillFromBackgroundRead() error {
	result, err := r.backgroundRead.Get()
	r.backgroundRead = nil
	if err != nil {
		return err
	}

	// Refill the buffer with the result.
	readr := result.(readResult)
	r.totalSize = readr.totalSize
	_, err = r.buf.Write(readr.data)
	if err != nil {
		return err
	}

	// Start the new background read if the current chunk is not the last one: either we have the total size or
	// we got enough data.
	if (r.totalSize != -1 && r.offset < r.totalSize) || (r.totalSize == -1 && len(readr.data) == r.bufSize) {
		r.startBackgroundRead()
	}

	return nil
}

type s3Writer struct {
	client      S3Client
	bucket      *string
	key         *string
	enableDebug bool
	numRetries  int
	partSize    int
	// All writers share the same pool of workers.
	uploadPool *JobPool

	buf      *bytes.Buffer
	metadata map[string]*string

	// Multipart upload is created on-demand. If the contents fit inside the buf, PutObject is used instead.
	upload  *awsS3.CreateMultipartUploadOutput
	futures []*Future
}

func (w *s3Writer) Write(p []byte) (int, error) {
	n, err := w.buf.Write(p)
	if err != nil {
		// Never happens to bytes.Buffer
		return n, err
	}

	if w.buf.Len() >= w.partSize {
		err = w.flush()
	}
	return n, err
}

func (w *s3Writer) Close() error {
	defer func() {
		// Free storage after closing.
		w.buf = nil
	}()

	// If the object is small enough, do a single PutObject and return.
	if w.upload == nil {
		labels := pprof.Labels("name", "s3-put-object", "bucket", *w.bucket, "key", *w.key)
		future := w.uploadPool.SubmitWithLabels(labels, func() (interface{}, error) {
			return nil, w.putObject(w.buf.Bytes())
		})
		_, err := future.Get()
		return err
	}

	if w.buf.Len() > 0 {
		err := w.flush()
		if err != nil {
			return err
		}
	}

	var parts []*awsS3.CompletedPart
	for _, future := range w.futures {
		part, err := future.Get()
		if err != nil {
			return err
		}
		parts = append(parts, part.(*awsS3.CompletedPart))
	}

	return w.completeUpload(parts)
}

func (w *s3Writer) Abort() {
	// Free storage after aborting.
	w.buf = nil

	if w.upload == nil {
		return
	}

	for i := 0; i < w.numRetries; i++ {
		ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
		defer cancel()
		_, err := w.client.AbortMultipartUploadWithContext(ctx, &awsS3.AbortMultipartUploadInput{
			Bucket:   w.bucket,
			Key:      w.key,
			UploadId: w.upload.UploadId,
		})
		if err != nil {
			if i != w.numRetries-1 {
				log.Printf("ERROR: retrying aborting %s/%s, retry #%d, error: %v\n", *w.bucket, *w.key, i+1, errToString(err))
				continue
			} else {
				log.Printf("ERROR: aborting multipart upload failed: %v\n", errToString(err))
			}
		}
		w.upload = nil
		return
	}
}

func (w *s3Writer) createUpload(metadata map[string]*string) error {
	for i := 0; i < w.numRetries; i++ {
		ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
		defer cancel()
		var err error
		w.upload, err = w.client.CreateMultipartUploadWithContext(ctx, &awsS3.CreateMultipartUploadInput{
			Bucket:   w.bucket,
			Key:      w.key,
			Metadata: metadata,
		})
		if err != nil {
			if i != w.numRetries-1 {
				log.Printf("ERROR: retrying creating %s/%s, retry #%d, error: %v\n", *w.bucket, *w.key, i+1, errToString(err))
				continue
			} else {
				return err
			}
		}
		return nil
	}
	return errors.New("s3Writer.createUpload: unreachable")
}

func (w *s3Writer) flush() error {
	if w.upload == nil {
		err := w.createUpload(w.metadata)
		if err != nil {
			return err
		}
	}

	partNum := int64(len(w.futures) + 1)
	data := copyBytes(w.buf.Bytes())
	labels := pprof.Labels("name", "s3-upload-part", "bucket", *w.bucket, "key", *w.key, "part", strconv.FormatInt(partNum, 10))
	future := w.uploadPool.SubmitWithLabels(labels, func() (interface{}, error) {
		return w.uploadPart(partNum, data)
	})
	w.futures = append(w.futures, future)
	w.buf.Reset()

	return nil
}

func copyBytes(b []byte) []byte {
	ret := make([]byte, 0, len(b))
	return append(ret, b...)
}

func (w *s3Writer) uploadPart(partNum int64, data []byte) (*awsS3.CompletedPart, error) {
	startTime := time.Now()
	if w.enableDebug {
		log.Printf("uploading part %d of %s", partNum, *w.upload.Key)
	}
	for i := 0; i < w.numRetries; i++ {
		ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
		defer cancel()
		result, err := w.client.UploadPartWithContext(ctx, &awsS3.UploadPartInput{
			Body:       bytes.NewReader(data),
			Bucket:     w.upload.Bucket,
			Key:        w.upload.Key,
			UploadId:   w.upload.UploadId,
			PartNumber: &partNum,
		})
		if err != nil {
			if i != w.numRetries-1 {
				log.Printf("ERROR: retrying uploading part %d of %s/%s, retry #%d, error: %v\n", partNum,
					*w.upload.Bucket, *w.upload.Key, i+1, errToString(err))
				continue
			} else {
				return nil, err
			}
		}

		if w.enableDebug {
			log.Printf("uploaded part %d of %s/%s in %v\n", partNum, *w.upload.Bucket, *w.upload.Key, time.Since(startTime))
		}
		completedPart := &awsS3.CompletedPart{
			ETag:       result.ETag,
			PartNumber: &partNum,
		}
		metrics.S3PartWrites.Inc()
		return completedPart, nil
	}
	return nil, errors.New("s3Writer.uploadPart: unreachable")
}

func (w *s3Writer) completeUpload(parts []*awsS3.CompletedPart) error {
	for i := 0; i < w.numRetries; i++ {
		ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
		defer cancel()
		_, err := w.client.CompleteMultipartUploadWithContext(ctx, &awsS3.CompleteMultipartUploadInput{
			Bucket:   w.bucket,
			Key:      w.key,
			UploadId: w.upload.UploadId,
			MultipartUpload: &awsS3.CompletedMultipartUpload{
				Parts: parts,
			},
		})
		if err != nil {
			if i != w.numRetries-1 {
				log.Printf("ERROR: retrying closing %s/%s, retry #%d, error: %v\n", *w.bucket, *w.key, i+1, errToString(err))
				continue
			} else {
				return err
			}
		}
		if w.enableDebug {
			log.Printf("completed object %s/%s\n", *w.upload.Bucket, *w.upload.Key)
		}
		return nil
	}
	return errors.New("s3Writer.completeUpload: unreachable")
}

func (w *s3Writer) putObject(data []byte) error {
	startTime := time.Now()
	if w.enableDebug {
		log.Printf("putting object %s/%s\n", *w.bucket, *w.key)
	}
	for i := 0; i < w.numRetries; i++ {
		ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
		defer cancel()
		_, err := w.client.PutObjectWithContext(ctx, &awsS3.PutObjectInput{
			Body:     bytes.NewReader(data),
			Bucket:   w.bucket,
			Key:      w.key,
			Metadata: w.metadata,
		})
		if err != nil {
			if i != w.numRetries-1 {
				log.Printf("ERROR: retrying putting %s/%s, retry #%d, error: %v\n", *w.bucket, *w.key, i+1, errToString(err))
				continue
			} else {
				return err
			}
		}
		if w.enableDebug {
			log.Printf("put whole object %s/%s in %v\n", *w.bucket, *w.key, time.Since(startTime))
		}
		return nil
	}
	return errors.New("s3Writer.putObject: unreachable")
}
