diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/config/xml/MqttMessageDrivenChannelAdapterParser.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/config/xml/MqttMessageDrivenChannelAdapterParser.java index c0f05904be1..d8b335af06a 100644 --- a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/config/xml/MqttMessageDrivenChannelAdapterParser.java +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/config/xml/MqttMessageDrivenChannelAdapterParser.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,7 +48,6 @@ protected AbstractBeanDefinition doParse(Element element, ParserContext parserCo builder.addPropertyReference("outputChannel", channelName); IntegrationNamespaceUtils.setReferenceIfAttributeDefined(builder, element, "error-channel"); IntegrationNamespaceUtils.setValueIfAttributeDefined(builder, element, "qos"); - IntegrationNamespaceUtils.setValueIfAttributeDefined(builder, element, "recovery-interval"); IntegrationNamespaceUtils.setValueIfAttributeDefined(builder, element, "manual-acks"); return builder.getBeanDefinition(); diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/AbstractMqttClientManager.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/AbstractMqttClientManager.java new file mode 100644 index 00000000000..a22efb8ad74 --- /dev/null +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/AbstractMqttClientManager.java @@ -0,0 +1,176 @@ +/* + * Copyright 2022-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.mqtt.core; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; +import org.springframework.context.SmartLifecycle; +import org.springframework.integration.mqtt.inbound.AbstractMqttMessageDrivenChannelAdapter; +import org.springframework.util.Assert; + +/** + * Abstract class for MQTT client managers which can be a base for any common v3/v5 client manager implementation. + * Contains some basic utility and implementation-agnostic fields and methods. + * + * @param MQTT client type + * @param MQTT connection options type (v5 or v3) + * @param

MQTT client persistence type (for v5 or v3) + * + * @author Artem Vozhdayenko + * + * @since 6.0 + */ +public abstract class AbstractMqttClientManager implements ClientManager, ApplicationEventPublisherAware { + + protected final Log logger = LogFactory.getLog(this.getClass()); // NOSONAR + + private static final int DEFAULT_MANAGER_PHASE = 0; + + private final Set connectCallbacks = Collections.synchronizedSet(new HashSet<>()); + + private final String clientId; + + private int phase = DEFAULT_MANAGER_PHASE; + + private boolean manualAcks; + + private ApplicationEventPublisher applicationEventPublisher; + + private P persistence; + + private String url; + + private String beanName; + + private volatile T client; + + AbstractMqttClientManager(String clientId) { + Assert.notNull(clientId, "'clientId' is required"); + this.clientId = clientId; + } + + protected void setManualAcks(boolean manualAcks) { + this.manualAcks = manualAcks; + } + + protected String getUrl() { + return this.url; + } + + protected void setUrl(String url) { + this.url = url; + } + + protected String getClientId() { + return this.clientId; + } + + protected ApplicationEventPublisher getApplicationEventPublisher() { + return this.applicationEventPublisher; + } + + protected synchronized void setClient(T client) { + this.client = client; + } + + protected P getPersistence() { + return this.persistence; + } + + /** + * Set client persistence if some specific impl is required for topics QoS. + * @param persistence persistence implementation to use for te client + */ + public void setPersistence(P persistence) { + this.persistence = persistence; + } + + protected Set getCallbacks() { + return this.connectCallbacks; + } + + @Override + public boolean isManualAcks() { + return this.manualAcks; + } + + @Override + public T getClient() { + return this.client; + } + + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { + Assert.notNull(applicationEventPublisher, "'applicationEventPublisher' cannot be null"); + this.applicationEventPublisher = applicationEventPublisher; + } + + @Override + public void setBeanName(String name) { + this.beanName = name; + } + + @Override + public String getBeanName() { + return this.beanName; + } + + /** + * The phase of component autostart in {@link SmartLifecycle}. + * If the custom one is required, note that for the correct behavior it should be less than phase of + * {@link AbstractMqttMessageDrivenChannelAdapter} implementations. + * The default phase is {@link #DEFAULT_MANAGER_PHASE}. + * @return {@link SmartLifecycle} autostart phase + * @see #setPhase + */ + @Override + public int getPhase() { + return this.phase; + } + + @Override + public void addCallback(ConnectCallback connectCallback) { + this.connectCallbacks.add(connectCallback); + } + + @Override + public boolean removeCallback(ConnectCallback connectCallback) { + return this.connectCallbacks.remove(connectCallback); + } + + public synchronized boolean isRunning() { + return this.client != null; + } + + /** + * Set the phase of component autostart in {@link SmartLifecycle}. + * If the custom one is required, note that for the correct behavior it should be less than phase of + * {@link AbstractMqttMessageDrivenChannelAdapter} implementations. + * @see #getPhase + */ + public void setPhase(int phase) { + this.phase = phase; + } + +} diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/ClientManager.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/ClientManager.java new file mode 100644 index 00000000000..bbb909c9615 --- /dev/null +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/ClientManager.java @@ -0,0 +1,55 @@ +/* + * Copyright 2022-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.mqtt.core; + +import org.springframework.context.SmartLifecycle; + +/** + * A utility abstraction over MQTT client which can be used in any MQTT-related component + * without need to handle generic client callbacks, reconnects etc. + * Using this manager in multiple MQTT integrations will preserve a single connection. + * + * @param MQTT client type + * @param MQTT connection options type (v5 or v3) + * + * @author Artem Vozhdayenko + * + * @since 6.0 + */ +public interface ClientManager extends SmartLifecycle, MqttComponent { + + T getClient(); + + boolean isManualAcks(); + + void addCallback(ConnectCallback connectCallback); + + boolean removeCallback(ConnectCallback connectCallback); + + /** + * A contract for a custom callback if needed by a usage. + * + * @see org.eclipse.paho.mqttv5.client.MqttCallback#connectComplete + * @see org.eclipse.paho.client.mqttv3.MqttCallbackExtended#connectComplete + */ + interface ConnectCallback { + + void connectComplete(boolean isReconnect); + + } + +} diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/Mqttv3ClientManager.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/Mqttv3ClientManager.java new file mode 100644 index 00000000000..c0f17d87e2c --- /dev/null +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/Mqttv3ClientManager.java @@ -0,0 +1,167 @@ +/* + * Copyright 2022-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.mqtt.core; + +import org.eclipse.paho.client.mqttv3.IMqttAsyncClient; +import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken; +import org.eclipse.paho.client.mqttv3.MqttAsyncClient; +import org.eclipse.paho.client.mqttv3.MqttCallbackExtended; +import org.eclipse.paho.client.mqttv3.MqttClientPersistence; +import org.eclipse.paho.client.mqttv3.MqttConnectOptions; +import org.eclipse.paho.client.mqttv3.MqttException; +import org.eclipse.paho.client.mqttv3.MqttMessage; + +import org.springframework.integration.mqtt.event.MqttConnectionFailedEvent; +import org.springframework.util.Assert; + +/** + * A client manager implementation for MQTT v3 protocol. Requires a client ID and server URI. + * If needed, the connection options may be overridden and passed as a {@link MqttConnectOptions} dependency. + * By default, automatic reconnect is used. If it is required to be turned off, one should listen for + * {@link MqttConnectionFailedEvent} and reconnect the MQTT client manually. + * + * @author Artem Vozhdayenko + * @since 6.0 + */ +public class Mqttv3ClientManager + extends AbstractMqttClientManager + implements MqttCallbackExtended { + + private final MqttConnectOptions connectionOptions; + + public Mqttv3ClientManager(String url, String clientId) { + this(buildDefaultConnectionOptions(url), clientId); + } + + public Mqttv3ClientManager(MqttConnectOptions connectionOptions, String clientId) { + super(clientId); + Assert.notNull(connectionOptions, "'connectionOptions' is required"); + this.connectionOptions = connectionOptions; + String[] serverURIs = connectionOptions.getServerURIs(); + Assert.notEmpty(serverURIs, "'serverURIs' must be provided in the 'MqttConnectionOptions'"); + setUrl(serverURIs[0]); + if (!connectionOptions.isAutomaticReconnect()) { + logger.info("If this `ClientManager` is used from message-driven channel adapters, " + + "it is recommended to set 'automaticReconnect' MQTT connection option. " + + "Otherwise connection check and reconnect should be done manually."); + } + } + + private static MqttConnectOptions buildDefaultConnectionOptions(String url) { + Assert.notNull(url, "'url' is required"); + MqttConnectOptions connectOptions = new MqttConnectOptions(); + connectOptions.setServerURIs(new String[]{ url }); + connectOptions.setAutomaticReconnect(true); + return connectOptions; + } + + @Override + public synchronized void start() { + if (getClient() == null) { + try { + setClient(createClient()); + } + catch (MqttException e) { + throw new IllegalStateException("could not start client manager", e); + } + } + try { + getClient().connect(this.connectionOptions) + .waitForCompletion(this.connectionOptions.getConnectionTimeout()); + } + catch (MqttException e) { + // See GH-3822 + if (getConnectionInfo().isAutomaticReconnect()) { + try { + getClient().reconnect(); + } + catch (MqttException re) { + logger.error("MQTT client failed to connect. Never happens.", re); + } + } + else { + var applicationEventPublisher = getApplicationEventPublisher(); + if (applicationEventPublisher != null) { + applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, e)); + } + else { + logger.error("could not start client manager, client_id=" + getClientId(), e); + } + } + } + } + + @Override + public synchronized void stop() { + var client = getClient(); + if (client == null) { + return; + } + try { + client.disconnectForcibly(this.connectionOptions.getConnectionTimeout()); + } + catch (MqttException e) { + logger.error("could not disconnect from the client", e); + } + finally { + try { + client.close(); + } + catch (MqttException e) { + logger.error("could not close the client", e); + } + setClient(null); + } + } + + @Override + public synchronized void connectionLost(Throwable cause) { + logger.error("connection lost, client_id=" + getClientId(), cause); + } + + @Override + public void connectComplete(boolean reconnect, String serverURI) { + getCallbacks().forEach(callback -> callback.connectComplete(reconnect)); + } + + @Override + public void messageArrived(String topic, MqttMessage message) { + // not this manager concern + } + + @Override + public void deliveryComplete(IMqttDeliveryToken token) { + // nor this manager concern + } + + @Override + public MqttConnectOptions getConnectionInfo() { + return this.connectionOptions; + } + + private IMqttAsyncClient createClient() throws MqttException { + var persistence = getPersistence(); + var url = getUrl(); + var clientId = getClientId(); + var client = persistence == null ? + new MqttAsyncClient(url, clientId) : + new MqttAsyncClient(url, clientId, persistence); + client.setManualAcks(isManualAcks()); + client.setCallback(this); + return client; + } +} diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/Mqttv5ClientManager.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/Mqttv5ClientManager.java new file mode 100644 index 00000000000..3caed4c3ae3 --- /dev/null +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/core/Mqttv5ClientManager.java @@ -0,0 +1,180 @@ +/* + * Copyright 2022-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.mqtt.core; + +import org.eclipse.paho.mqttv5.client.IMqttAsyncClient; +import org.eclipse.paho.mqttv5.client.IMqttToken; +import org.eclipse.paho.mqttv5.client.MqttAsyncClient; +import org.eclipse.paho.mqttv5.client.MqttCallback; +import org.eclipse.paho.mqttv5.client.MqttClientPersistence; +import org.eclipse.paho.mqttv5.client.MqttConnectionOptions; +import org.eclipse.paho.mqttv5.client.MqttDisconnectResponse; +import org.eclipse.paho.mqttv5.common.MqttException; +import org.eclipse.paho.mqttv5.common.MqttMessage; +import org.eclipse.paho.mqttv5.common.packet.MqttProperties; + +import org.springframework.integration.mqtt.event.MqttConnectionFailedEvent; +import org.springframework.util.Assert; + +/** + * A client manager implementation for MQTT v5 protocol. Requires a client ID and server URI. + * If needed, the connection options may be overridden and passed as a {@link MqttConnectionOptions} dependency. + * By default, automatic reconnect is used. If it is required to be turned off, one should listen for + * {@link MqttConnectionFailedEvent} and reconnect the MQTT client manually. + * + * @author Artem Vozhdayenko + * @since 6.0 + */ +public class Mqttv5ClientManager + extends AbstractMqttClientManager + implements MqttCallback { + + private final MqttConnectionOptions connectionOptions; + + public Mqttv5ClientManager(String url, String clientId) { + this(buildDefaultConnectionOptions(url), clientId); + } + + public Mqttv5ClientManager(MqttConnectionOptions connectionOptions, String clientId) { + super(clientId); + Assert.notNull(connectionOptions, "'connectionOptions' is required"); + this.connectionOptions = connectionOptions; + if (!this.connectionOptions.isAutomaticReconnect()) { + logger.info("If this `ClientManager` is used from message-driven channel adapters, " + + "it is recommended to set 'automaticReconnect' MQTT connection option. " + + "Otherwise connection check and reconnect should be done manually."); + } + Assert.notEmpty(connectionOptions.getServerURIs(), "'serverURIs' must be provided in the 'MqttConnectionOptions'"); + setUrl(connectionOptions.getServerURIs()[0]); + } + + private static MqttConnectionOptions buildDefaultConnectionOptions(String url) { + Assert.notNull(url, "'url' is required"); + var connectionOptions = new MqttConnectionOptions(); + connectionOptions.setServerURIs(new String[]{ url }); + connectionOptions.setAutomaticReconnect(true); + return connectionOptions; + } + + @Override + public synchronized void start() { + if (getClient() == null) { + try { + setClient(createClient()); + } + catch (MqttException e) { + throw new IllegalStateException("could not start client manager", e); + } + } + try { + getClient().connect(this.connectionOptions) + .waitForCompletion(this.connectionOptions.getConnectionTimeout()); + } + catch (MqttException e) { + if (getConnectionInfo().isAutomaticReconnect()) { + try { + getClient().reconnect(); + } + catch (MqttException re) { + logger.error("MQTT client failed to connect. Never happens.", re); + } + } + else { + var applicationEventPublisher = getApplicationEventPublisher(); + if (applicationEventPublisher != null) { + applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, e)); + } + else { + logger.error("could not start client manager, client_id=" + getClientId(), e); + } + } + } + } + + @Override + public synchronized void stop() { + var client = getClient(); + if (client == null) { + return; + } + + try { + client.disconnectForcibly(this.connectionOptions.getConnectionTimeout()); + } + catch (MqttException e) { + logger.error("could not disconnect from the client", e); + } + finally { + try { + client.close(); + } + catch (MqttException e) { + logger.error("could not close the client", e); + } + setClient(null); + } + } + + @Override + public void messageArrived(String topic, MqttMessage message) { + // not this manager concern + } + + @Override + public void deliveryComplete(IMqttToken token) { + // not this manager concern + } + + @Override + public void connectComplete(boolean reconnect, String serverURI) { + getCallbacks().forEach(callback -> callback.connectComplete(reconnect)); + } + + @Override + public void authPacketArrived(int reasonCode, MqttProperties properties) { + // not this manager concern + } + + @Override + public void disconnected(MqttDisconnectResponse disconnectResponse) { + if (logger.isInfoEnabled()) { + logger.info("MQTT disconnected: " + disconnectResponse); + } + } + + @Override + public void mqttErrorOccurred(MqttException exception) { + logger.error("MQTT error occurred", exception); + } + + @Override + public MqttConnectionOptions getConnectionInfo() { + return this.connectionOptions; + } + + private MqttAsyncClient createClient() throws MqttException { + var persistence = getPersistence(); + var url = getUrl(); + var clientId = getClientId(); + var client = persistence == null ? + new MqttAsyncClient(url, clientId) + : new MqttAsyncClient(url, clientId, persistence); + client.setManualAcks(isManualAcks()); + client.setCallback(this); + return client; + } +} diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/AbstractMqttMessageDrivenChannelAdapter.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/AbstractMqttMessageDrivenChannelAdapter.java index da2d3f2c7ab..bf0bd3eda04 100644 --- a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/AbstractMqttMessageDrivenChannelAdapter.java +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/AbstractMqttMessageDrivenChannelAdapter.java @@ -26,6 +26,7 @@ import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.core.log.LogMessage; import org.springframework.integration.endpoint.MessageProducerSupport; +import org.springframework.integration.mqtt.core.ClientManager; import org.springframework.integration.mqtt.support.MqttMessageConverter; import org.springframework.integration.support.management.IntegrationManagedResource; import org.springframework.jmx.export.annotation.ManagedAttribute; @@ -38,30 +39,38 @@ /** * Abstract class for MQTT Message-Driven Channel Adapters. * + * @param MQTT Client type + * @param MQTT connection options type (v5 or v3) + * * @author Gary Russell * @author Artem Bilan * @author Trung Pham * @author Mikhail Polivakha + * @author Artem Vozhdayenko * * @since 4.0 * */ @ManagedResource @IntegrationManagedResource -public abstract class AbstractMqttMessageDrivenChannelAdapter extends MessageProducerSupport - implements ApplicationEventPublisherAware { +public abstract class AbstractMqttMessageDrivenChannelAdapter extends MessageProducerSupport + implements ApplicationEventPublisherAware, ClientManager.ConnectCallback { /** * The default completion timeout in milliseconds. */ public static final long DEFAULT_COMPLETION_TIMEOUT = 30_000L; + protected final Lock topicLock = new ReentrantLock(); // NOSONAR + private final String url; private final String clientId; private final Set topics; + private final ClientManager clientManager; + private long completionTimeout = DEFAULT_COMPLETION_TIMEOUT; private boolean manualAcks; @@ -70,18 +79,31 @@ public abstract class AbstractMqttMessageDrivenChannelAdapter extends MessagePro private MqttMessageConverter converter; - protected final Lock topicLock = new ReentrantLock(); // NOSONAR - public AbstractMqttMessageDrivenChannelAdapter(@Nullable String url, String clientId, String... topic) { Assert.hasText(clientId, "'clientId' cannot be null or empty"); - Assert.notNull(topic, "'topics' cannot be null"); - Assert.noNullElements(topic, "'topics' cannot have null elements"); this.url = url; this.clientId = clientId; - this.topics = new LinkedHashSet<>(); + this.topics = initTopics(topic); + this.clientManager = null; + } + + public AbstractMqttMessageDrivenChannelAdapter(ClientManager clientManager, String... topic) { + Assert.notNull(clientManager, "'clientManager' cannot be null"); + this.clientManager = clientManager; + this.topics = initTopics(topic); + this.url = null; + this.clientId = null; + } + + private static Set initTopics(String[] topic) { + Assert.notNull(topic, "'topics' cannot be null"); + Assert.noNullElements(topic, "'topics' cannot have null elements"); + final Set initialTopics = new LinkedHashSet<>(); + int defaultQos = 1; for (String t : topic) { - this.topics.add(new Topic(t, 1)); + initialTopics.add(new Topic(t, defaultQos)); } + return initialTopics; } public void setConverter(MqttMessageConverter converter) { @@ -89,6 +111,11 @@ public void setConverter(MqttMessageConverter converter) { this.converter = converter; } + @Nullable + protected ClientManager getClientManager() { + return this.clientManager; + } + /** * Set the QoS for each topic; a single value will apply to all topics otherwise * the correct number of qos values must be provided. @@ -133,6 +160,7 @@ protected String getUrl() { return this.url; } + @Nullable protected String getClientId() { return this.clientId; } @@ -157,6 +185,22 @@ public String[] getTopic() { } } + @Override + protected void onInit() { + super.onInit(); + if (this.clientManager != null) { + this.clientManager.addCallback(this); + } + } + + @Override + public void destroy() { + super.destroy(); + if (this.clientManager != null) { + this.clientManager.removeCallback(this); + } + } + @Override public String getComponentType() { return "mqtt:inbound-channel-adapter"; @@ -181,7 +225,7 @@ public void setManualAcks(boolean manualAcks) { } protected boolean isManualAcks() { - return this.manualAcks; + return this.clientManager == null ? this.manualAcks : this.clientManager.isManualAcks(); } /** diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/MqttPahoMessageDrivenChannelAdapter.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/MqttPahoMessageDrivenChannelAdapter.java index 19c88b6da64..b38bc7aa700 100644 --- a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/MqttPahoMessageDrivenChannelAdapter.java +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/MqttPahoMessageDrivenChannelAdapter.java @@ -16,14 +16,14 @@ package org.springframework.integration.mqtt.inbound; -import java.time.Instant; import java.util.Arrays; -import java.util.concurrent.ScheduledFuture; +import java.util.stream.Stream; -import org.eclipse.paho.client.mqttv3.IMqttClient; +import org.eclipse.paho.client.mqttv3.IMqttAsyncClient; import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken; -import org.eclipse.paho.client.mqttv3.MqttCallback; -import org.eclipse.paho.client.mqttv3.MqttClient; +import org.eclipse.paho.client.mqttv3.IMqttMessageListener; +import org.eclipse.paho.client.mqttv3.IMqttToken; +import org.eclipse.paho.client.mqttv3.MqttCallbackExtended; import org.eclipse.paho.client.mqttv3.MqttConnectOptions; import org.eclipse.paho.client.mqttv3.MqttException; import org.eclipse.paho.client.mqttv3.MqttMessage; @@ -31,6 +31,7 @@ import org.springframework.context.ApplicationEventPublisher; import org.springframework.integration.IntegrationMessageHeaderAccessor; import org.springframework.integration.acks.SimpleAcknowledgment; +import org.springframework.integration.mqtt.core.ClientManager; import org.springframework.integration.mqtt.core.ConsumerStopAction; import org.springframework.integration.mqtt.core.DefaultMqttPahoClientFactory; import org.springframework.integration.mqtt.core.MqttPahoClientFactory; @@ -55,36 +56,40 @@ * * @author Gary Russell * @author Artem Bilan + * @author Artem Vozhdayenko * * @since 4.0 * */ -public class MqttPahoMessageDrivenChannelAdapter extends AbstractMqttMessageDrivenChannelAdapter - implements MqttCallback, MqttPahoComponent { +public class MqttPahoMessageDrivenChannelAdapter + extends AbstractMqttMessageDrivenChannelAdapter + implements MqttCallbackExtended, MqttPahoComponent { /** * The default disconnect completion timeout in milliseconds. */ public static final long DISCONNECT_COMPLETION_TIMEOUT = 5_000L; - private static final int DEFAULT_RECOVERY_INTERVAL = 10_000; - private final MqttPahoClientFactory clientFactory; - private int recoveryInterval = DEFAULT_RECOVERY_INTERVAL; - private long disconnectCompletionTimeout = DISCONNECT_COMPLETION_TIMEOUT; - private volatile IMqttClient client; - - private volatile ScheduledFuture reconnectFuture; - - private volatile boolean connected; + private volatile IMqttAsyncClient client; private volatile boolean cleanSession; private volatile ConsumerStopAction consumerStopAction; + /** + * Use this constructor when you don't need additional {@link MqttConnectOptions}. + * @param url The URL. + * @param clientId The client id. + * @param topic The topic(s). + */ + public MqttPahoMessageDrivenChannelAdapter(String url, String clientId, String... topic) { + this(url, clientId, new DefaultMqttPahoClientFactory(), topic); + } + /** * Use this constructor for a single url (although it may be overridden if the server * URI(s) are provided by the {@link MqttConnectOptions#getServerURIs()} provided by @@ -117,15 +122,19 @@ public MqttPahoMessageDrivenChannelAdapter(String clientId, MqttPahoClientFactor this.clientFactory = clientFactory; } - /** - * Use this constructor when you don't need additional {@link MqttConnectOptions}. - * @param url The URL. - * @param clientId The client id. + * Use this constructor when you need to use a single {@link ClientManager} + * (for instance, to reuse an MQTT connection). + * @param clientManager The client manager. * @param topic The topic(s). + * @since 6.0 */ - public MqttPahoMessageDrivenChannelAdapter(String url, String clientId, String... topic) { - this(url, clientId, new DefaultMqttPahoClientFactory(), topic); + public MqttPahoMessageDrivenChannelAdapter(ClientManager clientManager, + String... topic) { + super(clientManager, topic); + var factory = new DefaultMqttPahoClientFactory(); + factory.setConnectionOptions(clientManager.getConnectionInfo()); + this.clientFactory = factory; } /** @@ -138,16 +147,6 @@ public synchronized void setDisconnectCompletionTimeout(long completionTimeout) this.disconnectCompletionTimeout = completionTimeout; } - /** - * The time (ms) to wait between reconnection attempts. - * Default {@value #DEFAULT_RECOVERY_INTERVAL}. - * @param recoveryInterval the interval. - * @since 4.2.2 - */ - public synchronized void setRecoveryInterval(int recoveryInterval) { - this.recoveryInterval = recoveryInterval; - } - @Override public MqttConnectOptions getConnectionInfo() { MqttConnectOptions options = this.clientFactory.getConnectionOptions(); @@ -168,54 +167,69 @@ protected void onInit() { DefaultPahoMessageConverter pahoMessageConverter = new DefaultPahoMessageConverter(); pahoMessageConverter.setBeanFactory(getBeanFactory()); setConverter(pahoMessageConverter); - } } @Override protected void doStart() { - Assert.state(getTaskScheduler() != null, "A 'taskScheduler' is required"); try { - connectAndSubscribe(); + connect(); } catch (Exception ex) { - logger.error(ex, "Exception while connecting and subscribing, retrying"); - scheduleReconnect(); + if (getConnectionInfo().isAutomaticReconnect()) { + try { + this.client.reconnect(); + } + catch (MqttException re) { + logger.error(re, "MQTT client failed to connect. Never happens."); + } + } + else { + logger.error(ex, "Exception while connecting"); + var applicationEventPublisher = getApplicationEventPublisher(); + if (applicationEventPublisher != null) { + applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, ex)); + } + } } } @Override protected synchronized void doStop() { - cancelReconnect(); - if (this.client != null) { - try { - if (this.consumerStopAction.equals(ConsumerStopAction.UNSUBSCRIBE_ALWAYS) - || (this.consumerStopAction.equals(ConsumerStopAction.UNSUBSCRIBE_CLEAN) - && this.cleanSession)) { + try { + if (this.consumerStopAction.equals(ConsumerStopAction.UNSUBSCRIBE_ALWAYS) + || (this.consumerStopAction.equals(ConsumerStopAction.UNSUBSCRIBE_CLEAN) + && this.cleanSession)) { - this.client.unsubscribe(getTopic()); - } - } - catch (MqttException ex) { - logger.error(ex, "Exception while unsubscribing"); - } - try { - this.client.disconnectForcibly(this.disconnectCompletionTimeout); - } - catch (MqttException ex) { - logger.error(ex, "Exception while disconnecting"); + this.client.unsubscribe(getTopic()); } + } + catch (MqttException ex1) { + logger.error(ex1, "Exception while unsubscribing"); + } + + if (getClientManager() != null) { + return; + } - this.client.setCallback(null); + try { + this.client.disconnectForcibly(this.disconnectCompletionTimeout); + } + catch (MqttException ex) { + logger.error(ex, "Exception while disconnecting"); + } + } + @Override + public void destroy() { + super.destroy(); + if (getClientManager() == null) { try { this.client.close(); } - catch (MqttException ex) { - logger.error(ex, "Exception while closing"); + catch (MqttException e) { + logger.error(e, "Could not close client"); } - this.connected = false; - this.client = null; } } @@ -225,7 +239,8 @@ public void addTopic(String topic, int qos) { try { super.addTopic(topic, qos); if (this.client != null && this.client.isConnected()) { - this.client.subscribe(topic, qos); + this.client.subscribe(topic, qos, this::messageArrived) + .waitForCompletion(getCompletionTimeout()); } } catch (MqttException e) { @@ -242,7 +257,7 @@ public void removeTopic(String... topic) { this.topicLock.lock(); try { if (this.client != null && this.client.isConnected()) { - this.client.unsubscribe(topic); + this.client.unsubscribe(topic).waitForCompletion(getCompletionTimeout()); } super.removeTopic(topic); } @@ -254,31 +269,45 @@ public void removeTopic(String... topic) { } } - private synchronized void connectAndSubscribe() throws MqttException { // NOSONAR + private synchronized void connect() throws MqttException { // NOSONAR MqttConnectOptions connectionOptions = this.clientFactory.getConnectionOptions(); this.cleanSession = connectionOptions.isCleanSession(); this.consumerStopAction = this.clientFactory.getConsumerStopAction(); if (this.consumerStopAction == null) { this.consumerStopAction = ConsumerStopAction.UNSUBSCRIBE_CLEAN; } - Assert.state(getUrl() != null || connectionOptions.getServerURIs() != null, - "If no 'url' provided, connectionOptions.getServerURIs() must not be null"); - this.client = this.clientFactory.getClientInstance(getUrl(), getClientId()); - this.client.setCallback(this); - if (this.client instanceof MqttClient) { - ((MqttClient) this.client).setTimeToWait(getCompletionTimeout()); + + var clientManager = getClientManager(); + if (clientManager == null) { + Assert.state(getUrl() != null || connectionOptions.getServerURIs() != null, + "If no 'url' provided, connectionOptions.getServerURIs() must not be null"); + this.client = this.clientFactory.getAsyncClientInstance(getUrl(), getClientId()); + this.client.setCallback(this); + this.client.connect(connectionOptions).waitForCompletion(getCompletionTimeout()); + this.client.setManualAcks(isManualAcks()); + } + else { + this.client = clientManager.getClient(); } + } + private void subscribe() { this.topicLock.lock(); String[] topics = getTopic(); ApplicationEventPublisher applicationEventPublisher = getApplicationEventPublisher(); try { - this.client.connect(connectionOptions); - this.client.setManualAcks(isManualAcks()); if (topics.length > 0) { int[] requestedQos = getQos(); - int[] grantedQos = Arrays.copyOf(requestedQos, requestedQos.length); - this.client.subscribe(topics, grantedQos); + IMqttMessageListener listener = this::messageArrived; + IMqttMessageListener[] listeners = Stream.of(topics) + .map(t -> listener) + .toArray(IMqttMessageListener[]::new); + IMqttToken subscribeToken = this.client.subscribe(topics, requestedQos, listeners); + subscribeToken.waitForCompletion(getCompletionTimeout()); + int[] grantedQos = subscribeToken.getGrantedQos(); + if (grantedQos.length == 1 && grantedQos[0] == 0x80) { // NOSONAR + throw new MqttException(MqttException.REASON_CODE_SUBSCRIBE_FAILED); + } warnInvalidQosForSubscription(topics, requestedQos, grantedQos); } } @@ -287,25 +316,12 @@ private synchronized void connectAndSubscribe() throws MqttException { // NOSONA if (applicationEventPublisher != null) { applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, ex)); } - logger.error(ex, () -> "Error connecting or subscribing to " + Arrays.toString(topics)); - if (this.client != null) { // Could be reset during event handling before - this.client.disconnectForcibly(this.disconnectCompletionTimeout); - try { - this.client.setCallback(null); - this.client.close(); - } - catch (MqttException e1) { - // NOSONAR - } - this.client = null; - } - throw ex; + logger.error(ex, () -> "Error subscribing to " + Arrays.toString(topics)); } finally { this.topicLock.unlock(); } if (this.client.isConnected()) { - this.connected = true; String message = "Connected and subscribed to " + Arrays.toString(topics); logger.debug(message); if (applicationEventPublisher != null) { @@ -325,56 +341,10 @@ private void warnInvalidQosForSubscription(String[] topics, int[] requestedQos, } } - private synchronized void cancelReconnect() { - if (this.reconnectFuture != null) { - this.reconnectFuture.cancel(false); - this.reconnectFuture = null; - } - } - - private synchronized void scheduleReconnect() { - cancelReconnect(); - if (isActive()) { - try { - this.reconnectFuture = getTaskScheduler() - .schedule(() -> { - try { - logger.debug("Attempting reconnect"); - synchronized (MqttPahoMessageDrivenChannelAdapter.this) { - if (!MqttPahoMessageDrivenChannelAdapter.this.connected) { - connectAndSubscribe(); - MqttPahoMessageDrivenChannelAdapter.this.reconnectFuture = null; - } - } - } - catch (MqttException ex) { - logger.error(ex, "Exception while connecting and subscribing"); - scheduleReconnect(); - } - }, Instant.now().plusMillis(this.recoveryInterval)); - } - catch (Exception ex) { - logger.error(ex, "Failed to schedule reconnect"); - } - } - } - @Override public synchronized void connectionLost(Throwable cause) { if (isRunning()) { - this.logger.error(() -> "Lost connection: " + cause.getMessage() + "; retrying..."); - this.connected = false; - if (this.client != null) { - try { - this.client.setCallback(null); - this.client.close(); - } - catch (MqttException e) { - // NOSONAR - } - } - this.client = null; - scheduleReconnect(); + this.logger.error(() -> "Lost connection: " + cause.getMessage()); ApplicationEventPublisher applicationEventPublisher = getApplicationEventPublisher(); if (applicationEventPublisher != null) { applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, cause)); @@ -436,6 +406,18 @@ private AbstractIntegrationMessageBuilder toMessageBuilder(String topic, Mqtt public void deliveryComplete(IMqttDeliveryToken token) { } + @Override + public void connectComplete(boolean isReconnect) { + connectComplete(isReconnect, getUrl()); + } + + @Override + public void connectComplete(boolean reconnect, String serverURI) { + if (!reconnect) { + subscribe(); + } + } + /** * Used to complete message arrival when {@link #isManualAcks()} is true. * @@ -447,7 +429,7 @@ private static class AcknowledgmentImpl implements SimpleAcknowledgment { private final int qos; - private final IMqttClient ackClient; + private final IMqttAsyncClient ackClient; /** * Construct an instance with the provided properties. @@ -455,7 +437,7 @@ private static class AcknowledgmentImpl implements SimpleAcknowledgment { * @param qos the message QOS. * @param client the client. */ - AcknowledgmentImpl(int id, int qos, IMqttClient client) { + AcknowledgmentImpl(int id, int qos, IMqttAsyncClient client) { this.id = id; this.qos = qos; this.ackClient = client; diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/Mqttv5PahoMessageDrivenChannelAdapter.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/Mqttv5PahoMessageDrivenChannelAdapter.java index aaab5792f0d..8245a5d7e5f 100644 --- a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/Mqttv5PahoMessageDrivenChannelAdapter.java +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/Mqttv5PahoMessageDrivenChannelAdapter.java @@ -18,8 +18,10 @@ import java.util.Arrays; import java.util.Map; +import java.util.stream.IntStream; import org.eclipse.paho.mqttv5.client.IMqttAsyncClient; +import org.eclipse.paho.mqttv5.client.IMqttMessageListener; import org.eclipse.paho.mqttv5.client.IMqttToken; import org.eclipse.paho.mqttv5.client.MqttAsyncClient; import org.eclipse.paho.mqttv5.client.MqttCallback; @@ -28,6 +30,7 @@ import org.eclipse.paho.mqttv5.client.MqttDisconnectResponse; import org.eclipse.paho.mqttv5.common.MqttException; import org.eclipse.paho.mqttv5.common.MqttMessage; +import org.eclipse.paho.mqttv5.common.MqttSubscription; import org.eclipse.paho.mqttv5.common.packet.MqttProperties; import org.springframework.beans.factory.BeanCreationException; @@ -36,6 +39,7 @@ import org.springframework.integration.acks.SimpleAcknowledgment; import org.springframework.integration.context.IntegrationContextUtils; import org.springframework.integration.mapping.HeaderMapper; +import org.springframework.integration.mqtt.core.ClientManager; import org.springframework.integration.mqtt.core.MqttComponent; import org.springframework.integration.mqtt.event.MqttConnectionFailedEvent; import org.springframework.integration.mqtt.event.MqttProtocolErrorEvent; @@ -67,11 +71,13 @@ * @author Artem Bilan * @author Mikhail Polivakha * @author Lucas Bowler + * @author Artem Vozhdayenko * * @since 5.5.5 * */ -public class Mqttv5PahoMessageDrivenChannelAdapter extends AbstractMqttMessageDrivenChannelAdapter +public class Mqttv5PahoMessageDrivenChannelAdapter + extends AbstractMqttMessageDrivenChannelAdapter implements MqttCallback, MqttComponent { private final MqttConnectionOptions connectionOptions; @@ -89,6 +95,7 @@ public class Mqttv5PahoMessageDrivenChannelAdapter extends AbstractMqttMessageDr public Mqttv5PahoMessageDrivenChannelAdapter(String url, String clientId, String... topic) { super(url, clientId, topic); + Assert.hasText(url, "'url' cannot be null or empty"); this.connectionOptions = new MqttConnectionOptions(); this.connectionOptions.setServerURIs(new String[]{ url }); this.connectionOptions.setAutomaticReconnect(true); @@ -106,6 +113,19 @@ public Mqttv5PahoMessageDrivenChannelAdapter(MqttConnectionOptions connectionOpt } } + /** + * Use this constructor when you need to use a single {@link ClientManager} + * (for instance, to reuse an MQTT connection). + * @param clientManager The client manager. + * @param topic The topic(s). + * @since 6.0 + */ + public Mqttv5PahoMessageDrivenChannelAdapter(ClientManager clientManager, + String... topic) { + super(clientManager, topic); + this.connectionOptions = clientManager.getConnectionInfo(); + } + @Override public MqttConnectionOptions getConnectionInfo() { return this.connectionOptions; @@ -143,7 +163,7 @@ public void setHeaderMapper(HeaderMapper headerMapper) { @Override protected void onInit() { super.onInit(); - if (this.mqttClient == null) { + if (getClientManager() == null && this.mqttClient == null) { try { this.mqttClient = new MqttAsyncClient(getUrl(), getClientId(), this.persistence); this.mqttClient.setCallback(this); @@ -162,26 +182,32 @@ protected void onInit() { @Override protected void doStart() { - ApplicationEventPublisher applicationEventPublisher = getApplicationEventPublisher(); - try { - this.mqttClient.connect(this.connectionOptions).waitForCompletion(getCompletionTimeout()); - } - catch (MqttException ex) { - if (this.connectionOptions.isAutomaticReconnect()) { - try { - this.mqttClient.reconnect(); - } - catch (MqttException e) { - logger.error(ex, "MQTT client failed to connect. Never happens."); - } + var clientManager = getClientManager(); + if (clientManager == null) { + try { + this.mqttClient.connect(this.connectionOptions).waitForCompletion(getCompletionTimeout()); } - else { - if (applicationEventPublisher != null) { - applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, ex)); + catch (MqttException ex) { + if (getConnectionInfo().isAutomaticReconnect()) { + try { + this.mqttClient.reconnect(); + } + catch (MqttException re) { + logger.error(re, "MQTT client failed to connect. Never happens."); + } + } + else { + ApplicationEventPublisher applicationEventPublisher = getApplicationEventPublisher(); + if (applicationEventPublisher != null) { + applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, ex)); + } + logger.error(ex, "MQTT client failed to connect."); } - logger.error(ex, "MQTT client failed to connect."); } } + else { + this.mqttClient = clientManager.getClient(); + } } @Override @@ -191,7 +217,10 @@ protected void doStop() { try { if (this.mqttClient != null && this.mqttClient.isConnected()) { this.mqttClient.unsubscribe(topics).waitForCompletion(getCompletionTimeout()); - this.mqttClient.disconnect().waitForCompletion(getCompletionTimeout()); + + if (getClientManager() == null) { + this.mqttClient.disconnect().waitForCompletion(getCompletionTimeout()); + } } } catch (MqttException ex) { @@ -206,7 +235,7 @@ protected void doStop() { public void destroy() { super.destroy(); try { - if (this.mqttClient != null) { + if (getClientManager() == null && this.mqttClient != null) { this.mqttClient.close(true); } } @@ -221,7 +250,8 @@ public void addTopic(String topic, int qos) { try { super.addTopic(topic, qos); if (this.mqttClient != null && this.mqttClient.isConnected()) { - this.mqttClient.subscribe(topic, qos).waitForCompletion(getCompletionTimeout()); + this.mqttClient.subscribe(new MqttSubscription(topic, qos), this::messageArrived) + .waitForCompletion(getCompletionTimeout()); } } catch (MqttException ex) { @@ -259,8 +289,9 @@ public void messageArrived(String topic, MqttMessage mqttMessage) { headers.put(MqttHeaders.RECEIVED_TOPIC, topic); if (isManualAcks()) { + var client = this.mqttClient != null ? this.mqttClient : getClientManager().getClient(); headers.put(IntegrationMessageHeaderAccessor.ACKNOWLEDGMENT_CALLBACK, - new AcknowledgmentImpl(mqttMessage.getId(), mqttMessage.getQos(), this.mqttClient)); + new AcknowledgmentImpl(mqttMessage.getId(), mqttMessage.getQos(), client)); } Object payload = @@ -307,32 +338,53 @@ public void deliveryComplete(IMqttToken token) { } + @Override + public void connectComplete(boolean isReconnect) { + connectComplete(isReconnect, getUrl()); + } + @Override public void connectComplete(boolean reconnect, String serverURI) { - if (!reconnect) { - ApplicationEventPublisher applicationEventPublisher = getApplicationEventPublisher(); - String[] topics = getTopic(); - this.topicLock.lock(); - try { - if (topics.length > 0) { - int[] requestedQos = getQos(); - this.mqttClient.subscribe(topics, requestedQos).waitForCompletion(getCompletionTimeout()); - String message = "Connected and subscribed to " + Arrays.toString(topics); - logger.debug(message); - if (applicationEventPublisher != null) { - applicationEventPublisher.publishEvent(new MqttSubscribedEvent(this, message)); - } - } + if (reconnect) { + return; + } + var clientManager = getClientManager(); + if (clientManager != null && this.mqttClient == null) { + this.mqttClient = clientManager.getClient(); + } + + String[] topics = getTopic(); + ApplicationEventPublisher applicationEventPublisher = getApplicationEventPublisher(); + this.topicLock.lock(); + try { + if (topics.length == 0) { + return; } - catch (MqttException ex) { - if (applicationEventPublisher != null) { - applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, ex)); - } - logger.error(ex, () -> "Error subscribing to " + Arrays.toString(topics)); + + int[] requestedQos = getQos(); + MqttSubscription[] subscriptions = IntStream.range(0, topics.length) + .mapToObj(i -> new MqttSubscription(topics[i], requestedQos[i])) + .toArray(MqttSubscription[]::new); + IMqttMessageListener listener = this::messageArrived; + IMqttMessageListener[] listeners = IntStream.range(0, topics.length) + .mapToObj(t -> listener) + .toArray(IMqttMessageListener[]::new); + this.mqttClient.subscribe(subscriptions, null, null, listeners, null) + .waitForCompletion(getCompletionTimeout()); + String message = "Connected and subscribed to " + Arrays.toString(topics); + logger.debug(message); + if (applicationEventPublisher != null) { + applicationEventPublisher.publishEvent(new MqttSubscribedEvent(this, message)); } - finally { - this.topicLock.unlock(); + } + catch (MqttException ex) { + if (applicationEventPublisher != null) { + applicationEventPublisher.publishEvent(new MqttConnectionFailedEvent(this, ex)); } + logger.error(ex, () -> "Error subscribing to " + Arrays.toString(topics)); + } + finally { + this.topicLock.unlock(); } } diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/AbstractMqttMessageHandler.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/AbstractMqttMessageHandler.java index a87077cd6fb..bf5b6362772 100644 --- a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/AbstractMqttMessageHandler.java +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/AbstractMqttMessageHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import org.springframework.integration.handler.AbstractMessageHandler; import org.springframework.integration.handler.ExpressionEvaluatingMessageProcessor; import org.springframework.integration.handler.MessageProcessor; +import org.springframework.integration.mqtt.core.ClientManager; import org.springframework.integration.mqtt.support.MqttHeaders; import org.springframework.integration.mqtt.support.MqttMessageConverter; import org.springframework.integration.support.management.ManageableLifecycle; @@ -36,13 +37,17 @@ /** * Abstract class for MQTT outbound channel adapters. * + * @param MQTT Client type + * @param MQTT connection options type (v5 or v3) + * * @author Gary Russell * @author Artem Bilan + * @author Artem Vozhdayenko * * @since 4.0 * */ -public abstract class AbstractMqttMessageHandler extends AbstractMessageHandler +public abstract class AbstractMqttMessageHandler extends AbstractMessageHandler implements ManageableLifecycle, ApplicationEventPublisherAware { /** @@ -64,6 +69,8 @@ public abstract class AbstractMqttMessageHandler extends AbstractMessageHandler private final String clientId; + private final ClientManager clientManager; + private long completionTimeout = DEFAULT_COMPLETION_TIMEOUT; private long disconnectCompletionTimeout = DISCONNECT_COMPLETION_TIMEOUT; @@ -90,6 +97,15 @@ public AbstractMqttMessageHandler(@Nullable String url, String clientId) { Assert.hasText(clientId, "'clientId' cannot be null or empty"); this.url = url; this.clientId = clientId; + this.clientManager = null; + } + + public AbstractMqttMessageHandler(ClientManager clientManager) { + Assert.notNull(clientManager, "'clientManager' cannot be null or empty"); + this.clientManager = clientManager; + clientManager.getConnectionInfo(); + this.url = null; + this.clientId = null; } @Override @@ -242,6 +258,7 @@ protected String getUrl() { return this.url; } + @Nullable public String getClientId() { return this.clientId; } @@ -292,6 +309,11 @@ protected long getDisconnectCompletionTimeout() { return this.disconnectCompletionTimeout; } + @Nullable + protected ClientManager getClientManager() { + return this.clientManager; + } + @Override protected void onInit() { super.onInit(); diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/MqttPahoMessageHandler.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/MqttPahoMessageHandler.java index aa00aca70b4..4fa3802b996 100644 --- a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/MqttPahoMessageHandler.java +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/MqttPahoMessageHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import org.eclipse.paho.client.mqttv3.MqttMessage; import org.springframework.context.ApplicationEventPublisher; +import org.springframework.integration.mqtt.core.ClientManager; import org.springframework.integration.mqtt.core.DefaultMqttPahoClientFactory; import org.springframework.integration.mqtt.core.MqttPahoClientFactory; import org.springframework.integration.mqtt.core.MqttPahoComponent; @@ -48,11 +49,13 @@ * * @author Gary Russell * @author Artem Bilan + * @author Artem Vozhdayenko * * @since 4.0 * */ -public class MqttPahoMessageHandler extends AbstractMqttMessageHandler implements MqttCallback, MqttPahoComponent { +public class MqttPahoMessageHandler extends AbstractMqttMessageHandler + implements MqttCallback, MqttPahoComponent { private final MqttPahoClientFactory clientFactory; @@ -62,6 +65,15 @@ public class MqttPahoMessageHandler extends AbstractMqttMessageHandler implement private volatile IMqttAsyncClient client; + /** + * Use this constructor when you don't need additional {@link MqttConnectOptions}. + * @param url The URL. + * @param clientId The client id. + */ + public MqttPahoMessageHandler(String url, String clientId) { + this(url, clientId, new DefaultMqttPahoClientFactory()); + } + /** * Use this constructor for a single url (although it may be overridden if the server * URI(s) are provided by the {@link MqttConnectOptions#getServerURIs()} provided by @@ -88,12 +100,16 @@ public MqttPahoMessageHandler(String clientId, MqttPahoClientFactory clientFacto } /** - * Use this constructor when you don't need additional {@link MqttConnectOptions}. - * @param url The URL. - * @param clientId The client id. + * Use this constructor when you need to use a single {@link ClientManager} + * (for instance, to reuse an MQTT connection). + * @param clientManager The client manager. + * @since 6.0 */ - public MqttPahoMessageHandler(String url, String clientId) { - this(url, clientId, new DefaultMqttPahoClientFactory()); + public MqttPahoMessageHandler(ClientManager clientManager) { + super(clientManager); + var factory = new DefaultMqttPahoClientFactory(); + factory.setConnectionOptions(clientManager.getConnectionInfo()); + this.clientFactory = factory; } /** @@ -169,6 +185,11 @@ protected void doStop() { } private synchronized IMqttAsyncClient checkConnection() throws MqttException { + var theClientManager = getClientManager(); + if (theClientManager != null) { + return theClientManager.getClient(); + } + if (this.client != null && !this.client.isConnected()) { this.client.setCallback(null); this.client.close(); diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/Mqttv5PahoMessageHandler.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/Mqttv5PahoMessageHandler.java index 46908724b60..9c165ed1440 100644 --- a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/Mqttv5PahoMessageHandler.java +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/Mqttv5PahoMessageHandler.java @@ -34,6 +34,7 @@ import org.springframework.context.ApplicationEventPublisher; import org.springframework.integration.context.IntegrationContextUtils; import org.springframework.integration.mapping.HeaderMapper; +import org.springframework.integration.mqtt.core.ClientManager; import org.springframework.integration.mqtt.core.MqttComponent; import org.springframework.integration.mqtt.event.MqttConnectionFailedEvent; import org.springframework.integration.mqtt.event.MqttMessageDeliveredEvent; @@ -52,10 +53,11 @@ * * @author Artem Bilan * @author Lucas Bowler + * @author Artem Vozhdayenko * * @since 5.5.5 */ -public class Mqttv5PahoMessageHandler extends AbstractMqttMessageHandler +public class Mqttv5PahoMessageHandler extends AbstractMqttMessageHandler implements MqttCallback, MqttComponent { private final MqttConnectionOptions connectionOptions; @@ -73,6 +75,7 @@ public class Mqttv5PahoMessageHandler extends AbstractMqttMessageHandler public Mqttv5PahoMessageHandler(String url, String clientId) { super(url, clientId); + Assert.hasText(url, "'url' cannot be null or empty"); this.connectionOptions = new MqttConnectionOptions(); this.connectionOptions.setServerURIs(new String[]{ url }); this.connectionOptions.setAutomaticReconnect(true); @@ -83,6 +86,16 @@ public Mqttv5PahoMessageHandler(MqttConnectionOptions connectionOptions, String this.connectionOptions = connectionOptions; } + /** + * Use this constructor when you need to use a single {@link ClientManager} + * (for instance, to reuse an MQTT connection). + * @param clientManager The client manager. + * @since 6.0 + */ + public Mqttv5PahoMessageHandler(ClientManager clientManager) { + super(clientManager); + this.connectionOptions = clientManager.getConnectionInfo(); + } private static String obtainServerUrlFromOptions(MqttConnectionOptions connectionOptions) { Assert.notNull(connectionOptions, "'connectionOptions' must not be null"); @@ -131,9 +144,11 @@ public void setAsyncEvents(boolean asyncEvents) { protected void onInit() { super.onInit(); try { - this.mqttClient = new MqttAsyncClient(getUrl(), getClientId(), this.persistence); - this.mqttClient.setCallback(this); - incrementClientInstance(); + if (getClientManager() == null) { + this.mqttClient = new MqttAsyncClient(getUrl(), getClientId(), this.persistence); + this.mqttClient.setCallback(this); + incrementClientInstance(); + } } catch (MqttException ex) { throw new BeanCreationException("Cannot create 'MqttAsyncClient' for: " + getComponentName(), ex); @@ -152,17 +167,25 @@ protected void onInit() { @Override protected void doStart() { try { - this.mqttClient.connect(this.connectionOptions).waitForCompletion(getCompletionTimeout()); + var clientManager = getClientManager(); + if (clientManager != null) { + this.mqttClient = clientManager.getClient(); + } + else { + this.mqttClient.connect(this.connectionOptions).waitForCompletion(getCompletionTimeout()); + } } catch (MqttException ex) { - logger.error(ex, "MQTT client failed to connect."); + logger.error(ex, "MQTT client failed to connect."); } } @Override protected void doStop() { try { - this.mqttClient.disconnect().waitForCompletion(getDisconnectCompletionTimeout()); + if (getClientManager() == null) { + this.mqttClient.disconnect().waitForCompletion(getDisconnectCompletionTimeout()); + } } catch (MqttException ex) { logger.error(ex, "Failed to disconnect 'MqttAsyncClient'"); @@ -173,7 +196,9 @@ protected void doStop() { public void destroy() { super.destroy(); try { - this.mqttClient.close(true); + if (getClientManager() == null) { + this.mqttClient.close(true); + } } catch (MqttException ex) { logger.error(ex, "Failed to close 'MqttAsyncClient'"); diff --git a/spring-integration-mqtt/src/main/resources/org/springframework/integration/mqtt/config/spring-integration-mqtt.xsd b/spring-integration-mqtt/src/main/resources/org/springframework/integration/mqtt/config/spring-integration-mqtt.xsd index 2a31b8c9f07..3d18126d39f 100644 --- a/spring-integration-mqtt/src/main/resources/org/springframework/integration/mqtt/config/spring-integration-mqtt.xsd +++ b/spring-integration-mqtt/src/main/resources/org/springframework/integration/mqtt/config/spring-integration-mqtt.xsd @@ -66,15 +66,6 @@ - - - - - - diff --git a/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/ClientManagerBackToBackTests.java b/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/ClientManagerBackToBackTests.java new file mode 100644 index 00000000000..a92ba0d8771 --- /dev/null +++ b/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/ClientManagerBackToBackTests.java @@ -0,0 +1,287 @@ +/* + * Copyright 2022-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.mqtt; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.eclipse.paho.client.mqttv3.MqttConnectOptions; +import org.eclipse.paho.client.mqttv3.MqttException; +import org.junit.jupiter.api.Test; + +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.event.EventListener; +import org.springframework.integration.config.EnableIntegration; +import org.springframework.integration.dsl.IntegrationFlow; +import org.springframework.integration.mqtt.core.Mqttv3ClientManager; +import org.springframework.integration.mqtt.core.Mqttv5ClientManager; +import org.springframework.integration.mqtt.event.MqttSubscribedEvent; +import org.springframework.integration.mqtt.inbound.MqttPahoMessageDrivenChannelAdapter; +import org.springframework.integration.mqtt.inbound.Mqttv5PahoMessageDrivenChannelAdapter; +import org.springframework.integration.mqtt.outbound.MqttPahoMessageHandler; +import org.springframework.integration.mqtt.outbound.Mqttv5PahoMessageHandler; +import org.springframework.integration.mqtt.support.MqttHeaders; +import org.springframework.integration.support.MessageBuilder; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.PollableChannel; + +/** + * @author Artem Vozhdayenko + * @since 6.0 + */ +class ClientManagerBackToBackTests implements MosquittoContainerTest { + + @Test + void testSameV3ClientIdWorksForPubAndSub() throws Exception { + testSubscribeAndPublish(Mqttv3Config.class, Mqttv3Config.TOPIC_NAME, Mqttv3Config.subscribedLatch); + } + + @Test + void testSameV5ClientIdWorksForPubAndSub() throws Exception { + testSubscribeAndPublish(Mqttv5Config.class, Mqttv5Config.TOPIC_NAME, Mqttv5Config.subscribedLatch); + } + + @Test + void testV3ClientManagerReconnect() throws Exception { + testSubscribeAndPublish(Mqttv3ConfigWithDisconnect.class, Mqttv3ConfigWithDisconnect.TOPIC_NAME, + Mqttv3ConfigWithDisconnect.subscribedLatch); + } + + @Test + void testV5ClientManagerReconnect() throws Exception { + testSubscribeAndPublish(Mqttv5ConfigWithDisconnect.class, Mqttv5ConfigWithDisconnect.TOPIC_NAME, + Mqttv5ConfigWithDisconnect.subscribedLatch); + } + + private void testSubscribeAndPublish(Class configClass, String topicName, CountDownLatch subscribedLatch) + throws Exception { + + try (var ctx = new AnnotationConfigApplicationContext(configClass)) { + // given + var input = ctx.getBean("mqttOutFlow.input", MessageChannel.class); + var output = ctx.getBean("fromMqttChannel", PollableChannel.class); + String testPayload = "foo"; + assertThat(subscribedLatch.await(20, TimeUnit.SECONDS)).isTrue(); + + // when + input.send(MessageBuilder.withPayload(testPayload).setHeader(MqttHeaders.TOPIC, topicName).build()); + Message receive = output.receive(20_000); + + // then + assertThat(receive).isNotNull(); + Object payload = receive.getPayload(); + if (payload instanceof String sp) { + assertThat(sp).isEqualTo(testPayload); + } + else { + assertThat(payload).isEqualTo(testPayload.getBytes(StandardCharsets.UTF_8)); + } + } + } + + @Configuration + @EnableIntegration + public static class Mqttv3Config { + + static final String TOPIC_NAME = "test-topic-v3"; + + static final CountDownLatch subscribedLatch = new CountDownLatch(1); + + @EventListener + public void onSubscribed(MqttSubscribedEvent e) { + subscribedLatch.countDown(); + } + + @Bean + public Mqttv3ClientManager mqttv3ClientManager() { + MqttConnectOptions connectionOptions = new MqttConnectOptions(); + connectionOptions.setServerURIs(new String[]{ MosquittoContainerTest.mqttUrl() }); + connectionOptions.setAutomaticReconnect(true); + return new Mqttv3ClientManager(connectionOptions, "client-manager-client-id-v3"); + } + + @Bean + public IntegrationFlow mqttOutFlow(Mqttv3ClientManager mqttv3ClientManager) { + return f -> f.handle(new MqttPahoMessageHandler(mqttv3ClientManager)); + } + + @Bean + public IntegrationFlow mqttInFlow(Mqttv3ClientManager mqttv3ClientManager) { + return IntegrationFlow.from(new MqttPahoMessageDrivenChannelAdapter(mqttv3ClientManager, TOPIC_NAME)) + .channel(c -> c.queue("fromMqttChannel")) + .get(); + } + + } + + @Configuration + @EnableIntegration + public static class Mqttv3ConfigWithDisconnect { + + static final String TOPIC_NAME = "test-topic-v3-reconnect"; + + static final CountDownLatch subscribedLatch = new CountDownLatch(1); + + @EventListener + public void onSubscribed(MqttSubscribedEvent e) { + subscribedLatch.countDown(); + } + + @Bean + public ClientV3Disconnector disconnector(Mqttv3ClientManager clientManager) { + return new ClientV3Disconnector(clientManager); + } + + @Bean + public Mqttv3ClientManager mqttv3ClientManager() { + MqttConnectOptions connectionOptions = new MqttConnectOptions(); + connectionOptions.setServerURIs(new String[]{ MosquittoContainerTest.mqttUrl() }); + connectionOptions.setAutomaticReconnect(true); + return new Mqttv3ClientManager(connectionOptions, "client-manager-client-id-v3-reconnect"); + } + + @Bean + public IntegrationFlow mqttOutFlow() { + return f -> f.handle(new MqttPahoMessageHandler(MosquittoContainerTest.mqttUrl(), "old-client-v3")); + } + + @Bean + public IntegrationFlow mqttInFlow(Mqttv3ClientManager mqttv3ClientManager) { + return IntegrationFlow.from(new MqttPahoMessageDrivenChannelAdapter(mqttv3ClientManager, TOPIC_NAME)) + .channel(c -> c.queue("fromMqttChannel")) + .get(); + } + + } + + @Configuration + @EnableIntegration + public static class Mqttv5Config { + + static final String TOPIC_NAME = "test-topic-v5"; + + static final CountDownLatch subscribedLatch = new CountDownLatch(1); + + @EventListener + public void onSubscribed(MqttSubscribedEvent e) { + subscribedLatch.countDown(); + } + + @Bean + public Mqttv5ClientManager mqttv5ClientManager() { + return new Mqttv5ClientManager(MosquittoContainerTest.mqttUrl(), "client-manager-client-id-v5"); + } + + @Bean + public IntegrationFlow mqttOutFlow(Mqttv5ClientManager mqttv5ClientManager) { + return f -> f.handle(new Mqttv5PahoMessageHandler(mqttv5ClientManager)); + } + + @Bean + public IntegrationFlow mqttInFlow(Mqttv5ClientManager mqttv5ClientManager) { + return IntegrationFlow.from(new Mqttv5PahoMessageDrivenChannelAdapter(mqttv5ClientManager, TOPIC_NAME)) + .channel(c -> c.queue("fromMqttChannel")) + .get(); + } + + } + + @Configuration + @EnableIntegration + public static class Mqttv5ConfigWithDisconnect { + + static final String TOPIC_NAME = "test-topic-v5-reconnect"; + + static final CountDownLatch subscribedLatch = new CountDownLatch(1); + + @EventListener + public void onSubscribed(MqttSubscribedEvent e) { + subscribedLatch.countDown(); + } + + @Bean + public ClientV5Disconnector clientV3Disconnector(Mqttv5ClientManager clientManager) { + return new ClientV5Disconnector(clientManager); + } + + @Bean + public Mqttv5ClientManager mqttv5ClientManager() { + return new Mqttv5ClientManager(MosquittoContainerTest.mqttUrl(), "client-manager-client-id-v5-reconnect"); + } + + @Bean + public IntegrationFlow mqttOutFlow(Mqttv5ClientManager mqttv5ClientManager) { + return f -> f.handle(new Mqttv5PahoMessageHandler(MosquittoContainerTest.mqttUrl(), "old-client-v5")); + } + + @Bean + public IntegrationFlow mqttInFlow(Mqttv5ClientManager mqttv5ClientManager) { + return IntegrationFlow.from(new Mqttv5PahoMessageDrivenChannelAdapter(mqttv5ClientManager, TOPIC_NAME)) + .channel(c -> c.queue("fromMqttChannel")) + .get(); + } + + } + + + public static class ClientV3Disconnector { + + private final Mqttv3ClientManager clientManager; + + ClientV3Disconnector(Mqttv3ClientManager clientManager) { + this.clientManager = clientManager; + } + + @EventListener + public void handleSubscribedEvent(MqttSubscribedEvent e) { + try { + this.clientManager.getClient().disconnectForcibly(); + } + catch (MqttException ex) { + throw new IllegalStateException("could not disconnect the client!"); + } + } + + } + + public static class ClientV5Disconnector { + + private final Mqttv5ClientManager clientManager; + + ClientV5Disconnector(Mqttv5ClientManager clientManager) { + this.clientManager = clientManager; + } + + @EventListener + public void handleSubscribedEvent(MqttSubscribedEvent e) { + try { + this.clientManager.getClient().disconnectForcibly(); + } + catch (org.eclipse.paho.mqttv5.common.MqttException ex) { + throw new IllegalStateException("could not disconnect the client!"); + } + } + + } + +} diff --git a/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java b/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java index a8c742017a7..138c81ddd0e 100644 --- a/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java +++ b/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java @@ -17,13 +17,11 @@ package org.springframework.integration.mqtt; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.BDDMockito.willReturn; @@ -34,30 +32,24 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; -import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.time.Instant; import java.util.Properties; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import javax.net.SocketFactory; import org.aopalliance.intercept.MethodInterceptor; -import org.assertj.core.api.Condition; import org.eclipse.paho.client.mqttv3.IMqttAsyncClient; -import org.eclipse.paho.client.mqttv3.IMqttClient; import org.eclipse.paho.client.mqttv3.IMqttToken; import org.eclipse.paho.client.mqttv3.MqttAsyncClient; -import org.eclipse.paho.client.mqttv3.MqttCallback; -import org.eclipse.paho.client.mqttv3.MqttClient; +import org.eclipse.paho.client.mqttv3.MqttCallbackExtended; import org.eclipse.paho.client.mqttv3.MqttConnectOptions; import org.eclipse.paho.client.mqttv3.MqttDeliveryToken; import org.eclipse.paho.client.mqttv3.MqttException; @@ -65,9 +57,7 @@ import org.eclipse.paho.client.mqttv3.MqttToken; import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; -import org.mockito.internal.stubbing.answers.CallsRealMethods; import org.springframework.aop.framework.ProxyFactoryBean; import org.springframework.beans.DirectFieldAccessor; @@ -83,6 +73,7 @@ import org.springframework.integration.handler.MessageProcessor; import org.springframework.integration.mqtt.core.ConsumerStopAction; import org.springframework.integration.mqtt.core.DefaultMqttPahoClientFactory; +import org.springframework.integration.mqtt.core.Mqttv3ClientManager; import org.springframework.integration.mqtt.event.MqttConnectionFailedEvent; import org.springframework.integration.mqtt.event.MqttIntegrationEvent; import org.springframework.integration.mqtt.event.MqttSubscribedEvent; @@ -98,13 +89,12 @@ import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.support.ErrorMessage; import org.springframework.messaging.support.GenericMessage; -import org.springframework.scheduling.TaskScheduler; -import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.ReflectionUtils; /** * @author Gary Russell * @author Artem Bilan + * @author Artem Vozhdayenko * * @since 4.0 * @@ -120,21 +110,11 @@ public class MqttAdapterTests { this.alwaysComplete = (IMqttToken) pfb.getObject(); } - @Test - public void testCloseOnBadConnectIn() throws Exception { - final IMqttClient client = mock(IMqttClient.class); - willThrow(new MqttException(0)).given(client).connect(any()); - MqttPahoMessageDrivenChannelAdapter adapter = buildAdapterIn(client, null, ConsumerStopAction.UNSUBSCRIBE_NEVER); - adapter.start(); - verify(client).close(); - adapter.stop(); - } - @Test public void testCloseOnBadConnectOut() throws Exception { final IMqttAsyncClient client = mock(IMqttAsyncClient.class); - willThrow(new MqttException(0)).given(client).connect(any()); MqttPahoMessageHandler adapter = buildAdapterOut(client); + willThrow(new MqttException(0)).given(client).connect(any()); adapter.start(); try { adapter.handleMessage(new GenericMessage<>("foo")); @@ -191,7 +171,7 @@ public void testOutboundOptionsApplied() throws Exception { connectCalled.set(true); return token; }).given(client).connect(any(MqttConnectOptions.class)); - willReturn(token).given(client).subscribe(any(String[].class), any(int[].class)); + willReturn(token).given(client).subscribe(any(String[].class), any(int[].class), any()); final MqttDeliveryToken deliveryToken = mock(MqttDeliveryToken.class); final AtomicBoolean publishCalled = new AtomicBoolean(); @@ -214,6 +194,66 @@ public void testOutboundOptionsApplied() throws Exception { handler.stop(); } + @Test + void testClientManagerIsNotConnectedAndClosedInHandler() throws Exception { + // given + var clientManager = mock(Mqttv3ClientManager.class); + when(clientManager.getConnectionInfo()).thenReturn(new MqttConnectOptions()); + var client = mock(MqttAsyncClient.class); + given(clientManager.getClient()).willReturn(client); + + var deliveryToken = mock(MqttDeliveryToken.class); + given(client.publish(anyString(), any(MqttMessage.class))).willReturn(deliveryToken); + + var handler = new MqttPahoMessageHandler(clientManager); + handler.setDefaultTopic("mqtt-foo"); + handler.setBeanFactory(mock(BeanFactory.class)); + handler.afterPropertiesSet(); + handler.start(); + + // when + handler.handleMessage(new GenericMessage<>("Hello, world!")); + handler.stop(); + + // then + verify(client, never()).connect(any(MqttConnectOptions.class)); + verify(client).publish(anyString(), any(MqttMessage.class)); + verify(client, never()).disconnect(); + verify(client, never()).disconnect(anyLong()); + verify(client, never()).close(); + } + + @Test + void testClientManagerIsNotConnectedAndClosedInAdapter() throws Exception { + // given + var clientManager = mock(Mqttv3ClientManager.class); + when(clientManager.getConnectionInfo()).thenReturn(new MqttConnectOptions()); + var client = mock(MqttAsyncClient.class); + given(clientManager.getClient()).willReturn(client); + + var subscribeToken = mock(MqttToken.class); + given(subscribeToken.getGrantedQos()).willReturn(new int[]{ 2 }); + given(client.subscribe(any(String[].class), any(int[].class), any())) + .willReturn(subscribeToken); + + var adapter = new MqttPahoMessageDrivenChannelAdapter(clientManager, "mqtt-foo"); + adapter.setBeanFactory(mock(BeanFactory.class)); + adapter.afterPropertiesSet(); + + // when + adapter.start(); + adapter.connectComplete(false, null); + adapter.stop(); + + // then + verify(client, never()).connect(any(MqttConnectOptions.class)); + verify(client).subscribe(eq(new String[]{ "mqtt-foo" }), any(int[].class), any()); + verify(client).unsubscribe(new String[]{ "mqtt-foo" }); + verify(client, never()).disconnect(); + verify(client, never()).disconnect(anyLong()); + verify(client, never()).close(); + } + @Test public void testInboundOptionsApplied() throws Exception { DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory(); @@ -233,21 +273,12 @@ public void testInboundOptionsApplied() throws Exception { factory.setConnectionOptions(connectOptions); factory = spy(factory); - final IMqttClient client = mock(IMqttClient.class); - willAnswer(invocation -> client).given(factory).getClientInstance(anyString(), anyString()); + final IMqttAsyncClient client = mock(IMqttAsyncClient.class); + willReturn(client).given(factory).getAsyncClientInstance(anyString(), anyString()); final AtomicBoolean connectCalled = new AtomicBoolean(); - final AtomicBoolean failConnection = new AtomicBoolean(); - final CountDownLatch waitToFail = new CountDownLatch(1); - final CountDownLatch failInProcess = new CountDownLatch(1); - final CountDownLatch goodConnection = new CountDownLatch(2); - final MqttException reconnectException = new MqttException(MqttException.REASON_CODE_SERVER_CONNECT_ERROR); + IMqttToken token = mock(IMqttToken.class); willAnswer(invocation -> { - if (failConnection.get()) { - failInProcess.countDown(); - waitToFail.await(10, TimeUnit.SECONDS); - throw reconnectException; - } MqttConnectOptions options = invocation.getArgument(0); assertThat(options.getConnectionTimeout()).isEqualTo(23); assertThat(options.getKeepAliveInterval()).isEqualTo(45); @@ -259,15 +290,16 @@ public void testInboundOptionsApplied() throws Exception { assertThat(new String(options.getWillMessage().getPayload())).isEqualTo("bar"); assertThat(options.getWillMessage().getQos()).isEqualTo(2); connectCalled.set(true); - goodConnection.countDown(); - return null; + return token; }).given(client).connect(any(MqttConnectOptions.class)); + given(client.subscribe(any(String[].class), any(int[].class), any())).willReturn(token); + given(token.getGrantedQos()).willReturn(new int[]{ 2 }); - final AtomicReference callback = new AtomicReference<>(); + final AtomicReference callback = new AtomicReference<>(); willAnswer(invocation -> { callback.set(invocation.getArgument(0)); return null; - }).given(client).setCallback(any(MqttCallback.class)); + }).given(client).setCallback(any(MqttCallbackExtended.class)); given(client.isConnected()).willReturn(true); @@ -278,9 +310,6 @@ public void testInboundOptionsApplied() throws Exception { adapter.setOutputChannel(outputChannel); QueueChannel errorChannel = new QueueChannel(); adapter.setErrorChannel(errorChannel); - ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler(); - taskScheduler.initialize(); - adapter.setTaskScheduler(taskScheduler); adapter.setBeanFactory(mock(BeanFactory.class)); ApplicationEventPublisher applicationEventPublisher = mock(ApplicationEventPublisher.class); final BlockingQueue events = new LinkedBlockingQueue<>(); @@ -289,9 +318,9 @@ public void testInboundOptionsApplied() throws Exception { return null; }).given(applicationEventPublisher).publishEvent(any(MqttIntegrationEvent.class)); adapter.setApplicationEventPublisher(applicationEventPublisher); - adapter.setRecoveryInterval(500); adapter.afterPropertiesSet(); adapter.start(); + adapter.connectComplete(false, null); verify(client, times(1)).connect(any(MqttConnectOptions.class)); assertThat(connectCalled.get()).isTrue(); @@ -339,78 +368,49 @@ public Message toMessage(Object payload, MessageHeaders headers) { IllegalStateException exception = (IllegalStateException) errorMessage.getPayload(); assertThat(exception).hasMessage("'MqttMessageConverter' returned 'null'"); assertThat(errorMessage.getOriginalMessage().getPayload()).isSameAs(message); - - // lose connection and make first reconnect fail - failConnection.set(true); - RuntimeException e = new RuntimeException("foo"); - adapter.connectionLost(e); - - event = events.poll(10, TimeUnit.SECONDS); - assertThat(event).isInstanceOf(MqttConnectionFailedEvent.class); - assertThat(e).isSameAs(event.getCause()); - - assertThat(failInProcess.await(10, TimeUnit.SECONDS)).isTrue(); - waitToFail.countDown(); - failConnection.set(false); - event = events.poll(10, TimeUnit.SECONDS); - assertThat(event).isInstanceOf(MqttConnectionFailedEvent.class); - assertThat(reconnectException).isSameAs(event.getCause()); - - // reconnect can now succeed; however, we might have other failures on a slow server (500ms retry). - assertThat(goodConnection.await(10, TimeUnit.SECONDS)).isTrue(); - int n = 0; - while (!(event instanceof MqttSubscribedEvent) && n++ < 20) { - event = events.poll(10, TimeUnit.SECONDS); - } - assertThat(event).isInstanceOf(MqttSubscribedEvent.class); - assertThat(((MqttSubscribedEvent) event).getMessage()).isEqualTo("Connected and subscribed to [baz, fix]"); - taskScheduler.destroy(); } @Test public void testStopActionDefault() throws Exception { - final IMqttClient client = mock(IMqttClient.class); + final IMqttAsyncClient client = mock(IMqttAsyncClient.class); MqttPahoMessageDrivenChannelAdapter adapter = buildAdapterIn(client, null, null); adapter.start(); + adapter.connectComplete(false, null); adapter.stop(); verifyUnsubscribe(client); } @Test public void testStopActionDefaultNotClean() throws Exception { - final IMqttClient client = mock(IMqttClient.class); + final IMqttAsyncClient client = mock(IMqttAsyncClient.class); MqttPahoMessageDrivenChannelAdapter adapter = buildAdapterIn(client, false, null); adapter.start(); + adapter.connectComplete(false, null); adapter.stop(); verifyNotUnsubscribe(client); } @Test public void testStopActionAlways() throws Exception { - final IMqttClient client = mock(IMqttClient.class); + final IMqttAsyncClient client = mock(IMqttAsyncClient.class); MqttPahoMessageDrivenChannelAdapter adapter = buildAdapterIn(client, false, ConsumerStopAction.UNSUBSCRIBE_ALWAYS); adapter.start(); + adapter.connectComplete(false, null); adapter.stop(); verifyUnsubscribe(client); - - adapter.connectionLost(new RuntimeException("Intentional")); - - TaskScheduler taskScheduler = TestUtils.getPropertyValue(adapter, "taskScheduler", TaskScheduler.class); - - verify(taskScheduler, never()) - .schedule(any(Runnable.class), any(Instant.class)); } @Test public void testStopActionNever() throws Exception { - final IMqttClient client = mock(IMqttClient.class); + final IMqttAsyncClient client = mock(IMqttAsyncClient.class); MqttPahoMessageDrivenChannelAdapter adapter = buildAdapterIn(client, null, ConsumerStopAction.UNSUBSCRIBE_NEVER); adapter.start(); + adapter.connectComplete(false, null); adapter.stop(); verifyNotUnsubscribe(client); } @@ -436,39 +436,6 @@ public void testCustomExpressions() { ctx.close(); } - @Test - public void testReconnect() throws Exception { - final IMqttClient client = mock(IMqttClient.class); - MqttPahoMessageDrivenChannelAdapter adapter = buildAdapterIn(client, null, ConsumerStopAction.UNSUBSCRIBE_NEVER); - adapter.setRecoveryInterval(10); - LogAccessor logger = spy(TestUtils.getPropertyValue(adapter, "logger", LogAccessor.class)); - new DirectFieldAccessor(adapter).setPropertyValue("logger", logger); - given(logger.isDebugEnabled()).willReturn(true); - final AtomicInteger attemptingReconnectCount = new AtomicInteger(); - willAnswer(i -> { - if (attemptingReconnectCount.getAndIncrement() == 0) { - adapter.connectionLost(new RuntimeException("while schedule running")); - } - i.callRealMethod(); - return null; - }).given(logger).debug("Attempting reconnect"); - ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler(); - taskScheduler.initialize(); - adapter.setTaskScheduler(taskScheduler); - adapter.start(); - adapter.connectionLost(new RuntimeException("initial")); - verify(client).close(); - Thread.sleep(1000); - // the following assertion should be equalTo, but leq to protect against a slow CI server - assertThat(attemptingReconnectCount.get()).isLessThanOrEqualTo(2); - AtomicReference failed = new AtomicReference<>(); - adapter.setApplicationEventPublisher(failed::set); - adapter.connectionLost(new IllegalStateException()); - assertThat(failed.get()).isInstanceOf(MqttConnectionFailedEvent.class); - adapter.stop(); - taskScheduler.destroy(); - } - @Test public void testSubscribeFailure() throws Exception { DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory(); @@ -487,19 +454,14 @@ public void testSubscribeFailure() throws Exception { connectOptions.setWill("foo", "bar".getBytes(), 2, true); factory = spy(factory); - MqttAsyncClient aClient = mock(MqttAsyncClient.class); - final MqttClient client = mock(MqttClient.class); - willAnswer(invocation -> client).given(factory).getClientInstance(anyString(), anyString()); + final MqttAsyncClient client = mock(MqttAsyncClient.class); + willReturn(client).given(factory).getAsyncClientInstance(anyString(), anyString()); given(client.isConnected()).willReturn(true); - new DirectFieldAccessor(client).setPropertyValue("aClient", aClient); - willAnswer(new CallsRealMethods()).given(client).connect(any(MqttConnectOptions.class)); - willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class)); - willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class), isNull()); - willReturn(alwaysComplete).given(aClient).connect(any(MqttConnectOptions.class), any(), any()); + willReturn(alwaysComplete).given(client).connect(any(MqttConnectOptions.class)); IMqttToken token = mock(IMqttToken.class); given(token.getGrantedQos()).willReturn(new int[]{ 0x80 }); - willReturn(token).given(aClient).subscribe(any(String[].class), any(int[].class), isNull(), isNull(), any()); + willReturn(token).given(client).subscribe(any(String[].class), any(int[].class), any()); MqttPahoMessageDrivenChannelAdapter adapter = new MqttPahoMessageDrivenChannelAdapter("foo", "bar", factory, "baz", "fix"); @@ -507,14 +469,18 @@ public void testSubscribeFailure() throws Exception { ReflectionUtils.doWithMethods(MqttPahoMessageDrivenChannelAdapter.class, m -> { m.setAccessible(true); method.set(m); - }, m -> m.getName().equals("connectAndSubscribe")); + }, m -> m.getName().equals("connect")); assertThat(method.get()).isNotNull(); - Condition subscribeFailed = new Condition<>(ex -> - ((MqttException) ex.getCause()).getReasonCode() == MqttException.REASON_CODE_SUBSCRIBE_FAILED, - "expected the reason code to be REASON_CODE_SUBSCRIBE_FAILED"); - assertThatExceptionOfType(InvocationTargetException.class).isThrownBy(() -> method.get().invoke(adapter)) - .withCauseInstanceOf(MqttException.class) - .is(subscribeFailed); + method.get().invoke(adapter); + ReflectionUtils.doWithMethods(MqttPahoMessageDrivenChannelAdapter.class, m -> { + m.setAccessible(true); + method.set(m); + }, m -> m.getName().equals("subscribe")); + assertThat(method.get()).isNotNull(); + ApplicationEventPublisher eventPublisher = mock(ApplicationEventPublisher.class); + adapter.setApplicationEventPublisher(eventPublisher); + method.get().invoke(adapter); + verify(eventPublisher).publishEvent(any(MqttConnectionFailedEvent.class)); } @Test @@ -535,19 +501,14 @@ public void testDifferentQos() throws Exception { connectOptions.setWill("foo", "bar".getBytes(), 2, true); factory = spy(factory); - MqttAsyncClient aClient = mock(MqttAsyncClient.class); - final MqttClient client = mock(MqttClient.class); - willAnswer(invocation -> client).given(factory).getClientInstance(anyString(), anyString()); + final MqttAsyncClient client = mock(MqttAsyncClient.class); + willReturn(client).given(factory).getAsyncClientInstance(anyString(), anyString()); given(client.isConnected()).willReturn(true); - new DirectFieldAccessor(client).setPropertyValue("aClient", aClient); - willAnswer(new CallsRealMethods()).given(client).connect(any(MqttConnectOptions.class)); - willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class)); - willAnswer(new CallsRealMethods()).given(client).subscribe(any(String[].class), any(int[].class), isNull()); - willReturn(alwaysComplete).given(aClient).connect(any(MqttConnectOptions.class), any(), any()); + willReturn(alwaysComplete).given(client).connect(any(MqttConnectOptions.class)); IMqttToken token = mock(IMqttToken.class); given(token.getGrantedQos()).willReturn(new int[]{ 2, 0 }); - willReturn(token).given(aClient).subscribe(any(String[].class), any(int[].class), isNull(), isNull(), any()); + willReturn(token).given(client).subscribe(any(String[].class), any(int[].class), any()); MqttPahoMessageDrivenChannelAdapter adapter = new MqttPahoMessageDrivenChannelAdapter("foo", "bar", factory, "baz", "fix"); @@ -555,7 +516,13 @@ public void testDifferentQos() throws Exception { ReflectionUtils.doWithMethods(MqttPahoMessageDrivenChannelAdapter.class, m -> { m.setAccessible(true); method.set(m); - }, m -> m.getName().equals("connectAndSubscribe")); + }, m -> m.getName().equals("connect")); + assertThat(method.get()).isNotNull(); + method.get().invoke(adapter); + ReflectionUtils.doWithMethods(MqttPahoMessageDrivenChannelAdapter.class, m -> { + m.setAccessible(true); + method.set(m); + }, m -> m.getName().equals("subscribe")); assertThat(method.get()).isNotNull(); LogAccessor logger = spy(TestUtils.getPropertyValue(adapter, "logger", LogAccessor.class)); new DirectFieldAccessor(adapter).setPropertyValue("logger", logger); @@ -566,62 +533,19 @@ public void testDifferentQos() throws Exception { logMessage.get() .equals("Granted QOS different to Requested QOS; topics: [baz, fix] " + "requested: [1, 1] granted: [2, 0]"))); - verify(client).setTimeToWait(30_000L); new DirectFieldAccessor(adapter).setPropertyValue("running", Boolean.TRUE); adapter.stop(); verify(client).disconnectForcibly(5_000L); } - @Test - public void testNoNPEOnReconnectAndStopRaceCondition() throws Exception { - final IMqttClient client = mock(IMqttClient.class); - MqttPahoMessageDrivenChannelAdapter adapter = buildAdapterIn(client, null, ConsumerStopAction.UNSUBSCRIBE_NEVER); - adapter.setRecoveryInterval(10); - - MqttException mqttException = new MqttException(MqttException.REASON_CODE_SUBSCRIBE_FAILED); - - willThrow(mqttException) - .given(client) - .subscribe(any(), ArgumentMatchers.any()); - - LogAccessor logger = spy(TestUtils.getPropertyValue(adapter, "logger", LogAccessor.class)); - new DirectFieldAccessor(adapter).setPropertyValue("logger", logger); - CountDownLatch exceptionLatch = new CountDownLatch(1); - ArgumentCaptor mqttExceptionArgumentCaptor = ArgumentCaptor.forClass(MqttException.class); - willAnswer(i -> { - exceptionLatch.countDown(); - return null; - }) - .given(logger) - .error(mqttExceptionArgumentCaptor.capture(), eq("Exception while connecting and subscribing")); - - ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler(); - taskScheduler.initialize(); - adapter.setTaskScheduler(taskScheduler); - - adapter.setApplicationEventPublisher(event -> { - if (event instanceof MqttConnectionFailedEvent) { - adapter.destroy(); - } - }); - adapter.start(); - - assertThat(exceptionLatch.await(10, TimeUnit.SECONDS)).isTrue(); - assertThat(mqttExceptionArgumentCaptor.getValue()) - .isNotNull() - .isSameAs(mqttException); - - taskScheduler.destroy(); - } - - private MqttPahoMessageDrivenChannelAdapter buildAdapterIn(final IMqttClient client, Boolean cleanSession, - ConsumerStopAction action) { + private MqttPahoMessageDrivenChannelAdapter buildAdapterIn(final IMqttAsyncClient client, Boolean cleanSession, + ConsumerStopAction action) throws MqttException { DefaultMqttPahoClientFactory factory = new DefaultMqttPahoClientFactory() { @Override - public IMqttClient getClientInstance(String uri, String clientId) throws MqttException { + public IMqttAsyncClient getAsyncClientInstance(String uri, String clientId) { return client; } @@ -636,10 +560,13 @@ public IMqttClient getClientInstance(String uri, String clientId) throws MqttExc } factory.setConnectionOptions(connectOptions); given(client.isConnected()).willReturn(true); + IMqttToken token = mock(IMqttToken.class); + given(client.connect(any(MqttConnectOptions.class))).willReturn(token); + given(client.subscribe(any(String[].class), any(int[].class), any())).willReturn(token); + given(token.getGrantedQos()).willReturn(new int[]{ 2 }); MqttPahoMessageDrivenChannelAdapter adapter = new MqttPahoMessageDrivenChannelAdapter("client", factory, "foo"); adapter.setApplicationEventPublisher(mock(ApplicationEventPublisher.class)); adapter.setOutputChannel(new NullChannel()); - adapter.setTaskScheduler(mock(TaskScheduler.class)); adapter.afterPropertiesSet(); return adapter; } @@ -663,16 +590,16 @@ public IMqttAsyncClient getAsyncClientInstance(String uri, String clientId) { return adapter; } - private void verifyUnsubscribe(IMqttClient client) throws Exception { + private void verifyUnsubscribe(IMqttAsyncClient client) throws Exception { verify(client).connect(any(MqttConnectOptions.class)); - verify(client).subscribe(any(String[].class), any(int[].class)); + verify(client).subscribe(any(String[].class), any(int[].class), any()); verify(client).unsubscribe(any(String[].class)); verify(client).disconnectForcibly(anyLong()); } - private void verifyNotUnsubscribe(IMqttClient client) throws Exception { + private void verifyNotUnsubscribe(IMqttAsyncClient client) throws Exception { verify(client).connect(any(MqttConnectOptions.class)); - verify(client).subscribe(any(String[].class), any(int[].class)); + verify(client).subscribe(any(String[].class), any(int[].class), any()); verify(client, never()).unsubscribe(any(String[].class)); verify(client).disconnectForcibly(anyLong()); } diff --git a/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/config/xml/MqttMessageDrivenChannelAdapterParserTests-context.xml b/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/config/xml/MqttMessageDrivenChannelAdapterParserTests-context.xml index d8da6b510fb..e0e92a480e1 100644 --- a/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/config/xml/MqttMessageDrivenChannelAdapterParserTests-context.xml +++ b/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/config/xml/MqttMessageDrivenChannelAdapterParserTests-context.xml @@ -16,7 +16,6 @@ client-id="foo" url="tcp://localhost:1883" client-factory="clientFactory" - recovery-interval="5000" channel="out" />