package task

import (
	"fmt"
	"time"

	"google.golang.org/protobuf/proto"

	"a.yandex-team.ru/library/go/core/log/zap"
	"a.yandex-team.ru/library/go/core/metrics"
	"a.yandex-team.ru/travel/buses/backend/internal/common/connector"
	"a.yandex-team.ru/travel/buses/backend/internal/common/dict"
	workerLogging "a.yandex-team.ru/travel/buses/backend/internal/worker/logging"
	wpb "a.yandex-team.ru/travel/buses/backend/proto/worker"
	"a.yandex-team.ru/travel/library/go/logbroker"
	travelMetrics "a.yandex-team.ru/travel/library/go/metrics"
	pbTravel "a.yandex-team.ru/travel/proto"
)

type SearchTask struct {
	logger              *zap.Logger
	supplier            dict.Supplier
	maxRPS              float64
	queue               *SearchTaskQueue
	connector           connector.Client
	searchProducer      *logbroker.Producer
	communicationLogger *workerLogging.CommunicationLogWriter
	metrics             *travelMetrics.AppMetrics
}

func NewSearchTask(
	queue *SearchTaskQueue,
	cfg *connector.Config,
	searchProducer *logbroker.Producer,
	supplier dict.Supplier,
	logger *zap.Logger,
	communicationLogger *workerLogging.CommunicationLogWriter,
	metricsRegistry metrics.Registry,
) *SearchTask {
	connectorClient, _ := connector.NewClient(cfg, supplier.ID, logger)
	searchMetrics := travelMetrics.NewAppMetrics(
		metricsRegistry.WithPrefix("search").WithTags(map[string]string{"supplier": supplier.Name}),
	)
	return &SearchTask{
		logger:              logger,
		supplier:            supplier,
		queue:               queue,
		connector:           connectorClient,
		searchProducer:      searchProducer,
		communicationLogger: communicationLogger,
		metrics:             searchMetrics,
	}
}

func (st *SearchTask) MaxRPS() float64 {
	return st.supplier.SearchRPS
}

func (st *SearchTask) MaxConcurrency() uint32 {
	return st.supplier.Concurrency
}

func (st *SearchTask) Do() {
	const logMessage = "SearchTask.Do"

	searchRequest, err := st.queue.Pop(st.supplier.ID)
	if err != nil {
		if err != ErrNoTasks {
			st.logger.Errorf("%s: %s", logMessage, err.Error())
		}
		return
	}

	searchResult := &wpb.TSearchResult{
		Request: proto.Clone(searchRequest).(*wpb.TSearchRequest),
	}
	rides, explanation, err := st.connector.GetSearch(
		searchRequest.From, searchRequest.To, searchRequest.Date, searchRequest.TryNoCache)
	curTime := time.Now().Unix()
	if err != nil {
		msg := fmt.Sprintf("%s: error getting rides from partner sID=%d: %s",
			logMessage, st.supplier.ID, err.Error())
		st.logger.Info(msg)
		searchResult.Header = &wpb.TResponseHeader{
			Code: pbTravel.EErrorCode_EC_GENERAL_ERROR,
			Error: &pbTravel.TError{
				Code:    pbTravel.EErrorCode_EC_GENERAL_ERROR,
				Message: msg,
			},
			Timestamp: curTime,
		}
	} else {
		searchResult.Header = &wpb.TResponseHeader{
			Code:      pbTravel.EErrorCode_EC_OK,
			Timestamp: curTime,
		}
		searchResult.Rides = rides
	}
	searchResult.Header.Explanation = explanation

	err = st.searchProducer.Write(searchResult)
	if err != nil {
		st.logger.Errorf("%s: can not write to search producer: %s", logMessage, err.Error())
	}

	st.trackResult(searchRequest, searchResult)
}

func (st *SearchTask) trackResult(searchRequest *wpb.TSearchRequest, searchResult *wpb.TSearchResult) {
	logResult := proto.Clone(searchResult).(*wpb.TSearchResult)
	logResult.Header = nil
	logResult.Request = nil
	st.communicationLogger.SendLog(wpb.ECommunicationLogRecordType_LRT_WORKER_SEARCH, searchRequest.GetHeader(),
		searchRequest, searchResult.GetHeader(), logResult)

	updateMetrics(st.metrics, searchResult)
}

type carrierGroup string

const (
	matchedCarrierGroup  carrierGroup = "matched"
	missingCarrierGroup  carrierGroup = "missing"
	suppliedCarrierGroup carrierGroup = "supplied"
)

func updateMetrics(metrics *travelMetrics.AppMetrics, searchResult *wpb.TSearchResult) {
	counts := make(map[carrierGroup]int64)
	for _, ride := range searchResult.Rides {
		group := matchedCarrierGroup
		if carrier := ride.Carrier; carrier == nil {
			group = missingCarrierGroup
		} else if carrier.ID == 0 {
			group = suppliedCarrierGroup
		}
		counts[group] += 1
	}
	for key, count := range counts {
		metrics.GetOrCreateCounter(
			"result", map[string]string{"carrier_group": string(key)}, "rides_count",
		).Add(count)
	}
}
