package services

import (
	"context"
	"crypto/hmac"
	"crypto/sha1"
	"encoding/base64"
	"errors"
	"fmt"
	"time"

	"github.com/golang/protobuf/ptypes"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/types/known/timestamppb"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/xray/internal/db"
	"a.yandex-team.ru/security/xray/internal/dbmodels"
	"a.yandex-team.ru/security/xray/internal/queue"
	"a.yandex-team.ru/security/xray/internal/servers/grpc/auth"
	"a.yandex-team.ru/security/xray/internal/servers/grpc/infra"
	"a.yandex-team.ru/security/xray/internal/stagehealth"
	"a.yandex-team.ru/security/xray/pkg/xrayerrors"
	"a.yandex-team.ru/security/xray/pkg/xrayrpc"
	"a.yandex-team.ru/yp/go/yp"
)

var (
	_ xrayrpc.StageServiceServer = (*StageService)(nil)

	errNoStage      = errors.New("no stage specified")
	errNoAuthInfo   = errors.New("no auth info provided")
	errAccessDenied = errors.New("access denied")
)

type StageService struct {
	*infra.Infra
}

func (s *StageService) Schedule(
	ctx context.Context,
	req *xrayrpc.StageScheduleRequest,
) (*xrayrpc.StageScheduleReply, error) {

	if req.Stage == nil {
		return nil, errNoStage
	}

	if err := s.checkStageAuth(ctx, req.Stage.Id); err != nil {
		return nil, status.Error(codes.PermissionDenied, err.Error())
	}

	analyzeID, err := s.scheduleStage(ctx, req.Stage, req.Description, req.Force)
	if err != nil {
		return nil, err
	}

	return &xrayrpc.StageScheduleReply{
		AnalyzeId: s.signAnalyzeID(analyzeID),
	}, nil
}

func (s *StageService) ScheduleLatest(
	ctx context.Context,
	req *xrayrpc.StageScheduleLatestRequest,
) (*xrayrpc.StageScheduleLatestReply, error) {

	if req.StageId == "" {
		return nil, errNoStage
	}

	if err := s.checkStageAuth(ctx, req.StageId); err != nil {
		return nil, status.Error(codes.PermissionDenied, err.Error())
	}

	stage, err := s.fetchStage(ctx, req.StageId)
	if err != nil {
		return nil, err
	}

	analyzeID, err := s.scheduleStage(ctx, stage, req.Description, req.Force)
	if err != nil {
		return nil, err
	}

	return &xrayrpc.StageScheduleLatestReply{
		StageId:       stage.Id,
		StageUuid:     stage.Uuid,
		StageRevision: stage.Revision,
		AnalyzeId:     analyzeID,
	}, nil
}

func (s *StageService) scheduleStage(
	ctx context.Context,
	stage *xrayrpc.Stage,
	description string,
	force bool,
) (string, error) {

	// prevalidation
	if !force {
		latestRevision, err := s.DB.LookupLatestStageRevision(ctx, stage.Uuid)
		if err != nil {
			if err != db.ErrNotFound {
				return "", fmt.Errorf("failed to get latest stage info: %w", err)
			}
		} else {
			if latestRevision >= stage.Revision {
				return "", status.Error(xrayerrors.ErrCodeConflictSchedule, "stage already analyzed")
			}
		}
	}

	message, err := proto.Marshal(&xrayrpc.InternalAnalyzeRequest{
		Stage: stage,
	})
	if err != nil {
		return "", fmt.Errorf("failed to schedule stage: %w", err)
	}

	msg := base64.RawStdEncoding.EncodeToString(message)
	var analyzeID string
	err = s.DB.ScheduleStageAnalysis(ctx, dbmodels.ScheduleLatestStageData{
		StageID:       stage.Id,
		StageUUID:     stage.Uuid,
		StageRevision: stage.Revision,
		Description:   description,
		CreatedAt:     time.Now(),
		Status:        xrayrpc.AnalysisStatusKind_ASK_START,
		Force:         force,
		Scheduler: func() (string, error) {
			var err error
			analyzeID, err = s.Queue.SendMessage(ctx, &queue.SendOptions{
				QueueURL: s.Config.RequestsQueueURL(),
				Msg:      msg,
			})
			return analyzeID, err
		},
	})

	if err != nil {
		if err == db.ErrAlreadyAnalyzed {
			return "", status.Error(xrayerrors.ErrCodeConflictSchedule, "stage already analyzed")
		}

		return "", fmt.Errorf("failed to schedule stage: %w", err)
	}

	return s.signAnalyzeID(analyzeID), nil
}

func (s *StageService) GetStatus(
	ctx context.Context,
	req *xrayrpc.StageGetStatusRequest,
) (*xrayrpc.StageGetStatusReply, error) {

	if req.Stage == nil {
		return nil, errNoStage
	}

	if err := s.checkStageAuth(ctx, req.Stage.Id); err != nil {
		return nil, status.Error(codes.PermissionDenied, err.Error())
	}

	stageStatus, err := s.DB.LookupStageStatus(ctx, dbmodels.StageInfo{
		ID:       req.Stage.Id,
		UUID:     req.Stage.Uuid,
		Revision: req.Stage.Revision,
	})
	if err != nil {
		if err == db.ErrNotFound {
			return nil, status.Error(codes.NotFound, "stage not found")
		}
		return nil, fmt.Errorf("failed to lookup stage status: %w", err)
	}

	return &xrayrpc.StageGetStatusReply{
		Status: &xrayrpc.StageStatus{
			UpdatedAt:      timestamppb.New(stageStatus.UpdatedAt),
			AnalysisStatus: stageStatus.Status,
			StageHealth:    stagehealth.Calculate(stageStatus.Overview),
			Issues: &xrayrpc.StageStatus_IssuesCounter{
				Unknown: stageStatus.Overview.Issues.Unknown,
				Info:    stageStatus.Overview.Issues.Info,
				Low:     stageStatus.Overview.Issues.Low,
				Medium:  stageStatus.Overview.Issues.Medium,
				High:    stageStatus.Overview.Issues.High,
			},
			Warnings: stageStatus.Overview.Warnings,
		},
	}, nil
}

func (s *StageService) ListAll(
	ctx context.Context,
	req *xrayrpc.StageListAllRequest,
) (*xrayrpc.StageListAllReply, error) {

	authInfo := auth.ContextAuthInfo(ctx)
	if authInfo == nil {
		return nil, errNoAuthInfo
	}

	if !authInfo.IsAdmin && !authInfo.IsReader {
		// admin can do what he want's
		return nil, errAccessDenied
	}

	var limit int32
	switch {
	case req.Limit > 100:
		return nil, status.Errorf(codes.InvalidArgument, "limit (%d) must be lower than 100", req.Limit)
	case req.Limit > 0:
		limit = req.Limit
	default:
		limit = 50
	}

	statuses, err := s.DB.SelectStagesStatus(ctx, limit)
	if err != nil {
		return nil, status.Errorf(codes.Internal, "failed to retrieve stages information: %s", err)
	}

	out := make([]*xrayrpc.StageStatus, len(statuses))
	for i, stageStatus := range statuses {
		updatedAt, _ := ptypes.TimestampProto(stageStatus.UpdatedAt)
		out[i] = &xrayrpc.StageStatus{
			Stage: &xrayrpc.Stage{
				Id:       stageStatus.ID,
				Uuid:     stageStatus.UUID,
				Revision: stageStatus.Revision,
			},
			UpdatedAt:      updatedAt,
			AnalysisStatus: stageStatus.Status,
			Issues: &xrayrpc.StageStatus_IssuesCounter{
				Unknown: stageStatus.Overview.Issues.Unknown,
				Info:    stageStatus.Overview.Issues.Info,
				Low:     stageStatus.Overview.Issues.Low,
				Medium:  stageStatus.Overview.Issues.Medium,
				High:    stageStatus.Overview.Issues.High,
			},
			Warnings: stageStatus.Overview.Warnings,
		}
	}

	return &xrayrpc.StageListAllReply{Statuses: out}, nil
}

func (s *StageService) List(
	ctx context.Context,
	req *xrayrpc.StageListRequest,
) (*xrayrpc.StageListReply, error) {

	var limit int32
	switch {
	case req.Limit > 100:
		return nil, status.Errorf(codes.InvalidArgument, "limit (%d) must be lower than 100", req.Limit)
	case req.Limit > 0:
		limit = req.Limit
	default:
		limit = 50
	}

	stages, err := s.fetchUserStages(ctx, limit)
	if err != nil {
		return nil, status.Error(codes.Internal, err.Error())
	}

	if len(stages) == 0 {
		return &xrayrpc.StageListReply{}, nil
	}

	statuses, err := s.DB.SelectLatestStagesStatus(ctx, stages, limit)
	if err != nil {
		return nil, status.Errorf(codes.Internal, "failed to retrieve stages information: %s", err)
	}

	out := make([]*xrayrpc.StageStatus, len(statuses))
	for i, stageStatus := range statuses {
		updatedAt, _ := ptypes.TimestampProto(stageStatus.UpdatedAt)
		out[i] = &xrayrpc.StageStatus{
			Stage: &xrayrpc.Stage{
				Id:       stageStatus.ID,
				Uuid:     stageStatus.UUID,
				Revision: stageStatus.Revision,
			},
			UpdatedAt:      updatedAt,
			AnalysisStatus: stageStatus.Status,
			StageHealth:    stagehealth.Calculate(stageStatus.Overview),
			Issues: &xrayrpc.StageStatus_IssuesCounter{
				Unknown: stageStatus.Overview.Issues.Unknown,
				Info:    stageStatus.Overview.Issues.Info,
				Low:     stageStatus.Overview.Issues.Low,
				Medium:  stageStatus.Overview.Issues.Medium,
				High:    stageStatus.Overview.Issues.High,
			},
			Warnings: stageStatus.Overview.Warnings,
		}
	}

	return &xrayrpc.StageListReply{Statuses: out}, nil
}

func (s *StageService) GetResults(
	ctx context.Context,
	req *xrayrpc.StageGetResultsRequest,
) (*xrayrpc.StageGetResultsReply, error) {

	if req.Stage == nil {
		return nil, errNoStage
	}

	if err := s.checkStageAuth(ctx, req.Stage.Id); err != nil {
		return nil, status.Error(codes.PermissionDenied, err.Error())
	}

	dbProject, err := s.DB.LookupStageAnalysis(ctx, dbmodels.StageInfo{
		ID:       req.Stage.Id,
		UUID:     req.Stage.Uuid,
		Revision: req.Stage.Revision,
	})
	if err != nil {
		if err == db.ErrNotFound {
			return nil, status.Error(codes.NotFound, "stage not found")
		}
		return nil, fmt.Errorf("failed to lookup stage results: %w", err)
	}

	switch dbProject.Status {
	case xrayrpc.AnalysisStatusKind_ASK_UNSPECIFIED:
		return nil, status.Error(xrayerrors.ErrCodeInvalidStageStatus, "stage analysis in 'UNSPECIFIED' state, something really shit happens")
	default:
	}

	updatedAt, _ := ptypes.TimestampProto(dbProject.UpdatedAt)
	result := &xrayrpc.StageGetResultsReply{
		AnalyzeId:         s.signAnalyzeID(dbProject.ID),
		UpdatedAt:         updatedAt,
		AnalysisStatus:    dbProject.Status,
		StageHealth:       stagehealth.Calculate(dbProject.Overview),
		StatusDescription: dbProject.StatusDescription,
		Results:           new(xrayrpc.AnalyzeResult),
	}

	if dbProject.LogPath != "" {
		result.LogUri = s.S3Storage.FileURI(dbProject.LogPath)
	}

	if dbProject.ResultPath != "" {
		result.Results = new(xrayrpc.AnalyzeResult)
		rawResult, err := s.S3Storage.DownloadFile(dbProject.ResultPath)
		if err != nil {
			s.Logger.Error("failed to download results", log.Error(err), log.String("path", dbProject.ResultPath))
		} else {
			if err = proto.Unmarshal(rawResult, result.Results); err != nil {
				s.Logger.Error("failed to parse results", log.Error(err), log.String("path", dbProject.ResultPath))
			}
		}
	}

	return result, nil
}

func (s *StageService) fetchUserStages(ctx context.Context, limit int32) ([]string, error) {
	authInfo := auth.ContextAuthInfo(ctx)
	if authInfo == nil {
		return nil, errNoAuthInfo
	}

	if authInfo.UserLogin == "" {
		return nil, nil
	}

	userAccessRSP, err := s.YP.GetUserAccessAllowedTo(context.Background(), yp.UserAccessAllowedToRequest{
		Permissions: []yp.UserAccessAllowedToPermission{{
			UserID:     authInfo.UserLogin,
			Limit:      limit,
			ObjectType: yp.ObjectTypeStage,
			Permission: yp.AccessControlPermissionWrite,
		}},
	})
	if err != nil {
		return nil, fmt.Errorf("failed to get user %q stages: %w", authInfo.UserLogin, err)
	}

	selectStagesRSP, err := s.YP.GetObjects(context.Background(), yp.GetObjectsRequest{
		ObjectType: yp.ObjectTypeStage,
		ObjectIDs:  userAccessRSP.Objects[0].ObjectIDs,
		Format:     yp.PayloadFormatYson,
		Selectors:  []string{"/meta/uuid"},
	})
	if err != nil {
		return nil, fmt.Errorf("failed to get user %q stages info: %w", authInfo.UserLogin, err)
	}

	stageUUIDs := make([]string, selectStagesRSP.Count())
	i := 0
	for selectStagesRSP.Next() {
		err = selectStagesRSP.Fill(&stageUUIDs[i])
		if err != nil {
			panic(err)
		}
		i++
	}

	return stageUUIDs, selectStagesRSP.Error()
}

func (s *StageService) checkStageAuth(ctx context.Context, stageID string) error {
	authInfo := auth.ContextAuthInfo(ctx)
	if authInfo == nil {
		return errNoAuthInfo
	}

	if authInfo.IsAdmin || authInfo.IsReader {
		// admin can do what he want's
		return nil
	}

	// TODO(buglloc): check permissions only by stage_id is incorrect!
	rsp, err := s.YP.CheckObjectPermissions(ctx, yp.CheckObjectPermissionsRequest{
		Permissions: []yp.CheckObjectPermission{
			{
				ObjectID:   stageID,
				ObjectType: yp.ObjectTypeStage,
				Permission: yp.AccessControlPermissionWrite,
				SubjectID:  authInfo.UserLogin,
			},
		},
	})
	if err != nil {
		return fmt.Errorf("failed to check user permissions for stage %q: %w", stageID, err)
	}

	if len(rsp.Permissions) == 0 || rsp.Permissions[0].Action != yp.AccessControlActionAllow {
		return errAccessDenied
	}
	return nil
}

func (s *StageService) fetchStage(ctx context.Context, stageID string) (*xrayrpc.Stage, error) {
	rsp, err := s.YP.GetStage(ctx, yp.GetStageRequest{
		Format:    yp.PayloadFormatYson,
		ID:        stageID,
		Selectors: []string{"/meta/uuid", "/spec/revision"},
	})
	if err != nil {
		return nil, err
	}

	result := &xrayrpc.Stage{
		Id: stageID,
	}

	if err = rsp.Fill(&result.Uuid, &result.Revision); err != nil {
		return nil, fmt.Errorf("failed to get stage info: %w", err)
	}

	return result, nil
}

func (s *StageService) signAnalyzeID(id string) string {
	mac := hmac.New(sha1.New, []byte(s.Config.SignKey))
	_, _ = mac.Write([]byte(id))
	sign := base64.StdEncoding.EncodeToString(mac.Sum(nil))
	return fmt.Sprintf("%s:%s", id, sign)
}
