package layerstorage

import (
	"context"
	"fmt"
	"io/ioutil"
	"os"
	"path"
	"path/filepath"
	"sort"
	"strings"

	"golang.org/x/sync/singleflight"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/xray/internal/fsutil"
	"a.yandex-team.ru/security/xray/internal/storage/fetcher"
	"a.yandex-team.ru/security/xray/internal/storage/rescache"
)

const DefaultMaxSize = 30 * 1024 * 1024 * 1024

type Storage struct {
	storagePath string
	fetcher     *fetcher.Fetcher
	cache       *rescache.Cache
	fetchGroup  singleflight.Group
	log         log.Logger
}

func NewStorage(storagePath string, opts ...Option) (*Storage, error) {
	if err := fsutil.CreateDir(storagePath); err != nil {
		return nil, fmt.Errorf("failed to create storage dir: %w", err)
	}

	storage := &Storage{
		storagePath: storagePath,
		log:         &nop.Logger{},
	}

	cacheOpts := []rescache.Option{
		rescache.WithMaxSize(DefaultMaxSize),
	}
	fetcherOpts := []fetcher.Option{
		fetcher.WithSingleFileMode(),
	}

	for _, opt := range opts {
		switch o := opt.(type) {
		case maxSizeOption:
			cacheOpts = append(cacheOpts, rescache.WithMaxSize(o.maxSize))
		case loggerOption:
			storage.log = o.log
			cacheOpts = append(cacheOpts, rescache.WithLogger(o.log))
			fetcherOpts = append(fetcherOpts, fetcher.WithLogger(o.log))
		}
	}

	storage.fetcher = fetcher.NewFetcher(fetcherOpts...)
	storage.cache = rescache.NewCache(cacheOpts...)
	return storage, storage.Promote()
}

func (s *Storage) Download(ctx context.Context, uri fetcher.URI) (*rescache.Resource, error) {
	return s.cache.Fetch(uri.ID, func() (*rescache.Resource, error) {
		layerPath := s.makePath(uri.ID)
		size, err := s.fetcher.Download(ctx, uri, layerPath)
		if err != nil {
			return nil, err
		}

		return &rescache.Resource{
			ID:    uri.ID,
			Bytes: size,
			Path:  layerPath,
		}, nil
	})
}

func (s *Storage) Promote() error {
	layers, err := ioutil.ReadDir(s.storagePath)
	if err != nil {
		return fmt.Errorf("failed to promote storage: %w", err)
	}

	sort.Slice(layers, func(i, j int) bool {
		return layers[i].ModTime().UnixNano() < layers[j].ModTime().UnixNano()
	})

	for _, l := range layers {
		fileName := l.Name()
		if !strings.HasPrefix(fileName, "layer-") {
			continue
		}

		layerPath := path.Join(s.storagePath, fileName)
		if strings.Contains(fileName, ".tmp-") {
			s.log.Info("remove temporary layer", log.String("name", fileName))
			_ = os.RemoveAll(layerPath)
			continue
		}

		_, _ = s.cache.Store(&rescache.Resource{
			ID:    fileName[6:],
			Path:  layerPath,
			Bytes: l.Size(),
		})
	}
	return nil
}

func (s *Storage) Close() {
	s.fetcher.Close()
}

func (s *Storage) makePath(id string) string {
	return filepath.Join(s.storagePath, fmt.Sprintf("layer-%s", id))
}

func (s *Storage) ParseURI(ctx context.Context, uri string) (fetcher.URI, error) {
	// TODO(anton-k): DRY
	parsedURI, err := s.fetcher.ParseURI(ctx, uri)
	if err != nil {
		return fetcher.URI{}, err
	}

	return parsedURI, nil
}
