package main

import (
	"strings"

	"github.com/golang/protobuf/proto"
	"github.com/golang/protobuf/protoc-gen-go/descriptor"
	"github.com/pkg/errors"

	"code.justin.tv/eventbus/schema/cmd/internal/util"
	"code.justin.tv/eventbus/schema/pkg/eventbus/authorization"
)

type DataClassification int

const (
	idSuffix = "_id"

	DataClassificationMissing    DataClassification = 0
	DataClassificationCustomer   DataClassification = 1
	DataClassificationRestricted DataClassification = 2
	DataClassificationInternal   DataClassification = 3
	DataClassificationPublic     DataClassification = 4
	DataClassificationInvalid    DataClassification = 5
)

var (
	forbiddenFieldNameText = []string{"channel_id", "broadcaster_id"}
)

// note that error messages returned should not have a newline in them
// because of the way prototool parses the output of the plugin
func lintFields(file *descriptor.FileDescriptorProto) error {
	for _, field := range fields(file) {
		if isForbiddenFieldName(field.GetName()) {
			return errors.Errorf("forbidden field name %q", field.GetName())
		}

		if hasIDSuffix(field.GetName()) && !fieldIsString(field) {
			return errors.Errorf("%s field names must be of type string", idSuffix)
		}

		if file.GetPackage() != "clock" && fieldIsTimestamp(field) && !strings.HasSuffix(field.GetName(), "_at") {
			return errors.Errorf("%s is a timestamp field which should end in '_at'", field.GetName())
		}

		if err := validateDataClassification(fieldDataClassification(field), util.IsAuthorizedField(field)); err != nil {
			return errors.Wrap(err, "invalid data classification")
		}
	}

	return nil
}

func fields(file *descriptor.FileDescriptorProto) []*descriptor.FieldDescriptorProto {
	var fields []*descriptor.FieldDescriptorProto
	for _, message := range file.GetMessageType() {
		fields = append(fields, message.GetField()...)
	}

	return fields
}

func isForbiddenFieldName(name string) bool {
	for _, forbidden := range forbiddenFieldNameText {
		if strings.Contains(name, forbidden) {
			return true
		}
	}

	return false
}

func fieldIsTimestamp(field *descriptor.FieldDescriptorProto) bool {
	switch field.GetType() {
	case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
		return strings.HasSuffix(field.GetTypeName(), ".google.protobuf.Timestamp")
	}
	return false
}

func hasIDSuffix(name string) bool {
	return strings.HasSuffix(name, idSuffix) || name == "id"
}

func fieldIsString(field *descriptor.FieldDescriptorProto) bool {
	switch field.GetType() {
	case descriptor.FieldDescriptorProto_TYPE_STRING:
		return true
	case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
		// Allow the eventbus StringChange and protobuf StringValue to qualify for a string field
		return strings.HasSuffix(field.GetTypeName(), ".change.StringChange") ||
			strings.HasSuffix(field.GetTypeName(), ".authorization.String") ||
			strings.HasSuffix(field.GetTypeName(), "google.protobuf.StringValue") ||
			strings.HasSuffix(field.GetTypeName(), ".change.StringValueChange")
	}
	return false
}

func fieldDataClassification(field *descriptor.FieldDescriptorProto) DataClassification {
	fieldOpts := field.GetOptions()
	if fieldOpts == nil {
		return DataClassificationMissing
	}

	i, err := proto.GetExtension(fieldOpts, authorization.E_DataClassification)
	if err != nil {
		return DataClassificationMissing
	}

	classification, ok := i.(*string)
	if !ok || classification == nil {
		return DataClassificationMissing
	}

	switch *classification {
	case "customer":
		return DataClassificationCustomer
	case "restricted":
		return DataClassificationRestricted
	case "internal":
		return DataClassificationInternal
	case "public":
		return DataClassificationPublic
	default:
		return DataClassificationInvalid
	}
}

func validateDataClassification(classification DataClassification, isAuthorizedField bool) error {
	switch classification {
	case DataClassificationCustomer, DataClassificationRestricted:
		if !isAuthorizedField {
			return errors.New("fields marked 'restricted' or 'customer' must be of type eventbus.authorization.<Type>")
		}
		return nil
	case DataClassificationInternal, DataClassificationPublic:
		return nil
	case DataClassificationMissing:
		if isAuthorizedField {
			return errors.New("fields of type eventbus.authorization.<Type> must have a data classification annotation")
		}
		return nil
	case DataClassificationInvalid:
		return errors.New("invalid data classification, must be one of [public, internal, restricted, customer]")
	default:
		return errors.New("unexpected data classification") // shouldn't happen
	}
}
