Skip to content

Commit aaf3cab

Browse files
committed
Pass-thru custom Redis commands using Lettuce.
We now accept unknown custom Redis commands when using the Lettuce driver. Previously, custom commands were required to exist in Lettuce's CommandType enumeration and unknown commands (such as modules) failed to run. Closes #1979
1 parent 9179967 commit aaf3cab

File tree

2 files changed

+88
-11
lines changed

2 files changed

+88
-11
lines changed

Diff for: src/main/java/org/springframework/data/redis/connection/lettuce/LettuceConnection.java

+66-10
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@
3737
import io.lettuce.core.protocol.Command;
3838
import io.lettuce.core.protocol.CommandArgs;
3939
import io.lettuce.core.protocol.CommandType;
40+
import io.lettuce.core.protocol.ProtocolKeyword;
4041
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection;
4142
import io.lettuce.core.sentinel.api.StatefulRedisSentinelConnection;
4243

4344
import java.lang.reflect.Constructor;
45+
import java.nio.charset.StandardCharsets;
4446
import java.util.ArrayList;
4547
import java.util.Collections;
4648
import java.util.HashMap;
@@ -404,7 +406,7 @@ public Object execute(String command, @Nullable CommandOutput commandOutputTypeH
404406
try {
405407

406408
String name = command.trim().toUpperCase();
407-
CommandType commandType = CommandType.valueOf(name);
409+
ProtocolKeyword commandType = getCommandType(name);
408410

409411
validateCommandIfRunningInTransactionMode(commandType, args);
410412

@@ -1045,14 +1047,14 @@ io.lettuce.core.ScanCursor getScanCursor(long cursorId) {
10451047
return io.lettuce.core.ScanCursor.of(Long.toString(cursorId));
10461048
}
10471049

1048-
private void validateCommandIfRunningInTransactionMode(CommandType cmd, byte[]... args) {
1050+
private void validateCommandIfRunningInTransactionMode(ProtocolKeyword cmd, byte[]... args) {
10491051

10501052
if (this.isQueueing()) {
10511053
validateCommand(cmd, args);
10521054
}
10531055
}
10541056

1055-
private void validateCommand(CommandType cmd, @Nullable byte[]... args) {
1057+
private void validateCommand(ProtocolKeyword cmd, @Nullable byte[]... args) {
10561058

10571059
RedisCommand redisCommand = RedisCommand.failsafeCommandLookup(cmd.name());
10581060
if (!RedisCommand.UNKNOWN.equals(redisCommand) && redisCommand.requiresArguments()) {
@@ -1105,6 +1107,15 @@ LettuceConnectionProvider getConnectionProvider() {
11051107
return connectionProvider;
11061108
}
11071109

1110+
private static ProtocolKeyword getCommandType(String name) {
1111+
1112+
try {
1113+
return CommandType.valueOf(name);
1114+
} catch (IllegalArgumentException e) {
1115+
return new CustomCommandType(name);
1116+
}
1117+
}
1118+
11081119
/**
11091120
* {@link TypeHints} provide {@link CommandOutput} information for a given {@link CommandType}.
11101121
*
@@ -1113,7 +1124,7 @@ LettuceConnectionProvider getConnectionProvider() {
11131124
static class TypeHints {
11141125

11151126
@SuppressWarnings("rawtypes") //
1116-
private static final Map<CommandType, Class<? extends CommandOutput>> COMMAND_OUTPUT_TYPE_MAPPING = new HashMap<>();
1127+
private static final Map<ProtocolKeyword, Class<? extends CommandOutput>> COMMAND_OUTPUT_TYPE_MAPPING = new HashMap<>();
11171128

11181129
@SuppressWarnings("rawtypes") //
11191130
private static final Map<Class<?>, Constructor<CommandOutput>> CONSTRUCTORS = new ConcurrentHashMap<>();
@@ -1275,7 +1286,7 @@ static class TypeHints {
12751286
* @return {@link ByteArrayOutput} as default when no matching {@link CommandOutput} available.
12761287
*/
12771288
@SuppressWarnings("rawtypes")
1278-
public CommandOutput getTypeHint(CommandType type) {
1289+
public CommandOutput getTypeHint(ProtocolKeyword type) {
12791290
return getTypeHint(type, new ByteArrayOutput<>(CODEC));
12801291
}
12811292

@@ -1286,7 +1297,7 @@ public CommandOutput getTypeHint(CommandType type) {
12861297
* @return
12871298
*/
12881299
@SuppressWarnings("rawtypes")
1289-
public CommandOutput getTypeHint(CommandType type, CommandOutput defaultType) {
1300+
public CommandOutput getTypeHint(ProtocolKeyword type, CommandOutput defaultType) {
12901301

12911302
if (type == null || !COMMAND_OUTPUT_TYPE_MAPPING.containsKey(type)) {
12921303
return defaultType;
@@ -1407,7 +1418,7 @@ static PipeliningFlushPolicy buffered(int bufferSize) {
14071418

14081419
/**
14091420
* State object associated with flushing of the currently ongoing pipeline.
1410-
*
1421+
*
14111422
* @author Mark Paluch
14121423
* @since 2.3
14131424
*/
@@ -1440,7 +1451,7 @@ public interface PipeliningFlushState {
14401451

14411452
/**
14421453
* Implementation to flush on each command.
1443-
*
1454+
*
14441455
* @author Mark Paluch
14451456
* @since 2.3
14461457
*/
@@ -1465,7 +1476,7 @@ public void onClose(StatefulConnection<?, ?> connection) {}
14651476

14661477
/**
14671478
* Implementation to flush on closing the pipeline.
1468-
*
1479+
*
14691480
* @author Mark Paluch
14701481
* @since 2.3
14711482
*/
@@ -1497,7 +1508,7 @@ public void onClose(StatefulConnection<?, ?> connection) {
14971508

14981509
/**
14991510
* Pipeline state for buffered flushing.
1500-
*
1511+
*
15011512
* @author Mark Paluch
15021513
* @since 2.3
15031514
*/
@@ -1529,4 +1540,49 @@ public void onClose(StatefulConnection<?, ?> connection) {
15291540
connection.setAutoFlushCommands(true);
15301541
}
15311542
}
1543+
1544+
/**
1545+
* @since 2.3.8
1546+
*/
1547+
static class CustomCommandType implements ProtocolKeyword {
1548+
1549+
private final String name;
1550+
1551+
CustomCommandType(String name) {
1552+
this.name = name;
1553+
}
1554+
1555+
@Override
1556+
public byte[] getBytes() {
1557+
return name.getBytes(StandardCharsets.US_ASCII);
1558+
}
1559+
1560+
@Override
1561+
public String name() {
1562+
return name;
1563+
}
1564+
1565+
@Override
1566+
public boolean equals(Object o) {
1567+
1568+
if (this == o) {
1569+
return true;
1570+
}
1571+
if (!(o instanceof CustomCommandType)) {
1572+
return false;
1573+
}
1574+
CustomCommandType that = (CustomCommandType) o;
1575+
return ObjectUtils.nullSafeEquals(name, that.name);
1576+
}
1577+
1578+
@Override
1579+
public int hashCode() {
1580+
return ObjectUtils.nullSafeHashCode(name);
1581+
}
1582+
1583+
@Override
1584+
public String toString() {
1585+
return name;
1586+
}
1587+
}
15321588
}

Diff for: src/test/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionUnitTests.java

+22-1
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,18 @@
1919
import static org.mockito.Mockito.*;
2020

2121
import io.lettuce.core.RedisClient;
22+
import io.lettuce.core.RedisFuture;
2223
import io.lettuce.core.XAddArgs;
2324
import io.lettuce.core.XClaimArgs;
2425
import io.lettuce.core.api.StatefulRedisConnection;
2526
import io.lettuce.core.api.async.RedisAsyncCommands;
2627
import io.lettuce.core.api.sync.RedisCommands;
28+
import io.lettuce.core.codec.ByteArrayCodec;
2729
import io.lettuce.core.codec.RedisCodec;
30+
import io.lettuce.core.output.StatusOutput;
31+
import io.lettuce.core.protocol.AsyncCommand;
32+
import io.lettuce.core.protocol.Command;
33+
import io.lettuce.core.protocol.CommandArgs;
2834

2935
import java.lang.reflect.InvocationTargetException;
3036
import java.time.Duration;
@@ -33,6 +39,7 @@
3339
import org.junit.jupiter.api.BeforeEach;
3440
import org.junit.jupiter.api.Test;
3541
import org.mockito.ArgumentCaptor;
42+
3643
import org.springframework.dao.InvalidDataAccessResourceUsageException;
3744
import org.springframework.data.redis.connection.AbstractConnectionUnitTestBase;
3845
import org.springframework.data.redis.connection.RedisServerCommands.ShutdownOption;
@@ -198,7 +205,6 @@ void xClaimShouldNotAddJustIdFlagToArgs() {
198205
}
199206

200207
assertThat(ReflectionTestUtils.getField(args.getValue(), "justid")).isEqualTo(false);
201-
202208
}
203209

204210
@Test // DATAREDIS-1226
@@ -216,6 +222,21 @@ void xClaimJustIdShouldAddJustIdFlagToArgs() {
216222

217223
assertThat(ReflectionTestUtils.getField(args.getValue(), "justid")).isEqualTo(true);
218224
}
225+
226+
@Test // GH-1979
227+
void executeShouldPassThruCustomCommands() {
228+
229+
Command<byte[], byte[], String> command = new Command<>(new LettuceConnection.CustomCommandType("FOO.BAR"),
230+
new StatusOutput<>(ByteArrayCodec.INSTANCE));
231+
AsyncCommand<byte[], byte[], String> future = new AsyncCommand<>(command);
232+
future.complete();
233+
234+
when(asyncCommandsMock.dispatch(any(), any(), any())).thenReturn((RedisFuture) future);
235+
236+
connection.execute("foo.bar", command.getOutput());
237+
238+
verify(asyncCommandsMock).dispatch(eq(command.getType()), eq(command.getOutput()), any(CommandArgs.class));
239+
}
219240
}
220241

221242
public static class LettucePipelineConnectionUnitTests extends BasicUnitTests {

0 commit comments

Comments
 (0)