package main

import (
	"bytes"
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"log"
	"net/http"
	"net/url"
	"os"
	"regexp"
	"strings"
	"sync"
	"time"

	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/library/go/yandex/tvm"
	"a.yandex-team.ru/library/go/yandex/tvm/tvmauth"
	"context"
)

// ==========================================================================================

var BlackBoxHostname = "blackbox.yandex-team.ru"
var BlackBoxTvmID tvm.ClientID = 223

var Config = ConfigStruct{
	AuthAddr:  "127.0.0.1:9090",
	AdminAddr: "127.0.0.1:9091",
	Tvm: TvmDataStruct{
		CacheDir:    "/dev/shm/bbproxy",
		InitRetries: 5,
	},
	CacheMaxTime:    Duration{10 * time.Minute},
	CacheBadTime:    Duration{1 * time.Second},
	CacheMaxBadTime: Duration{10 * time.Second},
	ClientTimeout:   Duration{1 * time.Second},
	PeerTimeout:     Duration{400 * time.Millisecond},
	Verbose:         false,
}

// ==========================================================================================

var DebugMutex sync.RWMutex
var Debug = false
var LCache = NewLoginCache()
var TClient TvmClient

// ==========================================================================================

type ConfigStruct struct {
	AdminAddr       string
	AuthAddr        string
	Tvm             TvmDataStruct
	CacheMaxTime    Duration
	CacheBadTime    Duration
	CacheMaxBadTime Duration
	ClientTimeout   Duration
	PeerTimeout     Duration
	Verbose         bool
}

type TvmDataStruct struct {
	ID          tvm.ClientID
	Secret      string
	File        string
	CacheDir    string
	InitRetries int
}

type BlackBoxAnswerStruct struct {
	Status struct {
		Value string
	}
	Auth struct {
		Secure bool
	}
	Dbfields  map[string]string
	ExpiresIn int32
}

// ==========================================================================================

type Duration struct {
	time.Duration
}

func (d Duration) MarshalJSON() ([]byte, error) {
	return json.Marshal(d.String())
}

func (d *Duration) UnmarshalJSON(b []byte) error {
	var v interface{}
	if err := json.Unmarshal(b, &v); err != nil {
		return err
	}
	switch value := v.(type) {
	case float64:
		d.Duration = time.Duration(value * 1e9)
		return nil
	case string:
		var err error
		if d.Duration, err = time.ParseDuration(value); err != nil {
			return err
		}
		return nil
	default:
		return errors.New("invalid duration")
	}
}

// ==========================================================================================

func ReadConfig(path string) {
	data, err := ioutil.ReadFile(path)
	if err != nil {
		log.Fatalf("Bad config file path %s: %s", path, err.Error())
	}
	if err = json.Unmarshal(data, &Config); err != nil {
		log.Fatalf("Bad json in config file %s: %s", path, err.Error())
	}
}

func ReadTvmFile(path string, tvmData *TvmDataStruct) {
	if path != "" {
		data, err := ioutil.ReadFile(path)
		if err != nil {
			log.Printf("Bad tvm file path %s: %s", path, err.Error())
		}
		if json.Unmarshal(data, tvmData) != nil {
			log.Fatalf("Bad json in tvm file %s", path)
		}
	}
}

func getURL(urlString string,
	headers map[string]string,
	timeout time.Duration,
	closeConnection bool) ([]byte, bool) {

	client := &http.Client{Timeout: timeout}
	urlType, err := url.Parse(urlString)
	if err != nil {
		return []byte(fmt.Sprintf("Failed to encode URL %s: %s", urlString, err)), false
	}
	h := make(http.Header)
	for k, v := range headers {
		h.Set(k, v)
	}
	request := http.Request{URL: urlType, Header: h, Close: closeConnection}
	response, err := client.Do(&request)
	if err != nil {
		return []byte(fmt.Sprintf("Failed to get answer from %s: %s", urlString, err)), false
	}
	var data bytes.Buffer
	// TODO: check size and error
	_, _ = data.ReadFrom(response.Body)
	return data.Bytes(), true
}

// ==========================================================================================

type LoginCacheKey struct {
	UserIP     string
	OauthToken string
	SessionID  string
}

type LoginCacheValue struct {
	Login     string
	LifeTime  time.Time
	CacheTime time.Time
}

type LoginCache struct {
	Mutex sync.RWMutex
	Map   map[LoginCacheKey]LoginCacheValue
}

func NewLoginCache() *LoginCache {
	return &LoginCache{Map: make(map[LoginCacheKey]LoginCacheValue)}
}

func (c *LoginCache) Get(key *LoginCacheKey) (LoginCacheValue, bool) {
	c.Mutex.RLock()
	r, ok := c.Map[*key]
	c.Mutex.RUnlock()
	return r, ok
}

func (c *LoginCache) Update(key *LoginCacheKey,
	login string,
	expirationTime time.Duration) {

	c.Mutex.Lock()
	t := time.Now()
	if login == "" {
		// if BAD request is cached, but expired, set LifeTime to CacheMaxBadTime, keep it in cache for twice that time
		if _, ok := c.Map[*key]; ok {
			c.Map[*key] = LoginCacheValue{LifeTime: t.Add(Config.CacheMaxBadTime.Duration), CacheTime: t.Add(2 * Config.CacheMaxBadTime.Duration)}
		} else {
			c.Map[*key] = LoginCacheValue{LifeTime: t.Add(Config.CacheBadTime.Duration), CacheTime: t.Add(Config.CacheMaxBadTime.Duration)}
		}
	} else {
		if expirationTime > Config.CacheMaxTime.Duration || expirationTime == 0 {
			expirationTime = Config.CacheMaxTime.Duration
		}
		c.Map[*key] = LoginCacheValue{Login: login, LifeTime: t.Add(expirationTime), CacheTime: t.Add(Config.CacheMaxTime.Duration)}
	}
	c.Mutex.Unlock()
}

func (c *LoginCache) Clean() {
	c.Mutex.Lock()
	t := time.Now()
	cacheLength := len(c.Map)
	for k, v := range c.Map {
		if v.CacheTime.Before(t) {
			delete(c.Map, k)
		}
	}
	cacheNewLength := len(c.Map)
	c.Mutex.Unlock()
	if Config.Verbose || cacheLength != cacheNewLength {
		log.Printf("Cleaning Cache, records: %d -> %d", cacheLength, cacheNewLength)
	}
}

// ==========================================================================================

type AuthReply struct {
	Created    time.Time
	UserIP     string
	Headers    http.Header
	AuthAll    bool
	SessionID  string
	OauthToken string
	Writer     http.ResponseWriter
	UseTvm     bool
}

func (a *AuthReply) formatSecret(text string) string {
	DebugMutex.RLock()
	debug := Debug
	DebugMutex.RUnlock()
	if !debug {
		textBytes := []byte(text)
		size := len(text) - 1
		for i := 1; i < size; i++ {
			textBytes[i] = byte('*')
		}
		text = string(textBytes)
	}
	return text
}

func (a *AuthReply) AcceptRequest(login string, str string) {
	log.Printf("Accepted request 200 (%s@%s): %v, %s", a.formatSecret(login), a.UserIP, time.Since(a.Created), str)
	a.Writer.Header().Set("Content-type", "text/plain")
	a.Writer.Header().Set("Burne-Yandex-Login", login)
	a.Writer.Header().Set("COOKIE", a.Headers.Get("COOKIE"))
	a.Writer.Header().Set("Connection", "Close")
	a.Writer.WriteHeader(http.StatusOK)
}

func (a *AuthReply) DeclineRequest(status int, str string) {
	if a.AuthAll {
		a.AcceptRequest("None", str)
		return
	}
	log.Printf("Declined request %d: %v, %s, %s", status, time.Since(a.Created), str, a.Headers)
	a.Writer.Header().Set("Content-type", "text/plain")
	a.Writer.Header().Set("Connection", "Close")
	a.Writer.WriteHeader(status)
	_, _ = fmt.Fprintf(a.Writer, "%s\n", str)
}

var ReOAuth = regexp.MustCompile("OAuth (.+)")

func (a *AuthReply) Authorize() {
	authHeader := a.Headers.Get("AUTHORIZATION")
	if authHeader != "" {
		match := ReOAuth.FindStringSubmatch(authHeader)
		if match != nil {
			a.OauthToken = match[1]
		}
	}
	if a.OauthToken == "" {
		if a.SessionID == "" {
			a.DeclineRequest(http.StatusUnauthorized, "No SessionID in cookies")
			return
		}
	}
	cacheKey := LoginCacheKey{a.UserIP, a.OauthToken, a.SessionID}
	if value, ok := LCache.Get(&cacheKey); ok {
		lifeTime := value.LifeTime.Sub(a.Created)
		cacheTime := value.CacheTime.Sub(a.Created)
		if lifeTime > 0 {
			if value.Login == "" {
				a.DeclineRequest(http.StatusUnauthorized, fmt.Sprintf("BAD from cache, %v/%v", lifeTime, cacheTime))
				return
			} else {
				a.AcceptRequest(value.Login, fmt.Sprintf("OK from cache (%v)", lifeTime))
				return
			}
		}
	}
	a.blackBoxAuthorize(cacheKey)
}

func (a *AuthReply) blackBoxAuthorize(key LoginCacheKey) {
	var bbHeaders map[string]string
	bbQuery := url.Values{}
	bbQuery.Set("userip", a.UserIP)
	bbQuery.Set("format", "json")
	bbQuery.Set("dbfields", "accounts.login.uid")
	if a.OauthToken != "" {
		bbQuery.Set("method", "oauth")
		bbQuery.Set("oauth_token", a.OauthToken)
	} else {
		bbQuery.Set("method", "sessionid")
		bbQuery.Set("sessionid", a.SessionID)
		bbQuery.Set("host", "yandex-team.ru")
	}
	if a.UseTvm {
		ticket, ok := TClient.GetTicket()
		if !ok {
			a.DeclineRequest(http.StatusUnauthorized, "Failed to get tvm ticket for blackbox request")
			return
		}
		bbHeaders = map[string]string{"X-Ya-Service-Ticket": ticket}
	} else {
		bbHeaders = map[string]string{}
	}
	bbData, ok := getURL(fmt.Sprintf("https://%s/blackbox?%s", BlackBoxHostname, bbQuery.Encode()),
		bbHeaders,
		Config.PeerTimeout.Duration,
		false)
	if !ok {
		str := strings.Trim(string(bbData), "\n\r")
		a.DeclineRequest(http.StatusForbidden, str)
		return
	}
	var bbAnswer BlackBoxAnswerStruct
	if json.Unmarshal(bbData, &bbAnswer) != nil {
		str := strings.Trim(string(bbData), "\n\r")
		a.DeclineRequest(http.StatusForbidden, fmt.Sprintf("Failed to parse blackbox answer (%s)", str))
		return
	}
	var expiration time.Duration
	// check if cookies are valid (NEED_RESET - they are old but valid)
	if bbAnswer.Status.Value == "VALID" || bbAnswer.Status.Value == "NEED_RESET" {
		if a.OauthToken == "" && !bbAnswer.Auth.Secure {
			LCache.Update(&key, "", expiration)
			str := strings.Trim(string(bbData), "\n\r")
			a.DeclineRequest(http.StatusUnauthorized, fmt.Sprintf("Failed to validate secure cookie via blackbox (%s)", str))
			return
		}
	} else {
		LCache.Update(&key, "", expiration)
		str := strings.Trim(string(bbData), "\n\r")
		a.DeclineRequest(http.StatusUnauthorized, fmt.Sprintf("Failed to validate cookies or token via blackbox (%s)", str))
		return
	}
	login := bbAnswer.Dbfields["accounts.login.uid"]
	if login != "" {
		LCache.Update(&key, login, time.Duration(bbAnswer.ExpiresIn)*time.Second)
		a.AcceptRequest(login, "OK")
		return
	}
	a.DeclineRequest(http.StatusForbidden, fmt.Sprintf("Should not happen (%s)", bbData))
}

func authHandle(w http.ResponseWriter, r *http.Request) {
	a := AuthReply{
		Created: time.Now(),
		Writer:  w,
		AuthAll: (r.Header.Get("X-AUTH-ALL") != ""),
		UserIP:  r.Header.Get("X-REAL-IP"),
		Headers: r.Header,
		UseTvm:  (Config.Tvm.ID != 0),
	}
	if a.UserIP == "" {
		a.DeclineRequest(http.StatusForbidden, "No real user IP")
		return
	}
	if Config.Verbose {
		log.Printf("Connection from %s", a.UserIP)
	}
	cookies := strings.ReplaceAll(r.Header.Get("COOKIE"), " ", "")
	for _, cookie := range strings.Split(cookies, ";") {
		c := strings.SplitN(cookie, "=", 2)
		if c[0] == "Session_id" {
			a.SessionID = c[1]
		}
	}
	a.Authorize()
}

func authHTTP() {
	server := &http.Server{
		Addr:           Config.AuthAddr,
		Handler:        http.HandlerFunc(authHandle),
		ReadTimeout:    Config.ClientTimeout.Duration,
		WriteTimeout:   Config.ClientTimeout.Duration,
		MaxHeaderBytes: 1 << 20,
	}
	log.Fatal(server.ListenAndServe())
}

// ==========================================================================================

func adminHandle(w http.ResponseWriter, r *http.Request) {
	var logString string
	var reportString string

	if err := r.ParseForm(); err != nil {
		log.Printf("[Admin] parsing request from %s error: %s", r.RemoteAddr, err)
		return
	}
	switch r.Form.Get("debug") {
	case "on":
		logString = "debug=on"
		reportString = "Setting debug to on"
		DebugMutex.Lock()
		Debug = true
		DebugMutex.Unlock()
	case "off":
		logString = "debug=off"
		reportString = "Setting debug to off"
		DebugMutex.Lock()
		Debug = false
		DebugMutex.Unlock()
	default:
		logString = fmt.Sprintf("unknown parameters %s", r.Form)
		reportString = "Unknown request"
	}
	log.Printf("[Admin] request from %s: %s", r.RemoteAddr, logString)
	_, _ = fmt.Fprintf(w, "%s\n", reportString)
}

func adminHTTP() {
	server := &http.Server{
		Addr:           Config.AdminAddr,
		Handler:        http.HandlerFunc(adminHandle),
		ReadTimeout:    Config.ClientTimeout.Duration,
		WriteTimeout:   Config.ClientTimeout.Duration,
		MaxHeaderBytes: 1 << 20,
	}
	log.Fatal(server.ListenAndServe())
}

// ==========================================================================================

type TvmClient struct {
	client *tvmauth.Client
}

func NewTvmClient() TvmClient {

	if err := os.MkdirAll(Config.Tvm.CacheDir, 0700); err != nil {
		log.Fatalf("Failed to create cache directory: %v", err)
	}
	if err := os.Chmod(Config.Tvm.CacheDir, 0700); err != nil {
		log.Fatalf("Failed to chmod cache directory: %v", err)
	}
	sleep := 1 * time.Second
	maxSleep := 15 * time.Second
	for i := 0; i < Config.Tvm.InitRetries; i++ {
		env := tvm.BlackboxProdYateam
		options := tvmauth.NewIDsOptions(
			Config.Tvm.Secret,
			[]tvm.ClientID{BlackBoxTvmID})
		settings := tvmauth.TvmAPISettings{
			SelfID:                      Config.Tvm.ID,
			EnableServiceTicketChecking: false,
			BlackboxEnv:                 &env,
			ServiceTicketOptions:        options,
			DiskCacheDir:                Config.Tvm.CacheDir,
		}
		if c, err := tvmauth.NewAPIClient(settings, &nop.Logger{}); err != nil {
			e := err.(*tvm.Error)
			if e.Retriable {
				log.Printf("Cannot initialize tvm client: %s", e.Msg)
			} else {
				log.Fatalf("Failed to initialize tvm client (%s): %s", e.Code, e.Msg)
			}
			sleep *= 2
			if sleep > maxSleep {
				sleep = maxSleep
			}
			log.Printf("Sleeping %v before next retry", sleep)
			time.Sleep(sleep)
		} else {
			log.Printf("Successfully initialized tvm client")
			return TvmClient{c}
		}
	}
	log.Fatalf("Failed to initialize tvm client in %d retries", Config.Tvm.InitRetries)
	return TvmClient{}
}

func (c TvmClient) GetTicket() (string, bool) {
	sleep := 100 * time.Millisecond
	for i := 0; i < 3; i++ {
		ticketStr, err := c.client.GetServiceTicketForID(context.Background(), BlackBoxTvmID)
		if err != nil {
			e := err.(*tvm.Error)
			if e.Retriable {
				log.Printf("Cannot get tvm ticket (%s): %s", e.Code, e.Msg)
				time.Sleep(sleep)
			} else {
				break
			}
		} else {
			return ticketStr, true
		}
	}
	return "", false
}

// ==========================================================================================

func main() {
	log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
	if len(os.Args) != 2 {
		log.Fatalf("Usage: %s <config_file>", os.Args[0])
	}
	configFileName := os.Args[1]

	var tvmConfig TvmDataStruct
	ReadConfig(configFileName)
	ReadTvmFile(Config.Tvm.File, &tvmConfig)

	Config.Tvm.ID = tvmConfig.ID
	cfg, _ := json.Marshal(Config)
	log.Printf("Using config (tvm secret masked): %s", string(cfg))
	Config.Tvm.Secret = tvmConfig.Secret

	if Config.Tvm.ID != 0 {
		TClient = NewTvmClient()
	}
	go adminHTTP()
	go authHTTP()
	for {
		time.Sleep(5 * time.Second)
		LCache.Clean()
	}
}
