package ru.yandex.reminders.util.task;

import com.mongodb.client.model.Filters;
import lombok.val;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.joda.time.LocalDateTime;
import org.springframework.beans.factory.annotation.Autowired;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.commune.bazinga.impl.JobStatus;
import ru.yandex.commune.mongo3.MongoUtils;
import ru.yandex.misc.ExceptionUtils;
import ru.yandex.misc.db.q.SqlOrder;
import ru.yandex.misc.lang.Check;
import ru.yandex.misc.time.MoscowTime;
import ru.yandex.reminders.logic.event.Event;
import ru.yandex.reminders.logic.event.EventData;
import ru.yandex.reminders.logic.event.EventMdao;
import ru.yandex.reminders.logic.event.SpecialClientIds;
import ru.yandex.reminders.logic.flight.FlightEventMeta;
import ru.yandex.reminders.logic.flight.shift.FlightShift;
import ru.yandex.reminders.logic.flight.shift.FlightShiftMdao;
import ru.yandex.reminders.logic.reminder.Reminder;
import ru.yandex.reminders.mongodb.BazingaOnetimeJobMdao;
import ru.yandex.reminders.util.DateTimeUtils;

import java.util.concurrent.atomic.AtomicInteger;

public class TimeZoneMigrationTask extends ManualTaskSupport {
    public static final boolean MIGRATION_FLAG = true;
    private static final Instant FROM = MoscowTime.instant(2016, 7, 24, 12, 0);
    @Autowired
    private EventMdao eventMdao;
    @Autowired
    private FlightShiftMdao flightShiftMdao;
    @Autowired
    private BazingaOnetimeJobMdao bazingaOnetimeJobMdao;

    @Override
    protected void doExecute() {
        try {
            migrateFlightEvents();
            setStatus(TaskStatus.running(getStarted(), "flight events done"));
            migrateFlightShifts();
            setStatus(TaskStatus.success(getStarted(), "done"));

        } catch (Exception e) {
            setStatus(TaskStatus.fail(getStarted(), ExceptionUtils.prettyPrint(e)));
        }
    }

    private void migrateFlightEvents() {
        val notMigratedFilter = Filters.ne("flight.migrate", MIGRATION_FLAG);

        val filter = Filters.and(
                Filters.eq("cid", SpecialClientIds.FLIGHT),
                Filters.gte("reminders.sendTs", MongoUtils.toMongoValue(FROM)),
                notMigratedFilter);

        long total = eventMdao.getCollectionX().count(filter);
        val processed = new AtomicInteger(0);

        eventMdao.getCollectionX().findStream(filter, SqlOrder.unordered(), e -> {
            FlightEventMeta flight = e.getFlightMeta().get();
            val depTz = flight.getDepartureCityTz();
            val depTs = flight.getDepartureTs();

            Check.isFalse(flight.isMigrated());

            val arrTz = flight.getArrivalTz();
            val newArrTs = flight.getArrivalTs().map(flight.isSegmented()
                    ? ts -> DateTimeUtils.tzMigrateInstant(ts, arrTz.get())
                    : ts -> ts.plus(new Duration(depTs, DateTimeUtils.tzMigrateInstant(depTs, depTz))));

            val newArrDateTime = newArrTs.map(ts -> new LocalDateTime(ts, arrTz.get()));

            flight = new FlightEventMeta(
                    flight.getMid(), flight.getFlightNumber(), flight.getAirline(),
                    flight.getPlannedDepartureTs().map(i -> DateTimeUtils.tzMigratedInstantLocalDateTime(i, depTz)),
                    flight.getDepartureCity(), flight.getDepartureAirport(),
                    DateTimeUtils.tzMigratedInstantLocalDateTime(flight.getDepartureTs(), depTz), depTz,
                    flight.getArrivalCity(), flight.getArrivalAirport(), newArrDateTime, arrTz,
                    flight.getSource(), flight.getCheckInLink(), flight.getAeroexpressLink(),
                    flight.getLastSegmentFlightNumber(),
                    flight.getLastSegmentDepartureDateTime(), flight.getLastSegmentSource(),
                    flight.getDirection(), flight.getLang(), flight.getYaDomain());

            ListF<Reminder> reminders = e.getReminders().map(
                    r -> r.withSendTs(DateTimeUtils.tzMigrateInstant(r.getSendTs(), r.getSendTz())));

            val data = new EventData(
                    e.getSource(), e.getName(), e.getDescription(), e.getData(), reminders, Option.of(flight));

            val event = new Event(e.getId(), data, e.getSenderName(), e.getUpdatedTs(), e.getUpdatedReqId());

            reminders.forEach(r -> bazingaOnetimeJobMdao.rescheduleByActiveUniqueIdAndStatus(
                    BazingaOnetimeJobMdao.reminderIdToActiveUniqueId(r.getId()), JobStatus.READY, r.getSendTs()));

            eventMdao.getCollectionX().replaceOneWithoutId(
                    Filters.and(EventMdao.toQuery(e.getId()), notMigratedFilter), event);

            if (processed.incrementAndGet() % 1000 == 0) {
                setStatus(TaskStatus.running(getStarted(), processed.get() + " of " + total + " flight events done"));
            }
        });
    }

    private void migrateFlightShifts() {
        val notMigratedFilter = Filters.ne("migrate", MIGRATION_FLAG);

        val filter = Filters.and(Filters.gte("plannedTs", MongoUtils.toMongoValue(FROM)), notMigratedFilter);

        val total = flightShiftMdao.getCollectionX().count(filter);
        val processed = new AtomicInteger(0);

        flightShiftMdao.getCollectionX().findStream(filter, SqlOrder.unordered(), s -> {
            Check.isFalse(s.isMigrated());

            val shift = new FlightShift(
                    s.getId(), s.getFlightNum(), s.getGeoId(), s.getTz(),
                    DateTimeUtils.tzMigrateInstant(s.getPlannedTs(), s.getTz()),
                    DateTimeUtils.tzMigrateInstant(s.getActualTs(), s.getTz()),
                    s.getSendTs().map(i -> DateTimeUtils.tzMigrateInstant(i, s.getTz())), s.isLatest());

            shift.getSendTs().forEach(ts -> bazingaOnetimeJobMdao.rescheduleByActiveUniqueIdAndStatus(
                    BazingaOnetimeJobMdao.flightShiftIdToActiveUniqueId(s.getId()), JobStatus.READY, ts));

            flightShiftMdao.getCollectionX().replaceOneWithoutId(
                    Filters.and(Filters.eq(shift.getId()), notMigratedFilter), shift);

            if (processed.incrementAndGet() % 1000 == 0) {
                setStatus(TaskStatus.running(getStarted(), processed.get() + " of " + total + " flight shifts done"));
            }
        });
    }
}
