package ru.yandex.qloud.kikimr.transport;

import NKikimrClient.TGRpcServerGrpc;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import ru.yandex.kikimr.proto.Kqp;
import ru.yandex.kikimr.proto.Minikql;
import ru.yandex.kikimr.proto.Msgbus;

import java.util.Collections;
import java.util.List;
import java.util.function.Function;

import static java.util.stream.Collectors.toList;

@Component
public class YQL {
    private final static Logger LOG = LoggerFactory.getLogger(YQL.class);

    private final KikimrRpc kikimrRpc;

    public YQL(@Autowired KikimrRpc kikimrRpc) {
        this.kikimrRpc = kikimrRpc;
    }

    public void executeQueryWithRootAccess(String query) {
        executeQueryWithToken(query, kikimrRpc.getRootUserToken());
    }

    private void executeQueryWithToken(String query, String token) {
        LOG.debug("going to execute query: {}", query);
        long startTime = System.currentTimeMillis();
        doYqlRequest(query, token);
        LOG.debug("execute yql query: time = {}; query = {}", System.currentTimeMillis() - startTime, query);
    }

    public Result queries(List<String> queries, int maxResultsCount) {
        TGRpcServerGrpc.TGRpcServerBlockingStub rpcWithDeadline = kikimrRpc.getRpcWithDeadline();

        List<Result> results = Lists.newArrayListWithCapacity(queries.size());
        int resultsFound = 0;
        for (String query : queries) {
            try {
                Result result = query(withReplacedLimit(query, maxResultsCount - resultsFound), rpcWithDeadline);
                results.add(result);
                resultsFound += result.getRows().size();
                if (resultsFound >= maxResultsCount) {
                    break;
                }
            } catch (YQLException|KQPException e) {
                if (Strings.nullToEmpty(e.getMessage()).contains("Table not found") || Strings.nullToEmpty(e.getMessage()).contains("Cannot find table")) {
                    LOG.warn(String.format("Error for query %s; search will be continued", query), e);
                } else {
                    throw e;
                }
            }
        }

        return mergeResults(results);
    }

    private String withReplacedLimit(String query, int newLimit) {
        try {
            int oldLimit = Integer.parseInt(StringUtils.substringBetween(query, "LIMIT ", ";"));
            if (newLimit != oldLimit) {
                return StringUtils.replace(query, "LIMIT " + oldLimit, "LIMIT " + newLimit);
            }
        } catch (Exception e) {
            LOG.warn(String.format("no parsed limit in query %s", query), e);
        }
        return query;
    }

    public Result query(String query) {
        return query(query, kikimrRpc.getRpcWithDeadline());
    }

    private Result query(String query, TGRpcServerGrpc.TGRpcServerBlockingStub rpcWithDeadLine) {
        long time0 = System.currentTimeMillis();

        Kqp.TQueryResponse queryResponse = doYqlRequestWithDeadLinedRpc(query, rpcWithDeadLine, kikimrRpc.getUserSecurityToken()).getQueryResponse();

        Minikql.TResult result = queryResponse.getResults(0);
        long time1 = System.currentTimeMillis();

        List<String> fields = queryResponse.getResults(0)
                .getType().getStruct().getMemberList().stream()
                .filter(m -> m.getName().equals("Data"))
                .findFirst().get()
                .getType().getList().getItem().getStruct().getMemberList().stream()
                .map(Minikql.TMember::getName).collect(toList());

        List<Minikql.TValue> values = result.getValue().getStruct(0).getListList();
        List<Row> rows = values.stream()
                .map(v -> new Row(v, fields.size()))
                .collect(toList());

        long time2 = System.currentTimeMillis();

        LOG.debug("Running YQL query: {} (query {}, parsing {})", query, (time1 - time0), (time2 - time1));

        return new Result(fields, rows);
    }

    private Msgbus.TYqlResponse doYqlRequestWithDeadLinedRpc(
            String query,
            TGRpcServerGrpc.TGRpcServerBlockingStub rpcWithDeadLine,
            String securityToken
    ) {
        Msgbus.TYqlRequest.Builder requestBuilder = Msgbus.TYqlRequest.newBuilder();
        if (securityToken != null) {
            requestBuilder.setSecurityToken(securityToken);
        }

        requestBuilder.setRequestType(Kqp.ERequestType.REQUEST_TYPE_PROCESS_QUERY)
                .setQueryRequest(
                        Kqp.TQueryRequest.newBuilder()
                                .setQuery(query)
                                .setType(Kqp.EQueryType.QUERY_TYPE_SQL)
                                .setKeepSession(false)
                );

        Msgbus.TYqlResponse response =  rpcWithDeadLine.yqlRequest(requestBuilder.build());
        if (response.hasKqpError()) {
            throw new KQPException(response.getKqpError());
        }

        long queryErrorsCount = response.getQueryResponse().getQueryIssuesCount();
        if (queryErrorsCount > 0) {
            LoggerFactory.getLogger(getClass()).warn("Query: {}, issues: {}", query, response.getQueryResponse().getQueryIssuesList());
        }
        if (response.hasKqpStatus() && response.getKqpStatus() != Kqp.EStatus.STATUS_SUCCESS) {
            if (queryErrorsCount > 0) {
                throw new YQLException(response.getQueryResponse().getQueryIssuesList());
            } else {
                throw new KQPException(response.getKqpStatus());
            }
        }

        return response;
    }

    private Msgbus.TYqlResponse doYqlRequest(String query, String securityToken) {
        return doYqlRequestWithDeadLinedRpc(query, kikimrRpc.getRpcWithDeadline(), securityToken);
    }

    private Result mergeResults(List<Result> results) {
        if (results.isEmpty()) {
            return new Result(Collections.emptyList(), Collections.emptyList());
        }
        if (results.size() == 1) {
            return results.get(0);
        }

        List<String> fieldNames = results.get(0).getFieldsNames();

        ImmutableList.Builder<Row> mergedRowsBuilder = ImmutableList.builder();
        results.forEach(result -> {
            if (! result.getFieldsNames().equals(fieldNames)) {
                throw new IllegalStateException(String.format(
                        "results have different fields: actual = %s; expected =  %s",
                        result.getFieldsNames(), fieldNames
                ));
            }
            mergedRowsBuilder.addAll(result.getRows());
        });

        return new Result(fieldNames, mergedRowsBuilder.build());
    }

    public static class Result {
        private final List<String> fieldsNames;
        private final List<Row> rows;

        public Result(List<String> fieldsNames, List<Row> rows) {
            this.fieldsNames = fieldsNames;
            this.rows = rows;
        }

        public List<String> getFieldsNames() {
            return fieldsNames;
        }

        public List<Row> getRows() {
            return rows;
        }
    }

    public static class Row {
        private Object[] data;

        private static final List<MinikqlValueConverter> MINIKQL_VALUE_CONVERTERS = Lists.newArrayList(
                createConverter(Minikql.TValue::hasBool, Minikql.TValue::getBool),
                createConverter(Minikql.TValue::hasBytes, value -> value.getBytes().toStringUtf8()),
                createConverter(Minikql.TValue::hasText, Minikql.TValue::getText),
                createConverter(Minikql.TValue::hasInt32, Minikql.TValue::getInt32),
                createConverter(Minikql.TValue::hasInt64, Minikql.TValue::getInt64),
                createConverter(Minikql.TValue::hasUint32, Minikql.TValue::getUint32),
                createConverter(Minikql.TValue::hasUint64, Minikql.TValue::getUint64),
                createConverter(Minikql.TValue::hasFloat, Minikql.TValue::getFloat),
                createConverter(Minikql.TValue::hasDouble, Minikql.TValue::getDouble)
        );

        public Row(Minikql.TValue rowValue, int fieldsCount) {
            if (rowValue.getStructCount() != fieldsCount) {
                throw new RuntimeException("Row struct count != fields count");
            }

            data = new Object[fieldsCount];
            for (int i = 0; i < fieldsCount; i++) {
                final Minikql.TValue v = rowValue.getStruct(i).hasOptional() ? rowValue.getStruct(i).getOptional() : rowValue.getStruct(i);

                boolean converted = false;
                for (MinikqlValueConverter converter : MINIKQL_VALUE_CONVERTERS) {
                    if (converter.canConvert(v)) {
                        data[i] = converter.convert(v);
                        converted = true;
                        break;
                    }
                }

                if (!converted) {
                    throw new RuntimeException("Unknown type: " + v);
                }
            }
        }

        public Object[] getData() {
            return data;
        }
    }

    private interface MinikqlValueConverter {
        boolean canConvert(Minikql.TValue value);

        Object convert(Minikql.TValue value);
    }

    private static MinikqlValueConverter createConverter(
            Function<Minikql.TValue, Boolean> canConvertFunction,
            Function<Minikql.TValue, Object> convertFunction
    ) {
        return new MinikqlValueConverter() {
            @Override
            public boolean canConvert(Minikql.TValue value) {
                return canConvertFunction.apply(value);
            }
            @Override
            public Object convert(Minikql.TValue value) {
                return convertFunction.apply(value);
            }
        };
    }
}
