package filereport

import (
	"context"
	"time"

	"github.com/karlseguin/ccache/v2"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/security/vt-proxy/internal/logger"
	"a.yandex-team.ru/security/vt-proxy/internal/models"
	"a.yandex-team.ru/security/vt-proxy/internal/virustotal"
)

const (
	refreshTTL          = 24 * 31 * time.Hour
	cacheTTL            = 24 * time.Hour
	cacheMaxSize        = 100 * 1024
	cacheBuckets        = 16
	cacheDeleteBuffer   = 1024
	cacheGetsPerPromote = 3
	cacheItemsToPrune   = 500
	cachePromoteBuffer  = 1024
)

type (
	ReportClient struct {
		vt    *virustotal.Client
		cache *ccache.Cache
		db    *models.DB
	}

	Config struct {
		YdbAuthToken string
		YdbDatabase  string
		YdbEndpoint  string
		VtAuthToken  string
	}
)

var (
	ErrFetchFail = xerrors.NewSentinel("failed to get report from upstream")
)

func NewClient(ctx context.Context, cfg Config) (client *ReportClient, err error) {
	var ydb *models.DB
	ydb, err = models.NewDB(ctx, models.DBConfig{
		AuthToken: cfg.YdbAuthToken,
		Database:  cfg.YdbDatabase,
		Endpoint:  cfg.YdbEndpoint,
	})
	if err != nil {
		return
	}

	client = &ReportClient{
		vt: virustotal.NewClient(
			virustotal.ClientConfig{
				AuthToken: cfg.VtAuthToken,
			}),
		cache: ccache.New(
			ccache.Configure().
				Buckets(cacheBuckets).
				DeleteBuffer(cacheDeleteBuffer).
				GetsPerPromote(cacheGetsPerPromote).
				ItemsToPrune(cacheItemsToPrune).
				MaxSize(cacheMaxSize).
				PromoteBuffer(cachePromoteBuffer),
		),
		db: ydb,
	}
	return
}

func (r *ReportClient) Close() {
	r.cache.Stop()
}

func (r *ReportClient) GetReport(hash string, forceFetch bool) (oldReport *models.Report, report *models.Report, cacheStatus CacheStatus, err error) {
	if !forceFetch {
		if item := r.cache.Get(hash); item != nil && !item.Expired() && item.Value() != nil {
			report = item.Value().(*models.Report)
			if !isReportExpired(report) {
				cacheStatus = CacheStatus{CacheStatusHotCache}
				return
			}
		}
	}

	cacheStatus = CacheStatus{}
	fetch := true
	var dbErr error
	report, dbErr = r.db.LookupRecord(hash)
	switch {
	case dbErr == models.ErrNotFound:
		cacheStatus.Status = CacheStatusVtNew
	case dbErr != nil:
		logger.Error("failed to fetch report", log.String("hash", hash), log.Error(dbErr))
		cacheStatus.Status = CacheStatusVtRefresh
	default:
		cacheStatus.Status = CacheStatusCache
		fetch = forceFetch || isReportExpired(report)

		if fetch {
			cacheStatus.Status = CacheStatusVtRefresh
		}
	}

	if fetch {
		upstreamReport, upstreamErr := r.vt.FileReport(hash)
		if upstreamErr != nil {
			logger.Error("failed to fetch report from upstream, err", log.String("hash", hash), log.Error(upstreamErr))
			if report == nil {
				err = ErrFetchFail.WithFrame()
				return
			}
		} else {
			now := time.Now()
			var createdAt time.Time
			if report == nil {
				createdAt = now
			} else {
				createdAt = report.CreatedAt
				oldReport = report
			}

			report = &models.Report{
				Found: upstreamReport.Found,
				Md5:   upstreamReport.Md5,
				Sha1:  upstreamReport.Sha1,
				// Save sha256 hash from request for case if VT not found our file
				Sha256:    hash,
				Positives: upstreamReport.Positives,
				Total:     upstreamReport.Total,
				UpdatedAt: now,
				CreatedAt: createdAt,
				Scans:     upstreamReport.Scans,
			}

			upstreamErr = r.db.InsertRecord(report)
			if upstreamErr != nil {
				logger.Error("failed to save report", log.String("hash", hash), log.Error(upstreamErr))
			}
		}
	}

	if report != nil {
		cacheUntil := report.UpdatedAt
		if time.Since(cacheUntil) > cacheTTL {
			// If we have very old report - just cache it with default TTL
			cacheUntil = time.Now()
		}

		r.cache.Set(hash, report, time.Until(cacheUntil.Add(cacheTTL)))
	}
	return
}

func (r *ReportClient) UpdatedToday() (count uint64, err error) {
	t := time.Now()
	year, month, day := t.Date()
	midnight := time.Date(year, month, day, 0, 0, 0, 0, t.Location())
	return r.db.CountUpdatedAfter(midnight.Unix())
}

func isReportExpired(report *models.Report) bool {
	return time.Since(report.CreatedAt) <= refreshTTL &&
		time.Since(report.UpdatedAt) >= cacheTTL
}
