package main

import (
	"crypto/tls"
	"fmt"
	"log"
	"strings"
	"time"

	"gopkg.in/ldap.v3"
)

const LDAPTypeComputer = "computer"

type LDAPCacheRequestType uint

type LDAPCacheResponse struct {
	found bool
	dn    string
	err   error
}

type LDAPCacheRequest struct {
	Type     LDAPCacheRequestType
	Hostname string
	DN       string
	Out      chan LDAPCacheResponse
}

const (
	LDAPCacheGet LDAPCacheRequestType = iota
	LDAPCacheSet
)

type ldapConfig struct {
	Server       string
	UserDN       string
	UserPassword string
	BaseDN       string
	workers      int
	In           chan Request
	Cache        chan LDAPCacheRequest
}

func ldapWorker(config ldapConfig) {
	//log.Printf("[Debug] run ldapWorker()")
	var req Request
	var resp Response
	for {
		req = <-config.In
		switch req.Type {
		case LDAPAddComputerToGroup:
			switch dn := req.Data.(type) {
			case string:
				err := LDAPModifyGroupMembers(config, dn, true)
				if err != nil {
					resp.Data = nil
					resp.err = fmt.Errorf("ldapWorker(): %q:%w", dn, err)
				} else {
					resp.Data = dn
					resp.err = nil
				}
			default:
				resp.Data = nil
				resp.err = fmt.Errorf("ldapWorker(): bad request LDAPAddComputerToGroup, get data type %T", dn)
			}
			req.Out <- resp
			req.wg.Done()
		case LDAPFindComputerDN:
			switch hostname := req.Data.(type) {
			case string:
				dn, err := FindHostDN(hostname, config)
				if err != nil {
					resp.Data = nil
					resp.err = fmt.Errorf("ldapWorker(): %q:%w", hostname, err)
				} else {
					resp.Data = dn
					resp.err = nil
				}
			default:
				resp.Data = nil
				resp.err = fmt.Errorf("ldapWorker(): bad request LDAPFindComputerDN, get data type %T", hostname)
			}
			req.Out <- resp
			req.wg.Done()
		case LDAPGetBroGroupMembers:
			resp.Data, resp.err = LDAPGetGroupMembers(config)
			req.Out <- resp
			req.wg.Done()
		default:
			resp.Data = nil
			resp.err = fmt.Errorf("ldapWorker(): bad request type %d", req.Type)
			req.Out <- resp
			req.wg.Done()
		}
	}
}

func LDAPCache(ch chan LDAPCacheRequest) {
	cache := make(map[string]string)
	var resp LDAPCacheResponse
	for {
		req := <-ch
		switch req.Type {
		case LDAPCacheGet:
			resp.dn, resp.found = cache[req.Hostname]
			resp.err = nil
		case LDAPCacheSet:
			cache[req.Hostname] = req.DN
			resp.found = true
			resp.dn = req.DN
			resp.err = nil
		default:
			resp.found = false
			resp.dn = ""
			resp.err = fmt.Errorf("LDAPCache(): unknown request type %d", req.Type)
		}
		req.Out <- resp
	}
}

func findHostDN(hostname string, config ldapConfig) (string, error) {
	l, err := ldap.DialURL(config.Server)
	if err != nil {
		return "", fmt.Errorf("findHostDN(): dial URL: %w", err)
	}
	defer l.Close()

	err = l.StartTLS(&tls.Config{InsecureSkipVerify: true})
	if err != nil {
		return "", fmt.Errorf("findHostDN(): start TLS: %w", err)
	}

	err = l.Bind(config.UserDN, config.UserPassword)
	if err != nil {
		return "", fmt.Errorf("findHostDN(): bind: %w", err)
	}

	dn, err := findDNByCN(l, config.BaseDN, hostname, LDAPTypeComputer)
	if err != nil {
		return "", fmt.Errorf("findHostDN(): %w", err)
	}

	return dn, nil
}

func FindHostDN(hostname string, config ldapConfig) (dn string, err error) {
	cacheReq := LDAPCacheRequest{
		Type:     LDAPCacheGet,
		Hostname: hostname,
		DN:       "",
		Out:      make(chan LDAPCacheResponse),
	}
	config.Cache <- cacheReq
	resp := <-cacheReq.Out
	if resp.err != nil {
		err = fmt.Errorf("FindHostDN(): find in cache:%w", resp.err)
		return
	}

	if resp.found {
		dn = resp.dn
		return
	}

	for attempt := 0; attempt < 5; attempt++ {
		dn, err = findHostDN(hostname, config)
		if err != nil {
			log.Printf("FindHostDN(): attempt %d: %s", attempt, err.Error())
			time.Sleep(time.Duration(1*attempt+1) * time.Second)
			continue
		}

		break
	}

	if err != nil && dn != "" {
		cacheReq.Type = LDAPCacheSet
		cacheReq.Hostname = hostname
		cacheReq.DN = dn
		config.Cache <- cacheReq

		resp = <-cacheReq.Out
		if resp.err != nil {
			err = fmt.Errorf("FindHostDN(): store in cache:%w", resp.err)
		}
	}

	return
}

func findDNByCN(l *ldap.Conn, baseDN string, objectName string, objectType string) (string, error) {
	searchRequest := ldap.NewSearchRequest(
		baseDN,
		ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
		fmt.Sprintf("(&(objectCategory=%s)(CN=%s))", objectType, objectName),
		[]string{"dn", "cn"},
		nil,
	)

	sr, err := l.Search(searchRequest)
	if err != nil {
		return "", fmt.Errorf("findDNByCN(): search: %w", err)
	}

	if len(sr.Entries) == 0 {
		return "", nil
	} else if len(sr.Entries) > 1 {
		return "", fmt.Errorf("findDNByCN(): more than one %s %q found", objectType, objectName)
	}

	return sr.Entries[0].DN, nil
}

func LDAPGetGroupMembers(config ldapConfig) (members []string, err error) {
	for attempt := 0; attempt < 5; attempt++ {
		members, err = ldapGetGroupMembers(config)
		if err != nil {
			log.Printf("LDAPGetGroupMembers(): attempt %d: %s", attempt, err.Error())
			time.Sleep(time.Duration(1*attempt+1) * time.Second)
			continue
		}

		break
	}

	return
}

func ldapGetGroupMembers(config ldapConfig) (members []string, err error) {
	l, err := ldap.DialURL(config.Server)
	if err != nil {
		err = fmt.Errorf("getGroupMembers(): dial URL: %w", err)
		return
	}
	defer l.Close()

	err = l.StartTLS(&tls.Config{InsecureSkipVerify: true})
	if err != nil {
		err = fmt.Errorf("getGroupMembers(): start TLS: %w", err)
		return
	}

	err = l.Bind(config.UserDN, config.UserPassword)
	if err != nil {
		err = fmt.Errorf("findHostDN(): bind: %w", err)
		return
	}

	const groupDN = "CN=INFRAWIN-576,OU=AVS,OU=Services,DC=ld,DC=yandex,DC=ru"
	searchRequest := ldap.NewSearchRequest(
		config.BaseDN,
		ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
		fmt.Sprintf("(&(objectCategory=%s)(memberOf=%s))", LDAPTypeComputer, groupDN),
		[]string{"dn", "cn"},
		nil,
	)

	sr, err := l.Search(searchRequest)
	if err != nil {
		err = fmt.Errorf("getGroupMembers(): search: %w", err)
		return
	}

	for _, entry := range sr.Entries {
		members = append(members, strings.ToLower(entry.DN))
	}

	return
}

func LDAPModifyGroupMembers(config ldapConfig, dn string, add bool) (err error) {
	for attempt := 0; attempt < 5; attempt++ {
		err = ldapModifyGroupMembers(config, dn, add)
		if err != nil {
			log.Printf("LDAPModifyGroupMembers(): attempt %d: %s", attempt, err.Error())
			time.Sleep(time.Duration(1*attempt+1) * time.Second)
			continue
		}

		break
	}

	return
}

func ldapModifyGroupMembers(config ldapConfig, dn string, add bool) (err error) {
	l, err := ldap.DialURL(config.Server)
	if err != nil {
		err = fmt.Errorf("ldapModifyGroupMembers(): dial URL: %w", err)
		return
	}
	defer l.Close()

	err = l.StartTLS(&tls.Config{InsecureSkipVerify: true})
	if err != nil {
		err = fmt.Errorf("ldapModifyGroupMembers(): start TLS: %w", err)
		return
	}

	err = l.Bind(config.UserDN, config.UserPassword)
	if err != nil {
		err = fmt.Errorf("ldapModifyGroupMembers(): bind: %w", err)
		return
	}

	const groupDN = "CN=INFRAWIN-576,OU=AVS,OU=Services,DC=ld,DC=yandex,DC=ru"
	modify := ldap.NewModifyRequest(groupDN, nil)
	if add {
		modify.Add("member", []string{dn})
	} else {
		modify.Delete("member", []string{dn})
	}

	err = l.Modify(modify)
	if err != nil {
		err = fmt.Errorf("ldapModifyGroupMembers(): %w", err)
		return
	}

	return
}
