package s3

import (
	"bytes"
	"fmt"
	"strconv"

	"github.com/aws/aws-sdk-go/aws"
	s3pkg "github.com/aws/aws-sdk-go/service/s3"
)

type UploadPartInput struct {
	Key        string
	UploadId   string
	PartNumber int
	Body       *bytes.Reader
	Length     int
}

func (s *s3) InitiateMultipartUpload(key string) (string, error) {
	input := &s3pkg.CreateMultipartUploadInput{
		Bucket: aws.String(s.Bucket),
		Key:    aws.String(key),
	}
	resp, err := s.client.CreateMultipartUpload(input)
	var uploadId string
	if err == nil {
		uploadId = *resp.UploadId
	}
	s.logger.Log("Initiated multipart upload; key=", key, ", uploadId=", uploadId, ", err=", err)
	return uploadId, err
}

func (s *s3) UploadPart(input *UploadPartInput) (string, error) {
	req := &s3pkg.UploadPartInput{
		Bucket:        aws.String(s.Bucket),
		Key:           aws.String(input.Key),
		UploadId:      aws.String(input.UploadId),
		PartNumber:    aws.Int64(int64(input.PartNumber)),
		ContentLength: aws.Int64(int64(input.Length)),
		Body:          input.Body,
	}
	resp, err := s.client.UploadPart(req)
	shortUploadId := input.UploadId
	if len(shortUploadId) > 12 {
		shortUploadId = shortUploadId[:12]
	}

	if err != nil {
		err = fmt.Errorf("Upload part to S3 failed; key=%s, uploadId=%s...<truncated>, part=%v, length=%v, err=%v\n",
			input.Key, shortUploadId, input.PartNumber, input.Length, err)
		return "", err
	}

	s.logger.Log("Uploaded part to S3; key=", input.Key, ", uploadId=", shortUploadId,
		"s...<truncated>, part=", input.PartNumber, ", length=", input.Length, ", etag=", *resp.ETag)
	return *resp.ETag, nil
}

func (s *s3) CompleteMultipartUpload(key, uploadId string, parts map[string]string) error {
	s3Parts, err := s.mapToCompletedParts(parts)
	if err != nil {
		return err
	}

	input := &s3pkg.CompleteMultipartUploadInput{
		Bucket:          aws.String(s.Bucket),
		Key:             aws.String(key),
		UploadId:        aws.String(uploadId),
		MultipartUpload: &s3pkg.CompletedMultipartUpload{Parts: s3Parts},
	}

	_, err = s.client.CompleteMultipartUpload(input)
	s.logger.Log("Completed multipart upload; key=", key, ", uploadId=", uploadId, ", err=", err)

	return err
}

func (*s3) mapToCompletedParts(parts map[string]string) ([]*s3pkg.CompletedPart, error) {
	completedParts := make([]*s3pkg.CompletedPart, len(parts))
	for iString, etag := range parts {
		i, err := strconv.Atoi(iString)
		if err != nil {
			return nil, err
		}

		completedParts[i-1] = &s3pkg.CompletedPart{
			ETag:       aws.String(etag),
			PartNumber: aws.Int64(int64(i)),
		}
	}
	return completedParts, nil
}
