package storage

import (
	"a.yandex-team.ru/library/go/core/log"
	"fmt"
	"github.com/golang/protobuf/proto"
	"io/fs"
	"io/ioutil"
	"os"

	"a.yandex-team.ru/infra/hostctl/internal/fileutil"
	pb "a.yandex-team.ru/infra/hostctl/proto"
)

func New(path string) *storage {
	if len(path) == 0 {
		path = StateFile
	}
	return &storage{
		path,
		func() (fs.File, error) { return os.Open(path) },
		nil,
	}
}

func NewReadonly(path string) Storage {
	return &readonly{New(path)}
}

type Storage interface {
	Load() (*pb.HostctlState, error)
	Save(state *pb.HostctlState) error
	SaveIfModified(l log.Logger, state *pb.HostctlState) error
}

type storage struct {
	path string
	open func() (fs.File, error)
	orig *pb.HostctlState
}

func (s *storage) Load() (*pb.HostctlState, error) {
	// returning not nil orig in case errs
	er := func(err error) (*pb.HostctlState, error) {
		return &pb.HostctlState{}, err
	}
	f, err := s.open()
	if err != nil {
		return er(fmt.Errorf("failed open: %w", err))
	}
	defer f.Close()
	stat, err := f.Stat()
	if err != nil {
		return er(fmt.Errorf("failed fstat: %w", err))
	}
	if stat.IsDir() {
		return er(fmt.Errorf("can not read state: %s is dir", s.path))
	}
	content, err := ioutil.ReadAll(f)
	if err != nil {
		return er(fmt.Errorf("can not read slots: %w", err))
	}
	orig := &pb.HostctlState{}
	err = proto.Unmarshal(content, orig)
	if err != nil {
		return er(fmt.Errorf("failed to unmarshal %s: %w", s.path, err))
	}
	// Save loaded state to be able to compare and avoid extra writes.
	s.orig = orig
	state := &pb.HostctlState{}
	proto.Merge(state, s.orig)
	return state, nil
}

func (s *storage) Save(state *pb.HostctlState) error {
	slotsBytes, err := proto.Marshal(state)
	if err != nil {
		return fmt.Errorf("failed marshal slots: %w", err)
	}
	return fileutil.AtomicWrite(s.path, slotsBytes, 0644, true, false, 0, 0)
}

func (s *storage) SaveIfModified(l log.Logger, state *pb.HostctlState) error {
	if proto.Equal(state, s.orig) {
		l.Info("No changes - not overwriting state...")
	} else {
		l.Info("Saving state to fs...")
		if err := s.Save(state); err != nil {
			return err
		}
	}
	return nil
}

type readonly struct {
	real *storage
}

func (s *readonly) Load() (*pb.HostctlState, error) {
	return s.real.Load()
}

func (s *readonly) Save(_ *pb.HostctlState) error {
	return nil
}

func (s *readonly) SaveIfModified(l log.Logger, state *pb.HostctlState) error {
	return nil
}
