package sessionstorage

import (
	"bytes"
	"context"
	"fmt"
	"net/http"
	"net/http/httputil"
	"path"
	"strconv"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/s3"
	"github.com/aws/aws-sdk-go/service/s3/s3manager"
	"github.com/klauspost/compress/zstd"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/gideon/viewer/internal/models"
	"a.yandex-team.ru/security/libs/go/yahttp"
)

const (
	region = "yandex"
)

type (
	Storage struct {
		cfg           Options
		s3Client      *s3.S3
		uploader      *s3manager.Uploader
		downloader    *s3manager.Downloader
		downloadProxy *httputil.ReverseProxy
		log           log.Logger
		dec           *zstd.Decoder
		enc           *zstd.Encoder
	}

	Options struct {
		Endpoint        string
		Bucket          string
		AccessKeyID     string
		SecretAccessKey string
	}
)

func NewStorage(opts Options, logger log.Logger) (*Storage, error) {
	enc, err := zstd.NewWriter(nil, zstd.WithEncoderConcurrency(5), zstd.WithEncoderLevel(zstd.SpeedDefault))
	if err != nil {
		return nil, fmt.Errorf("can't create zstd writer: %w", err)
	}

	dec, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(10))
	if err != nil {
		return nil, fmt.Errorf("can't create zstd reader: %w", err)
	}

	storage := &Storage{
		cfg: opts,
		log: logger,
		enc: enc,
		dec: dec,
	}

	if err := storage.initS3Client(); err != nil {
		return nil, fmt.Errorf("failed to initialize S3 client: %w", err)
	}

	if err := storage.initBucket(); err != nil {
		return nil, fmt.Errorf("failed to initialize S3 bucket: %w", err)
	}

	storage.initUploader()

	storage.initDownloader()

	storage.initDownloadProxy()

	return storage, nil
}

func (s *Storage) Close() error {
	return s.enc.Close()
}

func (s *Storage) initDownloadProxy() {
	host := fmt.Sprintf("%s.%s", s.cfg.Bucket, s.cfg.Endpoint)

	director := func(req *http.Request) {
		req.URL.Scheme = "http"
		req.URL.Host = host
	}

	s.downloadProxy = &httputil.ReverseProxy{
		Director:  director,
		Transport: s.s3Client.Config.HTTPClient.Transport,
	}
}

func (s *Storage) initBucket() error {
	if _, err := s.s3Client.HeadBucket(&s3.HeadBucketInput{Bucket: aws.String(s.cfg.Bucket)}); err != nil {
		_, err := s.s3Client.CreateBucket(&s3.CreateBucketInput{Bucket: aws.String(s.cfg.Bucket)})
		if err != nil {
			return err
		}
	}
	return nil
}

func (s *Storage) initS3Client() error {
	s3Credentials := credentials.NewStaticCredentials(
		s.cfg.AccessKeyID,
		s.cfg.SecretAccessKey,
		"",
	)
	_, err := s3Credentials.Get()
	if err != nil {
		err = fmt.Errorf("bad credentials: %w", err)
		return err
	}

	cfg := aws.NewConfig().
		WithRegion(region).
		WithEndpoint(s.cfg.Endpoint).
		WithCredentials(s3Credentials).
		WithHTTPClient(yahttp.NewClient(yahttp.Config{
			RedirectPolicy: yahttp.RedirectNoFollow,
			Timeout:        time.Minute * 5,
			DialTimeout:    time.Second,
		}))

	s3Session, err := session.NewSession()
	if err != nil {
		err = fmt.Errorf("failed to create session: %w", err)
		return err
	}

	s.s3Client = s3.New(s3Session, cfg)
	return nil
}

func (s *Storage) initUploader() {
	s.uploader = s3manager.NewUploaderWithClient(s.s3Client)
}

func (s *Storage) initDownloader() {
	s.downloader = s3manager.NewDownloaderWithClient(s.s3Client)
}

func (s *Storage) NewSession(ctx context.Context, sessionInfo models.SSHSessionInfo, data []byte) error {
	return s.upload(ctx, newSessionPath(sessionInfo), data)
}

func (s *Storage) IsSessionExist(ctx context.Context, sessionInfo models.SSHSessionInfo) bool {
	return s.sessionExist(ctx, sessionPath(sessionInfo))
}

func (s *Storage) IsNewSessionExist(ctx context.Context, sessionInfo models.SSHSessionInfo) bool {
	return s.sessionExist(ctx, newSessionPath(sessionInfo))
}

func (s *Storage) sessionExist(ctx context.Context, path string) bool {
	_, err := s.s3Client.HeadObjectWithContext(
		ctx,
		&s3.HeadObjectInput{
			Bucket: aws.String(s.cfg.Bucket),
			Key:    aws.String(path),
		},
	)

	// TODO(buglloc): check me, plz
	return err == nil
}

func (s *Storage) UploadSession(ctx context.Context, sessionInfo models.SSHSessionInfo, data []byte) error {
	return s.upload(ctx, sessionPath(sessionInfo), data)
}

func (s *Storage) upload(ctx context.Context, path string, data []byte) error {
	_, err := s.uploader.UploadWithContext(
		ctx,
		&s3manager.UploadInput{
			Bucket: aws.String(s.cfg.Bucket),
			Key:    aws.String(path),
			Body:   bytes.NewReader(s.enc.EncodeAll(data, nil)),
		},
	)
	return err
}

func (s *Storage) DownloadSession(ctx context.Context, sessInfo models.SSHSessionInfo) ([]byte, error) {
	return s.download(ctx, sessionPath(sessInfo))
}

func (s *Storage) DownloadNewSession(ctx context.Context, sessInfo models.SSHSessionInfo) ([]byte, error) {
	return s.download(ctx, newSessionPath(sessInfo))
}

func (s *Storage) download(ctx context.Context, path string) ([]byte, error) {
	result := &aws.WriteAtBuffer{}
	_, err := s.downloader.DownloadWithContext(
		ctx,
		result,
		&s3.GetObjectInput{
			Bucket: aws.String(s.cfg.Bucket),
			Key:    aws.String(path),
		},
	)

	if err != nil {
		return nil, err
	}

	return s.dec.DecodeAll(result.Bytes(), nil)
}

func (s *Storage) DeleteNewSession(ctx context.Context, sessInfo models.SSHSessionInfo) error {
	_, err := s.s3Client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{
		Bucket: aws.String(s.cfg.Bucket),
		Key:    aws.String(newSessionPath(sessInfo)),
	})

	return err
}

func (s *Storage) DeleteFolder(ctx context.Context, folder string) error {
	iter := s3manager.NewDeleteListIterator(s.s3Client, &s3.ListObjectsInput{
		Bucket: aws.String(s.cfg.Bucket),
		Prefix: aws.String(strings.Trim(folder, "/") + "/"),
	})

	if err := s3manager.NewBatchDeleteWithClient(s.s3Client).Delete(ctx, iter); err != nil {
		return fmt.Errorf("failed to delete objects: %w", err)
	}

	return nil
}

func (s *Storage) ListNewSessions(ctx context.Context) ([]models.SSHSessionInfo, error) {
	return s.listSessions(ctx, "new")
}

func (s *Storage) ListSessions(ctx context.Context, daysAgo int) ([]models.SSHSessionInfo, error) {
	dir := time.Now().Add(-time.Duration(daysAgo) * 24 * time.Hour).Format("2006-01-02")
	return s.listSessions(ctx, dir)
}

func (s *Storage) listSessions(ctx context.Context, dir string) ([]models.SSHSessionInfo, error) {
	listRes, err := s.s3Client.ListObjectsWithContext(
		ctx,
		&s3.ListObjectsInput{
			Bucket: aws.String(s.cfg.Bucket),
			Prefix: aws.String(dir + "/"),
		},
	)

	if err != nil {
		return nil, err
	}

	out := make([]models.SSHSessionInfo, 0, len(listRes.Contents))
	for _, r := range listRes.Contents {
		if r.Key == nil {
			s.log.Error("nil key",
				log.String("dir", dir),
				log.String("item", r.String()),
				log.Error(err))
			continue
		}

		parsed, err := parseSessionFilename(path.Base(*r.Key))
		if err != nil {
			s.log.Error("can'r parse session filename",
				log.String("key", *r.Key),
				log.Error(err))
			continue
		}

		out = append(out, parsed)
	}

	return out, nil
}

func sessionFileName(sessInfo models.SSHSessionInfo) string {
	return fmt.Sprintf(
		"%d_%s_%s_%d.json.zstd",
		sessInfo.TS,
		sessInfo.Host,
		sessInfo.PodID,
		sessInfo.SessionID,
	)
}

func sessionPath(sessInfo models.SSHSessionInfo) string {
	return path.Join(
		time.Unix(0, int64(sessInfo.TS)).Format("2006-01-02"),
		sessionFileName(sessInfo),
	)
}

func newSessionPath(sessInfo models.SSHSessionInfo) string {
	return path.Join(
		"new",
		sessionFileName(sessInfo),
	)
}

func parseSessionFilename(filename string) (models.SSHSessionInfo, error) {
	if !strings.HasSuffix(filename, ".json.zstd") {
		return models.SSHSessionInfo{}, fmt.Errorf("unexpected name (%s): unknown suffix", filename)
	}

	parts := strings.Split(strings.TrimSuffix(filename, ".json.zstd"), "_")
	if len(parts) != 4 {
		return models.SSHSessionInfo{}, fmt.Errorf("unexpected name (%s):must be 4 parts, not %d", filename, len(parts))
	}

	ts, err := strconv.ParseUint(parts[0], 10, 64)
	if err != nil {
		return models.SSHSessionInfo{}, fmt.Errorf("invalid session ts (%s): %w", parts[0], err)
	}

	sessID, err := strconv.ParseUint(parts[3], 10, 32)
	if err != nil {
		return models.SSHSessionInfo{}, fmt.Errorf("invalid session id (%s): %w", parts[4], err)
	}

	return models.SSHSessionInfo{
		TS:        ts,
		Host:      parts[1],
		PodID:     parts[2],
		SessionID: uint32(sessID),
	}, nil
}
