package main

import (
	"bytes"
	"fmt"
	"text/template"
	"io/ioutil"
	"log"
	"os"
	"path/filepath"
	"sort"
	"strings"

	"github.com/golang/protobuf/proto"
	pbplugin "github.com/golang/protobuf/protoc-gen-go/plugin"
	"github.com/pkg/errors"

	"code.justin.tv/eventbus/schema/cmd/internal/recon"
	"code.justin.tv/eventbus/schema/cmd/internal/util"
)

var errLog = log.New(os.Stderr, "", log.LstdFlags)

func main() {
	err := codegen()
	if err != nil {
		errLog.Printf("Fatal error: %s\n", err.Error())
		os.Exit(1)
	}
}

func codegen() error {
	data, err := ioutil.ReadAll(os.Stdin)
	if err != nil {
		return errors.Wrap(err, "could not read input")
	}
	var request pbplugin.CodeGeneratorRequest

	if err := proto.Unmarshal(data, &request); err != nil {
		return errors.Wrap(err, "could not parse input proto")
	}

	if len(request.FileToGenerate) == 0 {
		return errors.New("no files to generate")
	}

	response, err := buildFiles(&request)
	if err != nil {
		return errors.Wrap(err, "could not build generated code")
	}

	data, err = proto.Marshal(response)
	if err != nil {
		return errors.Wrap(err, "could not marshal output proto")
	}
	_, err = os.Stdout.Write(data)
	if err != nil {
		return errors.Wrap(err, "could not write output proto")
	}
	return nil
}

func buildFiles(req *pbplugin.CodeGeneratorRequest) (*pbplugin.CodeGeneratorResponse, error) {
	lookup := recon.Recon(req)
	var response pbplugin.CodeGeneratorResponse

	for _, filename := range req.FileToGenerate {
		f := lookup.Files[filename]

		expectedName, expectedAction, err := util.ExpectedMessageName(filename)
		if err != nil {
			return nil, err
		}

		var input struct {
			EventType       string
			PackageName     string
			GenTypes        []msgWrap
			MessageNames    []string
			EncryptedFields []recon.EncryptedField
		}

		input.PackageName = getPackageName(f.GetOptions().GetGoPackage()) // package name == containing dir name (golang convention)
		input.EventType = expectedName

		var ignoredTypes []string
		for _, t := range f.MessageTypes {
			if t.GetName() == expectedName {
				input.GenTypes = append(input.GenTypes, msgWrap{
					MessageType: t,
					action:      expectedAction,
				})
			} else {
				ignoredTypes = append(ignoredTypes, t.GetName())
			}
			input.MessageNames = append(input.MessageNames, t.GetName())
		}
		sort.Strings(input.MessageNames)

		if len(input.GenTypes) != 1 {
			return nil, fmt.Errorf("Expected %s to contain a message type named %s. Only found these: %v", filename, expectedName, ignoredTypes)
		}

		input.EncryptedFields = recon.GetEncryptedFields(expectedName, f)

		var buf bytes.Buffer
		err = fileTemplate.Execute(&buf, input)
		if err != nil {
			errLog.Print(err.Error())
			return nil, err
		}

		fileNameWithoutExt := strings.Split(filepath.Base(f.GetName()), ".")[0]
		outFile := &pbplugin.CodeGeneratorResponse_File{
			Name:    proto.String(fmt.Sprintf("%s/%s.eventbus.go", filepath.Dir(f.GetName()), fileNameWithoutExt)),
			Content: proto.String(buf.String()),
		}
		response.File = append(response.File, outFile)
	}
	return &response, nil
}

func getPackageName(goPkgString string) string {
	parts := strings.Split(goPkgString, ";")
	if len(parts) == 1 {
		return filepath.Base(parts[0])
	} else if len(parts) == 2 {
		return parts[1]
	} else {
		log.Fatalf("invalid 'go_package' option '%s'", goPkgString)
		return ""
	}
}

var templateFuncs = template.FuncMap{
	"Title": strings.Title,
}

var fileTemplate = template.Must(template.New("gen.go").Funcs(templateFuncs).Parse(`// NOTE this is a generated file! do not edit!

package {{ .PackageName }}

import (
	"context"

	"github.com/golang/protobuf/proto"

	eventbus "code.justin.tv/eventbus/client"
{{- if .EncryptedFields}}

	"code.justin.tv/eventbus/schema/pkg/eventbus/authorization"
	"code.justin.tv/eventbus/client/encryption"
{{- end}}
)
{{ $eventType := .EventType }}
const (
{{- range .GenTypes }}
	{{ .EventTypeConst }} = "{{ .Name }}"
{{- end }}
)

{{ range .GenTypes -}}
type {{ .ActionType }} = {{ .Name }}

type {{ .Name }}Handler func(context.Context, *eventbus.Header, *{{ .Name }}) error

func (h {{ .Name }}Handler) Handler() eventbus.Handler {
	return func(ctx context.Context, message eventbus.RawMessage) error {
		dst := &{{ .Name }}{}
		err := proto.Unmarshal(message.Payload, dst)
		if err != nil {
			return err
		}
		return h(ctx, message.Header, dst)
	}
}

func Register{{ .Name }}Handler(mux *eventbus.Mux, f {{ .Name }}Handler) {
	mux.RegisterHandler({{ .EventTypeConst }}, f.Handler())
}

func Register{{ .ActionType }}Handler(mux *eventbus.Mux, f {{ .Name }}Handler) {
	Register{{ .Name }}Handler(mux, f)
}

func (*{{ .Name }}) EventBusName() string {
	return {{ .EventTypeConst }}
}
{{ end -}}

{{ range .EncryptedFields -}}
func get{{ .EventType}}{{ .MessageName }}{{ .FieldName }}AuthContext(environment string) authorization.Context {
	return authorization.Context{
		authorization.EventType:   "{{ .EventType }}",
		authorization.Environment: environment,
		authorization.MessageName: "{{ .MessageName }}",
		authorization.FieldName:   "{{ .FieldName }}",
	}
}

func (m *{{ .MessageName}}) SetEncrypted{{ .FieldName }}(p encryption.Provider, plaintext {{.GoType}}) error {
	if m == nil {
		return authorization.ErrEncryptionNilReceiver
	}

	encCtx := get{{.EventType}}{{.MessageName}}{{.FieldName}}AuthContext(p.Environment())
	err := authorization.ValidateContext(encCtx)
	if err != nil {
		return err
	}

	enc := p.Encrypter()
	b, err := enc.Encrypt{{.GoType | Title}}(encCtx, plaintext)
	if err != nil {
		return err
	}
	
	m.{{.FieldName}} = &authorization.{{ .GoType | Title }}{
		OlePayload: b,
	}
	return err
}

func (m *{{ .MessageName}}) GetDecrypted{{ .FieldName }}(dec encryption.Decrypter) ({{ .GoType }}, error) {
	if m == nil {
		return {{ .ZeroValue }}, authorization.ErrDecryptionNilReceiver
	}
	if m.Get{{.FieldName}}() == nil {
		return {{ .ZeroValue }}, nil
	}
	return dec.Decrypt{{ .GoType | Title }}(m.Get{{.FieldName}}().GetOlePayload())
}
{{ end -}}
`))

type LineWriter struct {
	lines  []string
	indent string
}

func (w *LineWriter) Indent() {
	w.indent += "\t"
}
func (w *LineWriter) Dedent() {
	w.indent = w.indent[:len(w.indent)-1]
}

func (w *LineWriter) P(bits ...string) {
	w.lines = append(w.lines, w.indent+strings.Join(bits, ""))
}

func (w *LineWriter) MultiRaw(s string) {
	w.lines = append(w.lines, strings.Split(s, "\n")...)
}

type msgWrap struct {
	*recon.MessageType
	action string
}

func (m msgWrap) LowerName() string {
	s := m.ShortName()
	return strings.ToLower(s[:1]) + s[1:]
}

func (m msgWrap) EventTypeConst() string {
	return m.action + "EventType"
}

func (m msgWrap) ActionType() string {
	return m.action
}
