package main

import (
	"errors"
	"fmt"
	"io"
	"log"

	"code.google.com/p/gopacket"
	"code.google.com/p/gopacket/layers"
	"github.com/PreetamJinka/sflow"
)

func portsFromTCPPacket(packet *gopacket.Packet) (srcPort uint16, dstPort uint16, err error) {
	tcpLayer := (*packet).Layer(layers.LayerTypeTCP)
	if tcpLayer == nil {
		return uint16(0), uint16(0), errors.New("no TCP layer found")
	}
	tcp, ok := tcpLayer.(*layers.TCP)
	if !ok {
		return uint16(0), uint16(0), fmt.Errorf("gopacket returned supposed tcp layer which could not be asserted?, %v", tcpLayer)
	}
	return uint16(tcp.SrcPort), uint16(tcp.DstPort), nil
}

func portsFromUDPPacket(packet *gopacket.Packet) (srcPort uint16, dstPort uint16, err error) {
	udpLayer := (*packet).Layer(layers.LayerTypeUDP)
	if udpLayer == nil {
		return uint16(0), uint16(0), errors.New("no UDP layer found")
	}
	udp, ok := udpLayer.(*layers.UDP)
	if !ok {
		return uint16(0), uint16(0), fmt.Errorf("gopacket returned supposed udp layer which could not be asserted?, %v", udpLayer)
	}
	return uint16(udp.SrcPort), uint16(udp.DstPort), nil
}

func findPacketFromRawPacketFlowRecord(rawFlowRecord sflow.RawPacketFlow, filter *Filter) (p *CapturedPacket) {
	header := rawFlowRecord.Header
	packet := gopacket.NewPacket(header, layers.LayerTypeEthernet, gopacket.Default)

	ipv4Layer := packet.Layer(layers.LayerTypeIPv4)
	if ipv4Layer == nil {
		return // no ipv4 layer found in this record
	}
	ipv4, ok := ipv4Layer.(*layers.IPv4)
	if !ok {
		log.Printf("gopacket returned supposed ipv4 layer which could not be asserted?, %v", ipv4Layer)
		return
	}

	if _, ok := filter.Protocol[ipv4.Protocol]; !ok {
		return // we don't want this sort of packet
	}

	if filter.IP != nil {
		found := false
		for _, ipNet := range filter.IP {
			if ipNet.Contains(ipv4.SrcIP) || ipNet.Contains(ipv4.DstIP) {
				found = true
				break
			}
		}
		if !found {
			return
		}
	}

	var srcPort, dstPort uint16
	if ipv4.Protocol == layers.IPProtocolTCP {
		var err error
		srcPort, dstPort, err = portsFromTCPPacket(&packet)
		if err != nil {
			log.Printf("encountered error reading TCP Packet, %v", err)
			return
		}
	} else if ipv4.Protocol == layers.IPProtocolUDP {
		var err error
		srcPort, dstPort, err = portsFromUDPPacket(&packet)
		if err != nil {
			log.Printf("encountered error reading UDP Packet, %v", err)
			return
		}
	}

	if filter.Port != nil {
		found := false
		for _, uint16Range := range filter.Port {
			if uint16Range.Contains(srcPort) || uint16Range.Contains(dstPort) {
				found = true
				break
			}
		}
		if !found {
			return
		}
	}

	return &CapturedPacket{
		Protocol:        ipv4.Protocol,
		SourceIP:        ipv4.SrcIP,
		SourcePort:      srcPort,
		DestinationIP:   ipv4.DstIP,
		DestinationPort: dstPort,
		Length:          ipv4.Length,
	}
}

func findPacketsFromFlowSample(flowSample *sflow.FlowSample, filter *Filter) []*CapturedPacket {
	packets := make([]*CapturedPacket, 0)
	for _, record := range flowSample.Records {
		rawFlowRecord, ok := record.(sflow.RawPacketFlow)
		if !ok {
			continue // pass over whatever this is
		}
		if packet := findPacketFromRawPacketFlowRecord(rawFlowRecord, filter); packet != nil {
			packets = append(packets, packet)
		}
	}
	return packets
}

func FindPackets(rs io.ReadSeeker, filter *Filter) []*CapturedPacket {
	decoder := sflow.NewDecoder(rs)
	dgram, err := decoder.Decode()
	if err != nil {
		log.Printf("sflow decoding error, %v", err)
		return nil
	}
	packets := make([]*CapturedPacket, 0)
	for _, sample := range dgram.Samples {
		flowSample, ok := sample.(*sflow.FlowSample)
		if !ok {
			continue // just pass over counter samples
		}
		packets = append(packets, findPacketsFromFlowSample(flowSample, filter)...)
	}
	return packets
}
