package exports

import (
	"context"
	"database/sql"
	"fmt"
	"reflect"

	"a.yandex-team.ru/drive/analytics/gotasks"
	"a.yandex-team.ru/drive/library/go/gosql"
	"a.yandex-team.ru/yt/go/mapreduce"
	"a.yandex-team.ru/yt/go/mapreduce/spec"
	"a.yandex-team.ru/yt/go/schema"
	"a.yandex-team.ru/yt/go/ypath"
	"a.yandex-team.ru/yt/go/yt"
	"a.yandex-team.ru/yt/go/yterrors"
	"a.yandex-team.ru/zootopia/library/go/db"
	"a.yandex-team.ru/zootopia/library/go/db/events"
	"a.yandex-team.ru/zootopia/library/go/db/objects"
	"a.yandex-team.ru/zootopia/library/go/goyt"
)

type State struct {
	BeginEventID int64 `json:"begin_event_id" yson:"begin_event_id"`
	MaxEventTime int64 `json:"max_event_time" yson:"max_event_time"`
}

const stateAttr = "_export_state"
const batchSize = 200000
const minBatchSize = 200

type Exporter struct {
	event  events.Event
	idCol  string
	store  events.Store
	yt     yt.Client
	table  ypath.Path
	db     *gosql.DB
	daily  ypath.Path
	mapper mapreduce.Job
	// BatchSize contains default size of batch.
	BatchSize int
	// MinBatchSize contains minimal size of batch.
	MinBatchSize int
}

func NewExporter(
	event events.Event, idCol string,
	yc yt.Client, ytTable ypath.YPath,
	dbConn *gosql.DB, dbTable string,
	dailyDir ypath.YPath, dailyMapper mapreduce.Job,
) *Exporter {
	return &Exporter{
		event:        event,
		idCol:        idCol,
		store:        events.NewStore(event, idCol, dbTable, db.Postgres),
		yt:           yc,
		table:        ytTable.YPath(),
		db:           dbConn,
		daily:        dailyDir.YPath(),
		mapper:       dailyMapper,
		BatchSize:    batchSize,
		MinBatchSize: minBatchSize,
	}
}

func (e *Exporter) Export(ctx *gotasks.Context) (errMain error) {
	consumer, err := e.getConsumer()
	if err != nil {
		return err
	}
	tx, err := e.yt.BeginTx(ctx.Context, nil)
	if err != nil {
		return err
	}
	defer func() {
		if errMain == nil {
			errMain = tx.Commit()
			return
		}
		_ = tx.Abort()
	}()
	var batch []events.Event
	return e.exportTx(tx, consumer, &batch)
}

func (e *Exporter) getConsumer() (events.Consumer, error) {
	var state State
	// Check that state attribute exists. If there is no such attribute we
	// assume that we should recalc all data from Backend YT exports.
	if err := e.yt.GetNode(context.TODO(), e.table.YPath().Attr(stateAttr), &state, nil); err != nil {
		// Check that error is resolve error. In other cases something
		// is wrong.
		if !yterrors.ContainsErrorCode(err, yterrors.CodeResolveError) {
			return nil, err
		}
		if err := e.initExport(); err != nil {
			return nil, err
		}
		// After initialization of export, we should reload export state.
		if err := e.yt.GetNode(
			context.TODO(), e.table.YPath().Attr(stateAttr), &state, nil,
		); err != nil {
			return nil, err
		}
	}
	return events.NewOrderedConsumer(e.store, state.BeginEventID), nil
}

type uniqueReducer struct {
	Event events.Event
}

func (r uniqueReducer) InputTypes() []interface{} {
	return []interface{}{r.Event}
}

func (r uniqueReducer) OutputTypes() []interface{} {
	return []interface{}{r.Event}
}

func (r *uniqueReducer) Do(
	ctx mapreduce.JobContext, in mapreduce.Reader, out []mapreduce.Writer,
) error {
	return mapreduce.GroupKeys(in, func(in mapreduce.Reader) error {
		return r.reduceGroup(in, out)
	})
}

func (r *uniqueReducer) reduceGroup(
	in mapreduce.Reader, out []mapreduce.Writer,
) error {
	if !in.Next() {
		return fmt.Errorf("empty input")
	}
	var row interface{}
	if err := in.Scan(&row); err != nil {
		return err
	}
	// Skip other rows.
	for in.Next() {
	}
	return out[0].Write(row)
}

func (e *Exporter) initExport() (errMain error) {
	outSchema, err := schema.Infer(e.event)
	if err != nil {
		return err
	}
	tx, err := e.yt.BeginTx(context.TODO(), nil)
	if err != nil {
		return err
	}
	defer func() {
		if errMain == nil {
			errMain = tx.Commit()
			return
		}
		_ = tx.Abort()
	}()
	mr := mapreduce.New(e.yt).WithTx(tx)
	sortSpec := spec.Spec{
		InputTablePaths: []ypath.YPath{e.table},
		OutputTablePath: e.table,
		SortBy:          []string{e.idCol},
	}
	if e.mapper == nil || len(e.daily.String()) == 0 {
		if _, err := tx.CreateNode(
			context.TODO(), e.table.Rich().SetSchema(outSchema),
			yt.NodeTable, &yt.CreateNodeOptions{
				IgnoreExisting: true,
				Attributes: map[string]interface{}{
					"schema": outSchema, stateAttr: State{},
				},
			},
		); err != nil {
			return err
		}
		sortOp, err := mr.Sort(sortSpec.Sort())
		if err != nil {
			return err
		}
		return sortOp.Wait()
	}
	nodes, err := goyt.ListDir(context.TODO(), tx, e.daily)
	if err != nil {
		return err
	}
	var tables []ypath.YPath
	for _, node := range nodes {
		tables = append(tables, e.daily.Child(node.Name))
	}
	mrSpec := spec.Spec{
		InputTablePaths: tables,
		OutputTablePaths: []ypath.YPath{
			e.table.Rich().SetSchema(outSchema),
		},
		SortBy:   []string{e.idCol},
		ReduceBy: []string{e.idCol},
		Pool:     "carsharing",
	}
	reducer := &uniqueReducer{Event: e.event}
	mrOp, err := mr.MapReduce(
		e.mapper, reducer, mrSpec.MapReduce(),
	)
	if err != nil {
		return err
	}
	if err := mrOp.Wait(); err != nil {
		return err
	}
	sortOp, err := mr.Sort(sortSpec.Sort())
	if err != nil {
		return err
	}
	if err := sortOp.Wait(); err != nil {
		return err
	}
	lastEvent, err := e.readLastEvent(tx)
	if err != nil {
		return err
	}
	if lastEvent == nil {
		return e.writeState(tx, State{})
	}
	return e.writeState(tx, State{
		BeginEventID: lastEvent.EventID() + 1,
		MaxEventTime: lastEvent.EventTime().Unix(),
	})
}

func (e *Exporter) readLastEvent(tx yt.Tx) (events.Event, error) {
	var lastRow int64
	if err := tx.GetNode(context.TODO(), e.table.Attr("row_count"), &lastRow, nil); err != nil {
		return nil, err
	}
	// If there are no events, we should return nil.
	if lastRow == 0 {
		return nil, nil
	}
	// Index starts from zero.
	lastRow--
	r, err := tx.ReadTable(
		context.Background(),
		e.table.Rich().AddRange(
			ypath.Range{Exact: &ypath.ReadLimit{RowIndex: &lastRow}},
		),
		nil,
	)
	defer func() {
		_ = r.Close()
	}()
	if err != nil {
		return nil, err
	}
	clone := reflect.New(reflect.TypeOf(e.event))
	if !r.Next() {
		return nil, sql.ErrNoRows
	}
	if err := r.Scan(clone.Interface()); err != nil {
		return nil, err
	}
	return clone.Elem().Interface().(events.Event), nil
}

func (e *Exporter) exportTx(
	tx yt.Tx, consumer events.Consumer, batch *[]events.Event,
) error {
	if _, err := tx.LockNode(context.TODO(), e.table, yt.LockExclusive, nil); err != nil {
		return err
	}
	defer func() {
		_ = tx.UnlockNode(context.TODO(), e.table, nil)
	}()
	state, err := e.readState(tx)
	if err != nil {
		return err
	}
	// Check that consumer has consistent state.
	if consumer.BeginEventID() != state.BeginEventID {
		return fmt.Errorf("consumer has inconsistent state")
	}
	if err := consumer.ConsumeEvents(e.db, func(event events.Event) error {
		if len(*batch) >= e.BatchSize {
			w, err := tx.WriteTable(
				context.Background(), e.table.YPath().Rich().SetAppend(), nil,
			)
			if err != nil {
				return err
			}
			for _, row := range *batch {
				if err := w.Write(row); err != nil {
					_ = w.Rollback()
					return err
				}
			}
			if err := w.Commit(); err != nil {
				return err
			}
			// batch is not empty, so we can get last event from it.
			lastEvent := (*batch)[len(*batch)-1]
			// Update state with new BeginEventID and MaxEventTime
			state.BeginEventID = lastEvent.EventID() + 1
			state.MaxEventTime = lastEvent.EventTime().Unix()
			// Try to persist this state to YT table.
			if err := e.writeState(tx, state); err != nil {
				return err
			}
			// Empty batch.
			*batch = nil
		}
		// We should append event only after uploading batch because
		// consumer dont apply event if error is returned.
		*batch = append(*batch, event)
		return nil
	}); err != nil && err != sql.ErrNoRows {
		return err
	}
	if len(*batch) >= e.MinBatchSize {
		w, err := tx.WriteTable(
			context.Background(), e.table.YPath().Rich().SetAppend(), nil,
		)
		if err != nil {
			return err
		}
		for _, row := range *batch {
			if err := w.Write(row); err != nil {
				_ = w.Rollback()
				return err
			}
		}
		if err := w.Commit(); err != nil {
			return err
		}
		// batch is not empty, so we can get last event from it.
		lastEvent := (*batch)[len(*batch)-1]
		// Update state with new BeginEventID and MaxEventTime.
		state.BeginEventID = lastEvent.EventID() + 1
		state.MaxEventTime = lastEvent.EventTime().Unix()
		// Try to persist this state to YT table.
		if err := e.writeState(tx, state); err != nil {
			return err
		}
		// Empty batch.
		*batch = nil
	}
	mergeSpec := spec.Spec{
		InputTablePaths: []ypath.YPath{e.table},
		OutputTablePath: e.table,
		CombineChunks:   true,
		MergeMode:       "sorted",
		Pool:            "carsharing",
	}
	mr := mapreduce.New(e.yt).WithTx(tx)
	mergeOp, err := mr.Merge(mergeSpec.Merge())
	if err != nil {
		return err
	}
	return mergeOp.Wait()
}

func (e *Exporter) readState(tx yt.Tx) (State, error) {
	var state State
	err := tx.GetNode(context.Background(), e.table.YPath().Attr(stateAttr), &state, nil)
	return state, err
}

func (e *Exporter) writeState(tx yt.Tx, state State) error {
	return tx.SetNode(context.Background(), e.table.YPath().Attr(stateAttr), state, nil)
}

const safeHistoryGap = 5000

func getSafeHistoryEventID(db *sql.DB, table, col string) (int64, error) {
	row := db.QueryRow(fmt.Sprintf("SELECT max(%q) FROM %q", col, table))
	var id int64
	if err := row.Scan(&id); err != nil {
		if err == sql.ErrNoRows {
			return 0, nil
		}
		return 0, err
	}
	id -= safeHistoryGap
	if id < 0 {
		id = 0
	}
	return id, nil
}

func ExportSimpleSnapshot(
	object objects.Object, idCol string,
	ytConn *goyt.YT, ytTable ypath.Path,
	dbConn *sql.DB, dbTable string,
) error {
	outSchema, err := schema.Infer(object)
	if err != nil {
		return err
	}
	store := objects.NewStore(object, idCol, dbTable, db.Postgres)
	if err := ytConn.WithTx(func(tx *goyt.Tx) error {
		if _, err := tx.Raw().CreateNode(
			context.TODO(), ytTable.Rich().SetSchema(outSchema),
			yt.NodeTable, &yt.CreateNodeOptions{
				IgnoreExisting: true,
				Attributes: map[string]interface{}{
					"schema": outSchema, stateAttr: State{},
				},
			},
		); err != nil {
			return err
		}
		w, err := tx.Raw().WriteTable(context.Background(), ytTable, nil)
		if err != nil {
			return err
		}
		rows, err := store.ReadObjects(dbConn)
		if err != nil {
			_ = w.Rollback()
			return err
		}
		defer rows.Close()
		for rows.Next() {
			if err := w.Write(rows.Object()); err != nil {
				_ = w.Rollback()
				return err
			}
		}
		if err := rows.Err(); err != nil {
			_ = w.Rollback()
			return err
		}
		return w.Commit()
	}); err != nil {
		return err
	}
	return nil
}

func ExportHistorySnapshot(
	object objects.Object, idCol string,
	ytConn *goyt.YT, ytTable ypath.Path,
	dbConn *sql.DB, dbTable string,
	historyTable string, historyIDCol string,
) error {
	outSchema, err := schema.Infer(object)
	if err != nil {
		return err
	}
	store := objects.NewStore(object, idCol, dbTable, db.Postgres)
	if err := ytConn.WithTx(func(tx *goyt.Tx) error {
		if _, err := tx.Raw().CreateNode(
			context.TODO(), ytTable.Rich().SetSchema(outSchema),
			yt.NodeTable, &yt.CreateNodeOptions{
				IgnoreExisting: true,
				Attributes: map[string]interface{}{
					"schema": outSchema, stateAttr: State{},
				},
			},
		); err != nil {
			return err
		}
		historyID, err := getSafeHistoryEventID(
			dbConn, historyTable, historyIDCol,
		)
		if err != nil {
			return err
		}
		w, err := tx.Raw().WriteTable(context.Background(), ytTable, nil)
		if err != nil {
			return err
		}
		rows, err := store.ReadObjects(dbConn)
		if err != nil {
			_ = w.Rollback()
			return err
		}
		defer rows.Close()
		for rows.Next() {
			if err := w.Write(rows.Object()); err != nil {
				_ = w.Rollback()
				return err
			}
		}
		if err := rows.Err(); err != nil {
			_ = w.Rollback()
			return err
		}
		if err := w.Commit(); err != nil {
			return err
		}
		return tx.Set(
			ytTable.YPath().Attr("_safe_history_event_id"),
			historyID,
		)
	}); err != nil {
		return err
	}
	return nil
}
