package ru.yandex.crypta.lab.yt;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Scanner;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import javax.inject.Inject;

import NCrypta.NSiberia.NYtDescriber.GroupStats.TGroupStats;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.Maps;
import javafx.util.Pair;
import org.glassfish.jersey.internal.inject.InjectionManager;
import org.jooq.Record4;
import org.jooq.Result;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.crypta.audience.proto.TUserDataStats;
import ru.yandex.crypta.clients.geobase.GeobaseClient;
import ru.yandex.crypta.clients.pgaas.PostgresClient;
import ru.yandex.crypta.clients.utils.Caching;
import ru.yandex.crypta.common.exception.Exceptions;
import ru.yandex.crypta.lab.AudienceService;
import ru.yandex.crypta.lab.ConstructorService;
import ru.yandex.crypta.lab.Keyword;
import ru.yandex.crypta.lab.LabService;
import ru.yandex.crypta.lab.ModelService;
import ru.yandex.crypta.lab.SampleService;
import ru.yandex.crypta.lab.SegmentExportService;
import ru.yandex.crypta.lab.SegmentService;
import ru.yandex.crypta.lab.Site2VecService;
import ru.yandex.crypta.lab.TrainingSampleService;
import ru.yandex.crypta.lab.WithOptions;
import ru.yandex.crypta.lab.base.BaseYtService;
import ru.yandex.crypta.lab.custom_audience.CustomAudienceClient;
import ru.yandex.crypta.lab.proto.Sample;
import ru.yandex.crypta.lab.proto.Segment;
import ru.yandex.crypta.lab.proto.SegmentGroup;
import ru.yandex.crypta.lab.proto.SocdemThreshold;
import ru.yandex.crypta.lab.proto.TSampleGroupID;
import ru.yandex.crypta.lab.proto.TSimpleSampleStats;
import ru.yandex.crypta.lab.proto.TSimpleSampleStatsWithInfo;
import ru.yandex.crypta.lab.proto.WholeView;
import ru.yandex.crypta.lab.tables.KeywordsTable;
import ru.yandex.crypta.lab.tables.SegmentExportsTable;
import ru.yandex.crypta.lab.tables.SegmentExportsTagsTable;
import ru.yandex.crypta.lab.tables.SegmentsTable;
import ru.yandex.crypta.lab.tables.SocdemThresholdsTable;
import ru.yandex.crypta.lab.utils.Acl;
import ru.yandex.crypta.lab.utils.Affinities;
import ru.yandex.crypta.lab.utils.Geo;
import ru.yandex.crypta.lab.utils.Paths;
import ru.yandex.crypta.lab.utils.Proto;
import ru.yandex.crypta.lab.utils.SegmentYtField;
import ru.yandex.crypta.lib.proto.EEnvironment;
import ru.yandex.crypta.lib.yt.PathUtils;
import ru.yandex.crypta.lib.yt.YtService;
import ru.yandex.crypta.siberia.bin.custom_audience.common.proto.TCaRule;
import ru.yandex.crypta.siberia.bin.custom_audience.common.proto.TExtendedCaRule;
import ru.yandex.crypta.siberia.proto.TStats;
import ru.yandex.inside.yt.kosher.cypress.RangeLimit;
import ru.yandex.inside.yt.kosher.cypress.YPath;
import ru.yandex.inside.yt.kosher.impl.ytree.builder.YTree;
import ru.yandex.inside.yt.kosher.tables.YTableEntryTypes;

public class DefaultLabService extends BaseYtService<LabService> implements LabService {

    private static final Cache<String, TUserDataStats> globalStatsCache = CacheBuilder.newBuilder()
            .expireAfterWrite(1, TimeUnit.HOURS)
            .build();
    private static final String defaultGlobalStatsCacheKey = "";
    private final Paths paths;
    private final Geo geo;
    private final InjectionManager injectionManager;
    private final CustomAudienceClient customAudienceClient;
    private final Affinities affinities = new Affinities(10);

    @Inject
    public DefaultLabService(
            InjectionManager injectionManager,
            EEnvironment environment,
            YtService yt,
            PostgresClient sql,
            GeobaseClient geobaseClient,
            CustomAudienceClient customAudienceClient
    )
    {
        super(environment, sql, yt);
        this.injectionManager = injectionManager;
        this.paths = getPaths(environment);
        this.geo = new Geo(geobaseClient);
        this.customAudienceClient = customAudienceClient;
    }

    public static Paths getPaths(EEnvironment environment) {
        return new Paths(YPath.simple("//home/crypta").child(PathUtils.toPath(environment)).child("lab"));
    }

    @Override
    public void initialize() {
        globalStatsCache.put(defaultGlobalStatsCacheKey, getGlobalStats());
    }

    @Override
    public Map<String, String> getKeywordInfo(String id) {
        return tables()
                .keywords()
                .selectByIdQuery(id)
                .fetchOptionalInto(Keyword.class)
                .map(Keyword::getSegments)
                .orElse(new HashMap<>());
    }

    @Override
    public Keyword updateKeywordInfo(Keyword keyword) {
        return withSqlTransaction(tables -> {
            KeywordsTable table = tables.keywords();
            Optional<Keyword> optionalKeyword = table
                    .selectByIdQuery(keyword.getId())
                    .fetchOptionalInto(Keyword.class);
            if (optionalKeyword.isPresent()) {
                table.updateById(keyword).execute();
            } else {
                table.insertById(keyword).execute();
            }
            return keyword;
        });
    }

    @Override
    public WholeView getAll() {
        return WholeView.newBuilder()
                .addAllModels(models().getAll())
                .addAllSegments(segments().getAll())
                .build();
    }

    @Override
    public JsonNode getDirectTaxonomy() {
        InputStream inputStream = getClass().getResourceAsStream("/direct_taxonomy.json");
        Scanner s = new Scanner(inputStream, StandardCharsets.UTF_8.name()).useDelimiter("\\A");

        try {
            JsonNode jsonFromFile = new ObjectMapper().readTree(s.hasNext() ? s.next() : "");
            final JsonNodeFactory factory = JsonNodeFactory.instance;
            final ObjectNode interests = factory.objectNode();

            for (Segment segment : segments().getAll()) {
                if (
                        !(Objects.isNull(segment.getTankerNameKey()) || Objects.isNull(segment.getTankerDescriptionKey()))
                        && !(Objects.equals(segment.getTankerNameKey(), "") || Objects.equals(segment.getTankerDescriptionKey(), ""))
                ) {
                    if (segment.getType().equals(Segment.Type.INTEREST)) {
                        final ObjectNode interest = factory.objectNode();
                        final ObjectNode types = factory.objectNode();

                        interest.put(
                                "tanker_name_key", segment.getTankerNameKey()
                        );
                        interest.put(
                                "tanker_description_key",
                                segment.getTankerDescriptionKey()
                        );
                        String parentId = segment.getParentId();
                        if (parentId.startsWith("group-")) {
                            parentId = null;
                        }
                        interest.put("parent_id", parentId);
                        for (Segment.Export export : segment.getExports().getExportsList()) {
                            String typeValue = String.format("%d:%d", export.getKeywordId(), export.getSegmentId());
                            if (export.getType().equals(Segment.Export.Type.LONGTERM)) {
                                types.put("long", typeValue);
                            } else if (export.getType().equals(Segment.Export.Type.SHORTTERM)) {
                                types.put("short", typeValue);
                            }
                        }
                        interest.set("types", types);
                        if (types.size() > 0) {
                            interests.set(segment.getId(), interest);
                        }
                    }
                }
            }
            ((ObjectNode) jsonFromFile).set("interests", interests);
            return jsonFromFile;
        } catch (IOException e) {
            throw Exceptions.unchecked(e);
        }
    }

    @Override
    public Map<String, Map<String, Double>> getSocdemThresholds() {
        List<SocdemThreshold> socdemThresholds = tables()
                .socdemThresholds()
                .selectQuery()
                .fetchInto(SocdemThreshold.class);
        Map<String, Map<String, Double>> result = Maps.newHashMap();
        socdemThresholds.forEach(each -> {
            result.computeIfAbsent(each.getSocdem(), k -> Maps.newHashMap());
            result.get(each.getSocdem()).put(each.getSegment(), each.getThreshold());
        });
        return result;
    }

    @Override
    public SocdemThreshold updateSocdemThreshold(SocdemThreshold socdemThreshold) {
        return withSqlTransaction(tables -> {
            SocdemThresholdsTable table = tables.socdemThresholds();
            Optional<SocdemThreshold> optionalSocdemSegment = table
                    .selectByIdQuery(socdemThreshold)
                    .fetchOptionalInto(SocdemThreshold.class);
            if (optionalSocdemSegment.isPresent()) {
                table.updateQuery(socdemThreshold).execute();
            } else {
                table.insertQuery(socdemThreshold).execute();
            }

            return table.selectByIdQuery(socdemThreshold).fetchOneInto(SocdemThreshold.class);
        });
    }

    @Override
    public AudienceService audiences() {
        return prepare(injectionManager.createAndInitialize(DefaultAudienceService.class));
    }

    @Override
    public ConstructorService constructor() {
        return prepare(injectionManager.createAndInitialize(DefaultConstructorService.class));
    }

    @Override
    public SampleService samples() {
        return prepare(injectionManager.createAndInitialize(DefaultSampleService.class));
    }

    @Override
    public TrainingSampleService trainingSamples()  {
        return prepare(injectionManager.createAndInitialize(DefaultTrainingSampleService.class));
    }

    @Override
    public Site2VecService site2vec() {
        return injectionManager.createAndInitialize(DefaultSite2VecService.class);
    }

    @Override
    public SegmentService segments() {
        return prepare(injectionManager.createAndInitialize(DefaultSegmentService.class));
    }

    @Override
    public SegmentExportService segmentExports() {
        return prepare(injectionManager.createAndInitialize(DefaultSegmentExportService.class));
    }

    @Override
    public ModelService models() {
        return prepare(injectionManager.createAndInitialize(DefaultModelService.class));
    }

    private <T extends WithOptions<T>> T prepare(T service) {
        return service
                .withLanguage(language())
                .withSecurityContext(securityContext());
    }

    private TSimpleSampleStats getSimpleSampleStats(TUserDataStats baseUserStats, TUserDataStats targetUserStats) {
        TSimpleSampleStats.Builder result = TSimpleSampleStats.newBuilder();

        Map<Long, Map<Long, Segment>> exportsMapping = segments().getExportsToSimpleSegments();
        Map<String, List<SegmentGroup>> parents = segments().getParentsPerSegment();

        computeCounts(result, targetUserStats);

        affinities.computeForSegments(result, targetUserStats, baseUserStats, exportsMapping, parents);
        affinities.computeForSocdem(result, targetUserStats, baseUserStats);
        affinities.computeForTokens(() -> result.getAffinitiesBuilder().addByHostBuilder(),
                targetUserStats.getAffinities().getHosts(), baseUserStats.getAffinities().getHosts());
        affinities.computeForTokens(() -> result.getAffinitiesBuilder().addByWordBuilder(),
                targetUserStats.getAffinities().getWords(), baseUserStats.getAffinities().getWords());
        affinities.computeForTokens(() -> result.getAffinitiesBuilder().addByAppBuilder(),
                targetUserStats.getAffinities().getApps(), baseUserStats.getAffinities().getApps());

        return result.build();
    }

    @Override
    public Optional<TSimpleSampleStatsWithInfo> getStatsFromSiberia(String userSetId) {
        return processStatsFromSiberia(samples().getStatsFromSiberia(userSetId));
    }

    @Override
    public Optional<TSimpleSampleStatsWithInfo> getStatsFromSiberia(Pair<String, String> targetBaseIds) {
        var targetStats = samples().getStatsFromSiberia(targetBaseIds.getKey());
        var baseStats = samples().getStatsFromSiberia(targetBaseIds.getValue());

        return processStatsFromSiberiaWithBaseSample(targetStats, baseStats);
    }

    @Override
    public Optional<TSimpleSampleStatsWithInfo> getStatsFromSiberia(String userSetId, Optional<String> customGlobalStatsVersion) {
        return processStatsFromSiberia(samples().getStatsFromSiberia(userSetId), customGlobalStatsVersion);
    }

    @Override
    public Optional<TSimpleSampleStatsWithInfo> getStatsFromSiberia(String userSetId, String login) {
        return processStatsFromSiberia(samples().getStatsFromSiberia(userSetId, login));
    }

    @Override
    public String getUserSetIdByExportId(String exportId) {
        String segmentId = segments().getByExportId(exportId).getId();
        return segments().getSegmentPairBySegmentLabId(segmentId + "__" + exportId).getUserSetId();
    }

    @Override
    public Optional<TSimpleSampleStats> getCustomAudienceStats(TExtendedCaRule rule) {
        var targetStats = customAudienceClient.getStatsByExtendedRule(rule);
        if (!targetStats.isPresent()) {
            return Optional.empty();
        }

        TUserDataStats baseStats = getGlobalStatsUsingCache();
        var stats = getSimpleSampleStats(baseStats, targetStats.get());
        return Optional.of(stats);
    }

    @Override
    public Optional<TSimpleSampleStats> getCustomAudienceStats(TCaRule rule) {
        var targetStats = customAudienceClient.getStats(rule);
        if (!targetStats.isPresent()) {
            return Optional.empty();
        }

        TUserDataStats baseStats = getGlobalStatsUsingCache();
        var stats = getSimpleSampleStats(baseStats, targetStats.get());
        return Optional.of(stats);
    }

    private Optional<TSimpleSampleStatsWithInfo> processStatsFromSiberia(Optional<TStats> siberiaStatsOptional) {
        return processStatsFromSiberia(siberiaStatsOptional, Optional.empty());
    }

    private Optional<TSimpleSampleStatsWithInfo> processStatsFromSiberiaWithBaseSample(Optional<TStats> targetStatsOptional, Optional<TStats> baseStatsOptional) {
        if (!targetStatsOptional.isPresent() || !baseStatsOptional.isPresent()) {
            return Optional.empty();
        }

        var targetStats = targetStatsOptional.get();
        var targetUserDataStats = targetStats.getUserDataStats();

        var baseStats = baseStatsOptional.get();
        var baseUserDataStats = baseStats.getUserDataStats();

        CheckCanViewStats(targetUserDataStats);
        CheckCanViewStats(baseUserDataStats, "Base sample is too small");

        var stats = getSimpleSampleStats(baseUserDataStats, targetUserDataStats);

        var simpleSampleStatsWithInfo = TSimpleSampleStatsWithInfo.newBuilder()
                .setStats(stats)
                .setInfo(targetStats.getInfo())
                .build();

        return Optional.of(simpleSampleStatsWithInfo);
    }

    private Optional<TSimpleSampleStatsWithInfo> processStatsFromSiberia(Optional<TStats> siberiaStatsOptional, Optional<String> customGlobalStatsVersion) {
        if (!siberiaStatsOptional.isPresent()) {
            return Optional.empty();
        }
        var siberiaStats = siberiaStatsOptional.get();
        var targetStats = siberiaStats.getUserDataStats();

        CheckCanViewStats(targetStats);

        TUserDataStats baseStats = getGlobalStatsUsingCache(customGlobalStatsVersion);
        var stats = getSimpleSampleStats(baseStats, targetStats);

        var simpleSampleStatsWithInfo = TSimpleSampleStatsWithInfo.newBuilder()
                .setStats(stats)
                .setInfo(siberiaStats.getInfo())
                .build();

        return Optional.of(simpleSampleStatsWithInfo);
    }


    private TUserDataStats getUserDataStats(TSampleGroupID target) {
        Optional<TUserDataStats> maybeStats = samples().getStats(target.getSampleID(), target.getGroupID());
        TUserDataStats targetStats = maybeStats.orElseThrow(Exceptions::notFound);
        CheckCanViewStats(targetStats);
        return targetStats;
    }


    @Override
    public Optional<TSimpleSampleStats> getStats(TSampleGroupID target) {
        TUserDataStats baseStats = getGlobalStatsUsingCache();
        return Optional.of(getSimpleSampleStats(baseStats, getUserDataStats(target)));
    }

    @Override
    public Optional<TSimpleSampleStats> getStats(TSampleGroupID target, TSampleGroupID base) {
        Sample baseSample = samples().getSample(base.getSampleID());
        TStats baseStats = samples().getStatsFromSiberia(baseSample.getSiberiaUserSetId()).orElseThrow(Exceptions::notFound);
        return Optional.of(getSimpleSampleStats(baseStats.getUserDataStats(), getUserDataStats(target)));
    }

    @Override
    public Optional<TSimpleSampleStats> getStatsByExportId(TSampleGroupID target, String baseExportId) {
        String userSetId = getUserSetIdByExportId(baseExportId);
        var baseStats = samples().getStatsFromSiberia(userSetId).orElseThrow(Exceptions::notFound);
        return Optional.of(getSimpleSampleStats(baseStats.getUserDataStats(), getUserDataStats(target)));
    }

    private void computeCounts(TSimpleSampleStats.Builder result, TUserDataStats stats) {
        TUserDataStats.TAttributesStats attributes = stats.getAttributes();
        TSimpleSampleStats.TCounts.Builder counts = result.getCountsBuilder();
        attributes.getAgeList().forEach(counts::addByAge);
        attributes.getGenderList().forEach(counts::addByGender);
        attributes.getIncomeList().forEach(counts::addByIncome);
        attributes.getDeviceList().forEach(counts::addByDevice);
        attributes.getGenderAgeIncomeList().forEach(counts::addByGenderAgeIncome);
        attributes.getRegionList().forEach(each -> result
                .getCountsBuilder()
                .addByRegionBuilder()
                .setName(geo.getLinguistics(each.getRegion(), language()).getNominativeCase())
                .setID(each.getRegion())
                .setCount(each.getCount())
        );
    }

    @Override
    public TUserDataStats getGlobalStats() {
        return getGlobalStats(Optional.empty());
    }

    private TUserDataStats getGlobalStats(Optional<String> customVersion) {
        var path = customVersion.map(paths::customUserDataStats).orElseGet(paths::globalSampleStats);
        // TODO use OOB protobuf reading
        ListF<TUserDataStats> stats = ytService().readTableYson(path, entry -> {
            TUserDataStats.Builder builder = TUserDataStats.newBuilder();
            TUserDataStats.getDescriptor()
                    .getFields()
                    .forEach(field -> Proto.readField(field, entry, builder));
            return builder.build();
        });
        return stats.first();
    }

    @Override
    public TSimpleSampleStats getSimpleSampleStats(TUserDataStats userDataStats) {
        var base = getGlobalStatsUsingCache();
        return getSimpleSampleStats(base, userDataStats);
    }

    private Result<Record4<String, String, Long, Long>> getSegmentsByTag(String tag) {
        return sqlDsl().select(SegmentsTable.NAME_RU, SegmentsTable.DESCRIPTION_RU, SegmentExportsTable.EXPORT_SEGMENT_ID, SegmentExportsTable.EXPORT_KEYWORD_ID)
                .from(SegmentsTable.TABLE
                        .rightJoin(SegmentExportsTable.TABLE
                                .rightJoin(SegmentExportsTagsTable.TABLE)
                                .on(SegmentExportsTable.ID.eq(SegmentExportsTagsTable.SEGMENT_EXPORT_ID))
                        ).on(SegmentsTable.ID.eq(SegmentExportsTable.SEGMENT_ID))
                ).where(SegmentExportsTagsTable.TAG.eq(tag)).fetch();
    }

    @Override
    public List<Map<String, String>> getSegmentsMetaForMediaservices() {
        Result<Record4<String, String, Long, Long>> segments = getSegmentsByTag("aletcrm");
        return  segments.stream().map(segment -> {
            Map<String, String> meta = new HashMap<>();
            meta.put("name", segment.get(SegmentsTable.NAME_RU));
            meta.put("description", segment.get(SegmentsTable.DESCRIPTION_RU));
            meta.put("yt_field", SegmentYtField.get(Integer.valueOf(segment.get(SegmentExportsTable.EXPORT_KEYWORD_ID).toString())));
            meta.put("segment_id", segment.get(SegmentExportsTable.EXPORT_SEGMENT_ID).toString());

            return meta;
        }).collect(Collectors.toList());
    }

    private TUserDataStats getGlobalStatsUsingCache() {
        return getGlobalStatsUsingCache(Optional.empty());
    }

    private TUserDataStats getGlobalStatsUsingCache(Optional<String> customVersion) {
        Callable<TUserDataStats> supplier = () -> getGlobalStats(customVersion);
        var cacheKey = customVersion.orElse(defaultGlobalStatsCacheKey);
        return Caching.fetch(globalStatsCache, cacheKey, supplier);

    }

    @Override
    public LabService clone() {
        return new DefaultLabService(injectionManager, environment(), ytService(), sql(), geo.getGeobase(), customAudienceClient);
    }

    @Override
    public Optional<TSimpleSampleStats> getStatsFromYt(YPath path, String groupId, Optional<String> customUserDataStatsVersion) {
        List<TGroupStats> stats = new ArrayList<>();

        ytService().getHahn().tables().read(
                path.withExact(
                        new RangeLimit(Cf.list(YTree.stringNode(groupId)), -1, -1)
                ),
                YTableEntryTypes.nativeProto(TGroupStats.newBuilder()),
                (Consumer<TGroupStats>) stats::add

        );

        if (stats.isEmpty()) {
            return Optional.empty();
        }

        if (stats.size() != 1) {
            throw Exceptions.illegal(MessageFormat.format("More than one result for given group id {}", groupId));
        }

        var targetStats = stats.get(0).getStats();

        CheckCanViewStats(targetStats);

        var baseStats = getGlobalStatsUsingCache(customUserDataStatsVersion);

        return Optional.of(getSimpleSampleStats(baseStats, targetStats));
    }

    private void CheckCanViewStats(TUserDataStats stats) {
        CheckCanViewStats(stats, "Sample is too small");
    }

    private void CheckCanViewStats(TUserDataStats stats, String msg) {
        if (!Acl.canViewStats(stats, securityContext())) {
            throw Exceptions.illegal(msg);
        }
    }
}
