package awacs

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"log"
	"net/http"
	"sort"
	"strings"
	"time"

	"a.yandex-team.ru/infra/awacs/clients/go/awacs"
	"a.yandex-team.ru/infra/temporal/clients/staff"
	"go.temporal.io/sdk/activity"
	"google.golang.org/protobuf/types/known/fieldmaskpb"
	"k8s.io/apimachinery/pkg/util/wait"

	awacspb "a.yandex-team.ru/infra/awacs/proto"
	"a.yandex-team.ru/library/go/slices"

	"a.yandex-team.ru/infra/temporal/workflows/startreker/processor"
)

const pollingInterval = 10 * time.Second

const WatchHeartbeatInterval = time.Minute * 3
const WatchHeartbeatIntervalJitter = time.Minute * 3
const PollIntervalInHeartbeats = 4
const WatchHeartbeatTimeout = time.Minute * 8

type clients struct {
	awacs *client
	staff *staff.Client
}

type Activities struct {
	clients
}

type WatchingActivities struct {
	clients
}

func newClients(awacsURL, awacsToken string, staffClientConfig *staff.ClientConfig) clients {
	return clients{
		awacs: NewClient(awacsURL, awacsToken, WithPermanentErrorCheck(func(err *awacs.APIError) bool {
			return slices.ContainsInt32([]int32{400, 404}, err.Code)
		})),
		staff: staff.NewClient(staffClientConfig),
	}
}

func NewActivities(awacsURL, awacsToken string, staffClientConfig *staff.ClientConfig) *Activities {
	return &Activities{
		newClients(awacsURL, awacsToken, staffClientConfig),
	}
}

func NewWatchingActivities(awacsURL, awacsToken string, staffClientConfig *staff.ClientConfig) *WatchingActivities {
	return &WatchingActivities{
		newClients(awacsURL, awacsToken, staffClientConfig),
	}
}

func isNotFound(err error) bool {
	var apiErr *awacs.APIError
	return errors.As(err, &apiErr) && apiErr.Code == 404
}

func (a *Activities) ListBalancersByNamespaceID(ctx context.Context, namespaceID string) ([]*awacspb.Balancer, error) {
	balancers, err := a.awacs.ListBalancersByNamespaceID(ctx, namespaceID, nil)
	if err != nil {
		return nil, err
	}
	return balancers, nil
}

func (a *Activities) WaitUntilNamespaceIsUnpausedAndSettled(ctx context.Context, namespaceID string) error {
	var messageType *awacspb.Balancer
	m, _ := fieldmaskpb.New(messageType, "meta", "status")
	balancers, err := a.awacs.ListBalancersByNamespaceID(ctx, namespaceID, m)
	if err != nil {
		return err
	}
	for _, balancer := range balancers {
		if balancer.Meta.GetTransportPaused().GetValue() {
			return fmt.Errorf("balancer %s/%s is paused", namespaceID, balancer.Meta.Id)
		}
		if len(findInProgressBalancerRevisionStatusMany(balancer.Status.Revisions)) > 0 {
			return fmt.Errorf("balancer %s/%s is in progress", namespaceID, balancer.Meta.Id)
		}
	}
	for _, balancer := range balancers {
		state, err := a.awacs.GetBalancerState(ctx, namespaceID, balancer.Meta.Id)
		if err != nil {
			return err
		}
		for _, entry := range listBalancerStateEntries(state) {
			for _, status := range entry.Statuses {
				s := status.GetValidated().Status
				if s != "True" && s != "False" {
					return fmt.Errorf("%s %s is not validated yet",
						BalancerStateEntryTypeNames[entry.Type], entry.FlatID)
				}
			}
		}
	}
	return nil
}

func (a *Activities) WaitUntilBalancerActiveRevCtimeIsGTE(ctx context.Context, namespaceID, balancerID string, timestamp int64) (string, error) {
	i := 0
	for {
		balancer, err := a.awacs.GetBalancer(ctx, namespaceID, balancerID)
		if err != nil {
			if isPermanent(err) {
				return "", err
			} else {
				continue
			}
		}
		activeRev := findActiveBalancerRevisionStatus(balancer.GetStatus().Revisions)
		if activeRev != nil && activeRev.Ctime.GetSeconds() >= timestamp {
			return activeRev.Id, nil
		}

		select {
		case <-time.After(pollingInterval):
			activity.RecordHeartbeat(ctx, i)
		case <-ctx.Done():
			return "", nil
		}
		i++
	}
}

func (a *Activities) GetCertificate(ctx context.Context, namespaceID, certID string) (*awacspb.Certificate, error) {
	cert, err := a.awacs.GetCertificate(ctx, namespaceID, certID)
	return cert, err
}

func (a *Activities) GetCertificateRenewal(ctx context.Context, namespaceID, certID string) (*awacspb.CertificateRenewal, error) {
	certRenewal, err := a.awacs.GetCertificateRenewal(ctx, namespaceID, certID)
	return certRenewal, err
}

func (a *Activities) ListCertificateRenewals(ctx context.Context, query *awacspb.ListCertificateRenewalsRequest_Query) ([]*awacspb.CertificateRenewal, error) {
	req := &awacspb.ListCertificateRenewalsRequest{
		Query:      query,
		SortOrder:  awacspb.SortOrder_ASCEND,
		SortTarget: awacspb.ListCertificateRenewalsRequest_TARGET_CERT_VALIDITY_NOT_AFTER,
	}
	certRenewals, err := a.awacs.ListCertificateRenewals(ctx, req)
	return certRenewals, err
}

func (a *Activities) UnpauseCertificateRenewal(ctx context.Context, namespaceID, certID string, targetDiscoverability *awacspb.DiscoverabilityCondition) error {
	certRenewal, err := a.awacs.GetCertificateRenewal(ctx, namespaceID, certID)
	if err != nil {
		return err
	}
	return a.awacs.SetCertificateRenewalTargetDiscoverability(ctx, namespaceID, certID, certRenewal.GetMeta().Version, targetDiscoverability)
}

func (a *Activities) WaitUntilCertificateIsRenewed(ctx context.Context, namespaceID, certID string) (string, error) {
	i := 0
	for {
		cert, err := a.awacs.GetCertificate(ctx, namespaceID, certID)
		if err != nil {
			if isPermanent(err) {
				return "", err
			} else {
				continue
			}
		}

		certRenewal, err := a.awacs.GetCertificateRenewal(ctx, namespaceID, certID)
		if err != nil {
			if isPermanent(err) {
				return "", err
			} else {
				continue
			}
		}

		if cert.Spec.GetFields().GetSerialNumber() == certRenewal.Spec.GetFields().GetSerialNumber() {
			return cert.Meta.Version, nil
		}

		select {
		case <-time.After(pollingInterval):
			activity.RecordHeartbeat(ctx, i)
		case <-ctx.Done():
			return "", nil
		}
		i++
	}
}

func (a *Activities) ExtendCertificateDiscoverability(ctx context.Context, namespaceID, certID, location string) error {
	cert, err := a.awacs.GetCertificate(ctx, namespaceID, certID)
	if err != nil {
		return err
	}
	meta := cert.GetMeta()
	discoverability := meta.GetDiscoverability()
	if discoverability == nil {
		return fmt.Errorf("certificate %s/%s does not have discoverability set", namespaceID, certID)
	}
	perLocation := discoverability.GetPerLocation()
	if perLocation == nil {
		discoverability.Kind = &awacspb.DiscoverabilityCondition_PerLocation_{
			PerLocation: &awacspb.DiscoverabilityCondition_PerLocation{
				Values: map[string]*awacspb.BoolCondition{
					location: {Value: true},
				},
			},
		}
	} else {
		perLocation.GetValues()[location] = &awacspb.BoolCondition{Value: true}
	}
	return a.awacs.SetCertificateMeta(ctx, namespaceID, certID, meta)
}

func (a *Activities) WaitUntilCertificateIsActiveInBalancers(ctx context.Context, namespaceID, certID, revID string, balancerIds []string) error {
	i := 0
	for {
		cert, err := a.awacs.GetCertificate(ctx, namespaceID, certID)
		if err != nil {
			if isPermanent(err) {
				return err
			} else {
				continue
			}
		}

		perBalancerRevStatus := findCertificateRevisionStatusPerBalancerByID(cert.GetStatuses(), revID)
		if perBalancerRevStatus == nil {
			// Probably just not discovered yet, keep polling
			continue
		}

		activatedBalancerIds := make([]string, 0)
		for flatBalancerID, activeStatus := range perBalancerRevStatus.GetActive() {
			if activeStatus.Status == "True" {
				activatedBalancerIds = append(activatedBalancerIds, strings.Split(flatBalancerID, ":")[1])
			}
		}
		if slices.ContainsAllStrings(activatedBalancerIds, balancerIds) {
			return nil
		}

		select {
		case <-time.After(pollingInterval):
			activity.RecordHeartbeat(ctx, i)
		case <-ctx.Done():
			return nil
		}
		i++
	}
}

func (a *Activities) MakeCertificateDiscoverableByDefault(ctx context.Context, namespaceID, certID string) error {
	cert, err := a.awacs.GetCertificate(ctx, namespaceID, certID)
	if err != nil {
		return err
	}
	meta := cert.GetMeta()
	discoverability := meta.GetDiscoverability()
	if discoverability == nil {
		return fmt.Errorf("certificate %s/%s does not have discoverability set", namespaceID, certID)
	}
	discoverability.Default = &awacspb.BoolCondition{Value: true}
	return a.awacs.SetCertificateMeta(ctx, namespaceID, certID, meta)
}

func (a *Activities) PauseBalancer(ctx context.Context, namespaceID, balancerID string) error {
	err := a.awacs.PauseBalancer(ctx, namespaceID, balancerID)
	if err != nil {
		return err
	}

	return nil
}

func (a *Activities) UnpauseBalancer(ctx context.Context, namespaceID, balancerID string) error {
	err := a.awacs.UnpauseBalancer(ctx, namespaceID, balancerID)
	if err != nil {
		return err
	}

	return nil
}

func (a *Activities) GetBalancer(ctx context.Context, namespaceID, balancerID string) (*awacspb.Balancer, error) {
	balancerPb, err := a.awacs.GetBalancer(ctx, namespaceID, balancerID)
	if err != nil {
		return nil, err
	}

	return balancerPb, nil
}

func (a *Activities) ListBalancers(ctx context.Context, namespaceID string) ([]*awacspb.Balancer, error) {
	balancers, err := a.awacs.ListBalancersByNamespaceID(ctx, namespaceID, nil)
	if err != nil {
		if isNotFound(err) {
			return nil, NewPermanentError(fmt.Sprintf("namespace %s does not exist", namespaceID), err)
		} else {
			return nil, fmt.Errorf("failed to list balancers in namespace %s: %w", namespaceID, err)
		}
	}

	return balancers, nil
}

func (a *Activities) ListBackends(ctx context.Context, namespaceID string) ([]*awacspb.Backend, error) {
	backends, err := a.awacs.ListBackends(ctx, namespaceID)
	if err != nil {
		return nil, err
	}

	return backends, nil
}

func (a *Activities) ListBackendRevisions(
	ctx context.Context,
	namespaceID string,
	backendID string,
	skip int32,
	limit int32) (*awacspb.ListBackendRevisionsResponse, error) {

	backends, err := a.awacs.ListBackendRevisions(ctx, namespaceID, backendID, skip, limit)
	if err != nil {
		return nil, err
	}

	return backends, nil
}

type InclusionGraph []struct {
	BalancerID          string   `json:"id"`
	NamespaceID         string   `json:"namespace_id"`
	Type                string   `json:"type"`
	IncludedBackendIds  []string `json:"included_backend_ids,omitempty"`
	IncludedDomainIds   []string `json:"included_domain_ids,omitempty"`
	IncludedUpstreamIds []string `json:"included_upstream_ids,omitempty"`
}

func (a *Activities) GetNamespaceAspectsSet(ctx context.Context, namespaceID string) (*awacspb.GetNamespaceAspectsSetResponse, error) {
	rsp, err := a.awacs.GetNamespaceAspectsSet(ctx, namespaceID)
	if err != nil {
		return nil, err

	}
	return rsp, nil
}

func (a *Activities) GetLoadStatisticsEntry(ctx context.Context, time time.Time) (*awacspb.GetLoadStatisticsEntryResponse, error) {
	rsp, err := a.awacs.GetLoadStatisticsEntry(ctx, time)
	if err != nil {
		return nil, err
	}
	return rsp, nil
}

type NamespaceRps struct {
	NamespaceID string
	Rps         float32
}

func (a *Activities) GetYesterdayMaxRpsStatsByNamespace(ctx context.Context) ([]NamespaceRps, error) {

	t := time.Now().UTC().Truncate(24 * time.Hour)
	rsp, err := a.awacs.GetLoadStatisticsEntry(ctx, t)
	if err != nil {
		return nil, err
	}

	res := make([]NamespaceRps, 0, len(rsp.Entry.DateStatistics.ByBalancer))
	for namespaceID, entryPb := range rsp.Entry.DateStatistics.ByNamespace {
		res = append(res, NamespaceRps{namespaceID, entryPb.Max})
	}

	sort.Slice(res, func(i, j int) bool { return res[i].NamespaceID < res[j].NamespaceID })

	return res, nil
}

func (a *Activities) UpdateBackend(
	ctx context.Context, specPb *awacspb.BackendSpec,
	namespaceID, backendID, version, comment string) (*awacspb.UpdateBackendResponse, error) {

	rsp, err := a.awacs.UpdateBackend(ctx, specPb, namespaceID, backendID, version, comment)
	return rsp, err
}

func (a *Activities) DoesEndpointsetExist(ctx context.Context, cluster, endpointSetID string) (bool, error) {
	const url = "http://sd.yandex.net:8080/resolve_endpoints/json"

	body, err := json.Marshal(map[string]string{
		"endpoint_set_id": endpointSetID,
		"client_name":     "awacssdtool-temporal",
		"cluster_name":    strings.ToLower(cluster),
	})

	if err != nil {
		return false, err
	}

	req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(body))
	if err != nil {
		return false, err
	}
	req.Header.Set("Content-Type", "application/json")

	rsp, err := http.DefaultClient.Do(req)
	if err != nil {
		return false, err
	}
	defer func() {
		if err := rsp.Body.Close(); err != nil {
			log.Println(err)
		}
	}()

	var objMap map[string]json.RawMessage

	b, err := ioutil.ReadAll(rsp.Body)
	if err != nil {
		return false, err
	}

	if err = json.Unmarshal(b, &objMap); err != nil {
		return false, err
	}

	var resolveStatus int
	err = json.Unmarshal(objMap["resolve_status"], &resolveStatus)
	if err != nil {
		return false, err
	}

	return resolveStatus == 2, nil
}

func (a *Activities) WaitUntilBackendsRevisionsStatusesActive(
	ctx context.Context, namespaceID string, backendRevisions []string, unpausedBalancers uint) error {

	const heartbeatInterval = 10 * time.Second

	logger := activity.GetLogger(ctx)

	for {
		backends, err := a.awacs.ListBackends(ctx, namespaceID)
		if err != nil {
			if isPermanent(err) {
				return err
			} else {
				continue
			}
		}

		ok, err := func() (bool, error) {
			for _, backendRev := range backendRevisions {
				backendFound := false
				for _, backend := range backends {
					if backendRev != backend.Meta.Version {
						continue
					}
					backendFound = true

					for _, status := range backend.Statuses {
						if backendRev != status.Id {
							continue
						}
						if status.Active == nil {
							logger.Info(fmt.Sprintf("backend %s, revision %s status.Active == nil",
								backend.Meta.Id, status.Id))
							return false, nil
						}

						cnt := uint(0)
						for _, condition := range status.Active {
							if condition.Status == "True" {
								cnt += 1
							}
						}
						if cnt < unpausedBalancers {
							logger.Info(fmt.Sprintf("backend %v, revision %v, balancers with Active status %v < %v",
								backend.Meta.Id, backendRev, cnt, unpausedBalancers))
							return false, nil
						}
					}
				}
				if !backendFound {
					return false, fmt.Errorf("backend revision %s not found", backendRev)
				}
			}

			return true, nil
		}()

		if err != nil {
			return err
		}

		if ok {
			return nil
		}

		select {
		case <-time.After(heartbeatInterval):
			activity.RecordHeartbeat(ctx)
		case <-ctx.Done():
			return nil
		}
	}
}

func (a *Activities) ListNamespaceIds(ctx context.Context) ([]string, error) {
	summaries, err := a.awacs.ListSummaries(ctx)
	if err != nil {
		return nil, fmt.Errorf("failed to list namespace summaries: %w", err)
	}
	rv := make([]string, len(summaries))
	for i, summary := range summaries {
		rv[i] = summary.Id
	}
	return rv, nil
}

func (a *Activities) ExtractResponsible(ctx context.Context, namespaceID string) (*processor.Responsible, error) {
	namespace, err := a.awacs.GetNamespace(ctx, namespaceID)
	if err != nil {
		return nil, fmt.Errorf("failed to get namespace %s: %w", namespaceID, err)
	}
	var observersLogins []string
	var responsible processor.Responsible
	observers := namespace.Meta.GetObservers()
	if observers != nil && observers.Enabled {
		switch subject := observers.Subject.(type) {
		case *awacspb.Observers_OnDuty_:
			responsible.Kind = processor.AbcSchedule
			responsible.AbcScheduleID = int(subject.OnDuty.AbcScheduleId)
			return &responsible, nil
		case *awacspb.Observers_Persons_:
			observersLogins = append(observersLogins, subject.Persons.StaffLogins...)
			if len(observersLogins) > 0 {
				persons, err := a.staff.ListPersons(observersLogins)
				if err != nil {
					return nil, fmt.Errorf("failed to list persons: %w", err)
				}
				responsible.Logins = staff.FilterActualPersonsLogins(persons)
			}
		}
	}
	responsible.Kind = processor.Logins
	if len(responsible.Logins) == 0 {
		responsible.Logins, err = a.extractNamespaceOwners(namespace)
		if err != nil {
			return nil, err
		}
	}
	return &responsible, nil
}

func (a *Activities) extractNamespaceOwners(namespace *awacspb.Namespace) ([]string, error) {
	var err error
	logins := make([]string, 0)
	logins = append(logins, namespace.Meta.Auth.Staff.Owners.Logins...)
	var ownersLogins []string

	var persons []*staff.Person
	if len(logins) > 0 {
		persons, err = a.staff.ListPersons(logins)
		if err != nil {
			return nil, fmt.Errorf("failed to list persons: %w", err)
		}
		ownersLogins = staff.FilterActualPersonsLogins(persons)
		if len(ownersLogins) > 0 {
			return ownersLogins, nil
		}
	}

	// There is no real login in owners, let's check groups
	groupIDs := make([]string, 0)
	groupIDs = append(groupIDs, namespace.Meta.Auth.Staff.Owners.GroupIds...)

	if len(groupIDs) > 0 {
		persons, err = a.staff.ListGroupsMembers(groupIDs)
		if err != nil {
			return nil, fmt.Errorf("failed to list group members: %w", err)
		}
		ownersLogins = staff.FilterActualPersonsLogins(persons)
	}

	if len(ownersLogins) == 0 {
		// !!! No real owners for namespace
		return []string{"ferenets"}, nil
	}

	return ownersLogins, nil
}

func (a *Activities) ExtractResponsibleLogins(ctx context.Context, namespaceID string) ([]string, error) {
	namespace, err := a.awacs.GetNamespace(ctx, namespaceID)
	if err != nil {
		return nil, fmt.Errorf("failed to get namespace %s: %w", namespaceID, err)
	}
	return a.extractNamespaceOwners(namespace)
}

func (a *WatchingActivities) WaitUntilStuckBalancerIDsChange(ctx context.Context, namespaceID string, currStuckBalancerIDs []string) ([]string, error) {
	for i := 0; ; i++ {
		select {
		case <-ctx.Done():
			return nil, nil
		case <-time.After(WatchHeartbeatInterval + wait.Jitter(WatchHeartbeatIntervalJitter, 1.0)):
			activity.RecordHeartbeat(ctx)
			if i%PollIntervalInHeartbeats != 0 {
				continue
			}
		}
		balancers, err := a.awacs.ListBalancersByNamespaceID(ctx, namespaceID, nil)
		if err != nil {
			return nil, err
		}

		// Find stuck balancers
		stuckBalancerIds := make([]string, 0)
		for _, balancer := range balancers {
			hasBeenStuck := slices.ContainsString(currStuckBalancerIDs, balancer.Meta.Id)
			var state *awacspb.BalancerState
			if !hasBeenStuck {
				state, err = a.awacs.GetBalancerState(ctx, namespaceID, balancer.Meta.Id)
				if err != nil {
					return nil, err
				}
			}
			if isBalancerStuck(balancer, hasBeenStuck, state) {
				stuckBalancerIds = append(stuckBalancerIds, balancer.Meta.Id)
			}
		}

		if !slices.EqualAnyOrderStrings(currStuckBalancerIDs, stuckBalancerIds) {
			sort.Strings(stuckBalancerIds)
			return stuckBalancerIds, nil
		}
	}
}

func (a *WatchingActivities) WaitUntilStuckL3BalancerIDsChange(ctx context.Context, namespaceID string, currStuckL3BalancerIDs []string) ([]string, error) {
	for i := 0; ; i++ {
		select {
		case <-ctx.Done():
			return nil, nil
		case <-time.After(WatchHeartbeatInterval + wait.Jitter(WatchHeartbeatIntervalJitter, 1.0)):
			activity.RecordHeartbeat(ctx)
			if i%PollIntervalInHeartbeats != 0 {
				continue
			}
		}
		l3Balancers, err := a.awacs.ListL3BalancersByNamespaceID(ctx, namespaceID, nil)
		if err != nil {
			return nil, err
		}

		// Find stuck balancers
		stuckL3BalancerIds := make([]string, 0)
		for _, l3Balancer := range l3Balancers {
			if isL3BalancerStuck(l3Balancer) {
				stuckL3BalancerIds = append(stuckL3BalancerIds, l3Balancer.Meta.Id)
			}
		}

		if !slices.EqualAnyOrderStrings(currStuckL3BalancerIDs, stuckL3BalancerIds) {
			sort.Strings(stuckL3BalancerIds)
			return stuckL3BalancerIds, nil
		}
	}
}

func (a *WatchingActivities) WaitUntilNamespaceRemoved(ctx context.Context, namespaceID string) (bool, error) {
	for i := 0; ; i++ {
		select {
		case <-ctx.Done():
			return false, nil
		case <-time.After(WatchHeartbeatInterval + wait.Jitter(WatchHeartbeatIntervalJitter, 1.0)):
			activity.RecordHeartbeat(ctx)
			if i%PollIntervalInHeartbeats != 0 {
				continue
			}
		}
		_, err := a.awacs.GetNamespace(ctx, namespaceID)
		if err != nil {
			if isNotFound(err) {
				return true, nil
			}
			return false, err
		}
	}
}
