package screening

import (
	"context"
	"fmt"
	"os"
	"time"

	"github.com/serialx/hashring"

	"a.yandex-team.ru/infra/walle/server/go/internal/utilities"
	lockrepo "a.yandex-team.ru/infra/walle/server/go/internal/utilities/repository"
	"a.yandex-team.ru/library/go/core/log"
)

const instanceLockIDPrefixTemplate = "dmc-screening-%d/instance-"

type shardManager struct {
	ring         *hashring.HashRing
	instance     string
	tier         uint64
	shardsNum    int
	lockRepo     lockrepo.LockRepo
	locker       utilities.Locker
	instanceLock *utilities.LockObject
	logger       log.Logger
}

func newShardManager(tier uint64, shardNum int, lockRepo lockrepo.LockRepo, logger log.Logger) (*shardManager, error) {
	instance, err := os.Hostname()
	if err != nil {
		return nil, err
	}
	locker := utilities.NewLocker(lockRepo)
	lock, err := locker.Lock(logger, getInstanceLockID(tier, instance), instance)
	if err != nil {
		return nil, err
	}
	return &shardManager{
		shardsNum:    shardNum,
		instance:     instance,
		lockRepo:     lockRepo,
		locker:       locker,
		instanceLock: lock,
		tier:         tier,
		logger:       logger,
	}, nil
}

func (sm *shardManager) getShards(ctx context.Context) ([]*shard, error) {
	if err := sm.updateHashring(ctx); err != nil {
		return nil, err
	}
	var res []*shard
	for i := 0; i < sm.shardsNum; i++ {
		instance, _ := sm.ring.GetNode(fmt.Sprintf("%d", i))
		if instance == sm.instance {
			res = append(res, &shard{id: i})
		}
	}
	return res, nil
}

func (sm *shardManager) updateHashring(ctx context.Context) error {
	instances, err := sm.findAllInstances(ctx)
	if err != nil {
		return err
	}
	weights := make(map[string]int, len(instances))
	for _, inst := range instances {
		weights[inst] = 2048
	}

	sm.ring = hashring.NewWithWeights(weights)
	return nil
}

func (sm *shardManager) findAllInstances(ctx context.Context) ([]string, error) {
	return sm.lockRepo.FindOwners(
		ctx,
		&lockrepo.LockFilter{
			Prefix:      fmt.Sprintf(instanceLockIDPrefixTemplate, sm.tier),
			LockedUntil: time.Now().Add(-mainIterationInterval),
		},
	)
}

func (sm *shardManager) clean() {
	sm.locker.Unlock(sm.instanceLock)
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()
	if err := sm.lockRepo.Delete(ctx, getInstanceLockID(sm.tier, sm.instance)); err != nil {
		sm.logger.Errorf("Failed to delete an instance lock: %v", err)
	}
}

type shard struct {
	id int
}

func getInstanceLockID(tier uint64, instance string) string {
	return fmt.Sprintf(instanceLockIDPrefixTemplate+"%s", tier, instance)
}
