package lblight

import (
	"context"
	"errors"
	"fmt"
	"io"
	"os"
	"sync"
	"time"

	"github.com/cenkalti/backoff/v4"
	"github.com/gofrs/uuid"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/libs/go/lblight/internal/lbapi"
)

const (
	DefaultEndpoint = "logbroker.yandex.net"
	DefaultPort     = 2135
	Version         = "Light LB (v0.0.0)"
)

type Writer struct {
	endpoint       string
	topic          string
	sourceID       []byte
	partitionGroup uint32
	retries        uint64
	credentials    ydb.Credentials

	codec       Codec
	log         log.Logger
	session     *session
	sessionInfo SessionInfo
	stream      lbapi.PersQueueService_WriteSessionClient
	streamMu    sync.RWMutex
	streamCond  *sync.Cond
	ctx         context.Context
	closeFn     context.CancelFunc
}

type SessionInfo struct {
	NextSeqNo uint64
	SessionID string
	Topic     string
	Partition uint32
}

func NewWriter(topic string, opts ...WriterOption) (*Writer, error) {
	ctx, cancelCtx := context.WithCancel(context.Background())
	w := Writer{
		topic:    topic,
		endpoint: DefaultEndpoint,
		log:      &nop.Logger{},
		sourceID: defaultSourceID(),
		retries:  2,
		ctx:      ctx,
		closeFn:  cancelCtx,
		codec:    CodecRaw,
	}

	w.streamCond = sync.NewCond(&w.streamMu)
	for _, opt := range opts {
		opt(&w)
	}

	var err error
	w.session, err = NewSession(SessionWithLogger(w.log), SessionWithEndpoint(w.endpoint))
	if err != nil {
		return nil, err
	}

	return &w, nil
}

func (w *Writer) Dial(ctx context.Context) error {
	err := backoff.RetryNotify(
		func() error {
			err := w.dial(ctx)
			var lbErr *Error
			if !errors.As(err, &lbErr) {
				return err
			}

			if lbErr.Retryable() {
				return err
			}
			return backoff.Permanent(err)
		},
		backoff.WithContext(backoff.WithMaxRetries(backoff.NewExponentialBackOff(), w.retries), ctx),
		func(err error, duration time.Duration) {
			w.log.Error("can't initialize connection", log.Duration("sleep", duration), log.Error(err))
		},
	)
	return err
}

func (w *Writer) SessionInfo() SessionInfo {
	return w.sessionInfo
}

func (w *Writer) Write(ctx context.Context, msg WriteMsg) error {
	send := func() error {
		w.streamMu.RLock()
		defer w.streamMu.RUnlock()

		if w.stream == nil {
			return ErrNoStream
		}

		if err := w.stream.Context().Err(); err != nil {
			return err
		}

		creds, err := w.Credentials(ctx)
		if err != nil {
			return err
		}

		seqNo := msg.SeqNo
		if seqNo == 0 {
			seqNo = w.sessionInfo.NextSeqNo
			w.sessionInfo.NextSeqNo++
		}

		return w.stream.SendMsg(&lbapi.WriteRequest{
			Credentials: creds,
			Request: &lbapi.WriteRequest_Data_{
				Data: &lbapi.WriteRequest_Data{
					SeqNo:        seqNo,
					CreateTimeMs: msg.PqTimestamp(),
					Data:         msg.Data,
					Codec:        w.codec,
				},
			},
		})
	}

	return w.doWrite(ctx, send)
}

func (w *Writer) WriteBatch(ctx context.Context, msgs ...WriteMsg) error {
	send := func() error {
		w.streamMu.RLock()
		defer w.streamMu.RUnlock()

		if w.stream == nil {
			return ErrNoStream
		}

		if err := w.stream.Context().Err(); err != nil {
			return err
		}

		creds, err := w.Credentials(ctx)
		if err != nil {
			return err
		}

		data := make([]*lbapi.WriteRequest_Data, len(msgs))
		for i, m := range msgs {
			seqNo := m.SeqNo
			if seqNo == 0 {
				seqNo = w.sessionInfo.NextSeqNo
				w.sessionInfo.NextSeqNo++
			}

			data[i] = &lbapi.WriteRequest_Data{
				SeqNo:        seqNo,
				CreateTimeMs: m.PqTimestamp(),
				Data:         m.Data,
				Codec:        w.codec,
			}
		}

		return w.stream.SendMsg(&lbapi.WriteRequest{
			Credentials: creds,
			Request: &lbapi.WriteRequest_DataBatch_{
				DataBatch: &lbapi.WriteRequest_DataBatch{
					Data: data,
				},
			},
		})
	}

	return w.doWrite(ctx, send)
}

func (w *Writer) doWrite(ctx context.Context, send func() error) error {
	return backoff.RetryNotify(
		func() error {
			err := send()
			if err == nil {
				return nil
			}

			w.log.Error("can't send message, reconnecting", log.Error(err))
			dialErr := w.Dial(w.ctx)
			if dialErr != nil {
				return backoff.Permanent(dialErr)
			}

			return ErrNextTry
		},
		backoff.WithContext(backoff.WithMaxRetries(backoff.NewExponentialBackOff(), w.retries), ctx),
		func(err error, duration time.Duration) {
			if err == ErrNextTry {
				w.log.Info("next try to write message")
				return
			}

			w.log.Error("can't write message", log.Duration("sleep", duration), log.Error(err))
		},
	)
}

// FetchFeedback в случае не доступности LB - ждет переподключения при попытке записать сообщение.
// Это должно минимизировать количество сессий и дрочи LB (во всяком случае пока у нас не большой рейт сообщений с хоста)
func (w *Writer) FetchFeedback() (FeedbackMsg, error) {
	var (
		out       FeedbackMsg
		sessionID string
	)
	err := backoff.RetryNotify(
		func() (err error) {
			sessionID, out, err = w.fetchFeedback()
			if err == nil {
				return nil
			}

			if err == io.EOF {
				w.log.Warn("can't receive feedback, wait reconnection", log.String("session_id", sessionID), log.Error(err))
				return w.waitStream(sessionID)
			}

			if s, ok := status.FromError(err); ok && s.Code() == codes.Unavailable {
				w.log.Warn("can't receive feedback, wait reconnection", log.String("session_id", sessionID), log.Error(err))
				return w.waitStream(sessionID)
			}

			return err
		},
		backoff.WithContext(backoff.WithMaxRetries(backoff.NewExponentialBackOff(), w.retries), w.ctx),
		func(err error, duration time.Duration) {
			if err == ErrNextTry {
				w.log.Info("next try to receive feedback")
				// no need to log
				return
			}

			w.log.Error("failed to receive feedback", log.Duration("sleep", duration), log.Error(err))
		},
	)

	select {
	case <-w.ctx.Done():
		return nil, ErrClosed
	default:
		return out, err
	}
}

func (w *Writer) Close(_ context.Context) error {
	defer w.closeFn()

	if w.session != nil {
		return w.session.Close()
	}
	return nil
}

func (w *Writer) Credentials(ctx context.Context) (*lbapi.Credentials, error) {
	if w.credentials == nil {
		return nil, nil
	}

	token, err := w.credentials.Token(ctx)
	if err != nil {
		return nil, err
	}

	return &lbapi.Credentials{
		Credentials: &lbapi.Credentials_OauthToken{
			OauthToken: []byte(token), // TODO: there is one typeless token field in new V1 protocol.
		},
	}, nil
}

func (w *Writer) dial(ctx context.Context) error {
	w.streamMu.Lock()
	defer w.streamMu.Unlock()

	creds, err := w.Credentials(ctx)
	if err != nil {
		return err
	}

	err = w.session.Dial(ctx)
	if err != nil {
		return fmt.Errorf("can't dial new session: %w", err)
	}

	w.stream, err = w.session.Client.WriteSession(ctx)
	if err != nil {
		if err := w.session.Close(); err != nil {
			w.log.Warn("can't close session", log.Error(err))
		}

		return err
	}

	err = w.stream.SendMsg(&lbapi.WriteRequest{
		Request: &lbapi.WriteRequest_Init_{
			Init: &lbapi.WriteRequest_Init{
				Topic:          w.topic,
				SourceId:       w.sourceID,
				ProxyCookie:    w.session.Cookie,
				PartitionGroup: w.partitionGroup,
				Version:        Version,
			},
		},
		Credentials: creds,
	})

	if err != nil {
		return err
	}

	rsp := lbapi.WriteResponse{}
	if err = w.stream.RecvMsg(&rsp); err != nil {
		return err
	}

	switch msg := rsp.Response.(type) {
	case *lbapi.WriteResponse_Error:
		if msg.Error == nil {
			return errors.New("received error response w/o error")
		}

		return &Error{Code: msg.Error.Code, Description: msg.Error.Description}
	case *lbapi.WriteResponse_Init_:
		if msg.Init == nil {
			return errors.New("received init response w/o error")
		}

		reply := msg.Init

		w.sessionInfo = SessionInfo{
			NextSeqNo: reply.MaxSeqNo + 1,
			SessionID: reply.SessionId,
			Topic:     reply.Topic,
			Partition: reply.Partition,
		}

		w.log.Info("stream initialized", log.Any("session_info", w.sessionInfo))
		w.streamCond.Broadcast()
		return nil
	case nil:
		return errors.New("received nil response")
	default:
		return fmt.Errorf("persqueue: received unexpected response type %T", msg)
	}
}

func (w *Writer) fetchFeedback() (string, FeedbackMsg, error) {
	var (
		msg       lbapi.WriteResponse
		sessionID string
	)

	w.streamMu.RLock()
	if w.stream == nil {
		return sessionID, nil, ErrNoStream
	}

	sessionID = w.sessionInfo.SessionID
	err := w.stream.RecvMsg(&msg)
	w.streamMu.RUnlock()

	if err != nil {
		return sessionID, nil, err
	}

	if msg.Response == nil {
		return sessionID, nil, ErrNoFeedback
	}

	switch rsp := msg.Response.(type) {
	case *lbapi.WriteResponse_AckBatch_:
		out := &FeedbackMsgAckBatch{
			ACKs: make([]FeedbackMsgAck, len(rsp.AckBatch.Ack)),
		}
		for i, a := range rsp.AckBatch.Ack {
			out.ACKs[i] = FeedbackMsgAck{
				SeqNo:          a.SeqNo,
				Offset:         a.Offset,
				AlreadyWritten: a.AlreadyWritten,
			}
		}
		return sessionID, out, nil
	case *lbapi.WriteResponse_Ack_:
		return sessionID, &FeedbackMsgAck{
			SeqNo:          rsp.Ack.SeqNo,
			Offset:         rsp.Ack.Offset,
			AlreadyWritten: rsp.Ack.AlreadyWritten,
		}, nil
	case *lbapi.WriteResponse_Error:
		if rsp.Error == nil {
			return sessionID, nil, ErrNilErrFeedback
		}

		return sessionID, &FeedbackMsgError{
			Error: &Error{
				rsp.Error.Code,
				rsp.Error.Description,
			},
		}, nil
	default:
		return sessionID, nil, fmt.Errorf("received unexpected feedback message: %T", rsp)
	}
}

func (w *Writer) waitStream(curSessionID string) error {
	w.streamMu.Lock()
	defer w.streamMu.Unlock()

	if curSessionID != w.sessionInfo.SessionID {
		w.log.Info(
			"skip stream cond waiting, due to session_id mismatch",
			log.String("waiter_session_id", curSessionID),
			log.String("current_session_id", w.sessionInfo.SessionID),
			log.Any("session_info", w.sessionInfo),
		)
		// probably writer already create the new session
		return ErrNextTry
	}

	w.streamCond.Wait()
	return ErrNextTry
}

func defaultSourceID() []byte {
	sourceID, err := os.Hostname()
	if err == nil && sourceID != "" {
		return []byte(sourceID)
	}

	return []byte(uuid.Must(uuid.NewV4()).String())
}
