package memory

import (
	"fmt"
	"sync"
	"sync/atomic"
	"time"

	"code.justin.tv/gds/gds/golibs/dynamodb/lazy"
	"code.justin.tv/gds/gds/golibs/event"
	"code.justin.tv/gds/gds/golibs/uuid"
	"code.justin.tv/extensions/fulton-configuration/data/model"
	"code.justin.tv/extensions/fulton-configuration/data/model/shared"
	"code.justin.tv/extensions/fulton-configuration/protocol"
)

type store struct {
	uuid       uuid.Source
	clone      func(src, dest interface{}) error
	common     map[string]*model.Common
	channels   map[string]*model.Channel
	blocks     map[string]bool
	allowReset int32
	mutex      sync.Mutex
}

func New(src uuid.Source) model.StoreWithTracker {
	out := &store{
		uuid:     src,
		clone:    cloneByMarshal,
		common:   make(map[string]*model.Common),
		channels: make(map[string]*model.Channel),
		blocks:   make(map[string]bool),
	}
	return out
}

func (s *store) AsyncLoadCommon(envirionment, extensionID string) model.CommonPromise {
	p := newCommonPromise()
	go func() {
		data, err := s.LoadCommon(envirionment, extensionID)
		p.inner.Set(data, err)
	}()
	return shared.CommonPromise(p)
}

func (s *store) AsyncLoadChannel(environment, extensionID, channelID string) model.ChannelPromise {
	p := newChannelPromise()
	go func() {
		data, err := s.LoadChannel(environment, extensionID, channelID)
		p.inner.Set(data, err)
	}()
	return shared.ChannelPromise(p)
}

func (s *store) LoadCommon(environment, extensionID string) (*model.Common, error) {
	defer s.mutex.Unlock()
	s.mutex.Lock()
	if ins, ok := s.common[commonKey(environment, extensionID)]; ok {
		return s.copyCommon(ins)
	}
	return nil, nil
}

func (s *store) LoadChannel(environment, extensionID, channelID string) (*model.Channel, error) {
	defer s.mutex.Unlock()
	s.mutex.Lock()
	if _, ok := s.blocks[channelID]; ok {
		return nil, nil
	}
	if ins, ok := s.channels[channelKey(environment, extensionID, channelID)]; ok {
		return s.copyChannel(ins)
	}
	return nil, nil
}

func (s *store) SaveCommon(c *model.Common) error {
	s.mutex.Lock()
	defer s.mutex.Unlock()
	key := commonKey(c.Environment, c.ExtensionID)
	prev, ok := s.common[key]
	if ok && c.ConcurrencyUUID != prev.ConcurrencyUUID {
		return protocol.ErrConcurrency
	}
	if !ok && c.ConcurrencyUUID != "" {
		return protocol.ErrConcurrency
	}
	out := &model.Common{}
	if err := s.clone(c, out); err != nil {
		return err
	}
	var err error
	if out.ConcurrencyUUID, err = s.uuid.Next(); err != nil {
		return err
	}
	if out.UnpublishedTime == nil && len(out.Messages) > 0 {
		now := time.Now()
		out.UnpublishedTime = &now
	}
	s.common[key] = out
	c.ConcurrencyUUID = out.ConcurrencyUUID
	return nil
}

func (s *store) SaveChannel(c *model.Channel) error {
	s.mutex.Lock()
	defer s.mutex.Unlock()
	if _, ok := s.blocks[c.ChannelID]; ok {
		return protocol.ErrForbiddenByBroadcaster
	}
	key := channelKey(c.Environment, c.ExtensionID, c.ChannelID)
	prev, ok := s.channels[key]
	if ok && c.ConcurrencyUUID != prev.ConcurrencyUUID {
		return protocol.ErrConcurrency
	}
	if !ok && c.ConcurrencyUUID != "" {
		return protocol.ErrConcurrency
	}
	out := &model.Channel{}
	if err := s.clone(c, out); err != nil {
		return err
	}
	var err error
	if out.ConcurrencyUUID, err = s.uuid.Next(); err != nil {
		return err
	}
	if out.UnpublishedTime == nil && len(out.Messages) > 0 {
		now := time.Now()
		out.UnpublishedTime = &now
	}
	s.channels[key] = out
	c.ConcurrencyUUID = out.ConcurrencyUUID
	return nil
}

func (s *store) DeleteChannel(channelID string) error {
	s.mutex.Lock()
	defer s.mutex.Unlock()
	return s.block(channelID)
}

func (s *store) MarkCommonPublished(c *model.Common) error {
	defer s.mutex.Unlock()
	s.mutex.Lock()
	internal, ok := s.common[commonKey(c.Environment, c.ExtensionID)]
	if !ok {
		return protocol.ErrConcurrency
	}
	if internal.ConcurrencyUUID != c.ConcurrencyUUID || internal.UnpublishedTime == nil {
		return protocol.ErrConcurrency
	}
	internal.Messages = []event.Message{}
	internal.UnpublishedTime = nil
	return nil
}

func (s *store) MarkChannelPublished(c *model.Channel) error {
	defer s.mutex.Unlock()
	s.mutex.Lock()
	internal, ok := s.channels[channelKey(c.Environment, c.ExtensionID, c.ChannelID)]
	if !ok {
		return protocol.ErrConcurrency
	}
	if internal.ConcurrencyUUID != c.ConcurrencyUUID || internal.UnpublishedTime == nil {
		return protocol.ErrConcurrency
	}
	internal.Messages = []event.Message{}
	internal.UnpublishedTime = nil
	return nil
}

func (s *store) Block(channelID string) error {
	s.mutex.Lock()
	defer s.mutex.Unlock()
	return s.block(channelID)
}

func (s *store) Unblock(channelID string) error {
	s.mutex.Lock()
	defer s.mutex.Unlock()
	delete(s.blocks, channelID)
	return nil
}

func (s *store) OnDeletionFinished(channelID string) error {
	s.mutex.Lock()
	defer s.mutex.Unlock()
	if _, found := s.blocks[channelID]; found {
		s.blocks[channelID] = false
	}
	return nil
}

func (s *store) IsBlocked(channelID string) model.BlockPromise {
	s.mutex.Lock()
	defer s.mutex.Unlock()
	_, ok := s.blocks[channelID]
	return newBlockPromise(ok, nil)
}

func (s *store) DeletionInProgress() ([]string, error) {
	s.mutex.Lock()
	defer s.mutex.Unlock()
	out := []string{}
	for k, v := range s.blocks {
		if v {
			out = append(out, k)
		}
	}
	return out, nil
}

// for test environments only
func (s *store) IsResetEnabled() bool { return atomic.LoadInt32(&s.allowReset) != 0 }
func (s *store) EnableDataReset()     { atomic.StoreInt32(&s.allowReset, 1) }
func (s *store) ResetAllData() error {
	if !s.IsResetEnabled() {
		return protocol.ErrUnavailable
	}
	s.mutex.Lock()
	defer s.mutex.Unlock()
	s.common = make(map[string]*model.Common)
	s.channels = make(map[string]*model.Channel)
	s.blocks = make(map[string]bool)
	return nil
}

func cloneByMarshal(src, dest interface{}) error {
	av, err := lazy.Marshal(src)
	if err == nil {
		err = lazy.Unmarshal(av, dest)
	}
	return err
}

func (s *store) copyChannel(i *model.Channel) (*model.Channel, error) {
	out := &model.Channel{}
	if err := s.clone(i, out); err != nil {
		return nil, err
	}
	return out, nil
}

func (s *store) block(channelID string) error {
	s.blocks[channelID] = true
	return nil
}

func (s *store) copyCommon(i *model.Common) (*model.Common, error) {
	out := &model.Common{}
	if err := s.clone(i, out); err != nil {
		return nil, err
	}
	return out, nil
}

func channelKey(env, extID, chID string) string {
	return fmt.Sprintf("%s:%s:%s", env, extID, chID)
}

func commonKey(env, extID string) string {
	return fmt.Sprintf("%s:%s", env, extID)
}
