package ru.yandex.travel.orders.services.finances.billing;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.time.Duration;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.javamoney.moneta.Money;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.inside.yt.kosher.Yt;
import ru.yandex.inside.yt.kosher.common.GUID;
import ru.yandex.inside.yt.kosher.cypress.LockMode;
import ru.yandex.inside.yt.kosher.cypress.YPath;
import ru.yandex.inside.yt.kosher.impl.ytree.builder.YTree;
import ru.yandex.inside.yt.kosher.tables.YTableEntryTypes;
import ru.yandex.inside.yt.kosher.ytree.YTreeNode;
import ru.yandex.travel.commons.logging.NestedMdc;
import ru.yandex.travel.orders.entities.Order;
import ru.yandex.travel.orders.entities.finances.BillingTransaction;
import ru.yandex.travel.orders.entities.finances.BillingTransactionKind;
import ru.yandex.travel.orders.entities.finances.BillingTransactionPaymentSystemType;
import ru.yandex.travel.orders.entities.finances.BillingTransactionPaymentType;
import ru.yandex.travel.orders.entities.finances.BillingTransactionType;
import ru.yandex.yt.ytclient.proxy.request.CreateNode;
import ru.yandex.yt.ytclient.proxy.request.ObjectType;
import ru.yandex.yt.ytclient.proxy.request.TransactionalOptions;

import static java.util.stream.Collectors.toSet;
import static ru.yandex.travel.orders.services.finances.billing.BillingHelper.BILLING_DATE_TIME_FORMAT;
import static ru.yandex.travel.orders.services.finances.billing.BillingHelper.BILLING_TIME_ZONE;

@RequiredArgsConstructor
@Slf4j
public class BillingTransactionYtTableClient {
    private static final ObjectMapper JSON_MAPPER = new ObjectMapper();

    private static final boolean FLAG_REQUIRED = false;
    private static final boolean FLAG_OPTIONAL = true;
    private static final boolean FLAG_DO_NOT_PING_ANCESTOR_TRANSACTION = false;
    private static final boolean FLAG_NOT_RECURSIVE = false;
    private static final boolean FLAG_DO_NOT_IGNORE_EXISTING = false;
    private static final String ATTRIBUTE_LAST_EXPORTED_TX_ID = "travel_last_exported_tx_id";

    private final BillingTransactionYtTableClientProperties properties;
    private final Yt yt;

    public void exportTransactions(LocalDate transactionsDate, Collection<BillingTransaction> transactions) {
        doInTx(txId -> exportTransactions(txId, transactionsDate, transactions));
    }

    public void exportTransactions(GUID txId, LocalDate transactionsDate, Collection<BillingTransaction> transactions) {
        log.info("Starting exporting {} billing transactions to YT", transactions.size());
        Preconditions.checkArgument(!transactions.isEmpty(),
                "A not empty transaction set is expected; date %s", transactionsDate);

        checkIds(transactions);
        checkTransactionsDate(transactionsDate, transactions);
        for (BillingTransactionKind kind : BillingTransactionKind.values()) {
            YPath table = getDestinationTablePath(transactionsDate, kind);

            ensureTransactionsTableExists(txId, table, kind);
            obtainExclusiveTableLock(txId, table);
            Long lastExportedTransactionId = readLastTransactionId(txId, table);
            Collection<BillingTransaction> newTransactions = filterTransactions(transactions, lastExportedTransactionId, kind);
            // the table for the release day should have already been created,
            // we will migrate it all the previous tables on the next day
            // (keep the commented out code as a template for future migrations)
            //boolean withServiceId = checkColumnExists(txId, table, "service_id");
            boolean withServiceId = true;
            if (newTransactions.isEmpty()) {
                log.warn("All {} transactions have been filtered out by the last exported transaction id - {}",
                        kind, lastExportedTransactionId);
                return;
            }
            writeTransactions(txId, table, newTransactions, withServiceId);
            long newLastTxId = newTransactions.stream().mapToLong(BillingTransaction::getYtId).max().orElseThrow();
            writeLastTransactionId(txId, table, newLastTxId);
        }
    }

    void checkIds(Collection<BillingTransaction> transactions) {
        Set<Long> ids = new HashSet<>();
        for (BillingTransaction transaction : transactions) {
            if (transaction.getYtId() == null) {
                throw new IllegalArgumentException("Transaction without YT id: " + transaction);
            }
            BillingTransaction origTx = transaction.getOriginalTransaction();
            if (origTx != null && origTx.getYtId() == null) {
                throw new IllegalArgumentException("Original transaction without YT id: " + transaction);
            }
            if (!ids.add(transaction.getYtId())) {
                throw new IllegalArgumentException("Duplicate transaction YT id: " + transaction.getYtId());
            }
        }
    }

    void checkTransactionsDate(LocalDate expectedTransactionDate, Collection<BillingTransaction> transactions) {
        for (BillingTransaction transaction : transactions) {
            LocalDate payoutDate = BillingHelper.toBillingDate(transaction.getPayoutAt()).toLocalDate();
            if (!expectedTransactionDate.equals(payoutDate)) {
                throw new IllegalArgumentException(String.format(
                        "Payout date mismatch for %s: expected date: %s, actual: %s",
                        transaction.getDescription(), expectedTransactionDate, payoutDate
                ));
            }
        }
    }

    YPath getDestinationTablePath(LocalDate transactionsDate, BillingTransactionKind kind) {
        String tablesDirectory = "";
        if (kind == BillingTransactionKind.INCOME)
            tablesDirectory = properties.getIncomeTablesDirectory();
        if (kind == BillingTransactionKind.PAYMENT)
            tablesDirectory = properties.getTablesDirectory();
        return YPath.simple(tablesDirectory).child(transactionsDate.toString());
    }

    void doInTx(Consumer<GUID> action) {
        Duration txDuration = properties.getTransactionDuration();
        GUID txId = yt.transactions().start(txDuration);
        try {
            log.info("Running yt operations in the [{}] transaction", txId);
            action.accept(txId);
            yt.transactions().commit(txId);
        } catch (Exception e) {
            try {
                yt.transactions().abort(txId);
            } catch (Exception e2) {
                log.warn("Failed to abort the current transaction", e2);
                e.addSuppressed(e2);
            }
            throw e;
        }
    }

    void ensureTransactionsTableExists(GUID txId, YPath path, BillingTransactionKind kind) {
        if (yt.cypress().exists(Optional.of(txId), FLAG_DO_NOT_PING_ANCESTOR_TRANSACTION, path)) {
            log.info("Table [{}] exists", path);
            return;
        }
        log.info("Creating a new transactions table: {}", path);
        yt.cypress().create(new CreateNode(path, ObjectType.Table)
                .setAttributes(Map.of("schema", getTransactionsTableSchema(kind)))
                .setTransactionalOptions(new TransactionalOptions(txId))
                .setRecursive(FLAG_NOT_RECURSIVE)
                .setIgnoreExisting(FLAG_DO_NOT_IGNORE_EXISTING)
                .setForce(false));
    }

    private void obtainExclusiveTableLock(GUID txId, YPath tablePath) {
        log.info("Trying to obatin a lock for {} with txId {}", tablePath, txId);
        yt.cypress().lock(txId, tablePath, LockMode.EXCLUSIVE);
    }

    Set<String> getTransactionsTableColumns(BillingTransactionKind kind) {
        return getTransactionsTableSchema(kind).asList().stream()
                .map(cd -> cd.asMap().get("name").stringValue())
                .collect(toSet());
    }

    private YTreeNode getTransactionsTableSchema(BillingTransactionKind kind) {
        if (kind == BillingTransactionKind.PAYMENT)
            return YTree.listBuilder()
                    .value(ytColumnDef("service_id", "int64"))
                    .value(ytColumnDef("transaction_id", "int64"))
                    .value(ytColumnDef("orig_transaction_id", "int64", FLAG_OPTIONAL))
                    .value(ytColumnDef("transaction_type", "string"))
                    .value(ytColumnDef("payment_type", "string"))
                    .value(ytColumnDef("paysys_type_cc", "string"))
                    .value(ytColumnDef("partner_id", "int64"))
                    .value(ytColumnDef("price", "string"))
                    .value(ytColumnDef("currency", "string"))
                    .value(ytColumnDef("trust_payment_id", "string"))
                    .value(ytColumnDef("client_id", "int64"))
                    .value(ytColumnDef("service_order_id", "string"))
                    .value(ytColumnDef("dt", "string"))
                    .value(ytColumnDef("update_dt", "string"))
                    .buildList();
        if (kind == BillingTransactionKind.INCOME)
            return YTree.listBuilder()
                    .value(ytColumnDef("service_id", "int64"))
                    .value(ytColumnDef("transaction_id", "int64"))
                    .value(ytColumnDef("orig_transaction_id", "int64", FLAG_OPTIONAL))
                    .value(ytColumnDef("transaction_type", "string"))
                    .value(ytColumnDef("service_order_id", "string"))
                    .value(ytColumnDef("dt", "string"))
                    .value(ytColumnDef("amount", "string"))
                    .value(ytColumnDef("currency", "string"))
                    .value(ytColumnDef("client_id", "int64"))
                    .buildList();
        return YTree.listBuilder().buildList();
    }

    Long readLastTransactionId(GUID txId, YPath yPath) {
        YPath lastIdPath = yPath.attribute(ATTRIBUTE_LAST_EXPORTED_TX_ID);
        if (!yt.cypress().exists(Optional.of(txId), FLAG_DO_NOT_PING_ANCESTOR_TRANSACTION, lastIdPath)) {
            return null;
        }
        return yt.cypress().get(Optional.of(txId), FLAG_DO_NOT_PING_ANCESTOR_TRANSACTION, lastIdPath, Cf.set()).longValue();
    }

    void writeLastTransactionId(GUID txId, YPath yPath, long newLastExportedTxId) {
        log.info("Storing a new value of the '{}' attribute: {}", ATTRIBUTE_LAST_EXPORTED_TX_ID, newLastExportedTxId);
        YPath lastIdPath = yPath.attribute(ATTRIBUTE_LAST_EXPORTED_TX_ID);
        yt.cypress().set(txId, FLAG_DO_NOT_PING_ANCESTOR_TRANSACTION, lastIdPath, newLastExportedTxId);
    }

    // for dev & testing purposes
    Map<Long, BillingTransaction> readTransactions(GUID txId, YPath yPath) {
        return readTransactions(txId, yPath, true, BillingTransactionKind.PAYMENT);
    }

    Map<Long, BillingTransaction> readTransactions(GUID txId, YPath yPath, BillingTransactionKind kind) {
        return readTransactions(txId, yPath, true, kind);
    }

    Map<Long, BillingTransaction> readTransactions(GUID txId, YPath yPath, boolean validateSchema) {
        return readTransactions(txId, yPath, validateSchema, BillingTransactionKind.PAYMENT);
    }

    Map<Long, BillingTransaction> readTransactions(GUID txId, YPath yPath, boolean validateSchema,
                                                   BillingTransactionKind kind) {
        Map<Long, BillingTransaction> transactions = new LinkedHashMap<>();
        Set<String> expectedColumns = getTransactionsTableColumns(kind);
        yt.tables().read(Optional.of(txId), FLAG_DO_NOT_PING_ANCESTOR_TRANSACTION, yPath,
                YTableEntryTypes.JACKSON).forEachRemaining(row -> {
            if (validateSchema) {
                Set<String> columns = Sets.newHashSet(row.fieldNames());
                Preconditions.checkState(columns.equals(expectedColumns), "Unexpected table schema: %s", columns);
            }
            long id = row.get("transaction_id").asLong();
            Preconditions.checkState(!transactions.containsKey(id), "Duplicate transaction id: {}", id);
            BillingTransaction origTx = row.hasNonNull("orig_transaction_id") ?
                    BillingTransaction.builder().ytId(row.get("orig_transaction_id").asLong()).build() : null;
            // (keep the commented out code as a template for future migrations)
            //Long serviceId = row.hasNonNull("service_id") ?
            //        Long.parseLong(row.get("service_id").asText()) : null;
            transactions.put(id, BillingTransaction.builder()
                    .serviceId(row.get("service_id").asLong())
                    .ytId(id)
                    .originalTransaction(origTx)
                    .transactionType(BillingTransactionType.forValue(row.get("transaction_type").asText()))
                    .paymentType(BillingTransactionPaymentType.forValue(row.get("payment_type").asText()))
                    .paymentSystemType(BillingTransactionPaymentSystemType.forValue(row.get("paysys_type_cc").asText()))
                    .partnerId(row.get("partner_id").asLong())
                    .serviceOrderId(row.get("service_order_id").asText())
                    .createdAt(parseTxDt(row.get("update_dt").asText()))
                    .payoutAt(parseTxDt(row.get("dt").asText()))
                    .value(Money.of(new BigDecimal(row.get("price").asText()), row.get("currency").asText()))
                    .trustPaymentId(row.get("trust_payment_id").asText())
                    .clientId(row.get("client_id").asLong())
                    .kind(kind)
                    .build());
        });
        return transactions;
    }

    @SuppressWarnings("SameParameterValue")
    boolean checkColumnExists(GUID txId, YPath path, String columnName) {
        List<YTreeNode> columns = yt.cypress()
                .get(Optional.of(txId), FLAG_DO_NOT_PING_ANCESTOR_TRANSACTION, path, Cf.set("schema"))
                .getAttributeOrThrow("schema").asList();
        for (YTreeNode column : columns) {
            String name = column.asMap().get("name").stringValue();
            if (columnName.equals(name)) {
                return true;
            }
        }
        return false;
    }

    void writeTransactions(GUID txId, YPath path, Collection<BillingTransaction> transactions) {
        writeTransactions(txId, path, transactions, true);
    }

    void writeTransactions(GUID txId, YPath path, Collection<BillingTransaction> transactions, boolean withServiceId) {
        if (transactions.isEmpty()) {
            log.warn("Nothing to export");
            return;
        }
        List<JsonNode> records = new ArrayList<>();
        for (BillingTransaction tx : transactions) {
            if (tx.getKind() == BillingTransactionKind.PAYMENT) {
                records.add(JSON_MAPPER.createObjectNode()
                        .put("service_id", tx.getServiceId())
                        .put("transaction_id", tx.getYtId())
                        .put("transaction_type", tx.getTransactionType().getValue())
                        .put("payment_type", tx.getPaymentType().getValue())
                        .put("paysys_type_cc", tx.getPaymentSystemType().getValue())
                        .put("partner_id", tx.getPartnerId())
                        .put("service_order_id", tx.getServiceOrderId())
                        .put("dt", formatTxDt(tx.getPayoutAt()))
                        .put("update_dt", formatTxDt(tx.getCreatedAt()))
                        .put("price", tx.getValue().getNumberStripped().setScale(2, RoundingMode.UNNECESSARY).toString())
                        .put("currency", tx.getValue().getCurrency().getCurrencyCode())
                        .put("trust_payment_id", tx.getTrustPaymentId())
                        .put("client_id", tx.getClientId())
                        .put("orig_transaction_id", tx.getOriginalTransaction() != null ?
                                tx.getOriginalTransaction().getYtId() : null));
            }
            if (tx.getKind() == BillingTransactionKind.INCOME) {
                records.add(JSON_MAPPER.createObjectNode()
                        .put("service_id", tx.getServiceId())
                        .put("transaction_id", tx.getYtId())
                        .put("orig_transaction_id", tx.getOriginalTransaction() != null ?
                                tx.getOriginalTransaction().getYtId() : null)
                        .put("transaction_type", tx.getTransactionType().getValue())
                        .put("service_order_id", tx.getServiceOrderId())
                        .put("dt", formatTxDt(tx.getPayoutAt()))
                        .put("amount", tx.getValue().getNumberStripped().setScale(2, RoundingMode.UNNECESSARY).toString())
                        .put("currency", tx.getValue().getCurrency().getCurrencyCode())
                        .put("client_id", tx.getClientId())
                );
            }
        }

        if (!withServiceId) {
            // migration code for the release day, on the next day all tables should come with the new column
            log.warn("No service_id column yet, skipping its values");
            records.forEach(r -> ((ObjectNode) r).remove("service_id"));
        }

        YPath appendablePath = path.append(true);
        yt.tables().write(Optional.of(txId), FLAG_DO_NOT_PING_ANCESTOR_TRANSACTION, appendablePath,
                YTableEntryTypes.JACKSON, Cf.wrap(records).iterator());

        log.info("Successfully exported {} transactions", transactions.size());
    }

    Collection<BillingTransaction> filterTransactions(Collection<BillingTransaction> transactions,
                                                      Long lastExportedTxId, BillingTransactionKind kind) {
        log.info("Checking if any the {} new transactions was already exported; last exported tx id {}",
                transactions.size(), lastExportedTxId);
        Collection<BillingTransaction> newTransactions = new ArrayList<>();
        for (BillingTransaction transaction : transactions) {
            if (transaction.getKind() != kind)
                continue;
            if (lastExportedTxId == null || transaction.getYtId() > lastExportedTxId) {
                newTransactions.add(transaction);
            } else {
                Order sourceOrder = transaction.getSourceFinancialEvent().getOrder();
                try (NestedMdc ignored = NestedMdc.forOptionalEntityId(sourceOrder == null ? null : sourceOrder.getId())) {
                    log.warn("{} has already been exported, skipping it", transaction.getDescription());
                }
            }
        }
        return newTransactions;
    }

    Instant parseTxDt(String value) {
        return LocalDateTime.from(BILLING_DATE_TIME_FORMAT.parse(value)).atZone(BILLING_TIME_ZONE).toInstant();
    }

    String formatTxDt(Instant value) {
        return value.atZone(BILLING_TIME_ZONE).format(BILLING_DATE_TIME_FORMAT);
    }

    private YTreeNode ytColumnDef(String name, String type) {
        return ytColumnDef(name, type, FLAG_REQUIRED);
    }

    private YTreeNode ytColumnDef(String name, String type, boolean optional) {
        return YTree.mapBuilder().key("name").value(name).key("type").value(type).key("required").value(!optional).buildMap();
    }
}
