package skyboned

import (
	"context"
	"errors"
	"fmt"
	"net/http"
	"strconv"
	"strings"
	"time"

	pb "a.yandex-team.ru/infra/skyboned/go/skyboned_rpc"
	"a.yandex-team.ru/infra/skyboned/go/src/api"
	"a.yandex-team.ru/infra/skyboned/go/src/auth"
	"a.yandex-team.ru/infra/skyboned/go/src/database"
	"a.yandex-team.ru/infra/skyboned/go/src/util"

	"a.yandex-team.ru/library/go/yandex/tvm/tvmauth"
	"github.com/jackc/pgx/v4"
	"github.com/vmihailenco/msgpack/v4"
	"go.uber.org/zap"

	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/types/known/emptypb"
	"google.golang.org/protobuf/types/known/structpb"
	"google.golang.org/protobuf/types/known/timestamppb"
)

const (
	SkybonedOwnerID = 2
)

type TID struct{}

// Skyboned server for gRPC
type SkybonedServer struct {
	db  *database.DB
	tvm *tvmauth.Client

	metricGatherers map[string]func(*database.DB, context.Context) error
}

// Setup
func SetupSkybonedServer() (*SkybonedServer, error) {
	db, err := database.DBSetup(&database.DBConfGlobal)
	if err != nil {
		zap.S().Fatalf("database setup failed: %v", err)
	}
	zap.S().Info("init dbpool")
	var tvm *tvmauth.Client
	if !auth.TVMDisable {
		tvm, err = auth.TVMSetup(zap.L().Core())
		if err != nil {
			zap.S().Fatalf("TVM client setup failed: %v", err)
		}
		zap.S().Info("init tvm auth")
	} else {
		zap.S().Info("disabled tvm auth")
	}

	gatherers := map[string]func(*database.DB, context.Context) error{
		"announce_overdue": func(db *database.DB, ctx context.Context) error {
			var (
				cnt            int
				trackerAddress string
			)
			rows, err := db.Query(ctx, "SELECT at.address, COUNT(a.resource_id) FROM announce a JOIN announce_tracker at ON at.id = a.tracker_id WHERE a.schedule_ts < $1 GROUP BY at.id",
				time.Now().Add(-time.Second*60).Unix())
			if err != nil {
				return err
			}
			for rows.Next() {
				err = rows.Scan(&trackerAddress, &cnt)
				if err != nil {
					return err
				} else {
					Gauges["announce_overdue"].With(map[string]string{"tracker_address": strings.Split(trackerAddress, ":")[0]}).Set(float64(cnt))
				}
			}
			return nil
		},
		"resource_count": func(db *database.DB, ctx context.Context) error {
			var (
				cnt       int
				onwerName string
			)
			rows, err := db.Query(ctx, "SELECT count, name FROM (SELECT owner, COUNT(*) FROM resource GROUP BY owner) t1 INNER JOIN (SELECT id, name FROM owner) t2 ON (t1.owner = t2.id)")
			if err != nil {
				return err
			}
			for rows.Next() {
				err = rows.Scan(&cnt, &onwerName)
				if err != nil {
					return err
				} else {
					Gauges["resource_count"].With(map[string]string{"owner": onwerName}).Set(float64(cnt))
				}
			}
			return nil
		},
	}
	registry := SetupMetrics()
	if !auth.TVMDisable {
		SetupPusher(tvm, registry)
	}

	skbnd := SkybonedServer{
		db:              db,
		tvm:             tvm,
		metricGatherers: gatherers,
	}
	go skbnd.Metrics()
	return &skbnd, nil
}

// TVM Auth
func (ss *SkybonedServer) authTVM(ctx context.Context) (OwnerID int, err error) {
	// Get x-ya-service-ticket header and check it
	if md, ok := metadata.FromIncomingContext(ctx); ok {
		if tvmTicket, ok := md["x-ya-service-ticket"]; ok && len(tvmTicket) == 1 {
			chticket, err := ss.tvm.CheckServiceTicket(context.Background(), tvmTicket[0])
			if err != nil {
				return OwnerID, err
			}

			err = ss.db.QueryRow(ctx, "SELECT id FROM owner WHERE tvm_id = $1", chticket.SrcID).Scan(&OwnerID) //TODO: replace with cached version
			if errors.Is(err, pgx.ErrNoRows) {
				err = fmt.Errorf("%w: your tvm id %d not found in registered - who are you?", auth.ErrTVMAuthFailed, chticket.SrcID)
			}
			return OwnerID, err
		}
	}
	return OwnerID, auth.ErrTVMAuthFailed
}

func (ss *SkybonedServer) announceWait(ctx context.Context, resid string) (err error) {
	var cnt int
	q := func() error {
		err := ss.db.QueryRow(ctx, "SELECT COUNT(*) FROM announce WHERE resource_id = $1 AND schedule_ts != 0", resid).Scan(&cnt)
		return err
	}

	t := time.NewTicker(time.Second * 1)
	d := time.NewTimer(time.Second * 60)

	if err = q(); err != nil {
		return
	}
	if cnt != 0 {
		return nil
	}
	defer t.Stop()
	defer d.Stop()
	for {
		select {
		case <-d.C:
			return ErrSkybonedTimeout
		case <-t.C:
			if err = q(); err != nil {
				return err
			}
			if cnt != 0 {
				return nil
			}
		}

	}
}

func (ss *SkybonedServer) Metrics() {

	ctx := context.WithValue(context.Background(), auth.RequestID{}, "metrics")

	t := time.NewTicker(time.Second * 5)

	for range t.C {
		for signalName, gatherFunc := range ss.metricGatherers {
			if err := gatherFunc(ss.db, ctx); err != nil {
				zap.S().Errorw("metrics",
					signalName, err)
			}
		}

	}

}

func (ss *SkybonedServer) Ping(ctx context.Context, e *emptypb.Empty) (*pb.Pong, error) {
	return &pb.Pong{Response: "PONG! now in grpc"}, nil
}

// /add_resource
func (ss *SkybonedServer) Announce(ctx context.Context, ar *pb.AnnounceRequest) (resp *pb.Status, err error) {
	requestID := util.GenerateRequestID()
	ctx = context.WithValue(ctx, auth.RequestID{}, requestID)

	t := time.Now()
	defer func() {
		if resp == nil {
			resp = &pb.Status{Code: 500, Message: ""}
			err = errors.New("resp uninitialized")
		}
		Timers["add_resource_time"].RecordDuration(time.Since(t))
		Counters["add_resource_status"].With(map[string]string{"status": strconv.FormatInt(int64(resp.Code), 10)}).Inc()
		zap.S().Infow(
			requestID,
			"rbtorrent", ar.Uid,
			"/add_resource", resp.Code)
		if err != nil {
			zap.S().Errorw(
				requestID,
				"err", err)
		}
	}()

	ar.Uid = strings.TrimPrefix(ar.Uid, "rbtorrent:")

	ownerID := SkybonedOwnerID
	if !auth.TVMDisable {
		ownerID, err = ss.authTVM(ctx)
		if err != nil {
			return ErrorWrapper(http.StatusForbidden, err)
		}
	}

	md, ok := metadata.FromIncomingContext(ctx)
	if !ok {
		zap.S().Errorw(
			requestID,
			"rbtorrent", ar.Uid,
			"err", "could not retrieve metadata")
	} else {
		zap.S().Infow(
			requestID,
			"rbtorrent", ar.Uid,
			"addr", md["x-forwarded-for"][0],
			"handle", "/add_resource",
		)
	}

	//Check if resource exists
	var packedInfo []byte
	err = ss.db.QueryRow(ctx, "SELECT info FROM resource WHERE id = $1", ar.Uid).Scan(&packedInfo)
	if err != nil {
		if !errors.Is(err, pgx.ErrNoRows) {
			return ErrorWrapper(http.StatusInternalServerError, err)
		}
	} else {
		var info interface{}
		resp = &pb.Status{Code: 200, Message: "OK"}
		err = msgpack.Unmarshal(packedInfo, &info)
		if err != nil {
			return ErrorWrapper(http.StatusInternalServerError, err)
		}
		currentLinks, err := api.ParseInfo(info)
		if err != nil {
			return ErrorWrapper(http.StatusBadRequest, err)
		}
		incomingLinks, err := api.ParseInfo(ar.Info.AsMap())
		if err != nil {
			return ErrorWrapper(http.StatusBadRequest, err)
		}
		generalOpts := map[string]string{}
		if ar.SourceId == "" {
			if incomingLinks.Type != api.LinkV1 {
				return ErrorWrapper(http.StatusBadRequest, fmt.Errorf("no source_id supplied for request"))
			}
			return resp, nil
		}
		generalOpts["SourceId"] = ar.SourceId
		err = incomingLinks.SetOpts(generalOpts)
		if err != nil {
			return ErrorWrapper(http.StatusBadRequest, err)
		}
		err = currentLinks.Update(incomingLinks)
		if err != nil {
			if errors.Is(err, api.ErrExceededLinksLimit) || errors.Is(err, api.ErrSkipUpdate) {
				zap.S().Errorw(
					requestID,
					"rbtorrent", ar.Uid,
					"err", err)
				return resp, nil
			}
			return ErrorWrapper(http.StatusBadRequest, err)
		}
		packedInfo, err = currentLinks.Pack()
		if err != nil {
			return ErrorWrapper(http.StatusInternalServerError, err)
		}
		_, err = ss.db.Exec(ctx, "UPDATE resource SET info = $1, ts_used = 2 WHERE id = $2", packedInfo, ar.Uid)
		if err != nil {
			return ErrorWrapper(http.StatusInternalServerError, err)
		}
		return resp, nil
	}

	// Head goes bencode -> map -> msgpack
	head, err := api.DecodeHead(ar.Head)
	if err != nil {
		return ErrorWrapper(http.StatusBadRequest, err, "skyboned: head decode failed")
	}
	api.FixHead(head)
	packedhead, err := msgpack.Marshal(head)
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err, "skyboned: head package to msgpack failed")
	}

	// Info (links) goes protobuf.Struct -> map -> msgpack
	links, err := api.ParseInfo(ar.Info.AsMap())
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err)
	}

	if ar.SourceId != "" {
		generalOpts := map[string]string{
			"SourceId": ar.SourceId,
		}
		err = links.SetOpts(generalOpts)
		if err != nil {
			return ErrorWrapper(http.StatusBadRequest, err)
		}
	} else {
		if links.Type != api.LinkV1 {
			return ErrorWrapper(http.StatusBadRequest, fmt.Errorf("skyboned: missing source_id in V2 request"))
		}
	}

	packedinfo, err := links.Pack()
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err, "skyboned: info package to msgpack failed")
	}

	tx, err := ss.db.TxStart(ctx)
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err)
	}
	defer tx.Rollback(ctx)

	_, err = ss.db.TxExec(ctx, tx, "DELETE FROM announce_op WHERE resource = $1", ar.Uid)
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err)
	}

	_, err = ss.db.TxExec(ctx, tx, "INSERT INTO resource (id, type, data, info, mode, owner, ts_added, ts_used) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT DO NOTHING",
		ar.Uid,
		"rbtorrent1",
		packedhead,
		packedinfo,
		ar.Mode,
		ownerID,
		time.Now().Unix(),
		1, // thus we distinguish resources added from new api, this field was unused in db at the time
	)
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err)
	}

	_, err = ss.db.TxExec(ctx, tx, "INSERT INTO announce (resource_id, schedule_ts, tracker_id) SELECT $1, 0, id as tracker_id FROM announce_tracker ON CONFLICT DO NOTHING", ar.Uid)
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err)
	}
	err = tx.Commit(ctx)
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err)
	}
	_, err = ss.db.Exec(ctx, "NOTIFY announce_new")
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err)
	}

	if !NoWait {
		err = ss.announceWait(ctx, ar.Uid)
		if err != nil {
			if errors.Is(err, ErrSkybonedTimeout) {
				return ErrorWrapper(http.StatusServiceUnavailable, err)
			} else {
				return ErrorWrapper(http.StatusInternalServerError, err)
			}
		}
	}

	resp = &pb.Status{Code: 200, Message: "OK"}
	return
}

// /remove_resource
func (ss *SkybonedServer) Delete(ctx context.Context, dr *pb.DeleteRequest) (resp *pb.Status, err error) {
	requestID := util.GenerateRequestID()
	ctx = context.WithValue(ctx, auth.RequestID{}, requestID)

	t := time.Now()
	defer func() {
		if resp == nil {
			resp = &pb.Status{Code: 500, Message: ""}
			err = errors.New("resp uninitialized")
		}
		Timers["remove_resource_time"].RecordDuration(time.Since(t))
		Counters["remove_resource_status"].With(map[string]string{"status": strconv.FormatInt(int64(resp.Code), 10)}).Inc()
		zap.S().Infow(
			requestID,
			"rbtorrent", dr.Uid,
			"/remove_resource", resp.Code)
		if err != nil {
			zap.S().Errorw(
				requestID,
				"err", err)
		}
	}()

	dr.Uid = strings.TrimPrefix(dr.Uid, "rbtorrent:")

	ownerID := SkybonedOwnerID
	if !auth.TVMDisable {
		ownerID, err = ss.authTVM(ctx)
		if err != nil {
			return ErrorWrapper(http.StatusForbidden, err)
		}
	}

	md, ok := metadata.FromIncomingContext(ctx)
	if !ok {
		zap.S().Errorw(
			requestID,
			"rbtorrent", dr.Uid,
			"err", "could not retrieve metadata")
	} else {
		zap.S().Infow(
			requestID,
			"rbtorrent", dr.Uid,
			"addr", md["x-forwarded-for"][0],
			"handle", "/remove_resource",
		)
	}

	//Check if resource exists
	var packedInfo []byte
	var resourceOwnerID int
	err = ss.db.QueryRow(ctx, "SELECT info, owner FROM resource WHERE id = $1", dr.Uid).Scan(&packedInfo, &resourceOwnerID)
	if err != nil {
		if errors.Is(err, pgx.ErrNoRows) {
			// We send 200 if resource does not exist
			resp = &pb.Status{Code: 200, Message: "OK"}
			return resp, nil //set err to nil
		}
		return ErrorWrapper(http.StatusInternalServerError, err)
	}

	// Check if OwnerId is the same as OwnerId which resource had been shared
	if !auth.CheckTVMID(resourceOwnerID, ownerID) && dr.SourceId == "" {
		return ErrorWrapper(http.StatusForbidden, ErrSkybonedForbidden)
	}

	links, err := api.ParseInfo(packedInfo)
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err)
	}

	obliterate := links.Delete(dr.SourceId)

	tx, err := ss.db.Master.Begin(ctx)
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err)
	}
	defer tx.Rollback(ctx)

	if obliterate {
		_, err = ss.db.TxExec(ctx, tx, "DELETE FROM resource WHERE id = $1 AND owner = $2", dr.Uid, ownerID)
		if err != nil {
			return ErrorWrapper(http.StatusInternalServerError, err)
		}

		_, err = ss.db.TxExec(
			ctx,
			tx,
			"INSERT INTO announce_op (resource, tracker_id, deadline, op)"+
				"SELECT $1, id as tracker_id, $2, $3"+
				"FROM announce_tracker "+
				"ON CONFLICT DO NOTHING",
			dr.Uid, time.Now().Unix()+int64(time.Hour.Seconds())*48, "remove",
		)
		if err != nil {
			return ErrorWrapper(http.StatusInternalServerError, err)
		}
	} else {
		packedInfo, err = links.Pack()
		if err != nil {
			return ErrorWrapper(http.StatusInternalServerError, err)
		}
		_, err = ss.db.TxExec(ctx, tx, "UPDATE resource SET info = $1, ts_used = 2 WHERE id = $2", packedInfo, dr.Uid)
	}

	err = tx.Commit(ctx)
	if err != nil {
		return ErrorWrapper(http.StatusInternalServerError, err)
	}

	resp = &pb.Status{Code: 200, Message: "OK"}
	return
}

func (ss *SkybonedServer) Get(ctx context.Context, gr *pb.GetRequest) (res *pb.Resource, err error) {
	requestID := util.GenerateRequestID()
	ctx = context.WithValue(ctx, auth.RequestID{}, requestID)
	gr.Uid = strings.TrimPrefix(gr.Uid, "rbtorrent:")

	md, ok := metadata.FromIncomingContext(ctx)
	if !ok {
		zap.S().Errorw(
			requestID,
			"rbtorrent", gr.Uid,
			"err", "could not retrieve metadata")
	} else {
		zap.S().Infow(
			requestID,
			"rbtorrent", gr.Uid,
			"addr", md["x-forwarded-for"][0],
			"handle", "/get_resource",
		)
	}

	pgrow := ss.db.QueryRow(ctx, "SELECT resource.id, type, data, info, mode, owner.name, ts_added, ts_used FROM resource JOIN owner on resource.owner = owner.id WHERE resource.id = $1", gr.Uid)
	resource, err := api.NewResourceFromRow(pgrow)
	if err != nil {
		if errors.Is(err, pgx.ErrNoRows) {
			return &pb.Resource{}, status.Error(errorHTTPtoCode[http.StatusNotFound], fmt.Sprintf("rbtorrent:%s not found!", gr.Uid))
		}
		zap.S().Error(err)
		return res, status.Error(errorHTTPtoCode[http.StatusInternalServerError], fmt.Sprintf("%s: %s", gr.Uid, err.Error()))
	}
	res = &pb.Resource{}
	res.Uid = resource.UID
	res.Type = resource.Type
	res.Head, err = structpb.NewStruct(resource.Head)
	if err != nil {
		zap.S().Error(err)
		return res, status.Error(errorHTTPtoCode[http.StatusInternalServerError], fmt.Sprintf("%s: %s", gr.Uid, err.Error()))
	}
	res.Info = map[string]*pb.Links{}
	for hash, links := range resource.Info {
		res.Info[hash] = &pb.Links{Links: map[string]*pb.LinkOpts{}}
		for link, linkopts := range links {
			res.Info[hash].Links[link] = &pb.LinkOpts{Linkopts: linkopts}
		}
	}
	res.Mode = "plain"
	res.Owner = resource.Owner
	res.TsAdded = timestamppb.New(time.Unix(int64(resource.TScreated), 0))
	res.TsUsed = resource.TSused
	return res, nil
}

var (
	ErrSkybonedForbidden = status.Error(codes.PermissionDenied, "skyboned: forbidden")
	ErrSkybonedTimeout   = status.Error(codes.Unavailable, "skyboned: timeout")
)
