package ru.yandex.crypta.lab.yt;

import java.io.IOException;
import java.io.InputStream;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;

import javax.inject.Inject;

import NCrypta.NSiberia.DescribingExperiment.TDescribingExperiment;
import com.google.common.base.MoreObjects;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import okhttp3.MediaType;
import okhttp3.RequestBody;
import org.quartz.JobBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.crypta.clients.pgaas.PostgresClient;
import ru.yandex.crypta.common.exception.Exceptions;
import ru.yandex.crypta.common.exception.NotFoundException;
import ru.yandex.crypta.common.ws.EntityId;
import ru.yandex.crypta.lab.TrainingSampleService;
import ru.yandex.crypta.lab.base.BaseYtService;
import ru.yandex.crypta.lab.job.ModelTrainingJob;
import ru.yandex.crypta.lab.proto.AccessLevel;
import ru.yandex.crypta.lab.proto.ETrainableSegmentOriginType;
import ru.yandex.crypta.lab.proto.ETrainableSegmentTargetType;
import ru.yandex.crypta.lab.proto.TTrainableSegment;
import ru.yandex.crypta.lab.proto.TTrainingSample;
import ru.yandex.crypta.lab.proto.Timestamps;
import ru.yandex.crypta.lab.proto.TrainableSegments;
import ru.yandex.crypta.lab.siberia.RawResponse;
import ru.yandex.crypta.lab.siberia.SiberiaClient;
import ru.yandex.crypta.lab.siberia.SiberiaHelpers;
import ru.yandex.crypta.lab.tables.Tables;
import ru.yandex.crypta.lab.tables.TrainableSegmentsTable;
import ru.yandex.crypta.lab.tables.TrainingSamplesTable;
import ru.yandex.crypta.lab.utils.Acl;
import ru.yandex.crypta.lab.utils.Paths;
import ru.yandex.crypta.lib.proto.EEnvironment;
import ru.yandex.crypta.lib.proto.identifiers.TId;
import ru.yandex.crypta.lib.proto.identifiers.TIds;
import ru.yandex.crypta.lib.python.custom_ml.proto.ETrainingSampleState;
import ru.yandex.crypta.lib.python.custom_ml.proto.TIndustry;
import ru.yandex.crypta.lib.python.custom_ml.proto.TMetrics;
import ru.yandex.crypta.lib.python.custom_ml.proto.TTrainingError;
import ru.yandex.crypta.lib.schedulers.QuartzScheduler;
import ru.yandex.crypta.lib.schedulers.Schedulers;
import ru.yandex.crypta.lib.yt.YtService;
import ru.yandex.crypta.siberia.proto.TDescribeIdsResponse;
import ru.yandex.inside.yt.kosher.cypress.YPath;
import ru.yandex.inside.yt.kosher.tables.YTableEntryTypes;
import ru.yandex.inside.yt.kosher.ytree.YTreeMapNode;
import ru.yandex.inside.yt.kosher.ytree.YTreeNode;

public class DefaultTrainingSampleService extends BaseYtService<TrainingSampleService> implements TrainingSampleService {
    private static final String ID_PREFIX = "training-sample";
    public static final String RAW_SAMPLE = "raw_sample";
    public static final String RAW_SAMPLE_FILE = "raw_sample_file";
    public static final String ID_TYPE = "yandexuid";
    public static final Long ROWS_NUM = 100_000L;

    private static final Logger LOG = LoggerFactory.getLogger(DefaultTrainingSampleService.class);
    public static final MediaType APPLICATION_JSON = MediaType.get("application/json");

    private final Paths labYtPath;
    private final Schedulers schedulers;
    private final SiberiaClient siberiaClient;

    @Inject
    public DefaultTrainingSampleService(
            EEnvironment environment,
            PostgresClient sql,
            YtService yt,
            Schedulers schedulers,
            SiberiaClient siberiaClient
    ) {
        super(environment, sql, yt);
        this.labYtPath = DefaultLabService.getPaths(environment);
        this.schedulers = schedulers;
        this.siberiaClient = siberiaClient;
    }

    @Override
    public DefaultTrainingSampleService clone() {
        return new DefaultTrainingSampleService(environment(), sql(), ytService(), schedulers, siberiaClient);
    }

    private String getLogin() {
        return securityContext().getUserPrincipal().getName();
    }

    private long getNowTimestamp() {
        return Instant.now().getEpochSecond();
    }

    private Timestamps getInitialTimestamps() {
        long now = getNowTimestamp();
        return Timestamps.newBuilder().setCreated(now).setModified(now).build();
    }

    private <T> T withQuartz(Function<QuartzScheduler, T> callback) {
        return callback.apply(schedulers.getQuartz());
    }

    @Override
    public List<TTrainingSample> getAll() {
        return tables()
                .trainingSamples()
                .selectAccessible()
                .fetchInto(TTrainingSample.class);
    }

    @Override
    public TTrainingSample getSample(String id) throws NotFoundException {
        return fetchSample(tables(), id);
    }

    private TTrainingSample fetchSample(Tables tables, String id) {
        return tables
                .trainingSamples()
                .selectByIdQuery(id)
                .fetchOptionalInto(TTrainingSample.class)
                .orElseThrow(Exceptions::notFound);
    }

    @Override
    public TTrainingSample createSample(String path, InputStream stream, String audienceId,
                                        String name, AccessLevel accessLevel, Long ttl, Long positiveSegmentSize,
                                        Long negativeSegmentSize,
                                        String modelName, String partner, String loginsToShare) {
        YPath tablePath = (path != null) ? YPath.simple(path) : null;
        Long finalTtl = Optional.ofNullable(ttl).orElse(TrainingSamplesTable.DEFAULT_TTL);
        String finalName = Optional.ofNullable(name).orElse("Training Sample");
        Long finalPositiveSegmentSize =
                Optional.ofNullable(positiveSegmentSize).orElse(TrainableSegmentsTable.DEFAULT_POSITIVE_SEGMENT_SIZE);
        Long finalNegativeSegmentSize =
                Optional.ofNullable(negativeSegmentSize).orElse(TrainableSegmentsTable.DEFAULT_NEGATIVE_SEGMENT_SIZE);

        return withSqlTransaction(tables -> {
            var resultSample = withYtTransaction(ytTransaction -> {
                String sampleId = new EntityId(ID_PREFIX).toString();
                YPath destinationRoot = labYtPath.training_sample(sampleId);

                if (tablePath != null) {
                    List<YTreeNode> acl = Acl.get(cypress(), optionalId(ytTransaction), tablePath);
                    createDirectory(ytTransaction, destinationRoot, Acl.getDirectoryAcl(acl, accessLevel));
                    YPath destinationRawSample = destinationRoot.child(RAW_SAMPLE);
                    copyTable(ytTransaction, tablePath, destinationRawSample);
                } else if (stream != null) {
                    createDirectory(ytTransaction, destinationRoot, Acl.loginAcl(getLogin()));
                    YPath destinationRawSampleFile = destinationRoot.child(RAW_SAMPLE_FILE);
                    writeFile(ytTransaction, destinationRawSampleFile, stream);
                } else if (audienceId != null) {
                    createDirectory(ytTransaction, destinationRoot, Acl.loginAcl(getLogin()));
                }

                TTrainingSample.Builder sample = TTrainingSample.newBuilder()
                        .setId(sampleId)
                        .setName(finalName)
                        .setAuthor(getLogin())
                        .setTtl(finalTtl)
                        .setAccessLevel(MoreObjects.firstNonNull(accessLevel, AccessLevel.PRIVATE))
                        .setStatus(ETrainingSampleState.CREATED);
                if (partner != null) {
                    sample.setModelName(modelName).setPartner(partner);
                }
                sample.getTimestampsBuilder().mergeFrom(getInitialTimestamps());

                tables.trainingSamples().insertQuery(sample.build()).execute();

                return fetchSample(tables, sampleId);
            });

            scheduleTrainingJob(resultSample, (stream != null), audienceId, finalPositiveSegmentSize,
                    finalNegativeSegmentSize, finalTtl, loginsToShare);

            tables.trainingSamples().updateState(resultSample.getId(), ETrainingSampleState.WAITING).execute();

            return resultSample;
        });
    }

    private void scheduleTrainingJob(TTrainingSample sample, Boolean fromFile, String audienceId,
                                     Long positiveSegmentSize, Long negativeSegmentSize, Long ttl, String loginsToShare) {
        withQuartz(quartz -> {
            YPath outputDirPath = labYtPath.training_sample(sample.getId());
            YPath rawSampleTablePath = outputDirPath.child(RAW_SAMPLE);
            String rawSampleFilePath = fromFile ?
                    labYtPath.training_sample(sample.getId()).child(RAW_SAMPLE_FILE).toString() : null;

            JobBuilder jobDetail = quartz.job(sample.getId(), ModelTrainingJob.class)
                    .usingJobData(ModelTrainingJob.SAMPLE_ID, sample.getId())
                    .usingJobData(ModelTrainingJob.MODEL_NAME, sample.getModelName())
                    .usingJobData(ModelTrainingJob.PARTNER, sample.getPartner())
                    .usingJobData(ModelTrainingJob.LOGIN, sample.getAuthor())
                    .usingJobData(ModelTrainingJob.TABLE_PATH, rawSampleTablePath.toString())
                    .usingJobData(ModelTrainingJob.FILE_PATH, rawSampleFilePath)
                    .usingJobData(ModelTrainingJob.OUTPUT_DIR_PATH, outputDirPath.toString())
                    .usingJobData(ModelTrainingJob.AUDIENCE_ID, audienceId)
                    .usingJobData(ModelTrainingJob.POSITIVE_SEGMENT_SIZE, positiveSegmentSize.toString())
                    .usingJobData(ModelTrainingJob.NEGATIVE_SEGMENT_SIZE, negativeSegmentSize.toString())
                    .usingJobData(ModelTrainingJob.TTL, ttl.toString())
                    .usingJobData(ModelTrainingJob.LOGINS_TO_SHARE, loginsToShare)
                    .usingJobData(ModelTrainingJob.ENVIRONMENT, getSandboxEnvironment());

            quartz.schedule(jobDetail, quartz.startNow(sample.getId()));
            LOG.info("Scheduled Custom Model Training Job: {}", jobDetail);
            return jobDetail;
        });
    }

    @Override
    public TMetrics getSampleMetricsFromYt(String sampleId) {
        YPath metricsTablePath = labYtPath.training_sample(sampleId).child("model_metrics");
        List<YTreeMapNode> records = new ArrayList<>(1);
        yt().tables().read(metricsTablePath, YTableEntryTypes.YSON, (Consumer<YTreeMapNode>) records::add);
        try {
            var ytMetrics = TMetrics.parseFrom(records.get(0).getBytes("Metrics"));
            TMetrics.Builder metrics = TMetrics.newBuilder().setSampleId(sampleId).mergeFrom(ytMetrics);
            return metrics.build();
        } catch (InvalidProtocolBufferException e) {
            throw Exceptions.illegal(e.getMessage());
        }
    }

    @Override
    public TMetrics addSampleMetricsToDB(String sampleId) {
        return withSqlTransaction(tables -> withYtTransaction(ytTransaction -> {
            TMetrics metrics = getSampleMetricsFromYt(sampleId);
            if (!IsMetricsReadyForSample(sampleId)) {
                tables().metrics().insertQuery(metrics).execute();
            }
            return metrics;
        }));
    }

    @Override
    public TMetrics getMetrics(String sampleId) {
        return tables()
                .metrics()
                .selectByIdQuery(sampleId)
                .fetchOptionalInto(TMetrics.class)
                .orElseThrow(NotFoundException::new);
    }

    @Override
    public boolean IsMetricsReadyForSample(String sampleId) {
        return tables().metrics().isMetricsReadyForSample(sampleId);
    }

    private List<TTrainableSegment> fetchSegments(Tables tables, String sampleId) {
        return tables
                .trainableSegments()
                .selectBySampleIdQuery(sampleId)
                .fetchInto(TTrainableSegment.class);
    }

    @Override
    public List<TTrainableSegment> getSegmentsDetails(String sampleId) {
        return fetchSegments(tables(), sampleId);
    }

    @Override
    public TrainableSegments getSegments(String sampleId) {
        List<TTrainableSegment> segments = fetchSegments(tables(), sampleId);
        Map<String, String> segmentNameToUserSetId = new HashMap<>();
        for (TTrainableSegment segment : segments) {
            String segmentName = getSegmentTableName(segment.getTargetType(), segment.getOriginType());
            segmentNameToUserSetId.put(segmentName, segment.getSiberiaUserSetId());
        }
        TrainableSegments.Builder segmentsForSample = TrainableSegments.newBuilder()
                .setSampleId(sampleId)
                .putAllSegmentNameToUserSetId(segmentNameToUserSetId);
        return segmentsForSample.build();
    }

    @Override
    public TTrainingSample deleteSample(String sampleId) {
        return withSqlTransaction(tables -> withYtTransaction(ytTransaction -> {
            TTrainingSample sample = fetchSample(tables, sampleId);
            List<TTrainableSegment> segments = fetchSegments(tables, sample.getId());

            if (!Objects.equals(sample.getAuthor(), getLogin())) {
                throw Exceptions.illegal("Can't delete sample of someone else");
            }

            YPath sampleFolder = labYtPath.training_sample(sampleId);
            deleteDirectory(ytTransaction, sampleFolder);

            segments.forEach(segment -> SiberiaHelpers.removeSiberiaUserSet(siberiaClient,
                    segment.getSiberiaUserSetId()));

            tables.metrics().deleteByIdQuery(sample.getId()).execute();
            tables.trainableSegments().deleteBySampleIdQuery(sample.getId()).execute();
            tables.trainingSamples().deleteByIdQuery(sample.getId()).execute();

            return sample;
        }));
    }

    @Override
    public List<TTrainingSample> deleteOutdatedSamples() {
        return withLongSqlTransaction(tables -> withLongYtTransaction(ytTransaction -> {
            Long now = getNowTimestamp();

            List<TTrainingSample> outdatedSamples = tables.trainingSamples()
                    .selectOutdatedSamplesQuery(now)
                    .fetchInto(TTrainingSample.class);

            outdatedSamples.forEach(each -> {
                YPath sampleFolder = labYtPath.training_sample(each.getId());
                deleteDirectory(ytTransaction, sampleFolder);
                updateSampleState(each.getId(), ETrainingSampleState.DELETED);
            });

            LOG.info("Outdated samples deleted: {}", outdatedSamples.size());
            return outdatedSamples;
        }));
    }

    @Override
    public void writeTrainingSampleMetrics(TMetrics metrics) {
        withSqlTransaction(tables -> tables().metrics().insertQuery(metrics).execute());
    }

    @Override
    public void writeTrainingSampleError(TTrainingError trainingError) {
        withSqlTransaction(tables -> {
            tables().trainingSamples().addError(trainingError).execute();
            return tables().trainingSamples().updateState(trainingError.getSampleId(), ETrainingSampleState.ERROR).execute();
        });
    }

    @Override
    public void addIndustry(TIndustry industry) {
        TIndustry.Builder finalIndustry = industry.toBuilder().setAuthor(getLogin());
        withSqlTransaction(tables -> tables().trainingSamplesIndustries().insertQuery(finalIndustry.build()).execute());
    }

    @Override
    public List<TIndustry> getAllIndustries() {
        return tables()
                .trainingSamplesIndustries()
                .selectQuery()
                .fetchInto(TIndustry.class);
    }

    private String parseResponse(RawResponse response) {
        var responseCode = response.getHttpCode();
        LOG.info("Response code: {}", responseCode);
        if (responseCode != 200) {
            throw Exceptions.internal(
                    "user_sets/describe_ids Siberia handle return code = '" + responseCode
                            + "' and message = '" + response.getText() + "'");
        }

        try {
            var builder = TDescribeIdsResponse.newBuilder();
            JsonFormat.parser().merge(response.getText(), builder);
            return builder.build().getUserSetId();
        } catch (IOException e) {
            throw Exceptions.internal("Error while parsing response from Siberia: " + e.getMessage());
        }
    }

    private String getSegmentTableName(ETrainableSegmentTargetType targetType, ETrainableSegmentOriginType originType){
        return (originType.name() + "_" + targetType.name()).toLowerCase();
    }

    @Override
    public TTrainableSegment addSegmentDescriptionFromTable(String sampleId, ETrainableSegmentTargetType targetType,
                                                            ETrainableSegmentOriginType originType) {
        return withSqlTransaction(tables -> {
            String segmentName = getSegmentTableName(targetType, originType);
            YPath segmentTablePath = labYtPath.training_sample(sampleId).child("segments").child(segmentName);

            TTrainableSegment segment = withYtTransaction(ytTransaction -> {
                YPath tableWithRange = segmentTablePath.withRange(0, ROWS_NUM);

                ListF<YTreeMapNode> result = Cf.arrayList();
                yt().tables().read(
                        tableWithRange,
                        YTableEntryTypes.YSON,
                        (Consumer<YTreeMapNode>) result::add);

                List<TId> ids = result.map(record -> TId.newBuilder()
                        .setType(ID_TYPE)
                        .setValue(record.getString(ID_TYPE))
                        .build());

                LOG.info("Ids count: {}", ids.size());
                TIds finalIds = TIds.newBuilder().addAllIds(ids).build();

                TDescribingExperiment experiment = TDescribingExperiment.newBuilder().setCryptaIdUserDataVersion("by_crypta_id").build();
                var response = siberiaClient.tryCall(siberia -> {
                    var body = RequestBody.create(APPLICATION_JSON, JsonFormat.printer().print(finalIds));
                    return siberia.describeIds(body, "fast", JsonFormat.printer().print(experiment));
                });

                String userSetId = parseResponse(response);
                YTreeNode segmentTable = adoptableTable(segmentTablePath, ytTransaction);
                int rowsCount = segmentTable.getAttributeOrThrow(ROW_COUNT).intValue();
                TTrainableSegment.Builder segmentBuilder = TTrainableSegment.newBuilder()
                        .setSampleId(sampleId)
                        .setTargetType(targetType)
                        .setOriginType(originType)
                        .setSegmentSize(rowsCount)
                        .setName(segmentName)
                        .setPath(segmentTablePath.toString())
                        .setSiberiaUserSetId(userSetId);

                return segmentBuilder.build();
            });

            tables.trainableSegments().insertQuery(segment).execute();

            return segment;
        });
    }

    @Override
    public TTrainableSegment addSegmentDescription(String sampleId, ETrainableSegmentTargetType targetType,
                                                   ETrainableSegmentOriginType originType, Integer rowsCount,
                                                   String userSetId) {
        return withSqlTransaction(tables -> {
            String segmentName = getSegmentTableName(targetType, originType);
            YPath segmentTablePath = labYtPath.training_sample(sampleId).child("segments").child(segmentName);

            TTrainableSegment segment = TTrainableSegment.newBuilder()
                    .setSampleId(sampleId)
                    .setTargetType(targetType)
                    .setOriginType(originType)
                    .setSegmentSize(rowsCount)
                    .setName(segmentName)
                    .setPath(segmentTablePath.toString())
                    .setSiberiaUserSetId(userSetId).build();

            tables.trainableSegments().insertQuery(segment).execute();

            return segment;
        });
    }

    @Override
    public void updateSampleState(String id, ETrainingSampleState state) {
        withSqlTransaction(tables -> tables().trainingSamples().updateState(id, state).execute());
    }
}
