package splunk

import (
	"crypto/tls"
	"encoding/csv"
	"io"
	"io/ioutil"
	"log"
	"net/http"
	"net/url"
	"regexp"
	"strings"
	"time"

	"a.yandex-team.ru/security/osquery/osquery-consistency/internal/config"
	"a.yandex-team.ru/security/osquery/osquery-consistency/internal/misc"
	"a.yandex-team.ru/security/osquery/osquery-consistency/internal/vault"
)

type connection struct {
	username, password, hostname, jobsPath string
	client                                 *http.Client
}

func matchTag(tag string, responseBody []byte) (matched string) {
	tagLength := len("<" + tag + ">")
	re := regexp.MustCompile("<" + tag + ">.*</" + tag + ">")
	matched = string(re.Find(responseBody))
	matched = matched[tagLength : len(matched)-tagLength-1]
	// log.Printf("Successfully matched: %q\n", matched)
	return matched
}

func getSessionKey(conn connection) (sessionkey string) {
	log.Println("trying to obtain sessionkey...")
	http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
	resp, err := http.PostForm(conn.hostname+"/services/auth/login", url.Values{"username": {conn.username}, "password": {conn.password}})
	misc.ErrorCheck(err)
	defer func() { misc.ErrorCheck(resp.Body.Close()) }()
	body, err := ioutil.ReadAll(resp.Body)
	misc.ErrorCheck(err)
	sessionkey = matchTag("sessionKey", body)
	log.Printf("got sessionkey: %s", sessionkey)
	return sessionkey
}

func sendQuery(conn connection, query, sessionkey string) (result string) {
	// time.Sleep(1 * time.Second)
	data := url.Values{}
	data.Set("search", query)
	data.Set("output_mode", "csv")
	req, err := http.NewRequest("POST", conn.hostname+conn.jobsPath, strings.NewReader(data.Encode()))
	misc.ErrorCheck(err)
	req.Header.Add("Authorization", "Splunk "+sessionkey)
	req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

	resp, err := conn.client.Do(req)
	misc.ErrorCheck(err)
	defer func() { misc.ErrorCheck(resp.Body.Close()) }()
	misc.QuerySentLog(conn.hostname+conn.jobsPath, resp.StatusCode)
	body, err := ioutil.ReadAll(resp.Body)
	misc.ErrorCheck(err)

	retryFlag := true
	for retryFlag {
		resp, err = conn.client.Do(req)
		misc.ErrorCheck(err)
		defer func() { misc.ErrorCheck(resp.Body.Close()) }()
		misc.QuerySentLog(conn.hostname+conn.jobsPath, resp.StatusCode)
		if !(resp.StatusCode/100 == 2) {
			retryFlag = true
		} else {
			retryFlag = false
		}
		body, err = ioutil.ReadAll(resp.Body)
		misc.ErrorCheck(err)
		time.Sleep(1 * time.Second)
	}

	if strings.Contains(string(body), `<msg type="ERROR">`) ||
		(strings.Contains(string(body), `<msg type="FATAL">`)) {
		log.Println(body)
		log.Fatal("ERROR" + string(body))
	}
	return string(body)
}

func splitResult(searchResult string) (hosts map[string]string) {
	resultReader := csv.NewReader(strings.NewReader(searchResult))
	hosts = map[string]string{}

	for {
		record, err := resultReader.Read()
		if err == io.EOF {
			break
		}
		misc.ErrorCheck(err)
		hosts[record[0]] = record[1]
	}
	return
}

func SearchDryRun(query string) (hosts map[string]string) {
	return hosts
}

func Search(query string) (hosts map[string]string) {
	log.Println("splunk start")
	searchResult := ""
	conn := connection{
		username: vault.GetSplunkCreds().Username,
		password: vault.GetSplunkCreds().Password,
		hostname: config.GetSplunkConfig().Hostname,
		jobsPath: config.GetSplunkConfig().JobsPath,
		client:   misc.CreateHTTPClient(),
	}

	log.Printf("splunk url: %s%s", conn.hostname, conn.jobsPath)
	sessionkey := getSessionKey(conn)
	searchResult = sendQuery(conn, query, sessionkey)
	hosts = splitResult(searchResult)

	log.Printf("splunk done\n\n")
	return
}
