package main

import (
	"bytes"
	"context"
	"database/sql"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"os"
	"strings"
	"time"

	_ "github.com/go-sql-driver/mysql"

	msql "a.yandex-team.ru/cloud/bitbucket/public-api/yandex/cloud/mdb/mysql/v1"
	"a.yandex-team.ru/direct/infra/go-libs/pkg/dbconfig"
	logger "a.yandex-team.ru/direct/infra/go-libs/pkg/logformat"
	mdb "a.yandex-team.ru/direct/infra/go-libs/pkg/mdbgrpc"
	"a.yandex-team.ru/direct/infra/go-libs/pkg/zklib"
)

var (
	DefaultToken                     = "/etc/direct-tokens/mdb_robot-direct-admin-manager"
	logLevel                         = logger.InfoLvl
	newMasterHost, currentMasterHost string
	clusterName                      string
	mdbToken                         string
	debugMode, killedMode            bool

	zkServers, zkTokenFile string
)

func WrapFatal(msg string, args ...interface{}) {
	logger.Crit(msg, args...)
	os.Exit(1)
}

func getDatabaseConfigPath(mdbShardName string) (path string) {
	switch {
	case strings.Contains(mdbShardName, "-production"):
		path = "/etc/yandex-direct/db-config.json"
		zkTokenFile = "/etc/direct-tokens/zookeeper_direct-prod-rw"
		zkServers = "ppc-zk-1.da.yandex.ru:2181,ppc-zk-2.da.yandex.ru:2181,ppc-zk-3.da.yandex.ru:2181"
	case strings.Contains(mdbShardName, "-testing"):
		path = "/direct/np/db-config/db-config.test.json"
		zkServers = "ppc-zk-1.da.yandex.ru:2181,ppc-zk-2.da.yandex.ru:2181,ppc-zk-3.da.yandex.ru:2181"
	case strings.Contains(mdbShardName, "-dev7"):
		path = "/direct/np/db-config/db-config.dev7.json"
		zkServers = "ppc-zk-1.da.yandex.ru:2181,ppc-zk-2.da.yandex.ru:2181,ppc-zk-3.da.yandex.ru:2181"
	case strings.Contains(mdbShardName, "-devtest"):
		path = "/direct/np/db-config/db-config.devtest.json"
		zkServers = "ppc-zk-1.da.yandex.ru:2181,ppc-zk-2.da.yandex.ru:2181,ppc-zk-3.da.yandex.ru:2181"
	}
	return
}

type MysqlProcessList struct {
	ID           int
	User         string
	Host         string
	DB           string
	Command      string
	Time         int
	State        string
	Info         []byte
	RowsSent     int
	RowsExamined int
}

type ConnectionID struct {
	ID int
}

func KillOldQuery(mysqlHost string, clusterName string) (bool, error) {
	zkPath := getDatabaseConfigPath(clusterName)
	zkNode, err := ZkLoadDatabaseConfig(zkPath)
	if err != nil {
		return false, err
	}
	dbcnf := dbconfig.NewDBConfig()
	instance := strings.Split(clusterName, "-")[0]
	if err := dbcnf.LoadDBConfig(*zkNode.Data); err != nil {
		return false, err
	}
	var mysqlPasswd string
	mysqlUser, _ := dbcnf.GetParamForInstance(instance, "user")
	mysqlPort, _ := dbcnf.GetParamForInstance(instance, "port")
	if p, ok := dbcnf.GetParamForInstance(instance, "pass"); ok {
		switch v := p.(type) {
		case map[string]interface{}:
			passPath, ok := v["file"]
			if !ok {
				return false, fmt.Errorf("not found pass/file for instance %s", instance)
			}
			raw, err := ioutil.ReadFile(fmt.Sprintf("%s", passPath))
			if err != nil {
				return false, fmt.Errorf("error read %s, error %s", passPath, err)
			}
			mysqlPasswd = string(bytes.TrimSuffix(raw, []byte("\n")))
		default:
			mysqlPasswd = fmt.Sprintf("%s", v)
		}
	}

	fmt.Println(mysqlUser, mysqlPasswd, mysqlHost)

	dns := fmt.Sprintf("%s:%s@tcp(%s:%.0f)/mysql", mysqlUser, mysqlPasswd, mysqlHost, mysqlPort)
	dnsFmtLog := fmt.Sprintf("%s:***@tcp(%s:%.0f)/mysql", mysqlUser, mysqlHost, mysqlPort)
	conn, err := sql.Open("mysql", dns)
	if err != nil {
		return false, fmt.Errorf("error connect %s: %s", dnsFmtLog, err)
	}
	cntx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
	defer cancel()

	var connid ConnectionID
	row := conn.QueryRowContext(cntx, "SELECT CONNECTION_ID() as ID")
	if err := row.Scan(&connid.ID); err != nil {
		logger.Warn("error get CONNECTION_ID(): %s", err)
	}
	logger.Debug("current connection_id() = %d", connid.ID)

	rows, err := conn.QueryContext(cntx, "SHOW PROCESSLIST")

	if err != nil {
		return false, fmt.Errorf("error execute %s at %s, %s", "SHOW PROCECESSLIST", dnsFmtLog, err)
	}
	defer func() { _ = rows.Close() }()
	var pl MysqlProcessList
	killedProcessList := make(map[int]MysqlProcessList)
	for rows.Next() {
		if err := rows.Scan(&pl.ID, &pl.User, &pl.Host, &pl.DB, &pl.Command, &pl.Time,
			&pl.State, &pl.Info, &pl.RowsSent, &pl.RowsExamined); err == nil {
			if strings.EqualFold(pl.User, fmt.Sprint(mysqlUser)) && connid.ID != pl.ID {
				killedProcessList[pl.ID] = pl
			} else {
				logger.Debug("skip process %d, because user %s != %s", pl.ID, pl.User, mysqlUser)
			}
		} else {
			logger.Warn("error scan row at %s:%d: %s", mysqlHost, mysqlPort, err)
		}
	}
	logger.Info("found %d queries user %s for killed", len(killedProcessList), mysqlUser)
	if len(killedProcessList) > 0 {
		for id := range killedProcessList {
			func() {
				cntx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
				defer cancel()
				command := fmt.Sprintf("KILL %d", id)
				_, err := conn.ExecContext(cntx, command)
				if err != nil {
					logger.Warn("error %s, %s", command, err)
				}
			}()
		}
		logger.Info("queries killed")
	}
	return true, nil
}

func ZkLoadDatabaseConfig(zkPath string) (zkNode zklib.ZkNode, err error) {
	servers := strings.Split(zkServers, ",")
	zkAddress := zklib.NewZkAddress("", "", servers...)

	if len(zkTokenFile) > 0 {
		if mytoken, err := ioutil.ReadFile(zkTokenFile); err != nil {
			logger.Crit("error read %s: %s", zkTokenFile, err)
		} else {
			zkToken := string(bytes.TrimSuffix(mytoken, []byte("\n")))
			zkAddress = zklib.NewZkAddressWithToken(zkToken, servers...)
		}
	}

	zkconn, err := zklib.NewZkConnect(zkAddress)
	if err != nil {
		return zkNode, fmt.Errorf("error connect to %s, error %s", zkServers, err)
	}
	defer zkconn.Close()

	zkNode = zklib.NewZkNode(zkPath)
	if zkNode, err = zkconn.LoadNode(zkNode); err != nil {
		return zklib.ZkNode{}, err
	}
	return zkNode, nil
}

func main() {
	flag.StringVar(&newMasterHost, "new-master", "", "new master host")
	flag.StringVar(&clusterName, "cluster-name", "", "mdb cluster name")
	flag.StringVar(&mdbToken, "token", DefaultToken, "token for connect to mdb-api")
	flag.BoolVar(&debugMode, "debug", false, "debug mode")
	flag.BoolVar(&killedMode, "kill", false, "killed old queries")
	flag.Parse()

	if debugMode {
		logLevel = logger.DebugLvl
	}

	logDirectFormat := logger.DirectMessageFormat("direct.back", "mysql-change-master")

	newlogger := log.New(os.Stdout, "", 0)
	logger.NewLoggerFormat(
		logger.WithLogger(newlogger),
		logger.WithLogLevel(logLevel),
		logger.WithLogFormat(logDirectFormat),
	)

	if len(clusterName) == 0 {
		WrapFatal("empty cluster name, use -cluster-name <mdb_mysql_name>")
	}

	conf, err := mdb.LoadCloudKey(mdbToken)
	if err != nil {
		WrapFatal("error load %s: %s", mdbToken, err)
	}

	tokenGenerator, err := mdb.NewIamTokenGenerator(conf)
	if err != nil {
		logger.Warn("error NewIamTokenGenerator: %s", err)
	}

	conn, cancel, err := mdb.NewDealer("", mdb.AIM{})
	defer cancel()
	if err != nil {
		log.Fatal(err)
	}
	if gen, ok := tokenGenerator.(mdb.AimTokenGenerator); ok {
		req, _ := gen.IAMTokenRequest()
		logger.Debug("jreq: %+v error: %s\n", req, err)
		token, err := gen.NewIamToken(conn)
		if err != nil {
			WrapFatal(err.Error())
		}
		logger.Debug("expire: %s, string: %s\n", token.GetExpiresAt(), token.String())

		conn, _, err = mdb.NewDealer("", mdb.NewAIM(token))
		if err != nil {
			WrapFatal(err.Error())
		}

		clusters, err := mdb.ListClusters(conn, "")
		if err != nil {
			logger.Warn("error ListClusters: %s", err)
		}

		cluster, err := clusters.FindByClusterName(clusterName)
		if err != nil {
			WrapFatal(err.Error())
		}

		if len(newMasterHost) == 0 {
			clusterHosts := make(map[int]*msql.Host)
			hosts, err := mdb.ListHosts(conn, cluster.Id)
			if err != nil {
				WrapFatal(err.Error())
			}
			var hostNumber int
			for _, host := range hosts.Hosts {
				if host == nil {
					continue
				}
				hostNumber++
				clusterHosts[hostNumber] = host
			}
			if len(clusterHosts) == 0 {
				WrapFatal("empty list hosts for cluster %s", cluster.Id)
			}

			fmt.Println("choose number host to master switch")
			for i := 1; i < len(clusterHosts)+1; i++ {
				host := clusterHosts[i]
				if strings.EqualFold(host.GetRole().String(), "MASTER") {
					currentMasterHost = host.GetName()
				}
				fmt.Printf("%d: %s (role: %s, alive: %s)\n", i, host.GetName(), host.GetRole(), host.GetHealth())
			}
			_, _ = fmt.Fscan(os.Stdin, &hostNumber)
			if v, ok := clusterHosts[hostNumber]; ok {
				newMasterHost = v.GetName()
				if strings.EqualFold(v.GetRole().String(), "MASTER") {
					logger.Info("current master are %s, skip switch", newMasterHost)
					os.Exit(0)
				}
				if status := v.GetHealth().String(); !strings.EqualFold(status, "ALIVE") {
					logger.Warn("new master %s not alive: %s, skip switch", newMasterHost, status)
					os.Exit(0)
				}
			} else {
				WrapFatal("wrong number host")
			}
		}

		logger.Info("start switch master to %s -> %s", currentMasterHost, newMasterHost)

		operation, err := mdb.SwitchMaster(conn, cluster.Id, newMasterHost)
		if err != nil {
			WrapFatal(err.Error())
		}

		t := time.NewTicker(5 * time.Second)

		after := time.After(120 * time.Second)
		for range t.C {
			select {
			case <-after:
				t.Stop()
			default:
			}
			ok, err := mdb.GetOperationStatus(conn, operation)
			if ok {
				logger.Info("operation %s done", operation.GetId())
				break
			} else if !ok && err == nil {
				logger.Info("operation %s progress", operation.GetId())
			} else {
				logger.Warn("operation %s failed, %s", operation.GetId(), err)
				os.Exit(1)
			}
		}

		if killedMode {
			if ok, err := KillOldQuery(currentMasterHost, clusterName); !ok && err != nil {
				logger.Crit("error KillOldQuery, %s", err)
			}
		}

	}
}
