package main

import (
	"bytes"
	"encoding/json"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"strings"
	"time"

	_ "github.com/go-sql-driver/mysql"
	"github.com/jmoiron/sqlx"

	s "a.yandex-team.ru/direct/infra/dt-db-manager/pkg/support"
	y "a.yandex-team.ru/direct/infra/dt-db-manager/pkg/yttransfer"
)

var usage = "Программа для обновления states у LogbrokerWriter. На вход требуется файл с указанием инстансов и новых gtid позиций.\n" +
	" \tФормат файла:\n [{\"host\":\"ppctest-mysql02e.da.yandex.ru\",\"instance\":\"ppcdata15\",\"port\":3355," +
	" \"group-name\":\"testing\",\"replica-name\":\"\"},{\"host\":\"ppctest-mysql02e.da.yandex.ru\", ... }]\n" +
	" \tПримеры команд:\n" +
	" \tbinlogwriter-update -account direct  -yt-token /etc/direct-tokens/yt_robot-direct-yt-test" +
	" -mysql-token /etc/direct-tokens/mysql_direct-test -mysql-user direct-test -instances-file ./replicas-plan.testing" +
	" -broker-type logbrokerwriter-json\n" +
	" \tbinlogwriter-update -account direct  -yt-token /etc/direct-tokens/yt_robot-direct-yt-test" +
	" -mysql-token /etc/direct-tokens/mysql_direct-test -mysql-user direct-test -instances-file ./replicas-plan.testing" +
	"-broker-type logbrokerwriter\n"

func PrintUsage() {
	fmt.Println(usage)
	flag.PrintDefaults()
}

type MysqlConnect struct {
	s.ReplicasFormat
	Number int
	Conn   *sqlx.DB
	Error  error
}

type MysqlConnects []MysqlConnect

type MasterStatus struct {
	File            string `db:"File"`
	Position        int    `db:"Position"`
	BinlogDoDB      string `db:"Binlog_Do_DB"`
	BinlogIgnoreDB  string `db:"Binlog_Ignore_DB"`
	ExecutedGtidSet string `db:"Executed_Gtid_Set"`
}

type Counters struct {
	SeqNo     uint64 `yson:"seq_no"`
	Set       string `yson:"set"`
	Timestamp uint64 `yson:"hwm_timestamp"`
}

func (c Counters) GetGtid() (result string) {
	return strings.ReplaceAll(c.Set, "\n", "")
}

type WriterStates map[string]*Counters
type MysqlStates map[s.ReplicasFormat]Counters

func (ws WriterStates) CounterForInstance(instance s.ReplicasFormat) (result *Counters) {
	for key, value := range ws {
		key = strings.SplitN(key, ":", 2)[1]
		key = strings.Split(key, "-")[0]
		if instance.ShortName() == key {
			return value
		}
	}
	return
}

func readStatesFile(path string) (states WriterStates, err error) {
	states = make(WriterStates)
	if data, err := ioutil.ReadFile(path); err != nil {
		return states, fmt.Errorf("error read file %s: %s", path, err)
	} else {
		err = json.Unmarshal(data, &states)
		if err != nil {
			return states, fmt.Errorf("error unmarshal %s: %s", data, err)
		}
	}
	return
}

func readStatesYT(client y.YtConnect, ytdir string) (states WriterStates, err error) {
	var ytfiles []string
	var errmsg bytes.Buffer
	states = make(WriterStates)
	ytfiles, err = client.ListYtNode(ytdir)
	if err != nil {
		err = fmt.Errorf("error listing nodes %s: %s", ytdir, err)
		return
	}
	for _, name := range ytfiles {
		if !strings.Contains(name, "-state") {
			continue
		}
		var value Counters
		ytfile := fmt.Sprintf("%s/%s", ytdir, name)
		if err := client.GetAttribute(ytfile, "counters", &value); err != nil {
			errmsg.WriteString(fmt.Sprintf("error get attribute counters for %s: %s", ytfile, err))
			continue
		}
		states[name] = &value
	}
	if errmsg.Len() != 0 {
		return states, fmt.Errorf(errmsg.String())
	}
	return
}

func writeStatesYT(client y.YtConnect, ytdir string, ytStates WriterStates, mysqlStates MysqlStates) (err error) {
	//update YT State and upload
	var errmsg bytes.Buffer
	for name, mysqlGtid := range mysqlStates {
		ytGtid := ytStates.CounterForInstance(name)
		//fmt.Println(name, ytGtid, mysqlGtid)
		if ytGtid != nil {
			fmt.Printf("instance %s, mysql gtid: %s, yt gtid: %s\n", name.Instance, mysqlGtid.GetGtid(), (*ytGtid).GetGtid())
		} else {
			fmt.Printf("not found mysql %s in binlogwriter states\n", name.Instance)
			continue
		}
		(*ytGtid).Set = mysqlGtid.GetGtid()
	}

	for name, attr := range ytStates {
		if !strings.Contains(name, "-state") {
			continue
		}
		ytFile := fmt.Sprintf("%s/%s", ytdir, name)
		if err := client.SetAttribute(ytFile, "counters", &attr); err != nil {
			msg := fmt.Sprintf("error get attribute counters for %s: %s", ytFile, err)
			errmsg.WriteString(msg)
		}
	}
	if errmsg.Len() != 0 {
		return fmt.Errorf("%s", errmsg.String())
	}
	return
}

func writeStatesFile(states WriterStates, writerType, group string) error {
	jdata, err := json.Marshal(states)
	if err != nil {
		return fmt.Errorf("error marshal %v: %s", states, err)
	}
	savedPath := fmt.Sprintf("/tmp/%s-%s-%d.states", writerType, group, time.Now().Unix())
	log.Printf("save old state to %s", savedPath)
	if err := ioutil.WriteFile(savedPath, jdata, 0664); err != nil {
		return fmt.Errorf("error write states to %s: %s", savedPath, err)
	}
	return nil
}

func readStatesMysql(mysqlListFile, tokenFile, user, dbname string) (mstate MysqlStates, err error) {
	mydata, err := ioutil.ReadFile(mysqlListFile)
	if err != nil {
		err = fmt.Errorf("error read file %s: %s", mysqlListFile, err)
		return
	}
	var instances s.ReplicasFormats
	if err = json.Unmarshal(mydata, &instances); err != nil {
		err = fmt.Errorf("error unmarshal %s: %s", mydata, err)
		return
	}

	passord, err := ioutil.ReadFile(tokenFile)
	if err != nil {
		err = fmt.Errorf("error read %s: %s", tokenFile, err)
		return
	}

	if len(instances) == 0 {
		err = fmt.Errorf("empty list instance into file %s", mysqlListFile)
		return
	}

	var connects MysqlConnects
	var errmsg bytes.Buffer
	for num, instance := range instances {
		if strings.Contains(instance.Instance, "sandbox") ||
			strings.Contains(instance.Instance, "ppclog") ||
			strings.Contains(instance.Instance, "monitor") {
			continue
		}
		dbConnString := fmt.Sprintf("%s:%s@(%s:%d)/%s", user, passord, instance.Host, instance.Port, dbname)
		conn, err := sqlx.Connect("mysql", dbConnString)
		mc := MysqlConnect{instance,
			num,
			conn,
			err,
		}

		connects = append(connects, mc)
		if err != nil {
			msg := fmt.Sprintf("error connect %s:%d: %s\n", instance.Host, instance.Port, err)
			fmt.Println(err)
			errmsg.WriteString(msg)
			continue
		}
		defer func() { _ = conn.Close() }()
	}

	mstate = make(MysqlStates)
	for _, conn := range connects {
		var out []MasterStatus
		if conn.Error != nil {
			continue
		}
		if err := conn.Conn.Select(&out, "SHOW MASTER STATUS"); err != nil {
			msg := fmt.Sprintf("error show master for interface %s: %s\n", conn.Instance, err)
			errmsg.WriteString(msg)
		}
		mstate[conn.ReplicasFormat] = Counters{Set: out[0].ExecutedGtidSet}
	}

	if errmsg.Len() > 0 {
		return nil, fmt.Errorf("%s", errmsg.String())
	}
	return mstate, nil
}

func main() {
	ytDirectory := flag.String("yt-directory", "//home/direct/test/binlogbroker", "путь куда привезти бекап")
	brokerType := flag.String("broker-type", "logbrokerwriter", "")
	group := flag.String("group", "testing", "")
	ytCluster := flag.String("cluster", "freud", "кластер YT для восстановления бекапов")
	ytAccount := flag.String("account", "direct", "аккаунт в YT")
	ytToken := flag.String("yt-token", "/tmp/pass1", "токен для подключения к YT")

	statesFile := flag.String("states-file", "", "")
	mysqlToken := flag.String("mysql-token", "/tmp/pass2", "токен для подключения к mysql")
	mysqlUser := flag.String("mysql-user", "default", "user для подключения к mysql")
	mysqlDBName := flag.String("mysql-db", "mysql", "dbname для подключения к mysql")
	instancesFile := flag.String("instances-file", "", "файл с инстансами для подключения")
	onlyUpdateStates := flag.Bool("only-update-states", false, "обновить GTID стейты, не удалять ноды")

	flag.Usage = PrintUsage
	flag.Parse()

	ytWorkdir := fmt.Sprintf("%s/%s/%s", *ytDirectory, *brokerType, *group)
	ytBackupdir := ytWorkdir + ".old"

	ytreplicator := s.YtReplicator{
		SourceCluster:      "",
		SourceDir:          "",
		DestinationCluster: *ytCluster,
		DestinationDir:     ytWorkdir,
		Account:            *ytAccount,
		YTTokenFile:        *ytToken,
		WriterYTDir:        "",
		WriterYTCluster:    "",
		WriterHosts:        s.Servers{},
	}

	ytconnector, err := s.NewGroupYtReplicator(ytreplicator)
	if err != nil {
		log.Fatal(err)
	}

	client := ytconnector.DestinationClient

	if *onlyUpdateStates {
		//load state from YT or backup saved state
		var ytStates WriterStates
		if len(*statesFile) == 0 {
			ytStates, err = readStatesYT(client, ytWorkdir)
		} else {
			ytStates, err = readStatesFile(*statesFile)
		}
		if err != nil {
			log.Fatalf("%s", err)
		}

		//load states from MySQL
		mysqlStates, err := readStatesMysql(*instancesFile, *mysqlToken, *mysqlUser, *mysqlDBName)
		if err != nil {
			log.Fatalf("error readStatesMysql: %s", err)
		}

		//saved current state in file
		//<//home/direct/test/binlogbroker(ytDirectory)/logbrokerwriter(brokerType)/testing(group)/<states>
		err = writeStatesFile(ytStates, *brokerType, *group)
		if err != nil {
			log.Fatalf("error writeStatesFile: %s", err)
		}

		//saved current state in YT
		if err = writeStatesYT(client, ytWorkdir, ytStates, mysqlStates); err != nil {
			log.Fatalf("erro writeStatesYT: %s", err)
		}
	} else {
		if ok, err := client.NodeYtExists(ytWorkdir); ok && err == nil {
			currentNodes, err := client.RecurseListYtNode(ytWorkdir)
			if err != nil {
				log.Fatalf("error recurse list current dir %s: %s", ytWorkdir, err)
			}
			if len(currentNodes) > 0 {
				if ok, err := client.NodeYtExists(ytBackupdir); ok && err == nil {
					if err := client.RemoveYtNode(ytBackupdir, false, true); err != nil {
						log.Fatalf("error delete dir %s: %s", ytBackupdir, err)
					}
					log.Printf("success remove %s", ytBackupdir)
				}
				if _, err := client.MoveYtNode(ytWorkdir, ytBackupdir, false, true); err != nil {
					log.Fatalf("error move dir %s --> %s: %s", ytWorkdir, ytBackupdir, err)
				}
				log.Printf("succes move %s --> %s", ytWorkdir, ytBackupdir)
			} else {
				log.Printf("current directory %s are empty. Skip move it!", ytWorkdir)
			}
		}
	}
}
