package porto

import (
	"encoding/binary"
	"fmt"
	"math"
	"net"
	"syscall"
	"time"

	"google.golang.org/protobuf/proto"

	portopb "a.yandex-team.ru/infra/porto/proto"
	"a.yandex-team.ru/security/libs/go/porto/internal/netpool"
)

const (
	portoSocket = "/run/portod.socket"

	connTimeout = 500 * time.Second
)

type (
	AvailableProperty struct {
		Name        string
		Description string
	}

	Data struct {
		Name        string
		Description string
	}

	VolumeDescription struct {
		Path       string
		Properties map[string]string
		Containers []string
	}

	StorageDescription struct {
		Name         string
		OwnerUser    string
		OwnerGroup   string
		LastUsage    uint64
		PrivateValue string
	}

	LayerDescription struct {
		Name         string
		OwnerUser    string
		OwnerGroup   string
		LastUsage    uint64
		PrivateValue string
	}

	ImportLayerOpts struct {
		Layer        string
		Tarball      string
		Merge        bool
		Place        string
		PrivateValue string
	}

	ListLayersOpts struct {
		Place string
		Mask  string
	}

	ListVolumesOpts struct {
		Path      string
		Container string
	}

	GetResponse struct {
		Value    string
		Error    int
		ErrorMsg string
	}

	Error struct {
		Errno   portopb.EError
		ErrName string
		Message string
	}

	ConnectionOpts struct {
		// Maximum connections in pool (default=5)
		MaxConnections int
	}

	Connection struct {
		connPool *netpool.NetPool
		err      portopb.EError
		msg      string
	}
)

func (e *Error) Error() string {
	return fmt.Sprintf("[%d] %s: %s", e.Errno, e.ErrName, e.Message)
}

//Connect establishes connection to a Porto daemon via unix socket.
//Close must be called when the API is not needed anymore.
func NewPortoConnection(opts ConnectionOpts) (*Connection, error) {
	dialer := net.Dialer{
		Timeout: connTimeout,
	}

	pool, err := netpool.New(
		func() (net.Conn, error) {
			return dialer.Dial("unix", portoSocket)
		},
		opts.MaxConnections,
	)

	if err != nil {
		return nil, err
	}

	ret := Connection{
		connPool: pool,
	}
	return &ret, nil
}

func (conn *Connection) Close() error {
	return conn.connPool.Close()
}

func (conn *Connection) GetLastError() portopb.EError {
	return conn.err
}

func (conn *Connection) GetLastErrorMessage() string {
	return conn.msg
}

func (conn *Connection) GetVersion() (string, string, error) {
	req := &portopb.TPortoRequest{
		Version: &portopb.TVersionRequest{},
	}
	resp, err := conn.performRequest(req)
	if err != nil {
		return "", "", err
	}

	return resp.GetVersion().GetTag(), resp.GetVersion().GetRevision(), nil
}

// ContainerAPI
func (conn *Connection) Create(name string) (err error) {
	req := &portopb.TPortoRequest{
		Create: &portopb.TCreateRequest{
			Name: &name,
		},
	}

	_, err = conn.performRequest(req)
	return
}

func (conn *Connection) CreateWeak(name string) (err error) {
	req := &portopb.TPortoRequest{
		CreateWeak: &portopb.TCreateRequest{
			Name: &name,
		},
	}

	_, err = conn.performRequest(req)
	return
}

func (conn *Connection) Destroy(name string) (err error) {
	req := &portopb.TPortoRequest{
		Destroy: &portopb.TDestroyRequest{
			Name: &name,
		},
	}

	_, err = conn.performRequest(req)
	return
}

func (conn *Connection) Start(name string) (err error) {
	req := &portopb.TPortoRequest{
		Start: &portopb.TStartRequest{
			Name: &name,
		},
	}

	_, err = conn.performRequest(req)
	return
}

func (conn *Connection) Stop(name string) error {
	return conn.StopWithTimeout(name, -1)
}

func (conn *Connection) StopWithTimeout(name string, timeout time.Duration) (err error) {
	req := &portopb.TPortoRequest{
		Stop: &portopb.TStopRequest{
			Name: &name,
		},
	}

	if timeout >= 0 {
		if timeout/time.Millisecond > math.MaxUint32 {
			return fmt.Errorf("timeout must be less than %d ms", math.MaxUint32)
		}

		timeoutMs := uint32(timeout / time.Millisecond)
		req.Stop.TimeoutMs = &timeoutMs
	}

	_, err = conn.performRequest(req)
	return
}

func (conn *Connection) Kill(name string, sig syscall.Signal) (err error) {
	sigNum := int32(sig)
	req := &portopb.TPortoRequest{
		Kill: &portopb.TKillRequest{
			Name: &name,
			Sig:  &sigNum,
		},
	}

	_, err = conn.performRequest(req)
	return
}

func (conn *Connection) Pause(name string) (err error) {
	req := &portopb.TPortoRequest{
		Pause: &portopb.TPauseRequest{
			Name: &name,
		},
	}

	_, err = conn.performRequest(req)
	return
}

func (conn *Connection) Resume(name string) (err error) {
	req := &portopb.TPortoRequest{
		Resume: &portopb.TResumeRequest{
			Name: &name,
		},
	}

	_, err = conn.performRequest(req)
	return
}

func (conn *Connection) Wait(container string, timeout time.Duration) (string, error) {
	return conn.WaitMulti([]string{container}, timeout)
}

func (conn *Connection) WaitMulti(containers []string, timeout time.Duration) (string, error) {
	req := &portopb.TPortoRequest{
		Wait: &portopb.TWaitRequest{
			Name: containers,
		},
	}

	if timeout >= 0 {
		if timeout/time.Millisecond > math.MaxUint32 {
			return "", fmt.Errorf("timeout must be less than %d ms", math.MaxUint32)
		}

		timeoutMs := uint32(timeout / time.Millisecond)
		req.Wait.TimeoutMs = &timeoutMs
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return "", err
	}

	return resp.GetWait().GetName(), nil
}

func (conn *Connection) List() ([]string, error) {
	return conn.ListByMask("")
}

func (conn *Connection) ListByMask(mask string) ([]string, error) {
	req := &portopb.TPortoRequest{
		List: &portopb.TListRequest{},
	}

	if mask != "" {
		req.List.Mask = &mask
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return nil, err
	}

	return resp.GetList().GetName(), nil
}

func (conn *Connection) ListProperties() (ret []AvailableProperty, err error) {
	req := &portopb.TPortoRequest{
		ListProperties: new(portopb.TListPropertiesRequest),
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return ret, err
	}

	for _, property := range resp.GetListProperties().GetList() {
		var p = AvailableProperty{
			Name:        property.GetName(),
			Description: property.GetDesc(),
		}
		ret = append(ret, p)
	}
	return ret, err
}

func (conn *Connection) GetProperties(containers []string, variables []string) (ret map[string]map[string]GetResponse, err error) {
	return conn.GetPropertiesWithBlock(containers, variables, false)
}

func (conn *Connection) GetPropertiesWithBlock(containers []string, variables []string, nonblock bool) (
	ret map[string]map[string]GetResponse, err error) {

	ret = make(map[string]map[string]GetResponse)
	req := &portopb.TPortoRequest{
		Get: &portopb.TGetRequest{
			Name:     containers,
			Variable: variables,
			Nonblock: &nonblock,
		},
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return nil, err
	}

	for _, item := range resp.GetGet().GetList() {
		for _, value := range item.GetKeyval() {
			var v = GetResponse{
				Value:    value.GetValue(),
				Error:    int(value.GetError()),
				ErrorMsg: value.GetErrorMsg(),
			}

			if _, ok := ret[item.GetName()]; !ok {
				ret[item.GetName()] = make(map[string]GetResponse)
			}

			ret[item.GetName()][value.GetVariable()] = v
		}
	}
	return ret, err
}

func (conn *Connection) GetProperty(name string, property string) (string, error) {
	req := &portopb.TPortoRequest{
		GetProperty: &portopb.TGetPropertyRequest{
			Name:     &name,
			Property: &property,
		},
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return "", err
	}

	return resp.GetGetProperty().GetValue(), nil
}

func (conn *Connection) SetProperty(name string, property string, value string) error {
	req := &portopb.TPortoRequest{
		SetProperty: &portopb.TSetPropertyRequest{
			Name:     &name,
			Property: &property,
			Value:    &value,
		},
	}
	_, err := conn.performRequest(req)
	return err
}

// VolumeAPI
func (conn *Connection) ListVolumeProperties() (ret []AvailableProperty, err error) {
	req := &portopb.TPortoRequest{
		ListVolumeProperties: &portopb.TListVolumePropertiesRequest{},
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return nil, err
	}

	for _, property := range resp.GetListVolumeProperties().GetList() {
		var desc = AvailableProperty{
			Name:        property.GetName(),
			Description: property.GetDesc(),
		}
		ret = append(ret, desc)
	}
	return ret, err
}

func (conn *Connection) CreateVolume(path string, config map[string]string) (desc VolumeDescription, err error) {
	req := &portopb.TPortoRequest{
		CreateVolume: &portopb.TCreateVolumeRequest{
			Properties: config,
		},
	}

	if path != "" {
		req.CreateVolume.Path = &path
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return desc, err
	}

	volume := resp.GetCreateVolume()
	desc.Path = volume.GetPath()
	desc.Containers = append(desc.Containers, volume.GetContainers()...)
	desc.Properties = volume.GetProperties()
	return desc, err
}

func (conn *Connection) TuneVolume(path string, config map[string]string) error {
	req := &portopb.TPortoRequest{
		TuneVolume: &portopb.TTuneVolumeRequest{
			Path:       &path,
			Properties: config,
		},
	}
	_, err := conn.performRequest(req)
	return err
}

func (conn *Connection) LinkVolume(path string, container string) error {
	req := &portopb.TPortoRequest{
		LinkVolume: &portopb.TLinkVolumeRequest{
			Path:      &path,
			Container: &container,
		},
	}
	_, err := conn.performRequest(req)
	return err
}

func (conn *Connection) UnlinkVolume(path string, container string) error {
	return conn.UnlinkVolumeWithStrict(path, container, false)
}

func (conn *Connection) UnlinkVolumeWithStrict(path string, container string, strict bool) error {
	req := &portopb.TPortoRequest{
		UnlinkVolume: &portopb.TUnlinkVolumeRequest{
			Path:      &path,
			Container: &container,
			Strict:    &strict,
		},
	}
	if container == "" {
		req.UnlinkVolume.Container = nil
	}
	_, err := conn.performRequest(req)
	return err
}

func (conn *Connection) GetVolume(path string, container string) (desc VolumeDescription, resultErr error) {
	req := &portopb.TPortoRequest{
		GetVolume: &portopb.TGetVolumeRequest{
			Path: []string{path},
		},
	}

	if container != "" {
		req.GetVolume.Container = &container
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		resultErr = err
		return
	}

	volume := resp.GetGetVolume().Volume[0]
	desc.Path = volume.GetPath()
	desc.Containers = append(desc.Containers, volume.GetContainer())
	return
}

func (conn *Connection) ListVolumes(path string, container string) (ret []VolumeDescription, err error) {
	req := &portopb.TPortoRequest{
		ListVolumes: &portopb.TListVolumesRequest{},
	}

	if path != "" {
		req.ListVolumes.Path = &path
	}

	if container != "" {
		req.ListVolumes.Container = &container
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return nil, err
	}

	for _, volume := range resp.GetListVolumes().GetVolumes() {
		ret = append(ret, VolumeDescription{
			Path:       volume.GetPath(),
			Containers: volume.GetContainers(),
			Properties: volume.GetProperties(),
		})
	}

	return ret, err
}

// LayerAPI
func (conn *Connection) ImportLayer(opts ImportLayerOpts) error {
	req := &portopb.TPortoRequest{
		ImportLayer: &portopb.TImportLayerRequest{
			Layer:        &opts.Layer,
			Tarball:      &opts.Tarball,
			Merge:        &opts.Merge,
			PrivateValue: &opts.PrivateValue,
		},
	}

	if opts.Place != "" {
		req.ImportLayer.Place = &opts.Place
	}

	_, err := conn.performRequest(req)
	return err
}

func (conn *Connection) ExportLayer(volume string, tarball string) error {
	req := &portopb.TPortoRequest{
		ExportLayer: &portopb.TExportLayerRequest{
			Volume:  &volume,
			Tarball: &tarball,
		},
	}
	_, err := conn.performRequest(req)
	return err
}

func (conn *Connection) RemoveLayer(layer string) error {
	return conn.RemoveLayerWithPlace(layer, "")
}

func (conn *Connection) RemoveLayerWithPlace(layer string, place string) error {
	req := &portopb.TPortoRequest{
		RemoveLayer: &portopb.TRemoveLayerRequest{
			Layer: &layer,
		},
	}

	if place != "" {
		req.RemoveLayer.Place = &place
	}

	_, err := conn.performRequest(req)
	return err
}

func (conn *Connection) ListLayers() ([]string, error) {
	req := &portopb.TPortoRequest{
		ListLayers: &portopb.TListLayersRequest{},
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return nil, err
	}

	return resp.GetListLayers().GetLayer(), nil
}

func (conn *Connection) ListLayers2(place string, mask string) (ret []LayerDescription, err error) {
	req := &portopb.TPortoRequest{
		ListLayers: &portopb.TListLayersRequest{},
	}

	if place != "" {
		req.ListLayers.Place = &place
	}

	if mask != "" {
		req.ListLayers.Mask = &mask
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return nil, err
	}

	for _, layer := range resp.GetListLayers().GetLayers() {
		var desc LayerDescription

		desc.Name = layer.GetName()
		desc.OwnerUser = layer.GetOwnerUser()
		desc.OwnerGroup = layer.GetOwnerGroup()
		desc.LastUsage = layer.GetLastUsage()
		desc.PrivateValue = layer.GetPrivateValue()

		ret = append(ret, desc)
	}

	return ret, nil
}

func (conn *Connection) GetLayerPrivate(layer string, place string) (string, error) {
	req := &portopb.TPortoRequest{
		GetLayerPrivate: &portopb.TGetLayerPrivateRequest{
			Layer: &layer,
		},
	}

	if place != "" {
		req.GetLayerPrivate.Place = &place
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return "", err
	}

	return resp.GetGetLayerPrivate().GetPrivateValue(), nil
}

func (conn *Connection) SetLayerPrivate(layer string, place string,
	privateValue string) error {
	req := &portopb.TPortoRequest{
		SetLayerPrivate: &portopb.TSetLayerPrivateRequest{
			Layer:        &layer,
			PrivateValue: &privateValue,
		},
	}

	if place != "" {
		req.SetLayerPrivate.Place = &place
	}

	_, err := conn.performRequest(req)
	return err
}

func (conn *Connection) ListStorage(place string, mask string) (ret []StorageDescription, err error) {
	req := &portopb.TPortoRequest{
		ListStorages: &portopb.TListStoragesRequest{},
	}

	if place != "" {
		req.ListStorages.Place = &place
	}

	if mask != "" {
		req.ListStorages.Mask = &mask
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return nil, err
	}

	for _, storage := range resp.GetListStorages().GetStorages() {
		var desc StorageDescription

		desc.Name = storage.GetName()
		desc.OwnerUser = storage.GetOwnerUser()
		desc.OwnerGroup = storage.GetOwnerGroup()
		desc.LastUsage = storage.GetLastUsage()
		desc.PrivateValue = storage.GetPrivateValue()

		ret = append(ret, desc)
	}

	return ret, nil
}

func (conn *Connection) RemoveStorage(name string, place string) error {
	req := &portopb.TPortoRequest{
		RemoveStorage: &portopb.TRemoveStorageRequest{
			Name: &name,
		},
	}

	if place != "" {
		req.RemoveStorage.Place = &place
	}

	_, err := conn.performRequest(req)
	return err
}

func (conn *Connection) ConvertPath(path string, src string, dest string) (string, error) {
	req := &portopb.TPortoRequest{
		ConvertPath: &portopb.TConvertPathRequest{
			Path:        &path,
			Source:      &src,
			Destination: &dest,
		},
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return "", err
	}

	return resp.GetConvertPath().GetPath(), nil
}

func (conn *Connection) AttachProcess(name string, pid uint32, comm string) error {
	req := &portopb.TPortoRequest{
		AttachProcess: &portopb.TAttachProcessRequest{
			Name: &name,
			Pid:  &pid,
			Comm: &comm,
		},
	}

	_, err := conn.performRequest(req)
	return err
}

func (conn *Connection) AttachThread(name string, pid uint32, comm string) error {
	req := &portopb.TPortoRequest{
		AttachThread: &portopb.TAttachProcessRequest{
			Name: &name,
			Pid:  &pid,
			Comm: &comm,
		},
	}

	_, err := conn.performRequest(req)
	return err
}

func (conn *Connection) LocateProcess(pid uint32, comm string) (string, error) {
	req := &portopb.TPortoRequest{
		LocateProcess: &portopb.TLocateProcessRequest{
			Pid:  &pid,
			Comm: &comm,
		},
	}

	resp, err := conn.performRequest(req)
	if err != nil {
		return "", err
	}

	return resp.GetLocateProcess().GetName(), nil
}

func (conn *Connection) performRequest(req *portopb.TPortoRequest) (*portopb.TPortoResponse, error) {
	conn.err = 0
	conn.msg = ""

	data, err := proto.Marshal(req)
	if err != nil {
		return nil, err
	}

	var netConn net.Conn
	for i := 0; i < conn.connPool.PoolSize(); i++ {
		netConn, err = conn.connPool.Acquire()
		if err != nil {
			// in case of connection error - just go out
			return nil, err
		}

		err = sendData(netConn, data)
		if err != nil {
			// If any network errors occurs - don't return connection into pool
			conn.connPool.Discard(netConn)

			// Probably this is broken connection - so try to use another
			continue
		}
		break
	}

	if err != nil {
		// catch last iteration error
		return nil, err
	}
	data, err = recvData(netConn)
	if err != nil {
		// If any network errors occurs - don't return connection into pool
		conn.connPool.Discard(netConn)
		return nil, err
	}

	// Otherwise - return connection into pool
	conn.connPool.Release(netConn)

	var resp portopb.TPortoResponse
	err = proto.Unmarshal(data, &resp)
	if err != nil {
		return nil, err
	}

	conn.err = resp.GetError()
	conn.msg = resp.GetErrorMsg()

	if resp.GetError() != portopb.EError_Success {
		return nil, &Error{
			Errno:   conn.err,
			ErrName: portopb.EError_name[int32(conn.err)],
			Message: conn.msg,
		}
	}

	return &resp, nil
}

func sendData(conn net.Conn, data []byte) error {
	// First we have to send actual data size,
	// then the data itself
	buf := make([]byte, 64)
	wroteLen := binary.PutUvarint(buf, uint64(len(data)))
	_, err := conn.Write(buf[:wroteLen])
	if err != nil {
		return err
	}
	_, err = conn.Write(data)
	return err
}

func recvData(conn net.Conn) ([]byte, error) {
	buf := make([]byte, 1024*1024)

	size, err := conn.Read(buf)
	if err != nil {
		return nil, err
	}

	exp, shift := binary.Uvarint(buf)

	// length of result is exp,
	// so preallocate a buffer for it.
	var ret = make([]byte, exp)
	// bytes after an encoded uint64 and up to len
	// are belong to a packed structure, so copy them
	copy(ret, buf[shift:size])

	// we don't need to check that
	// size > shift, as we ask to read enough data
	// to decode uint64. Otherwise we would have an error before.
	for pos := size - shift; uint64(pos) < exp; {
		n, err := conn.Read(ret[pos:])
		if err != nil {
			return nil, err
		}
		pos += n
	}

	return ret, nil
}
