package whitelist

import (
	"encoding/json"
	"flag"
	"fmt"
	"log"
	"net"
	"net/http"
	"net/url"
	"sync"
	"time"

	"code.justin.tv/common/envoy"
)

type Whitelist struct {
	cidrs []net.IPNet
	asns  map[string]bool
	mu    sync.RWMutex
	url   *url.URL
}

const (
	updateInterval = 30 * time.Second
	clientTimeout  = 15 * time.Second
	dialerTimeout  = 5 * time.Second
)

var (
	client               *http.Client
	whitelistFetchStatus *envoy.State

	whitelistsURLPath = flag.String("whitelists-path", "whitelists.json", "Path to get the whitelists")
)

func makeURL(host string, port string, path string) *url.URL {
	return &url.URL{
		Scheme: "http",
		Host:   net.JoinHostPort(host, port),
		Path:   *whitelistsURLPath,
	}
}

// New creates a new whitelist handler fetching from the specified host
func New(host string, port string) *Whitelist {
	return &Whitelist{
		url: makeURL(host, port, *whitelistsURLPath),
	}
}

func getWhitelists(address string) ([]net.IPNet, []string, error) {
	response, err := client.Get(address)
	if err != nil {
		return nil, nil, fmt.Errorf("could not get current whitelist: %v", err)
	}
	defer response.Body.Close()

	var w struct {
		ASNWhitelist  []string
		CIDRWhitelist []string
	}

	err = json.NewDecoder(response.Body).Decode(&w)
	if err != nil {
		return nil, nil, fmt.Errorf("could not get ips: %v", err)
	}

	var res []net.IPNet
	for _, ips := range w.CIDRWhitelist {
		_, cidr, err := net.ParseCIDR(ips)
		if err != nil {
			log.Printf("could not parse cidr: %v", err)
			continue
		}
		res = append(res, *cidr)
	}

	return res, w.ASNWhitelist, nil
}

func (W *Whitelist) Run() {
	client = &http.Client{
		Transport: &http.Transport{
			Dial: (&net.Dialer{
				Timeout: dialerTimeout,
			}).Dial,
		},
		Timeout: clientTimeout,
	}

	whitelistFetchStatus = envoy.NewState("envoy_fetch", "Whitelist fetching", envoy.ConditionUnknown)

	ticker := time.NewTicker(updateInterval)
	defer ticker.Stop()

	for _ = range ticker.C {
		cidrs, asns, err := getWhitelists(W.url.String())
		if err != nil {
			whitelistFetchStatus.Critical(fmt.Sprintf("Unable to fetch whitelists from %v: %v", W.url.String(), err))
			log.Printf("unable to get whitelists: %v", err)
			continue
		}
		whitelistFetchStatus.Normal("Succesfully fetched whitelists")

		asnMap := make(map[string]bool, len(asns))
		for _, asn := range asns {
			asnMap[asn] = true
		}

		W.mu.Lock()
		W.cidrs = cidrs
		W.asns = asnMap
		W.mu.Unlock()
	}
}

func (W *Whitelist) ContainsIP(sip string) bool {
	W.mu.RLock()
	defer W.mu.RUnlock()

	// Possible future optimization here.
	// Parsing the ip for every contains seems wasteful,
	// but at the moment this is not a significant performance drain.
	ip := net.ParseIP(sip)
	if ip == nil {
		log.Printf("could not parse ip: %v", sip)
		return false
	}

	for _, c := range W.cidrs {
		if c.Contains(ip) {
			return true
		}
	}

	return false
}

// ContainsASNID returns true if the whitelist contains the specified
// ASN ID. For "AS12345 Foo" the ASN ID would be the string "12345".
func (W *Whitelist) ContainsASNID(asnID string) bool {
	W.mu.RLock()
	defer W.mu.RUnlock()

	_, isWhitelisted := W.asns[asnID]
	return isWhitelisted
}
