package resolver

import (
	"a.yandex-team.ru/solomon/libs/go/cache"
	"a.yandex-team.ru/solomon/libs/go/workerpool"

	"encoding/json"
	"fmt"
	"log"
	"net"
	"reflect"
	"strconv"
	"strings"
	"time"
)

// ==========================================================================================

type RRType int

const (
	NONE  RRType = iota // Record RDATA type is either as in case of PTR or as in case of A requests
	PTR                 // Record RDATA type: []string
	A                   // Record RDATA type: []net.IP
	CNAME               // Record RDATA type: string
	MX                  // Record RDATA type: []*net.MX
	NS                  // Record RDATA type: []*net.NS
	TXT                 // Record RDATA type: []string
)

var typeNames = []string{
	NONE:  "NONE",
	PTR:   "PTR",
	A:     "A",
	CNAME: "CNAME",
	MX:    "MX",
	NS:    "NS",
	TXT:   "TXT",
}

var typeTypes = []reflect.Type{
	PTR:   reflect.TypeOf((*[]string)(nil)).Elem(),
	A:     reflect.TypeOf((*[]net.IP)(nil)).Elem(),
	CNAME: reflect.TypeOf((*string)(nil)).Elem(),
	MX:    reflect.TypeOf((*[]*net.MX)(nil)).Elem(),
	NS:    reflect.TypeOf((*[]*net.NS)(nil)).Elem(),
	TXT:   reflect.TypeOf((*[]string)(nil)).Elem(),
}

var nameTypes = map[string]RRType{
	"NONE":  NONE,
	"PTR":   PTR,
	"A":     A,
	"CNAME": CNAME,
	"MX":    MX,
	"NS":    NS,
	"TXT":   TXT,
}

func (r RRType) String() string {
	if int(r) < len(typeNames) {
		return typeNames[r]
	}
	return "TYPE" + strconv.Itoa(int(r))
}

func (r RRType) MarshalJSON() ([]byte, error) {
	return json.Marshal(r.String())
}

func (r *RRType) UnmarshalJSON(b []byte) error {
	var s string
	var ok bool
	if err := json.Unmarshal(b, &s); err != nil {
		return err
	}
	if *r, ok = nameTypes[s]; !ok {
		return fmt.Errorf("no such RRType %s", s)
	}
	return nil
}

// ==========================================================================================

type poolJob struct {
	Request       *Request
	Record        *Record
	MinLastUpdate *time.Time
}

type Request struct {
	Name string `json:"n"`
	Type RRType `json:"t"`
}

type Record struct {
	Name  string      `json:"n"`
	Type  RRType      `json:"t"`
	RDATA interface{} `json:"r"`
	Error error       `json:"e"`
}

func (r *Record) UnmarshalJSON(b []byte) error {
	var err error
	var t struct {
		Name  string           `json:"n"`
		Type  RRType           `json:"t"`
		Raw   *json.RawMessage `json:"r"`
		Error error            `json:"e"`
	}
	if err = json.Unmarshal(b, &t); err != nil {
		return err
	}
	if int(t.Type) >= len(typeTypes) {
		return fmt.Errorf("invalid %T %v for record RDATA", t.Type, t.Type)
	}
	data := reflect.New(typeTypes[t.Type])
	if err := json.Unmarshal(*t.Raw, data.Interface()); err != nil {
		return err
	}
	r.Name = t.Name
	r.Type = t.Type
	r.RDATA = reflect.Indirect(data).Interface()
	r.Error = t.Error
	return nil
}

// ==========================================================================================

// DNS resolver. Able to serve stale items and prefetch.
//
// type Resolver interface {
//     Resolv([]*Request) []*Record
//     ResolvSimple([]string) map[string][]string
//     Purge()
//     Destroy()
//     Dump() ([]byte, error)
//     Restore([]byte) error
// }
//

type Resolver struct {
	LogPrefix    string
	FixDots      bool
	VerboseLevel int
	resolver     func(req Request) (interface{}, error)
	cache        *cache.Cache
	workerpool   *workerpool.WorkerPool
}

func NewResolver(goodCacheTime, badCacheTime, prefetchTime, cleanUpInterval time.Duration,
	cacheSize, workers int,
	serveStale bool,
	fixDots bool,
	verboseLevel int) *Resolver {

	r := &Resolver{
		LogPrefix:    "[resolver] ",
		FixDots:      fixDots,
		VerboseLevel: verboseLevel,
	}

	r.cache = cache.NewCache(
		"resolver",
		func(req interface{}) (interface{}, error) {
			return r.cacheResolver(req.(Request))
		},
		goodCacheTime,
		badCacheTime,
		prefetchTime,
		cleanUpInterval,
		serveStale,
		verboseLevel,
		cacheSize,
	)
	r.workerpool = workerpool.NewWorkerPool(
		"resolver",
		workers,
		func(req interface{}) {
			r.poolWorker(req.(*poolJob))
		},
		verboseLevel > 1,
	)
	return r
}

func (r *Resolver) log(lvl int, ts *time.Time, format string, v ...interface{}) {
	if r.VerboseLevel >= lvl {
		tsStr := ""
		if ts != nil {
			tsStr = ", " + time.Since(*ts).String()
		}
		log.Printf(r.LogPrefix+format+tsStr, v...)
	}
}

func (r *Resolver) cacheResolver(req Request) (interface{}, error) {
	var err error

	if r.resolver != nil {
		return r.resolver(req)
	}
	if r.VerboseLevel >= 2 {
		r.log(2, nil, "resolving %v for %v", req.Name, req.Type)
	}
	rec := &Record{
		Name: req.Name,
		Type: req.Type,
	}
	name := req.Name
	if req.Type != PTR && r.FixDots {
		// "name" here is always trimmed
		name = name + "."
	}
	switch req.Type {
	case PTR:
		var names []string
		names, err = net.LookupAddr(name)
		if r.FixDots {
			for i, n := range names {
				names[i] = strings.TrimRight(n, ".")
			}
		}
		rec.RDATA = names
	case A:
		var ips []net.IP
		ips, err = net.LookupIP(name)
		rec.RDATA = ips
	case CNAME:
		var cname string
		cname, err = net.LookupCNAME(name)
		if r.FixDots {
			cname = strings.TrimRight(cname, ".")
		}
		rec.RDATA = cname
	case MX:
		var mxs []*net.MX
		mxs, err = net.LookupMX(name)
		for _, mx := range mxs {
			if r.FixDots {
				mx.Host = strings.TrimRight(mx.Host, ".")
			}
		}
		rec.RDATA = mxs
	case NS:
		var nss []*net.NS
		nss, err = net.LookupNS(name)
		for _, ns := range nss {
			if r.FixDots {
				ns.Host = strings.TrimRight(ns.Host, ".")
			}
		}
		rec.RDATA = nss
	case TXT:
		var txts []string
		txts, err = net.LookupTXT(name)
		rec.RDATA = txts
	}
	rec.Error = err
	if err != nil {
		r.log(1, nil, "resolution failure %s %s: %v", name, rec.Type, err)
	}
	return rec, err
}

func (r *Resolver) poolWorker(job *poolJob) {
	req := job.Request
	if req.Type == NONE {
		if net.ParseIP(req.Name) != nil {
			req.Type = PTR
		} else {
			req.Type = A
		}
	}
	if req.Type != PTR && r.FixDots {
		req.Name = strings.TrimRight(req.Name, ".")
	}
	var cacheError error
	var rec interface{}

	if job.MinLastUpdate != nil {
		rec, cacheError = r.cache.GetForceFresh(*job.Request, job.MinLastUpdate)
	} else {
		rec, cacheError = r.cache.Get(*job.Request)
	}
	job.Record = rec.(*Record)
	job.Record.Error = cacheError
}

func (r *Resolver) Resolv(reqs []*Request, minLastUpdate *time.Time) []*Record {
	reqTime := time.Now()
	jobs := make([]interface{}, len(reqs))
	recs := make([]*Record, len(reqs))

	for i, req := range reqs {
		jobs[i] = &poolJob{
			Request:       req,
			MinLastUpdate: minLastUpdate,
		}
	}
	r.workerpool.Do(jobs)
	for i, job := range jobs {
		recs[i] = job.(*poolJob).Record
	}
	r.log(2, &reqTime, "resolve %d records", len(recs))
	return recs
}

func (r *Resolver) ResolvSimple(rs []string) map[string][]string {
	reqs := make([]*Request, len(rs))
	for i, req := range rs {
		reqs[i] = &Request{Name: req, Type: NONE}
	}
	recs := r.Resolv(reqs, nil)
	result := map[string][]string{}
	for _, rec := range recs {
		switch rec.Type {
		case A:
			netIPs := rec.RDATA.([]net.IP)
			ips := make([]string, len(netIPs))
			for i, ip := range netIPs {
				ips[i] = ip.String()
			}
			result[rec.Name] = ips
		case PTR:
			result[rec.Name] = rec.RDATA.([]string)
		default:
			result[rec.Name] = []string{}
		}
	}
	return result
}

func (r *Resolver) Purge() {
	r.log(1, nil, "purging")
	r.cache.Purge()
}

func (r *Resolver) Destroy() {
	r.log(1, nil, "destroying")
	r.workerpool.Stop(true)
	r.cache.Destroy()
}

func (r *Resolver) Dump(onlyFresh bool) ([]byte, error) {
	r.log(1, nil, "dumping %d records (only fresh = %v)", r.cache.Len(), onlyFresh)
	return r.cache.Dump(onlyFresh)
}

func (r *Resolver) Restore(inData []byte) error {
	r.log(1, nil, "restoring")
	bumpEOL := false
	err := r.cache.Restore(inData, bumpEOL, Request{}, &Record{})
	if err != nil {
		return err
	}
	r.log(1, nil, "restored %d records", r.cache.Len())
	return nil
}
