package parser

import (
	"bufio"
	"bytes"
	"compress/zlib"
	"fmt"
	"io"
	"net"
	"strconv"
	"strings"
	"time"
	"unsafe"

	"github.com/vmihailenco/msgpack"
)

// https://bb.yandex-team.ru/projects/NOC/repos/fw-filter/browse/fw-injector.h#43-53
const (
	opcodeCmdType       = 0x00
	opcodeTimestamp     = 0x01
	opcodeIpAddr        = 0x02
	opcodeUsername      = 0x04
	opcodeEntryPoint    = 0x05
	opcodeHostname      = 0x06
	opcodeMACAddr       = 0x07
	opcodeUncomprSize   = 0x10
	opcodeUncomprMD5    = 0x11
	opcodeCompBinData   = 0x20
	opcodeUncompBinData = 0x21
)

const (
	commandAdd     = 0x01
	commandDel     = 0x02
	commandPubs    = 0x03
	commandReplace = 0x04
)

func msgDecoder(trusted bool, hostname string) func(*msgpack.Decoder) (interface{}, error) {
	return func(d *msgpack.Decoder) (interface{}, error) {
		size, err := d.DecodeMapLen()
		if err != nil {
			return nil, err
		}

		if size <= 0 {
			return nil, nil
		}

		op, err := d.DecodeUint8()
		if err != nil {
			return nil, err
		}

		if op != opcodeCmdType {
			return nil, fmt.Errorf("unexpected opcode (expected CmdType): %d", op)
		}

		cmd, err := d.DecodeUint8()
		if err != nil {
			return nil, err
		}

		// skip cmd field
		size--
		switch cmd {
		case commandAdd:
			return decodeCommandAdd(d, size)
		case commandDel:
			return decodeCommandDel(d, size)
		case commandPubs:
			if !trusted {
				return nil, nil
			}

			return decodePubs(d, size)
		case commandReplace:
			return decodeRuleset(d, hostname, size)
		default:
			// that's fine, we just not want to support all kind of messages, huh
			fmt.Println("skip", cmd)
			return nil, nil
		}
	}
}

func decodeCommandAdd(d *msgpack.Decoder, fieldsCount int) (*AddRecord, error) {
	var out AddRecord
	for i := 0; i < fieldsCount; i++ {
		op, err := d.DecodeUint8()
		if err != nil {
			return nil, fmt.Errorf("unable to parse opcode of field %d: %w", i, err)
		}

		switch op {
		case opcodeTimestamp:
			out.Timestamp, err = d.DecodeInt64()
		case opcodeIpAddr:
			if out.RuleKind != RuleKindNone {
				err = fmt.Errorf("can't set ipaddr to already materialized record of type %s", out.RuleKind)
				break
			}

			out.IP, err = d.DecodeBytes()
			if err != nil {
				break
			}

			switch len(out.IP) {
			case net.IPv4len:
				out.RuleKind = RuleKindIPv4
			case net.IPv6len:
				out.RuleKind = RuleKindIPv6
			default:
				err = fmt.Errorf("unexpected ip len: %d", len(out.IP))
			}

		case opcodeMACAddr:
			out.MAC, err = d.DecodeBytes()
			out.RuleKind = RuleKindMAC
		case opcodeUsername:
			out.Username, err = d.DecodeString()
		case opcodeHostname:
			out.Hostname, err = d.DecodeString()
		case opcodeEntryPoint:
			out.Entrypoint, err = d.DecodeUint8()
			if err == nil {
				// entrypoint starts from 0 on the NOC side %)
				out.Entrypoint++
			}
		default:
			err = fmt.Errorf("unexpected opcode: %d", op)
		}

		if err != nil {
			return nil, fmt.Errorf("unable to parse field %d of opcode %d: %w", i, op, err)
		}
	}

	return &out, nil
}

func decodeCommandDel(d *msgpack.Decoder, fieldsCount int) (*DeleteRecord, error) {
	out := DeleteRecord{
		Timestamp: time.Now().Unix(),
	}

	for i := 0; i < fieldsCount; i++ {
		op, err := d.DecodeUint8()
		if err != nil {
			return nil, fmt.Errorf("unable to parse opcode of field %d: %w", i, err)
		}

		switch op {
		case opcodeIpAddr:
			if out.RuleKind != RuleKindNone {
				err = fmt.Errorf("can't set ipaddr to already materialized record of type %s", out.RuleKind)
				break
			}

			out.IP, err = d.DecodeBytes()
			if err != nil {
				break
			}

			switch len(out.IP) {
			case net.IPv4len:
				out.RuleKind = RuleKindIPv4
			case net.IPv6len:
				out.RuleKind = RuleKindIPv6
			default:
				err = fmt.Errorf("unexpected ip len: %d", len(out.IP))
			}

		case opcodeMACAddr:
			out.MAC, err = d.DecodeBytes()
			out.RuleKind = RuleKindMAC
		case opcodeHostname:
			out.Hostname, err = d.DecodeString()
		default:
			err = fmt.Errorf("unexpected opcode: %d", op)
		}

		if err != nil {
			return nil, fmt.Errorf("unable to parse field %d of opcode %d: %w", i, op, err)
		}
	}

	return &out, nil
}

func decodeRuleset(d *msgpack.Decoder, hostname string, fieldsCount int) (*RulesetRecord, error) {
	rawRuleset, err := decodeBinaryData(d, fieldsCount)
	if err != nil {
		return nil, err
	}

	r, err := zlib.NewReader(bytes.NewBuffer(rawRuleset.Data))
	if err != nil {
		return nil, fmt.Errorf("create zlib reader: %w", err)
	}

	// TODO(buglloc): check md5 && len
	return parseRuleset(rawRuleset.Timestamp, hostname, r)
}

func decodePubs(d *msgpack.Decoder, fieldsCount int) (*UpdatePublishersRecord, error) {
	rawPub, err := decodeBinaryData(d, fieldsCount)
	if err != nil {
		return nil, err
	}

	r, err := zlib.NewReader(bytes.NewBuffer(rawPub.Data))
	if err != nil {
		return nil, fmt.Errorf("create zlib reader: %w", err)
	}

	// TODO(buglloc): check md5 && len
	return parsePubs(rawPub.Timestamp, r)
}

func decodeBinaryData(d *msgpack.Decoder, fieldsCount int) (*binaryData, error) {
	if fieldsCount < 4 || fieldsCount > 5 {
		return nil, fmt.Errorf("wrong number of opcodes: should be 4 or 5, got %d", fieldsCount)
	}

	var out binaryData
	for i := 0; i < fieldsCount; i++ {
		op, err := d.DecodeUint8()
		if err != nil {
			return nil, fmt.Errorf("unable to parse opcode of field %d: %w", i, err)
		}

		switch op {
		case opcodeTimestamp:
			out.Timestamp, err = d.DecodeInt64()
		case opcodeUncomprSize:
			out.UncompressedSize, err = d.DecodeInt()
		case opcodeCompBinData:
			out.Data, err = d.DecodeBytes()
		case opcodeUncomprMD5:
			var data []byte
			data, err = d.DecodeBytes()
			if err == nil {
				if len(data) != 16 {
					err = fmt.Errorf("unexpected md5 len: %d", len(data))
				} else {
					copy(out.Md5Sum[:], data)
				}
			}
		case opcodeHostname:
			_, err = d.DecodeString()

		default:
			err = fmt.Errorf("unexpected opcode: %d", op)
		}

		if err != nil {
			return nil, fmt.Errorf("unable to parse field %d of opcode %d: %w", i, op, err)
		}
	}

	return &out, nil
}

func parsePubs(ts int64, r io.Reader) (*UpdatePublishersRecord, error) {
	out := UpdatePublishersRecord{
		Timestamp: ts,
	}

	scanner := bufio.NewScanner(r)
	for scanner.Scan() {
		line := strings.TrimSpace(scanner.Text())
		if len(line) == 0 {
			continue
		}

		if line[0] == '#' {
			continue
		}

		out.Hostnames = append(out.Hostnames, line)
	}

	return &out, scanner.Err()
}

func parseRuleset(ts int64, hostname string, r io.Reader) (*RulesetRecord, error) {
	out := RulesetRecord{
		Timestamp: ts,
		Hostname:  hostname,
	}

	parseLine := func(in []byte) error {
		/*
			i172.28.122.143;1645803989;asinamati;wireless
			m88:e9:fe:60:d8:a9;1645807160;forx;wireless
			i2a02:6b8:0:402:5861:3b8:cb1e:9a31;1645810672;alisasnooki;wireless
		*/

		if len(in) == 0 {
			return nil
		}

		data, rest := nextStringElem(in, ';')
		if data == "" {
			return io.ErrUnexpectedEOF
		}

		var rule Rule
		var err error
		switch data[0] {
		case 'm':
			rule.MAC, err = net.ParseMAC(data[1:])
			rule.Kind = RuleKindMAC
		case 'i':
			rule.IP = net.ParseIP(data[1:])
			rule.Kind = RuleKindIPv6
			if ip4 := rule.IP.To4(); ip4 != nil {
				rule.Kind = RuleKindIPv4
			}

		default:
			err = fmt.Errorf("unexpected data marker: %x", data[0])
		}

		if err != nil {
			return fmt.Errorf("unable to parse ip/mac: %w", err)
		}

		data, rest = nextStringElem(rest, ';')
		if data == "" {
			return io.ErrUnexpectedEOF
		}
		rule.Timestamp, err = strconv.ParseInt(data, 10, 64)
		if err != nil {
			return fmt.Errorf("unable to parse timestamp: %w", err)
		}

		data, rest = nextStringElem(rest, ';')
		if data == "" {
			return io.ErrUnexpectedEOF
		}
		rule.Username = data

		data, rest = nextStringElem(rest, ';')
		if data == "" {
			rule.Entrypoint = parseEntryPoint(*(*string)(unsafe.Pointer(&rest)))
		} else {
			rule.Entrypoint = parseEntryPoint(data)
		}

		out.Rules = append(out.Rules, rule)
		return nil
	}

	curLine := 0
	scanner := bufio.NewScanner(r)
	for scanner.Scan() {
		curLine++
		line := bytes.TrimSpace(scanner.Bytes())
		if len(line) == 0 {
			continue
		}

		if err := parseLine(line); err != nil {
			return nil, fmt.Errorf("unable to parse line %d: %w", curLine, err)
		}
	}

	return &out, scanner.Err()
}

func nextStringElem(in []byte, tok byte) (string, []byte) {
	if idx := bytes.IndexByte(in, tok); idx >= 0 {
		return string(in[:idx]), in[idx+1:]
	}

	return "", in
}

func parseEntryPoint(in string) uint8 {
	switch in {
	case "wired":
		return 1
	case "wireless":
		return 2
	case "vpn":
		return 3
	case "mobile":
		return 4
	case "nac":
		return 5
	default:
		return 0
	}
}
