package pgpcapcollector

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"sync/atomic"

	"time"

	"pgtools/pgquery"

	"code.google.com/p/gopacket"
	"code.google.com/p/gopacket/layers"
	"code.google.com/p/gopacket/pcap"
	"code.google.com/p/gopacket/tcpassembly"
)

var (
	query_count           int64
	complete_query_count  int64
	client_message_count  int64
	invalid_message_count int64
	server_message_count  int64
	missed_count          int64
	active_sessions       map[string]*PGSession
	completed_queries     chan pgquery.PGQuery
)

const (
	BUFFER_SIZE = 1500
)

type PGMessage struct {
	Type     byte
	TotalLen int
	Data     []byte
	Time     time.Time
}

type PGSession struct {
	client         string
	active_queries []*pgquery.PGQuery
}

func NewPGSession(client string) *PGSession {
	return &PGSession{
		client,
		make([]*pgquery.PGQuery, 0),
	}
}

func (s *PGSession) addQuery(q *pgquery.PGQuery) {
	s.active_queries = append(s.active_queries, q)
}

func (s *PGSession) addRow(t time.Time) {
	// Add a row to the current query for this session
	for _, query := range s.active_queries {
		if query.StartTime.Before(t) {
			// This row is for this query
			if query.RowsReturned == 0 {
				// New query
				query.FirstDataTime = t
			}
			// fmt.Println("Adding row", t)
			query.RowsReturned++
			break
		}
	}
}

func (s *PGSession) endQuery(t time.Time) *pgquery.PGQuery {
	// End this query and send it off for processing.
	for i, query := range s.active_queries {
		if query.StartTime.Before(t) {
			// This row is for this query
			if query.RowsReturned == 0 {
				// New query
				query.FirstDataTime = t
			}
			query.DoneTime = t
			atomic.AddInt64(&complete_query_count, 1)

			// Remove the query from the active list and return it.
			s.active_queries = append(s.active_queries[:i], s.active_queries[i+1:]...)
			return query
		}
	}
	return nil
}

type PGStreamFactory struct{}

// Postgres stream parser
type PGStream struct {
	net, transport gopacket.Flow
	client         string
	isClient       bool
	curMessage     PGMessage
	leftover       []byte
	handler        PGMessageHandler
}

type PGMessageHandler interface {
	handleMessage(p *PGStream, message *PGMessage)
}

func (p *PGStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream {
	var isClient bool
	var client string
	var handler PGMessageHandler
	if (transport.Dst().String() == "6543") || (transport.Dst().String() == "5432") {
		isClient = true
		client = net.Src().String() + ":" + transport.Src().String()
		handler = &PGClientMessageHandler{}
	} else {
		isClient = false
		client = net.Dst().String() + ":" + transport.Dst().String()
		handler = &PGServerMessageHandler{}
	}

	stream := &PGStream{
		net:        net,
		transport:  transport,
		curMessage: PGMessage{},
		leftover:   make([]byte, 0),
		isClient:   isClient,
		client:     client,
		handler:    handler,
	}

	return stream
}

func addQuery(c string, q *pgquery.PGQuery) {
	session := active_sessions[c]
	if session == nil {
		session = NewPGSession(c)
		active_sessions[c] = session
	}
	session.addQuery(q)
}

func addRow(c string, t time.Time) {
	session := active_sessions[c]
	if session == nil {
		// fmt.Println("Row with no sessions")
		return
	}
	session.addRow(t)
}

func endQuery(c string, t time.Time) *pgquery.PGQuery {
	session := active_sessions[c]
	if session == nil {
		// fmt.Println("End query with no sessions")
		return nil
	}
	return session.endQuery(t)
}

func clearSession(c string) {
	// We lost a packet, just nuke the entire associated session.
	delete(active_sessions, c)
}

func validClientMessageType(msg_type byte) bool {
	switch msg_type {
	case 'Q', 'p', 'X', 'P', 'F', 'E', 'H', 'D', 'c', 'd', 'C', 'B', 'S':
		return true
	default:
		return false
	}
}

func clientMessageName(msg_type byte) string {
	switch msg_type {
	case 'Q':
		return "Query"
	case 'p':
		return "PasswordMessage"
	case 'X':
		return "Terminate"
	case 'P':
		return "Parse"
	case 'F':
		return "FunctionCall"
	case 'E':
		return "Execute"
	case 'H':
		return "Flush"
	case 'D':
		return "Describe"
	case 'c':
		return "CopyDone"
	case 'd':
		return "CopyData"
	case 'C':
		return "Close"
	case 'B':
		return "Bind"
	case 'S':
		return "Sync"
	default:
		// Might be one of the following
		// CancelRequest
		// SSLRequest
		// StartupMessage
		// fmt.Println("Unknown")
		return fmt.Sprintf("Unknown:%d", msg_type)
	}

}

func validServerMessageType(msg_type byte) bool {
	switch msg_type {
	case 'C', 'T', 'D', 'I', 'Z', 'E', 'R', 'K', '1', '2', '3', 'c',
		'd', 'G', 'H', 'W', 'V', 'n', 'N', 'A', 't', 'S', 's':
		return true
	default:
		return false
	}
}

func serverMessageName(msg_type byte) string {
	switch msg_type {
	case 'C':
		return "CommandComplete"
	case 'T':
		return "RowDescription"
	case 'D':
		return "DataRow"
	case 'I':
		return "EmptyQueryResponse"
	case 'Z':
		return "ReadyForQuery"
	case 'E':
		return "ErrorResponse"
	case 'R':
		return "Authentication"
	case 'K':
		return "BackendKeyData"
	case '1':
		return "ParseComplete"
	case '2':
		return "BindComplete"
	case '3':
		return "CloseComplete"
	case 'c':
		return "CopyDone"
	case 'd':
		return "CopyData"
	case 'G':
		return "CopyInResponse"
	case 'H':
		return "CopyOutResponse"
	case 'W':
		return "CopyBothResponse"
	case 'V':
		return "FunctionCallResponse"
	case 'n':
		return "NoData"
	case 'N':
		return "NoticeResponse"
	case 'A':
		return "NotificationResponse"
	case 't':
		return "ParameterDescription"
	case 'S':
		return "ParameterStatus"
	case 's':
		return "PortalSuspended"
	default:
		return fmt.Sprintf("Unknown:%d", msg_type)
	}
}

func (p *PGStream) ReadMessages(packet_bytes []byte, t time.Time) {
	buf := bytes.NewReader(packet_bytes)
	for {
		var curMessage PGMessage
		if p.curMessage.TotalLen != 0 {
			remaining_len := p.curMessage.TotalLen - len(p.curMessage.Data)
			if remaining_len > len(packet_bytes) {
				// Still not done yet. Just append all of the data.
				p.curMessage.Data = append(p.curMessage.Data, packet_bytes...)
				break
			} else {
				// Pull off the rest of the data for this message and process.
				remainingData := make([]byte, remaining_len)
				_ = binary.Read(buf, binary.BigEndian, remainingData)

				p.curMessage.Data = append(p.curMessage.Data, remainingData...)
				p.curMessage.Time = t
			}
			curMessage = p.curMessage
			p.curMessage = PGMessage{}
		} else {
			// We're pretty sure this should be a legitimate standard
			// message.
			if buf.Len() == 0 {
				// Must be out of data
				break
			}

			if buf.Len() < 5 {
				// Grr, header got split across a packet. Need to save for the next
				// packet
				// fmt.Println("Not enough bytes, adding leftover")
				leftover := make([]byte, buf.Len())
				binary.Read(buf, binary.BigEndian, leftover)
				// fmt.Println("First LO:", serverMessageName(leftover[0]))
				p.leftover = leftover
				break
			}
			var id uint8
			var message_len int32
			_ = binary.Read(buf, binary.BigEndian, &id)
			_ = binary.Read(buf, binary.BigEndian, &message_len)

			if message_len < 4 {
				atomic.AddInt64(&invalid_message_count, 1)
				break
			}

			p.curMessage.Type = id
			p.curMessage.TotalLen = int(message_len - 4) // Minus length field
			if p.curMessage.TotalLen > buf.Len() {
				p.curMessage.Data = make([]byte, buf.Len())
				_ = binary.Read(buf, binary.BigEndian, &p.curMessage.Data)
				p.curMessage.Time = t
				break
			} else {
				if p.curMessage.TotalLen > 0 {
				}
				p.curMessage.Data = make([]byte, p.curMessage.TotalLen)
				_ = binary.Read(buf, binary.BigEndian, p.curMessage.Data)
				p.curMessage.Time = t

				curMessage = p.curMessage
				p.curMessage = PGMessage{}
			}
		}

		p.handler.handleMessage(p, &curMessage)
	}
}

func (p *PGStream) Reassembled(reassemblies []tcpassembly.Reassembly) {
	for _, reassembly := range reassemblies {
		var packet_bytes []byte
		if len(p.leftover) > 0 {
			packet_bytes = append(p.leftover, reassembly.Bytes...)
			p.leftover = make([]byte, 0)
		} else {
			packet_bytes = reassembly.Bytes
		}
		if len(packet_bytes) < 1 {
			continue
		}
		if reassembly.Skip != 0 {
			if reassembly.Skip == -1 {
				// Probably starting in the middle of a pcap stream
				p.curMessage = PGMessage{}
			} else {
				// We missed a packet
				// fmt.Println("Server missed ", reassembly.Skip, "bytes")
				atomic.AddInt64(&missed_count, 1)
				p.curMessage = PGMessage{}
				break
			}
		}
		p.ReadMessages(packet_bytes, reassembly.Seen)
	}
}

func (p *PGStream) ReassemblyComplete() {
	// Close the associated connections
	// fmt.Println("Reassembly complete", p.transport)
}

type PGClientMessageHandler struct {
}

func (mh *PGClientMessageHandler) handleMessage(p *PGStream, message *PGMessage) {
	atomic.AddInt64(&client_message_count, 1)
	if !validClientMessageType(message.Type) {
		// fmt.Println(p.client, "Invalid client message:", clientMessageName(message.Type))
		// Clean up
		return
	}

	switch message.Type {
	case 'Q':
		// Add this to the queue of active messages for this client.
		// fmt.Println(p.client, ":", string(message.Data))
		atomic.AddInt64(&query_count, 1)
		query := pgquery.NewPGQuery(string(message.Data), message.Time)
		addQuery(p.client, query)
	default:
		// fmt.Println(p.client, "Client message:", clientMessageName(message.Type))
	}
}

type PGServerMessageHandler struct {
}

func (mh *PGServerMessageHandler) handleMessage(p *PGStream, message *PGMessage) {
	atomic.AddInt64(&server_message_count, 1)
	if !validServerMessageType(message.Type) {
		fmt.Println(p.net)
		fmt.Println(p.transport)
		fmt.Println(p.client, "Invalid server message:", serverMessageName(message.Type))
		atomic.AddInt64(&invalid_message_count, 1)
		return
	}

	switch message.Type {
	case 'D':
		addRow(p.client, message.Time)
	case 'Z':
		query := endQuery(p.client, message.Time)
		if query != nil {
			completed_queries <- query
		}
	}
}

type PgPcapCollector struct {
	Infile string "Input file, can be /dev/stdin if you want to pipe live data"
}

func NewPgPcapCollector(infile string) *PgPcapCollector {
	pc := PgPcapCollector{
		infile,
	}
	return &pc
}

func (pc *PgPcapCollector) GetQueryChannel() chan *pgquery.PGQuery {
	return completed_queries
}

func (pc *PgPcapCollector) Run() {
	// This loops over the pcap file over and over, makes it
	// easier to simulate load.

	if handle, err := pcap.OpenOffline(pc.Infile); err != nil {
		panic(err)
	} else {
		streamFactory := &PGStreamFactory{}
		streamPool := tcpassembly.NewStreamPool(streamFactory)
		assembler := tcpassembly.NewAssembler(streamPool)
		assembler.MaxBufferedPagesTotal = 1

		packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
		packets := packetSource.Packets()
		// ticker := time.Tick(time.Second)

		FLUSH_FREQUENCY := int64(1)

		last_flush_time := time.Unix(0, 0)
		cur_time := time.Unix(0, 0)
		last_query_count := int64(0)
		// last_missed_count := 0
		for {
			select {
			case packet := <-packets:
				if packet == nil {
					fmt.Println("No packet, done!")
					assembler.FlushAll()
					return
				}
				if packet.NetworkLayer() == nil || packet.TransportLayer() == nil || packet.TransportLayer().LayerType() != layers.LayerTypeTCP {
					fmt.Println("Unusable packet")
					continue
				}
				cur_time = packet.Metadata().Timestamp
				tcp := packet.TransportLayer().(*layers.TCP)
				assembler.AssembleWithTimestamp(packet.NetworkLayer().NetworkFlow(), tcp, cur_time)

				if (cur_time.Unix() - last_flush_time.Unix()) > FLUSH_FREQUENCY {
					// Sort and print output
					// s := make(pgquery.PGQueryStatSlice, 0, len(query_stats))
					// for _, d := range query_stats {
					// 	s = append(s, d)
					// }

					// sort.Sort(s)
					// for _, d := range s {
					// 	fmt.Printf("%5.4f %9d : %s\n", float64(d.ElapsedUsec)/1000000000.0, d.Count, d.Query)
					// }

					if last_flush_time.Unix() != 0 {
						fmt.Println("Flushing:", cur_time)
						fmt.Println(" Client Message count:", atomic.LoadInt64(&client_message_count))
						fmt.Println(" Server Message count:", atomic.LoadInt64(&server_message_count))
						fmt.Println(" Invalid Message count:", atomic.LoadInt64(&invalid_message_count))
						fmt.Println(" Query count:", atomic.LoadInt64(&query_count))
						fmt.Println(" Complete count:", atomic.LoadInt64(&complete_query_count))
						fmt.Println(" Delta:", (atomic.LoadInt64(&query_count))-last_query_count)
						fmt.Println(" Missed packets:", atomic.LoadInt64(&missed_count))
						fmt.Println(" Queries/second", (query_count-last_query_count)/FLUSH_FREQUENCY)
						// fmt.Println(" Debug", atomic.LoadInt64(&query_count)) hh
						flushed, closed := assembler.FlushOlderThan(cur_time.Add(time.Minute * -2))
						fmt.Println("Flushed Connections:", flushed)
						fmt.Println("Closed Connections:", closed)
					} else {
						fmt.Println("Start:", cur_time)
					}
					last_flush_time = cur_time
					last_query_count = atomic.LoadInt64(&query_count)

				}
			}
		}
	}
}

func init() {
	active_sessions = make(map[string]*PGSession)
	completed_queries = make(chan *pgquery.PGQuery, 10000)
}
