Skip to content

Commit 1cd392e

Browse files
committed
Use prepare(String) for Prepared Statement preparation to prevent bind values from being cached.
We now also apply query options from the initial statement to the bound statement by copying these if query options are set. Closes #1213
1 parent 1d1ae65 commit 1cd392e

File tree

5 files changed

+184
-12
lines changed

5 files changed

+184
-12
lines changed

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/AsyncCassandraTemplate.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,7 @@ public ListenableFuture<PreparedStatement> createPreparedStatement(CqlSession se
10811081
* @return
10821082
*/
10831083
protected CompletionStage<PreparedStatement> doPrepare(CqlSession session) {
1084-
return session.prepareAsync(statement);
1084+
return session.prepareAsync(statement.getQuery());
10851085
}
10861086

10871087
/*

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ public PreparedStatementHandler(Statement<?> statement) {
10871087
*/
10881088
@Override
10891089
public PreparedStatement createPreparedStatement(CqlSession session) throws DriverException {
1090-
return session.prepare(statement);
1090+
return session.prepare(statement.getQuery());
10911091
}
10921092

10931093
/*

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/PreparedStatementDelegate.java

+180-9
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,18 @@
1616
package org.springframework.data.cassandra.core;
1717

1818
import java.util.Map;
19+
import java.util.Objects;
20+
import java.util.function.Consumer;
21+
import java.util.function.Predicate;
22+
import java.util.function.Supplier;
1923

2024
import org.apache.commons.logging.Log;
2125

2226
import org.springframework.data.cassandra.core.cql.QueryExtractorDelegate;
27+
import org.springframework.lang.Nullable;
28+
import org.springframework.util.Assert;
2329
import org.springframework.util.StringUtils;
30+
import org.springframework.util.function.SingletonSupplier;
2431

2532
import com.datastax.oss.driver.api.core.CqlIdentifier;
2633
import com.datastax.oss.driver.api.core.cql.BoundStatement;
@@ -30,6 +37,7 @@
3037
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
3138
import com.datastax.oss.driver.api.core.cql.Statement;
3239
import com.datastax.oss.driver.api.core.type.DataType;
40+
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
3341

3442
/**
3543
* Support class for Cassandra Template API implementation classes that want to make use of prepared statements.
@@ -40,30 +48,51 @@
4048
class PreparedStatementDelegate {
4149

4250
/**
43-
* Bind values held in {@link SimpleStatement} to the {@link PreparedStatement}.
51+
* Bind values held in {@link SimpleStatement} to the {@link PreparedStatement} and apply query options that are set
52+
* or do not match the default value.
4453
*
45-
* @param statement
54+
* @param source
4655
* @param ps
4756
* @return the bound statement.
4857
*/
49-
static BoundStatement bind(SimpleStatement statement, PreparedStatement ps) {
58+
static BoundStatement bind(SimpleStatement source, PreparedStatement ps) {
59+
60+
BoundStatementBuilder builder = ps.boundStatementBuilder(source.getPositionalValues().toArray());
61+
62+
Mapper mapper = Mapper.INSTANCE;
5063

51-
BoundStatementBuilder boundStatementBuilder = ps.boundStatementBuilder(statement.getPositionalValues().toArray());
52-
Map<CqlIdentifier, Object> namedValues = statement.getNamedValues();
64+
mapper.from(source.getExecutionProfileName()).whenHasText().to(builder::setExecutionProfileName);
65+
mapper.from(source.getExecutionProfile()).whenNonNull().to(builder::setExecutionProfile);
66+
mapper.from(source.getRoutingKeyspace()).whenNonNull().to(builder::setRoutingKeyspace);
67+
mapper.from(source.getRoutingKey()).whenNonNull().to(builder::setRoutingKey);
68+
mapper.from(source.getRoutingToken()).whenNonNull().to(builder::setRoutingToken);
69+
mapper.from(source.isIdempotent()).whenNonNull().to(builder::setIdempotence);
70+
mapper.from(source.isTracing()).whenNonNull().to(builder::setTracing);
71+
mapper.from(source.getQueryTimestamp()).whenNot(it -> it == Statement.NO_DEFAULT_TIMESTAMP)
72+
.to(builder::setQueryTimestamp);
73+
mapper.from(source.getPagingState()).whenNonNull().to(builder::setPagingState);
74+
mapper.from(source.getPageSize()).whenNot(it -> it == 0L).to(builder::setPageSize);
75+
mapper.from(source.getConsistencyLevel()).whenNonNull().to(builder::setConsistencyLevel);
76+
mapper.from(source.getSerialConsistencyLevel()).whenNonNull().to(builder::setSerialConsistencyLevel);
77+
mapper.from(source.getTimeout()).whenNonNull().to(builder::setTimeout);
78+
mapper.from(source.getNode()).whenNonNull().to(builder::setNode);
79+
mapper.from(source.getNowInSeconds()).whenNot(it -> it == Statement.NO_NOW_IN_SECONDS).to(builder::setNowInSeconds);
80+
81+
Map<CqlIdentifier, Object> namedValues = source.getNamedValues();
5382

5483
ColumnDefinitions variableDefinitions = ps.getVariableDefinitions();
84+
CodecRegistry codecRegistry = builder.codecRegistry();
5585
for (Map.Entry<CqlIdentifier, Object> entry : namedValues.entrySet()) {
5686

5787
if (entry.getValue() == null) {
58-
boundStatementBuilder = boundStatementBuilder.setToNull(entry.getKey());
88+
builder = builder.setToNull(entry.getKey());
5989
} else {
6090
DataType type = variableDefinitions.get(entry.getKey()).getType();
61-
boundStatementBuilder = boundStatementBuilder.set(entry.getKey(), entry.getValue(),
62-
boundStatementBuilder.codecRegistry().codecFor(type));
91+
builder = builder.set(entry.getKey(), entry.getValue(), codecRegistry.codecFor(type));
6392
}
6493
}
6594

66-
return ps.bind(statement.getPositionalValues().toArray());
95+
return builder.build();
6796
}
6897

6998
/**
@@ -117,4 +146,146 @@ private static String getMessage(Statement<?> statement) {
117146
return String.format("Cannot prepare statement %s. Statement must be a SimpleStatement.", statement);
118147
}
119148

149+
enum Mapper {
150+
151+
INSTANCE;
152+
153+
/**
154+
* Return a new {@link Source} from the specified value supplier that can be used to perform the mapping.
155+
*
156+
* @param <T> the source type
157+
* @param supplier the value supplier
158+
* @return a {@link Source} that can be used to complete the mapping
159+
* @see #from(Object)
160+
*/
161+
public <T> Source<T> from(Supplier<T> supplier) {
162+
163+
Assert.notNull(supplier, "Supplier must not be null");
164+
return getSource(supplier);
165+
}
166+
167+
/**
168+
* Return a new {@link Source} from the specified value that can be used to perform the mapping.
169+
*
170+
* @param <T> the source type
171+
* @param value the value
172+
* @return a {@link Source} that can be used to complete the mapping
173+
*/
174+
public <T> Source<T> from(@Nullable T value) {
175+
return from(() -> value);
176+
}
177+
178+
private <T> Source<T> getSource(Supplier<T> supplier) {
179+
return new Source<>(SingletonSupplier.of(supplier), t -> true);
180+
}
181+
}
182+
183+
/**
184+
* A source value/supplier that is in the process of being mapped.
185+
*
186+
* @param <T> the source type
187+
*/
188+
static class Source<T> {
189+
190+
private final Supplier<T> supplier;
191+
192+
private final Predicate<T> predicate;
193+
194+
private Source(Supplier<T> supplier, Predicate<T> predicate) {
195+
196+
Assert.notNull(predicate, "Predicate must not be null");
197+
198+
this.supplier = supplier;
199+
this.predicate = predicate;
200+
}
201+
202+
/**
203+
* Return a filtered version of the source that won't map non-null values or suppliers that throw a
204+
* {@link NullPointerException}.
205+
*
206+
* @return a new filtered source instance
207+
*/
208+
public Source<T> whenNonNull() {
209+
return new Source<>(this.supplier, Objects::nonNull);
210+
}
211+
212+
/**
213+
* Return a filtered version of the source that will only map values that are {@code true}.
214+
*
215+
* @return a new filtered source instance
216+
*/
217+
public Source<T> whenTrue() {
218+
return when(Boolean.TRUE::equals);
219+
}
220+
221+
/**
222+
* Return a filtered version of the source that will only map values that are {@code false}.
223+
*
224+
* @return a new filtered source instance
225+
*/
226+
public Source<T> whenFalse() {
227+
return when(Boolean.FALSE::equals);
228+
}
229+
230+
/**
231+
* Return a filtered version of the source that will only map values that have a {@code toString()} containing
232+
* actual text.
233+
*
234+
* @return a new filtered source instance
235+
*/
236+
public Source<T> whenHasText() {
237+
return when((value) -> StringUtils.hasText(Objects.toString(value, null)));
238+
}
239+
240+
/**
241+
* Return a filtered version of the source that will only map values equal to the specified {@code object}.
242+
*
243+
* @param object the object to match
244+
* @return a new filtered source instance
245+
*/
246+
public Source<T> whenEqualTo(Object object) {
247+
return when(object::equals);
248+
}
249+
250+
/**
251+
* Return a filtered version of the source that won't map values that match the given predicate.
252+
*
253+
* @param predicate the predicate used to filter values
254+
* @return a new filtered source instance
255+
*/
256+
public Source<T> whenNot(Predicate<T> predicate) {
257+
258+
Assert.notNull(predicate, "Predicate must not be null");
259+
return when(predicate.negate());
260+
}
261+
262+
/**
263+
* Return a filtered version of the source that won't map values that don't match the given predicate.
264+
*
265+
* @param predicate the predicate used to filter values
266+
* @return a new filtered source instance
267+
*/
268+
public Source<T> when(Predicate<T> predicate) {
269+
270+
Assert.notNull(predicate, "Predicate must not be null");
271+
return new Source<>(this.supplier, (this.predicate != null) ? this.predicate.and(predicate) : predicate);
272+
}
273+
274+
/**
275+
* Complete the mapping by passing any non-filtered value to the specified consumer.
276+
*
277+
* @param consumer the consumer that should accept the value if it's not been filtered
278+
*/
279+
public void to(Consumer<T> consumer) {
280+
281+
Assert.notNull(consumer, "Consumer must not be null");
282+
283+
T value = this.supplier.get();
284+
if (this.predicate.test(value)) {
285+
consumer.accept(value);
286+
}
287+
}
288+
289+
}
290+
120291
}

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ReactiveCassandraTemplate.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -1060,7 +1060,7 @@ public PreparedStatementHandler(Statement<?> statement) {
10601060
*/
10611061
@Override
10621062
public Mono<PreparedStatement> createPreparedStatement(ReactiveSession session) throws DriverException {
1063-
return session.prepare(statement);
1063+
return session.prepare(statement.getQuery());
10641064
}
10651065

10661066
/*

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/cql/QueryOptionsUtil.java

+1
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,5 @@ private static int getTtlSeconds(Duration ttl) {
181181
private static boolean hasTtl(Duration ttl) {
182182
return !ttl.isZero() && !ttl.isNegative();
183183
}
184+
184185
}

0 commit comments

Comments
 (0)