package synchronizer

import (
	"context"
	"errors"
	"fmt"
	"time"

	"go.uber.org/zap"

	"a.yandex-team.ru/security/fisthop-collector/internal/parser"
	"a.yandex-team.ru/security/fisthop-collector/internal/stampdb"
	"a.yandex-team.ru/security/fisthop-collector/internal/zmq"
)

type Synchronizer struct {
	done     chan struct{}
	ctx      context.Context
	cancelFn context.CancelFunc
	cfg      *Config
}

func NewSynchronizer(cfg *Config) (*Synchronizer, error) {
	ctx, cancel := context.WithCancel(context.Background())
	return &Synchronizer{
		done:     make(chan struct{}),
		ctx:      ctx,
		cancelFn: cancel,
		cfg:      cfg,
	}, nil
}

func (s *Synchronizer) Start() error {
	defer close(s.done)

	db, err := stampdb.NewWriter(s.cfg.outDir,
		stampdb.WithWriterLogger(s.cfg.log),
		stampdb.WithWriterFlushPeriod(s.cfg.flushPeriod),
		// https://bb.yandex-team.ru/projects/NOC/repos/fw-filter/browse/fw-injector.h#33
		stampdb.WithWriterMaxAssoc(time.Hour*24*2),
	)

	if err != nil {
		return fmt.Errorf("unable to create stampdb: %w", err)
	}

	if err := db.Restore(); err != nil {
		return fmt.Errorf("unable to restore stampdb: %w", err)
	}

	defer func() {
		if err := db.Shutdown(context.Background()); err != nil {
			s.cfg.log.Error("failed to shutdown stampdb", zap.Error(err))
		}
	}()

	gang, err := zmq.NewGang(s.cfg.trustedServers, zmq.WithLogger(s.cfg.log))
	if err != nil {
		s.cfg.log.Fatal("unable to create zmq gang", zap.Error(err))
	}

	defer func() {
		if err := gang.Shutdown(context.Background()); err != nil {
			s.cfg.log.Error("failed to shutdown firsthop gang", zap.Error(err))
		}
	}()

	haveState := false
	s.cfg.log.Info("wait state resync")
	records := make(chan parser.Message)
	go func() {
		for {
			r, err := gang.NextRecord()
			if err != nil {
				if errors.Is(err, zmq.ErrClosed) {
					close(records)
					return
				}

				s.cfg.log.Fatal("gang read fail", zap.Error(err))
			}

			records <- r
		}
	}()

	var r parser.Message
	for {
		select {
		case <-s.ctx.Done():
			return nil
		case r = <-records:
			if r == nil {
				return nil
			}
		}

		switch rr := r.(type) {
		case *parser.AddRecord:
			if !haveState {
				break
			}

			switch rr.RuleKind {
			case parser.RuleKindIPv4, parser.RuleKindIPv6:
				ip := rr.IP.String()
				db.AddEntry(ip, stampdb.Entry{
					Epoch:    time.Now().UnixMilli(),
					AssocTS:  rr.Timestamp,
					IP:       ip,
					Username: rr.Username,
					Hostname: rr.Hostname,
					Via:      rr.Entrypoint,
				})
			case parser.RuleKindMAC:
				mac := rr.MAC.String()
				db.AddEntry(mac, stampdb.Entry{
					Epoch:    time.Now().UnixMilli(),
					AssocTS:  rr.Timestamp,
					MAC:      mac,
					Username: rr.Username,
					Hostname: rr.Hostname,
					Via:      rr.Entrypoint,
				})
			default:
				s.cfg.log.Error("unsupported entry", zap.Uint8("kind", uint8(rr.RuleKind)))
				continue
			}
		case *parser.DeleteRecord:
			if !haveState {
				break
			}

			switch rr.RuleKind {
			case parser.RuleKindIPv4, parser.RuleKindIPv6:
				db.DelEntry(rr.Timestamp, rr.Hostname, rr.IP.String())
			case parser.RuleKindMAC:
				db.DelEntry(rr.Timestamp, rr.Hostname, rr.MAC.String())
			default:
				s.cfg.log.Error("unsupported entry", zap.Uint8("kind", uint8(rr.RuleKind)))
				continue
			}

		case *parser.RulesetRecord:
			haveState = true
			now := time.Now()
			epoch := now.UnixMilli()
			ts := now.Unix()
			for _, rule := range rr.Rules {
				switch rule.Kind {
				case parser.RuleKindIPv4, parser.RuleKindIPv6:
					ip := rule.IP.String()
					db.AddEntry(ip, stampdb.Entry{
						Epoch:    epoch,
						AssocTS:  rule.Timestamp,
						IP:       ip,
						Username: rule.Username,
						Hostname: rr.Hostname,
						Via:      rule.Entrypoint,
					})
				case parser.RuleKindMAC:
					mac := rule.MAC.String()
					db.AddEntry(mac, stampdb.Entry{
						Epoch:    epoch,
						AssocTS:  rule.Timestamp,
						MAC:      mac,
						Username: rule.Username,
						Hostname: rr.Hostname,
						Via:      rule.Entrypoint,
					})
				default:
					s.cfg.log.Error("unsupported entry", zap.Uint8("kind", uint8(rule.Kind)))
					continue
				}
			}

			db.Cleanup(epoch, ts, rr.Hostname)
			s.cfg.log.Info("state synced")
		}
	}
}

func (s *Synchronizer) Shutdown(ctx context.Context) error {
	s.cancelFn()

	select {
	case <-ctx.Done():
		return ctx.Err()
	case <-s.done:
		return nil
	}
}
