package legdb

import (
	"context"
	"strings"
	"time"

	pgx "github.com/jackc/pgx/v4"

	"a.yandex-team.ru/library/go/core/xerrors"
	dir "a.yandex-team.ru/travel/avia/shared_flights/lib/go/direction"
	"a.yandex-team.ru/travel/avia/shared_flights/lib/go/dtutil"
	appMetrics "a.yandex-team.ru/travel/avia/shared_flights/status_importer/internal/metrics"
	"a.yandex-team.ru/travel/avia/shared_flights/status_importer/internal/objects/model"
	"a.yandex-team.ru/travel/library/go/errutil"
	"a.yandex-team.ru/travel/library/go/metrics"
)

type Scannable interface {
	Next() bool
	Scan(dest ...interface{}) error
}

func (l *FlightLeg) FlightLeg(carrierID int64, flightNumber string, flightDay dtutil.StringDate, direction dir.Direction, stationID model.StationID) (leg int16, departureDate dtutil.StringDate, err error) {
	defer errutil.Wrap(&err, "db: flight-leg(%d, %s, %v, %s, %v)", carrierID, flightNumber, flightDay, direction, stationID)
	defer func(startTime time.Time) {
		elapsed := time.Since(startTime)
		metrics.GlobalAppMetrics().
			GetOrCreateHistogram(
				appMetrics.DB,
				nil,
				appMetrics.FlightTimings,
				appMetrics.FlightLegReadTimingsBuckets,
			).RecordDuration(elapsed)
	}(time.Now())

	flightDayParsed, err := time.Parse("2006-01-02", string(flightDay))
	if err != nil {
		return 0, departureDate, xerrors.Errorf("date-parsing error: %w", err)
	}

	sql := prepareSQL(direction, operatingFromForDirection(direction), operatingUntilForDirection(direction))
	pool := l.cluster.RO()
	ctx, cancel := context.WithTimeout(context.Background(), l.legTimeout)
	defer cancel()
	conn, err := pool.Acquire(ctx)
	if err != nil {
		return 0, departureDate, xerrors.Errorf("db-flightLeg: %w", err)
	}
	defer conn.Release()

	var rows pgx.Rows
	if rows, err = conn.Query(ctx, sql, carrierID, flightNumber, flightDayParsed, stationID); err != nil {
		if err == pgx.ErrNoRows {
			return leg, departureDate, nil
		} else {
			return leg, departureDate, xerrors.Errorf(
				"db-flightLeg: %w",
				err,
			)
		}
	}
	defer rows.Close()

	return findMatch(flightDay, rows, direction)
}

func findMatch(flightDay dtutil.StringDate, rows Scannable, direction dir.Direction,
) (leg int16, departureDate dtutil.StringDate, err error) {
	var (
		flightDateIndex int
		ok              bool
		operatingDays   dtutil.OperatingDays
	)
	// flightDateIndex and flightDay might be either arrival or departure at this point
	if flightDateIndex, ok = dtutil.DateCache.IndexOfStringDate(flightDay); !ok {
		return leg, departureDate, xerrors.Errorf("cannot load date from cache: %v", flightDay)
	}

	for rows.Next() {
		var arrivalDayShift int
		err = rows.Scan(&leg, &operatingDays, &arrivalDayShift)
		departureDate = flightDepartureDateForDirection(direction, flightDateIndex, arrivalDayShift)
		if err != nil {
			return leg, departureDate, xerrors.Errorf("error scanning for leg number in database: %w", err)
		}
		flightDepartureWeekday := flightDepartureWeekdayForDirection(direction, flightDateIndex, arrivalDayShift)
		if operatingDays.OperatesOn(flightDepartureWeekday) {
			return leg, departureDate, nil
		}
	}
	return leg, departureDate, nil
}

func flightDepartureDateForDirection(direction dir.Direction, index int, shift int) dtutil.StringDate {
	if direction == dir.ARRIVAL {
		return dtutil.DateCache.Date(index - shift).StringDateDashed()
	}
	return dtutil.DateCache.Date(index).StringDateDashed()
}

func flightDepartureWeekdayForDirection(direction dir.Direction, flightDateIndex int, arrivalDayShift int) time.Weekday {
	if direction == dir.ARRIVAL {
		return dtutil.DateCache.WeekDay(flightDateIndex - arrivalDayShift)
	}
	return dtutil.DateCache.WeekDay(flightDateIndex)

}

func operatingUntilForDirection(direction dir.Direction) (operatingUntilForDirection string) {
	if direction == dir.ARRIVAL {
		operatingUntilForDirection = "fp.operating_until + fp.arrival_day_shift"
	} else {
		operatingUntilForDirection = "fp.operating_until"
	}
	return operatingUntilForDirection
}

func operatingFromForDirection(direction dir.Direction) (operatingFromForDirection string) {
	if direction == dir.ARRIVAL {
		operatingFromForDirection = "fp.operating_from + fp.arrival_day_shift"
	} else {
		operatingFromForDirection = "fp.operating_from"
	}
	return operatingFromForDirection
}

func prepareSQL(direction dir.Direction, operatingFromForDirection string, operatingUntilForDirection string) string {
	sqlTemplate := `SELECT
	fb.leg_seq_number as leg_seq_number,
	fp.operating_on_days as operating_on_days,
	fp.arrival_day_shift as arrival_day_shift
FROM flight_pattern fp
	INNER JOIN flight_base fb on fp.flight_base_id = fb.id
WHERE fp.marketing_carrier = $1
	AND fp.marketing_flight_number = $2
	AND {{operating_from}} <= $3
	AND {{operating_until}} >= $3
	AND fb.{{direction}}_station = $4

UNION SELECT
	fb.leg_seq_number as leg_seq_number,
	fp.operating_on_days as operating_on_days,
	fp.arrival_day_shift as arrival_day_shift

FROM sirena_flight_pattern fp
	INNER JOIN sirena_flight_base fb on fp.flight_base_id = fb.id
WHERE fp.marketing_carrier = $1
	AND fp.marketing_flight_number = $2
	AND {{operating_from}} <= $3
	AND {{operating_until}} >= $3
	AND fb.{{direction}}_station = $4

UNION SELECT
	fb.leg_seq_number as leg_seq_number,
	fp.operating_on_days as operating_on_days,
	fp.arrival_day_shift as arrival_day_shift

FROM apm_flight_pattern fp
	INNER JOIN apm_flight_base fb on fp.flight_base_id = fb.id
WHERE fp.marketing_carrier = $1
	AND fp.marketing_flight_number = $2
	AND {{operating_from}} <= $3
	AND {{operating_until}} >= $3
	AND fb.{{direction}}_station = $4`
	replacer := strings.NewReplacer(
		"{{direction}}", direction.String(),
		"{{operating_from}}", operatingFromForDirection,
		"{{operating_until}}", operatingUntilForDirection,
	)
	sql := replacer.Replace(sqlTemplate)
	return sql
}
