package leviathan

import (
	"fmt"
	"reflect"

	"github.com/Masterminds/squirrel"
	"github.com/pkg/errors"

	"code.justin.tv/safety/datastore/models"
)

const (
	sortDescTemplate = "%s desc"
)

var (
	errUnsupportedSortType   = fmt.Errorf("Unsupported sort type")
	errUnsupportedFilterType = fmt.Errorf("Unsupported filter type")
)

// sortBy appends additional query to sort result based on the sort.
// A sort struct must only contain fields of type *models.Sort
// type MySort struct {
//   Field1 *models.Sort `db:"column1"`
//   Field2 *models.Sort `db:"column2"`
// }
// If sort is nil, base select builder is returned
// If sort is not nil, a new select builder will be returned including non-nil fields
func sortBy(sb squirrel.SelectBuilder, sorts interface{}) (squirrel.SelectBuilder, error) {
	// Reflect on the type to handle nil cases because go
	if sorts == nil || reflect.ValueOf(sorts).IsNil() {
		return sb, nil
	}

	names, iSorts := toColumnsAndValues(sorts)
	for i, iSort := range iSorts {
		sort, ok := iSort.(*models.Sort)
		if !ok {
			return sb, errors.Errorf("Non Sort field found in %T", sorts)
		}

		// Skip empty fields
		if sort == nil {
			continue
		}

		switch sort.Type {
		case models.SortAscending:
			sb = sb.OrderBy(names[i])
		case models.SortDescending:
			sb = sb.OrderBy(fmt.Sprintf(sortDescTemplate, names[i]))
		default:
			return sb, errors.Wrap(errUnsupportedSortType, fmt.Sprintf("adding sort %+v", sort))
		}
	}
	return sb, nil
}

// filterBy appends additional query to filter result based on the filter.
// A filter struct must only contain fields of type *models.Filter
// type MyFilter struct {
//   Field1 *models.Filter `db:"column1"`
//   Field2 *models.Filter `db:"column2"`
// }
// If filter is nil, base select builder is returned
// If filter is not nil, a new select builder will be returned including non-nil fields
func filterBy(sb squirrel.SelectBuilder, filters interface{}) (squirrel.SelectBuilder, error) {
	// Reflect on the type to handle nil cases because go
	if filters == nil || reflect.ValueOf(filters).IsNil() {
		return sb, nil
	}

	names, iFilters := toColumnsAndValues(filters)
	for i, iFilter := range iFilters {
		filter, ok := iFilter.(*models.Filter)
		if !ok {
			return sb, errors.Errorf("Non Filter field found in %T", filters)
		}

		// Skip empty fields
		if filter == nil {
			continue
		}

		name := names[i]
		value := filter.Value
		switch filter.Type {
		case models.FilterEqual:
			sb = sb.Where(squirrel.Eq{name: value})
		case models.FilterNullOrNotEqual:
			sb = sb.Where(fmt.Sprintf("(%s IS NOT NULL OR %s <> ?)", name, name), value)
		case models.FilterNotEqual:
			sb = sb.Where(squirrel.NotEq{name: value})
		case models.FilterGreaterThan:
			sb = sb.Where(squirrel.Gt{name: value})
		case models.FilterLessThan:
			sb = sb.Where(squirrel.Lt{name: value})
		case models.FilterIncludes:
			sb = sb.Where(squirrel.Like{name: fmt.Sprintf("%%%s%%", value)})
		case models.FilterBetween:
			if values, ok := value.([]interface{}); ok {
				if len(values) != 2 {
					return sb, errors.Errorf("FilterBetween had %d values", len(values))
				}
				sb = sb.Where(squirrel.Or{
					squirrel.And{squirrel.GtOrEq{name: values[0]}, squirrel.LtOrEq{name: values[1]}},
					squirrel.And{squirrel.GtOrEq{name: values[1]}, squirrel.LtOrEq{name: values[0]}},
				})
			} else {
				return sb, errors.Errorf("FilterBetween value was not a slice of interfaces, but %T", value)
			}
		default:
			return sb, errors.Wrap(errUnsupportedFilterType, fmt.Sprintf("adding filter %+v", filter))
		}
	}
	return sb, nil
}
