Skip to content

Commit d42f950

Browse files
schnapsterrstoyanchev
authored andcommitted
Pass headers to STOMP receipt callbacks
See gh-28715
1 parent 4eabe29 commit d42f950

File tree

4 files changed

+53
-18
lines changed

4 files changed

+53
-18
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -27,6 +27,7 @@
2727
import java.util.concurrent.ScheduledFuture;
2828
import java.util.concurrent.TimeUnit;
2929
import java.util.concurrent.atomic.AtomicInteger;
30+
import java.util.function.Consumer;
3031

3132
import org.apache.commons.logging.Log;
3233

@@ -441,7 +442,7 @@ else if (logger.isDebugEnabled()) {
441442
String receiptId = headers.getReceiptId();
442443
ReceiptHandler handler = this.receiptHandlers.get(receiptId);
443444
if (handler != null) {
444-
handler.handleReceiptReceived();
445+
handler.handleReceiptReceived(headers);
445446
}
446447
else if (logger.isDebugEnabled()) {
447448
logger.debug("No matching receipt: " + accessor.getDetailedLogMessage(message.getPayload()));
@@ -546,16 +547,19 @@ private class ReceiptHandler implements Receiptable {
546547
@Nullable
547548
private final String receiptId;
548549

549-
private final List<Runnable> receiptCallbacks = new ArrayList<>(2);
550+
private final List<Consumer<StompHeaders>> receiptCallbacks = new ArrayList<>(2);
550551

551-
private final List<Runnable> receiptLostCallbacks = new ArrayList<>(2);
552+
private final List<Consumer<StompHeaders>> receiptLostCallbacks = new ArrayList<>(2);
552553

553554
@Nullable
554555
private ScheduledFuture<?> future;
555556

556557
@Nullable
557558
private Boolean result;
558559

560+
@Nullable
561+
private StompHeaders receiptHeaders;
562+
559563
public ReceiptHandler(@Nullable String receiptId) {
560564
this.receiptId = receiptId;
561565
if (receiptId != null) {
@@ -578,15 +582,20 @@ public String getReceiptId() {
578582

579583
@Override
580584
public void addReceiptTask(Runnable task) {
585+
addTask(h -> task.run(), true);
586+
}
587+
588+
@Override
589+
public void addReceiptTask(Consumer<StompHeaders> task) {
581590
addTask(task, true);
582591
}
583592

584593
@Override
585594
public void addReceiptLostTask(Runnable task) {
586-
addTask(task, false);
595+
addTask(h -> task.run(), false);
587596
}
588597

589-
private void addTask(Runnable task, boolean successTask) {
598+
private void addTask(Consumer<StompHeaders> task, boolean successTask) {
590599
Assert.notNull(this.receiptId,
591600
"To track receipts, set autoReceiptEnabled=true or add 'receiptId' header");
592601
synchronized (this) {
@@ -604,31 +613,32 @@ private void addTask(Runnable task, boolean successTask) {
604613
}
605614
}
606615

607-
private void invoke(List<Runnable> callbacks) {
608-
for (Runnable runnable : callbacks) {
616+
private void invoke(List<Consumer<StompHeaders>> callbacks) {
617+
for (Consumer<StompHeaders> consumer : callbacks) {
609618
try {
610-
runnable.run();
619+
consumer.accept(this.receiptHeaders);
611620
}
612621
catch (Throwable ex) {
613622
// ignore
614623
}
615624
}
616625
}
617626

618-
public void handleReceiptReceived() {
619-
handleInternal(true);
627+
public void handleReceiptReceived(StompHeaders receiptHeaders) {
628+
handleInternal(true, receiptHeaders);
620629
}
621630

622631
public void handleReceiptNotReceived() {
623-
handleInternal(false);
632+
handleInternal(false, null);
624633
}
625634

626-
private void handleInternal(boolean result) {
635+
private void handleInternal(boolean result, @Nullable StompHeaders receiptHeaders) {
627636
synchronized (this) {
628637
if (this.result != null) {
629638
return;
630639
}
631640
this.result = result;
641+
this.receiptHeaders = receiptHeaders;
632642
invoke(result ? this.receiptCallbacks : this.receiptLostCallbacks);
633643
DefaultStompSession.this.receiptHandlers.remove(this.receiptId);
634644
if (this.future != null) {

spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.messaging.simp.stomp;
1818

19+
import java.util.function.Consumer;
20+
1921
import org.springframework.lang.Nullable;
2022

2123
/**
@@ -143,6 +145,13 @@ interface Receiptable {
143145
*/
144146
void addReceiptTask(Runnable runnable);
145147

148+
/**
149+
* Consumer to invoke when a receipt is received. Accepts the headers of the received RECEIPT frame.
150+
* @throws java.lang.IllegalArgumentException if the receiptId is {@code null}
151+
* @since TBD
152+
*/
153+
void addReceiptTask(Consumer<StompHeaders> task);
154+
146155
/**
147156
* Task to invoke when a receipt is not received in the configured time.
148157
* @throws java.lang.IllegalArgumentException if the receiptId is {@code null}

spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -576,22 +576,30 @@ public void receiptReceived() {
576576
this.session.setTaskScheduler(mock(TaskScheduler.class));
577577

578578
AtomicReference<Boolean> received = new AtomicReference<>();
579+
AtomicReference<StompHeaders> receivedHeaders = new AtomicReference<>();
579580

580581
StompHeaders headers = new StompHeaders();
581582
headers.setDestination("/topic/foo");
582583
headers.setReceipt("my-receipt");
583584
Subscription subscription = this.session.subscribe(headers, mock(StompFrameHandler.class));
584-
subscription.addReceiptTask(() -> received.set(true));
585+
subscription.addReceiptTask(receiptHeaders -> {
586+
received.set(true);
587+
receivedHeaders.set(receiptHeaders);
588+
});
585589

586590
assertThat((Object) received.get()).isNull();
587591

588592
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.RECEIPT);
589593
accessor.setReceiptId("my-receipt");
594+
accessor.setNativeHeader("foo", "bar");
590595
accessor.setLeaveMutable(true);
591596
this.session.handleMessage(MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()));
592597

593598
assertThat(received.get()).isNotNull();
594599
assertThat(received.get()).isTrue();
600+
assertThat(receivedHeaders.get()).isNotNull();
601+
assertThat(receivedHeaders.get().get("foo").size()).isEqualTo(1);
602+
assertThat(receivedHeaders.get().get("foo").get(0)).isEqualTo("bar");
595603
}
596604

597605
@Test
@@ -600,6 +608,7 @@ public void receiptReceivedBeforeTaskAdded() {
600608
this.session.setTaskScheduler(mock(TaskScheduler.class));
601609

602610
AtomicReference<Boolean> received = new AtomicReference<>();
611+
AtomicReference<StompHeaders> receivedHeaders = new AtomicReference<>();
603612

604613
StompHeaders headers = new StompHeaders();
605614
headers.setDestination("/topic/foo");
@@ -608,13 +617,20 @@ public void receiptReceivedBeforeTaskAdded() {
608617

609618
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.RECEIPT);
610619
accessor.setReceiptId("my-receipt");
620+
accessor.setNativeHeader("foo", "bar");
611621
accessor.setLeaveMutable(true);
612622
this.session.handleMessage(MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()));
613623

614-
subscription.addReceiptTask(() -> received.set(true));
624+
subscription.addReceiptTask(receiptHeaders -> {
625+
received.set(true);
626+
receivedHeaders.set(receiptHeaders);
627+
});
615628

616629
assertThat(received.get()).isNotNull();
617630
assertThat(received.get()).isTrue();
631+
assertThat(receivedHeaders.get()).isNotNull();
632+
assertThat(receivedHeaders.get().get("foo").size()).isEqualTo(1);
633+
assertThat(receivedHeaders.get().get("foo").get(0)).isEqualTo("bar");
618634
}
619635

620636
@Test

src/docs/asciidoc/web/websocket.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1347,7 +1347,7 @@ receipt if the server supports it (simple broker does not). For example, with th
13471347
headers.setDestination("/topic/...");
13481348
headers.setReceipt("r1");
13491349
FrameHandler handler = ...;
1350-
stompSession.subscribe(headers, handler).addReceiptTask(() -> {
1350+
stompSession.subscribe(headers, handler).addReceiptTask(receiptHeaders -> {
13511351
// Subscription ready...
13521352
});
13531353
----

0 commit comments

Comments
 (0)