package main

import (
	"fmt"
	"sort"

	"code.justin.tv/rhys/nursery/pgparse"
)

type statementTracker struct {
	err error

	tables        map[string]string
	columns       map[string]map[string]struct{}
	substatements []*statementTracker
}

func newStatementTracker() *statementTracker {
	return &statementTracker{
		tables:        make(map[string]string),
		columns:       make(map[string]map[string]struct{}),
		substatements: nil,
	}
}

func (tr *statementTracker) Err() error {
	return tr.err
}

func (tr *statementTracker) Tables() []string {
	tables := make(map[string]struct{})
	for _, table := range tr.tables {
		tables[table] = struct{}{}
	}
	for _, sub := range tr.substatements {
		for _, table := range sub.Tables() {
			tables[table] = struct{}{}
		}
	}
	ordered := make([]string, 0, len(tables))
	for table := range tables {
		ordered = append(ordered, table)
	}
	sort.Strings(ordered)
	return ordered
}

func (tr *statementTracker) dumpStmt(stmt pgparse.Statement) {
	if tr.err != nil {
		return
	}

	switch stmt := stmt.(type) {
	case *pgparse.Select:
		for _, comment := range stmt.Comments {
			_ = comment
			// log.Printf("comment: %q", comment)
		}

		// log.Printf("distinct: %q", stmt.Distinct)

		for _, expr := range stmt.SelectExprs {
			switch expr := expr.(type) {
			case *pgparse.NonStarExpr:
				// log.Printf("expr.As: %q", expr.As)
				// log.Printf("expr.Expr: %T", expr.Expr)
				switch expr := expr.Expr.(type) {
				case pgparse.NumVal:
					val, err := pgparse.AsInterface(expr)
					if err != nil {
						tr.err = fmt.Errorf("AsInterface: %v", err)
						return
					}
					_ = val
					// log.Printf("val: %T", val)
					// log.Printf("val: %s", val)
				}
			}
		}

		for _, table := range stmt.From {
			tr.dumpTableExpr(table)
		}

		if stmt.Where != nil {
			// log.Printf("where: %q", stmt.Where.Type)
			tr.dumpExpr(stmt.Where.Expr)
		}

		for _, gb := range stmt.GroupBy {
			_ = gb
			// log.Printf("group by: %T", gb)
		}

		if stmt.Having != nil {
			// log.Printf("having: %q", stmt.Having.Type)
		}
	}
}

func (tr *statementTracker) dumpTableExpr(table pgparse.TableExpr) {
	if tr.err != nil {
		return
	}

	switch table := table.(type) {
	case *pgparse.AliasedTableExpr:
		switch expr := table.Expr.(type) {
		case *pgparse.TableName:
			name := string(expr.Name)
			as := string(table.As)
			if as == "" {
				as = name
			}
			tr.tables[as] = name
		case *pgparse.Subquery:
			tr.dumpExpr(expr)

		default:
			tr.err = fmt.Errorf("dumpTableExpr1 unknown type: %T", expr)
			return
		}
	case *pgparse.JoinTableExpr:
		tr.dumpTableExpr(table.LeftExpr)
		tr.dumpTableExpr(table.RightExpr)
		tr.dumpExpr(table.On)
		// tr.err = fmt.Errorf("join not yet supported")
		// return

	default:
		tr.err = fmt.Errorf("dumpTableExpr2 unknown type: %T", table)
		return
	}
}

func (tr *statementTracker) dumpExpr(expr pgparse.Expr) {
	if tr.err != nil {
		return
	}

	switch expr := expr.(type) {
	case *pgparse.AndExpr:
		tr.dumpExpr(expr.Left)
		tr.dumpExpr(expr.Right)
	case *pgparse.OrExpr:
		tr.dumpExpr(expr.Left)
		tr.dumpExpr(expr.Right)
	case *pgparse.ComparisonExpr:
		tr.dumpExpr(expr.Left)
		tr.dumpExpr(expr.Right)
	case *pgparse.ParenBoolExpr:
		tr.dumpExpr(expr.Expr)
	case *pgparse.NotExpr:
		tr.dumpExpr(expr.Expr)
	case *pgparse.NullCheck:

	case *pgparse.FuncExpr:
		for _, expr := range expr.Exprs {
			tr.dumpSelectExpr(expr)
		}
	case *pgparse.RangeCond:
		tr.dumpExpr(expr.Left)
		tr.dumpExpr(expr.From)
		tr.dumpExpr(expr.To)
	case *pgparse.BinaryExpr:
		tr.dumpExpr(expr.Left)
		tr.dumpExpr(expr.Right)

	case *pgparse.Subquery:
		sub := newStatementTracker()
		sub.dumpStmt(expr.Select)
		tr.substatements = append(tr.substatements, sub)

	case *pgparse.ColName:
		column := string(expr.Name)
		table := string(expr.Qualifier)
		if _, ok := tr.columns[table]; !ok {
			tr.columns[table] = make(map[string]struct{})
		}
		tr.columns[table][column] = struct{}{}

	case pgparse.ValTuple:
	case pgparse.StrVal:
	case pgparse.NumVal:
	case *pgparse.NullVal:

	default:
		tr.err = fmt.Errorf("dumpExpr unknown type: %T", expr)
		return
	}
}

func (tr *statementTracker) dumpSelectExpr(expr pgparse.SelectExpr) {
	if tr.err != nil {
		return
	}

	switch expr := expr.(type) {
	case *pgparse.NonStarExpr:
		tr.dumpExpr(expr.Expr)

	default:
		tr.err = fmt.Errorf("dumpSelectExpr unknown type: %T", expr)
		return
	}
}
