package filter

import (
	"fmt"
	"strings"

	"a.yandex-team.ru/library/go/core/xerrors"
)

func (filter *LogicOpFilter) ToMySQLQuery(fields map[string]string, args []interface{}) (string, []interface{}, error) {
	if len(filter.Args) == 0 {
		return "", nil, xerrors.New("malformed filter: empty args for logic operator")
	}

	predicates := make([]string, len(filter.Args))
	for idx, arg := range filter.Args {
		var err error
		predicates[idx], args, err = arg.ToMySQLQuery(fields, args)
		if err != nil {
			return "", nil, err
		}
	}

	if len(predicates) == 1 {
		return predicates[0], args, nil
	}

	var logicOp string
	switch filter.LogicOp {
	case LogicAnd:
		logicOp = " AND "
	case LogicOr:
		logicOp = " OR "
	default:
		return "", nil, xerrors.Errorf("malformed filter: unknown logic operator '%v'", filter.LogicOp)
	}

	return fmt.Sprint("(", strings.Join(predicates, logicOp), ")"), args, nil
}

func (filter *FieldFilter) ToMySQLQuery(fields map[string]string, args []interface{}) (string, []interface{}, error) {
	if len(filter.Values) == 0 {
		return "", nil, xerrors.New("malformed filter: empty values for field filter")
	}

	field, exists := fields[filter.Field]
	if !exists {
		return "", nil, xerrors.Errorf("unexpected field '%s'", filter.Field)
	}

	switch filter.CompareOp {
	case Equal:
		return fmt.Sprint(field, " IN (?)"), append(args, filter.Values), nil
	case NotEqual:
		return fmt.Sprint(field, " NOT IN (?)"), append(args, filter.Values), nil
	case Contains:
		for _, value := range filter.Values {
			args = append(args, fmt.Sprint("%", escapeLike(value), "%"))
		}
		return repeatPredicate(fmt.Sprint(field, " LIKE ?"), len(filter.Values)), args, nil
	case StartsWith:
		for _, value := range filter.Values {
			args = append(args, fmt.Sprint(escapeLike(value), "%"))
		}
		return repeatPredicate(fmt.Sprint(field, " LIKE ?"), len(filter.Values)), args, nil
	case Less:
		return fmt.Sprint(field, " < ?"), append(args, filter.Values[0]), nil
	case More:
		return fmt.Sprint(field, " > ?"), append(args, filter.Values[0]), nil
	default:
		return "", nil, xerrors.Errorf("malformed filter: unknown compare operator type '%v'", filter.CompareOp)
	}
}

func repeatPredicate(predicate string, count int) string {
	if count == 1 {
		return predicate
	}
	return fmt.Sprint("(", strings.Repeat(fmt.Sprint(predicate, " OR "), count-1), predicate, ")")
}

var escapeLikeMap [256]bool

func escapeLike(s string) string {
	var builder strings.Builder
	builder.Grow(2 * len(s))
	for _, c := range []byte(s) {
		if escapeLikeMap[c] {
			builder.WriteByte('\\')
		}
		builder.WriteByte(c)
	}
	return builder.String()
}

func init() {
	for _, c := range []byte{'\\', '%', '_'} {
		escapeLikeMap[c] = true
	}
}
