package ipset

import (
	"a.yandex-team.ru/balancer/production/x/iptables_daemon/daemonmetrics"
	"a.yandex-team.ru/balancer/production/x/iptables_daemon/types"
	"a.yandex-team.ru/library/go/core/log"
	"bytes"
	"fmt"
	"github.com/golang/protobuf/proto"
	"io"
	"math/rand"
	"net"
	"os/exec"
	"strconv"
	"strings"
	"time"
)

type IIPSet interface {
	RemoveSetsAndPurgeRules(log.Logger)
	CheckInstalledModules(log.Logger)
	ProcessCBBResponse(io.ReadCloser, io.WriteCloser, daemonmetrics.IMetrics, log.Logger) uint64
	Setup(log.Logger)
	UpdateIpsetMetrics(daemonmetrics.IMetrics, log.Logger)
	UpdateIptablesDropMetrics(daemonmetrics.IMetrics, log.Logger)
	GetCmdExecutor() ICmdExecutor
}

type IPSet struct {
	ipv4SetName string
	ipv6SetName string
	cmdExecutor ICmdExecutor
}

func (s *IPSet) RemoveSetsAndPurgeRules(logger log.Logger) {
	type Pair struct {
		first  string
		second string
	}
	for _, pair := range []Pair{{s.ipv4SetName, "ip"}, {s.ipv6SetName, "ip6"}} {
		setName, ipType := pair.first, pair.second

		// clear dangling iptables rules
		for {
			purgeRule := fmt.Sprintf("%stables -D INPUT -m set --match-set %s src -j DROP", ipType, setName)
			res := s.GetCmdExecutor().ExecCmd(purgeRule)

			if res.Err != nil {
				if !strings.Contains(res.StdErr, "Bad rule") && !strings.Contains(res.StdErr, fmt.Sprintf("Set %s doesn't exist", setName)) {
					logger.Fatalf("IPSet::RemoveSetsAndPurgeRules: Looks like iptables cannot be used: cmd \"%s\", err \"%v\", stdout \"%s\", stderr \"%s\"", purgeRule, res.Err, res.StdOut, res.StdErr)
				}
				logger.Infof("IPSet::RemoveSetsAndPurgeRules: no dangling iptables rule for ipset \"%s\"", setName)
				break
			} else {
				logger.Infof("IPSet::RemoveSetsAndPurgeRules: Purged dangling iptables rule for ipset \"%s\"", setName)
			}
		}

		// clear dangling ipsets
		destroySet := fmt.Sprintf("ipset destroy %s", setName)
		res := s.GetCmdExecutor().ExecCmd(destroySet)

		if res.Err != nil {
			if !strings.Contains(res.StdErr, "The set with the given name does not exist") {
				logger.Fatalf("IPSet::RemoveSetsAndPurgeRules: Looks like ipset cannot be used: cmd \"%s\", err \"%v\", stdout \"%s\", stderr \"%s\"", destroySet, res.Err, res.StdOut, res.StdErr)
			}
			logger.Infof("IPSet::RemoveSetsAndPurgeRules: no dangling ipset \"%s\"", setName)
		} else {
			logger.Infof("IPSet::RemoveSetsAndPurgeRules: detroyed dangling ipset \"%s\"", setName)
		}
	}
}

func (s *IPSet) CheckInstalledModules(logger log.Logger) {
	checkIPTablesExists := "iptables --version"
	res := s.GetCmdExecutor().ExecCmd(checkIPTablesExists)
	if res.Err != nil {
		logger.Fatalf("IPSet::CheckInstalledModules: Looks like iptables is not installed: cmd \"%s\", err \"%v\", stdout \"%s\", stderr \"%s\"", checkIPTablesExists, res.Err, res.StdOut, res.StdErr)
	}
	logger.Infof("IPSet::CheckInstalledModules: found iptables: %s", res.StdOut)

	checkIPSetExists := "ipset --version"
	res = s.GetCmdExecutor().ExecCmd(checkIPSetExists)
	if res.Err != nil {
		logger.Fatalf("IPSet::CheckInstalledModules: Looks like ipset is not installed: cmd \"%s\", err \"%v\", stdout \"%s\", stderr \"%s\"", checkIPSetExists, res.Err, res.StdOut, res.StdErr)
	}
	logger.Infof("IPSet::CheckInstalledModules: found ipset: %s", res.StdOut)
}

func (s *IPSet) InitSetsAndRules(logger log.Logger) {
	type Tuple struct {
		first  string
		second string
		third  string
	}
	for _, tuple := range []Tuple{{s.ipv4SetName, "ip", "inet"}, {s.ipv6SetName, "ip6", "inet6"}} {
		setName, ipType, setFamily := tuple.first, tuple.second, tuple.third
		createSet := fmt.Sprintf("ipset create %s hash:ip family %s maxelem 1000000 hashsize 1048576 timeout 600", setName, setFamily)
		res := s.GetCmdExecutor().ExecCmd(createSet)

		if res.Err != nil {
			s.RemoveSetsAndPurgeRules(logger)
			logger.Fatalf("IPSet::InitSetsAndRules: Looks like ipset cannot be used: cmd \"%s\", err \"%v\", stdout \"%s\", stderr \"%s\"", createSet, res.Err, res.StdOut, res.StdErr)
		} else {
			logger.Infof("IPSet::InitSetsAndRules: created ipset \"%s\"", setName)
		}

		createRule := fmt.Sprintf("%stables -A INPUT -m set --match-set %s src -j DROP", ipType, setName)
		res = s.GetCmdExecutor().ExecCmd(createRule)

		if res.Err != nil {
			s.RemoveSetsAndPurgeRules(logger)
			logger.Fatalf("IPSet::InitSetsAndRules: Looks like iptables cannot be used: cmd \"%s\", err \"%v\", stdout \"%s\", stderr \"%s\"", createRule, res.Err, res.StdOut, res.StdErr)
		} else {
			logger.Infof("IPSet::InitSetsAndRules: created iptables rule for ipset \"%s\"", setName)
		}
	}
}

func (s *IPSet) IsSetupRequired(logger log.Logger) bool {
	for _, setName := range []string{s.ipv4SetName, s.ipv6SetName} {
		cmd := fmt.Sprintf("ipset list %s -name -terse", setName)
		res := s.GetCmdExecutor().ExecCmd(cmd)
		if res.Err != nil {
			if !strings.Contains(res.StdErr, "The set with the given name does not exist") {
				logger.Fatalf("IsSetupRequired: Looks like ipset cannot be used: cmd \"%s\", err \"%v\", stdout \"%s\", stderr \"%s\"", cmd, res.Err, res.StdOut, res.StdErr)
			}
			return true
		}

		ipType := "ip"
		if setName == s.ipv6SetName {
			ipType = "ip6"
		}
		cmd = fmt.Sprintf("%stables -nvL INPUT  | grep  \"%s\"", ipType, setName)
		res = s.GetCmdExecutor().ExecCmd(cmd)
		if res.Err != nil {
			if len(res.StdErr) != 0 {
				logger.Fatalf("IsSetupRequired: Looks like iptables cannot be used: cmd \"%s\", err \"%v\", stdout \"%s\", stderr \"%s\"", cmd, res.Err, res.StdOut, res.StdErr)
			}
			return true
		}
	}
	return false
}

func (s *IPSet) Setup(logger log.Logger) {
	s.CheckInstalledModules(logger)

	if s.IsSetupRequired(logger) {
		logger.Infof("IPSet::Setup: setup is required")
		s.RemoveSetsAndPurgeRules(logger)
		s.InitSetsAndRules(logger)
	} else {
		logger.Infof("IPSet::Setup: setup is not required, probably I was restarted")
	}
}

func (s *IPSet) UpdateIpsetMetrics(metrics daemonmetrics.IMetrics, logger log.Logger) {
	sizev4, err := s.GetIpsetSize(s.ipv4SetName)
	if err != nil {
		logger.Errorf("%v", err)
	} else {
		metrics.ReportIPv4SetSize(sizev4)
	}
	sizev6, err := s.GetIpsetSize(s.ipv6SetName)
	if err != nil {
		logger.Errorf("%v", err)
	} else {
		metrics.ReportIPv6SetSize(sizev6)
	}
}

func (s *IPSet) GetCmdExecutor() ICmdExecutor {
	return s.cmdExecutor
}

func (s *IPSet) GetIpsetSize(setName string) (int64, error) {
	cmd := fmt.Sprintf("ipset list %s -terse | grep -E \"Number of entries\"", setName)
	cmdResult := s.GetCmdExecutor().ExecCmd(cmd)
	if cmdResult.Err != nil {
		return 0, fmt.Errorf("UpdateIpsetMetrics: running cmd \"%s\" failed with err \"%v\", stdOut \"%s\", stdErr \"%s\"", cmd, cmdResult.Err, cmdResult.StdOut, cmdResult.StdErr)
	}
	out := cmdResult.StdOut

	numEntriesStr := strings.Split(out, ": ")

	if len(numEntriesStr) != 2 || numEntriesStr[0] != "Number of entries" {
		return 0, fmt.Errorf("UpdateIpsetMetrics: got bad ipset output, cmd \"%s\", output: \"%s\"", cmd, out)
	}
	numEntries, err := strconv.ParseInt(numEntriesStr[1], 10, 64)
	if err != nil {
		return 0, fmt.Errorf("UpdateIpsetMetrics: got bad ipset output, cmd \"%s\", output: \"%s\"", cmd, out)
	}
	return numEntries, nil
}

// TODO change protobuf format to send only binary data
func (s *IPSet) ProcessCBBResponse(in io.ReadCloser, out io.WriteCloser, metrics daemonmetrics.IMetrics, logger log.Logger) uint64 {
	defer func(start time.Time, name string) {
		elapsed := time.Since(start).Milliseconds()
		logger.Infof("Execution of %s took %d ms", name, elapsed)
		metrics.ReportExecTimeMs(fmt.Sprintf("exec%sTimeMs", name), float64(elapsed))
	}(time.Now(), "ProcessCBBResponse")
	defer func() {
		_, err := fmt.Fprintf(out, "quit")
		if err != nil {
			logger.Errorf("ProcessCBBResponse: unable to write in ipset stdin pipe, error \"%v\"", err)
		}
	}()

	var buffer [1024]byte
	_, err := io.ReadFull(in, buffer[:1])
	if err != nil {
		metrics.ReportCbbHTTPError()
		logger.Warnf("IPSet::ProcessCBBResponse: error while reading http response: \"%v\"", err)
		return 0
	}

	bytesToRead := buffer[0]
	_, err = io.ReadFull(in, buffer[:bytesToRead])
	if err != nil {
		metrics.ReportCbbHTTPError()
		logger.Warnf("IPSet::ProcessCBBResponse: error while reading http response: \"%v\"", err)
		return 0
	}

	header := types.TCbbIpResponceHeader{}
	err = proto.Unmarshal(buffer[:bytesToRead], &header)
	if err != nil {
		metrics.ReportCbbBadResponse()
		logger.Errorf("IPSet::ProcessCBBResponse: bad cbb response: \"%v\", error while parsing header proto: \"%v\"", buffer[:bytesToRead], err)
		return 0
	}

	for i := uint32(0); i < header.NumIps; i++ {
		_, err = io.ReadFull(in, buffer[:1])
		if err != nil {
			metrics.ReportCbbHTTPError()
			logger.Warnf("IPSet::ProcessCBBResponse: error while reading http response: \"%v\"", err)
			return 0
		}
		bytesToRead = buffer[0]
		_, err = io.ReadFull(in, buffer[:bytesToRead])
		if err != nil {
			metrics.ReportCbbHTTPError()
			logger.Warnf("IPSet::ProcessCBBResponse: error while reading http response: \"%v\"", err)
			return 0
		}

		item := types.TCbbIpResponceItem{}
		err = proto.Unmarshal(buffer[:bytesToRead], &item)

		if err != nil {
			metrics.ReportCbbBadResponse()
			logger.Errorf("IPSet::ProcessCBBResponse: bad cbb response: \"%v\", error while parsing item proto: \"%v\"", buffer[:bytesToRead], err)
			return 0
		}

		parsedIP := net.ParseIP(item.Ip)
		if parsedIP == nil {
			metrics.ReportCbbBadResponse()
			logger.Errorf("IPSet::ProcessCBBResponse: trying to add invalid ip: \"%s\"", item.Ip)
			return 0
		}
		if item.ExpiredTs <= 0 {
			metrics.ReportCbbBadResponse()
			logger.Errorf("IPSet::ProcessCBBResponse: trying to add invalid deadline \"%d\"", item.ExpiredTs)
			return 0
		}

		timeout := int64(item.ExpiredTs) - time.Now().Unix()
		if timeout <= 0 {
			logger.Debugf("IPSet::ProcessCBBResponse: skipping addition of \"%s\" because deadline \"%d\" has expired", item.Ip, item.ExpiredTs)
			continue
		}
		// TODO sanitize timeout
		timeout += int64(rand.Intn(int(timeout / 2)))

		if parsedIP.To4() != nil {
			logger.Debugf("Executing Ipset command: add %s %s timeout %d -exist", s.ipv4SetName, item.Ip, timeout)
			_, err = fmt.Fprintf(out, "add %s %s timeout %d -exist\n", s.ipv4SetName, item.Ip, timeout)
		} else {
			logger.Debugf("Executing Ipset command: add %s %s timeout %d -exist", s.ipv6SetName, item.Ip, timeout)
			_, err = fmt.Fprintf(out, "add %s %s timeout %d -exist\n", s.ipv6SetName, item.Ip, timeout)
		}
		if err != nil {
			logger.Errorf("ProcessCBBResponse: unable to write in ipset stdin pipe, error \"%v\"", err)
			return 0
		}
		metrics.ReportCbbIPAdded()
	}
	return header.LastCreatedTs
}

func (s *IPSet) GetIptablesDroppedPackages(setName string, ipType string) (int64, error) {
	cmd := fmt.Sprintf("%stables -nvL INPUT  | grep  \"%s\" | awk '{print $1}'", ipType, setName)

	cmdResult := s.GetCmdExecutor().ExecCmd(cmd)
	if cmdResult.Err != nil {
		return 0, nil
	}

	bannedIpsCnt, err := strconv.ParseInt(cmdResult.StdOut, 10, 64)
	if err != nil {
		return 0, fmt.Errorf("UpdateIpsetMetrics: bad output while executing cmd \"%s\": \"%s\"", cmd, cmdResult.StdOut)
	}
	return bannedIpsCnt, nil
}

func (s *IPSet) UpdateIptablesDropMetrics(metrics daemonmetrics.IMetrics, logger log.Logger) {
	bannedv4Ips, err := s.GetIptablesDroppedPackages(s.ipv4SetName, "ip")
	if err != nil {
		logger.Errorf("%v", err)
	} else {
		metrics.ReportIP4TablesDrops(bannedv4Ips)
	}

	bannedv6Ips, err := s.GetIptablesDroppedPackages(s.ipv6SetName, "ip6")
	if err != nil {
		logger.Errorf("%v", err)
	} else {
		metrics.ReportIP6TablesDrops(bannedv6Ips)
	}
}

func New(namev4 string, namev6 string, cmdExecutor ICmdExecutor) *IPSet {
	return &IPSet{namev4, namev6, cmdExecutor}
}

type CmdRunResult struct {
	Err    error
	StdOut string
	StdErr string
}

type ICmdExecutor interface {
	ExecCmd(string) CmdRunResult
	NewCmdWithPipe(string, ...string) (ICmdWithPipe, error)
}

type CmdExecutor struct{}

func (e CmdExecutor) ExecCmd(line string) CmdRunResult {
	cmd := exec.Command("bash", "-c", line)
	var cmdOut = bytes.Buffer{}
	cmd.Stdout = &cmdOut
	var cmdErr = bytes.Buffer{}
	cmd.Stderr = &cmdErr
	err := cmd.Run()
	return CmdRunResult{err, strings.Trim(cmdOut.String(), "\n"), strings.Trim(cmdErr.String(), "\n")}
}

type ICmdWithPipe interface {
	Wait() CmdRunResult
	GetWritePipe() io.WriteCloser
}

type CmdWithPipe struct {
	cmdErr    bytes.Buffer
	cmdOut    bytes.Buffer
	cmd       *exec.Cmd
	pipeRead  io.ReadCloser
	pipeWrite io.WriteCloser
}

func (c *CmdWithPipe) GetWritePipe() io.WriteCloser {
	return c.pipeWrite
}

func (c *CmdWithPipe) Wait() CmdRunResult {
	defer c.pipeRead.Close()
	_ = c.pipeWrite.Close()
	err := c.cmd.Wait()
	return CmdRunResult{
		Err:    err,
		StdOut: c.cmdOut.String(),
		StdErr: c.cmdErr.String(),
	}
}

func (e *CmdExecutor) NewCmdWithPipe(name string, arg ...string) (ICmdWithPipe, error) {
	result := CmdWithPipe{}
	result.cmd = exec.Command(name, arg...)
	result.cmd.Stderr = &result.cmdErr
	result.cmd.Stdout = &result.cmdOut

	result.pipeRead, result.pipeWrite = io.Pipe()
	result.cmd.Stdin = result.pipeRead
	err := result.cmd.Start()
	if err != nil {
		_ = result.pipeRead.Close()
		_ = result.pipeWrite.Close()
		return nil, err
	}
	return &result, nil
}
