Skip to content

Add basic support for keyset based pagination and scrolling. #2692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public final class Constants {
public static final String NAME_OF_INTERNAL_ID = "__internalNeo4jId__";
public static final String NAME_OF_ELEMENT_ID = "__elementId__";

public static final String NAME_OF_ADDITIONAL_SORT = "__stable_uniq_sort__";

/**
* Indicates the list of dynamic labels.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import org.springframework.data.support.PageableExecutionUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* Base class for {@link RepositoryQuery} implementations for Neo4j.
Expand Down Expand Up @@ -80,7 +79,7 @@ public QueryMethod getQueryMethod() {
@Override
public final Object execute(Object[] parameters) {

boolean incrementLimit = queryMethod.isSliceQuery() && !queryMethod.getQueryAnnotation().map(q -> q.countQuery()).filter(StringUtils::hasText).isPresent();
boolean incrementLimit = queryMethod.incrementLimit();
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor(
(Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(),
parameters);
Expand All @@ -91,8 +90,7 @@ public final Object execute(Object[] parameters) {
PropertyFilterSupport.getInputProperties(resultProcessor, factory, mappingContext), parameterAccessor,
null, getMappingFunction(resultProcessor), incrementLimit ? l -> l + 1 : UnaryOperator.identity());

Object rawResult = new Neo4jQueryExecution.DefaultQueryExecution(neo4jOperations).execute(preparedQuery,
queryMethod.isCollectionLikeQuery() || queryMethod.isPageQuery() || queryMethod.isSliceQuery());
Object rawResult = new Neo4jQueryExecution.DefaultQueryExecution(neo4jOperations).execute(preparedQuery, queryMethod.asCollectionQuery());

Converter<Object, Object> preparingConverter = OptionalUnwrappingConverter.INSTANCE;
if (returnedType.isProjecting()) {
Expand All @@ -107,6 +105,8 @@ public final Object execute(Object[] parameters) {
rawResult = createPage(parameterAccessor, (List<?>) rawResult);
} else if (queryMethod.isSliceQuery()) {
rawResult = createSlice(incrementLimit, parameterAccessor, (List<?>) rawResult);
} else if (queryMethod.isScrollQuery()) {
rawResult = createWindow(resultProcessor, incrementLimit, parameterAccessor, (List<?>) rawResult, preparedQuery.getQueryFragmentsAndParameters());
}
return resultProcessor.processResult(rawResult, preparingConverter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Collection;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;

import org.neo4j.driver.types.MapAccessor;
import org.neo4j.driver.types.TypeSystem;
Expand All @@ -37,6 +38,8 @@
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

import reactor.core.publisher.Flux;

/**
* Base class for {@link RepositoryQuery} implementations for Neo4j.
*
Expand Down Expand Up @@ -67,16 +70,17 @@ public QueryMethod getQueryMethod() {
@Override
public final Object execute(Object[] parameters) {

boolean incrementLimit = queryMethod.incrementLimit();
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor((Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters);
ResultProcessor resultProcessor = queryMethod.getResultProcessor().withDynamicProjection(parameterAccessor);

ReturnedType returnedType = resultProcessor.getReturnedType();
PreparedQuery<?> preparedQuery = prepareQuery(returnedType.getReturnedType(),
PropertyFilterSupport.getInputProperties(resultProcessor, factory, mappingContext), parameterAccessor,
null, getMappingFunction(resultProcessor));
null, getMappingFunction(resultProcessor), incrementLimit ? l -> l + 1 : UnaryOperator.identity());

Object rawResult = new Neo4jQueryExecution.ReactiveQueryExecution(neo4jOperations).execute(preparedQuery,
queryMethod.isCollectionLikeQuery());
queryMethod.asCollectionQuery());

Converter<Object, Object> preparingConverter = OptionalUnwrappingConverter.INSTANCE;
if (returnedType.isProjecting()) {
Expand All @@ -87,10 +91,16 @@ public final Object execute(Object[] parameters) {
(EntityInstanceWithSource) OptionalUnwrappingConverter.INSTANCE.convert(source));
}

if (queryMethod.isScrollQuery()) {
rawResult = ((Flux<?>) rawResult).collectList().map(rawResultList ->
createWindow(resultProcessor, incrementLimit, parameterAccessor, rawResultList, preparedQuery.getQueryFragmentsAndParameters()));
}

return resultProcessor.processResult(rawResult, preparingConverter);
}

protected abstract <T extends Object> PreparedQuery<T> prepareQuery(Class<T> returnedType,
Collection<PropertyFilter.ProjectedPath> includedProperties, Neo4jParameterAccessor parameterAccessor,
@Nullable Neo4jQueryType queryType, @Nullable Supplier<BiFunction<TypeSystem, MapAccessor, ?>> mappingFunction);
@Nullable Neo4jQueryType queryType, @Nullable Supplier<BiFunction<TypeSystem, MapAccessor, ?>> mappingFunction,
@Nullable UnaryOperator<Integer> limitModifier);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,29 @@
import static org.neo4j.cypherdsl.core.Cypher.property;

import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.apiguardian.api.API;
import org.neo4j.cypherdsl.core.Condition;
import org.neo4j.cypherdsl.core.Conditions;
import org.neo4j.cypherdsl.core.Cypher;
import org.neo4j.cypherdsl.core.Expression;
import org.neo4j.cypherdsl.core.Functions;
import org.neo4j.cypherdsl.core.SortItem;
import org.neo4j.cypherdsl.core.StatementBuilder;
import org.neo4j.cypherdsl.core.SymbolicName;
import org.neo4j.driver.Value;
import org.springframework.data.domain.KeysetScrollPosition;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.neo4j.core.mapping.Constants;
import org.springframework.data.neo4j.core.mapping.GraphPropertyDescription;
import org.springframework.data.neo4j.core.mapping.Neo4jPersistentEntity;
import org.springframework.data.neo4j.core.mapping.Neo4jPersistentProperty;
import org.springframework.data.neo4j.core.mapping.NodeDescription;

/**
Expand All @@ -51,6 +60,7 @@ public final class CypherAdapterUtils {
*/
public static Function<Sort.Order, SortItem> sortAdapterFor(NodeDescription<?> nodeDescription) {
return order -> {

String domainProperty = order.getProperty();
boolean propertyIsQualified = domainProperty.contains(".");
SymbolicName root;
Expand All @@ -61,12 +71,21 @@ public static Function<Sort.Order, SortItem> sortAdapterFor(NodeDescription<?> n
root = Cypher.name(domainProperty.substring(0, indexOfSeparator));
domainProperty = domainProperty.substring(indexOfSeparator + 1);
}
String graphProperty = nodeDescription.getGraphProperty(domainProperty)
.map(GraphPropertyDescription::getPropertyName).orElseThrow(() -> new IllegalStateException(
String.format("Cannot order by the unknown graph property: '%s'", order.getProperty())));
Expression expression = property(root, graphProperty);
if (order.isIgnoreCase()) {
expression = Functions.toLower(expression);

var optionalGraphProperty = nodeDescription.getGraphProperty(domainProperty);
if (optionalGraphProperty.isEmpty()) {
throw new IllegalStateException(String.format("Cannot order by the unknown graph property: '%s'", order.getProperty()));
}
var graphProperty = optionalGraphProperty.get();
Expression expression;
if (graphProperty.isInternalIdProperty()) {
// Not using the id expression here, as the root will be referring to the constructed map being returned.
expression = property(root, Constants.NAME_OF_INTERNAL_ID);
} else {
expression = property(root, graphProperty.getPropertyName());
if (order.isIgnoreCase()) {
expression = Functions.toLower(expression);
}
}
SortItem sortItem = Cypher.sort(expression);

Expand All @@ -78,6 +97,72 @@ public static Function<Sort.Order, SortItem> sortAdapterFor(NodeDescription<?> n
};
}

public static Condition combineKeysetIntoCondition(Neo4jPersistentEntity<?> entity, KeysetScrollPosition scrollPosition, Sort sort) {

var incomingKeys = scrollPosition.getKeys();
var orderedKeys = new LinkedHashMap<String, Object>();

record PropertyAndOrder(Neo4jPersistentProperty property, Sort.Order order) {
}
var propertyAndDirection = new HashMap<String, PropertyAndOrder>();

sort.forEach(order -> {
var property = entity.getRequiredPersistentProperty(order.getProperty());
var propertyName = property.getPropertyName();
propertyAndDirection.put(propertyName, new PropertyAndOrder(property, order));

if (incomingKeys.containsKey(propertyName)) {
orderedKeys.put(propertyName, incomingKeys.get(propertyName));
}
});
if (incomingKeys.containsKey(Constants.NAME_OF_ADDITIONAL_SORT)) {
orderedKeys.put(Constants.NAME_OF_ADDITIONAL_SORT, incomingKeys.get(Constants.NAME_OF_ADDITIONAL_SORT));
}

var root = Constants.NAME_OF_TYPED_ROOT_NODE.apply(entity);

var resultingCondition = Conditions.noCondition();
// This is the next equality pair if previous sort key was equal
var nextEquals = Conditions.noCondition();
// This is the condition for when all the sort orderedKeys are equal, and we must filter via id
var allEqualsWithArtificialSort = Conditions.noCondition();

for (Map.Entry<String, Object> entry : orderedKeys.entrySet()) {

var k = entry.getKey();
var v = entry.getValue();
if (v == null || (v instanceof Value value && value.isNull())) {
throw new IllegalStateException("Cannot resume from KeysetScrollPosition. Offending key: '%s' is 'null'".formatted(k));
}
var parameter = Cypher.anonParameter(v);

Expression expression;

var scrollDirection = scrollPosition.getDirection();
if (Constants.NAME_OF_ADDITIONAL_SORT.equals(k)) {
expression = entity.getIdExpression();
var comparatorFunction = getComparatorFunction(scrollDirection == KeysetScrollPosition.Direction.Forward ? Sort.Direction.ASC : Sort.Direction.DESC, scrollDirection);
allEqualsWithArtificialSort = allEqualsWithArtificialSort.and(comparatorFunction.apply(expression, parameter));
} else {
var p = propertyAndDirection.get(k);
expression = p.property.isIdProperty() ? entity.getIdExpression() : root.property(k);

var comparatorFunction = getComparatorFunction(p.order.getDirection(), scrollDirection);
resultingCondition = resultingCondition.or(nextEquals.and(comparatorFunction.apply(expression, parameter)));
nextEquals = expression.eq(parameter);
allEqualsWithArtificialSort = allEqualsWithArtificialSort.and(nextEquals);
}
}
return resultingCondition.or(allEqualsWithArtificialSort);
}

private static BiFunction<Expression, Expression, Condition> getComparatorFunction(Sort.Direction sortDirection, KeysetScrollPosition.Direction scrollDirection) {
if (scrollDirection == KeysetScrollPosition.Direction.Backward) {
return sortDirection.isAscending() ? Expression::lte : Expression::gte;
}
return sortDirection.isAscending() ? Expression::gt : Expression::lt;
}

/**
* Converts a Spring Data sort to an equivalent list of {@link SortItem sort items}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@
import org.neo4j.cypherdsl.core.RelationshipPattern;
import org.neo4j.cypherdsl.core.SortItem;
import org.neo4j.driver.types.Point;
import org.springframework.data.domain.KeysetScrollPosition;
import org.springframework.data.domain.OffsetScrollPosition;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Range;
import org.springframework.data.domain.ScrollPosition;
import org.springframework.data.domain.Sort;
import org.springframework.data.geo.Box;
import org.springframework.data.geo.Circle;
Expand All @@ -64,6 +67,7 @@
import org.springframework.data.neo4j.core.mapping.PropertyFilter;
import org.springframework.data.neo4j.core.mapping.RelationshipDescription;
import org.springframework.data.neo4j.core.schema.TargetNode;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.repository.query.parser.AbstractQueryCreator;
import org.springframework.data.repository.query.parser.Part;
import org.springframework.data.repository.query.parser.PartTree;
Expand All @@ -82,6 +86,7 @@
final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndParameters, Condition> {

private final Neo4jMappingContext mappingContext;
private final QueryMethod queryMethod;

private final Class<?> domainType;
private final NodeDescription<?> nodeDescription;
Expand All @@ -99,6 +104,8 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar

private final Pageable pagingParameter;

private final ScrollPosition scrollPosition;

/**
* Stores the number of max results, if the {@link PartTree tree} is limiting.
*/
Expand All @@ -113,18 +120,21 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar

private final List<PropertyPathWrapper> propertyPathWrappers;

private final boolean keysetRequiresSort;

/**
* Can be used to modify the limit of a paged or sliced query.
*/
private final UnaryOperator<Integer> limitModifier;

CypherQueryCreator(Neo4jMappingContext mappingContext, Class<?> domainType, Neo4jQueryType queryType, PartTree tree,
CypherQueryCreator(Neo4jMappingContext mappingContext, QueryMethod queryMethod, Class<?> domainType, Neo4jQueryType queryType, PartTree tree,
Neo4jParameterAccessor actualParameters, Collection<PropertyFilter.ProjectedPath> includedProperties,
BiFunction<Object, Neo4jPersistentPropertyConverter<?>, Object> parameterConversion,
UnaryOperator<Integer> limitModifier) {

super(tree, actualParameters);
this.mappingContext = mappingContext;
this.queryMethod = queryMethod;

this.domainType = domainType;
this.nodeDescription = this.mappingContext.getRequiredNodeDescription(this.domainType);
Expand All @@ -139,6 +149,7 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar
this.parameterConversion = parameterConversion;

this.pagingParameter = actualParameters.getPageable();
this.scrollPosition = actualParameters.getScrollPosition();
this.limitModifier = limitModifier;

AtomicInteger symbolicNameIndex = new AtomicInteger();
Expand All @@ -148,6 +159,7 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar
mappingContext.getPersistentPropertyPath(part.getProperty())))
.collect(Collectors.toList());

this.keysetRequiresSort = queryMethod.isScrollQuery() && actualParameters.getScrollPosition() instanceof KeysetScrollPosition;
}

private class PropertyPathWrapper {
Expand Down Expand Up @@ -260,7 +272,12 @@ protected QueryFragmentsAndParameters complete(@Nullable Condition condition, So
.collect(Collectors.toMap(p -> p.nameOrIndex, p -> parameterConversion.apply(p.value, p.conversionOverride)));

QueryFragments queryFragments = createQueryFragments(condition, sort);
return new QueryFragmentsAndParameters(nodeDescription, queryFragments, convertedParameters);

var theSort = pagingParameter.getSort().and(sort);
if (keysetRequiresSort && theSort.isUnsorted()) {
throw new UnsupportedOperationException("Unsorted keyset based scrolling is not supported.");
}
return new QueryFragmentsAndParameters(nodeDescription, queryFragments, convertedParameters, theSort);
}

@NonNull
Expand All @@ -280,15 +297,12 @@ private QueryFragments createQueryFragments(@Nullable Condition condition, Sort
}
}

// closing action: add the condition and path match
queryFragments.setCondition(conditionFragment);

if (!relationshipChain.isEmpty()) {
queryFragments.setMatchOn(relationshipChain);
} else {
queryFragments.addMatchOn(startNode);
}
/// end of initial filter query creation
// end of initial filter query creation

if (queryType == Neo4jQueryType.COUNT) {
queryFragments.setReturnExpression(Functions.count(Cypher.asterisk()), true);
Expand All @@ -298,20 +312,38 @@ private QueryFragments createQueryFragments(@Nullable Condition condition, Sort
queryFragments.setDeleteExpression(Constants.NAME_OF_TYPED_ROOT_NODE.apply(nodeDescription));
queryFragments.setReturnExpression(Functions.count(Constants.NAME_OF_TYPED_ROOT_NODE.apply(nodeDescription)), true);
} else {

var theSort = pagingParameter.getSort().and(sort);

if (pagingParameter.isUnpaged() && scrollPosition == null && maxResults != null) {
queryFragments.setLimit(limitModifier.apply(maxResults.intValue()));
} else if (scrollPosition instanceof KeysetScrollPosition keysetScrollPosition) {

Neo4jPersistentEntity<?> entity = (Neo4jPersistentEntity<?>) nodeDescription;
// Enforce sorting by something that is hopefully stable comparable (looking at Neo4j's id() with tears in my eyes).
theSort = theSort.and(Sort.by(entity.getRequiredIdProperty().getName()).ascending());

queryFragments.setLimit(limitModifier.apply(maxResults.intValue()));
if (!keysetScrollPosition.isInitial()) {
conditionFragment = conditionFragment.and(CypherAdapterUtils.combineKeysetIntoCondition(entity, keysetScrollPosition, theSort));
}

queryFragments.setRequiresReverseSort(keysetScrollPosition.getDirection() == KeysetScrollPosition.Direction.Backward);
} else if (scrollPosition instanceof OffsetScrollPosition offsetScrollPosition) {
queryFragments.setSkip(offsetScrollPosition.getOffset());
queryFragments.setLimit(limitModifier.apply(pagingParameter.isUnpaged() ? maxResults.intValue() : pagingParameter.getPageSize()));
}

queryFragments.setReturnBasedOn(nodeDescription, includedProperties, isDistinct);
queryFragments.setOrderBy(Stream
.concat(sortItems.stream(),
pagingParameter.getSort().and(sort).stream().map(CypherAdapterUtils.sortAdapterFor(nodeDescription)))
theSort.stream().map(CypherAdapterUtils.sortAdapterFor(nodeDescription)))
.collect(Collectors.toList()));
if (pagingParameter.isUnpaged()) {
queryFragments.setLimit(maxResults);
} else {
long skip = pagingParameter.getOffset();
int pageSize = pagingParameter.getPageSize();
queryFragments.setSkip(skip);
queryFragments.setLimit(limitModifier.apply(pageSize));
}
}

// closing action: add the condition and path match
queryFragments.setCondition(conditionFragment);

return queryFragments;
}

Expand Down
Loading