package cmd

import (
	"bytes"
	"context"
	"encoding/json"
	"io"
	"log"
	"net/url"
	"os"
	"os/signal"
	"runtime"
	"strconv"
	"sync"
	"syscall"
	"time"

	"github.com/spf13/cobra"
	"go.uber.org/zap"
	"go.uber.org/zap/zapcore"
	"golang.org/x/crypto/ssh/terminal"

	"a.yandex-team.ru/infra/hostctl/pkg/hostinfo"
	"a.yandex-team.ru/infra/rtc/packages/yandex-hbf-agent/hbf-metrics-pusher/internal/iptables"
	"a.yandex-team.ru/infra/rtc/packages/yandex-hbf-agent/hbf-metrics-pusher/internal/netwrap"
	"a.yandex-team.ru/infra/rtc/packages/yandex-hbf-agent/hbf-metrics-pusher/internal/porto"
	"a.yandex-team.ru/infra/rtc/packages/yandex-hbf-agent/hbf-metrics-pusher/internal/solomon"
	"a.yandex-team.ru/library/go/core/log/zap/logrotate"
	"a.yandex-team.ru/library/go/yandex/tvm"

	porto_api "a.yandex-team.ru/infra/porto/api_go"
)

const (
	defLogPath = "/var/log/hbf-metrics-pusher/hbf-metrics-pusher.log"
)

var (
	debug          bool
	updateInterval uint64
	logPath        string

	solomonPushURL string
	solomonTimeout uint64
	solomonCluster string
	solomonProject string
	solomonService string

	secretVaultID    string
	secretStorageKey string
	secretHostCert   string

	tvmMyselfID  uint32
	tvmSolomonID uint32
	tvmCacheDir  string
)

var (
	commonLabels = make(map[string]string)
)

var serveCmd = &cobra.Command{
	Use:   "serve",
	Short: "Start `hbf-metrics-pusher` service",
	Run:   serve,
}

func init() {
	rootCmd.AddCommand(serveCmd)
	serveCmd.Flags().BoolVar(&debug, "debug", false, "set log level to debug")
	serveCmd.Flags().StringVar(&logPath, "log-path", defLogPath, "log path")

	serveCmd.Flags().Uint64Var(&updateInterval, "update-interval", 10, "metrics update interval in seconds")

	serveCmd.Flags().Uint64Var(&solomonTimeout, "solomon-timeout", solomon.ClientTimeout, "push to solomon timeout")
	serveCmd.Flags().StringVar(&solomonPushURL, "solomon-push-url", solomon.PushURL, "solomon push api url")
	serveCmd.Flags().StringVar(&solomonCluster, "solomon-cluster", solomon.Cluster, "solomon cluster")
	serveCmd.Flags().StringVar(&solomonProject, "solomon-project", solomon.Project, "solomon project")
	serveCmd.Flags().StringVar(&solomonService, "solomon-service", solomon.Service, "solomon service")

	serveCmd.Flags().StringVar(&secretVaultID, "secret-vault-id", solomon.VaultSecretID, "vault secret ID, that stores map with solomon OAuth token")
	serveCmd.Flags().StringVar(&secretStorageKey, "secret-storage-key", solomon.SecretKey, "name of the key in secret ID storage, that holds OAuth token value")
	serveCmd.Flags().StringVar(&secretHostCert, "secret-host-cert", solomon.HostCert, "host cert, that used to authentificate in nuvault, to get OAuth token for solomon")

	serveCmd.Flags().Uint32Var(&tvmMyselfID, "tvm-myself-id", solomon.TvmMyselfID, "tvm id of hbf-metrics-pusher")
	serveCmd.Flags().Uint32Var(&tvmSolomonID, "tvm-solomon-id", solomon.TvmSolomonID, "tvm id of solomon")
	serveCmd.Flags().StringVar(&tvmCacheDir, "tvm-cache-dir", solomon.TvmCacheDir, "cache dir for tmv")

	commonLabels["name"] = "hbf-drops-perservice"
}

func serve(cmd *cobra.Command, args []string) {
	logger, err := makeLogger(debug)
	if err != nil {
		log.Fatalf("Could not create logger: %s", err)
	}
	zap.ReplaceGlobals(logger)
	defer logger.Sync()
	zap.L().Info("Starting...")
	if gomaxprocs >= 1 {
		zap.L().Info("Set GOMAXPROCS", zap.Int("old value", runtime.GOMAXPROCS(gomaxprocs)), zap.Int("new value", runtime.GOMAXPROCS(0)))
	}

	err = os.MkdirAll(tvmCacheDir, 0o600)
	if err != nil {
		zap.L().Fatal("could not create TVM cache dir", zap.Error(err))
	}

	hostname, err := os.Hostname()
	if err != nil {
		zap.L().Warn("failed to get hostname", zap.Error(err))
	}
	commonLabels["host"] = hostname

	fd, err := os.Open(hostinfo.ServerInfoFile)
	if err != nil {
		zap.L().Warn("failed to open file", zap.Error(err))
	} else {
		serverInfo, err := hostinfo.FromStatReader(fd)
		if err != nil {
			zap.L().Warn("failed to get serverinfo", zap.Error(err))
		} else {
			commonLabels["geo"] = serverInfo.Location
		}
	}

	var wg sync.WaitGroup

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	ipt6, err := iptables.New(ctx, iptables.IPv6, vanila)
	if err != nil {
		cancel()
		zap.L().Fatal("failed to init iptables", zap.Error(err))
	}

	portoClient, err := porto.CreatePortoClient()
	if err != nil {
		cancel()
		zap.L().Fatal("failed to create porto client", zap.Error(err))
	}
	defer portoClient.Close()

	solomonPusher, err := solomon.New(ctx,
		&solomon.Config{
			Secrets: solomon.SecretsCfg{
				VaultID:    secretVaultID,
				StorageKey: secretStorageKey,
				HostCert:   secretHostCert,
			},
			Solomon: solomon.SolomonCfg{
				URL:     solomonPushURL,
				Cluster: solomonCluster,
				Project: solomonProject,
				Service: solomonService,
				Timeout: time.Duration(solomonTimeout) * time.Second,
			},
			Tvm: solomon.TvmCfg{
				MyselfID:  tvm.ClientID(tvmMyselfID),
				SolomonID: tvm.ClientID(tvmSolomonID),
				CacheDir:  tvmCacheDir,
			},
		},
	)
	if err != nil {
		cancel()
		zap.L().Fatal("failed to create solomon pusher", zap.Error(err))
	}

	sig := make(chan os.Signal)
	signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)

	// Wait for signal
	wg.Add(1)
	go func() {
		defer wg.Done()
		defer signal.Stop(sig)
		s := <-sig
		zap.L().Info("Got signal", zap.String("name", s.String()))
		cancel()
	}()

	// work with iptables output
	iptAnsw := make(chan []iptables.Answer, 1)
	wg.Add(1)
	go func(ctx context.Context) {
		defer wg.Done()
		for {
			select {
			case out := <-iptAnsw:
				body, err := formBody(out, portoClient)
				if err != nil {
					zap.L().Error("failed to form body", zap.Error(err))
				} else {
					code, err := solomonPusher.Push(ctx, body)
					if err != nil {
						zap.L().Error("failed to push to solomon", zap.Error(err))
					} else {
						zap.L().Info("Response", zap.Int("code", code))
					}
				}
			case <-ctx.Done():
				return
			}
		}
	}(ctx)

	// Run iptables in goroutine
	ticker := time.NewTicker(time.Duration(updateInterval) * time.Second)
	wg.Add(1)
	go func(ctx context.Context, t *time.Ticker) {
		defer wg.Done()
		for {
			select {
			case <-t.C:
				ipt6.ListPreDropRules(iptAnsw)
			case <-ctx.Done():
				t.Stop()
				return
			}
		}
	}(ctx, ticker)

	wg.Wait()
}

func formBody(answs []iptables.Answer, portoClient porto_api.PortoAPI) (io.Reader, error) {
	ipLabelMap, err := porto.IPLabelMap(portoClient)
	if err != nil {
		return nil, err
	}
	metrics := []solomon.Metric{}

	now := time.Now().Unix()

	totalIn := solomon.NewIgauge()
	totalIn.TS = now
	totalIn.Labels["service_name"] = "total"
	totalIn.Labels["direction"] = iptables.DirectionIn.String()

	totalOut := solomon.NewIgauge()
	totalOut.TS = now
	totalOut.Labels["service_name"] = "total"
	totalOut.Labels["direction"] = iptables.DirectionOut.String()

	for _, answ := range answs {
		for _, rule := range answ.Rules {
			pkts, err := strconv.ParseUint(string(rule["pkts"]), 10, 64)
			if err != nil {
				zap.L().Debug("failed to parse uint", zap.Error(err))
				continue
			}
			// skip rules with zero drops
			if pkts == 0 {
				zap.L().Debug("pkts is 0", zap.ByteString("comment", rule["comment"]))
				continue
			}

			metric := solomon.NewCounter()

			addr := iptables.SelectAddrField(answ.Direction, rule)
			addr, err = netwrap.TryParseIP(addr)
			if err != nil {
				zap.L().Debug("could not parse to net.IP/net.IPNet", zap.String("addr", addr))
				continue
			}
			label, ok := ipLabelMap[addr]
			if ok {
				metric.Labels["service_name"] = label.ServiceID
				metric.Labels["engine"] = label.Engine
			} else {
				metric.Labels["service_name"] = string(rule["comment"])
			}
			metric.Value = pkts

			metric.Labels["direction"] = answ.Direction.String()

			switch answ.Direction {
			case iptables.DirectionIn:
				totalIn.Value += pkts
			case iptables.DirectionOut:
				totalOut.Value += pkts
			}
			metric.TS = now
			zap.L().Debug("adding to metrics", zap.Any("metric", metric))
			metrics = append(metrics, metric)
		}
	}
	metrics = append(metrics, totalIn, totalOut)

	d := solomon.JSONFormat{
		TS:           time.Now().Unix(),
		CommonLabels: commonLabels,
		Metrics:      metrics,
	}
	zap.L().Debug("Formed message", zap.Any("body", d))
	b, err := json.Marshal(d)
	if err != nil {
		return nil, err
	}
	return bytes.NewBuffer(b), nil
}

func makeLogger(debug bool) (*zap.Logger, error) {
	u, err := url.ParseRequestURI(logPath)
	if err != nil {
		return nil, err
	}
	sink, err := logrotate.NewLogrotateSink(u, syscall.SIGHUP)
	if err != nil {
		return nil, err
	}
	encoderCfg := zap.NewProductionEncoderConfig()
	encoderCfg.EncodeTime = zapcore.ISO8601TimeEncoder
	encoderCfg.EncodeLevel = zapcore.CapitalLevelEncoder
	encoder := zapcore.NewConsoleEncoder(encoderCfg)
	al := zap.NewAtomicLevelAt(zapcore.InfoLevel)
	if debug {
		al = zap.NewAtomicLevelAt(zapcore.DebugLevel)
	}
	core := zapcore.NewCore(encoder, sink, al)
	// if serve runned under terminal, duplicate output to console
	if terminal.IsTerminal(int(os.Stdout.Fd())) {
		core = zapcore.NewTee(
			core,
			zapcore.NewCore(encoder, zapcore.Lock(os.Stdout), al))
	}
	return zap.New(core), nil
}
