// parses and executes bulk
package bulk

import (
	"code.justin.tv/d8a/iceman/lib/queries"
	"database/sql"
	"fmt"
	yaml "gopkg.in/yaml.v2"
	"strings"
	"time"
)

type Bulk struct {
	Relation string
	Read     *Read
	Write    *Write
	Batch    *Batch
	Sleep    int
	Verbose  bool
}

type Read struct {
	Query   string
	StartId int `yaml:"start_id"`
}

type Write struct {
	Query     string
	MaxWrites int `yaml:"max_writes"`
}

type Batch struct {
	Size  int
	Batch bool
}

// helper function to parse yaml
func (b *Bulk) parse(data []byte) error {
	return yaml.Unmarshal(data, b)
}

// extract bulk from file
func extractBulk(br *BulkRecord) (*Bulk, error) {
	data := br.Content

	var bulk Bulk
	if err := bulk.parse(data); err != nil {
		return &bulk, err
	}

	return &bulk, nil
}

// ApplyBulk executes the bulk operation on the
// given database, returning an error if it occurs
func ApplyBulk(db *sql.DB, driverQueries queries.DriverQueries, br *BulkRecord) error {
	// gather bulk parameters
	bulk, err := extractBulk(br)
	if err != nil {
		return err
	}

	relation := bulk.Relation
	read := bulk.Read
	write := bulk.Write
	batch := bulk.Batch
	sleep := bulk.Sleep
	verbose := bulk.Verbose

	// get min/max ids
	min, max, err := findRange(db, relation)
	if err != nil {
		return err
	}

	var start int
	start = min
	if read.StartId > start {
		start = read.StartId
	}
	if br.NextRow > start {
		start = br.NextRow
	}

	// determine max writes
	max_writes := 0 // no maximum
	if write.MaxWrites > 0 {
		max_writes = write.MaxWrites
	}

	read_query := buildReadQuery(read.Query, relation, driverQueries)

	if verbose {
		fmt.Println(read_query)
		fmt.Println(write.Query)
		fmt.Printf("Processing as many as %v rows (until id=%v) in %v.\n", max-start+1, max, relation)
	}
	var rows_processed int

	for start <= max && isUnderWriteLimit(rows_processed, max_writes) {
		end := start + batch.Size
		rows, err := db.Query(read_query, start, end)
		if err != nil {
			fmt.Println(err)
			break
		}

		if batch.Batch {
			r, e := processBatchRows(db, rows, driverQueries, write.Query, verbose)
			if e != nil {
				fmt.Println(e)
				break
			}
			rows_processed += r
		} else {
			r, e := processEveryRow(db, rows, write.Query, verbose)
			if e != nil {
				fmt.Println(e)
				break
			}
			rows_processed += r
		}
		if verbose {
			fmt.Printf("next id: %v\n", end)
		}
		time.Sleep(time.Duration(sleep) * time.Millisecond)
		start = end
	}

	if start > max {
		br.Complete = true
		br.NextRow = max + 1
	} else {
		br.Complete = false
		br.NextRow = start
	}

	// update table
	return updateBulkTable(db, driverQueries, br)
}

// helper function to get range of ids
func findRange(db *sql.DB, relation string) (int, int, error) {
	var min int
	var max int

	query := fmt.Sprintf("SELECT min(id), max(id) FROM %v;", relation)
	row, err := db.Query(query)
	if err != nil {
		return 0, 0, err
	}
	defer queries.TryClose(row)

	for row.Next() {
		err = row.Scan(&min, &max)
		if err != nil {
			return 0, 0, err
		}
	}
	return min, max, nil
}

// helper function to build read query
func buildReadQuery(read string, relation string, driverQueries queries.DriverQueries) string {
	if read == "" {
		return fmt.Sprintf(driverQueries.SelectIdBetween(), relation)
	}
	suffix := fmt.Sprintf(" "+driverQueries.AndIdBetween(), relation, relation)
	return strings.TrimSuffix(strings.TrimSpace(read), ";") + suffix
}

func isUnderWriteLimit(rows_processed int, max_writes int) bool {
	return max_writes == 0 || rows_processed < max_writes
}

// helper function for processBatchRows().
// gathers all ids and prepares it for write
func extractIdsFromRows(rows *sql.Rows) ([]interface{}, error) {
	var id int
	ids := make([]interface{}, 0)

	defer queries.TryClose(rows)
	for rows.Next() {
		err := rows.Scan(&id)
		if err != nil {
			return ids, err
		}
		ids = append(ids, id)
	}
	return ids, nil
}

// helper function for processBatchRows()
func buildBatchQuery(driverQueries queries.DriverQueries, ids []interface{}, query string) string {
	placeholders := driverQueries.CreatePlaceholders(len(ids))
	return strings.Replace(query, "?", placeholders, 1)
}

// run when batch is true
func processBatchRows(db *sql.DB, rows *sql.Rows, driverQueries queries.DriverQueries, query string, verbose bool) (int, error) {
	var rows_processed int

	ids, e := extractIdsFromRows(rows)
	if e != nil {
		return 0, e
	}
	if len(ids) == 0 {
		// Returning an error causes the loop that calls this to bail
		// TODO - return an error, and the outer loop should detect and continue
		if verbose {
			fmt.Println("Batch contains 0 Rows ... Skipping.")
		}
		return 0, nil
	}
	query = buildBatchQuery(driverQueries, ids, query)
	if verbose {
		fmt.Printf("Here is the query: %v\n", query)
		fmt.Println("IDS: ", ids)
	}
	result, err := db.Exec(query, driverQueries.FormatArray(ids)...)
	if err != nil {
		fmt.Println(err)
		return 0, err
	}
	ra, err := result.RowsAffected()
	if err != nil {
		return 0, err
	}
	rows_processed += int(ra)
	if verbose {
		fmt.Printf("Rows processed in batch: %v\n", ra)
	}
	return rows_processed, nil
}

// helper function for processEveryRow()
func extractParamsFromRows(rows *sql.Rows) ([][]interface{}, error) {
	parameters := make([][]interface{}, 0)
	defer queries.TryClose(rows)

	for rows.Next() {
		cols, _ := rows.Columns()
		params := make([]interface{}, len(cols))
		pointers := make([]interface{}, len(cols))

		for i := range pointers {
			pointers[i] = &params[i]
		}
		err := rows.Scan(pointers...)
		if err != nil {
			return parameters, err
		}
		parameters = append(parameters, params)
	}
	return parameters, nil
}

// run when batch is false
func processEveryRow(db *sql.DB, rows *sql.Rows, query string, verbose bool) (int, error) {
	var rows_processed int

	parameters, e := extractParamsFromRows(rows)
	if e != nil {
		return 0, e
	}

	for _, params := range parameters {
		if verbose {
			fmt.Print("HERE ARE THE PARAMS: ")
			fmt.Println(params...)
		}
		result, err := db.Exec(query, params...)
		if err != nil {
			return 0, err
		}
		ra, err := result.RowsAffected()
		if err != nil {
			return 0, err
		}
		rows_processed += int(ra)
	}
	return rows_processed, nil
}

// update table after executing bulk operation
func updateBulkTable(db *sql.DB, driverQueries queries.DriverQueries, br *BulkRecord) error {

	utc, err := time.LoadLocation("UTC")
	if err != nil {
		return err
	}

	updateQuery := driverQueries.UpdateBulk()
	var results sql.Result
	if results, err = db.Exec(updateQuery, time.Now().In(utc), br.Complete, br.NextRow, br.Filename); err != nil {
		return err
	}

	rowCount, err := results.RowsAffected()
	if err != nil {
		return err
	}

	if rowCount < 1 {
		insert := driverQueries.InsertBulk()
		if _, err := db.Exec(insert, br.Filename, br.Name, time.Now().In(utc), br.Complete, br.NextRow); err != nil {
			return err
		}
	}

	return nil
}
