package db

import (
	"bytes"
	"encoding/json"
	"math/rand"
	"sort"
	"strconv"
	"strings"
	"time"

	"code.justin.tv/web/jax/common/config"
	"code.justin.tv/web/jax/common/log"
	"code.justin.tv/web/jax/common/stats"
	"code.justin.tv/web/jax/db/query"

	es "github.com/Kaidence/elastigo/lib"
	"github.com/cactus/go-statsd-client/statsd"
	cache "github.com/patrickmn/go-cache"
)

// ElasticSearchReader is a Jax-specific reader to search channels in ElasticSearch
type ElasticSearchReader struct {
	Conn          *es.Conn
	ReplicateConn *es.Conn
	Cache         *cache.Cache
	ReplicateRate float64
	Index         string
	Type          string
	Stats         statsd.Statter
}

// NewElasticSearchReader creates a new interface to query Jax streams in ElasticSearch.
func NewElasticSearchReader(conf *config.Config, useCache bool) *ElasticSearchReader {
	client := &ElasticSearchReader{
		Conn:  es.NewConn(),
		Index: conf.ESIndex,
		Type:  conf.ESType,
		Stats: stats.InitStatsd(conf),
	}
	client.Conn.SetHosts(conf.ESHosts)
	client.Conn.SetPort(strconv.Itoa(conf.ESPort))

	client.ReplicateRate = conf.ESReplicateRate

	// Initialize to prevent nil panic
	client.ReplicateConn = es.NewConn()
	if conf.ESReplicateHosts != nil {
		client.ReplicateConn.SetHosts(conf.ESReplicateHosts)
		client.ReplicateConn.SetPort(strconv.Itoa(conf.ESPort))
	} else {
		client.ReplicateRate = 0
	}

	if useCache {
		client.Cache = cache.New(2*time.Minute, 30*time.Second)
		go client.updateCache()
	}

	return client
}

// updateCache updates the cache constantly with live channels
func (T *ElasticSearchReader) updateCache() {
	output := T.Lease("cache", 1000)
	for channels := range output {
		T.Stats.Gauge("cache.count", int64(T.Cache.ItemCount()), 1)
		T.Stats.Inc("cache.update", int64(len(channels)), 0.5)
		for _, ch := range channels {
			T.Cache.Set(channelNamePrefix+ch.Channel, ch, cache.DefaultExpiration)
			if id := ch.GetID(); id > 0 {
				T.Cache.Set(channelIDPrefix+strconv.FormatInt(id, 10), ch, cache.DefaultExpiration)
			}
		}
	}
}

func (T *ElasticSearchReader) Get(channel string, fields []string, filters []query.Filter) (*ResultSet, *JaxDbError) {
	return T.get(channelNamePrefix+channel, fields, filters)
}

func (T *ElasticSearchReader) GetByID(channelID string, fields []string, filters []query.Filter) (*ResultSet, *JaxDbError) {
	return T.get(channelIDPrefix+channelID, fields, filters)
}

// get retrieves a single channel from the in memory cache
func (T *ElasticSearchReader) get(cacheKey string, fields []string, filters []query.Filter) (*ResultSet, *JaxDbError) {
	res, found := T.Cache.Get(cacheKey)
	if found {
		if ch, ok := res.(ChannelResult); ok {
			flatProps := FlattenProperties(ch.Properties)
			if !matchFields(flatProps, fields) {
				return &ResultSet{Total: 0, Hits: []ChannelResult{}}, nil
			}

			for _, f := range filters {
				if !f.Valid(flatProps) {
					return &ResultSet{Total: 0, Hits: []ChannelResult{}}, nil
				}
			}

			return &ResultSet{
				Total: 1,
				Hits: []ChannelResult{
					ChannelResult{Channel: ch.Channel, Properties: jsonifyProperties(flatProps, fields)},
				}}, nil
		}
	}

	return &ResultSet{Total: 0, Hits: []ChannelResult{}}, nil
}

func matchFields(m map[string]interface{}, fields []string) bool {
	// Ensure all fields are there.
	for _, f := range fields {
		if _, ok := m[f]; !ok {
			return false
		}
	}

	return true
}

type ByChannelCount []ChannelResult

func (a ByChannelCount) Len() int           { return len(a) }
func (a ByChannelCount) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
func (a ByChannelCount) Less(i, j int) bool { return a[i].getViewerCount() > a[j].getViewerCount() }

type ByRecency []ChannelResult

func (a ByRecency) Len() int      { return len(a) }
func (a ByRecency) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByRecency) Less(i, j int) bool {
	return a[i].getStreamUpTimestamp() > a[j].getStreamUpTimestamp()
}

const (
	channelNamePrefix = "channel_"
	channelIDPrefix   = "id_"
)

var permittedIDField = map[string]string{
	"rails.channel":    channelNamePrefix,
	"rails.channel_id": channelIDPrefix,
}

func (T *ElasticSearchReader) bulkGet(idField string, ids []string, fields []string, sortField string, limit, offset int, filters ...query.Filter) (*ResultSet, *JaxDbError) {
	keyPrefix, ok := permittedIDField[idField]
	if !ok || T.Cache == nil {
		filtersWithChannels := append([]query.Filter{query.StringTermsFilter(idField, ids)}, filters...)
		return T.Search(nil,
			query.SearchQuery{
				Limit:   limit,
				Offset:  offset,
				Fields:  fields,
				Filters: filtersWithChannels,
			})
	}

	foundChannels := []ChannelResult{}
	for _, c := range ids {
		channel, found := T.Cache.Get(keyPrefix + c)
		if found {
			if res, ok := channel.(ChannelResult); ok {
				foundChannels = append(foundChannels, res)
			}
		}
	}

	// Filters the results by their properties
	filteredResults := []ChannelResult{}
	for _, c := range foundChannels {
		flatProperties := flattenProperties("", c.Properties)
		if !matchFields(flatProperties, fields) {
			continue
		}

		valid := true
		for _, filter := range filters {
			if !filter.Valid(flatProperties) {
				valid = false
				break
			}
		}
		if valid {
			filteredResults = append(filteredResults, ChannelResult{
				Channel:    c.Channel,
				Properties: jsonifyProperties(flatProperties, fields),
			})
		}
	}

	if sortField == "recency" {
		sort.Sort(ByRecency(filteredResults))
	} else {
		sort.Sort(ByChannelCount(filteredResults))
	}

	paginatedResults := []ChannelResult{}
	if limit+offset <= len(filteredResults) {
		paginatedResults = filteredResults[offset : offset+limit]
	} else if offset < len(filteredResults) {
		paginatedResults = filteredResults[offset:len(filteredResults)]
	}

	return &ResultSet{
		Total: len(filteredResults),
		Hits:  paginatedResults,
	}, nil
}

func (T *ElasticSearchReader) BulkGetByChannel(channels []string, fields []string, sortField string, limit, offset int, filters ...query.Filter) (*ResultSet, *JaxDbError) {
	return T.bulkGet("rails.channel", channels, fields, sortField, limit, offset, filters...)
}

func (T *ElasticSearchReader) BulkGetByChannelID(channelIDs []string, fields []string, sortField string, limit, offset int, filters ...query.Filter) (*ResultSet, *JaxDbError) {
	return T.bulkGet("rails.channel_id", channelIDs, fields, sortField, limit, offset, filters...)
}

func nextPrefix(prefix, key string) string {
	if len(prefix) == 0 {
		return key
	}
	return prefix + "." + key
}

func FlattenProperties(props map[string]interface{}) map[string]interface{} {
	return flattenProperties("", props)
}

// flattenProperties takes a nested property structure (returned by ES) and produces a flat property structure, with names joined by "."
func flattenProperties(prefix string, props map[string]interface{}) map[string]interface{} {
	flatProps := map[string]interface{}{}
	for k, v := range props {
		if inner, ok := v.(map[string]interface{}); ok {
			innerProps := flattenProperties(nextPrefix(prefix, k), inner)
			for k, v := range innerProps {
				flatProps[k] = v
			}
		} else {
			flatProps[nextPrefix(prefix, k)] = v
		}
	}
	return flatProps
}

func JsonifyProperties(props map[string]interface{}) map[string]interface{} {
	return jsonifyProperties(props, []string{})
}

// jsonifyProperties takes a flat property map and converts it into a nested structure.
func jsonifyProperties(props map[string]interface{}, field []string) map[string]interface{} {
	fieldMap := map[string]bool{}
	for _, f := range field {
		fieldMap[f] = true
	}

	nestedProps := map[string]interface{}{}
	for k, v := range props {
		if _, ok := fieldMap[k]; !ok && len(fieldMap) > 0 {
			continue
		}

		parts := strings.Split(k, ".")
		parentProps := nestedProps
		// build the parent structures if needed
		for i := 0; i < len(parts)-1; i++ {
			var ok bool
			if _, ok = parentProps[parts[i]]; !ok {
				parentProps[parts[i]] = map[string]interface{}{}
			}
			parentProps, ok = parentProps[parts[i]].(map[string]interface{})
			if !ok {
				log.Reportf("Failed to assign to parent properties, THIS SHOULD NEVER HAPPEN")
			}
		}
		parentProps[parts[len(parts)-1]] = v
	}
	return nestedProps
}

// Search searches ElasticSearch for channel results based on the query.
func (T *ElasticSearchReader) Search(args map[string]interface{}, q query.SearchQuery) (*ResultSet, *JaxDbError) {
	return T.searchConn(T.Conn, "search", args, q.ToQuery())
}

func (T *ElasticSearchReader) searchConn(conn *es.Conn, statName string, args map[string]interface{}, query []byte) (*ResultSet, *JaxDbError) {
	if args == nil {
		args = map[string]interface{}{}
	}

	t := time.Now()
	res, err := conn.Search(T.Index, T.Type, args, string(query))
	T.Stats.TimingDuration("es."+statName, time.Since(t), 0.5)

	go T.replicate(func() {
		T.Stats.Inc("es.replicate_"+statName, 1, 0.1)
		t := time.Now()
		T.ReplicateConn.Search(T.Index, T.Type, args, string(query))
		T.Stats.TimingDuration("es.replicate"+statName, time.Since(t), 0.5)
	})

	if err != nil {
		return nil, formatError(err)
	}
	return formatResult(res)
}

func (T *ElasticSearchReader) replicate(f func()) {
	repeats := int(T.ReplicateRate)
	for i := 0; i < repeats; i++ {
		f()
	}
	if rand.Float64() < T.ReplicateRate-float64(repeats) {
		f()
	}
}

// Aggregate queries ElasticSearch for an aggregate response
func (T *ElasticSearchReader) Aggregate(q query.AggregationQuery) ([]Aggregate, *JaxDbError) {
	t := time.Now()
	resp, err := T.Conn.Search(T.Index, T.Type, nil, q.ToQuery())
	T.Stats.TimingDuration("es.search_aggregate", time.Since(t), 0.5)
	if err != nil {
		return nil, formatError(err)
	}
	return formatAggregate(q, resp)
}

// Lease returns a chan of array of channel names.
// The channel is written to periodically with chunks of channels from ES that satisfy the scan filters.
func (T *ElasticSearchReader) Lease(statName string, bufferSize int, scanFilters ...query.Filter) chan []ChannelResult {
	return T.leaseConn(T.Conn, "lease."+statName, bufferSize, true, scanFilters...)
}

func (T *ElasticSearchReader) Scan(statName string, bufferSize int, scanFilters ...query.Filter) chan []ChannelResult {
	return T.leaseConn(T.Conn, "scan."+statName, bufferSize, false, scanFilters...)
}

func (T *ElasticSearchReader) leaseConn(conn *es.Conn, statName string, bufferSize int, continuous bool, scanFilters ...query.Filter) chan []ChannelResult {
	channels := make(chan []ChannelResult, 1)

	if conn == nil {
		return channels
	}
	go func() {
		start := time.Now()
		scanQuery := query.ScanQuery(scanFilters)

		scrollID := ""
		for {
			args := map[string]interface{}{"scroll": "30s", "size": bufferSize}
			// If we have a previous scroll id, we continue that query
			if scrollID != "" {
				t := time.Now()
				res, err := conn.Scroll(args, scrollID)
				T.Stats.TimingDuration("es."+statName, time.Since(t), 0.5)

				// If we errored or reached the end, reset the query for next attempt
				if err != nil || len(res.Hits.Hits) < bufferSize {
					scrollID = ""
				} else {
					scrollID = res.ScrollId
				}

				result, _ := formatResult(res)
				T.Stats.Inc("es."+statName, 1, 0.1)
				T.Stats.Inc("es."+statName+".channels", int64(len(result.Hits)), 0.1)
				channels <- result.Hits

				// If we reached the end, we wait a bit so we don't spam update
				if len(result.Hits) < bufferSize {
					T.Stats.Gauge("lease.duration", int64(time.Since(start).Seconds()), 1)
					if !continuous {
						close(channels)
						return
					}
					time.Sleep(20 * time.Second)
				}

				continue
			}

			start = time.Now()
			// Launch the initial query (which gets continued if possible at the start of the
			// function. The initial query does not return any results, it only initiates the scroll.
			args["search_type"] = "scan"
			res, err := T.searchConn(conn, statName, args, scanQuery)

			if err != nil {
				// If there is an error, sleep to not hammer our backend. This can happen on
				// index creation, among other situations.
				log.Reportf("failed to query elasticsearch for channels to lease (%s): %s", statName, err.Message)
				time.Sleep(15 * time.Second)
				continue
			} else if res.ScrollID == "" {
				log.Reportf("did not receive a ScrollID when leasing channels")
				time.Sleep(15 * time.Second)
			}
			scrollID = res.ScrollID
		}
	}()

	return channels
}

// formatResult takes a result returned by the Elastigo library and formats it
// into a Jax-appropriote ResultSet.
func formatResult(resp es.SearchResult) (*ResultSet, *JaxDbError) {
	hits := []ChannelResult{}
	var dec *json.Decoder
	var err error
	for _, val := range resp.Hits.Hits {
		var m map[string]interface{}
		// Marshal the result into a Go map. If there was no 'fields' query, the
		// data is in val.Source. Otherwise, it's in val.Fields. <_<;
		if val.Source != nil && len(*val.Source) > 0 {
			dec = json.NewDecoder(bytes.NewReader(*val.Source))
		} else if val.Fields != nil && len(*val.Fields) > 0 {
			dec = json.NewDecoder(bytes.NewReader(*val.Fields))
		}
		if dec != nil {
			// Tell the JSON decoder to format numbers into json.Number types instead
			// of floats. This is to prevent client-side breakages.
			dec.UseNumber()
			err = dec.Decode(&m)
			if err != nil {
				return nil, NewJaxDbError(err.Error())
			}
		}
		hits = append(hits, ChannelResult{Channel: val.Id, Properties: m})
	}
	return &ResultSet{Total: resp.Hits.Total, Hits: hits, ScrollID: resp.ScrollId}, nil
}

// formatAggregate takes a result returned by the Elastigo library and formats it
// into a Jax-appropriote array of aggregates.
func formatAggregate(q query.AggregationQuery, resp es.SearchResult) ([]Aggregate, *JaxDbError) {
	var results AggregateResponse
	var dec *json.Decoder
	var err error
	dec = json.NewDecoder(bytes.NewReader(resp.Aggregations))
	dec.UseNumber()
	err = dec.Decode(&results)
	if err != nil {
		return nil, NewJaxDbError(err.Error())
	}

	hits := []Aggregate{}
	for _, val := range results.Results.Buckets {
		hits = append(hits, val)
	}

	return hits, nil
}
