From 0aa3aba44f46f610999794e2cd142adaddd0f393 Mon Sep 17 00:00:00 2001 From: Peter-Josef Meisch Date: Thu, 6 May 2021 21:12:09 +0200 Subject: [PATCH] use scripted fields as ctor parameter, DocumentAdaptor improvements --- .../core/document/DocumentAdapters.java | 94 ++++++++++--------- .../mapping/KebabCaseFieldNamingStrategy.java | 28 ++++++ .../core/ElasticsearchTemplateTests.java | 85 ++++++++++++++++- 3 files changed, 163 insertions(+), 44 deletions(-) create mode 100644 src/main/java/org/springframework/data/elasticsearch/core/mapping/KebabCaseFieldNamingStrategy.java diff --git a/src/main/java/org/springframework/data/elasticsearch/core/document/DocumentAdapters.java b/src/main/java/org/springframework/data/elasticsearch/core/document/DocumentAdapters.java index a60be02a0..394836998 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/document/DocumentAdapters.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/document/DocumentAdapters.java @@ -23,7 +23,6 @@ import java.util.Collection; import java.util.HashMap; import java.util.LinkedHashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -63,36 +62,38 @@ * @author Matt Gilene * @since 4.0 */ -public class DocumentAdapters { +public final class DocumentAdapters { + + private DocumentAdapters() {} /** * Create a {@link Document} from {@link GetResponse}. *

- * Returns a {@link Document} using the source if available. + * Returns a {@link Document} using the getResponse if available. * - * @param source the source {@link GetResponse}. - * @return the adapted {@link Document}, null if source.isExists() returns false. + * @param getResponse the getResponse {@link GetResponse}. + * @return the adapted {@link Document}, null if getResponse.isExists() returns false. */ @Nullable - public static Document from(GetResponse source) { + public static Document from(GetResponse getResponse) { - Assert.notNull(source, "GetResponse must not be null"); + Assert.notNull(getResponse, "GetResponse must not be null"); - if (!source.isExists()) { + if (!getResponse.isExists()) { return null; } - if (source.isSourceEmpty()) { - return fromDocumentFields(source, source.getIndex(), source.getId(), source.getVersion(), source.getSeqNo(), - source.getPrimaryTerm()); + if (getResponse.isSourceEmpty()) { + return fromDocumentFields(getResponse, getResponse.getIndex(), getResponse.getId(), getResponse.getVersion(), + getResponse.getSeqNo(), getResponse.getPrimaryTerm()); } - Document document = Document.from(source.getSourceAsMap()); - document.setIndex(source.getIndex()); - document.setId(source.getId()); - document.setVersion(source.getVersion()); - document.setSeqNo(source.getSeqNo()); - document.setPrimaryTerm(source.getPrimaryTerm()); + Document document = Document.from(getResponse.getSourceAsMap()); + document.setIndex(getResponse.getIndex()); + document.setId(getResponse.getId()); + document.setVersion(getResponse.getVersion()); + document.setSeqNo(getResponse.getSeqNo()); + document.setPrimaryTerm(getResponse.getPrimaryTerm()); return document; } @@ -188,9 +189,10 @@ public static SearchDocument from(SearchHit source) { if (sourceRef == null || sourceRef.length() == 0) { return new SearchDocumentAdapter( - source.getScore(), source.getSortValues(), source.getFields(), highlightFields, fromDocumentFields(source, - source.getIndex(), source.getId(), source.getVersion(), source.getSeqNo(), source.getPrimaryTerm()), - innerHits, nestedMetaData, explanation, matchedQueries); + fromDocumentFields(source, source.getIndex(), source.getId(), source.getVersion(), source.getSeqNo(), + source.getPrimaryTerm()), + source.getScore(), source.getSortValues(), source.getFields(), highlightFields, innerHits, nestedMetaData, + explanation, matchedQueries); } Document document = Document.from(source.getSourceAsMap()); @@ -203,8 +205,8 @@ public static SearchDocument from(SearchHit source) { document.setSeqNo(source.getSeqNo()); document.setPrimaryTerm(source.getPrimaryTerm()); - return new SearchDocumentAdapter(source.getScore(), source.getSortValues(), source.getFields(), highlightFields, - document, innerHits, nestedMetaData, explanation, matchedQueries); + return new SearchDocumentAdapter(document, source.getScore(), source.getSortValues(), source.getFields(), + highlightFields, innerHits, nestedMetaData, explanation, matchedQueries); } @Nullable @@ -243,6 +245,10 @@ private static List from(@Nullable String[] matchedQueries) { * * @param documentFields the {@link DocumentField}s backing the {@link Document}. * @param index the index where the Document was found + * @param id the document id + * @param version the document version + * @param seqNo the seqNo if the document + * @param primaryTerm the primaryTerm of the document * @return the adapted {@link Document}. */ public static Document fromDocumentFields(Iterable documentFields, String index, String id, @@ -261,10 +267,13 @@ public static Document fromDocumentFields(Iterable documentFields return new DocumentFieldAdapter(fields, index, id, version, seqNo, primaryTerm); } - // TODO: Performance regarding keys/values/entry-set + /** + * Adapter for a collection of {@link DocumentField}s. + */ static class DocumentFieldAdapter implements Document { private final Collection documentFields; + private final Map documentFieldMap; private final String index; private final String id; private final long version; @@ -274,6 +283,8 @@ static class DocumentFieldAdapter implements Document { DocumentFieldAdapter(Collection documentFields, String index, String id, long version, long seqNo, long primaryTerm) { this.documentFields = documentFields; + this.documentFieldMap = new LinkedHashMap<>(documentFields.size()); + documentFields.forEach(documentField -> documentFieldMap.put(documentField.getName(), documentField)); this.index = index; this.id = id; this.version = version; @@ -353,14 +364,7 @@ public boolean isEmpty() { @Override public boolean containsKey(Object key) { - - for (DocumentField documentField : documentFields) { - if (documentField.getName().equals(key)) { - return true; - } - } - - return false; + return documentFieldMap.containsKey(key); } @Override @@ -380,11 +384,9 @@ public boolean containsValue(Object value) { @Override @Nullable public Object get(Object key) { - return documentFields.stream() // - .filter(documentField -> documentField.getName().equals(key)) // - .map(DocumentField::getValue).findFirst() // - .orElse(null); // + DocumentField documentField = documentFieldMap.get(key); + return documentField != null ? documentField.getValue() : null; } @Override @@ -409,17 +411,18 @@ public void clear() { @Override public Set keySet() { - return documentFields.stream().map(DocumentField::getName).collect(Collectors.toCollection(LinkedHashSet::new)); + return documentFieldMap.keySet(); } @Override public Collection values() { - return documentFields.stream().map(DocumentFieldAdapter::getValue).collect(Collectors.toList()); + return documentFieldMap.values().stream().map(DocumentFieldAdapter::getValue).collect(Collectors.toList()); } @Override public Set> entrySet() { - return documentFields.stream().collect(Collectors.toMap(DocumentField::getName, DocumentFieldAdapter::getValue)) + return documentFieldMap.entrySet().stream() + .collect(Collectors.toMap(Entry::getKey, entry -> DocumentFieldAdapter.getValue(entry.getValue()))) .entrySet(); } @@ -458,7 +461,6 @@ public String toJson() { } } - @Override public String toString() { return getClass().getSimpleName() + '@' + this.id + '#' + this.version + ' ' + toJson(); @@ -494,14 +496,14 @@ static class SearchDocumentAdapter implements SearchDocument { @Nullable private final Explanation explanation; @Nullable private final List matchedQueries; - SearchDocumentAdapter(float score, Object[] sortValues, Map fields, - Map> highlightFields, Document delegate, Map innerHits, + SearchDocumentAdapter(Document delegate, float score, Object[] sortValues, Map fields, + Map> highlightFields, Map innerHits, @Nullable NestedMetaData nestedMetaData, @Nullable Explanation explanation, @Nullable List matchedQueries) { + this.delegate = delegate; this.score = score; this.sortValues = sortValues; - this.delegate = delegate; fields.forEach((name, documentField) -> this.fields.put(name, documentField.getValues())); this.highlightFields.putAll(highlightFields); this.innerHits.putAll(innerHits); @@ -646,7 +648,13 @@ public boolean containsValue(Object value) { @Override public Object get(Object key) { - return delegate.get(key); + + if (delegate.containsKey(key)) { + return delegate.get(key); + } + + // fallback to fields + return fields.get(key); } @Override diff --git a/src/main/java/org/springframework/data/elasticsearch/core/mapping/KebabCaseFieldNamingStrategy.java b/src/main/java/org/springframework/data/elasticsearch/core/mapping/KebabCaseFieldNamingStrategy.java new file mode 100644 index 000000000..090972637 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/mapping/KebabCaseFieldNamingStrategy.java @@ -0,0 +1,28 @@ +/* + * Copyright 2019-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core.mapping; + +import org.springframework.data.mapping.model.CamelCaseSplittingFieldNamingStrategy; + +/** + * @author Peter-Josef Meisch + * @since 4.3 + */ +public class KebabCaseFieldNamingStrategy extends CamelCaseSplittingFieldNamingStrategy { + public KebabCaseFieldNamingStrategy() { + super("-"); + } +} diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java index 64bab266a..fc59cee2e 100755 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java @@ -19,6 +19,7 @@ import static org.assertj.core.api.Assertions.*; import static org.elasticsearch.index.query.QueryBuilders.*; import static org.springframework.data.elasticsearch.annotations.FieldType.*; +import static org.springframework.data.elasticsearch.annotations.FieldType.Integer; import static org.springframework.data.elasticsearch.core.document.Document.*; import static org.springframework.data.elasticsearch.utils.IdGenerator.*; import static org.springframework.data.elasticsearch.utils.IndexBuilder.*; @@ -3585,6 +3586,31 @@ void shouldWorkWithImmutableClasses() { assertThat(retrieved).isEqualTo(saved); } + @Test // #1488 + @DisplayName("should set scripted fields on immutable objects") + void shouldSetScriptedFieldsOnImmutableObjects() { + + ImmutableWithScriptedEntity entity = new ImmutableWithScriptedEntity("42", 42, null); + operations.save(entity); + + Map params = new HashMap<>(); + params.put("factor", 2); + NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) + .withSourceFilter(new FetchSourceFilter(new String[] { "*" }, new String[] {})) + .withScriptField( + new ScriptField("scriptedRate", new Script(ScriptType.INLINE, "expression", "doc['rate'] * factor", params))) + .build(); + + SearchHits searchHits = operations.search(searchQuery, + ImmutableWithScriptedEntity.class); + + assertThat(searchHits.getTotalHits()).isEqualTo(1); + ImmutableWithScriptedEntity foundEntity = searchHits.getSearchHit(0).getContent(); + assertThat(foundEntity.getId()).isEqualTo("42"); + assertThat(foundEntity.getRate()).isEqualTo(42); + assertThat(foundEntity.getScriptedRate()).isEqualTo(84.0); + } + // region entities @Document(indexName = INDEX_NAME_SAMPLE_ENTITY) @Setting(shards = 1, replicas = 0, refreshInterval = "-1") @@ -4366,7 +4392,7 @@ public void setText(@Nullable String text) { @Document(indexName = "immutable-class") private static final class ImmutableEntity { - @Id private final String id; + @Id @Nullable private final String id; @Field(type = FieldType.Text) private final String text; @Nullable private final SeqNoPrimaryTerm seqNoPrimaryTerm; @@ -4376,6 +4402,7 @@ public ImmutableEntity(@Nullable String id, String text, @Nullable SeqNoPrimaryT this.seqNoPrimaryTerm = seqNoPrimaryTerm; } + @Nullable public String getId() { return id; } @@ -4419,5 +4446,61 @@ public String toString() { + seqNoPrimaryTerm + '}'; } } + + @Document(indexName = "immutable-scripted") + public static final class ImmutableWithScriptedEntity { + @Id private final String id; + @Field(type = Integer) @Nullable private final int rate; + @Nullable @ScriptedField private final Double scriptedRate; + + public ImmutableWithScriptedEntity(String id, int rate, @Nullable java.lang.Double scriptedRate) { + this.id = id; + this.rate = rate; + this.scriptedRate = scriptedRate; + } + + public String getId() { + return id; + } + + public int getRate() { + return rate; + } + + @Nullable + public Double getScriptedRate() { + return scriptedRate; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + ImmutableWithScriptedEntity that = (ImmutableWithScriptedEntity) o; + + if (rate != that.rate) + return false; + if (!id.equals(that.id)) + return false; + return scriptedRate != null ? scriptedRate.equals(that.scriptedRate) : that.scriptedRate == null; + } + + @Override + public int hashCode() { + int result = id.hashCode(); + result = 31 * result + rate; + result = 31 * result + (scriptedRate != null ? scriptedRate.hashCode() : 0); + return result; + } + + @Override + public String toString() { + return "ImmutableWithScriptedEntity{" + "id='" + id + '\'' + ", rate=" + rate + ", scriptedRate=" + scriptedRate + + '}'; + } + } // endregion }