package pgadapter

import (
	"context"
	"database/sql"
	"database/sql/driver"
	"strconv"

	"github.com/jackc/pgx/v4"
	"github.com/jackc/pgx/v4/stdlib"
	"github.com/jmoiron/sqlx"
	"golang.yandex/hasql/checkers"
	hasql "golang.yandex/hasql/sqlx"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/passport/backend/federal_config_api/internal/core/interfaces"
	"a.yandex-team.ru/passport/backend/federal_config_api/internal/core/models"
	"a.yandex-team.ru/passport/backend/federal_config_api/internal/tracer"
)

type postgresqlAdapter struct {
	cluster *hasql.Cluster
	logger  log.Logger
}

// compile-time проверка, что интерфейс имплементирован корректно
var _ interfaces.FederalConfigAdapter = (*postgresqlAdapter)(nil)

func openSqlxViaPgxConnPool(connectionString string) (*sqlx.DB, error) {
	connConfig, err := pgx.ParseConfig(connectionString)
	if err != nil {
		return nil, err
	}
	// без этого не работает odyssey в mdb
	connConfig.PreferSimpleProtocol = true

	nativeDB := stdlib.OpenDB(*connConfig)
	return sqlx.NewDb(nativeDB, "pgx"), nil
}

func setupCluster(connections map[string]string, logger log.Logger) (*hasql.Cluster, error) {
	nodes := make([]hasql.Node, 0, len(connections))
	for address, connectionString := range connections {
		sqlConn, err := openSqlxViaPgxConnPool(connectionString)
		if err != nil {
			return nil, err
		}
		nodes = append(nodes, hasql.NewNode(address, sqlConn))
	}
	tr := tracer.NewTracer(logger)
	cluster, err := hasql.NewCluster(
		nodes,
		checkers.PostgreSQL,
		hasql.WithTracer(tr),
	)
	if err != nil {
		return nil, err
	}
	return cluster, nil
}

func NewPostgresqlAdapter(ctx context.Context, connections map[string]string, logger log.Logger) (*postgresqlAdapter, error) {
	cluster, err := setupCluster(connections, logger)
	if err != nil {
		return nil, err
	}
	return &postgresqlAdapter{
		cluster: cluster,
		logger:  logger,
	}, nil
}

func (pga *postgresqlAdapter) wrapPgErrors(ctx context.Context, err error) error {
	if err == driver.ErrBadConn {
		return interfaces.ErrBadConn
	} else {
		return err
	}
}

func (pga *postgresqlAdapter) GetByConfigID(ctx context.Context, namespace string, configID uint64) (models.FederationConfig, error) {
	var config models.FederationConfig

	err := pga.txWrap(ctx, sql.TxOptions{ReadOnly: true}, func(tx *sql.Tx) error {
		configID, entityID, namespace, domainIDs, configBody, err := pga.txGetConfigByID(ctx, tx, GetByConfigID, strconv.FormatUint(configID, 10))
		if err != nil {
			return err
		}

		config.ConfigID = configID
		config.EntityID = entityID
		config.Namespace = namespace
		config.DomainIDs = domainIDs
		config.ConfigBody = configBody
		return nil
	})
	if err == nil && config.Namespace != namespace {
		return models.FederationConfig{}, interfaces.ErrNamespaceMismatch
	}

	return config, err
}

func (pga *postgresqlAdapter) GetByEntityID(ctx context.Context, namespace string, entityID string) (models.FederationConfig, error) {
	var config models.FederationConfig

	err := pga.txWrap(ctx, sql.TxOptions{ReadOnly: true}, func(tx *sql.Tx) error {
		configID, entityID, namespace, domainIDs, configBody, err := pga.txGetConfigByID(ctx, tx, GetByEntityID, entityID)
		if err != nil {
			return err
		}

		config.ConfigID = configID
		config.EntityID = entityID
		config.Namespace = namespace
		config.DomainIDs = domainIDs
		config.ConfigBody = configBody
		return nil
	})
	if err == nil && config.Namespace != namespace {
		return config, interfaces.ErrNamespaceMismatch
	}

	return config, err
}

func (pga *postgresqlAdapter) GetByDomainID(ctx context.Context, namespace string, domainID uint64) (models.FederationConfig, error) {
	var config models.FederationConfig

	err := pga.txWrap(ctx, sql.TxOptions{ReadOnly: true}, func(tx *sql.Tx) error {
		configID, entityID, namespace, domainIDs, configBody, err := pga.txGetConfigByID(ctx, tx, GetByDomainID, strconv.FormatUint(domainID, 10))
		if err != nil {
			return err
		}

		config.ConfigID = configID
		config.EntityID = entityID
		config.Namespace = namespace
		config.DomainIDs = domainIDs
		config.ConfigBody = configBody
		return nil
	})
	if err == nil && config.Namespace != namespace {
		return config, interfaces.ErrNamespaceMismatch
	}

	return config, err
}

func (pga *postgresqlAdapter) Create(ctx context.Context, config models.FederationConfig) (uint64, error) {
	var (
		attrs          Attributes
		entityMapping  EntityIDToConfigID
		domainMappings []DomainIDToConfigID
		retvalConfigID uint64
	)

	err := pga.txWrap(ctx, sql.TxOptions{}, func(tx *sql.Tx) error {
		attrs.SAMLConfig = config.SAMLConfig
		attrs.OAuthConfig = config.OAuthConfig
		attrs.Enabled = config.Enabled
		configID, err := pga.txAddConfig(ctx, tx, attrs)
		if err != nil {
			return err
		}
		retvalConfigID = configID

		entityMapping.EntityID = config.EntityID
		entityMapping.ConfigID = configID
		for _, domainID := range config.DomainIDs {
			domainMappings = append(
				domainMappings,
				DomainIDToConfigID{
					DomainID: domainID,
					ConfigID: configID,
				},
			)
		}

		if config.EntityID != "" {
			if err = pga.txAddEntityIDToConfigID(ctx, tx, entityMapping); err != nil {
				return err
			}
		}
		for _, domainMapping := range domainMappings {
			if err = pga.txAddDomainIDToConfigID(ctx, tx, domainMapping); err != nil {
				return err
			}
		}
		if err = pga.txAddNamespaceToConfigID(ctx, tx, NamespaceToConfigID{config.Namespace, configID}); err != nil {
			return err
		}
		return nil
	})

	return retvalConfigID, err
}

func (pga *postgresqlAdapter) List(ctx context.Context, namespace string, startConfigID uint64, limit uint64) ([]models.FederationConfig, error) {
	var federationConfigs []models.FederationConfig

	err := pga.txWrap(ctx, sql.TxOptions{ReadOnly: true}, func(tx *sql.Tx) error {
		configIDs, err := pga.txListConfigIDs(ctx, tx, namespace, startConfigID, limit)
		if err != nil {
			return err
		}
		for _, configID := range configIDs {
			cID, entityID, cNamespace, domainIDs, configBody, err := pga.txGetConfigByID(
				ctx,
				tx,
				GetByConfigID,
				strconv.FormatUint(configID, 10),
			)
			if err != nil {
				return xerrors.Errorf("error during listing configs: %w", err)
			}
			federationConfigs = append(federationConfigs, models.FederationConfig{
				ConfigID:   cID,
				EntityID:   entityID,
				Namespace:  cNamespace,
				DomainIDs:  domainIDs,
				ConfigBody: configBody,
			})
		}
		return nil
	})
	return federationConfigs, err
}

func (pga *postgresqlAdapter) Update(ctx context.Context, namespace string, entityID *string, domainIDs *[]uint64, configID uint64, configBody models.ConfigBody) error {
	var (
		existingDomainIDs []uint64
		domainsDiff       DiffUInt64
		attrs             Attributes
	)

	attrs.SAMLConfig = configBody.SAMLConfig
	attrs.OAuthConfig = configBody.OAuthConfig
	attrs.Enabled = configBody.Enabled

	err := pga.txWrap(ctx, sql.TxOptions{}, func(tx *sql.Tx) error {
		if err := pga.assertNamespace(ctx, tx, namespace, configID); err != nil {
			return err
		}
		if err := pga.txEditConfig(ctx, tx, configID, attrs); err != nil {
			return err
		}

		existingEntityID, err := pga.txGetEntityIDMapping(ctx, tx, configID)
		if err != nil {
			return err
		}

		if entityID != nil {
			if *entityID == "" {
				if err = pga.txDeleteEntityIDMapping(ctx, tx, *entityID, configID); err != nil {
					return err
				}
			} else {
				if err = pga.txUpdateEntityIDMapping(ctx, tx, existingEntityID, *entityID, configID); err != nil {
					return err
				}
			}
		}

		existingDomainIDs, err = pga.txGetDomainIDMapping(ctx, tx, configID)
		if err != nil {
			return err
		}

		if domainIDs != nil {
			domainsDiff = utilDiffDomainIDs(existingDomainIDs, *domainIDs)
			if len(domainsDiff.ToRemove) > 0 {
				if err = pga.txDeleteDomainIDsMapping(ctx, tx, domainsDiff.ToRemove, configID); err != nil {
					return err
				}
			}
			if len(domainsDiff.ToAdd) > 0 {
				for _, domainID := range domainsDiff.ToAdd {
					mapping := DomainIDToConfigID{
						DomainID: domainID,
						ConfigID: configID,
					}
					if err = pga.txAddDomainIDToConfigID(ctx, tx, mapping); err != nil {
						return err
					}
				}
			}
		}
		return nil
	})

	return err
}

func (pga *postgresqlAdapter) Delete(ctx context.Context, namespace string, configID uint64) error {
	err := pga.txWrap(ctx, sql.TxOptions{}, func(tx *sql.Tx) error {
		if err := pga.assertNamespace(ctx, tx, namespace, configID); err != nil {
			return err
		}
		if err := pga.txDeleteConfig(ctx, tx, configID); err != nil {
			return err
		}
		return nil
	})

	return err
}

func (pga *postgresqlAdapter) assertNamespace(ctx context.Context, tx *sql.Tx, expectedNamespace string, configID uint64) error {
	actualNamespace, err := pga.txGetNamespaceByConfigID(ctx, tx, configID)
	if err != nil {
		return pga.wrapPgErrors(ctx, err)
	}
	if actualNamespace != expectedNamespace {
		return interfaces.ErrNamespaceMismatch
	}
	return nil
}
