// backfill_talk_a_little
//
// This script completes the 'Talk a Little' achievement for any channels that
// have already completed the 'Talk to Me' achievement, level 1.
// 'Talk to Me' requires 25 unique chatters, which is higher than the requirements
// for all levels of 'Talk a Little'. Because of this, it makes sense that anyone who
// has achieved 'Talk to Me' should have also achieved all levels of 'Talk a Little'
//
// Usage:
// go run main.go -host host -port port -db dbname -user dbuser -password dbpassword

package main

import (
	"context"
	"database/sql"
	"flag"
	"fmt"
	"time"

	_ "github.com/lib/pq"
	"github.com/pkg/errors"
	"golang.org/x/sync/semaphore"
)

var (
	dbHost     = flag.String("host", "", "")
	dbPort     = flag.String("port", "", "")
	dbName     = flag.String("db", "", "")
	dbUser     = flag.String("user", "", "")
	dbPassword = flag.String("password", "", "")
)

const workerCount = 100

func init() {
	flag.Parse()
}

type progression struct {
	ChannelID   string
	CompletedAt time.Time
}

type achievement struct {
	AchievementID string
	ProgressCap   int
}

func main() {
	db, err := newDB()
	if err != nil {
		fmt.Println("failed to open db connection:", err)
		return
	}

	achievements, err := getTalkALittleAch(db)
	if err != nil {
		fmt.Println("failed to get achievements: ", err)
		return
	}

	progressions, err := getCompletedTalkToMe(db)
	if err != nil {
		fmt.Println("failed to get completed talk to me channel ids: ", err)
		return
	}

	completeTalkALittle(db, achievements, progressions)
}

func newDB() (*sql.DB, error) {
	connection := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", *dbUser, *dbPassword, *dbHost, *dbPort, *dbName)

	fmt.Println("opening database connection at:", connection)

	session, err := sql.Open("postgres", connection)
	if err != nil {
		return nil, err
	}

	if err := session.Ping(); err != nil {
		return nil, err
	}

	fmt.Println("opened database connection at:", connection)

	return session, nil
}

func getTalkALittleAch(db *sql.DB) ([]achievement, error) {
	var achievements []achievement
	statement := `
		SELECT id, progress_cap
		FROM achievements
		WHERE key = 'talk_a_little'
	`

	rows, err := db.QueryContext(context.Background(), statement)
	switch {
	case err == sql.ErrNoRows:
		return achievements, nil
	case err != nil:
		msg := "failed to query for talk to me achievements"
		return nil, errors.Wrap(err, msg)
	}

	defer func() {
		err = rows.Close()
		if err != nil {
			fmt.Println("db: failed to close pg rows")
		}
	}()

	for rows.Next() {
		ach := achievement{}
		err = rows.Scan(
			&ach.AchievementID,
			&ach.ProgressCap,
		)
		if err != nil {
			return nil, errors.Wrap(err, "failed to read achievements row")
		}
		achievements = append(achievements, ach)
	}

	return achievements, nil
}

func getCompletedTalkToMe(db *sql.DB) ([]progression, error) {
	var progressions []progression
	statement := `
        SELECT channel_id, completed_at_utc
        FROM progressions AS p
        JOIN achievements AS a
        ON   p.achievement_id = a.id
        WHERE a.key = 'n_unique_chatter_broadcast'
        AND a.progress_cap = 25
        AND p.completed_at_utc IS NOT NULL
    `
	fmt.Println("Selecting all channel id's with completed Talk to Me achievement...")
	start := time.Now()

	rows, err := db.QueryContext(context.Background(), statement)
	switch {
	case err == sql.ErrNoRows:
		return progressions, nil
	case err != nil:
		msg := "failed to query for talk to me progressions"
		return nil, errors.Wrap(err, msg)
	}

	defer func() {
		err = rows.Close()
		if err != nil {
			fmt.Println("db: failed to close pg rows")
		}
	}()

	for rows.Next() {
		progress := progression{}
		err = rows.Scan(
			&progress.ChannelID,
			&progress.CompletedAt,
		)
		if err != nil {
			return nil, errors.Wrap(err, "failed to read progressions row")
		}
		progressions = append(progressions, progress)
	}

	elapsed := time.Since(start)
	fmt.Printf("Finished selecting %d rows in %s\n", len(progressions), elapsed)
	return progressions, nil
}

func completeTalkALittle(db *sql.DB, achievements []achievement, progressions []progression) {
	workerPool := semaphore.NewWeighted(int64(workerCount))
	failures := make(chan string, len(progressions))

	for _, progress := range progressions {
		fmt.Println("Completing Talk a Little for channel: ", progress.ChannelID)

		for _, ach := range achievements {
			if err := workerPool.Acquire(context.Background(), 1); err != nil {
				fmt.Println("failed to acquire worker from progress worker pool: ", err)
				continue
			}

			go insertProgression(db, progress, ach.ProgressCap, ach.AchievementID, workerPool, failures)
		}
	}

	if err := workerPool.Acquire(context.Background(), int64(workerCount)); err != nil {
		fmt.Println("failed to acquire workers: ", err)
	}

	close(failures)

	fmt.Println("=====================================================")
	for channelID := range failures {
		fmt.Println("failed progression insertion for: ", channelID)
	}

	fmt.Println("=====================================================")
	fmt.Println("Total completed Talk to Me: ", len(progressions))
}

func insertProgression(db *sql.DB, progress progression, progressCap int, achievementID string, workerPool *semaphore.Weighted, failChan chan<- string) {
	statement := `
        INSERT INTO progressions (
            channel_id,
            achievement_id,
            progress,
            created_at_utc,
            completed_at_utc
        )
        SELECT $1::VARCHAR, $2::UUID, $3, $4, $5
        WHERE NOT EXISTS (
        	SELECT channel_id, achievement_id
        	FROM progressions
        	WHERE channel_id = $1
        	AND achievement_id = $2
        )
    `

	_, err := db.ExecContext(context.Background(),
		statement,
		progress.ChannelID,
		achievementID,
		progressCap,
		progress.CompletedAt,
		progress.CompletedAt)

	if err != nil {
		failChan <- progress.ChannelID
	}

	workerPool.Release(1)
}
