package cache

import (
	"context"
	"fmt"
	"sync"
	"time"

	"github.com/karlseguin/ccache/v2"

	"a.yandex-team.ru/library/go/yandex/unistat"
	"a.yandex-team.ru/library/go/yandex/unistat/aggr"

	"a.yandex-team.ru/infra/dist/repo-daemon/internal/cacus"
	"a.yandex-team.ru/infra/dist/repo-daemon/internal/mds"
	"a.yandex-team.ru/infra/dist/repo-daemon/pkg/logger"
)

const (
	BundleFormat         string        = "%s/%s/%s"
	DataFetched          RequestBitmap = 0x00000001
	SignatureFetched     RequestBitmap = 0x00000002
	BundleFetched        RequestBitmap = 0x00000003
	ReleaseFetched       RequestBitmap = 0x00000004
	ReleaseGPGFetched    RequestBitmap = 0x00000008
	InReleaseFetched     RequestBitmap = 0x00000010
	PlainFetched         RequestBitmap = 0x00000020
	GzippedFetched       RequestBitmap = 0x00000040
	BzippedFetched       RequestBitmap = 0x00000080
	TorrentsCacheMaxSize int64         = 50000
	TorrentCacheTTL      time.Duration = 3600
	SourcesCacheMaxSize  int64         = 30000
	SourcesCacheTTL      time.Duration = 3600
)

type DistCache struct {
	dataCache    *ccache.Cache
	cacheStat    *CacheStats
	sessionCache *ccache.LayeredCache
	cleanupCache *ccache.LayeredCache
	torrentCache *ccache.Cache
	sourcesCache *ccache.Cache
	byHashCache  *ccache.Cache
	sessionTTL   time.Duration
	dataTTL      time.Duration
	byHashTTL    time.Duration
	indexLocks   IndexLocks
	byHashLocks  IndexLocks
	db           *cacus.DBClient
	mds          *mds.Client
}

func NewDistCache(param Param, db *cacus.DBClient, mdsClient *mds.Client) *DistCache {
	dataConfig := ccache.Configure()
	sessionConfig := ccache.Configure()
	cleanupConfig := ccache.Configure()
	byHashConfig := ccache.Configure()
	dataConfig.Buckets(param.Data.Buckets).MaxSize(param.Data.MaxSize).ItemsToPrune(param.Data.Prune).Track()
	byHashConfig.Buckets(param.Data.Buckets).MaxSize(param.Data.MaxSize).ItemsToPrune(param.Data.Prune)
	sessionConfig.Buckets(param.Session.Buckets).MaxSize(param.Session.MaxSize).ItemsToPrune(param.Session.Prune).OnDelete(onDeleteSession)
	cleanupConfig.Buckets(param.Session.Buckets).MaxSize(param.Session.MaxSize * 2).ItemsToPrune(param.Session.Prune)
	stats := CacheStats{
		Hits:   unistat.NewNumeric("data_hits", 1, aggr.Counter(), unistat.Sum),
		Misses: unistat.NewNumeric("data_misses", 1, aggr.Counter(), unistat.Sum),
	}
	unistat.Register(stats.Hits)
	unistat.Register(stats.Misses)
	return &DistCache{
		sessionTTL:   time.Duration(param.Session.TTL),
		dataTTL:      time.Duration(param.Data.TTL),
		byHashTTL:    time.Duration(param.Data.ByHashTTL),
		dataCache:    ccache.New(dataConfig),
		byHashCache:  ccache.New(byHashConfig),
		cacheStat:    &stats,
		sessionCache: ccache.Layered(sessionConfig),
		cleanupCache: ccache.Layered(cleanupConfig),
		indexLocks:   IndexLocks{bundle: make(map[string]*sync.Mutex)},
		byHashLocks:  IndexLocks{bundle: make(map[string]*sync.Mutex)},
		torrentCache: ccache.New(ccache.Configure().MaxSize(TorrentsCacheMaxSize).Buckets(128)),
		sourcesCache: ccache.New(ccache.Configure().MaxSize(SourcesCacheMaxSize).Buckets(128)),
		db:           db,
		mds:          mdsClient,
	}
}

type CacheStats struct {
	Hits   *unistat.Numeric
	Misses *unistat.Numeric
}

type IndexLocks struct {
	bundle map[string]*sync.Mutex
	mux    sync.Mutex
}

type DataItem struct {
	Bundle    *RepoIndexBundle
	Index     *cacus.Document
	UpdatedAt time.Time
}

type SessionItem struct {
	DataItem      *ccache.Item
	RequestBitmap RequestBitmap
}

func (s *SessionItem) Data() *DataItem {
	return s.DataItem.Value().(*DataItem)
}

func (s *SessionItem) Bitmap() RequestBitmap {
	return s.RequestBitmap
}

func (s *SessionItem) SetBitmap(bitmap RequestBitmap) {
	s.RequestBitmap = bitmap
}

type ByHashItem struct {
	Data        []byte
	UpdatedAt   time.Time
	ValidBefore time.Time
}

type CleanupItem struct {
	TimeCh  <-chan time.Time
	ReadyCh chan bool
}

type RequestBitmap uint32

func onDeleteSession(item *ccache.Item) {
	d := item.Value().(*SessionItem).DataItem
	d.Release()
}

func (c *DistCache) cleanupSession(primary, client string, cleanupItem *CleanupItem) {
	select {
	case <-cleanupItem.ReadyCh:
		c.sessionCache.Delete(primary, client)
		c.cleanupCache.Delete(primary, client)
	case t := <-cleanupItem.TimeCh:
		logger.Debugf("Session for [%s]:%s expired at %s before all data fetched. Session deleted!", primary, client, t)
		c.sessionCache.Delete(primary, client)
		c.cleanupCache.Delete(primary, client)
	}
}

func (c *DistCache) SetSessionCleanup(repo, env, arch, client string) chan bool {
	primary := fmt.Sprintf(BundleFormat, repo, env, arch)
	cleanup := c.cleanupCache.Get(primary, client)
	if cleanup != nil {
		return cleanup.Value().(*CleanupItem).ReadyCh
	}
	cleanupItem := &CleanupItem{TimeCh: time.NewTimer(time.Second * c.sessionTTL).C, ReadyCh: make(chan bool)}
	c.cleanupCache.Set(primary, client, cleanupItem, time.Second*c.sessionTTL)
	go c.cleanupSession(primary, client, cleanupItem)
	return cleanupItem.ReadyCh
}

func (c *DistCache) fetchDataItem(ctx context.Context, key, repo, env, arch string, oldData *DataItem) (*DataItem, bool, error) {
	var needUpdate bool
	logger.Debugf("%s: refreshing cache", key)
	if oldData != nil {
		logger.Debugf("%s: old data present", key)
	} else {
		logger.Debugf("%s: old data missing", key)
	}
	index, err := c.db.GetIndexWithTimeout(ctx, repo, env, arch, cacus.DBTimeout)
	if err != nil {
		return nil, false, err
	}
	if oldData != nil {
		needUpdate = cacus.IndexDiffers(oldData.Index, index)
	} else {
		needUpdate = true
	}
	if needUpdate {
		logger.Debugf("%s: old and new index data differs. Updating it", key)
		logger.Infof("Fetching index bundle for %s", key)
		fetchStartTime := time.Now()
		defer func() {
			logger.Infof("Finished fetching bundle for %s in %.3fs", key, time.Since(fetchStartTime).Seconds())
		}()
		bundle, err := c.FetchIndexBundle(ctx, index)
		if err != nil {
			return nil, false, fmt.Errorf("cannot fetch bundle(%s/%s/%s) for %s: %s", repo, env, arch, key, err)
		}
		if bundle.ByHash {
			c.byHashCache.Set(fmt.Sprintf("%s/%s", key, index.PlainSHA256), &ByHashItem{bundle.Plain, index.Lastupdated, time.Time{}}, time.Second*c.byHashTTL)
			c.byHashCache.Set(fmt.Sprintf("%s/%s", key, index.GzippedSHA256), &ByHashItem{bundle.GZIPed, index.Lastupdated, time.Time{}}, time.Second*c.byHashTTL)
			c.byHashCache.Set(fmt.Sprintf("%s/%s", key, index.BzippedSHA256), &ByHashItem{bundle.BZIP2ed, index.Lastupdated, time.Time{}}, time.Second*c.byHashTTL)
		}
		return &DataItem{Bundle: bundle, Index: index, UpdatedAt: index.Lastupdated}, true, nil
	} else {
		logger.Debugf("%s: old and new index data equals. No update required", key)
		return oldData, false, nil
	}
}

func (c *DistCache) GetSessionItem(ctx context.Context, repo, env, arch, client string) (*ccache.Item, error) {
	primary := fmt.Sprintf(BundleFormat, repo, env, arch)
	session, err := c.sessionCache.Fetch(primary, client, time.Second*c.sessionTTL, func() (interface{}, error) {
		bundleLock := c.indexLocks.Get(primary)
		bundleLock.Lock()
		defer bundleLock.Unlock()
		data := c.dataCache.Get(primary)
		if data == nil {
			c.cacheStat.Misses.Update(1)
			dataItem, _, err := c.fetchDataItem(ctx, primary, repo, env, arch, nil)
			if err != nil {
				return nil, err
			}
			c.dataCache.Set(primary, dataItem, time.Second*c.dataTTL)
		} else {
			if data.Expired() {
				dataItem, updated, err := c.fetchDataItem(ctx, primary, repo, env, arch, data.Value().(*DataItem))
				if err != nil {
					return nil, err
				}
				if updated {
					c.cacheStat.Misses.Update(1)
				} else {
					c.cacheStat.Hits.Update(1)
				}
				c.dataCache.Set(primary, dataItem, time.Second*c.dataTTL)
			} else {
				c.cacheStat.Hits.Update(1)
			}
		}
		data = c.dataCache.TrackingGet(primary).(*ccache.Item)
		return &SessionItem{DataItem: data, RequestBitmap: 0}, nil
	})
	if err != nil {
		return nil, err
	}
	return session, nil
}

func (c *DistCache) GetDataBundle(ctx context.Context, repo, env, arch string, force bool) (*DataItem, error) {
	key := fmt.Sprintf(BundleFormat, repo, env, arch)
	bundleLock := c.indexLocks.Get(key)
	bundleLock.Lock()
	defer bundleLock.Unlock()
	data := c.dataCache.Get(key)
	if data == nil {
		c.cacheStat.Misses.Update(1)
		dataItem, _, err := c.fetchDataItem(ctx, key, repo, env, arch, nil)
		if err != nil {
			return nil, err
		}
		c.dataCache.Set(key, dataItem, time.Second*c.dataTTL)
		return dataItem, nil
	} else {
		if data.Expired() || force {
			var dataItem *DataItem
			var updated bool
			var err error
			// pass oldData: nil to force data fetching
			if force {
				dataItem, updated, err = c.fetchDataItem(ctx, key, repo, env, arch, nil)
			} else {
				dataItem, updated, err = c.fetchDataItem(ctx, key, repo, env, arch, data.Value().(*DataItem))
			}
			if err != nil {
				return nil, err
			}
			if updated {
				c.cacheStat.Misses.Update(1)
			} else {
				c.cacheStat.Hits.Update(1)
			}
			c.dataCache.Set(key, dataItem, time.Second*c.dataTTL)
			return dataItem, nil
		} else {
			c.cacheStat.Hits.Update(1)
			return data.Value().(*DataItem), nil
		}
	}
}

func (c *DistCache) GetTorrentID(ctx context.Context, repo, storageKey string) (string, error) {
	cached, err := c.torrentCache.Fetch(storageKey, time.Second*TorrentCacheTTL, func() (interface{}, error) {
		logger.Debugf("looking for rbtorrent_id for %s:%s in database", repo, storageKey)
		pkg, err := c.db.FindPackageByStorageKey(ctx, repo, storageKey)
		if err != nil {
			logger.Errorf("package with torrent id for: %s not found", storageKey)
			return "", err
		}
		for _, deb := range pkg.Debs {
			if deb.StorageKey == storageKey {
				var err error
				if deb.RBTorrentID == "" {
					err = fmt.Errorf("no rbtorrent_id found for: %s", storageKey)
				}
				return deb.RBTorrentID, err
			}
		}
		return "", fmt.Errorf("no rbtorrent_id found for: %s", storageKey)
	})
	if err != nil {
		return "", err
	}
	return cached.Value().(string), nil
}

func (c *DistCache) GetSourceStorageKey(ctx context.Context, repo, env, file string) (string, error) {
	cacheKey := fmt.Sprintf("%s/%s/%s", repo, env, file)
	cached, err := c.sourcesCache.Fetch(cacheKey, SourcesCacheTTL*time.Second, func() (interface{}, error) {
		logger.Debugf("looking for storage_key for source file: %s/%s/%s", repo, env, file)
		storageKey, err := c.db.FindSourceKeyBySourceFile(ctx, repo, env, file)
		if err != nil {
			return "", err
		}
		return storageKey, nil
	})
	if err != nil {
		return "", err
	}
	return cached.Value().(string), nil
}

func (d *DataItem) Size() int64 {
	return d.Bundle.Size()
}

func (b *ByHashItem) Size() int64 {
	return int64(len(b.Data))
}

func (c *DistCache) GetByHash(repo, env, arch, hash string) *ByHashItem {
	item := c.byHashCache.Get(fmt.Sprintf("%s/%s/%s/%s", repo, env, arch, hash))
	if item != nil {
		return item.Value().(*ByHashItem)
	} else {
		return nil
	}
}

// FetchByHash fetches PackageIndexHistoryEntry from cache or from DB and MDS
// manages locks automatically
func (c *DistCache) FetchByHash(ctx context.Context, repo string, entry *cacus.PackageIndexHistoryEntry) (*ByHashItem, error) {
	cacheKey := fmt.Sprintf("%s/%s/%s/%s", repo, entry.Env, entry.Arch, entry.SHA256)
	mtx := c.byHashLocks.Get(cacheKey)
	mtx.Lock()
	defer mtx.Unlock()
	item, err := c.byHashCache.Fetch(cacheKey, c.byHashTTL*time.Second, func() (interface{}, error) {
		data, err := c.mds.GetFileByKeyWithTimeout(ctx, entry.StorageKey, mds.RequestTimeout)
		if err != nil {
			return nil, err
		}
		item := ByHashItem{
			Data:        data,
			UpdatedAt:   entry.UpdatedAt,
			ValidBefore: entry.ValidBefore,
		}
		return &item, nil
	})
	if err != nil {
		return nil, err
	}
	return item.Value().(*ByHashItem), nil
}

func (l *IndexLocks) Get(key string) *sync.Mutex {
	l.mux.Lock()
	bundleLock, ok := l.bundle[key]
	if !ok {
		bundleLock = &sync.Mutex{}
		l.bundle[key] = bundleLock
	}
	l.mux.Unlock()
	return bundleLock
}
