package bus

import (
	"context"
	"fmt"
	"log"
	"net"
	"os"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"google.golang.org/grpc"
	"google.golang.org/grpc/backoff"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/keepalive"
	"google.golang.org/grpc/peer"

	"a.yandex-team.ru/solomon/tools/discovery/internal/bus/proto"
	"a.yandex-team.ru/solomon/tools/discovery/proto"
)

// ==========================================================================================

type communicatorServer struct {
	pbBus.UnimplementedCommunicatorServer
	bus *Bus
}

func (c *communicatorServer) Communicate(ctx context.Context, in *pbBus.Request) (*pbBus.Reply, error) {
	var name string

	if p, ok := peer.FromContext(ctx); !ok {
		c.bus.log(0, nil, "incoming request without client address, %#v", ctx)
		return nil, fmt.Errorf("bad client")
	} else {
		ipStr := p.Addr.String()
		ipStr = ipStr[:strings.LastIndexByte(ipStr, ':')]
		if name, ok = c.bus.GetClientName(ipStr); !ok {
			c.bus.log(0, nil, "incoming request from unknown client %s", ipStr)
			return nil, fmt.Errorf("unknown client")
		}
	}

	ok := c.bus.OnNotify(in.GetRefs(), name)
	return &pbBus.Reply{Ok: ok}, nil
}

// ==========================================================================================

type communicatorClient struct {
	client  pbBus.CommunicatorClient
	timeout time.Duration
}

func NewCommunicatorClient(conn *grpc.ClientConn, timeout time.Duration) *communicatorClient {
	return &communicatorClient{
		client:  pbBus.NewCommunicatorClient(conn),
		timeout: timeout,
	}
}

func (c *communicatorClient) SendMessage(v interface{}) (interface{}, error) {
	switch v := v.(type) {
	case []*pbData.DataRef:
		return c.DoNotify(v)
	}
	return nil, fmt.Errorf("unknown message type, %#v", v)
}

func (c *communicatorClient) DoNotify(refs []*pbData.DataRef) (bool, error) {
	ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
	defer cancel()

	reply, err := c.client.Communicate(ctx, &pbBus.Request{Refs: refs})
	return reply.GetOk(), err
}

// ==========================================================================================

type InterCom struct {
	req  interface{}
	resp interface{}
	err  error
	wg   sync.WaitGroup
}

type Client struct {
	name  string
	ipStr string
	stop  int32
	conn  *grpc.ClientConn
	comm  chan *InterCom
}

type Bus struct {
	LogPrefix             string
	VerboseLevel          int
	ClientTimeout         time.Duration
	onNotifyMutex         sync.RWMutex
	onNotifyCallbacks     []func([]*pbData.DataRef, string)
	getCluster            func() (map[string]net.IP, error)
	serverListener        net.Listener
	server                *grpc.Server
	port                  string
	clientsUpdateInterval time.Duration
	clientsChanLength     int
	clients               map[string]*Client
	clientsIPToName       map[string]string
	clientsMutex          sync.RWMutex
	clientsWg             sync.WaitGroup
	cronWg                sync.WaitGroup
	stopChan              chan struct{}
}

func NewBus(
	addr string,
	getCluster func() (map[string]net.IP, error),
	verboseLevel int) (*Bus, error) {

	var err error

	b := &Bus{
		LogPrefix:             "[bus] ",
		VerboseLevel:          verboseLevel,
		ClientTimeout:         10 * time.Second,
		onNotifyCallbacks:     []func([]*pbData.DataRef, string){},
		clientsUpdateInterval: 30 * time.Second,
		getCluster:            getCluster,
		clientsChanLength:     10,
		clients:               make(map[string]*Client),
		clientsIPToName:       make(map[string]string),
		stopChan:              make(chan struct{}),
	}

	_, b.port, err = net.SplitHostPort(addr)
	if err != nil {
		return nil, err
	}
	if err := b.updateClientsList(); err != nil {
		return nil, err
	}
	if err = b.startServer(addr); err != nil {
		return nil, err
	}
	b.cronWg.Add(1)
	go b.cron()

	b.log(0, nil, "started")
	return b, nil
}

func (b *Bus) log(lvl int, ts *time.Time, format string, v ...interface{}) {
	if b.VerboseLevel >= lvl {
		tsStr := ""
		if ts != nil {
			tsStr = ", " + time.Since(*ts).String()
		}
		log.Printf(b.LogPrefix+format+tsStr, v...)
	}
}

// ==========================================================================================

func (b *Bus) cron() {
	defer b.cronWg.Done()

	tickerUpdate := time.NewTicker(b.clientsUpdateInterval)
	defer tickerUpdate.Stop()

	for {
		select {
		case <-tickerUpdate.C:
			if err := b.updateClientsList(); err != nil {
				b.log(0, nil, "%v", err)
			}
		case <-b.stopChan:
			b.log(1, nil, "exiting cron task")
			return
		}
	}
}

// ==========================================================================================

func (b *Bus) GetClientName(ipStr string) (string, bool) {
	b.clientsMutex.RLock()
	defer b.clientsMutex.RUnlock()

	name, ok := b.clientsIPToName[ipStr]
	return name, ok
}

// ==========================================================================================
// Recv

func (b *Bus) AddOnNotifyCallback(cb func([]*pbData.DataRef, string)) {
	b.onNotifyMutex.Lock()
	defer b.onNotifyMutex.Unlock()

	b.onNotifyCallbacks = append(b.onNotifyCallbacks, cb)
}

func (b *Bus) OnNotify(refs []*pbData.DataRef, name string) bool {
	reqTime := time.Now()
	b.log(1, nil, "got notification from %s (ready to run onNotify callbacks): %v", name, refs)

	b.onNotifyMutex.RLock()
	defer b.onNotifyMutex.RUnlock()

	// Callbacks
	b.log(1, &reqTime, "notification callbacks started")
	for _, cb := range b.onNotifyCallbacks {
		cb(refs, name)
	}
	b.log(1, &reqTime, "notification callbacks finished")
	return true
}

// ==========================================================================================
// Send

func (b *Bus) DoNotifyClient(name string, refs []*pbData.DataRef, async bool) (bool, error) {
	b.clientsMutex.RLock()
	defer b.clientsMutex.RUnlock()

	c, ok := b.clients[name]
	if !ok {
		return false, fmt.Errorf("unknown client")
	}
	b.log(1, nil, "notifying %s (async=%v)", name, async)

	com := &InterCom{req: refs}
	com.wg.Add(1)
	c.comm <- com

	if async {
		return true, nil
	}
	com.wg.Wait()

	return com.resp.(bool), com.err
}

func (b *Bus) DoNotifyClients(refs []*pbData.DataRef, async bool) (bool, error) {
	errStr := ""
	reqTime := time.Now()

	b.clientsMutex.RLock()
	b.log(1, &reqTime, "notifying all clients (async=%v)", async)
	coms := make(map[string]*InterCom, len(b.clients))
	for name, c := range b.clients {
		com := &InterCom{req: refs}
		coms[name] = com
		com.wg.Add(1)
		c.comm <- com
	}
	b.clientsMutex.RUnlock()

	if async {
		return true, nil
	}
	ok := true
	for name, com := range coms {
		com.wg.Wait()
		if ok && !com.resp.(bool) {
			ok = false
			errStr += fmt.Sprintf("%s: %v; ", name, com.err)
		}
	}
	b.log(1, &reqTime, "notified all clients, ok=%v", ok)

	if !ok {
		return false, fmt.Errorf(errStr)
	}
	return true, nil
}

// ==========================================================================================

func (b *Bus) updateClientsList() error {
	hostname, err := os.Hostname()
	if err != nil {
		return err
	}
	clients, err := b.getCluster()
	if err != nil {
		return err
	}

	b.clientsMutex.RLock()
	clientsIPToName := map[string]string{}
	clientsToStart := map[string]string{}
	clientsToStop := make(map[string]struct{}, len(clients))
	for name := range b.clients {
		clientsToStop[name] = struct{}{}
	}
	for name, ip := range clients {
		if name == hostname {
			continue
		}

		var ipStr string
		if ip4 := ip.To4(); ip4 != nil {
			ipStr = ip4.String()
		} else {
			ipStr = "[" + ip.String() + "]"
		}
		clientsIPToName[ipStr] = name

		if client, ok := b.clients[name]; !ok || client.ipStr != ipStr {
			clientsToStart[name] = ipStr
		} else {
			delete(clientsToStop, name)
		}
	}
	b.clientsIPToName = clientsIPToName
	b.clientsMutex.RUnlock()

	for name := range clientsToStop {
		if err = b.stopClient(name); err != nil {
			b.log(0, nil, "failed to stop client %s: %v", name, err)
		}
	}
	for name, ipStr := range clientsToStart {
		if err = b.startClient(name, ipStr); err != nil {
			return err
		}
	}
	return nil
}

// ==========================================================================================

func (b *Bus) startClient(name, ipStr string) error {
	conn, err := grpc.Dial(
		net.JoinHostPort(name, b.port),
		grpc.WithConnectParams(
			grpc.ConnectParams{
				Backoff: backoff.Config{
					BaseDelay:  1.0 * time.Second,
					Multiplier: 1.6,
					Jitter:     0.2,
					MaxDelay:   120 * time.Second,
				},
				MinConnectTimeout: 5 * time.Second,
			},
		),
		grpc.WithContextDialer(
			func(ctx context.Context, addr string) (net.Conn, error) {
				var d net.Dialer

				_, port, err := net.SplitHostPort(addr)
				if err != nil {
					return nil, &net.OpError{
						Op:     "dial",
						Net:    "tcp",
						Source: nil,
						Addr:   nil,
						Err:    err,
					}
				}
				return d.DialContext(ctx, "tcp", ipStr+":"+port)
			},
		),
		grpc.WithKeepaliveParams(
			keepalive.ClientParameters{
				Time:                10 * time.Second,
				Timeout:             10 * time.Second,
				PermitWithoutStream: true,
			},
		),
		grpc.WithTransportCredentials(insecure.NewCredentials()),
	)
	if err != nil {
		return fmt.Errorf("failed to start grpc client %s, %v", name, err)
	}

	c := &Client{
		name:  name,
		ipStr: ipStr,
		conn:  conn,
		comm:  make(chan *InterCom, b.clientsChanLength),
	}

	b.clientsMutex.Lock()
	defer b.clientsMutex.Unlock()

	b.clientsWg.Add(1)
	go b.client(c)

	b.clients[name] = c
	return nil
}

func (b *Bus) client(c *Client) {
	b.log(0, nil, "grpc client %s started", c.name)
	cc := NewCommunicatorClient(c.conn, b.ClientTimeout)

	for msg := range c.comm {
		reqTime := time.Now()
		if atomic.LoadInt32(&c.stop) != 0 {
			msg.err = fmt.Errorf("not sent, shutting down")
			b.log(1, &reqTime, "not sending message to %s: shutting down", c.name)
		} else {
			msg.resp, msg.err = cc.SendMessage(msg.req)
			b.log(1, &reqTime, "notification response from %s: %v, err=%v", c.name, msg.resp, msg.err)
		}
		msg.wg.Done()
	}

	if err := c.conn.Close(); err != nil {
		b.log(0, nil, "failed to stop grpc client %s, %v", c.name, err)
	} else {
		b.log(0, nil, "grpc client %s is stopped", c.name)
	}
	b.clientsWg.Done()
}

func (b *Bus) stopClient(name string) error {
	b.clientsMutex.Lock()
	defer b.clientsMutex.Unlock()

	if c, ok := b.clients[name]; ok {
		close(c.comm)
		delete(b.clients, name)
		// Skip sending all queued messages
		atomic.StoreInt32(&c.stop, 1)
	} else {
		return fmt.Errorf("no such grpc client")
	}
	return nil
}

func (b *Bus) stopClients() {
	b.clientsMutex.Lock()
	defer b.clientsMutex.Unlock()

	for _, c := range b.clients {
		close(c.comm)
		// Skip sending all queued messages
		atomic.StoreInt32(&c.stop, 1)
	}
	b.clients = map[string]*Client{}
	b.clientsWg.Wait()
}

// ==========================================================================================

func (b *Bus) startServer(addr string) error {
	var err error

	b.serverListener, err = net.Listen("tcp", addr)
	if err != nil {
		return err
	}
	b.server = grpc.NewServer(
		grpc.ConnectionTimeout(15*time.Second),
		grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
			MinTime:             5 * time.Second,
			PermitWithoutStream: true,
		}),
		grpc.KeepaliveParams(keepalive.ServerParameters{
			Time:    30 * time.Second,
			Timeout: 10 * time.Second,
		}),
	)
	pbBus.RegisterCommunicatorServer(b.server, &communicatorServer{bus: b})
	go func() {
		if err := b.server.Serve(b.serverListener); err != nil {
			b.log(0, nil, "grpc server is failed to serve, %v", err)
		} else {
			b.log(0, nil, "grpc server is stopped")
		}
	}()
	return nil
}

func (b *Bus) stopServer() {
	b.server.Stop()
	_ = b.serverListener.Close()
}

// ==========================================================================================

func (b *Bus) Shutdown() {
	b.log(0, nil, "begin shutdown")
	defer b.log(0, nil, "stopped")

	close(b.stopChan)
	b.stopServer()
	b.stopClients()

	b.cronWg.Wait()
}
