Skip to content

Fix basic update overriding values. #4921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.5.0-SNAPSHOT</version>
<version>4.5.x-GH-4918-SNAPSHOT</version>
<packaging>pom</packaging>

<name>Spring Data MongoDB</name>
Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-distribution/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.5.0-SNAPSHOT</version>
<version>4.5.x-GH-4918-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.5.0-SNAPSHOT</version>
<version>4.5.x-GH-4918-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@
*/
package org.springframework.data.mongodb.core.query;

import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;

import org.bson.Document;

import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;

/**
* {@link Document}-based {@link Update} variant.
*
* @author Thomas Risberg
* @author John Brisbin
* @author Oliver Gierke
Expand All @@ -33,74 +41,114 @@ public class BasicUpdate extends Update {
private final Document updateObject;

public BasicUpdate(String updateString) {
super();
this.updateObject = Document.parse(updateString);
this(Document.parse(updateString));
}

public BasicUpdate(Document updateObject) {
super();
this.updateObject = updateObject;
}

@Override
public Update set(String key, @Nullable Object value) {
updateObject.put("$set", Collections.singletonMap(key, value));
setOperationValue("$set", key, value);
return this;
}

@Override
public Update unset(String key) {
updateObject.put("$unset", Collections.singletonMap(key, 1));
setOperationValue("$unset", key, 1);
return this;
}

@Override
public Update inc(String key, Number inc) {
updateObject.put("$inc", Collections.singletonMap(key, inc));
setOperationValue("$inc", key, inc);
return this;
}

@Override
public Update push(String key, @Nullable Object value) {
updateObject.put("$push", Collections.singletonMap(key, value));
setOperationValue("$push", key, value);
return this;
}

@Override
public Update addToSet(String key, @Nullable Object value) {
updateObject.put("$addToSet", Collections.singletonMap(key, value));
setOperationValue("$addToSet", key, value);
return this;
}

@Override
public Update pop(String key, Position pos) {
updateObject.put("$pop", Collections.singletonMap(key, (pos == Position.FIRST ? -1 : 1)));
setOperationValue("$pop", key, (pos == Position.FIRST ? -1 : 1));
return this;
}

@Override
public Update pull(String key, @Nullable Object value) {
updateObject.put("$pull", Collections.singletonMap(key, value));
setOperationValue("$pull", key, value);
return this;
}

@Override
public Update pullAll(String key, Object[] values) {
Document keyValue = new Document();
keyValue.put(key, Arrays.copyOf(values, values.length));
updateObject.put("$pullAll", keyValue);
setOperationValue("$pullAll", key, List.of(values), (o, o2) -> {

if (o instanceof List<?> prev && o2 instanceof List<?> currentValue) {
List<Object> merged = new ArrayList<>(prev.size() + currentValue.size());
merged.addAll(prev);
merged.addAll(currentValue);
return merged;
}

return o2;
});
return this;
}

@Override
public Update rename(String oldName, String newName) {
updateObject.put("$rename", Collections.singletonMap(oldName, newName));
setOperationValue("$rename", oldName, newName);
return this;
}

@Override
public boolean modifies(String key) {
return super.modifies(key) || Update.fromDocument(getUpdateObject()).modifies(key);
}

@Override
public Document getUpdateObject() {
return updateObject;
}

void setOperationValue(String operator, String key, @Nullable Object value) {
setOperationValue(operator, key, value, (o, o2) -> o2);
}

void setOperationValue(String operator, String key, @Nullable Object value,
BiFunction<Object, Object, Object> mergeFunction) {

if (!updateObject.containsKey(operator)) {
updateObject.put(operator, Collections.singletonMap(key, value));
} else {
Object o = updateObject.get(operator);
if (o instanceof Map<?, ?> existing) {
Map<Object, Object> target = new LinkedHashMap<>(existing);

if (target.containsKey(key)) {
target.put(key, mergeFunction.apply(target.get(key), value));
} else {
target.put(key, value);
}
updateObject.put(operator, target);
} else {
throw new IllegalStateException(
"Cannot add ['%s' : { '%s' : ... }]. Operator already exists with value of type [%s] which is not suitable for appending"
.formatted(operator, key,
o != null ? ClassUtils.getShortName(o.getClass()) : "null"));
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,11 @@ protected void addMultiFieldOperation(String operator, String key, @Nullable Obj
if (existingValue == null) {
keyValueMap = new Document();
this.modifierOps.put(operator, keyValueMap);
} else if (existingValue instanceof Document document) {
keyValueMap = document;
} else {
if (existingValue instanceof Document document) {
keyValueMap = document;
} else {
throw new InvalidDataAccessApiUsageException(
"Modifier Operations should be a LinkedHashMap but was " + existingValue.getClass());
}
throw new InvalidDataAccessApiUsageException(
"Modifier Operations should be a LinkedHashMap but was " + existingValue.getClass());
}

keyValueMap.put(key, value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.springframework.data.mongodb.core.aggregation.SetOperation;
import org.springframework.data.mongodb.core.mapping.Document;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.core.query.BasicUpdate;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
Expand Down Expand Up @@ -326,6 +327,20 @@ void updateFirstWithSort(Class<?> domainType, Sort sort, UpdateDefinition update
"Science is real!");
}

@Test // GH-4918
void updateShouldHonorVersionProvided() {

Versioned source = template.insert(Versioned.class).one(new Versioned("id-1", "value-0"));

Update update = new BasicUpdate("{ '$set' : { 'value' : 'changed' }, '$inc' : { 'version' : 10 } }");
template.update(Versioned.class).matching(Query.query(Criteria.where("id").is(source.id))).apply(update).first();

assertThat(
collection(Versioned.class).find(new org.bson.Document("_id", source.id)).limit(1).into(new ArrayList<>()))
.containsExactly(new org.bson.Document("_id", source.id).append("version", 10L).append("value", "changed")
.append("_class", "org.springframework.data.mongodb.core.MongoTemplateUpdateTests$Versioned"));
}

private List<org.bson.Document> all(Class<?> type) {
return collection(type).find(new org.bson.Document()).into(new ArrayList<>());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.query;

import static org.assertj.core.api.Assertions.*;
import static org.springframework.data.mongodb.test.util.Assertions.*;
import static org.springframework.data.mongodb.test.util.Assertions.assertThat;

import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Stream;

import org.bson.Document;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;

import org.springframework.data.mongodb.core.query.Update.Position;

/**
* Unit tests for {@link BasicUpdate}.
*
* @author Christoph Strobl
* @author Mark Paluch
*/
class BasicUpdateUnitTests {

@Test // GH-4918
void setOperationValueShouldAppendsOpsCorrectly() {

BasicUpdate basicUpdate = new BasicUpdate("{}");
basicUpdate.setOperationValue("$set", "key1", "alt");
basicUpdate.setOperationValue("$set", "key2", "nps");
basicUpdate.setOperationValue("$unset", "key3", "x");

assertThat(basicUpdate.getUpdateObject())
.isEqualTo("{ '$set' : { 'key1' : 'alt', 'key2' : 'nps' }, '$unset' : { 'key3' : 'x' } }");
}

@Test // GH-4918
void setOperationErrorsOnNonMapType() {

BasicUpdate basicUpdate = new BasicUpdate("{ '$set' : 1 }");
assertThatExceptionOfType(IllegalStateException.class)
.isThrownBy(() -> basicUpdate.setOperationValue("$set", "k", "v"));
}

@ParameterizedTest // GH-4918
@CsvSource({ //
"{ }, k1, false", //
"{ '$set' : { 'k1' : 'v1' } }, k1, true", //
"{ '$set' : { 'k1' : 'v1' } }, k2, false", //
"{ '$set' : { 'k1.k2' : 'v1' } }, k1, false", //
"{ '$set' : { 'k1.k2' : 'v1' } }, k1.k2, true", //
"{ '$set' : { 'k1' : 'v1' } }, '', false", //
"{ '$inc' : { 'k1' : 1 } }, k1, true" })
void modifiesLooksUpKeyCorrectly(String source, String key, boolean modified) {

BasicUpdate basicUpdate = new BasicUpdate(source);
assertThat(basicUpdate.modifies(key)).isEqualTo(modified);
}

@ParameterizedTest // GH-4918
@MethodSource("updateOpArgs")
void updateOpsShouldNotOverrideExistingValues(String operator, Function<BasicUpdate, Update> updateFunction) {

Document source = Document.parse("{ '%s' : { 'key-1' : 'value-1' } }".formatted(operator));
Update update = updateFunction.apply(new BasicUpdate(source));

assertThat(update.getUpdateObject()).containsEntry("%s.key-1".formatted(operator), "value-1")
.containsKey("%s.key-2".formatted(operator));
}

@Test // GH-4918
void shouldNotOverridePullAll() {

Document source = Document.parse("{ '$pullAll' : { 'key-1' : ['value-1'] } }");
Update update = new BasicUpdate(source).pullAll("key-1", new String[] { "value-2" }).pullAll("key-2",
new String[] { "value-3" });

assertThat(update.getUpdateObject()).containsEntry("$pullAll.key-1", Arrays.asList("value-1", "value-2"))
.containsEntry("$pullAll.key-2", List.of("value-3"));
}

static Stream<Arguments> updateOpArgs() {
return Stream.of( //
Arguments.of("$set", (Function<BasicUpdate, Update>) update -> update.set("key-2", "value-2")),
Arguments.of("$unset", (Function<BasicUpdate, Update>) update -> update.unset("key-2")),
Arguments.of("$inc", (Function<BasicUpdate, Update>) update -> update.inc("key-2", 1)),
Arguments.of("$push", (Function<BasicUpdate, Update>) update -> update.push("key-2", "value-2")),
Arguments.of("$addToSet", (Function<BasicUpdate, Update>) update -> update.addToSet("key-2", "value-2")),
Arguments.of("$pop", (Function<BasicUpdate, Update>) update -> update.pop("key-2", Position.FIRST)),
Arguments.of("$pull", (Function<BasicUpdate, Update>) update -> update.pull("key-2", "value-2")),
Arguments.of("$pullAll",
(Function<BasicUpdate, Update>) update -> update.pullAll("key-2", new String[] { "value-2" })),
Arguments.of("$rename", (Function<BasicUpdate, Update>) update -> update.rename("key-2", "value-2")));
};
}
Loading