package main

import (
	"reflect"
	"sort"
	"testing"

	"code.justin.tv/rhys/nursery/pgparse"
)

func TestTableList(t *testing.T) {
	tests := [...]struct {
		query  string
		tables []string
	}{
		{query: `
-- TMI db.py
SELECT 1
FROM ticket_products
WHERE id in (1)
AND availability_state = 'active'
AND ((ticket_type = 'chansub' AND ticket_product_owner_id = 2)
    OR
    (ticket_type = 'teamsub' AND ticket_product_owner_id IN
        (SELECT team_users.team_id
        FROM team_users
        WHERE team_users.user_id = 3)))
LIMIT 1
`[1:], tables: []string{"team_users", "ticket_products"}},

		{query: `
SELECT pubs.login, mods.login
FROM publisher_ops, users AS pubs, users AS mods
WHERE publisher_ops.publisher_id = pubs.id
  AND publisher_ops.op_id = mods.id
  AND publisher_ops.op_id IS NOT NULL
  AND pubs.login IN (1,2,3)
`[1:], tables: []string{"publisher_ops", "users"}},

		{query: `
SELECT archives.* FROM archives
WHERE archives.broadcast_id IS NULL
AND (archives.kind is null or archives.kind != 'highlight')
ORDER BY broadcast_part
/*application:Twitch,controller:channels,action:videos*/
`[1:], tables: []string{"archives"}},

		{query: `
SELECT "archives".* FROM "archives"
WHERE "archives"."broadcast_id" IS NULL
AND (archives.kind is null or archives.kind != 'highlight')
ORDER BY broadcast_part
/*application:Twitch,controller:channels,action:videos*/
`[1:], tables: []string{"archives"}},
	}

	for _, tt := range tests {
		stmt, err := pgparse.Parse(tt.query)
		if err != nil {
			t.Errorf("parse: %v", err)
			continue
		}
		tr := newStatementTracker()
		tr.dumpStmt(stmt)
		sort.Strings(tt.tables)
		if have, want := tr.Tables(), tt.tables; !reflect.DeepEqual(have, want) {
			t.Errorf("tr.Tables(); %q != %q", have, want)
		}
	}
}

func BenchmarkParse(b *testing.B) {
	query := `
SELECT pubs.login, mods.login
FROM publisher_ops, users AS pubs, users AS mods
WHERE publisher_ops.publisher_id = pubs.id
  AND publisher_ops.op_id = mods.id
  AND publisher_ops.op_id IS NOT NULL
  AND pubs.login IN (1,2,3)
`[1:]
	for i := 0; i < b.N; i++ {
		_, err := pgparse.Parse(query)
		if err != nil {
			b.Fatalf("parse error: %v", err)
		}
	}
}
