package iprep

import (
	"database/sql"
	"fmt"
	"hash/fnv"
	"math"
	"math/rand"
	"strconv"
	"strings"
	"sync"
	"time"

	"code.justin.tv/abuse/shodan/dbopt"
	"code.justin.tv/abuse/shodan/whitelist"

	_ "github.com/mattn/go-sqlite3" // Cause we need that driver

	geo "github.com/abh/geoip"
)

// Source is a source of reputation.
type Source interface {
	// FracViewsForIP returns how many fractional views
	// tokens from an IP are worth.
	//
	// If an IP is not considered trustworthy we'll usually count 0 views from it.
	FracViewsForIP(ip string, ipRepInfo *interface{}) (float64, error)

	// Reload reloads all external data from disk.
	Reload() error
}

type opts interface {
	GetInt(name string, def int) int
	GetInt32(name string, def int32) int32
	GetInt64(name string, def int64) int64
	GetFloat32(name string, def float32) float32
	GetFloat64(name string, def float64) float64
}

type stdSource struct {
	optMap opts

	prevCacheBust int32

	asnLoc    string
	dbLoc     string
	whitelist *whitelist.Whitelist
	asn       *geo.GeoIP
	db        *sql.DB
	repCache  map[string]float64

	queryStmt *sql.Stmt
	isDummy   bool
	sync.RWMutex
}

// If according to our database an ASN has less requests than specified
// here we are going to use 0 score instead. This is for two reasons:
//
// 1) If there's very little data the predictions made by the ASN rep queries
//    are not very trustworthy.
//
// 2) Exposing our data for small ASNs with little data makes it relatively
//    easy to reverse engineer our logic.
const defMinRequestsForScore = 57809

// If an ASN has a score >= this threshold we consider it to be very bad.
// We currently don't count views from ASNs that are classified as very bad.
const defVBadThreshold = 60

// NewSource constructs a new reputation source.
//
// asnLoc is the path to the MaxMind asn file.
// dbLoc is path to a SQLite database with the ASN score data.
func NewSource(asnLoc string, dbLoc string, whitelist *whitelist.Whitelist, optMap opts) (Source, error) {
	if optMap == nil {
		optMap = dbopt.DefaultMap("gvc.iprep.*")
	}
	s := stdSource{
		prevCacheBust: -1,
		optMap:        optMap,
		asnLoc:        asnLoc,
		dbLoc:         dbLoc,
		whitelist:     whitelist,
		isDummy:       asnLoc == "" || dbLoc == "",
		repCache:      make(map[string]float64),
	}

	err := s.Reload()
	return &s, err
}

func (s *stdSource) Reload() error {
	if s.isDummy {
		return nil
	}

	s.Lock()
	defer s.Unlock()

	asn, err := geo.Open(s.asnLoc)
	if err != nil {
		return err
	}

	db, err := sql.Open("sqlite3", s.dbLoc)
	if err != nil {
		return err
	}

	queryStmt, err := db.Prepare("select dscore, count from vba_score where version = ? and client_asn_id = ?")
	if err != nil {
		return err
	}

	s.asn = asn
	s.db = db
	s.queryStmt = queryStmt

	return nil
}

type stdSourceInfo struct {
	ASNName          string
	ASNID            string
	IsWhitelisted    bool
	Score1           int
	Score2           int
	BlendedScore     int
	Jitter           float64
	Progress         float64
	JitteredProgress float64
}

var nullInfo = stdSourceInfo{}

func (s *stdSource) FracViewsForIP(ip string, info *interface{}) (float64, error) {
	if s.isDummy {
		return 1, nil
	}

	s.RLock()
	defer s.RUnlock()

	asnName, _ := s.asn.GetName(ip)
	asnID, err := asnNameToID(asnName)
	if err == nil {
		if info == nil {
			return s.fracViewsForASNID(asnID, &nullInfo)
		}

		myInfo := stdSourceInfo{
			ASNName: asnName,
			ASNID:   asnID,
		}

		defer func() { *info = myInfo }()
		return s.fracViewsForASNID(asnID, &myInfo)
	}

	return 1, err
}

func (s *stdSource) scoreForASNID(asnID string, info *stdSourceInfo) (int, error) {
	v1 := s.optMap.GetInt("gvc.iprep.v1", -1)
	v2 := s.optMap.GetInt("gvc.iprep.v2", -1)

	t1 := s.optMap.GetInt64("gvc.iprep.t1", 0)
	t2 := s.optMap.GetInt64("gvc.iprep.t2", 0)

	var score1, score2, count1, count2 int

	if err := s.queryStmt.QueryRow(v1, asnID).Scan(&score1, &count1); err != nil {
		if err != sql.ErrNoRows {
			return 0, fmt.Errorf("Error getting score for AS%s (version %d): %#v", asnID, v1, err)
		}
	}

	if err := s.queryStmt.QueryRow(v2, asnID).Scan(&score2, &count2); err != nil {
		if err != sql.ErrNoRows {
			return 0, fmt.Errorf("Error getting score for AS%s (version %d): %#v", asnID, v2, err)
		}
	}

	minReq := s.optMap.GetInt("gvc.iprep.score_minreq", defMinRequestsForScore)

	if count1 < minReq {
		score1 = 0
	}

	if count2 < minReq {
		score2 = 0
	}

	if t1 == t2 {
		// avoid division by zero later
		return score2, nil
	}

	info.Score1 = score1
	info.Score2 = score2

	fnvHash := fnv.New32a()
	fnvHash.Write([]byte(asnID))
	scaled1 := float64(fnvHash.Sum32()) / float64(math.MaxUint32) //  0 ..  1
	scaled2 := (scaled1 - 0.5) * 2                                // -1 .. +1
	jitter := s.optMap.GetFloat64("gvc.iprep.asn_jitter", 0.2) * scaled2

	tx := time.Now().Unix()
	progress := clampFloat64(float64(tx-t1)/float64(t2-t1), 0, 1)
	jProgress := clampFloat64(progress+jitter, 0, 1)
	blended := int(float64(score2-score1)*jProgress) + score1

	info.Jitter = jitter
	info.Progress = progress
	info.JitteredProgress = jProgress
	info.BlendedScore = blended

	return blended, nil
}

func clampFloat64(val float64, min float64, max float64) float64 {
	switch {
	case val < min:
		return min
	case val > max:
		return max
	default:
		return val
	}
}

func (s *stdSource) fracViewsForASNID(asnID string, info *stdSourceInfo) (float64, error) {
	newCacheBust := s.optMap.GetInt32("gvc.iprep.cache_bust", -1)
	cacheDecay := s.optMap.GetFloat64("gvc.iprep.cache_decay", 0.001)
	bustCache := newCacheBust != s.prevCacheBust

	// Fast path
	if info == &nullInfo && !bustCache && !(rand.Float64() < cacheDecay) {
		if cached, ok := s.repCache[asnID]; ok {
			return cached, nil
		}
	}

	if bustCache {
		s.prevCacheBust = newCacheBust
		s.repCache = make(map[string]float64)
	}

	if asnID == "" || s.isWhitelisted(asnID) {
		info.IsWhitelisted = true
		return 1, nil
	}

	score, err := s.scoreForASNID(asnID, info)

	if err != nil {
		return 0, err
	}

	result := 1.0

	if score >= s.optMap.GetInt("gvc.iprep.vbad_thresh", defVBadThreshold) {
		nerf := clampFloat64(s.optMap.GetFloat64("gvc.iprep.nerf", 0.0), 0.0, 1.0)
		weighted := (1 - nerf) * 1
		result -= weighted
	}

	// Updates only for nullInfo (not from debug threads)
	if info == &nullInfo {
		s.repCache[asnID] = result
	}

	return result, nil
}

func (s *stdSource) isWhitelisted(asnID string) bool {
	return s.whitelist.ContainsASNID(asnID)
}

func asnNameToID(asnName string) (string, error) {
	if asnName == "" {
		return "", nil
	}

	if !strings.HasPrefix(asnName, "AS") {
		return "", fmt.Errorf("Failed to parse asn id from %v", asnName)
	}

	index := strings.Index(asnName, " ")
	if index == -1 {
		index = len(asnName)
	}

	if _, err := strconv.ParseUint(asnName[2:index], 10, 64); err != nil {
		return "", fmt.Errorf("Failed to parse asn id from %v", asnName)
	}

	return asnName[2:index], nil
}
