package data

import (
	"fmt"
	"strings"
	"sync"
	"time"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/zap"
	"a.yandex-team.ru/mail/logconsumers/actdb_consumer/metrics"
	"a.yandex-team.ru/mail/logconsumers/actdb_consumer/storage"
)

type ModulesMap = map[Module]bool
type UIDModulesMap map[UID]ModulesMap
type DateUIDModulesMap map[Date]UIDModulesMap

type Buffer struct {
	buffer DateUIDModulesMap
	mu     sync.Mutex
}

func (b *Buffer) Add(item *ModuleActivity) {
	b.mu.Lock()
	defer b.mu.Unlock()

	date := item.Date
	uid := item.UID
	module := item.Module

	if date == "" || uid == 0 || module == "" {
		return
	}

	switch {
	case b.buffer == nil:
		b.buffer = DateUIDModulesMap{date: UIDModulesMap{uid: ModulesMap{module: true}}}
	case b.buffer[date] == nil:
		b.buffer[date] = UIDModulesMap{uid: ModulesMap{module: true}}
	case b.buffer[date][uid] == nil:
		b.buffer[date][uid] = ModulesMap{module: true}
	default:
		b.buffer[date][uid][module] = true
	}
}

func (b *Buffer) Length() int {
	res := 0
	for _, date := range b.buffer {
		res += len(date)
	}
	return res
}

const sqlTpl = string("INSERT INTO history.user_activity VALUES %s ON CONFLICT (uid, module) DO UPDATE SET last_dt = '%s' WHERE history.user_activity.last_dt != '%s'")

func (b *Buffer) Flush(storage storage.IStorage, logger *zap.Logger, yasm *metrics.Yasm) bool {
	b.mu.Lock()
	defer b.mu.Unlock()

	if b.Length() == 0 {
		return true
	}

	var (
		ok  bool
		err error
	)

	for date := range b.buffer {
		sql := b.makeSQL(date)
		ok, err = storage.Run(sql, yasm)
		if !ok || err != nil {
			logger.Warn(fmt.Sprintf("Flushing %d items failed", b.Length()), log.Error(err))
			yasm.Update("buffer_errors", float64(b.Length()))
			time.Sleep(250 * time.Millisecond)
		} else {
			logger.Infof("Flushed %d items", b.Length())
			yasm.Update("buffer_flushes", float64(b.Length()))
		}
	}

	if ok {
		b.reset()
	}

	return true
}

func (b *Buffer) makeSQL(date Date) string {
	items := make([]string, 0)
	for uid, modules := range b.buffer[date] {
		for module := range modules {
			items = append(items, fmt.Sprintf("(%d, '%s', '%s')", uid, module, date))
		}
	}
	return fmt.Sprintf(sqlTpl, strings.Join(items[:], ", "), date, date)
}

func (b *Buffer) reset() {
	b.buffer = make(DateUIDModulesMap)
}
