diff --git a/src/main/java/org/springframework/data/neo4j/config/Neo4jCdiConfigurationSupport.java b/src/main/java/org/springframework/data/neo4j/config/Neo4jCdiConfigurationSupport.java index 26abe0d20e..7bec1faf0f 100644 --- a/src/main/java/org/springframework/data/neo4j/config/Neo4jCdiConfigurationSupport.java +++ b/src/main/java/org/springframework/data/neo4j/config/Neo4jCdiConfigurationSupport.java @@ -75,10 +75,12 @@ public Configuration cypherDslConfiguration() { public Neo4jOperations neo4jOperations( @Any Instance neo4jClient, @Any Instance mappingContext, - @Any Instance cypherDslConfiguration + @Any Instance cypherDslConfiguration, + @Any Instance transactionManager ) { Neo4jTemplate neo4jTemplate = new Neo4jTemplate(resolve(neo4jClient), resolve(mappingContext)); neo4jTemplate.setCypherRenderer(Renderer.getRenderer(resolve(cypherDslConfiguration))); + neo4jTemplate.setTransactionManager(resolve(transactionManager)); return neo4jTemplate; } diff --git a/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java b/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java index adfd836f27..ee51dd4b3b 100644 --- a/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java +++ b/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java @@ -99,6 +99,9 @@ import org.springframework.data.util.TypeInformation; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; +import org.springframework.transaction.PlatformTransactionManager; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.support.TransactionTemplate; import org.springframework.util.Assert; /** @@ -110,6 +113,7 @@ * @since 6.0 */ @API(status = API.Status.STABLE, since = "6.0") +@SuppressWarnings("DataFlowIssue") public final class Neo4jTemplate implements Neo4jOperations, FluentNeo4jOperations, BeanClassLoaderAware, BeanFactoryAware { @@ -118,6 +122,13 @@ public final class Neo4jTemplate implements private static final String OPTIMISTIC_LOCKING_ERROR_MESSAGE = "An entity with the required version does not exist."; + private static final TransactionDefinition readOnlyTransactionDefinition = new TransactionDefinition() { + @Override + public boolean isReadOnly() { + return true; + } + }; + private final Neo4jClient neo4jClient; private final Neo4jMappingContext neo4jMappingContext; @@ -128,12 +139,16 @@ public final class Neo4jTemplate implements private EventSupport eventSupport; - private ProjectionFactory projectionFactoryf; + private ProjectionFactory projectionFactory; private Renderer renderer; private Function elementIdOrIdFunction; + private TransactionTemplate transactionTemplate; + + private TransactionTemplate transactionTemplateReadOnly; + public Neo4jTemplate(Neo4jClient neo4jClient) { this(neo4jClient, new Neo4jMappingContext()); } @@ -157,7 +172,7 @@ public Neo4jTemplate(Neo4jClient neo4jClient, Neo4jMappingContext neo4jMappingCo } ProjectionFactory getProjectionFactory() { - return Objects.requireNonNull(this.projectionFactoryf, "Projection support for the Neo4j template is only available when the template is a proper and fully initialized Spring bean."); + return Objects.requireNonNull(this.projectionFactory, "Projection support for the Neo4j template is only available when the template is a proper and fully initialized Spring bean."); } @Override @@ -188,10 +203,11 @@ public long count(String cypherQuery) { @Override public long count(String cypherQuery, Map parameters) { - - PreparedQuery preparedQuery = PreparedQuery.queryFor(Long.class).withCypherQuery(cypherQuery) - .withParameters(parameters).build(); - return toExecutableQuery(preparedQuery).getRequiredSingleResult(); + return transactionTemplateReadOnly.execute(tx -> { + PreparedQuery preparedQuery = PreparedQuery.queryFor(Long.class).withCypherQuery(cypherQuery) + .withParameters(parameters).build(); + return toExecutableQuery(preparedQuery).getRequiredSingleResult(); + }); } @Override @@ -201,84 +217,96 @@ public List findAll(Class domainType) { } private List doFindAll(Class domainType, @Nullable Class resultType) { - Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); - return createExecutableQuery(domainType, resultType, QueryFragmentsAndParameters.forFindAll(entityMetaData)) - .getResults(); + return transactionTemplateReadOnly + .execute(tx -> { + Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); + return createExecutableQuery(domainType, resultType, QueryFragmentsAndParameters.forFindAll(entityMetaData)) + .getResults(); + }); } @Override public List findAll(Statement statement, Class domainType) { - return createExecutableQuery(domainType, statement).getResults(); + return transactionTemplateReadOnly + .execute(tx -> createExecutableQuery(domainType, statement).getResults()); } @Override public List findAll(Statement statement, Map parameters, Class domainType) { - return createExecutableQuery(domainType, null, statement, parameters).getResults(); + return transactionTemplateReadOnly + .execute(tx -> createExecutableQuery(domainType, null, statement, parameters).getResults()); } @Override public Optional findOne(Statement statement, Map parameters, Class domainType) { - return createExecutableQuery(domainType, null, statement, parameters).getSingleResult(); + return transactionTemplateReadOnly + .execute(tx -> createExecutableQuery(domainType, null, statement, parameters).getSingleResult()); } @Override public List findAll(String cypherQuery, Class domainType) { - return createExecutableQuery(domainType, cypherQuery).getResults(); + return transactionTemplateReadOnly + .execute(tx -> createExecutableQuery(domainType, cypherQuery).getResults()); } @Override public List findAll(String cypherQuery, Map parameters, Class domainType) { - return createExecutableQuery(domainType, null, cypherQuery, parameters).getResults(); + return transactionTemplateReadOnly + .execute(tx -> createExecutableQuery(domainType, null, cypherQuery, parameters).getResults()); } @Override public Optional findOne(String cypherQuery, Map parameters, Class domainType) { - return createExecutableQuery(domainType, null, cypherQuery, parameters).getSingleResult(); + return transactionTemplateReadOnly + .execute(tx -> createExecutableQuery(domainType, null, cypherQuery, parameters).getSingleResult()); } @Override public ExecutableFind find(Class domainType) { - return new FluentOperationSupport(this).find(domainType); + return transactionTemplateReadOnly + .execute(tx -> new FluentOperationSupport(this).find(domainType)); } @SuppressWarnings("unchecked") List doFind(@Nullable String cypherQuery, @Nullable Map parameters, Class domainType, Class resultType, TemplateSupport.FetchType fetchType, @Nullable QueryFragmentsAndParameters queryFragmentsAndParameters) { - List intermediaResults = Collections.emptyList(); - if (cypherQuery == null && queryFragmentsAndParameters == null && fetchType == TemplateSupport.FetchType.ALL) { - intermediaResults = doFindAll(domainType, resultType); - } else { - ExecutableQuery executableQuery; - if (queryFragmentsAndParameters == null) { - executableQuery = createExecutableQuery(domainType, resultType, cypherQuery, - parameters == null ? Collections.emptyMap() : parameters); + return transactionTemplateReadOnly.execute(tx -> { + List intermediaResults = Collections.emptyList(); + if (cypherQuery == null && queryFragmentsAndParameters == null && fetchType == TemplateSupport.FetchType.ALL) { + intermediaResults = doFindAll(domainType, resultType); } else { - executableQuery = createExecutableQuery(domainType, resultType, queryFragmentsAndParameters); + ExecutableQuery executableQuery; + if (queryFragmentsAndParameters == null) { + executableQuery = createExecutableQuery(domainType, resultType, cypherQuery, + parameters == null ? Collections.emptyMap() : parameters); + } else { + executableQuery = createExecutableQuery(domainType, resultType, queryFragmentsAndParameters); + } + intermediaResults = switch (fetchType) { + case ALL -> executableQuery.getResults(); + case ONE -> executableQuery.getSingleResult().map(Collections::singletonList) + .orElseGet(Collections::emptyList); + }; } - intermediaResults = switch (fetchType) { - case ALL -> executableQuery.getResults(); - case ONE -> executableQuery.getSingleResult().map(Collections::singletonList) - .orElseGet(Collections::emptyList); - }; - } - if (resultType.isAssignableFrom(domainType)) { - return (List) intermediaResults; - } + if (resultType.isAssignableFrom(domainType)) { + return (List) intermediaResults; + } + + if (resultType.isInterface()) { + return intermediaResults.stream() + .map(instance -> getProjectionFactory().createProjection(resultType, instance)) + .collect(Collectors.toList()); + } - if (resultType.isInterface()) { + DtoInstantiatingConverter converter = new DtoInstantiatingConverter(resultType, neo4jMappingContext); return intermediaResults.stream() - .map(instance -> getProjectionFactory().createProjection(resultType, instance)) + .map(EntityInstanceWithSource.class::cast) + .map(converter::convert) + .map(v -> (R) v) + .filter(Objects::nonNull) .collect(Collectors.toList()); - } - - DtoInstantiatingConverter converter = new DtoInstantiatingConverter(resultType, neo4jMappingContext); - return intermediaResults.stream() - .map(EntityInstanceWithSource.class::cast) - .map(converter::convert) - .map(v -> (R) v) - .filter(Objects::nonNull) - .collect(Collectors.toList()); + }); } @Override @@ -297,22 +325,28 @@ public boolean existsById(Object id, Class domainType) { @Override public Optional findById(Object id, Class domainType) { - Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); - - return createExecutableQuery(domainType, null, - QueryFragmentsAndParameters.forFindById(entityMetaData, - convertIdValues(entityMetaData.getRequiredIdProperty(), id))) - .getSingleResult(); + return transactionTemplateReadOnly + .execute(tx -> { + Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); + + return createExecutableQuery(domainType, null, + QueryFragmentsAndParameters.forFindById(entityMetaData, + convertIdValues(entityMetaData.getRequiredIdProperty(), id))) + .getSingleResult(); + }); } @Override public List findAllById(Iterable ids, Class domainType) { - Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); - - return createExecutableQuery(domainType, null, - QueryFragmentsAndParameters.forFindByAllId( - entityMetaData, convertIdValues(entityMetaData.getRequiredIdProperty(), ids))) - .getResults(); + return transactionTemplateReadOnly + .execute(tx -> { + Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); + + return createExecutableQuery(domainType, null, + QueryFragmentsAndParameters.forFindByAllId( + entityMetaData, convertIdValues(entityMetaData.getRequiredIdProperty(), ids))) + .getResults(); + }); } private Object convertIdValues(@Nullable Neo4jPersistentProperty idProperty, @Nullable Object idValues) { @@ -331,11 +365,12 @@ private Object convertIdValues(@Nullable Neo4jPersistentProperty idProperty, @Nu } } - @Override public T save(T instance) { - return saveImpl(instance, Collections.emptySet(), null); + return transactionTemplate + .execute(tx -> saveImpl(instance, Collections.emptySet(), null)); + } @Override @@ -344,42 +379,45 @@ public T saveAs(T instance, BiPredicate saveImpl(instance, TemplateSupport.computeIncludedPropertiesFromPredicate(this.neo4jMappingContext, instance.getClass(), includeProperty), null)); } @Override public R saveAs(T instance, Class resultType) { - Assert.notNull(resultType, "ResultType must not be null"); + return transactionTemplate.execute(tx -> { - if (instance == null) { - return null; - } + Assert.notNull(resultType, "ResultType must not be null"); - if (resultType.equals(instance.getClass())) { - return resultType.cast(save(instance)); - } + if (instance == null) { + return null; + } - ProjectionFactory localProjectionFactory = getProjectionFactory(); - ProjectionInformation projectionInformation = localProjectionFactory.getProjectionInformation(resultType); - Collection pps = PropertyFilterSupport.addPropertiesFrom(instance.getClass(), resultType, - localProjectionFactory, neo4jMappingContext); + if (resultType.equals(instance.getClass())) { + return resultType.cast(save(instance)); + } - T savedInstance = saveImpl(instance, pps, null); - if (!resultType.isInterface()) { - @SuppressWarnings("unchecked") R result = (R) new DtoInstantiatingConverter(resultType, neo4jMappingContext).convertDirectly(savedInstance); - return result; - } - if (projectionInformation.isClosed()) { - return localProjectionFactory.createProjection(resultType, savedInstance); - } + ProjectionFactory localProjectionFactory = getProjectionFactory(); + ProjectionInformation projectionInformation = localProjectionFactory.getProjectionInformation(resultType); + Collection pps = PropertyFilterSupport.addPropertiesFrom(instance.getClass(), resultType, + localProjectionFactory, neo4jMappingContext); - Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(savedInstance.getClass()); - Neo4jPersistentProperty idProperty = entityMetaData.getIdProperty(); - PersistentPropertyAccessor propertyAccessor = entityMetaData.getPropertyAccessor(savedInstance); - return localProjectionFactory.createProjection(resultType, - this.findById(propertyAccessor.getProperty(idProperty), savedInstance.getClass()).get()); + T savedInstance = saveImpl(instance, pps, null); + if (!resultType.isInterface()) { + @SuppressWarnings("unchecked") R result = (R) new DtoInstantiatingConverter(resultType, neo4jMappingContext).convertDirectly(savedInstance); + return result; + } + if (projectionInformation.isClosed()) { + return localProjectionFactory.createProjection(resultType, savedInstance); + } + + Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(savedInstance.getClass()); + Neo4jPersistentProperty idProperty = entityMetaData.getIdProperty(); + PersistentPropertyAccessor propertyAccessor = entityMetaData.getPropertyAccessor(savedInstance); + return localProjectionFactory.createProjection(resultType, + this.findById(propertyAccessor.getProperty(idProperty), savedInstance.getClass()).get()); + }); } private T saveImpl(T instance, @Nullable Collection includedProperties, @Nullable NestedRelationshipProcessingStateMachine stateMachine) { @@ -464,7 +502,8 @@ private DynamicLabels determineDynamicLabels(T entityToBeSaved, Neo4jPersist @Override public List saveAll(Iterable instances) { - return saveAllImpl(instances, Collections.emptySet(), null); + return transactionTemplate + .execute(tx -> saveAllImpl(instances, Collections.emptySet(), null)); } private boolean requiresSingleStatements(boolean heterogeneousCollection, Neo4jPersistentEntity entityMetaData) { @@ -548,129 +587,149 @@ class Tuple3 { @Override public List saveAllAs(Iterable instances, BiPredicate includeProperty) { - return saveAllImpl(instances, null, includeProperty); + return transactionTemplate + .execute(tx -> saveAllImpl(instances, null, includeProperty)); } @Override public List saveAllAs(Iterable instances, Class resultType) { - Assert.notNull(resultType, "ResultType must not be null"); + return transactionTemplate + .execute(tx -> { - Class commonElementType = TemplateSupport.findCommonElementType(instances); + Assert.notNull(resultType, "ResultType must not be null"); - if (commonElementType == null) { - throw new IllegalArgumentException("Could not determine a common element of an heterogeneous collection"); - } + Class commonElementType = TemplateSupport.findCommonElementType(instances); - if (commonElementType == TemplateSupport.EmptyIterable.class) { - return Collections.emptyList(); - } + if (commonElementType == null) { + throw new IllegalArgumentException("Could not determine a common element of an heterogeneous collection"); + } - if (resultType.isAssignableFrom(commonElementType)) { - @SuppressWarnings("unchecked") // Nicer to live with this than streaming, mapping and collecting to avoid the cast. It's easier on the reactive side. - List saveElements = (List) saveAll(instances); - return saveElements; - } + if (commonElementType == TemplateSupport.EmptyIterable.class) { + return Collections.emptyList(); + } - ProjectionFactory localProjectionFactory = getProjectionFactory(); - ProjectionInformation projectionInformation = localProjectionFactory.getProjectionInformation(resultType); + if (resultType.isAssignableFrom(commonElementType)) { + @SuppressWarnings("unchecked") // Nicer to live with this than streaming, mapping and collecting to avoid the cast. It's easier on the reactive side. + List saveElements = (List) saveAll(instances); + return saveElements; + } - Collection pps = PropertyFilterSupport.addPropertiesFrom(commonElementType, resultType, - localProjectionFactory, neo4jMappingContext); + ProjectionFactory localProjectionFactory = getProjectionFactory(); + ProjectionInformation projectionInformation = localProjectionFactory.getProjectionInformation(resultType); - List savedInstances = saveAllImpl(instances, pps, null); + Collection pps = PropertyFilterSupport.addPropertiesFrom(commonElementType, resultType, + localProjectionFactory, neo4jMappingContext); - if (projectionInformation.isClosed()) { - return savedInstances.stream().map(instance -> localProjectionFactory.createProjection(resultType, instance)) - .collect(Collectors.toList()); - } + List savedInstances = saveAllImpl(instances, pps, null); - Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(commonElementType); - Neo4jPersistentProperty idProperty = entityMetaData.getIdProperty(); + if (projectionInformation.isClosed()) { + return savedInstances.stream().map(instance -> localProjectionFactory.createProjection(resultType, instance)) + .collect(Collectors.toList()); + } - List ids = savedInstances.stream().map(savedInstance -> { - PersistentPropertyAccessor propertyAccessor = entityMetaData.getPropertyAccessor(savedInstance); - return propertyAccessor.getProperty(idProperty); - }).collect(Collectors.toList()); + Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(commonElementType); + Neo4jPersistentProperty idProperty = entityMetaData.getIdProperty(); - return findAllById(ids, commonElementType) - .stream().map(instance -> localProjectionFactory.createProjection(resultType, instance)) - .collect(Collectors.toList()); + List ids = savedInstances.stream().map(savedInstance -> { + PersistentPropertyAccessor propertyAccessor = entityMetaData.getPropertyAccessor(savedInstance); + return propertyAccessor.getProperty(idProperty); + }).collect(Collectors.toList()); + + return findAllById(ids, commonElementType) + .stream().map(instance -> localProjectionFactory.createProjection(resultType, instance)) + .collect(Collectors.toList()); + }); } @Override public void deleteById(Object id, Class domainType) { - Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); - String nameOfParameter = "id"; - Condition condition = entityMetaData.getIdExpression().isEqualTo(parameter(nameOfParameter)); + transactionTemplate + .executeWithoutResult(tx -> { + + Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); + String nameOfParameter = "id"; + Condition condition = entityMetaData.getIdExpression().isEqualTo(parameter(nameOfParameter)); - log.debug(() -> String.format("Deleting entity with id %s ", id)); + log.debug(() -> String.format("Deleting entity with id %s ", id)); - Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData, condition); - ResultSummary summary = this.neo4jClient.query(renderer.render(statement)) - .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), id)) - .to(nameOfParameter).run(); + Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData, condition); + ResultSummary summary = this.neo4jClient.query(renderer.render(statement)) + .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), id)) + .to(nameOfParameter).run(); - log.debug(() -> String.format("Deleted %d nodes and %d relationships.", summary.counters().nodesDeleted(), - summary.counters().relationshipsDeleted())); + log.debug(() -> String.format("Deleted %d nodes and %d relationships.", summary.counters().nodesDeleted(), + summary.counters().relationshipsDeleted())); + }); } @Override public void deleteByIdWithVersion(Object id, Class domainType, Neo4jPersistentProperty versionProperty, Object versionValue) { - Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); + transactionTemplate + .executeWithoutResult(tx -> { + Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); - String nameOfParameter = "id"; - Condition condition = entityMetaData.getIdExpression().isEqualTo(parameter(nameOfParameter)) - .and(Cypher.property(Constants.NAME_OF_TYPED_ROOT_NODE.apply(entityMetaData), versionProperty.getPropertyName()) - .isEqualTo(parameter(Constants.NAME_OF_VERSION_PARAM)) - .or(Cypher.property(Constants.NAME_OF_TYPED_ROOT_NODE.apply(entityMetaData), versionProperty.getPropertyName()).isNull())); + String nameOfParameter = "id"; + Condition condition = entityMetaData.getIdExpression().isEqualTo(parameter(nameOfParameter)) + .and(Cypher.property(Constants.NAME_OF_TYPED_ROOT_NODE.apply(entityMetaData), versionProperty.getPropertyName()) + .isEqualTo(parameter(Constants.NAME_OF_VERSION_PARAM)) + .or(Cypher.property(Constants.NAME_OF_TYPED_ROOT_NODE.apply(entityMetaData), versionProperty.getPropertyName()).isNull())); - Statement statement = cypherGenerator.prepareMatchOf(entityMetaData, condition) - .returning(Constants.NAME_OF_TYPED_ROOT_NODE.apply(entityMetaData)).build(); + Statement statement = cypherGenerator.prepareMatchOf(entityMetaData, condition) + .returning(Constants.NAME_OF_TYPED_ROOT_NODE.apply(entityMetaData)).build(); - Map parameters = new HashMap<>(); - parameters.put(nameOfParameter, convertIdValues(entityMetaData.getRequiredIdProperty(), id)); - parameters.put(Constants.NAME_OF_VERSION_PARAM, versionValue); + Map parameters = new HashMap<>(); + parameters.put(nameOfParameter, convertIdValues(entityMetaData.getRequiredIdProperty(), id)); + parameters.put(Constants.NAME_OF_VERSION_PARAM, versionValue); - createExecutableQuery(domainType, null, statement, parameters).getSingleResult().orElseThrow( - () -> new OptimisticLockingFailureException(OPTIMISTIC_LOCKING_ERROR_MESSAGE) - ); + createExecutableQuery(domainType, null, statement, parameters).getSingleResult().orElseThrow( + () -> new OptimisticLockingFailureException(OPTIMISTIC_LOCKING_ERROR_MESSAGE) + ); - deleteById(id, domainType); + deleteById(id, domainType); + }); } @Override public void deleteAllById(Iterable ids, Class domainType) { - Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); - String nameOfParameter = "ids"; - Condition condition = entityMetaData.getIdExpression().in(parameter(nameOfParameter)); + transactionTemplate + .executeWithoutResult(tx -> { + + Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); + String nameOfParameter = "ids"; + Condition condition = entityMetaData.getIdExpression().in(parameter(nameOfParameter)); - log.debug(() -> String.format("Deleting all entities with the following ids: %s ", ids)); + log.debug(() -> String.format("Deleting all entities with the following ids: %s ", ids)); - Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData, condition); - ResultSummary summary = this.neo4jClient.query(renderer.render(statement)) - .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), ids)) - .to(nameOfParameter).run(); + Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData, condition); + ResultSummary summary = this.neo4jClient.query(renderer.render(statement)) + .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), ids)) + .to(nameOfParameter).run(); - log.debug(() -> String.format("Deleted %d nodes and %d relationships.", summary.counters().nodesDeleted(), - summary.counters().relationshipsDeleted())); + log.debug(() -> String.format("Deleted %d nodes and %d relationships.", summary.counters().nodesDeleted(), + summary.counters().relationshipsDeleted())); + }); } @Override public void deleteAll(Class domainType) { - Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); - log.debug(() -> String.format("Deleting all nodes with primary label %s", entityMetaData.getPrimaryLabel())); + transactionTemplate + .executeWithoutResult(tx -> { + + Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); + log.debug(() -> String.format("Deleting all nodes with primary label %s", entityMetaData.getPrimaryLabel())); - Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData); - ResultSummary summary = this.neo4jClient.query(renderer.render(statement)).run(); + Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData); + ResultSummary summary = this.neo4jClient.query(renderer.render(statement)).run(); - log.debug(() -> String.format("Deleted %d nodes and %d relationships.", summary.counters().nodesDeleted(), - summary.counters().relationshipsDeleted())); + log.debug(() -> String.format("Deleted %d nodes and %d relationships.", summary.counters().nodesDeleted(), + summary.counters().relationshipsDeleted())); + }); } private ExecutableQuery createExecutableQuery(Class domainType, Statement statement) { @@ -1031,7 +1090,7 @@ public void setBeanFactory(BeanFactory beanFactory) throws BeansException { SpelAwareProxyProjectionFactory spelAwareProxyProjectionFactory = new SpelAwareProxyProjectionFactory(); spelAwareProxyProjectionFactory.setBeanClassLoader(beanClassLoader); spelAwareProxyProjectionFactory.setBeanFactory(beanFactory); - this.projectionFactoryf = spelAwareProxyProjectionFactory; + this.projectionFactory = spelAwareProxyProjectionFactory; Configuration cypherDslConfiguration = beanFactory .getBeanProvider(Configuration.class) @@ -1039,6 +1098,8 @@ public void setBeanFactory(BeanFactory beanFactory) throws BeansException { this.renderer = Renderer.getRenderer(cypherDslConfiguration); this.elementIdOrIdFunction = SpringDataCypherDsl.elementIdOrIdFunction.apply(cypherDslConfiguration.getDialect()); this.cypherGenerator.setElementIdOrIdFunction(elementIdOrIdFunction); + this.transactionTemplate = new TransactionTemplate(beanFactory.getBean(PlatformTransactionManager.class)); + this.transactionTemplateReadOnly = new TransactionTemplate(beanFactory.getBean(PlatformTransactionManager.class), readOnlyTransactionDefinition); } // only used for the CDI configuration @@ -1046,6 +1107,12 @@ public void setCypherRenderer(Renderer rendererFromCdiConfiguration) { this.renderer = rendererFromCdiConfiguration; } + // only used for the CDI configuration + public void setTransactionManager(PlatformTransactionManager platformTransactionManager) { + this.transactionTemplate = new TransactionTemplate(platformTransactionManager); + this.transactionTemplateReadOnly = new TransactionTemplate(platformTransactionManager, readOnlyTransactionDefinition); + } + @Override public ExecutableQuery toExecutableQuery(Class domainType, QueryFragmentsAndParameters queryFragmentsAndParameters) { @@ -1079,29 +1146,32 @@ public ExecutableSave save(Class domainType) { } List doSave(Iterable instances, Class domainType) { - // empty check - if (!instances.iterator().hasNext()) { - return Collections.emptyList(); - } + return transactionTemplate + .execute(tx -> { + // empty check + if (!instances.iterator().hasNext()) { + return Collections.emptyList(); + } - Class resultType = TemplateSupport.findCommonElementType(instances); + Class resultType = TemplateSupport.findCommonElementType(instances); - Collection pps = PropertyFilterSupport.addPropertiesFrom(domainType, resultType, - getProjectionFactory(), neo4jMappingContext); + Collection pps = PropertyFilterSupport.addPropertiesFrom(domainType, resultType, + getProjectionFactory(), neo4jMappingContext); - NestedRelationshipProcessingStateMachine stateMachine = new NestedRelationshipProcessingStateMachine(neo4jMappingContext); - List results = new ArrayList<>(); - EntityFromDtoInstantiatingConverter converter = new EntityFromDtoInstantiatingConverter<>(domainType, neo4jMappingContext); - for (R instance : instances) { - T domainObject = converter.convert(instance); + NestedRelationshipProcessingStateMachine stateMachine = new NestedRelationshipProcessingStateMachine(neo4jMappingContext); + List results = new ArrayList<>(); + EntityFromDtoInstantiatingConverter converter = new EntityFromDtoInstantiatingConverter<>(domainType, neo4jMappingContext); + for (R instance : instances) { + T domainObject = converter.convert(instance); - T savedEntity = saveImpl(domainObject, pps, stateMachine); + T savedEntity = saveImpl(domainObject, pps, stateMachine); - @SuppressWarnings("unchecked") - R convertedBack = (R) new DtoInstantiatingConverter(resultType, neo4jMappingContext).convertDirectly(savedEntity); - results.add(convertedBack); - } - return results; + @SuppressWarnings("unchecked") + R convertedBack = (R) new DtoInstantiatingConverter(resultType, neo4jMappingContext).convertDirectly(savedEntity); + results.add(convertedBack); + } + return results; + }); } final class DefaultExecutableQuery implements ExecutableQuery { @@ -1114,35 +1184,42 @@ final class DefaultExecutableQuery implements ExecutableQuery { @SuppressWarnings("unchecked") public List getResults() { - Collection all = createFetchSpec().map(Neo4jClient.RecordFetchSpec::all).orElse(Collections.emptyList()); - if (preparedQuery.resultsHaveBeenAggregated()) { - return all.stream().flatMap(nested -> ((Collection) nested).stream()).distinct().collect(Collectors.toList()); - } - return all.stream().collect(Collectors.toList()); + return transactionTemplate + .execute(tx -> { + Collection all = createFetchSpec().map(Neo4jClient.RecordFetchSpec::all).orElse(Collections.emptyList()); + if (preparedQuery.resultsHaveBeenAggregated()) { + return all.stream().flatMap(nested -> ((Collection) nested).stream()).distinct().collect(Collectors.toList()); + } + return all.stream().collect(Collectors.toList()); + }); } @SuppressWarnings("unchecked") public Optional getSingleResult() { - try { - Optional one = createFetchSpec().flatMap(Neo4jClient.RecordFetchSpec::one); - if (preparedQuery.resultsHaveBeenAggregated()) { - return one.map(aggregatedResults -> ((LinkedHashSet) aggregatedResults).iterator().next()); + return transactionTemplate.execute(tx -> { + try { + Optional one = createFetchSpec().flatMap(Neo4jClient.RecordFetchSpec::one); + if (preparedQuery.resultsHaveBeenAggregated()) { + return one.map(aggregatedResults -> ((LinkedHashSet) aggregatedResults).iterator().next()); + } + return one; + } catch (NoSuchRecordException e) { + // This exception is thrown by the driver in both cases when there are 0 or 1+n records + // So there has been an incorrect result size, but not too few results but too many. + throw new IncorrectResultSizeDataAccessException(e.getMessage(), 1); } - return one; - } catch (NoSuchRecordException e) { - // This exception is thrown by the driver in both cases when there are 0 or 1+n records - // So there has been an incorrect result size, but not too few results but too many. - throw new IncorrectResultSizeDataAccessException(e.getMessage(), 1); - } + }); } @SuppressWarnings("unchecked") public T getRequiredSingleResult() { - Optional one = createFetchSpec().flatMap(Neo4jClient.RecordFetchSpec::one); - if (preparedQuery.resultsHaveBeenAggregated()) { - one = one.map(aggregatedResults -> ((LinkedHashSet) aggregatedResults).iterator().next()); - } - return one.orElseThrow(() -> new NoResultException(1, preparedQuery.getQueryFragmentsAndParameters().getCypherQuery())); + return transactionTemplate.execute(tx -> { + Optional one = createFetchSpec().flatMap(Neo4jClient.RecordFetchSpec::one); + if (preparedQuery.resultsHaveBeenAggregated()) { + one = one.map(aggregatedResults -> ((LinkedHashSet) aggregatedResults).iterator().next()); + } + return one.orElseThrow(() -> new NoResultException(1, preparedQuery.getQueryFragmentsAndParameters().getCypherQuery())); + }); } private Optional> createFetchSpec() { diff --git a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java index 195d1ad487..3a497d96a4 100644 --- a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java +++ b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java @@ -15,50 +15,19 @@ */ package org.springframework.data.neo4j.core; -import static org.neo4j.cypherdsl.core.Cypher.anyNode; -import static org.neo4j.cypherdsl.core.Cypher.asterisk; -import static org.neo4j.cypherdsl.core.Cypher.parameter; - -import org.neo4j.cypherdsl.core.FunctionInvocation; -import org.neo4j.cypherdsl.core.Named; -import org.neo4j.driver.Values; -import org.springframework.data.neo4j.core.mapping.IdDescription; -import org.springframework.data.neo4j.core.mapping.SpringDataCypherDsl; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.util.function.Tuple2; -import reactor.util.function.Tuple3; -import reactor.util.function.Tuples; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; -import java.util.function.BiPredicate; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.stream.Collectors; - import org.apache.commons.logging.LogFactory; import org.apiguardian.api.API; import org.neo4j.cypherdsl.core.Condition; import org.neo4j.cypherdsl.core.Cypher; +import org.neo4j.cypherdsl.core.FunctionInvocation; import org.neo4j.cypherdsl.core.Functions; +import org.neo4j.cypherdsl.core.Named; import org.neo4j.cypherdsl.core.Node; import org.neo4j.cypherdsl.core.Statement; import org.neo4j.cypherdsl.core.renderer.Configuration; import org.neo4j.cypherdsl.core.renderer.Renderer; import org.neo4j.driver.Value; +import org.neo4j.driver.Values; import org.neo4j.driver.types.Entity; import org.neo4j.driver.types.MapAccessor; import org.neo4j.driver.types.TypeSystem; @@ -83,6 +52,7 @@ import org.springframework.data.neo4j.core.mapping.DtoInstantiatingConverter; import org.springframework.data.neo4j.core.mapping.EntityFromDtoInstantiatingConverter; import org.springframework.data.neo4j.core.mapping.EntityInstanceWithSource; +import org.springframework.data.neo4j.core.mapping.IdDescription; import org.springframework.data.neo4j.core.mapping.IdentitySupport; import org.springframework.data.neo4j.core.mapping.MappingSupport; import org.springframework.data.neo4j.core.mapping.Neo4jMappingContext; @@ -94,6 +64,7 @@ import org.springframework.data.neo4j.core.mapping.NodeDescription; import org.springframework.data.neo4j.core.mapping.PropertyFilter; import org.springframework.data.neo4j.core.mapping.RelationshipDescription; +import org.springframework.data.neo4j.core.mapping.SpringDataCypherDsl; import org.springframework.data.neo4j.core.mapping.callback.ReactiveEventSupport; import org.springframework.data.neo4j.core.schema.TargetNode; import org.springframework.data.neo4j.repository.query.QueryFragments; @@ -104,7 +75,38 @@ import org.springframework.data.util.TypeInformation; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; +import org.springframework.transaction.ReactiveTransactionManager; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.reactive.TransactionalOperator; import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuple3; +import reactor.util.function.Tuples; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.BiPredicate; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.neo4j.cypherdsl.core.Cypher.anyNode; +import static org.neo4j.cypherdsl.core.Cypher.asterisk; +import static org.neo4j.cypherdsl.core.Cypher.parameter; /** * @author Michael J. Simons @@ -129,6 +131,17 @@ public final class ReactiveNeo4jTemplate implements private final CypherGenerator cypherGenerator; + private static final TransactionDefinition readOnlyTransactionDefinition = new TransactionDefinition() { + @Override + public boolean isReadOnly() { + return true; + } + }; + + private TransactionalOperator transactionalOperatorReadOnly; + + private TransactionalOperator transactionalOperator; + private ClassLoader beanClassLoader; private ReactiveEventSupport eventSupport; @@ -183,13 +196,13 @@ public Mono count(String cypherQuery) { public Mono count(String cypherQuery, Map parameters) { PreparedQuery preparedQuery = PreparedQuery.queryFor(Long.class).withCypherQuery(cypherQuery) .withParameters(parameters).build(); - return this.toExecutableQuery(preparedQuery).flatMap(ExecutableQuery::getSingleResult); + return transactionalOperatorReadOnly.transactional(this.toExecutableQuery(preparedQuery).flatMap(ExecutableQuery::getSingleResult)); } @Override public Flux findAll(Class domainType) { - return doFindAll(domainType, null); + return transactionalOperatorReadOnly.transactional(doFindAll(domainType, null)); } private Flux doFindAll(Class domainType, @Nullable Class resultType) { @@ -202,34 +215,34 @@ private Flux doFindAll(Class domainType, @Nullable Class resultType @Override public Flux findAll(Statement statement, Class domainType) { - return createExecutableQuery(domainType, statement).flatMapMany(ExecutableQuery::getResults); + return transactionalOperatorReadOnly.transactional(createExecutableQuery(domainType, statement).flatMapMany(ExecutableQuery::getResults)); } @Override public Flux findAll(Statement statement, Map parameters, Class domainType) { - return createExecutableQuery(domainType, null, statement, parameters).flatMapMany(ExecutableQuery::getResults); + return transactionalOperatorReadOnly.transactional(createExecutableQuery(domainType, null, statement, parameters).flatMapMany(ExecutableQuery::getResults)); } @Override public Mono findOne(Statement statement, Map parameters, Class domainType) { - return createExecutableQuery(domainType, null, statement, parameters).flatMap(ExecutableQuery::getSingleResult); + return transactionalOperatorReadOnly.transactional(createExecutableQuery(domainType, null, statement, parameters).flatMap(ExecutableQuery::getSingleResult)); } @Override public Flux findAll(String cypherQuery, Class domainType) { - return createExecutableQuery(domainType, cypherQuery).flatMapMany(ExecutableQuery::getResults); + return transactionalOperatorReadOnly.transactional(createExecutableQuery(domainType, cypherQuery).flatMapMany(ExecutableQuery::getResults)); } @Override public Flux findAll(String cypherQuery, Map parameters, Class domainType) { - return createExecutableQuery(domainType, null, cypherQuery, parameters).flatMapMany(ExecutableQuery::getResults); + return transactionalOperatorReadOnly.transactional(createExecutableQuery(domainType, null, cypherQuery, parameters).flatMapMany(ExecutableQuery::getResults)); } @Override public Mono findOne(String cypherQuery, Map parameters, Class domainType) { - return createExecutableQuery(domainType, null, cypherQuery, parameters).flatMap(ExecutableQuery::getSingleResult); + return transactionalOperatorReadOnly.transactional(createExecutableQuery(domainType, null, cypherQuery, parameters).flatMap(ExecutableQuery::getSingleResult)); } @Override @@ -253,8 +266,8 @@ Flux doFind(@Nullable String cypherQuery, @Nullable Map executableQuery.flatMapMany(ExecutableQuery::getResults); - case ONE -> executableQuery.flatMap(ExecutableQuery::getSingleResult).flux(); + case ALL -> transactionalOperatorReadOnly.transactional(executableQuery.flatMapMany(ExecutableQuery::getResults)); + case ONE -> transactionalOperatorReadOnly.transactional(executableQuery.flatMap(ExecutableQuery::getSingleResult).flux()); }; } @@ -290,10 +303,10 @@ public Mono findById(Object id, Class domainType) { Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); - return createExecutableQuery(domainType, null, + return transactionalOperatorReadOnly.transactional(createExecutableQuery(domainType, null, QueryFragmentsAndParameters.forFindById(entityMetaData, convertIdValues(entityMetaData.getRequiredIdProperty(), id))) - .flatMap(ExecutableQuery::getSingleResult); + .flatMap(ExecutableQuery::getSingleResult)); } @Override @@ -301,10 +314,10 @@ public Flux findAllById(Iterable ids, Class domainType) { Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); - return createExecutableQuery(domainType, null, + return transactionalOperatorReadOnly.transactional(createExecutableQuery(domainType, null, QueryFragmentsAndParameters.forFindByAllId(entityMetaData, convertIdValues(entityMetaData.getRequiredIdProperty(), ids))) - .flatMapMany(ExecutableQuery::getResults); + .flatMapMany(ExecutableQuery::getResults)); } @Override @@ -333,7 +346,7 @@ private Object convertIdValues(@Nullable Neo4jPersistentProperty idProperty, @Nu @Override public Mono save(T instance) { - return saveImpl(instance, Collections.emptySet(), null); + return transactionalOperator.transactional(saveImpl(instance, Collections.emptySet(), null)); } @Override @@ -343,7 +356,7 @@ public Mono saveAs(T instance, BiPredicate Mono saveAs(T instance, Class resultType) { Collection pps = PropertyFilterSupport.addPropertiesFrom(instance.getClass(), resultType, localProjectionFactory, neo4jMappingContext); - Mono savingPublisher = saveImpl(instance, pps, null); + Mono savingPublisher = transactionalOperator.transactional(saveImpl(instance, pps, null)); if (!resultType.isInterface()) { return savingPublisher.map(savedInstance -> { @@ -382,8 +395,8 @@ public Mono saveAs(T instance, Class resultType) { Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(savedInstance.getClass()); Neo4jPersistentProperty idProperty = entityMetaData.getIdProperty(); PersistentPropertyAccessor propertyAccessor = entityMetaData.getPropertyAccessor(savedInstance); - return this.findById(propertyAccessor.getProperty(idProperty), savedInstance.getClass()) - .map(loadedValue -> localProjectionFactory.createProjection(resultType, loadedValue)); + return transactionalOperatorReadOnly.transactional(this.findById(propertyAccessor.getProperty(idProperty), savedInstance.getClass()) + .map(loadedValue -> localProjectionFactory.createProjection(resultType, loadedValue))); }); } @@ -401,14 +414,14 @@ Flux doSave(Iterable instances, Class domainType) { NestedRelationshipProcessingStateMachine stateMachine = new NestedRelationshipProcessingStateMachine(neo4jMappingContext); EntityFromDtoInstantiatingConverter converter = new EntityFromDtoInstantiatingConverter<>(domainType, neo4jMappingContext); return Flux.fromIterable(instances) - .concatMap(instance -> { - T domainObject = converter.convert(instance); + .concatMap(instance -> { + T domainObject = converter.convert(instance); - @SuppressWarnings("unchecked") - Mono result = saveImpl(domainObject, pps, stateMachine) - .map(savedEntity -> (R) new DtoInstantiatingConverter(resultType, neo4jMappingContext).convertDirectly(savedEntity)); - return result; - }); + @SuppressWarnings("unchecked") + Mono result = transactionalOperator.transactional(saveImpl(domainObject, pps, stateMachine) + .map(savedEntity -> (R) new DtoInstantiatingConverter(resultType, neo4jMappingContext).convertDirectly(savedEntity))); + return result; + }); } private Mono saveImpl(T instance, @Nullable Collection includedProperties, @Nullable NestedRelationshipProcessingStateMachine stateMachine) { @@ -493,13 +506,13 @@ private Mono> determineDynamicLabels(T entityToBeSa @Override public Flux saveAll(Iterable instances) { - return saveAllImpl(instances, Collections.emptySet(), null); + return transactionalOperator.transactional(saveAllImpl(instances, Collections.emptySet(), null)); } @Override public Flux saveAllAs(Iterable instances, BiPredicate includeProperty) { - return saveAllImpl(instances, null, includeProperty); + return transactionalOperator.transactional(saveAllImpl(instances, null, includeProperty)); } @Override @@ -527,7 +540,7 @@ public Flux saveAllAs(Iterable instances, Class resultType) { Collection pps = PropertyFilterSupport.addPropertiesFrom(commonElementType, resultType, localProjectionFactory, neo4jMappingContext); - Flux savedInstances = saveAllImpl(instances, pps, null); + Flux savedInstances = transactionalOperator.transactional(saveAllImpl(instances, pps, null)); if (projectionInformation.isClosed()) { return savedInstances.map(instance -> localProjectionFactory.createProjection(resultType, instance)); } @@ -537,7 +550,7 @@ public Flux saveAllAs(Iterable instances, Class resultType) { return savedInstances.concatMap(savedInstance -> { PersistentPropertyAccessor propertyAccessor = entityMetaData.getPropertyAccessor(savedInstance); - return findById(propertyAccessor.getProperty(idProperty), commonElementType); + return transactionalOperatorReadOnly.transactional(findById(propertyAccessor.getProperty(idProperty), commonElementType)); }).map(instance -> localProjectionFactory.createProjection(resultType, instance)); } @@ -615,9 +628,10 @@ public Mono deleteAllById(Iterable ids, Class domainType) { Condition condition = entityMetaData.getIdExpression().in(parameter(nameOfParameter)); Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData, condition); - return Mono.defer(() -> this.neo4jClient.query(() -> renderer.render(statement)) - .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), ids)) - .to(nameOfParameter).run().then()); + return transactionalOperator.transactional(Mono.defer(() -> + this.neo4jClient.query(() -> renderer.render(statement)) + .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), ids)) + .to(nameOfParameter).run().then())); } @Override @@ -630,9 +644,10 @@ public Mono deleteById(Object id, Class domainType) { Condition condition = entityMetaData.getIdExpression().isEqualTo(parameter(nameOfParameter)); Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData, condition); - return Mono.defer(() -> this.neo4jClient.query(() -> renderer.render(statement)) - .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), id)) - .to(nameOfParameter).run().then()); + return transactionalOperator.transactional(Mono.defer(() -> + this.neo4jClient.query(() -> renderer.render(statement)) + .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), id)) + .to(nameOfParameter).run().then())); } @Override @@ -653,15 +668,16 @@ public Mono deleteByIdWithVersion(Object id, Class domainType, Neo4 parameters.put(nameOfParameter, convertIdValues(entityMetaData.getRequiredIdProperty(), id)); parameters.put(Constants.NAME_OF_VERSION_PARAM, versionValue); - return Mono.defer(() -> this.neo4jClient.query(() -> renderer.render(statement)) - .bindAll(parameters) - .fetch().one().switchIfEmpty(Mono.defer(() -> { - if (entityMetaData.hasVersionProperty()) { - return Mono.error(() -> new OptimisticLockingFailureException(OPTIMISTIC_LOCKING_ERROR_MESSAGE)); - } - return Mono.empty(); - }))) - .then(deleteById(id, domainType)); + return transactionalOperator.transactional(Mono.defer(() -> + this.neo4jClient.query(() -> renderer.render(statement)) + .bindAll(parameters) + .fetch().one().switchIfEmpty(Mono.defer(() -> { + if (entityMetaData.hasVersionProperty()) { + return Mono.error(() -> new OptimisticLockingFailureException(OPTIMISTIC_LOCKING_ERROR_MESSAGE)); + } + return Mono.empty(); + }))) + .then(deleteById(id, domainType))); } @Override @@ -669,7 +685,7 @@ public Mono deleteAll(Class domainType) { Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getRequiredPersistentEntity(domainType); Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData); - return Mono.defer(() -> this.neo4jClient.query(() -> renderer.render(statement)).run().then()); + return transactionalOperator.transactional(Mono.defer(() -> this.neo4jClient.query(() -> renderer.render(statement)).run().then())); } private Mono> createExecutableQuery(Class domainType, Statement statement) { @@ -1165,6 +1181,8 @@ public void setBeanFactory(BeanFactory beanFactory) throws BeansException { this.renderer = Renderer.getRenderer(cypherDslConfiguration); this.elementIdOrIdFunction = SpringDataCypherDsl.elementIdOrIdFunction.apply(cypherDslConfiguration.getDialect()); this.cypherGenerator.setElementIdOrIdFunction(elementIdOrIdFunction); + this.transactionalOperatorReadOnly = TransactionalOperator.create(beanFactory.getBean(ReactiveTransactionManager.class), readOnlyTransactionDefinition); + this.transactionalOperator = TransactionalOperator.create(beanFactory.getBean(ReactiveTransactionManager.class)); } @Override @@ -1177,7 +1195,7 @@ public ExecutableSave save(Class domainType) { return new ReactiveFluentOperationSupport(this).save(domainType); } - static final class DefaultReactiveExecutableQuery implements ExecutableQuery { + final class DefaultReactiveExecutableQuery implements ExecutableQuery { private final PreparedQuery preparedQuery; private final ReactiveNeo4jClient.RecordFetchSpec fetchSpec; @@ -1193,12 +1211,12 @@ static final class DefaultReactiveExecutableQuery implements ExecutableQuery< @SuppressWarnings("unchecked") public Flux getResults() { - return fetchSpec.all().switchOnFirst((signal, f) -> { + return transactionalOperator.transactional(fetchSpec.all().switchOnFirst((signal, f) -> { if (signal.hasValue() && preparedQuery.resultsHaveBeenAggregated()) { return f.concatMap(nested -> Flux.fromIterable((Collection) nested).distinct()).distinct(); } return f; - }); + })); } /** @@ -1206,14 +1224,14 @@ public Flux getResults() { * @throws IncorrectResultSizeDataAccessException if there is no or more than one result */ public Mono getSingleResult() { - return fetchSpec.one().map(t -> { + return transactionalOperator.transactional(fetchSpec.one().map(t -> { if (t instanceof LinkedHashSet) { @SuppressWarnings("unchecked") T firstItem = (T) ((LinkedHashSet) t).iterator().next(); return firstItem; } return t; - }).onErrorMap(IndexOutOfBoundsException.class, e -> new IncorrectResultSizeDataAccessException(e.getMessage(), 1)); + }).onErrorMap(IndexOutOfBoundsException.class, e -> new IncorrectResultSizeDataAccessException(e.getMessage(), 1))); } } } diff --git a/src/main/java/org/springframework/data/neo4j/core/transaction/Neo4jTransactionManager.java b/src/main/java/org/springframework/data/neo4j/core/transaction/Neo4jTransactionManager.java index d1f20d7145..3ea3ea4a25 100644 --- a/src/main/java/org/springframework/data/neo4j/core/transaction/Neo4jTransactionManager.java +++ b/src/main/java/org/springframework/data/neo4j/core/transaction/Neo4jTransactionManager.java @@ -246,7 +246,7 @@ public void setApplicationContext(ApplicationContext applicationContext) throws // Otherwise we open a session and synchronize it. Session session = driver.session(Neo4jTransactionUtils.defaultSessionConfig(targetDatabase, asUser)); - Transaction transaction = session.beginTransaction(TransactionConfig.empty()); + Transaction transaction = session.beginTransaction(Neo4jTransactionUtils.createTransactionConfigFrom(TransactionDefinition.withDefaults(), -1)); // Manually create a new synchronization connectionHolder = new Neo4jTransactionHolder(new Neo4jTransactionContext(targetDatabase, asUser), session, transaction); connectionHolder.setSynchronizedWithTransaction(true); diff --git a/src/main/java/org/springframework/data/neo4j/core/transaction/ReactiveNeo4jTransactionManager.java b/src/main/java/org/springframework/data/neo4j/core/transaction/ReactiveNeo4jTransactionManager.java index 1ee02b62e9..f8c3223f0e 100644 --- a/src/main/java/org/springframework/data/neo4j/core/transaction/ReactiveNeo4jTransactionManager.java +++ b/src/main/java/org/springframework/data/neo4j/core/transaction/ReactiveNeo4jTransactionManager.java @@ -215,7 +215,7 @@ public static Mono retrieveReactiveTransaction( return Mono.defer(() -> { ReactiveSession session = driver.session(ReactiveSession.class, Neo4jTransactionUtils.defaultSessionConfig(targetDatabase, asUser)); - return Mono.fromDirect(session.beginTransaction(TransactionConfig.empty())).map(tx -> { + return Mono.fromDirect(session.beginTransaction(Neo4jTransactionUtils.createTransactionConfigFrom(TransactionDefinition.withDefaults(), -1))).map(tx -> { ReactiveNeo4jTransactionHolder newConnectionHolder = new ReactiveNeo4jTransactionHolder( new Neo4jTransactionContext(targetDatabase, asUser), session, tx); diff --git a/src/test/java/org/springframework/data/neo4j/integration/imperative/Neo4jTemplateIT.java b/src/test/java/org/springframework/data/neo4j/integration/imperative/Neo4jTemplateIT.java index b23fa2cd4f..abd670f832 100644 --- a/src/test/java/org/springframework/data/neo4j/integration/imperative/Neo4jTemplateIT.java +++ b/src/test/java/org/springframework/data/neo4j/integration/imperative/Neo4jTemplateIT.java @@ -49,6 +49,7 @@ import org.springframework.data.neo4j.test.Neo4jIntegrationTest; import org.springframework.transaction.PlatformTransactionManager; import org.springframework.transaction.annotation.EnableTransactionManagement; +import org.springframework.transaction.support.TransactionTemplate; import java.util.ArrayList; import java.util.Arrays; @@ -880,20 +881,22 @@ void saveWeirdHierarchy() { } @Test - void updatingFindShouldWork() { + void updatingFindShouldWork(@Autowired PlatformTransactionManager transactionManager) { Map params = new HashMap<>(); params.put("wrongName", "Siemons"); params.put("correctName", "Simons"); - Optional optionalResult = neo4jTemplate - .findOne("MERGE (p:Person {lastName: $wrongName}) ON MATCH set p.lastName = $correctName RETURN p", - params, Person.class); - - assertThat(optionalResult).hasValueSatisfying( - updatedPerson -> { - assertThat(updatedPerson.getLastName()).isEqualTo("Simons"); - assertThat(updatedPerson.getAddress()).isNull(); // We didn't fetch it - } - ); + new TransactionTemplate(transactionManager).executeWithoutResult(tx -> { + Optional optionalResult = neo4jTemplate + .findOne("MERGE (p:Person {lastName: $wrongName}) ON MATCH set p.lastName = $correctName RETURN p", + params, Person.class); + + assertThat(optionalResult).hasValueSatisfying( + updatedPerson -> { + assertThat(updatedPerson.getLastName()).isEqualTo("Simons"); + assertThat(updatedPerson.getAddress()).isNull(); // We didn't fetch it + } + ); + }); } @Test diff --git a/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java b/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java index 77ca706c55..4b9982d535 100644 --- a/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java +++ b/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java @@ -499,10 +499,12 @@ void aggregateThroughQueryIntoCustomObjectDTOShouldWork(@Autowired PersonReposit } @Test // DATAGRAPH-1429 - void queryAggregatesShouldWorkWithTheTemplate(@Autowired Neo4jTemplate template) { + void queryAggregatesShouldWorkWithTheTemplate(@Autowired Neo4jTemplate template, @Autowired PlatformTransactionManager transactionManager) { + new TransactionTemplate(transactionManager).executeWithoutResult(tx -> { - List people = template.findAll("unwind range(1,5) as i with i create (p:Person {firstName: toString(i)}) return p", Person.class); - assertThat(people).extracting(Person::getFirstName).containsExactly("1", "2", "3", "4", "5"); + List people = template.findAll("unwind range(1,5) as i with i create (p:Person {firstName: toString(i)}) return p", Person.class); + assertThat(people).extracting(Person::getFirstName).containsExactly("1", "2", "3", "4", "5"); + }); } @Test diff --git a/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveNeo4jTemplateIT.java b/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveNeo4jTemplateIT.java index d13c83c57b..63bafad7d2 100644 --- a/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveNeo4jTemplateIT.java +++ b/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveNeo4jTemplateIT.java @@ -57,6 +57,7 @@ import org.springframework.data.neo4j.test.Neo4jReactiveTestConfiguration; import org.springframework.transaction.ReactiveTransactionManager; import org.springframework.transaction.annotation.EnableTransactionManagement; +import org.springframework.transaction.reactive.TransactionalOperator; import reactor.core.publisher.Flux; import reactor.test.StepVerifier; @@ -533,9 +534,10 @@ void saveAllProjectionShouldWork(@Autowired ReactiveNeo4jTemplate template) { } @Test - void saveAllAsWithOpenProjectionShouldWork(@Autowired ReactiveNeo4jTemplate template) { + void saveAllAsWithOpenProjectionShouldWork(@Autowired ReactiveNeo4jTemplate template, @Autowired ReactiveTransactionManager transactionManager) { // Using a query on purpose so that the address is null + TransactionalOperator.create(transactionManager).transactional( template.findOne("MATCH (p:Person {lastName: $lastName}) RETURN p", Collections.singletonMap("lastName", "Siemons"), Person.class) .zipWith(template.findOne("MATCH (p:Person {lastName: $lastName}) RETURN p", @@ -550,7 +552,7 @@ void saveAllAsWithOpenProjectionShouldWork(@Autowired ReactiveNeo4jTemplate temp p2.setFirstName("Helga"); p2.setLastName("Schneider"); return template.saveAllAs(Arrays.asList(p1, p2), OpenProjection.class); - }) + })) .map(OpenProjection::getFullName) .sort() .as(StepVerifier::create) @@ -832,13 +834,15 @@ void saveAllAsWithClosedProjectionShouldWork(@Autowired ReactiveNeo4jTemplate te } @Test - void updatingFindShouldWork() { + void updatingFindShouldWork(@Autowired ReactiveTransactionManager transactionManager) { Map params = new HashMap<>(); params.put("wrongName", "Siemons"); params.put("correctName", "Simons"); - neo4jTemplate - .findOne("MERGE (p:Person {lastName: $wrongName}) ON MATCH set p.lastName = $correctName RETURN p", - params, Person.class) + TransactionalOperator.create(transactionManager) + .transactional( + neo4jTemplate + .findOne("MERGE (p:Person {lastName: $wrongName}) ON MATCH set p.lastName = $correctName RETURN p", + params, Person.class)) .as(StepVerifier::create) .consumeNextWith(updatedPerson -> { diff --git a/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveRepositoryIT.java b/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveRepositoryIT.java index 4d29e1175d..c4952a480c 100644 --- a/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveRepositoryIT.java +++ b/src/test/java/org/springframework/data/neo4j/integration/reactive/ReactiveRepositoryIT.java @@ -702,9 +702,9 @@ void aggregateThroughQueryIntoListShouldWork(@Autowired ReactivePersonRepository } @Test // DATAGRAPH-1429 - void queryAggregatesShouldWorkWithTheTemplate(@Autowired ReactiveNeo4jTemplate template) { + void queryAggregatesShouldWorkWithTheTemplate(@Autowired ReactiveNeo4jTemplate template, @Autowired ReactiveTransactionManager reactiveTransactionManager) { - Flux people = template.findAll("unwind range(1,5) as i with i create (p:Person {firstName: toString(i)}) return p", Person.class); + Flux people = TransactionalOperator.create(reactiveTransactionManager).transactional(template.findAll("unwind range(1,5) as i with i create (p:Person {firstName: toString(i)}) return p", Person.class)); StepVerifier.create(people.map(Person::getFirstName)) .expectNext("1", "2", "3", "4", "5")