package main

import (
	"log"
	"sort"

	"code.justin.tv/rhys/nursery/pgparse"
)

func main() {
	log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)

	query := [...]string{`
-- TMI db.py
SELECT 1
FROM ticket_products
WHERE id in (1)
AND availability_state = 'active'
AND ((ticket_type = 'chansub' AND ticket_product_owner_id = 2)
    OR
    (ticket_type = 'teamsub' AND ticket_product_owner_id IN
        (SELECT team_users.team_id
        FROM team_users
        WHERE team_users.user_id = 3)))
LIMIT 1
`[1:],
		`
SELECT pubs.login, mods.login
FROM publisher_ops, users AS pubs, users AS mods
WHERE publisher_ops.publisher_id = pubs.id
  AND publisher_ops.op_id = mods.id
  AND publisher_ops.op_id IS NOT NULL
  AND pubs.login IN (1,2,3)
`[1:],
		`
SELECT archives.* FROM archives
WHERE archives.broadcast_id IS NULL
AND (archives.kind is null or archives.kind != 'highlight')
ORDER BY broadcast_part
/*application:Twitch,controller:channels,action:videos*/
`[1:],
		`
SELECT "archives".* FROM "archives"
WHERE "archives"."broadcast_id" IS NULL
AND (archives.kind is null or archives.kind != 'highlight')
ORDER BY broadcast_part
/*application:Twitch,controller:channels,action:videos*/
`[1:],
	}[2]

	stmt, err := pgparse.Parse(query)
	if err != nil {
		// log.Printf("len(query): %d", len(query))
		// log.Printf("offset: %s", query[19:])
		log.Fatalf("pgparse: %v", err)
	}
	log.Printf("stmt: %T", stmt)
	tr := newStatementTracker()
	tr.dumpStmt(stmt)
	log.Printf("tr: %#v", tr)
	log.Printf("tables: %q", tr.Tables())
}

type statementTracker struct {
	tables        map[string]string
	columns       map[string]map[string]struct{}
	substatements []*statementTracker
}

func newStatementTracker() *statementTracker {
	return &statementTracker{
		tables:        make(map[string]string),
		columns:       make(map[string]map[string]struct{}),
		substatements: nil,
	}
}

func (tr *statementTracker) Tables() []string {
	tables := make(map[string]struct{})
	for _, table := range tr.tables {
		tables[table] = struct{}{}
	}
	for _, sub := range tr.substatements {
		for _, table := range sub.Tables() {
			tables[table] = struct{}{}
		}
	}
	ordered := make([]string, 0, len(tables))
	for table := range tables {
		ordered = append(ordered, table)
	}
	sort.Strings(ordered)
	return ordered
}

func (tr *statementTracker) dumpStmt(stmt pgparse.Statement) {
	switch stmt := stmt.(type) {
	case *pgparse.Select:
		for _, comment := range stmt.Comments {
			log.Printf("comment: %q", comment)
		}

		log.Printf("distinct: %q", stmt.Distinct)

		for _, expr := range stmt.SelectExprs {
			switch expr := expr.(type) {
			case *pgparse.NonStarExpr:
				log.Printf("expr.As: %q", expr.As)
				log.Printf("expr.Expr: %T", expr.Expr)
				switch expr := expr.Expr.(type) {
				case pgparse.NumVal:
					val, err := pgparse.AsInterface(expr)
					if err != nil {
						log.Fatalf("AsInterface: %v", err)
					}
					log.Printf("val: %T", val)
					log.Printf("val: %s", val)
				}
			}
		}

		for _, from := range stmt.From {
			switch from := from.(type) {
			case *pgparse.AliasedTableExpr:
				switch expr := from.Expr.(type) {
				case *pgparse.TableName:
					table := string(expr.Name)
					as := string(from.As)
					if as == "" {
						as = table
					}
					tr.tables[as] = table
				default:
					log.Fatalf("unknown type: %T", expr)
				}
			default:
				log.Fatalf("unknown type: %T", from)
			}
		}

		if stmt.Where != nil {
			log.Printf("where: %q", stmt.Where.Type)
			tr.dumpExpr(stmt.Where.Expr)
		}

		for _, gb := range stmt.GroupBy {
			log.Printf("group by: %T", gb)
		}

		if stmt.Having != nil {
			log.Printf("having: %q", stmt.Having.Type)
		}
	}
}

func (tr *statementTracker) dumpExpr(expr pgparse.Expr) {
	switch expr := expr.(type) {
	case *pgparse.AndExpr:
		tr.dumpExpr(expr.Left)
		tr.dumpExpr(expr.Right)
	case *pgparse.OrExpr:
		tr.dumpExpr(expr.Left)
		tr.dumpExpr(expr.Right)
	case *pgparse.ComparisonExpr:
		tr.dumpExpr(expr.Left)
		tr.dumpExpr(expr.Right)
	case *pgparse.ParenBoolExpr:
		tr.dumpExpr(expr.Expr)
	case *pgparse.NullCheck:

	case *pgparse.Subquery:
		sub := newStatementTracker()
		sub.dumpStmt(expr.Select)
		tr.substatements = append(tr.substatements, sub)

	case *pgparse.ColName:
		column := string(expr.Name)
		table := string(expr.Qualifier)
		if _, ok := tr.columns[table]; !ok {
			tr.columns[table] = make(map[string]struct{})
		}
		tr.columns[table][column] = struct{}{}
	case pgparse.ValTuple:
		val, err := pgparse.AsInterface(expr)
		if err != nil {
			log.Fatalf("valtuple: %v", err)
		}
		for _, val := range val.([]interface{}) {
			log.Printf("val: %T", val)
			log.Printf("val: %q", val)
		}
	case pgparse.StrVal:
	case pgparse.NumVal:

	default:
		log.Fatalf("unknown type: %T", expr)
	}
}
