package extensionbillingmanagers

import (
	"context"

	"github.com/Masterminds/squirrel"
	"github.com/cactus/go-statsd-client/statsd"

	"code.justin.tv/devrel/dbx"
	"code.justin.tv/devrel/devsite-rbac/backend/common"
	"code.justin.tv/devrel/devsite-rbac/internal/errorutil"
	"code.justin.tv/devrel/devsite-rbac/rpc/rbacrpc"
)

const Table = "extension_billing_manager"

//go:generate counterfeiter . ExtensionBillingManagers
//go:generate errxer --timings ExtensionBillingManagers
type ExtensionBillingManagers interface {
	GetExtensionBillingManager(ctx context.Context, extensionClientID string) (ExtensionBillingManager, error)
	ListExtensionBillingManagers(ctx context.Context, params ListExtensionBillingManagersParams) ([]ExtensionBillingManager, int32, error)

	// Check if the membership is an assigned billing manager of any extension in that organization
	IsAssignedBillingManager(ctx context.Context, companyID, twitchID string) (bool, error)

	SetExtensionBillingManager(ctx context.Context, ebm *ExtensionBillingManager) error // insert or update
	DeleteExtensionBillingManager(ctx context.Context, ebm *ExtensionBillingManager) error
}

type ExtensionBillingManager struct {
	ExtensionClientID      string `db:"extension_client_id"`
	BillingManagerTwitchID string `db:"billing_manager_twitch_id"`
	CreatedAt              string `db:"created_at"`
	UpdatedAt              string `db:"updated_at"`
	XXX_Total              int32  `db:"_total"`
}

var Columns = dbx.FieldsFrom(ExtensionBillingManager{}).Exclude("_total")

type DBXExtensionBillingManagers struct {
	db common.DBXer
}

func New(db common.DBXer, stats statsd.Statter) ExtensionBillingManagers {
	impl := &DBXExtensionBillingManagers{db: db}
	errxWrap := &ExtensionBillingManagersErrx{
		ExtensionBillingManagers: impl,
		TimingFunc:               common.TimingStats(stats),
	}

	return errxWrap
}

func (d *DBXExtensionBillingManagers) GetExtensionBillingManager(ctx context.Context, extensionClientID string) (ExtensionBillingManager, error) {
	q := common.PSQL.Select(Columns...).From(Table).
		Where("extension_client_id = ?", extensionClientID).Limit(1)
	var e ExtensionBillingManager
	err := d.db.LoadOne(ctx, &e, q)
	return e, err
}

type ListExtensionBillingManagersParams struct {
	ExtensionClientIDs     []string
	BillingManagerTwitchID string
	CompanyID              string // restrict extension IDs to those owned by this company

	Limit      uint64
	Offset     uint64
	OrderBy    string // field to order
	OrderByDir string // DESC or ASC (default)
}

func (d *DBXExtensionBillingManagers) ListExtensionBillingManagers(ctx context.Context, p ListExtensionBillingManagersParams) ([]ExtensionBillingManager, int32, error) {
	cols := Columns.Add(common.CountOverAs("_total"))
	q := common.PSQL.Select(cols...).From(Table)

	if p.CompanyID != "" {
		q = q.Join("company_resources on resource_id = extension_client_id and resource_type = 'extension'")
		q = q.Where("company_id = ?", p.CompanyID)
	}

	if len(p.ExtensionClientIDs) > 0 {
		q = q.Where(squirrel.Eq{"extension_client_id": p.ExtensionClientIDs})
	}
	if p.BillingManagerTwitchID != "" {
		q = q.Where("billing_manager_twitch_id = ?", p.BillingManagerTwitchID)
	}

	q = common.Paginate(q, p.Limit, p.Offset)
	if p.OrderBy != "" {
		q = q.OrderBy(p.OrderBy + " " + p.OrderByDir)
	}

	list := []ExtensionBillingManager{}
	err := d.db.LoadAll(ctx, &list, q)

	return list, common.FirstRowInt32DBField(list, "_total"), err
}

func (d *DBXExtensionBillingManagers) IsAssignedBillingManager(ctx context.Context, companyID, twitchID string) (bool, error) {
	_, count, err := d.ListExtensionBillingManagers(ctx, ListExtensionBillingManagersParams{
		CompanyID:              companyID,
		BillingManagerTwitchID: twitchID,
		Limit:                  1,
	})
	return (count > 0), err
}

func (d *DBXExtensionBillingManagers) SetExtensionBillingManager(ctx context.Context, e *ExtensionBillingManager) error {
	_, err := d.GetExtensionBillingManager(ctx, e.ExtensionClientID) // already exists ?
	if errorutil.IsErrNoRows(err) {
		err = d.insertExtensionBillingManager(ctx, e)
	} else {
		err = d.updateExtensionBillingManager(ctx, e)
	}
	return err
}

func (d *DBXExtensionBillingManagers) insertExtensionBillingManager(ctx context.Context, e *ExtensionBillingManager) error {
	e.CreatedAt = common.TimeNowStr()
	e.UpdatedAt = common.TimeNowStr()
	return d.db.InsertOne(ctx, Table, e, dbx.Exclude("_total"))
}

func (d *DBXExtensionBillingManagers) updateExtensionBillingManager(ctx context.Context, e *ExtensionBillingManager) error {
	e.UpdatedAt = common.TimeNowStr()
	return d.db.UpdateOne(ctx, Table, e, dbx.FindBy("extension_client_id"), dbx.Exclude("created_at", "_total"))
}

func (d *DBXExtensionBillingManagers) DeleteExtensionBillingManager(ctx context.Context, e *ExtensionBillingManager) error {
	return d.db.DeleteOne(ctx, Table, e, dbx.FindBy("extension_client_id"))
}

//
// Converters
//

func (e ExtensionBillingManager) ToRPC() *rbacrpc.ExtensionBillingManager {
	rpcExtensionBillingManager := &rbacrpc.ExtensionBillingManager{
		ExtensionClientId:      e.ExtensionClientID,
		BillingManagerTwitchId: e.BillingManagerTwitchID,
		CreatedAt:              e.CreatedAt,
		UpdatedAt:              e.UpdatedAt,
	}

	return rpcExtensionBillingManager
}

func ListToRPC(list []ExtensionBillingManager) []*rbacrpc.ExtensionBillingManager {
	extensionBillingManagers := make([]*rbacrpc.ExtensionBillingManager, len(list))
	for i, extensionBillingManager := range list {
		extensionBillingManagers[i] = extensionBillingManager.ToRPC()
	}
	return extensionBillingManagers
}
