package mysql

import (
	"context"
	"database/sql"
	"fmt"
	"strconv"
	"strings"
	"time"

	"github.com/jmoiron/sqlx"

	"a.yandex-team.ru/passport/infra/daemons/yasms_internal/internal/model"
	"a.yandex-team.ru/passport/shared/golibs/logger"
)

type queryMetaData struct {
	TableName string
	Type      SetType
	Payload   string
	ID        []model.EntityID
}

type queryHolder struct {
	Query         string
	Args          []interface{}
	AffectedRows  int64
	InsertedID    *model.EntityID
	QueryMetaData queryMetaData
}

type auditLogRowParams struct {
	BulkID   uint64         `db:"bulk_id"`
	Table    string         `db:"table_name"`
	Type     string         `db:"type"`
	Payload  string         `db:"payload"`
	EntityID model.EntityID `db:"entity_id"`
}

type updateWithEventID struct {
	EntityID model.EntityID `db:"id"`
	EventID  uint64         `db:"event_id"`
}

type queryBatch []*queryHolder

const auditLogBulkQuery = `
	INSERT INTO audit_bulk (timestamp, author, issue, comment) VALUES (:timestamp, :author, :issue, :comment);
`

const insertAuditRow = `
	INSERT INTO audit_row (bulk_id, table_name, type, payload, entity_id) VALUES (:bulk_id, :table_name, :type, :payload, :entity_id);
`
const updateRowWithEventModifyTemplate = `
	UPDATE %s SET event_modify=:event_id where %s=:id;
`

const updateRowWithEventCreateTemplate = `
	UPDATE %s SET event_modify=:event_id, event_create=:event_id where %s=:id;
`

func (holder *queryHolder) String() string {
	return fmt.Sprintf(
		"query: %s\targs: %v\tlen: %d",
		strings.ReplaceAll(holder.Query, "\n", " "), holder.Args, holder.AffectedRows)
}

func (batch *queryBatch) append(query *queryHolder) {
	if query != nil {
		*batch = append(*batch, query)
	}
}

func (batch queryBatch) execSingleTransaction(db *sqlx.DB, ctx context.Context) error {
	return batch.execSingleTransactionWithAuditLog(db, ctx, nil)
}

type auditLogBulkDBParams struct {
	Author    string `db:"author"`
	Issue     string `db:"issue"`
	Comment   string `db:"comment"`
	Timestamp int64  `db:"timestamp"`
}

func (batch queryBatch) execAuditLogBulkQuery(ctx context.Context, tx *sql.Tx, auditLogBulkParams *model.AuditLogBulkParams) (int64, error) {
	// insert to audit_bulk. one query for the whole batch. use bulk_id in audit_row's queries
	if auditLogBulkParams == nil {
		return 0, nil
	}
	var auditLogBulkDBParams auditLogBulkDBParams
	auditLogBulkDBParams.Timestamp = time.Now().Unix()
	// made copy because of this join, mysql hasn't arrays as type
	auditLogBulkDBParams.Issue = strings.Join(auditLogBulkParams.Issue, ",")
	auditLogBulkDBParams.Author = auditLogBulkParams.Author
	auditLogBulkDBParams.Comment = auditLogBulkParams.Comment

	bulkQuery, bulkArgs, err := prepareNamedQuery(
		auditLogBulkQuery,
		auditLogBulkDBParams,
	)
	if err != nil {
		return 0, fmt.Errorf("failed to prepare query: %s; %s", err, auditLogBulkQuery)
	}

	auditBulkResult, err := tx.ExecContext(ctx, bulkQuery, bulkArgs...)
	if err != nil {
		return 0, fmt.Errorf("failed to exec query: %s; %s", err, auditLogBulkQuery)
	}
	bulkID, err := auditBulkResult.LastInsertId()
	if err != nil {
		return bulkID, fmt.Errorf("failed to get last inserted bulk ID: %s; %s", err, auditLogBulkQuery)
	}
	return bulkID, nil
}

var indexFieldByTableName = map[string]string{
	routesTableName:        "ruleid",
	gatesTableName:         "gateid",
	fallbacksTableName:     "id",
	blockedphonesTableName: "blockid",
	regionsTableName:       "id",
	templatesTableName:     "id",
}

func (batch queryBatch) updateDataRowWithEventDates(ctx context.Context, tx *sql.Tx, eventID int64, entityID model.EntityID, setType SetType, tableName string) error {
	var updateRowWithEventIDTemplate string
	switch setType {
	case Update:
		updateRowWithEventIDTemplate = updateRowWithEventModifyTemplate
	case Insert:
		updateRowWithEventIDTemplate = updateRowWithEventCreateTemplate
	}

	indexField, ok := indexFieldByTableName[tableName]
	if !ok {
		return fmt.Errorf("failed to prepare query. unknown table: %s", tableName)
	}

	updateRowWithEventID := fmt.Sprintf(updateRowWithEventIDTemplate, tableName, indexField)

	updateRowQuery, updateRowArgs, err := prepareNamedQuery(
		updateRowWithEventID,
		&updateWithEventID{
			EntityID: entityID,
			EventID:  uint64(eventID),
		},
	)
	if err != nil {
		return fmt.Errorf("failed to prepare query: %s; %s", err, updateRowQuery)
	}

	_, err = tx.ExecContext(ctx, updateRowQuery, updateRowArgs...)
	if err != nil {
		return fmt.Errorf("failed to exec query: %s; %s", err, updateRowQuery)
	}
	return nil
}

func (batch queryBatch) execAuditLogRowQuery(ctx context.Context, tx *sql.Tx, bulkID uint64, entityID model.EntityID, QueryMetaData *queryMetaData) (sql.Result, error) {
	auditLogQuery, auditLogArgs, err := prepareNamedQuery(
		insertAuditRow,
		&auditLogRowParams{
			Table:    QueryMetaData.TableName,
			Type:     QueryMetaData.Type.String(),
			EntityID: entityID,
			Payload:  QueryMetaData.Payload,
			BulkID:   bulkID,
		},
	)
	if err != nil {
		return nil, fmt.Errorf("failed to prepare query: %s; %s", err, auditLogQuery)
	}

	auditRowResult, err := tx.ExecContext(ctx, auditLogQuery, auditLogArgs...)
	if err != nil {
		return nil, fmt.Errorf("failed to exec query: %s; %s", err, auditLogQuery)
	}
	return auditRowResult, nil
}

func (batch queryBatch) execSingleTransactionWithAuditLog(db *sqlx.DB, ctx context.Context, auditLogBulkParams *model.AuditLogBulkParams) error {
	var affectedRowsTotal int64

	committed := false
	tx, err := db.BeginTx(ctx, nil)
	if err != nil {
		return fmt.Errorf("failed to start transaction: %s", err)
	}
	defer func() {
		if committed {
			return
		}
		if err := tx.Rollback(); err != nil {
			logger.Log().Errorf("failed to rollback changes: %s, rows affected: %d", err.Error(), affectedRowsTotal)
		}
	}()

	writeToAuditLog := auditLogBulkParams != nil

	bulkID, err := batch.execAuditLogBulkQuery(ctx, tx, auditLogBulkParams)
	if err != nil {
		return err
	}

	for _, query := range batch {
		// exec main query with data
		result, err := tx.ExecContext(ctx, query.Query, query.Args...)
		if err != nil {
			return fmt.Errorf("failed to exec query: %s; %s", err, query)
		}

		affectedRows, err := result.RowsAffected()
		if err != nil {
			return fmt.Errorf("failed to get the number of affected rows: %s; %s", err, query)
		}
		affectedRowsTotal += affectedRows

		if affectedRows != query.AffectedRows {
			return fmt.Errorf(
				"expected to affect %d rows, actually affected: %d; %s",
				query.AffectedRows, affectedRows, query)
		}

		// for insert query
		if query.InsertedID != nil {
			id, err := result.LastInsertId()
			if err != nil {
				return fmt.Errorf("failed to get last inserted ID: %s; %s", err, query)
			}
			insertResult := strconv.FormatInt(id, 10)
			*query.InsertedID = insertResult
			if query.QueryMetaData.Type == Insert {
				query.QueryMetaData.ID = []model.EntityID{insertResult}
			}
		}

		if writeToAuditLog {
			// insert to audit row
			for _, id := range query.QueryMetaData.ID {
				auditRowResult, err := batch.execAuditLogRowQuery(ctx, tx, uint64(bulkID), id, &query.QueryMetaData)
				if err != nil {
					return err
				}
				// update data rows with event_create and event_modify dates
				if query.QueryMetaData.Type.NeedUpdateData() {
					eventID, err := auditRowResult.LastInsertId()
					if err != nil {
						return fmt.Errorf("failed to get audit event id: %s", err)
					}
					err = batch.updateDataRowWithEventDates(ctx, tx, eventID, id, query.QueryMetaData.Type, query.QueryMetaData.TableName)
					if err != nil {
						return err
					}
				}
			}
		}
	}

	err = tx.Commit()
	if err != nil {
		return fmt.Errorf("failed to commit transaction: %s", err)
	}
	committed = true

	return nil
}

type deleteQueryParams struct {
	Delete []model.EntityID `db:"delete"`
}

func newDeleteQuery(queryTemplate string, toDelete []model.EntityID) (*queryHolder, error) {
	if len(toDelete) == 0 {
		return nil, nil
	}

	query, args, err := prepareNamedQuery(
		queryTemplate,
		&deleteQueryParams{
			Delete: toDelete,
		})
	if err != nil {
		return nil, err
	}

	return &queryHolder{
		Query:        query,
		Args:         args,
		AffectedRows: int64(len(toDelete)),
	}, nil
}

func prepareDeleteQuery(batch *queryBatch, queryTemplate string, tableName string, toDelete []model.EntityID, setType SetType) error {
	if len(toDelete) == 0 {
		return nil
	}
	query, err := newDeleteQuery(queryTemplate, toDelete)
	if err != nil {
		return err
	}
	query.QueryMetaData = queryMetaData{
		ID:        toDelete,
		TableName: tableName,
		Type:      setType,
		Payload:   "",
	}
	batch.append(query)
	return nil
}
