package app

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io/ioutil"
	"net"
	"strings"
	"time"

	grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/gideon/nuvault/internal/config"
	"a.yandex-team.ru/security/gideon/nuvault/internal/yaver"
	"a.yandex-team.ru/security/gideon/nuvault/pkg/nuvrpc"
)

type App struct {
	cfg  *config.Config
	yavc *yaver.Yaver
	grpc *grpc.Server
	log  log.Logger
}

func NewApp(cfg *config.Config, l log.Logger) (*App, error) {
	yavc, err := yaver.NewYaver(cfg.Yav, l)
	if err != nil {
		return nil, fmt.Errorf("failed to create yav client: %w", err)
	}

	caCert, err := ioutil.ReadFile(cfg.API.ClientCa)
	if err != nil {
		return nil, fmt.Errorf("failed to read client CAs: %w", err)
	}

	serverCert, err := tls.LoadX509KeyPair(cfg.API.SvrCrt, cfg.API.SvrKey)
	if err != nil {
		return nil, fmt.Errorf("failed to read server certificate: %w", err)
	}

	caCertPool := x509.NewCertPool()
	caCertPool.AppendCertsFromPEM(caCert)
	tlsConfig := &tls.Config{
		Certificates: []tls.Certificate{serverCert},
		ClientCAs:    caCertPool,
		ClientAuth:   tls.VerifyClientCertIfGiven,
	}

	app := &App{
		cfg:  cfg,
		yavc: yavc,
		log:  l,
	}

	app.grpc = grpc.NewServer(
		grpc.Creds(credentials.NewTLS(tlsConfig)),
		grpcMiddleware.WithUnaryServerChain(
			app.errHandler,
			app.mutualTLSInterceptor,
		),
	)

	return app, nil
}

func (a *App) Start() error {
	if err := a.yavc.Sync(true); err != nil {
		return fmt.Errorf("can't sync secrets: %w", err)
	}

	listener, err := net.Listen("tcp", a.cfg.API.Addr)
	if err != nil {
		return fmt.Errorf("failed to listen: %w", err)
	}

	defer func() {
		_ = listener.Close()
	}()

	go a.yavc.Start()

	nuvrpc.RegisterNuVaultServiceServer(a.grpc, a)
	nuvrpc.RegisterCommonServiceServer(a.grpc, a)

	a.log.Infof("starting gRPC at %s", listener.Addr().String())
	return a.grpc.Serve(listener)
}

func (a *App) Shutdown(ctx context.Context) error {
	ok := make(chan struct{})
	go func() {
		a.yavc.Shutdown()
		a.grpc.GracefulStop()
		ok <- struct{}{}
	}()

	select {
	case <-ok:
		return nil
	case <-ctx.Done():
		return ctx.Err()
	}
}

func (a *App) errHandler(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
	resp, err := handler(ctx, req)
	if err != nil {
		a.log.Error("request error", log.String("method", info.FullMethod), log.Error(err))
	}

	return resp, err
}

func (a *App) mutualTLSInterceptor(
	ctx context.Context,
	req interface{},
	info *grpc.UnaryServerInfo,
	handler grpc.UnaryHandler,
) (interface{}, error) {

	// Skip authorize/logging for common methods
	if strings.HasPrefix(info.FullMethod, "/nuvrpc.CommonService/") {
		return handler(ctx, req)
	}

	start := time.Now()

	cn, err := authenticate(ctx)
	if err != nil {
		return nil, fmt.Errorf("auth fail: %w", err)
	}

	h, err := handler(ctx, req)

	if err == nil {
		a.log.Info(
			fmt.Sprintf("ok req: %s", info.FullMethod),
			log.String("method", info.FullMethod),
			log.String("cn", cn),
			log.String("elapsed", time.Since(start).String()),
		)
	} else {
		a.log.Error(
			fmt.Sprintf("fail req: %s", info.FullMethod),
			log.String("method", info.FullMethod),
			log.String("cn", cn),
			log.String("elapsed", time.Since(start).String()),
			log.Error(err),
		)
	}

	return h, err
}
