diff --git a/pom.xml b/pom.xml
index a977b39d..ece01a46 100644
--- a/pom.xml
+++ b/pom.xml
@@ -7,7 +7,7 @@
org.springframework.data
spring-data-r2dbc
- 1.0.0.BUILD-SNAPSHOT
+ 1.0.0.gh-95-SNAPSHOT
Spring Data R2DBC
Spring Data module for R2DBC.
diff --git a/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java b/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java
index b700db91..2f197d62 100644
--- a/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java
+++ b/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java
@@ -20,6 +20,9 @@
import java.util.Collections;
import java.util.Optional;
+import org.springframework.beans.BeansException;
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.convert.converter.Converter;
@@ -37,6 +40,7 @@
import org.springframework.data.relational.core.conversion.BasicRelationalConverter;
import org.springframework.data.relational.core.mapping.NamingStrategy;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
+import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
@@ -48,8 +52,21 @@
* @see DatabaseClient
* @see org.springframework.data.r2dbc.repository.config.EnableR2dbcRepositories
*/
-@Configuration
-public abstract class AbstractR2dbcConfiguration {
+@Configuration(proxyBeanMethods = false)
+public abstract class AbstractR2dbcConfiguration implements ApplicationContextAware {
+
+ private static final String CONNECTION_FACTORY_BEAN_NAME = "connectionFactory";
+
+ private @Nullable ApplicationContext context;
+
+ /*
+ * (non-Javadoc)
+ * @see org.springframework.context.ApplicationContextAware#setApplicationContext(org.springframework.context.ApplicationContext)
+ */
+ @Override
+ public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
+ this.context = applicationContext;
+ }
/**
* Return a R2DBC {@link ConnectionFactory}. Annotate with {@link Bean} in case you want to expose a
@@ -91,7 +108,7 @@ public DatabaseClient databaseClient(ReactiveDataAccessStrategy dataAccessStrate
Assert.notNull(exceptionTranslator, "ExceptionTranslator must not be null!");
return DatabaseClient.builder() //
- .connectionFactory(connectionFactory()) //
+ .connectionFactory(lookupConnectionFactory()) //
.dataAccessStrategy(dataAccessStrategy) //
.exceptionTranslator(exceptionTranslator) //
.build();
@@ -137,7 +154,7 @@ public ReactiveDataAccessStrategy reactiveDataAccessStrategy(RelationalMappingCo
MappingR2dbcConverter converter = new MappingR2dbcConverter(mappingContext, r2dbcCustomConversions);
- return new DefaultReactiveDataAccessStrategy(getDialect(connectionFactory()), converter);
+ return new DefaultReactiveDataAccessStrategy(getDialect(lookupConnectionFactory()), converter);
}
/**
@@ -160,7 +177,7 @@ public R2dbcCustomConversions r2dbcCustomConversions() {
*/
protected StoreConversions getStoreConversions() {
- Dialect dialect = getDialect(connectionFactory());
+ Dialect dialect = getDialect(lookupConnectionFactory());
return StoreConversions.of(dialect.getSimpleTypeHolder(), R2dbcCustomConversions.STORE_CONVERTERS);
}
@@ -172,6 +189,23 @@ protected StoreConversions getStoreConversions() {
*/
@Bean
public R2dbcExceptionTranslator exceptionTranslator() {
- return new SqlErrorCodeR2dbcExceptionTranslator(connectionFactory());
+ return new SqlErrorCodeR2dbcExceptionTranslator(lookupConnectionFactory());
+ }
+
+ ConnectionFactory lookupConnectionFactory() {
+
+ ApplicationContext context = this.context;
+ Assert.notNull(context, "ApplicationContext is not yet initialized");
+
+ String[] beanNamesForType = context.getBeanNamesForType(ConnectionFactory.class);
+
+ for (String beanName : beanNamesForType) {
+
+ if (beanName.equals(CONNECTION_FACTORY_BEAN_NAME)) {
+ return context.getBean(CONNECTION_FACTORY_BEAN_NAME, ConnectionFactory.class);
+ }
+ }
+
+ return connectionFactory();
}
}
diff --git a/src/test/java/org/springframework/data/r2dbc/config/R2dbcConfigurationIntegrationTests.java b/src/test/java/org/springframework/data/r2dbc/config/R2dbcConfigurationIntegrationTests.java
new file mode 100644
index 00000000..2295cbd7
--- /dev/null
+++ b/src/test/java/org/springframework/data/r2dbc/config/R2dbcConfigurationIntegrationTests.java
@@ -0,0 +1,147 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.data.r2dbc.config;
+
+import static org.assertj.core.api.Assertions.*;
+import static org.mockito.Mockito.*;
+
+import io.r2dbc.h2.H2ConnectionConfiguration;
+import io.r2dbc.h2.H2ConnectionFactory;
+import io.r2dbc.spi.ConnectionFactory;
+
+import org.junit.Test;
+
+import org.springframework.context.annotation.AnnotationConfigApplicationContext;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.data.r2dbc.function.DatabaseClient;
+
+/**
+ * Tests for {@link AbstractR2dbcConfiguration}.
+ *
+ * @author Mark Paluch
+ */
+public class R2dbcConfigurationIntegrationTests {
+
+ @Test // gh-95
+ public void shouldLookupConnectionFactoryThroughLocalCall() {
+
+ AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(
+ NonBeanConnectionFactoryConfiguration.class);
+
+ context.getBean(DatabaseClient.class);
+
+ NonBeanConnectionFactoryConfiguration bean = context.getBean(NonBeanConnectionFactoryConfiguration.class);
+
+ assertThat(context.getBeanNamesForType(ConnectionFactory.class)).isEmpty();
+ assertThat(bean.callCounter).isGreaterThan(2);
+
+ context.stop();
+ }
+
+ @Test // gh-95
+ public void shouldLookupConnectionFactoryThroughLocalCallForExistingCustomBeans() {
+
+ AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(
+ CustomConnectionFactoryBeanNameConfiguration.class);
+
+ context.getBean(DatabaseClient.class);
+
+ CustomConnectionFactoryBeanNameConfiguration bean = context
+ .getBean(CustomConnectionFactoryBeanNameConfiguration.class);
+
+ assertThat(context.getBeanNamesForType(ConnectionFactory.class)).hasSize(1).contains("myCustomBean");
+ assertThat(bean.callCounter).isGreaterThan(2);
+
+ ConnectionFactoryWrapper wrapper = context.getBean(ConnectionFactoryWrapper.class);
+ assertThat(wrapper.connectionFactory).isExactlyInstanceOf(H2ConnectionFactory.class);
+
+ context.stop();
+ }
+
+ @Test // gh-95
+ public void shouldRegisterConnectionFactory() {
+
+ AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(
+ BeanConnectionFactoryConfiguration.class);
+
+ context.getBean(DatabaseClient.class);
+
+ BeanConnectionFactoryConfiguration bean = context.getBean(BeanConnectionFactoryConfiguration.class);
+
+ assertThat(bean.callCounter).isEqualTo(1);
+ assertThat(context.getBeanNamesForType(ConnectionFactory.class)).hasSize(1);
+
+ context.stop();
+ }
+
+ @Configuration(proxyBeanMethods = false)
+ static class NonBeanConnectionFactoryConfiguration extends AbstractR2dbcConfiguration {
+
+ int callCounter;
+
+ @Override
+ public ConnectionFactory connectionFactory() {
+
+ callCounter++;
+ return new H2ConnectionFactory(
+ H2ConnectionConfiguration.builder().inMemory("foo").username("sa").password("").build());
+ }
+ }
+
+ @Configuration(proxyBeanMethods = false)
+ static class CustomConnectionFactoryBeanNameConfiguration extends AbstractR2dbcConfiguration {
+
+ int callCounter;
+
+ @Bean
+ public ConnectionFactory myCustomBean() {
+ return mock(ConnectionFactory.class);
+ }
+
+ @Override
+ public ConnectionFactory connectionFactory() {
+
+ callCounter++;
+ return new H2ConnectionFactory(
+ H2ConnectionConfiguration.builder().inMemory("foo").username("sa").password("").build());
+ }
+
+ @Bean
+ ConnectionFactoryWrapper wrapper() {
+ return new ConnectionFactoryWrapper(lookupConnectionFactory());
+ }
+ }
+
+ static class ConnectionFactoryWrapper {
+ ConnectionFactory connectionFactory;
+
+ ConnectionFactoryWrapper(ConnectionFactory connectionFactory) {
+ this.connectionFactory = connectionFactory;
+ }
+ }
+
+ @Configuration(proxyBeanMethods = false)
+ static class BeanConnectionFactoryConfiguration extends NonBeanConnectionFactoryConfiguration {
+
+ @Override
+ @Bean
+ public ConnectionFactory connectionFactory() {
+ return super.connectionFactory();
+ }
+ }
+
+}