package ru.yandex.crypta.lab.yt;

import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.time.Instant;
import java.util.List;
import java.util.stream.Collectors;

import javax.inject.Inject;

import com.amazonaws.AmazonServiceException;
import com.amazonaws.services.s3.model.S3ObjectInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.crypta.clients.pgaas.PostgresClient;
import ru.yandex.crypta.common.exception.NotFoundException;
import ru.yandex.crypta.common.ws.EntityId;
import ru.yandex.crypta.lab.ModelService;
import ru.yandex.crypta.lab.base.BaseService;
import ru.yandex.crypta.lab.proto.Model;
import ru.yandex.crypta.lab.proto.ModelSegmentRelation;
import ru.yandex.crypta.lab.proto.Segment;
import ru.yandex.crypta.lab.proto.Thresholds;
import ru.yandex.crypta.lab.proto.Timestamps;
import ru.yandex.crypta.lab.tables.Tables;
import ru.yandex.crypta.lab.utils.ModelSegmentRelations;
import ru.yandex.crypta.lib.mds.MdsService;
import ru.yandex.crypta.lib.proto.EEnvironment;
import ru.yandex.misc.io.InputStreamSourceUtils;

public class DefaultModelService extends BaseService<ModelService> implements ModelService {

    private static final Logger LOG = LoggerFactory.getLogger(DefaultLabService.class);
    private static final String S3_LAB_FOLDER = "lab";
    private final MdsService mds;

    @Inject
    public DefaultModelService(EEnvironment environment, PostgresClient sql, MdsService mds) {
        super(environment, sql);
        this.mds = mds;
    }

    @Override
    public List<Model> getAll() {
        return tables()
                .models()
                .selectQuery()
                .fetchInto(Model.class);
    }

    @Override
    public Model get(String id) throws NotFoundException {
        return fetchModel(id, tables());
    }

    @Override
    public Model getByTag(String tag) throws NotFoundException {
        return tables()
                .models()
                .selectByTagQuery(tag)
                .fetchOptionalInto(Model.class)
                .orElseThrow(NotFoundException::new);
    }

    @Override
    public List<Model> get(List<String> ids) throws NotFoundException {
        return withSqlTransaction(tables -> tables.models().selectByIdsQuery(ids).fetchInto(Model.class));
    }

    @Override
    public Model delete(String id) throws NotFoundException, AmazonServiceException {
        return withSqlTransaction(tables -> {
            Model model = fetchModel(id, tables);
            LOG.info("Deleting model from database {}", model);

            modelSegmentRelations(tables).clear(model);
            tables.models().deleteByIdQuery(id).execute();

            LOG.info("Deleting model file from S3 MDS storage {}", id);
            mds.deleteFile(S3_LAB_FOLDER, id);

            return model;
        });
    }

    private Model fetchModel(String id, Tables tables) {
        return tables
                .models()
                .selectByIdQuery(id)
                .fetchOptionalInto(Model.class)
                .orElseThrow(NotFoundException::new);
    }

    @Override
    public List<Model> getActive(String id) {
        List<String> modelIds = tables()
                .modelSegmentRelations()
                .selectBySegmentIdQuery(id)
                .fetchInto(ModelSegmentRelation.class)
                .stream()
                .map(ModelSegmentRelation::getModelId)
                .collect(Collectors.toList());
        return get(modelIds)
                .stream()
                .filter(model -> model.getState().equals(Model.State.ENABLED))
                .collect(Collectors.toList());
    }

    @Override
    public Model link(Model model, Segment segment, Thresholds thresholds) {
        final long created = Instant.now().getEpochSecond();
        final long modified = created;

        return withSqlTransaction(tables -> {
            ModelSegmentRelation.Builder relationPrototype = ModelSegmentRelation
                    .newBuilder()
                    .setModelId(model.getId())
                    .setSegmentId(segment.getId())
                    .setTimestamps(Timestamps
                            .newBuilder()
                            .setCreated(created)
                            .setModified(modified)
                            .build()
                    );

            if (thresholds != null) {
                relationPrototype.setThresholds(thresholds);
            }

            tables.modelSegmentRelations()
                    .insertQuery(relationPrototype.build())
                    .execute();

            return fetchModel(model.getId(), tables);
        });
    }

    @Override
    public Model unlink(Model model, Segment segment) {
        return withSqlTransaction(tables -> {
            tables.modelSegmentRelations()
                    .deleteByModelAndSegmentIdQuery(model.getId(), segment.getId())
                    .execute();
            return fetchModel(model.getId(), tables);
        });
    }

    @Override
    public Model put(byte[] data, Model.Builder prototype) {
        return withSqlTransaction(tables -> {
            String id = new EntityId("model").toString();

            LOG.info("Uploading model {}", id);
            String s3Uri = mds.putFile(S3_LAB_FOLDER, id, data);

            long timestamp = Instant.now().getEpochSecond();
            Model model = prototype
                    .setId(id)
                    .setUri("")
                    .setS3Uri(s3Uri)
                    .setTimestamps(Timestamps.newBuilder().setCreated(timestamp).setModified(timestamp))
                    .build();

            LOG.info("Creating model {}", model);
            tables.models().insertQuery(model).execute();
            modelSegmentRelations(tables).store(model);
            return fetchModel(id, tables);
        });
    }

    @Override
    public byte[] getFileByTag(String tag) throws NotFoundException {
        ByteArrayOutputStream os = new ByteArrayOutputStream();
        Model model = getByTag(tag);

        readFileFromS3Mds(model, os);

        return os.toByteArray();
    }

    @Override
    public byte[] getFile(String id) throws NotFoundException {
        ByteArrayOutputStream os = new ByteArrayOutputStream();
        Model model = get(id);

        readFileFromS3Mds(model, os);

        return os.toByteArray();
    }

    @Override
    public Model update(Model.Builder model) {
        return withSqlTransaction(tables -> {
            tables.models().updateQuery(model.build()).execute();

            return model.build();
        });
    }

    private ModelSegmentRelations modelSegmentRelations(Tables tables) {
        return new ModelSegmentRelations(tables);
    }

    private void readFileFromS3Mds(Model model, OutputStream os) {
        S3ObjectInputStream fileInputStream = mds
                .getFile(model.getS3Uri().replaceFirst("/", ""))
                .getObjectContent();

        InputStreamSourceUtils.wrap(fileInputStream).readTo(os);
    }

    @Override
    public DefaultModelService clone() {
        return new DefaultModelService(environment(), sql(), mds);
    }
}
