diff --git a/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java index b88caee93..844190f08 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java @@ -22,6 +22,7 @@ import org.springframework.data.elasticsearch.core.convert.ElasticsearchConverter; import org.springframework.data.elasticsearch.core.convert.MappingElasticsearchConverter; import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse; +import org.springframework.data.elasticsearch.core.event.AfterSaveCallback; import org.springframework.data.elasticsearch.core.event.BeforeConvertCallback; import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentEntity; import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentProperty; @@ -45,6 +46,7 @@ * * @author Sascha Woo * @author Peter-Josef Meisch + * @author Roman Puchkovskiy */ public abstract class AbstractElasticsearchTemplate implements ElasticsearchOperations, ApplicationContextAware { @@ -117,8 +119,13 @@ public T save(T entity, IndexCoordinates index) { Assert.notNull(entity, "entity must not be null"); Assert.notNull(index, "index must not be null"); - index(getIndexQuery(entity), index); - return entity; + IndexQuery query = getIndexQuery(entity); + index(query, index); + + // suppressing because it's either entity itself or something of a correct type returned by an entity callback + @SuppressWarnings("unchecked") + T castResult = (T) query.getObject(); + return castResult; } @Override @@ -151,7 +158,10 @@ public Iterable save(Iterable entities, IndexCoordinates index) { }); } - return entities; + return indexQueries.stream() + .map(IndexQuery::getObject) + .map(entity -> (T) entity) + .collect(Collectors.toList()); } @Override @@ -455,11 +465,39 @@ protected void maybeCallbackBeforeConvertWithQuery(Object query) { } // this can be called with either a List or a List; these query classes - // don't have a common bas class, therefore the List argument + // don't have a common base class, therefore the List argument protected void maybeCallbackBeforeConvertWithQueries(List queries) { queries.forEach(this::maybeCallbackBeforeConvertWithQuery); } + protected T maybeCallbackAfterSave(T entity) { + + if (entityCallbacks != null) { + return entityCallbacks.callback(AfterSaveCallback.class, entity); + } + + return entity; + } + + protected void maybeCallbackAfterSaveWithQuery(Object query) { + + if (query instanceof IndexQuery) { + IndexQuery indexQuery = (IndexQuery) query; + Object queryObject = indexQuery.getObject(); + + if (queryObject != null) { + queryObject = maybeCallbackAfterSave(queryObject); + indexQuery.setObject(queryObject); + } + } + } + + // this can be called with either a List or a List; these query classes + // don't have a common base class, therefore the List argument + protected void maybeCallbackAfterSaveWithQueries(List queries) { + queries.forEach(this::maybeCallbackAfterSaveWithQuery); + } + // endregion } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java index aa9bfd710..114f1f99a 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java @@ -145,6 +145,9 @@ public String index(IndexQuery query, IndexCoordinates index) { if (queryObject != null) { setPersistentEntityId(queryObject, documentId); } + + maybeCallbackAfterSaveWithQuery(query); + return documentId; } @@ -226,7 +229,10 @@ public UpdateResponse update(UpdateQuery query, IndexCoordinates index) { private List doBulkOperation(List queries, BulkOptions bulkOptions, IndexCoordinates index) { maybeCallbackBeforeConvertWithQueries(queries); BulkRequest bulkRequest = requestFactory.bulkRequest(queries, bulkOptions, index); - return checkForBulkOperationFailure(execute(client -> client.bulk(bulkRequest, RequestOptions.DEFAULT))); + List ids = checkForBulkOperationFailure(execute( + client -> client.bulk(bulkRequest, RequestOptions.DEFAULT))); + maybeCallbackAfterSaveWithQueries(queries); + return ids; } // endregion diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java index 4a13fa7bf..17a117d46 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java @@ -31,14 +31,12 @@ import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.search.SearchScrollRequestBuilder; import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.client.Client; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.search.suggest.SuggestBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.data.domain.Pageable; import org.springframework.data.elasticsearch.core.convert.ElasticsearchConverter; import org.springframework.data.elasticsearch.core.document.DocumentAdapters; import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse; @@ -79,6 +77,7 @@ * @author Martin Choraine * @author Farid Azaza * @author Gyula Attila Csorogi + * @author Roman Puchkovskiy * @deprecated as of 4.0 */ @Deprecated @@ -153,6 +152,8 @@ public String index(IndexQuery query, IndexCoordinates index) { setPersistentEntityId(queryObject, documentId); } + maybeCallbackAfterSaveWithQuery(query); + return documentId; } @@ -188,7 +189,11 @@ public List bulkIndex(List queries, BulkOptions bulkOptions, Assert.notNull(queries, "List of IndexQuery must not be null"); Assert.notNull(bulkOptions, "BulkOptions must not be null"); - return doBulkOperation(queries, bulkOptions, index); + List ids = doBulkOperation(queries, bulkOptions, index); + + maybeCallbackAfterSaveWithQueries(queries); + + return ids; } @Override diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java index 7588f29df..ad9d9335f 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplate.java @@ -70,6 +70,7 @@ import org.springframework.data.elasticsearch.core.document.Document; import org.springframework.data.elasticsearch.core.document.DocumentAdapters; import org.springframework.data.elasticsearch.core.document.SearchDocument; +import org.springframework.data.elasticsearch.core.event.ReactiveAfterSaveCallback; import org.springframework.data.elasticsearch.core.event.ReactiveBeforeConvertCallback; import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentEntity; import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentProperty; @@ -97,6 +98,7 @@ * @author Peter-Josef Meisch * @author Mathias Teier * @author Aleksei Arsenev + * @author Roman Puchkovskiy * @since 3.2 */ public class ReactiveElasticsearchTemplate implements ReactiveElasticsearchOperations, ApplicationContextAware { @@ -185,7 +187,8 @@ public Mono save(T entity, IndexCoordinates index) { return doIndex(entity, adaptableEntity, index) // .map(it -> { return adaptableEntity.populateIdIfNecessary(it.getId()); - }); + }) + .flatMap(this::maybeCallAfterSave); } @Override @@ -213,11 +216,11 @@ public Flux saveAll(Mono> entities, Ind .map(e -> getIndexQuery(e.getBean(), e)) // .collect(Collectors.toList()); return doBulkOperation(indexRequests, BulkOptions.defaultOptions(), index) // - .map(bulkItemResponse -> { + .flatMap(bulkItemResponse -> { AdaptibleEntity mappedEntity = iterator.next(); mappedEntity.populateIdIfNecessary(bulkItemResponse.getResponse().getId()); - return mappedEntity.getBean(); + return maybeCallAfterSave(mappedEntity.getBean()); }); }); } @@ -882,5 +885,14 @@ protected Mono maybeCallBeforeConvert(T entity) { return Mono.just(entity); } + + protected Mono maybeCallAfterSave(T entity) { + + if (null != entityCallbacks) { + return entityCallbacks.callback(ReactiveAfterSaveCallback.class, entity); + } + + return Mono.just(entity); + } // endregion } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/event/AfterSaveCallback.java b/src/main/java/org/springframework/data/elasticsearch/core/event/AfterSaveCallback.java new file mode 100644 index 000000000..0ad8e1c3b --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/event/AfterSaveCallback.java @@ -0,0 +1,39 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core.event; + +import org.springframework.data.mapping.callback.EntityCallback; +import org.springframework.data.mapping.callback.EntityCallbacks; + +/** + * Entity callback triggered after save of an entity. + * + * @author Roman Puchkovskiy + * @since 4.0 + * @see EntityCallbacks + */ +@FunctionalInterface +public interface AfterSaveCallback extends EntityCallback { + + /** + * Entity callback method invoked after a domain object is saved. Can return either the same or a modified + * instance of the domain object. + * + * @param entity the domain object that was saved. + * @return the domain object that was persisted. + */ + T onAfterSave(T entity); +} diff --git a/src/main/java/org/springframework/data/elasticsearch/core/event/ReactiveAfterSaveCallback.java b/src/main/java/org/springframework/data/elasticsearch/core/event/ReactiveAfterSaveCallback.java new file mode 100644 index 000000000..561de46fd --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/event/ReactiveAfterSaveCallback.java @@ -0,0 +1,40 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core.event; + +import org.reactivestreams.Publisher; +import org.springframework.data.mapping.callback.EntityCallback; +import org.springframework.data.mapping.callback.ReactiveEntityCallbacks; + +/** + * Entity callback triggered after save of an entity. + * + * @author Roman Puchkovskiy + * @since 4.0 + * @see ReactiveEntityCallbacks + */ +@FunctionalInterface +public interface ReactiveAfterSaveCallback extends EntityCallback { + + /** + * Entity callback method invoked after a domain object is saved. Can return either the same or a modified + * instance of the domain object. + * + * @param entity the domain object that was saved. + * @return a {@link Publisher} emitting the domain object to be returned to the caller. + */ + Publisher onAfterSave(T entity); +} diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplateCallbackTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplateCallbackTests.java new file mode 100644 index 000000000..0edb8ad51 --- /dev/null +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplateCallbackTests.java @@ -0,0 +1,276 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.RestHighLevelClient; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.springframework.data.annotation.Id; +import org.springframework.data.elasticsearch.core.event.AfterSaveCallback; +import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; +import org.springframework.data.elasticsearch.core.query.BulkOptions; +import org.springframework.data.elasticsearch.core.query.IndexQuery; +import org.springframework.data.mapping.callback.EntityCallbacks; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; + +/** + * @author Roman Puchkovskiy + */ +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class ElasticsearchRestTemplateCallbackTests { + + private ElasticsearchRestTemplate template; + + @Mock + private RestHighLevelClient client; + + @Mock + private IndexResponse indexResponse; + @Mock + private BulkResponse bulkResponse; + @Mock + private BulkItemResponse bulkItemResponse; + + @BeforeEach + public void setUp() throws Exception { + template = new ElasticsearchRestTemplate(client); + + doReturn(indexResponse).when(client).index(any(IndexRequest.class), any(RequestOptions.class)); + doReturn("response-id").when(indexResponse).getId(); + + doReturn(bulkResponse).when(client).bulk(any(BulkRequest.class), any(RequestOptions.class)); + doReturn(new BulkItemResponse[] {bulkItemResponse, bulkItemResponse}).when(bulkResponse).getItems(); + doReturn("response-id").when(bulkItemResponse).getId(); + } + + @Test // DATAES-771 + void saveOneShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity = new Person("init", "luke"); + + Person saved = template.save(entity); + + verify(afterSaveCallback).onAfterSave(eq(entity)); + assertThat(saved.id).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveWithIndexCoordinatesShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity = new Person("init", "luke"); + + Person saved = template.save(entity, IndexCoordinates.of("index")); + + verify(afterSaveCallback).onAfterSave(eq(entity)); + assertThat(saved.id).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveArrayShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + Iterable saved = template.save(entity1, entity2); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + Iterator savedIterator = saved.iterator(); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveIterableShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + Iterable saved = template.save(Arrays.asList(entity1, entity2)); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + Iterator savedIterator = saved.iterator(); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveIterableWithIndexCoordinatesShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + Iterable saved = template.save(Arrays.asList(entity1, entity2), IndexCoordinates.of("index")); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + Iterator savedIterator = saved.iterator(); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void indexShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity = new Person("init", "luke"); + + IndexQuery indexQuery = indexQueryForEntity(entity); + template.index(indexQuery, IndexCoordinates.of("index")); + + verify(afterSaveCallback).onAfterSave(eq(entity)); + Person newPerson = (Person) indexQuery.getObject(); + assertThat(newPerson.id).isEqualTo("after-save"); + } + + private IndexQuery indexQueryForEntity(Person entity) { + IndexQuery indexQuery = new IndexQuery(); + indexQuery.setObject(entity); + return indexQuery; + } + + @Test // DATAES-771 + void bulkIndexShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + IndexQuery query1 = indexQueryForEntity(entity1); + IndexQuery query2 = indexQueryForEntity(entity2); + template.bulkIndex(Arrays.asList(query1, query2), IndexCoordinates.of("index")); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + Person savedPerson1 = (Person) query1.getObject(); + Person savedPerson2 = (Person) query2.getObject(); + assertThat(savedPerson1.getId()).isEqualTo("after-save"); + assertThat(savedPerson2.getId()).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void bulkIndexWithOptionsShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + IndexQuery query1 = indexQueryForEntity(entity1); + IndexQuery query2 = indexQueryForEntity(entity2); + template.bulkIndex(Arrays.asList(query1, query2), BulkOptions.defaultOptions(), IndexCoordinates.of("index")); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + Person savedPerson1 = (Person) query1.getObject(); + Person savedPerson2 = (Person) query2.getObject(); + assertThat(savedPerson1.getId()).isEqualTo("after-save"); + assertThat(savedPerson2.getId()).isEqualTo("after-save"); + } + + @Data + @AllArgsConstructor + @NoArgsConstructor + static class Person { + + @Id String id; + String firstname; + } + + static class ValueCapturingEntityCallback { + + private final List values = new ArrayList<>(1); + + protected void capture(T value) { + values.add(value); + } + + public List getValues() { + return values; + } + + @Nullable + public T getValue() { + return CollectionUtils.lastElement(values); + } + + } + + static class ValueCapturingAfterSaveCallback extends ValueCapturingEntityCallback + implements AfterSaveCallback { + + @Override + public Person onAfterSave(Person entity) { + + capture(entity); + return new Person() { + { + id = "after-save"; + firstname = entity.firstname; + } + }; + } + } +} diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTransportTemplateCallbackTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTransportTemplateCallbackTests.java new file mode 100644 index 000000000..b04c1e7cf --- /dev/null +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTransportTemplateCallbackTests.java @@ -0,0 +1,288 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.client.Client; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.springframework.data.annotation.Id; +import org.springframework.data.elasticsearch.core.event.AfterSaveCallback; +import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; +import org.springframework.data.elasticsearch.core.query.BulkOptions; +import org.springframework.data.elasticsearch.core.query.IndexQuery; +import org.springframework.data.mapping.callback.EntityCallbacks; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; + +/** + * @author Roman Puchkovskiy + */ +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class ElasticsearchTransportTemplateCallbackTests { + + private ElasticsearchTemplate template; + + @Mock + private Client client; + + @Mock + private IndexRequestBuilder indexRequestBuilder; + @Mock + private ActionFuture indexResponseActionFuture; + @Mock + private IndexResponse indexResponse; + @Mock + private BulkRequestBuilder bulkRequestBuilder; + @Mock + private ActionFuture bulkResponseActionFuture; + @Mock + private BulkResponse bulkResponse; + @Mock + private BulkItemResponse bulkItemResponse; + + @BeforeEach + public void setUp() { + template = new ElasticsearchTemplate(client); + + when(client.prepareIndex(anyString(), anyString(), anyString())).thenReturn(indexRequestBuilder); + doReturn(indexResponseActionFuture).when(indexRequestBuilder).execute(); + when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + doReturn("response-id").when(indexResponse).getId(); + + when(client.prepareBulk()).thenReturn(bulkRequestBuilder); + doReturn(bulkResponseActionFuture).when(bulkRequestBuilder).execute(); + when(bulkResponseActionFuture.actionGet()).thenReturn(bulkResponse); + doReturn(new BulkItemResponse[] {bulkItemResponse, bulkItemResponse}).when(bulkResponse).getItems(); + doReturn("response-id").when(bulkItemResponse).getId(); + } + + @Test // DATAES-771 + void saveOneShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity = new Person("init", "luke"); + + Person saved = template.save(entity); + + verify(afterSaveCallback).onAfterSave(eq(entity)); + assertThat(saved.id).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveWithIndexCoordinatesShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity = new Person("init", "luke"); + + Person saved = template.save(entity, IndexCoordinates.of("index")); + + verify(afterSaveCallback).onAfterSave(eq(entity)); + assertThat(saved.id).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveArrayShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + Iterable saved = template.save(entity1, entity2); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + Iterator savedIterator = saved.iterator(); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveIterableShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + Iterable saved = template.save(Arrays.asList(entity1, entity2)); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + Iterator savedIterator = saved.iterator(); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveIterableWithIndexCoordinatesShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + Iterable saved = template.save(Arrays.asList(entity1, entity2), IndexCoordinates.of("index")); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + Iterator savedIterator = saved.iterator(); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + assertThat(savedIterator.next().getId()).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void indexShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity = new Person("init", "luke"); + + IndexQuery indexQuery = indexQueryForEntity(entity); + template.index(indexQuery, IndexCoordinates.of("index")); + + verify(afterSaveCallback).onAfterSave(eq(entity)); + Person savedPerson = (Person) indexQuery.getObject(); + assertThat(savedPerson.id).isEqualTo("after-save"); + } + + private IndexQuery indexQueryForEntity(Person entity) { + IndexQuery indexQuery = new IndexQuery(); + indexQuery.setObject(entity); + return indexQuery; + } + + @Test // DATAES-771 + void bulkIndexShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + IndexQuery query1 = indexQueryForEntity(entity1); + IndexQuery query2 = indexQueryForEntity(entity2); + template.bulkIndex(Arrays.asList(query1, query2), IndexCoordinates.of("index")); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + Person savedPerson1 = (Person) query1.getObject(); + Person savedPerson2 = (Person) query2.getObject(); + assertThat(savedPerson1.getId()).isEqualTo("after-save"); + assertThat(savedPerson2.getId()).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void bulkIndexWithOptionsShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(EntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + IndexQuery query1 = indexQueryForEntity(entity1); + IndexQuery query2 = indexQueryForEntity(entity2); + template.bulkIndex(Arrays.asList(query1, query2), BulkOptions.defaultOptions(), IndexCoordinates.of("index")); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + Person savedPerson1 = (Person) query1.getObject(); + Person savedPerson2 = (Person) query2.getObject(); + assertThat(savedPerson1.getId()).isEqualTo("after-save"); + assertThat(savedPerson2.getId()).isEqualTo("after-save"); + } + + @Data + @AllArgsConstructor + @NoArgsConstructor + static class Person { + + @Id String id; + String firstname; + } + + static class ValueCapturingEntityCallback { + + private final List values = new ArrayList<>(1); + + protected void capture(T value) { + values.add(value); + } + + public List getValues() { + return values; + } + + @Nullable + public T getValue() { + return CollectionUtils.lastElement(values); + } + + } + + static class ValueCapturingAfterSaveCallback extends ValueCapturingEntityCallback + implements AfterSaveCallback { + + @Override + public Person onAfterSave(Person entity) { + + capture(entity); + return new Person() { + { + id = "after-save"; + firstname = entity.firstname; + } + }; + } + } +} diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplateCallbackTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplateCallbackTests.java new file mode 100644 index 000000000..8271ee709 --- /dev/null +++ b/src/test/java/org/springframework/data/elasticsearch/core/ReactiveElasticsearchTemplateCallbackTests.java @@ -0,0 +1,229 @@ +/* + * Copyright 2018-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.springframework.data.annotation.Id; +import org.springframework.data.elasticsearch.client.reactive.ReactiveElasticsearchClient; +import org.springframework.data.elasticsearch.core.event.ReactiveAfterSaveCallback; +import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; +import org.springframework.data.mapping.callback.ReactiveEntityCallbacks; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; + +/** + * @author Roman Puchkovskiy + */ +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class ReactiveElasticsearchTemplateCallbackTests { + + private ReactiveElasticsearchTemplate template; + + @Mock + private ReactiveElasticsearchClient client; + + @Mock + private IndexResponse indexResponse; + @Mock + private BulkResponse bulkResponse; + @Mock + private BulkItemResponse bulkItemResponse; + @Mock + private DocWriteResponse docWriteResponse; + + @BeforeEach + public void setUp() { + template = new ReactiveElasticsearchTemplate(client); + + when(client.index(any(IndexRequest.class))).thenReturn(Mono.just(indexResponse)); + doReturn("response-id").when(indexResponse).getId(); + + when(client.bulk(any(BulkRequest.class))).thenReturn(Mono.just(bulkResponse)); + doReturn(new BulkItemResponse[] {bulkItemResponse, bulkItemResponse}).when(bulkResponse).getItems(); + doReturn(docWriteResponse).when(bulkItemResponse).getResponse(); + doReturn("response-id").when(docWriteResponse).getId(); + } + + @Test // DATAES-771 + void saveOneShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(ReactiveEntityCallbacks.create(afterSaveCallback)); + + Person entity = new Person("init", "luke"); + + Person saved = template.save(entity).block(Duration.ofSeconds(1)); + + verify(afterSaveCallback).onAfterSave(eq(entity)); + assertThat(saved.id).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveOneFromPublisherShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(ReactiveEntityCallbacks.create(afterSaveCallback)); + + Person entity = new Person("init", "luke"); + + Person saved = template.save(Mono.just(entity)).block(Duration.ofSeconds(1)); + + verify(afterSaveCallback).onAfterSave(eq(entity)); + assertThat(saved.id).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveWithIndexCoordinatesShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(ReactiveEntityCallbacks.create(afterSaveCallback)); + + Person entity = new Person("init", "luke"); + + Person saved = template.save(entity, IndexCoordinates.of("index")).block(Duration.ofSeconds(1)); + + verify(afterSaveCallback).onAfterSave(eq(entity)); + assertThat(saved.id).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveFromPublisherWithIndexCoordinatesShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(ReactiveEntityCallbacks.create(afterSaveCallback)); + + Person entity = new Person("init", "luke"); + + Person saved = template.save(Mono.just(entity), IndexCoordinates.of("index")).block(Duration.ofSeconds(1)); + + verify(afterSaveCallback).onAfterSave(eq(entity)); + assertThat(saved.id).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveAllShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(ReactiveEntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + List saved = template.saveAll(Arrays.asList(entity1, entity2), IndexCoordinates.of("index")) + .toStream().collect(Collectors.toList()); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + assertThat(saved.get(0).getId()).isEqualTo("after-save"); + assertThat(saved.get(1).getId()).isEqualTo("after-save"); + } + + @Test // DATAES-771 + void saveFromMonoAllShouldInvokeAfterSaveCallbacks() { + + ValueCapturingAfterSaveCallback afterSaveCallback = spy(new ValueCapturingAfterSaveCallback()); + + template.setEntityCallbacks(ReactiveEntityCallbacks.create(afterSaveCallback)); + + Person entity1 = new Person("init1", "luke1"); + Person entity2 = new Person("init2", "luke2"); + + List saved = template.saveAll(Mono.just(Arrays.asList(entity1, entity2)), IndexCoordinates.of("index")) + .toStream().collect(Collectors.toList()); + + verify(afterSaveCallback, times(2)).onAfterSave(any()); + assertThat(saved.get(0).getId()).isEqualTo("after-save"); + assertThat(saved.get(1).getId()).isEqualTo("after-save"); + } + + @Data + @AllArgsConstructor + @NoArgsConstructor + static class Person { + + @Id String id; + String firstname; + } + + static class ValueCapturingEntityCallback { + + private final List values = new ArrayList<>(1); + + protected void capture(T value) { + values.add(value); + } + + public List getValues() { + return values; + } + + @Nullable + public T getValue() { + return CollectionUtils.lastElement(values); + } + + } + + static class ValueCapturingAfterSaveCallback extends ValueCapturingEntityCallback + implements ReactiveAfterSaveCallback { + + @Override + public Mono onAfterSave(Person entity) { + + return Mono.defer(() -> { + capture(entity); + Person newPerson = new Person() { + { + id = "after-save"; + firstname = entity.firstname; + } + }; + return Mono.just(newPerson); + }); + } + } +}