package taxidwh

import (
	"database/sql"
	"encoding/json"
	"fmt"
	"math"
	"time"

	"github.com/spf13/cobra"

	"a.yandex-team.ru/drive/analytics/gotasks"
	"a.yandex-team.ru/drive/library/go/clients/taxidwh"
	"a.yandex-team.ru/drive/library/go/gosql"
	"a.yandex-team.ru/drive/library/go/solomon"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/yt/go/ypath"
	"a.yandex-team.ru/yt/go/yt"
	"a.yandex-team.ru/zootopia/library/go/db"
	"a.yandex-team.ru/zootopia/library/go/db/events"
	"a.yandex-team.ru/zootopia/library/go/db/objects"
	"a.yandex-team.ru/zootopia/library/go/goyt"
)

// TaxiDWHCmd represents command "taxi-dwh".
var TaxiDWHCmd = cobra.Command{Use: "taxi-dwh"}

// YTTaxiDWHCmd represents command "taxi-dwh yt".
var YTTaxiDWHCmd = cobra.Command{Use: "yt"}

func init() {
	TaxiDWHCmd.PersistentFlags().String(
		"state-db", "analytics", "Name of database connection",
	)
	TaxiDWHCmd.PersistentFlags().String(
		"backend-db", "backend", "Name of database connection",
	)
	TaxiDWHCmd.PersistentFlags().String(
		"endpoint", "production", "URL of Taxi DWH endpoint",
	)
	TaxiDWHCmd.PersistentFlags().String(
		"source", "analytics", "TVM source",
	)
	TaxiDWHCmd.PersistentFlags().String(
		"target", "taxi-dwh", "TVM target",
	)
	TaxiDWHCmd.PersistentFlags().Int(
		"batch-size", 1000, "Maximal size of batch",
	)
	TaxiDWHCmd.PersistentFlags().Int(
		"min-batch-size", 50, "Minimal size of batch",
	)
	TaxiDWHCmd.PersistentFlags().String(
		"rule", "", "Rule name replacement",
	)
	gotasks.RootCmd.AddCommand(&TaxiDWHCmd)
	YTTaxiDWHCmd.PersistentFlags().String(
		"yt-proxy", "hahn", "YT proxy",
	)
	YTTaxiDWHCmd.PersistentFlags().String(
		"yt-path", "", "Path to YT table",
	)
	YTTaxiDWHCmd.PersistentFlags().Int64(
		"yt-from-id", -1, "Specified ID from which we should upload rows",
	)
	YTTaxiDWHCmd.PersistentFlags().Int64(
		"yt-to-id", -1, "Specified ID to which we should upload rows",
	)
	YTTaxiDWHCmd.PersistentFlags().String(
		"rule", "", "Rule name replacement",
	)
	TaxiDWHCmd.AddCommand(&YTTaxiDWHCmd)
}

func getDialect(conn *gosql.DB) db.Dialect {
	switch conn.Driver {
	case gosql.PostgresDriver:
		return db.Postgres
	case gosql.SQLiteDriver:
		return db.SQLite
	default:
		panic(fmt.Errorf("unsupported %q driver", conn.Driver))
	}
}

func getBeginEventID(db *gosql.DB, table string) (int64, error) {
	query, values := db.Builder.Select("taxi_dwh_state").
		Names("begin_event_id").
		Where(gosql.Column("table").Equal(table)).
		Build()
	row := db.QueryRow(query, values...)
	var eventID int64
	if err := row.Scan(&eventID); err != nil {
		return 0, err
	}
	return eventID, nil
}

func setBeginEventID(db *gosql.DB, table string, eventID int64) error {
	query, values := db.Builder.Update("taxi_dwh_state").
		Where(gosql.Column("table").Equal(table)).
		Names("begin_event_id").Values(eventID).
		Build()
	res, err := db.Exec(query, values...)
	if err != nil {
		return err
	}
	affected, err := res.RowsAffected()
	if err != nil {
		return err
	}
	if affected != 1 {
		return fmt.Errorf("unable to update event id for %q", table)
	}
	return nil
}

type deliveryImpl interface {
	RuleName() string
	EventStore(db *gosql.DB) events.ROStore
	EventDocument(event events.Event) (taxidwh.Document, error)
	ScanEvent(r goyt.TableReader) (events.Event, error)
}

type objectDeliveryImpl interface {
	RuleName() string
	ObjectStore(db *gosql.DB) objects.ROStore
	ObjectDocument(object objects.Object) (taxidwh.Document, error)
}

func deliveryMain(impl deliveryImpl) func(cmd *cobra.Command, args []string) {
	return gotasks.WrapMain(wrapDeliveryImpl(impl))
}

func objectDeliveryMain(impl objectDeliveryImpl) func(cmd *cobra.Command, args []string) {
	return gotasks.WrapMain(wrapObjectDeliveryImpl(impl))
}

func ytDeliveryMain(impl deliveryImpl) func(cmd *cobra.Command, args []string) {
	return gotasks.WrapMain(wrapYtDeliveryImpl(impl))
}

func wrapDeliveryImpl(impl deliveryImpl) func(*gotasks.Context) error {
	return func(ctx *gotasks.Context) error {
		stateDBName, err := ctx.Cmd.Flags().GetString("state-db")
		if err != nil {
			return err
		}
		backendDBName, err := ctx.Cmd.Flags().GetString("backend-db")
		if err != nil {
			return err
		}
		endpoint, err := ctx.Cmd.Flags().GetString("endpoint")
		if err != nil {
			return err
		}
		source, err := ctx.Cmd.Flags().GetString("source")
		if err != nil {
			return err
		}
		target, err := ctx.Cmd.Flags().GetString("target")
		if err != nil {
			return err
		}
		batchSize, err := ctx.Cmd.Flags().GetInt("batch-size")
		if err != nil {
			return err
		}
		minBatchSize, err := ctx.Cmd.Flags().GetInt("min-batch-size")
		if err != nil {
			return err
		}
		rule, err := ctx.Cmd.Flags().GetString("rule")
		if err != nil {
			return err
		}
		if rule == "" {
			rule = impl.RuleName()
		}
		stateDB, ok := ctx.DBs[stateDBName]
		if !ok {
			return fmt.Errorf("db %q does not exists", stateDBName)
		}
		backendDB, ok := ctx.DBs[backendDBName]
		if !ok {
			return fmt.Errorf("db %q does not exists", backendDBName)
		}
		beginEventID, err := getBeginEventID(stateDB, rule)
		if err != nil {
			return err
		}
		consumer := events.NewOrderedConsumer(
			impl.EventStore(backendDB), beginEventID,
		)
		tvm, err := ctx.GetServiceTicket(source, target)
		if err != nil {
			return err
		}
		client := taxidwh.NewClient(endpoint, tvm)
		var lastEventID int64
		var batch []taxidwh.Document
		tryUploadBatch := func() error {
			ctx.Logger.Debug(
				"Uploading documents",
				log.Int("batch_len", len(batch)),
				log.String("batch_first_id", batch[0].ID),
			)
			resp, err := client.AddDocuments(rule, batchFix(batch))
			if err != nil {
				ctx.Logger.Error(
					"Unable to add documents",
					log.Error(err),
				)
				ctx.SignalV(
					"taxi_dwh_error", 1,
					solomon.Label("type", "upload_error"),
					solomon.Label("rule", rule),
				)
				return err
			}
			for _, doc := range resp {
				if doc.Status == "warn" {
					ctx.Logger.Error(
						"Add documents warn",
						log.Any("response", doc),
					)
					ctx.SignalV(
						"taxi_dwh_error", 1,
						solomon.Label("type", "upload_warn"),
						solomon.Label("rule", rule),
					)
				} else if doc.Status != "ok" {
					ctx.Logger.Error(
						"Add documents error",
						log.Any("response", doc),
					)
					ctx.SignalV(
						"taxi_dwh_error", 1,
						solomon.Label("type", "upload_error"),
						solomon.Label("rule", rule),
					)
					return fmt.Errorf("wrong status %q", doc.Status)
				}
			}
			beginEventID = lastEventID + 1
			// Update state with new beginEventID.
			if err := setBeginEventID(
				stateDB, rule, beginEventID,
			); err != nil {
				ctx.Logger.Error(
					"Unable to set begin_event_id",
					log.String("rule", rule),
					log.Error(err),
				)
				return err
			}
			ctx.SignalV(
				"taxi_dwh", len(batch),
				solomon.Label("type", "uploaded"),
				solomon.Label("rule", rule),
			)
			return nil
		}
		uploadBatch := func() error {
			err := tryUploadBatch()
			for i := 0; i < 3 && err != nil; i++ {
				select {
				case <-time.After(5 * time.Duration(i) * time.Second):
				case <-ctx.Context.Done():
					return err
				}
				err = tryUploadBatch()
			}
			return err
		}
		if err := gosql.WithTxContext(ctx.Context, backendDB, nil, func(
			tx *sql.Tx,
		) error {
			if err := consumer.ConsumeEvents(
				tx, func(event events.Event) error {
					if len(batch) >= batchSize {
						if err := uploadBatch(); err != nil {
							return err
						}
						batch = nil
					}
					doc, err := impl.EventDocument(event)
					if err != nil {
						ctx.Logger.Error(
							"Unable to convert event",
							log.Any("event", event), log.Error(err),
						)
						ctx.SignalV(
							"taxi_dwh_error", 1,
							solomon.Label("type", "convert_error"),
							solomon.Label("rule", rule),
						)
						return err
					}
					// We should append event only after uploading batch
					// because consumer dont apply event if error is
					// returned.
					lastEventID = event.EventID()
					batch = append(batch, doc)
					return nil
				},
			); err != nil {
				return err
			}
			if len(batch) >= minBatchSize {
				if err := uploadBatch(); err != nil {
					return err
				}
				batch = nil
			}
			return nil
		}); err != nil && err != sql.ErrNoRows {
			return err
		}
		return nil
	}
}

func wrapObjectDeliveryImpl(impl objectDeliveryImpl) func(*gotasks.Context) error {
	return func(ctx *gotasks.Context) error {
		backendDBName, err := ctx.Cmd.Flags().GetString("backend-db")
		if err != nil {
			return err
		}
		endpoint, err := ctx.Cmd.Flags().GetString("endpoint")
		if err != nil {
			return err
		}
		source, err := ctx.Cmd.Flags().GetString("source")
		if err != nil {
			return err
		}
		target, err := ctx.Cmd.Flags().GetString("target")
		if err != nil {
			return err
		}
		batchSize, err := ctx.Cmd.Flags().GetInt("batch-size")
		if err != nil {
			return err
		}
		rule, err := ctx.Cmd.Flags().GetString("rule")
		if err != nil {
			return err
		}
		if rule == "" {
			rule = impl.RuleName()
		}
		backendDB, ok := ctx.DBs[backendDBName]
		if !ok {
			return fmt.Errorf("db %q does not exists", backendDBName)
		}
		tvm, err := ctx.GetServiceTicket(source, target)
		if err != nil {
			return err
		}
		client := taxidwh.NewClient(endpoint, tvm)
		var batch []taxidwh.Document
		tryUploadBatch := func() error {
			ctx.Logger.Debug(
				"Uploading documents",
				log.Int("batch_len", len(batch)),
				log.String("batch_first_id", batch[0].ID),
			)
			resp, err := client.AddDocuments(rule, batchFix(batch))
			if err != nil {
				ctx.Logger.Error(
					"Unable to add documents",
					log.Error(err),
				)
				return err
			}
			for _, doc := range resp {
				if doc.Status != "ok" {
					ctx.Logger.Error(
						"Add documents error",
						log.Any("response", doc),
					)
					return fmt.Errorf("wrong status %q", doc.Status)
				}
			}
			ctx.SignalV(
				"taxi_dwh", len(batch),
				solomon.Label("type", "uploaded"),
				solomon.Label("rule", rule),
			)
			return nil
		}
		uploadBatch := func() error {
			err := tryUploadBatch()
			for i := 0; i < 3 && err != nil; i++ {
				select {
				case <-time.After(5 * time.Duration(i) * time.Second):
				case <-ctx.Context.Done():
					return err
				}
				err = tryUploadBatch()
			}
			return err
		}
		if err := gosql.WithTxContext(ctx.Context, backendDB, nil, func(
			tx *sql.Tx,
		) error {
			rows, err := impl.ObjectStore(backendDB).ReadObjects(tx)
			if err != nil {
				return err
			}
			defer func() {
				_ = rows.Close()
			}()
			for rows.Next() {
				object := rows.Object()
				if len(batch) >= batchSize {
					if err := uploadBatch(); err != nil {
						return err
					}
					batch = nil
				}
				doc, err := impl.ObjectDocument(object)
				if err != nil {
					ctx.Logger.Error(
						"Unable to convert object",
						log.Any("object", object), log.Error(err),
					)
					ctx.SignalV(
						"taxi_dwh_error", 1,
						solomon.Label("type", "convert_error"),
						solomon.Label("rule", rule),
					)
					return err
				}
				batch = append(batch, doc)
			}
			if len(batch) > 0 {
				if err := uploadBatch(); err != nil {
					return err
				}
				batch = nil
			}
			return rows.Err()
		}); err != nil && err != sql.ErrNoRows {
			ctx.SignalV(
				"taxi_dwh_error", 1,
				solomon.Label("type", "upload_error"),
				solomon.Label("rule", rule),
			)
			return err
		}
		return nil
	}
}

func wrapYtDeliveryImpl(impl deliveryImpl) func(*gotasks.Context) error {
	return func(ctx *gotasks.Context) error {
		endpoint, err := ctx.Cmd.Flags().GetString("endpoint")
		if err != nil {
			return err
		}
		source, err := ctx.Cmd.Flags().GetString("source")
		if err != nil {
			return err
		}
		target, err := ctx.Cmd.Flags().GetString("target")
		if err != nil {
			return err
		}
		batchSize, err := ctx.Cmd.Flags().GetInt("batch-size")
		if err != nil {
			return err
		}
		minBatchSize, err := ctx.Cmd.Flags().GetInt("min-batch-size")
		if err != nil {
			return err
		}
		ytProxy, err := ctx.Cmd.Flags().GetString("yt-proxy")
		if err != nil {
			return err
		}
		ytPath, err := ctx.Cmd.Flags().GetString("yt-path")
		if err != nil {
			return err
		}
		ytFromID, err := ctx.Cmd.Flags().GetInt64("yt-from-id")
		if err != nil {
			return err
		}
		ytToID, err := ctx.Cmd.Flags().GetInt64("yt-to-id")
		if err != nil {
			return err
		}
		rule, err := ctx.Cmd.Flags().GetString("rule")
		if err != nil {
			return err
		}
		if rule == "" {
			rule = impl.RuleName()
		}
		yc, ok := ctx.YTs[ytProxy]
		if !ok {
			return fmt.Errorf("yt %q does not exists", ytProxy)
		}
		gyt := goyt.New(yc)
		reader, err := gyt.Raw().ReadTable(
			ctx.Context, ypath.Path(ytPath),
			&yt.ReadTableOptions{Unordered: false},
		)
		if err != nil {
			return err
		}
		defer func() {
			_ = reader.Close()
		}()
		tvm, err := ctx.GetServiceTicket(source, target)
		if err != nil {
			return err
		}
		client := taxidwh.NewClient(endpoint, tvm)
		var batch []taxidwh.Document
		tryUploadBatch := func() error {
			ctx.Logger.Info(
				"Uploading documents",
				log.Int("batch_len", len(batch)),
				log.String("batch_first_id", batch[0].ID),
			)
			resp, err := client.AddDocuments(rule, batchFix(batch))
			if err != nil {
				ctx.Logger.Error(
					"Unable to add documents",
					log.Error(err),
				)
				return err
			}
			for _, doc := range resp {
				if doc.Status != "ok" {
					ctx.Logger.Error(
						"Add documents error",
						log.Any("response", doc),
					)
					return fmt.Errorf("wrong status %q", doc.Status)
				}
			}
			return nil
		}
		uploadBatch := func() error {
			err := tryUploadBatch()
			for i := 0; i < 3 && err != nil; i++ {
				select {
				case <-time.After(5 * time.Duration(i) * time.Second):
				case <-ctx.Context.Done():
					return err
				}
				err = tryUploadBatch()
			}
			return err
		}
		for reader.Next() {
			if len(batch) >= batchSize {
				if err := uploadBatch(); err != nil {
					return err
				}
				batch = nil
			}
			event, err := impl.ScanEvent(reader)
			if err != nil {
				return err
			}
			if ytFromID != -1 && event.EventID() < ytFromID {
				continue
			}
			if ytToID != -1 && event.EventID() > ytToID {
				continue
			}
			doc, err := impl.EventDocument(event)
			if err != nil {
				ctx.Logger.Error(
					"Unable to convert event",
					log.Any("event", event), log.Error(err),
				)
				return err
			}
			batch = append(batch, doc)
		}
		if len(batch) >= minBatchSize {
			if err := uploadBatch(); err != nil {
				return err
			}
			batch = nil
		}
		return reader.Err()
	}
}

func batchFix(batch []taxidwh.Document) []taxidwh.Document {
	ids := map[string]int{}
	for i, document := range batch {
		ids[document.ID] = i
	}
	var newBatch []taxidwh.Document
	for i, document := range batch {
		if j, ok := ids[document.ID]; ok && j == i {
			newBatch = append(newBatch, document)
		}
	}
	return newBatch
}

func mongoFix(v interface{}) interface{} {
	data, err := json.Marshal(v)
	if err != nil {
		panic(err)
	}
	v = nil
	if err := json.Unmarshal(data, &v); err != nil {
		panic(err)
	}
	var fix func(v interface{}) interface{}
	fix = func(v interface{}) interface{} {
		switch tv := v.(type) {
		case nil:
			return nil
		case []interface{}:
			for i, v := range tv {
				tv[i] = fix(v)
			}
			return tv
		case map[string]interface{}:
			for k, v := range tv {
				tv[k] = fix(v)
			}
			return tv
		case float64:
			if tv >= float64(math.MaxInt64) {
				if uv := uint64(tv); float64(uv) == tv {
					return fmt.Sprint(uv)
				}
			}
			return tv
		default:
			return tv
		}
	}
	return fix(v)
}
