package sqlbuilder

import (
	"fmt"
	"strings"

	"a.yandex-team.ru/library/go/core/xerrors"
)

type SQLBuilder struct {
	allowedColumns map[string]bool
}

func New(allowedColumns map[string]bool) SQLBuilder {
	return SQLBuilder{allowedColumns: allowedColumns}
}

func (sqlBuilder *SQLBuilder) BuildNotFormattedWhere(whereMap map[string]interface{}) (string, []interface{}, error) {
	switch len(whereMap) {
	case 0:
		return "", nil, nil
	case 1:
		for key, value := range whereMap {
			var operator string
			if key == "and" || key == "or" {
				queryStrings := make([]string, 0)
				params := make([]interface{}, 0)
				for _, subValue := range value.([]interface{}) {
					subQueryStr, subParams, err := sqlBuilder.BuildNotFormattedWhere(subValue.(map[string]interface{}))
					if err != nil {
						return "", nil, err
					}
					queryStrings = append(queryStrings, subQueryStr)
					params = append(params, subParams...)
				}
				queryString := " ( " + strings.Join(queryStrings, " "+key+" ") + " ) "
				return queryString, params, nil
			}
			switch key {
			case "eq":
				operator = "="
			case "ne":
				operator = "!="
			case "gt":
				operator = ">"
			case "ge":
				operator = ">="
			case "lt":
				operator = "<"
			case "le":
				operator = "<="
			default:
				return "", nil, xerrors.Errorf("Unknown operation: %s", key)
			}
			for paramName, paramValue := range value.(map[string]interface{}) {
				allowed, ok := sqlBuilder.allowedColumns[paramName]
				if ok && allowed {
					queryString := fmt.Sprintf(" %s %s $%%d", paramName, operator)
					return queryString, []interface{}{paramValue}, nil
				}
				return "", nil, xerrors.Errorf("Column name %s not allowed", paramName)
			}
		}
	default:
		return "", nil, xerrors.Errorf("Where map has more than 1 root argument")
	}
	return "", nil, nil
}

func (sqlBuilder *SQLBuilder) BuildWhere(whereMap map[string]interface{}) (string, []interface{}, error) {
	query, arguments, err := sqlBuilder.BuildNotFormattedWhere(whereMap)
	if err != nil {
		return "", nil, err
	}
	positionNumbers := make([]interface{}, len(arguments))
	for i := 1; i <= len(arguments); i++ {
		positionNumbers[i-1] = i
	}
	formattedQuery := fmt.Sprintf(query, positionNumbers...)
	return formattedQuery, arguments, nil
}
