package lblight

import (
	"context"
	"errors"
	"fmt"
	"time"

	"github.com/golang/protobuf/proto"
	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Discovery"
	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Operations"
	"google.golang.org/grpc"
	"google.golang.org/grpc/keepalive"
	"google.golang.org/grpc/metadata"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/libs/go/lblight/internal/lbapi"
)

const (
	magicCookie    = uint64(123456789)
	maxMessageSize = 50 * 1024 * 1024
	rootDatabase   = "/Root"
	pqService      = "pq"
	metaDatabase   = "x-ydb-database"
	metaAuthTicket = "x-ydb-auth-ticket"
)

type sessionOptions struct {
	endpoint    string
	port        int
	credentials ydb.Credentials
	proxy       string
	database    string
	grpcCreds   grpc.DialOption
}

func (s *sessionOptions) Endpoint() string {
	return fmt.Sprintf("%s:%d", s.endpoint, s.port)
}

type session struct {
	settings sessionOptions

	proxyName string
	sessionID string

	log  log.Logger
	conn *grpc.ClientConn

	Cookie uint64
	Client lbapi.PersQueueServiceClient
}

func NewSession(opts ...SessionOption) (*session, error) {
	s := session{
		settings: sessionOptions{
			endpoint:  DefaultEndpoint,
			port:      DefaultPort,
			database:  rootDatabase,
			grpcCreds: grpc.WithInsecure(),
		},
		log: &nop.Logger{},
	}

	for _, opt := range opts {
		if err := opt(&s); err != nil {
			return nil, err
		}
	}

	return &s, nil
}

func (s *session) Dial(ctx context.Context) error {
	if s.conn != nil {
		if err := s.conn.Close(); err != nil {
			s.log.Warn("can't close previous session", log.Error(err))
		}
	}

	err := s.discoverProxy(ctx)
	if err != nil {
		return fmt.Errorf("can't discover lb proxy: %w", err)
	}

	s.log.Info("proxy selected", log.Any("proxyName", s.proxyName))

	err = s.createClient(ctx)
	if err != nil {
		return fmt.Errorf("can't create lb proxy connection: %w", err)
	}
	s.log.Info("session connection created", log.Any("proxyName", s.proxyName))

	return nil
}

func (s *session) Close() error {
	if s.conn != nil {
		err := s.conn.Close()
		s.conn = nil
		return err
	}
	return nil
}

func (s *session) discoverProxy(ctx context.Context) error {
	if s.settings.proxy != "" {
		s.proxyName = s.settings.proxy
		return nil
	}

	req := Ydb_Discovery.ListEndpointsRequest{
		Database: s.settings.database,
		Service:  []string{pqService},
	}
	var res Ydb_Discovery.ListEndpointsResult
	err := s.callYDBOperation(ctx, "/Ydb.Discovery.V1.DiscoveryService/ListEndpoints", &req, &res)
	if err != nil && len(res.Endpoints) == 0 {
		return err
	}

	if len(res.Endpoints) == 0 {
		return errors.New("endpoint no found")
	}

	endpoint := res.Endpoints[0]
	s.proxyName = fmt.Sprintf("%s:%d", endpoint.Address, endpoint.Port)
	s.Cookie = magicCookie
	return nil
}

func (s *session) createClient(ctx context.Context) error {
	ctx, err := s.appendOutgoingContext(ctx)
	if err != nil {
		return err
	}

	conn, err := grpc.DialContext(
		ctx,
		s.proxyName,
		s.settings.grpcCreds,
		grpc.WithKeepaliveParams(
			keepalive.ClientParameters{
				Time:                90 * time.Second,
				Timeout:             time.Second,
				PermitWithoutStream: true,
			},
		),
		grpc.WithDefaultCallOptions(
			grpc.MaxCallRecvMsgSize(maxMessageSize),
			grpc.MaxCallSendMsgSize(maxMessageSize),
		),
	)

	if err != nil {
		return err
	}

	s.conn = conn
	s.Client = lbapi.NewPersQueueServiceClient(conn)
	return nil
}

func (s *session) appendOutgoingContext(ctx context.Context) (context.Context, error) {
	opts := []string{
		metaDatabase, s.settings.database,
	}

	if s.settings.credentials != nil {
		token, err := s.settings.credentials.Token(ctx)
		if err != nil {
			return nil, err
		}
		opts = append(opts, metaAuthTicket, token)
	}

	return metadata.AppendToOutgoingContext(ctx, opts...), nil
}

func (s *session) callYDBOperation(ctx context.Context, method string, req, res proto.Message) error {
	ctx, err := s.appendOutgoingContext(ctx)
	if err != nil {
		return err
	}

	conn, err := grpc.DialContext(ctx, s.settings.Endpoint(), s.settings.grpcCreds)
	if err != nil {
		return err
	}
	defer func() { _ = conn.Close() }()

	var resp Ydb_Operations.GetOperationResponse
	err = conn.Invoke(ctx, method, req, &resp)
	if err != nil {
		return err
	}

	op := resp.Operation
	if !op.Ready {
		return errors.New("endpoint discovery operations returned not ready status; this should never happen")
	}

	if op.Status != Ydb.StatusIds_SUCCESS {
		return fmt.Errorf("endpoint discovery failed: %s", op.Status.String())
	}

	return proto.Unmarshal(op.Result.Value, res)
}
