package main

import (
	"flag"
	"log"
	"net"
	"os"
	"path/filepath"
	"strings"
	"sync"

	"github.com/golang/protobuf/proto"

	"code.justin.tv/ids/nfconvert/parse"
	"code.justin.tv/ids/nfconvert/pbflow"
)

func convertUint16(in *uint16) (out *uint32) {
	if in == nil {
		return nil
	}
	temp := uint32(*in)
	return &temp
}

func convertProtocol(in *uint8) (out *pbflow.Protocol) {
	if in == nil {
		return nil
	}
	temp := pbflow.Protocol(*in)
	return &temp
}

func convertDirection(in *uint8) (out *pbflow.FlowRecord_Direction) {
	if in == nil {
		return nil
	}
	temp := pbflow.FlowRecord_Direction(*in)
	return &temp
}

func convertTime(sec uint32, msec uint16) (out *uint64) {
	ns := uint64(sec)*1e9 + uint64(msec)*1e6
	return &ns
}

func convertMAC(in *parse.MACRecord) (out *pbflow.MACAddress) {
	if in == nil {
		return nil
	}
	data := make([]uint32, len(in))
	for i, byte := range in {
		data[i] = uint32(byte)
	}
	return &pbflow.MACAddress{Data: data}
}

func convertIP(in net.IP) (out *pbflow.IPAddress) {
	data := make([]uint32, len(in))
	for i, byte := range in {
		data[i] = uint32(byte)
	}

	var version pbflow.IPAddress_Version
	switch len(in) {
	case net.IPv4len:
		version = pbflow.IPAddress_IPV4
	case net.IPv6len:
		version = pbflow.IPAddress_IPV6
	}

	ip := &pbflow.IPAddress{
		Data:    data,
		Version: &version,
	}
	return ip
}

func convertUint8(in *uint8) (out *uint32) {
	if in == nil {
		return nil
	}
	temp := uint32(*in)
	return &temp
}

func convertRecord(rec *parse.Record) *pbflow.FlowRecord {
	pbf := &pbflow.FlowRecord{
		SrcIp:      convertIP(rec.Base.SrcIP),
		DstIp:      convertIP(rec.Base.DstIP),
		SrcPort:    convertUint16(&rec.Base.SrcPort),
		DstPort:    convertUint16(&rec.Base.DstPort),
		Protocol:   convertProtocol(&rec.Base.Prot),
		FlowStart:  convertTime(rec.Base.First, rec.Base.MsecFirst),
		FlowEnd:    convertTime(rec.Base.Last, rec.Base.MsecLast),
		PacketsIn:  &rec.Base.PacketsIn,
		BytesIn:    &rec.Base.BytesIn,
		SrcMask:    convertUint8(rec.SrcMask),
		DstMask:    convertUint8(rec.DstMask),
		InSrcMac:   convertMAC(rec.InSrcMAC),
		InDstMac:   convertMAC(rec.InDstMAC),
		OutSrcMac:  convertMAC(rec.OutSrcMAC),
		OutDstMac:  convertMAC(rec.OutDstMAC),
		SrcAs:      rec.GetSrcAS(),
		DstAs:      rec.GetDstAS(),
		NextHop:    convertIP(rec.NextHop),
		NextHopBgp: convertIP(rec.BGPNextHop),
		AggrFlows:  &rec.AggrFlows,
		InTosByte:  convertUint8(&rec.Base.Tos),
		OutTosByte: convertUint8(rec.DstTypeOfService),
		TcpFlags:   convertUint8(&rec.Base.TcpFlags),
		Direction:  convertDirection(rec.Direction),
		InputSnmp:  rec.InputInterface,
		OutputSnmp: rec.OutputInterface,
		SrcVlan:    convertUint16(rec.GetSrcVLAN()),
		DstVlan:    convertUint16(rec.GetDstVLAN()),
	}
	return pbf
}

// convert nfcapd file at filename 'in' to a protobuf file at 'out'
func convert(in string, out string) error {
	log.Println("converting", in, "->", out)
	infile, err := os.Open(in)
	if err != nil {
		return err
	}
	defer infile.Close()

	outfile, err := os.Create(out)
	if err != nil {
		return err
	}
	defer outfile.Close()

	var scanner parse.RecordScanner
	if err = scanner.Init(infile, in); err != nil {
		return err
	}
	scanner.Map(func(r *parse.Record) error {
		b, err := proto.Marshal(convertRecord(r))

		if err != nil {
			return err
		}
		outfile.Write(b)
		return nil
	})

	return nil

}

// returns deepest path which is shared by all paths. for example:
//    in:  /etc/service/foo /etc/service/bar /etc/service/baz/bang
//    out: /etc/service
//
func commonParent(paths []string) string {
	commonParent := paths[0]

	for _, p := range paths {
		dir, _ := filepath.Split(p)
		dir, _ = filepath.Abs(dir)
		for !strings.HasPrefix(dir, commonParent) {
			split := strings.Split(commonParent, "/")
			split = split[:len(split)-1]
			commonParent = "/" + filepath.Join(split...)
		}
	}
	return commonParent
}

// removes deepest common path, for example:
//    in:  /etc/service/foo /etc/service/bar /etc/service/baz/bang
//    out: foo bar baz/bang
//
func stripCommonParents(paths []string) []string {
	cp := commonParent(paths)
	out := make([]string, len(paths))
	for i, p := range paths {
		out[i] = strings.TrimPrefix(p, cp)
	}
	return out
}

func main() {
	var input string
	flag.StringVar(&input, "input", "", "filename (or glob of files) to read in")
	var outputDir string
	flag.StringVar(&outputDir, "output-dir", "", "directory to write to")
	var suffix string
	flag.StringVar(&suffix, "suffix", ".pb", "suffix to append to files when writing to disk")
	flag.Parse()

	var wg sync.WaitGroup

	inputFiles, err := filepath.Glob(input)
	if err != nil {
		log.Fatal(err)
	}
	outputFiles := stripCommonParents(inputFiles)
	for i, o := range outputFiles {
		outputFiles[i] = filepath.Join(outputDir, o+suffix)
	}
	log.Println("Reading from", inputFiles)
	log.Println("Writing to", outputFiles)
	for i, in := range inputFiles {
		wg.Add(1)
		out := outputFiles[i]
		log.Println("doing file no.", i)
		go func(in, out string, i int, wg *sync.WaitGroup) {
			defer wg.Done()
			var e error
			e = convert(in, out)
			if e != nil {
				log.Fatal(e)
			}
			log.Println("finished file no.", i)
		}(in, out, i, &wg)
	}
	wg.Wait()
}
