package ru.yandex.partner.core.entity.queue.repository;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import javax.annotation.Nonnull;

import org.jooq.Condition;
import org.jooq.DSLContext;
import org.jooq.Field;
import org.jooq.Record2;
import org.jooq.SortOrder;
import org.jooq.Table;
import org.jooq.UpdateSetFirstStep;
import org.jooq.UpdateSetMoreStep;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.dao.CannotAcquireLockException;
import org.springframework.dao.DeadlockLoserDataAccessException;
import org.springframework.stereotype.Repository;

import ru.yandex.partner.core.entity.queue.multistate.TaskMultistateGraph;
import ru.yandex.partner.core.entity.queue.service.TaskFacade;
import ru.yandex.partner.core.multistate.AbstractMultistate;
import ru.yandex.partner.core.multistate.Multistate;
import ru.yandex.partner.core.multistate.queue.TaskMultistate;
import ru.yandex.partner.core.multistate.queue.TaskStateFlag;
import ru.yandex.partner.core.queue.Task;
import ru.yandex.partner.core.queue.TaskData;
import ru.yandex.partner.core.queue.TaskFactory;
import ru.yandex.partner.dbschema.partner.tables.records.QueueRecord;
import ru.yandex.partner.defaultconfiguration.PartnerLocalDateTime;
import ru.yandex.partner.libs.utils.JooqUtils;

import static org.jooq.impl.DSL.min;
import static ru.yandex.partner.dbschema.partner.Tables.QUEUE;
import static ru.yandex.partner.libs.multistate.MultistatePredicates.empty;
import static ru.yandex.partner.libs.multistate.MultistatePredicates.has;

@Repository
public class TaskRepository {

    private static final Logger LOGGER = LoggerFactory.getLogger(TaskRepository.class);

    private final DSLContext dslContext;
    private final TaskMultistateGraph taskMultistateGraph;
    private final TaskFacade taskFacade;

    private Set<Long> unfinishedMultistateValues;

    public TaskRepository(DSLContext dslContext, TaskMultistateGraph taskMultistateGraph,
                          TaskFacade taskFacade) {
        this.dslContext = dslContext;
        this.taskMultistateGraph = taskMultistateGraph;
        this.taskFacade = taskFacade;
    }

    public TaskData insertTask(TaskData taskData) {
        QueueRecord queueRecord = taskDataToRecord(taskData);
        queueRecord.setAddDt(PartnerLocalDateTime.now());

        Field<?>[] fields = queueRecord.fields();
        List<?> values = Arrays.stream(fields).map(queueRecord::get).collect(Collectors.toList());
        QueueRecord queueRecord1 = dslContext.insertInto(QUEUE)
                .columns(fields)
                .values(values)
                .returning(fields)
                .fetchOne();

        if (queueRecord1 == null) {
            throw new IllegalStateException("Jooq returned null when inserting TaskData. " + taskData);
        }

        return recordToTaskData(queueRecord1);
    }

    public void insertTasks(Collection<TaskData> taskDataCollection) {
        var records = taskDataCollection.stream()
                .map(this::taskDataToRecord)
                .peek(record -> record.setAddDt(PartnerLocalDateTime.now()))
                .collect(Collectors.toList());

        JooqUtils.insertRecords(dslContext, QUEUE, records);
    }

    public void updateTask(TaskData taskData) {
        QueueRecord queueRecord = taskDataToRecord(taskData);
        updateRecord(queueRecord);
    }

    public <T extends Task<?, ?>> Optional<TaskData> pickFreeTask(Class<T> clazz) {

        Condition condition = getConditionForTaskClass(clazz);

        QueueRecord queueRecord = null;
        try {
            queueRecord = dslContext.selectFrom(QUEUE)
                    .where(condition)
                    .and(QUEUE.GRABBED_BY.isNull())
                    .and(QUEUE.MULTISTATE.eq(0L))
                    .orderBy(QUEUE.ID.sort(SortOrder.ASC))
                    .limit(1)
                    .forUpdate()
                    .fetchOne();
        } catch (DeadlockLoserDataAccessException | CannotAcquireLockException e) {
            LOGGER.warn("Got killed as deadlock loser or couldn't acquire lock while picking free task {}", clazz, e);
        }

        if (queueRecord == null) {
            return Optional.empty();
        }

        TaskData taskData = recordToTaskData(queueRecord);

        return Optional.of(taskData);
    }

    public void lockTask(TaskData taskData, Duration duration) {
        LocalDateTime now = PartnerLocalDateTime.now();
        LocalDateTime until = now.plus(duration);

        taskData.setGrabbedAt(now);
        taskData.setGrabbedUntil(until);
        taskData.setGrabbedBy(getWorkerId());

        QueueRecord queueRecord = taskDataToRecord(taskData);

        updateRecord(queueRecord);
    }

    public TaskData get(Long id) {
        QueueRecord queueRecord = dslContext.selectFrom(QUEUE)
            .where(QUEUE.ID.eq(id))
            .fetchOne();

        return recordToTaskData(queueRecord);
    }

    private void updateRecord(QueueRecord queueRecord) {
        UpdateSetFirstStep<QueueRecord> updateState = dslContext.update(QUEUE);
        UpdateSetMoreStep<QueueRecord> setState = null;
        for (Field field : queueRecord.fields()) {
            if (!field.getName().equals(QUEUE.TIME_ETA.getName())) {
                setState = Objects.requireNonNullElse(setState, updateState).set(field, field.getValue(queueRecord));
            }
        }

        Objects.requireNonNull(setState, "Something wrong. QueueRecord has no fields")
                .where(QUEUE.ID.eq(queueRecord.getId()))
                .execute();
    }

    @Nonnull
    private String getWorkerId() {
        String hostName = getHostName();
        long pid = ProcessHandle.current().pid();
        long threadId = Thread.currentThread().getId();
        return String.format("%s:%d:%d", hostName, pid, threadId);
    }

    private String getHostName() {
        try {
            return InetAddress.getLocalHost().getHostName();
        } catch (UnknownHostException e) {
            LOGGER.error("Could not get host name for queue.grabbed_by");
            return "unknown_host";
        }
    }

    @Nonnull
    private Set<Long> getUnfinishedMultistateValues() {
        if (unfinishedMultistateValues == null) {
            unfinishedMultistateValues = taskMultistateGraph.getMultistatesForPredicate(
                            has(TaskStateFlag.WORKING)
                                    .or(has(TaskStateFlag.FINISHED_WITH_ERROR))
                                    .or(has(TaskStateFlag.NEED_RESTART))
                                    .or(empty()))
                    .stream()
                    .map(Multistate::toMultistateValue)
                    .collect(Collectors.toSet());
        }
        return unfinishedMultistateValues;
    }

    private <T extends Task<?, ?>> Condition getConditionForTaskClass(Class<T> clazz) {

        TaskFactory<?> taskFactory = taskFacade.getFactoryForTaskClass(clazz);

        if (taskFactory.isAllowConcurrentExecution()) {
            return QUEUE.METHOD_TYPE.eq((long) taskFactory.getTypeId());
        }

        Table<Record2<Long, Long>> q2 = dslContext.select(QUEUE.GROUP_ID, min(QUEUE.ID).as("min_id"))
                .from(QUEUE)
                .where(QUEUE.METHOD_TYPE.eq((long) taskFactory.getTypeId()))
                .and(QUEUE.MULTISTATE.in(getUnfinishedMultistateValues()))
                .groupBy(QUEUE.GROUP_ID)
                .asTable("q2");

        return QUEUE.ID.in(dslContext.select(q2.field("min_id")).from(q2).asField());
    }

    private QueueRecord taskDataToRecord(TaskData taskData) {
        return new QueueRecord(
                taskData.getId(),
                taskData.getAddDt(),
                taskData.getStartDt(),
                taskData.getEndDt(),
                (long) taskData.getTypeId(),
                taskData.getGroupId(),
                taskData.getParams(),
                taskData.getUserId(),
                Optional.ofNullable(taskData.getMultistate()).map(AbstractMultistate::toMultistateValue).orElse(null),
                null,
                taskData.getLog(),
                taskData.getErrorData(),
                taskData.getResult(),
                (long) taskData.getTries(),
                taskData.getGrabbedBy(),
                taskData.getGrabbedAt(),
                taskData.getGrabbedUntil()
        );
    }

    private TaskData recordToTaskData(QueueRecord record) {
        return TaskData.newBuilder()
                .withId(record.getId())
                .withMultistate(Optional.ofNullable(record.getMultistate())
                        .map(TaskMultistate::new).orElseGet(TaskMultistate::new))
                .withAddDt(record.getAddDt())
                .withStartDt(record.getStartDt())
                .withEndDt(record.getEndDt())
                .withTypeId(record.getMethodType().intValue())
                .withGroupId(record.getGroupId())
                .withParams(record.getParams())
                .withUserId(record.getUserId())
                .withLog(record.getLog())
                .withErrorData(record.getErrorData())
                .withResult(record.getResult())
                .withTries(Optional.ofNullable(record.getTries()).map(Long::intValue).orElse(null))
                .withGrabbedBy(record.getGrabbedBy())
                .withGrabbedAt(record.getGrabbedAt())
                .withGrabbedUntil(record.getGrabbedUntil())
                .build();
    }
}
