package main

import (
	"encoding/binary"
	"flag"
	"fmt"
	"io"
	"log"
	"net"
	"os"

	"github.com/davecgh/go-spew/spew"
	"github.com/golang/protobuf/proto"

	"a.yandex-team.ru/junk/buglloc/pocs/z2/z2-agent/z2rpc"
	"a.yandex-team.ru/library/go/core/xerrors"
)

type connWrapper struct {
	cc net.Conn
}

func (c *connWrapper) writeMsg(msgType z2rpc.MessageType, msg proto.Message) error {
	specificMessage, err := proto.Marshal(msg)
	if err != nil {
		return xerrors.Errorf("ms marshaling: %w", err)
	}

	outMsg := &z2rpc.Message{
		Type:            msgType.Enum(),
		SpecificMessage: specificMessage,
	}
	fmt.Println("out: ")
	spew.Dump(outMsg)

	payload, err := proto.Marshal(outMsg)
	if err != nil {
		return xerrors.Errorf("outMsg marshaling: %w", err)
	}

	out := make([]byte, 4+len(payload))
	binary.BigEndian.PutUint32(out, uint32(len(payload)))
	copy(out[4:], payload)

	_, err = c.cc.Write(out)
	if err != nil {
		return xerrors.Errorf("z2 writeMsg: %w", err)
	}
	return nil
}

func (c *connWrapper) readMsg() (*z2rpc.Message, error) {
	msg, err := c.readZ2()
	if err != nil {
		return nil, err
	}

	protoMsg := new(z2rpc.Message)
	if err := proto.Unmarshal(msg, protoMsg); err != nil {
		return nil, xerrors.Errorf("unmarshal msg: %w", err)
	}

	fmt.Println("in: ")
	spew.Dump(protoMsg)
	return protoMsg, nil
}

func (c *connWrapper) readZ2() ([]byte, error) {
	msgHeader := make([]byte, 4)
	_, err := io.ReadFull(c.cc, msgHeader)
	if err != nil {
		return nil, xerrors.Errorf("z2 readMsg header: %w", err)
	}

	msgLen := binary.BigEndian.Uint32(msgHeader)
	msg := make([]byte, msgLen)
	_, err = io.ReadFull(c.cc, msg)
	if err != nil {
		return nil, xerrors.Errorf("z2 readMsg msg: %w", err)
	}

	return msg, nil
}

func (c *connWrapper) readSSH() ([]byte, error) {
	msgHeader := make([]byte, 4)
	_, err := io.ReadFull(c.cc, msgHeader)
	if err != nil {
		return nil, xerrors.Errorf("ssh readMsg header: %w", err)
	}

	msgLen := binary.BigEndian.Uint32(msgHeader)
	msg := make([]byte, 4+msgLen)
	copy(msg, msgHeader)
	_, err = io.ReadFull(c.cc, msg[4:])
	if err != nil {
		return nil, xerrors.Errorf("ssh readMsg msg: %w", err)
	}

	return msg, nil
}

func main() {
	var (
		keyID     string
		agentAddr string
	)
	flag.StringVar(&keyID, "ssh-key", "robot-search-secdist", "ssh-key id")
	flag.StringVar(&agentAddr, "addr", "/tmp/z2-agent", "ssh agept uds")
	flag.Parse()

	if err := os.RemoveAll(agentAddr); err != nil {
		log.Fatal(err)
	}

	l, err := net.Listen("unix", agentAddr)
	if err != nil {
		log.Fatal("listen error:", err)
	}
	defer func() { _ = l.Close() }()

	for {
		in, err := l.Accept()
		if err != nil {
			log.Fatal("accept error:", err)
		}

		go func(agentCC connWrapper) {
			cc, err := net.Dial("tcp", "z2.yandex-team.ru:15151")
			if err != nil {
				log.Printf("connection failed: %v\n", err)
				_ = agentCC.cc.Close()
			}

			c := connWrapper{
				cc: cc,
			}

			// wait HELLO msg
			_, err = c.readMsg()
			if err != nil {
				log.Println(err.Error())
				_ = agentCC.cc.Close()
				return
			}

			func() {
				defer func() { _ = agentCC.cc.Close() }()

				for {
					req, err := agentCC.readSSH()
					if err != nil {
						log.Println(err.Error())
						return
					}

					fmt.Println("ssh request:")
					spew.Dump(req)

					// send ssh-agent request
					err = c.writeMsg(z2rpc.MessageType_SSH_AGENT_FORWARD_REQUEST, &z2rpc.SshAgentForwardRequest{
						Id:      &keyID,
						Key:     &keyID,
						Content: req,
					})
					if err != nil {
						log.Println(err.Error())
						return
					}

					// read answer
					msg, err := c.readMsg()
					if err != nil {
						log.Println(err.Error())
						return
					}

					var agentRsp z2rpc.SshAgentForwardResponse
					if err := proto.Unmarshal(msg.SpecificMessage, &agentRsp); err != nil {
						log.Printf("unmarshal agent response: %v\n", err)
						return
					}

					_, err = agentCC.cc.Write(agentRsp.Content)
					if err != nil {
						log.Printf("write rsp: %v\n", err)
						return
					}
				}
			}()

		}(connWrapper{cc: in})
	}
}
