package internal

import (
	"context"
	"fmt"
	"math/rand"
	"regexp"
	"strconv"
	"strings"
	"time"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/travel/avia/shared_flights/lib/go/dtutil"
	"a.yandex-team.ru/travel/library/go/containers"
	"a.yandex-team.ru/yt/go/ypath"
	"a.yandex-team.ru/yt/go/yt"
	"a.yandex-team.ru/yt/go/yt/ythttp"
)

const dateFormat = "2006-01-02"

var yqlQuery = `
	PRAGMA yt.InferSchema = '1';
	Use hahn;

	DECLARE $start_date AS String;
	DECLARE $end_date AS String;
	DECLARE $intermediate_table AS String;
	DECLARE $limit AS UInt64;

	INSERT INTO $intermediate_table
		SELECT
			marketing_carrier_id,
			marketing_flight_number,
			segment_number,
			AGGREGATE_LIST_DISTINCT(flight_date) as flight_dates,
			AGGREGATE_LIST_DISTINCT(is_codeshare) as is_codeshare_flag,
			AGGREGATE_LIST_DISTINCT(segment_departure_date) as segment_departure_dates,
			AGGREGATE_LIST_DISTINCT(airport_from_id) as airport_from_ids,
			AGGREGATE_LIST_DISTINCT(airport_to_id) as airport_to_ids,
			AGGREGATE_LIST_DISTINCT(marketing_carrier_code) as marketing_carrier_codes,
			AGGREGATE_LIST_DISTINCT(operating_carrier_code) as operating_carrier_codes,
			AGGREGATE_LIST_DISTINCT(operating_flight_number) as operating_flight_number,
			AGGREGATE_LIST_DISTINCT(operating_carrier_id) as operating_carrier_ids,
			AGGREGATE_LIST_DISTINCT(operating_segment_number) as operating_segment_numbers,
			AGGREGATE_LIST_DISTINCT(segment_count) as segment_counts,
		FROM RANGE('home/avia/logs/avia-flight-schedule-log', $start_date, $end_date)
		GROUP BY marketing_carrier_id, marketing_flight_number, segment_number
		LIMIT $limit;
	`

type Flight struct {
	AirportFrom           int64  `yson:"airport_from"`
	AirportTo             int64  `yson:"airport_to"`
	MarketingCarrierCode  string `yson:"marketing_carrier_code"`
	MarketingCarrierID    int32  `yson:"marketing_carrier_id"`
	MarketingFlightNumber string `yson:"marketing_flight_number"`
	OperatingCarrierCode  string `yson:"operating_carrier_code"`
	OperatingCarrierID    int32  `yson:"operating_carrier_id"`
	OperatingFlightNumber string `yson:"operating_flight_number"`
	SegmentNumber         int32  `yson:"marketing_segment_number"`
}

type FlightTitle struct {
	CarrierID    int32
	CarrierCode  string
	FlightNumber string
	Segment      int32
}

type DataRow struct {
	MarketingCarrierID    int32  `yson:"marketing_carrier_id"`
	MarketingFlightNumber string `yson:"marketing_flight_number"`
	SegmentNumber         int32  `yson:"segment_number"`

	AirportFromIds          []int64  `yson:"airport_from_ids"`
	AirportToIds            []int64  `yson:"airport_to_ids"`
	FlightDates             []string `yson:"flight_dates"`
	IsCodeshareFlag         []bool   `yson:"is_codeshare_flag"`
	MarketingCarrierCodes   []string `yson:"marketing_carrier_codes"`
	OperatingCarrierCodes   []string `yson:"operating_carrier_codes"`
	OperatingCarrierIds     []int32  `yson:"operating_carrier_ids"`
	OperatingFlightNumbers  []string `yson:"operating_flight_number"`
	OperatingSegmentNumbers []int32  `yson:"operating_segment_numbers"`
	SegmentCounts           []int32  `yson:"segment_counts"`
	SegmentDepartureDates   []string `yson:"segment_departure_dates"`
}

type RegularFlightsProcessor struct {
	// internally calculated fields
	minDayNum         int
	maxDayNum         int
	minWeeksRequired  int
	intermediateTable string

	// calculation parameters
	Token                     string
	StartDate                 string
	EndDate                   string
	YTProxy                   string
	YTTablePath               string
	YTIntermediateTableFolder string
	Limit                     int
	Logger                    log.Logger
}

func (p *RegularFlightsProcessor) GenerateRegularFlights() error {
	p.Logger.Info("Start task")

	var ok bool
	p.StartDate, p.minDayNum, ok = p.parseDate(p.StartDate)
	if !ok {
		p.Logger.Error("Invalid start date", log.String("startDate", p.StartDate))
		return xerrors.Errorf("Invalid start date: %s", p.StartDate)
	}
	p.EndDate, p.maxDayNum, ok = p.parseDate(p.EndDate)
	if !ok {
		p.Logger.Error("Invalid end date", log.String("endDate", p.EndDate))
		return xerrors.Errorf("Invalid end date: %s", p.EndDate)
	}
	p.minWeeksRequired = (p.maxDayNum + 1 - p.minDayNum) / 7
	if p.minWeeksRequired < 0 {
		return xerrors.Errorf("end date is earlier than start date")
	} else if p.minWeeksRequired == 0 {
		p.minWeeksRequired = 1
	}

	intermediateTable := fmt.Sprintf(
		"%s/regular-flights-tmp-%s-%d", p.YTIntermediateTableFolder, time.Now().Format(time.StampMicro), rand.Intn(1000000000))
	p.intermediateTable = strings.ReplaceAll(intermediateTable, " ", "_")

	defer p.deleteIntermediateTable()

	yql, err := NewYqlOperation(yqlQuery, p.getParamsMap(), p.Token, p.Logger)
	if err != nil {
		p.Logger.Error("Error while constructing YQL", log.Error(err))
		return err
	}
	err = yql.Start()
	if err != nil {
		p.Logger.Error("Error while staring YQL", log.Error(err))
		return err
	}

	p.Logger.Info("Started YQL request")

	err = yql.Wait()
	if err != nil {
		p.Logger.Error("Error while waiting for YQL results", log.Error(err))
		return err
	}

	p.Logger.Info("Got results")

	regularFlights, err := p.readIntermediateTable()
	if err != nil {
		p.Logger.Error("Error while reading YQL results", log.Error(err))
		return err
	}
	if len(regularFlights) > 0 {
		err := p.saveToYT(regularFlights)
		if err != nil {
			p.Logger.Error("Error while saving YQL results", log.Error(err))
			return err
		}
	}
	return nil
}

func (p *RegularFlightsProcessor) parseDate(strDate string) (string, int, bool) {
	if len(strDate) < 6 {
		dateShift, err := strconv.Atoi(strDate)
		if err != nil {
			return strDate, 0, false
		}
		dateIndexNow, ok := dtutil.DateCache.IndexOfStringDate(dtutil.StringDate(time.Now().Format(dateFormat)))
		if !ok {
			return strDate, 0, false
		}
		resultDate := dtutil.DateCache.Date(dateIndexNow).AddDays(dateShift)
		if resultDate == 0 {
			return strDate, 0, false
		}
		return string(resultDate.StringDateDashed()), dateIndexNow + dateShift, true
	}
	dateIndex, ok := dtutil.DateCache.IndexOfStringDate(dtutil.StringDate(strDate))
	return strDate, dateIndex, ok
}

func (p *RegularFlightsProcessor) getParamsMap() map[string]interface{} {
	// 1 billion is big enough for the number of regular flights to never reach it
	// (and even if it does, we'd go out of YT quota to store it anyway)
	limit := 1000 * 1000 * 1000
	if p.Limit > 0 {
		limit = p.Limit
	}
	result := map[string]interface{}{
		"$start_date":         p.StartDate,
		"$end_date":           p.EndDate,
		"$limit":              limit,
		"$intermediate_table": p.intermediateTable,
	}
	p.Logger.Info("YQL parameters", log.Reflect("params", result))
	return result
}

func replaceParams(yqlQuery string, params map[string]string) string {
	for key, value := range params {
		yqlQuery = strings.ReplaceAll(yqlQuery, key, value)
	}
	return yqlQuery
}

func (p *RegularFlightsProcessor) readIntermediateTable() (map[Flight]bool, error) {
	ctx := context.Background()
	wctx, cancel := context.WithCancel(ctx)
	defer cancel()
	ytConfig := yt.Config{
		Proxy: p.YTProxy,
		Token: p.Token,
	}
	ytClient, _ := ythttp.NewClient(&ytConfig)
	path := ypath.Path(p.intermediateTable)
	ytReader, err := ytClient.ReadTable(wctx, path, nil)
	if err != nil {
		return nil, err
	}
	defer func() {
		if err := ytReader.Close(); err != nil {
			log.Error(err)
		}
	}()

	flightsData := make(map[FlightTitle]Flight)
	weeksCounter := make(map[FlightTitle]containers.Set[int])
	rowCount := 0
	for ytReader.Next() {
		var row DataRow
		err := ytReader.Scan(&row)
		if err != nil {
			p.Logger.Errorf("Error while scanning the intermediate table: %+v", err)
			return nil, err
		}
		if rowCount < 5 {
			// Data structure sample, just to make it easier to specify data types in the code
			p.Logger.Info("Results", log.Reflect("row", row))
		}
		rowCount++

		flight := newFlight(row)
		flightTitle := flightTitle(flight)
		flightsData[flightTitle] = flight
		flightWeeks, ok := weeksCounter[flightTitle]
		if !ok {
			flightWeeks = containers.SetOf[int]()
		}

		// Prefer "segment_departure_dates", but if that field is not filled in, use "flight_dates" as a last hope
		foundAtLeastOneDate := false
		for _, date := range row.SegmentDepartureDates {
			if len(date) == 0 {
				continue
			}
			flightWeeks.Add(getWeekNumber(date, p.minDayNum))
			foundAtLeastOneDate = true
		}
		if !foundAtLeastOneDate {
			flightDates := row.FlightDates
			for _, date := range flightDates {
				flightWeeks.Add(getWeekNumber(date, p.minDayNum))
			}
		}
		weeksCounter[flightTitle] = flightWeeks
	}
	p.Logger.Info("Row count", log.Int("rows count", rowCount))

	regularFlights := map[Flight]bool{}
	skippedFlightsCount := 0
	for flightTitle, weeks := range weeksCounter {
		if len(weeks) >= p.minWeeksRequired {
			regularFlights[flightsData[flightTitle]] = true
		} else {
			skippedFlightsCount++
		}
	}
	p.Logger.Info("regular flights", log.Int("count", len(regularFlights)))
	p.Logger.Info("non-regular flights", log.Int("count", skippedFlightsCount))
	return regularFlights, nil
}

func (p *RegularFlightsProcessor) deleteIntermediateTable() {
	ctx := context.Background()
	wctx, cancel := context.WithCancel(ctx)
	defer cancel()
	ytConfig := yt.Config{
		Proxy: p.YTProxy,
		Token: p.Token,
	}
	ytClient, _ := ythttp.NewClient(&ytConfig)
	tx, err := ytClient.BeginTx(wctx, nil)
	if err != nil {
		p.Logger.Fatal("Unable to access yt", log.Error(err))
		return
	}
	path := ypath.Path(p.intermediateTable)
	tableExists, err := tx.NodeExists(wctx, path, nil)
	if err != nil {
		p.Logger.Fatal("Unable to test if yt table exists", log.Error(err))
		return
	}
	if tableExists {
		err := tx.RemoveNode(wctx, path, nil)
		if err != nil {
			p.Logger.Fatal("Unable to remove existing yt table", log.Error(err))
			return
		}
	}
	err = tx.Commit()
	if err != nil {
		p.Logger.Fatal("Unable to commit yt transaction", log.Error(err))
	}
}

func (p *RegularFlightsProcessor) saveToYT(regularFlights map[Flight]bool) error {
	ctx := context.Background()
	wctx, cancel := context.WithCancel(ctx)
	defer cancel()
	ytConfig := yt.Config{
		Proxy: p.YTProxy,
		Token: p.Token,
	}
	ytClient, _ := ythttp.NewClient(&ytConfig)
	path := ypath.Path(fmt.Sprintf("%s/%d-weeks", p.YTTablePath, p.minWeeksRequired))

	p.Logger.Info(
		"Dumping regular flights into yt table",
		log.Int("count", len(regularFlights)),
		log.String("yt-output-path", path.String()),
	)

	tx, err := ytClient.BeginTx(wctx, nil)
	if err != nil {
		p.Logger.Fatal("Unable to access yt", log.Error(err))
		return err
	}
	tw, err := tx.WriteTable(wctx, path, nil)
	if err != nil {
		p.Logger.Fatal("Unable to create yt table", log.Error(err))
		return err
	}

	for flight := range regularFlights {
		err = tw.Write(flight)
		if err != nil {
			p.Logger.Fatal("Unable to write data to yt table", log.Error(err))
			return err
		}
	}

	err = tw.Commit()
	if err != nil {
		p.Logger.Fatal("Unable to flush data to yt table", log.Error(err))
		return err
	}
	err = tx.Commit()
	if err != nil {
		p.Logger.Fatal("Unable to commit yt transaction", log.Error(err))
		return err
	}

	p.Logger.Info("Done writing into yt table")
	return nil
}

func getWeekNumber(date string, minDayNum int) int {
	dateIndex := dtutil.DateCache.IndexOfStringDateP(dtutil.StringDate(date))
	return (dateIndex - minDayNum) / 7
}

func newFlight(row DataRow) Flight {
	return Flight{
		AirportFrom:           row.AirportFromIds[0],
		AirportTo:             row.AirportToIds[0],
		MarketingCarrierCode:  latinString(row.MarketingCarrierCodes),
		MarketingCarrierID:    row.MarketingCarrierID,
		MarketingFlightNumber: row.MarketingFlightNumber,
		OperatingCarrierCode:  latinString(row.OperatingCarrierCodes),
		OperatingCarrierID:    row.OperatingCarrierIds[0],
		OperatingFlightNumber: row.OperatingFlightNumbers[0],
		SegmentNumber:         row.SegmentNumber,
	}
}

func flightTitle(flight Flight) FlightTitle {
	return FlightTitle{
		CarrierID:    flight.MarketingCarrierID,
		CarrierCode:  flight.MarketingCarrierCode,
		FlightNumber: flight.MarketingFlightNumber,
		Segment:      flight.SegmentNumber,
	}
}

var latinStringPattern, _ = regexp.Compile("[A-Z0-9]")

func latinString(values []string) string {
	result := ""
	if len(values) > 0 {
		isLatin := false
		for _, value := range values {
			if latinStringPattern.MatchString(value) {
				if isLatin {
					if result > value {
						result = value
					}
				} else {
					result = value
					isLatin = true
				}
			} else {
				if !isLatin && (result == "" || result > value) {
					result = value
				}
			}
		}
	}
	return result
}
