package database

import (
	"context"
	"fmt"
	"os"
	"strings"

	"a.yandex-team.ru/library/go/core/log"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
	"go.mongodb.org/mongo-driver/mongo/options"
)

type MongoDBClient struct {
	l log.Logger

	DB *mongo.Database

	user       string
	password   string
	hosts      string
	port       string
	database   string
	replicaSet string
}

func NewClient(l log.Logger) *MongoDBClient {
	return &MongoDBClient{
		l:          l,
		user:       os.Getenv("MONGO_USER"),
		password:   os.Getenv("MONGO_PASSWORD"),
		hosts:      os.Getenv("MONGO_HOSTS"),
		port:       os.Getenv("MONGO_PORT"),
		database:   os.Getenv("MONGO_DATABASE"),
		replicaSet: os.Getenv("MONGO_REPLICASET"),
	}
}

func (mc *MongoDBClient) InitDBConnection() error {
	ctx := context.Background()
	if len(mc.user) == 0 {
		return fmt.Errorf("mongo auth failed: user")
	}
	if len(mc.password) == 0 {
		return fmt.Errorf("mongo auth failed: password")
	}

	uri := fmt.Sprintf("mongodb://%s:%s@%s/%s?replicaSet=%s",
		mc.user,
		mc.password,
		func(hosts string, port string) string {
			var res []string
			for _, host := range strings.Split(hosts, ",") {
				res = append(res, fmt.Sprintf("%s:%s", host, port))
			}
			return strings.Join(res, ",")
		}(mc.hosts, mc.port),
		mc.database,
		mc.replicaSet,
	)

	mc.l.Info("connecting to mongo...")
	client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri).SetMaxPoolSize(50))
	if err != nil {
		return fmt.Errorf("error trying to connect to mongodb: %w", err)
	}
	mc.l.Info("successfully connected to mongodb!")

	err = client.Ping(ctx, nil)
	if err != nil {
		return fmt.Errorf("error trying to ping mongo: %w", err)
	}

	mc.DB = client.Database(mc.database)
	return err
}

func (mc *MongoDBClient) Get(ctx context.Context, collection string, filter map[string]interface{}, fields, sort map[string]int, limit, page int, result interface{}) error {
	mc.l.Debug("mongo get request",
		log.String("collection", collection),
		log.Any("filter", filter),
		log.Any("fields", fields),
		log.Any("sort", sort),
		log.Int("limit", limit),
		log.Int("page", page),
	)
	opts := options.Find().SetProjection(fields).SetSort(sort).SetLimit(int64(limit)).SetSkip(int64((page - 1) * limit))
	cur, err := mc.DB.Collection(collection).Find(ctx, filter, opts)
	if err != nil {
		return err
	}

	err = cur.All(ctx, result)
	if err != nil {
		return err
	}

	return nil
}

func (mc *MongoDBClient) GetOne(ctx context.Context, collection string, filter map[string]interface{}, fields, sort map[string]int, result interface{}) error {
	mc.l.Debug("mongo getone request",
		log.String("collection", collection),
		log.Any("filter", filter),
		log.Any("fields", fields),
		log.Any("sort", sort),
	)
	opts := options.FindOne().SetProjection(fields).SetSort(sort)
	res := mc.DB.Collection(collection).FindOne(ctx, filter, opts)
	err := res.Decode(result)
	if err != nil {
		return err
	}

	return nil
}

func (mc *MongoDBClient) GetAll(ctx context.Context, collection string, result interface{}) error {
	mc.l.Debug("mongo getall request", log.String("collection", collection))
	cur, err := mc.DB.Collection(collection).Find(ctx, bson.M{})
	if err != nil {
		return err
	}

	err = cur.All(ctx, result)
	if err != nil {
		return err
	}

	return nil
}

func (mc *MongoDBClient) Count(ctx context.Context, collection string, filter map[string]interface{}) (int, error) {
	count, err := mc.DB.Collection(collection).CountDocuments(ctx, filter)
	return int(count), err
}

func (mc *MongoDBClient) AddItemToArray(ctx context.Context, collection, field string, filter map[string]interface{}, item string) error {
	update := bson.M{
		"$addToSet": bson.M{
			field: item,
		},
	}

	_, err := mc.DB.Collection(collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(true))
	if err != nil {
		return err
	}
	return nil
}

func (mc *MongoDBClient) PullItemFromArray(ctx context.Context, collection, field string, filter map[string]interface{}, item string) error {
	update := bson.M{
		"$pull": bson.M{
			field: item,
		},
	}

	_, err := mc.DB.Collection(collection).UpdateOne(ctx, filter, update)
	if err != nil {
		return err
	}
	return nil
}
