diff --git a/pom.xml b/pom.xml index a977b39d..ece01a46 100644 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ org.springframework.data spring-data-r2dbc - 1.0.0.BUILD-SNAPSHOT + 1.0.0.gh-95-SNAPSHOT Spring Data R2DBC Spring Data module for R2DBC. diff --git a/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java b/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java index b700db91..2f197d62 100644 --- a/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java +++ b/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java @@ -20,6 +20,9 @@ import java.util.Collections; import java.util.Optional; +import org.springframework.beans.BeansException; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.convert.converter.Converter; @@ -37,6 +40,7 @@ import org.springframework.data.relational.core.conversion.BasicRelationalConverter; import org.springframework.data.relational.core.mapping.NamingStrategy; import org.springframework.data.relational.core.mapping.RelationalMappingContext; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -48,8 +52,21 @@ * @see DatabaseClient * @see org.springframework.data.r2dbc.repository.config.EnableR2dbcRepositories */ -@Configuration -public abstract class AbstractR2dbcConfiguration { +@Configuration(proxyBeanMethods = false) +public abstract class AbstractR2dbcConfiguration implements ApplicationContextAware { + + private static final String CONNECTION_FACTORY_BEAN_NAME = "connectionFactory"; + + private @Nullable ApplicationContext context; + + /* + * (non-Javadoc) + * @see org.springframework.context.ApplicationContextAware#setApplicationContext(org.springframework.context.ApplicationContext) + */ + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.context = applicationContext; + } /** * Return a R2DBC {@link ConnectionFactory}. Annotate with {@link Bean} in case you want to expose a @@ -91,7 +108,7 @@ public DatabaseClient databaseClient(ReactiveDataAccessStrategy dataAccessStrate Assert.notNull(exceptionTranslator, "ExceptionTranslator must not be null!"); return DatabaseClient.builder() // - .connectionFactory(connectionFactory()) // + .connectionFactory(lookupConnectionFactory()) // .dataAccessStrategy(dataAccessStrategy) // .exceptionTranslator(exceptionTranslator) // .build(); @@ -137,7 +154,7 @@ public ReactiveDataAccessStrategy reactiveDataAccessStrategy(RelationalMappingCo MappingR2dbcConverter converter = new MappingR2dbcConverter(mappingContext, r2dbcCustomConversions); - return new DefaultReactiveDataAccessStrategy(getDialect(connectionFactory()), converter); + return new DefaultReactiveDataAccessStrategy(getDialect(lookupConnectionFactory()), converter); } /** @@ -160,7 +177,7 @@ public R2dbcCustomConversions r2dbcCustomConversions() { */ protected StoreConversions getStoreConversions() { - Dialect dialect = getDialect(connectionFactory()); + Dialect dialect = getDialect(lookupConnectionFactory()); return StoreConversions.of(dialect.getSimpleTypeHolder(), R2dbcCustomConversions.STORE_CONVERTERS); } @@ -172,6 +189,23 @@ protected StoreConversions getStoreConversions() { */ @Bean public R2dbcExceptionTranslator exceptionTranslator() { - return new SqlErrorCodeR2dbcExceptionTranslator(connectionFactory()); + return new SqlErrorCodeR2dbcExceptionTranslator(lookupConnectionFactory()); + } + + ConnectionFactory lookupConnectionFactory() { + + ApplicationContext context = this.context; + Assert.notNull(context, "ApplicationContext is not yet initialized"); + + String[] beanNamesForType = context.getBeanNamesForType(ConnectionFactory.class); + + for (String beanName : beanNamesForType) { + + if (beanName.equals(CONNECTION_FACTORY_BEAN_NAME)) { + return context.getBean(CONNECTION_FACTORY_BEAN_NAME, ConnectionFactory.class); + } + } + + return connectionFactory(); } } diff --git a/src/test/java/org/springframework/data/r2dbc/config/R2dbcConfigurationIntegrationTests.java b/src/test/java/org/springframework/data/r2dbc/config/R2dbcConfigurationIntegrationTests.java new file mode 100644 index 00000000..2295cbd7 --- /dev/null +++ b/src/test/java/org/springframework/data/r2dbc/config/R2dbcConfigurationIntegrationTests.java @@ -0,0 +1,147 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.config; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; + +import io.r2dbc.h2.H2ConnectionConfiguration; +import io.r2dbc.h2.H2ConnectionFactory; +import io.r2dbc.spi.ConnectionFactory; + +import org.junit.Test; + +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.r2dbc.function.DatabaseClient; + +/** + * Tests for {@link AbstractR2dbcConfiguration}. + * + * @author Mark Paluch + */ +public class R2dbcConfigurationIntegrationTests { + + @Test // gh-95 + public void shouldLookupConnectionFactoryThroughLocalCall() { + + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext( + NonBeanConnectionFactoryConfiguration.class); + + context.getBean(DatabaseClient.class); + + NonBeanConnectionFactoryConfiguration bean = context.getBean(NonBeanConnectionFactoryConfiguration.class); + + assertThat(context.getBeanNamesForType(ConnectionFactory.class)).isEmpty(); + assertThat(bean.callCounter).isGreaterThan(2); + + context.stop(); + } + + @Test // gh-95 + public void shouldLookupConnectionFactoryThroughLocalCallForExistingCustomBeans() { + + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext( + CustomConnectionFactoryBeanNameConfiguration.class); + + context.getBean(DatabaseClient.class); + + CustomConnectionFactoryBeanNameConfiguration bean = context + .getBean(CustomConnectionFactoryBeanNameConfiguration.class); + + assertThat(context.getBeanNamesForType(ConnectionFactory.class)).hasSize(1).contains("myCustomBean"); + assertThat(bean.callCounter).isGreaterThan(2); + + ConnectionFactoryWrapper wrapper = context.getBean(ConnectionFactoryWrapper.class); + assertThat(wrapper.connectionFactory).isExactlyInstanceOf(H2ConnectionFactory.class); + + context.stop(); + } + + @Test // gh-95 + public void shouldRegisterConnectionFactory() { + + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext( + BeanConnectionFactoryConfiguration.class); + + context.getBean(DatabaseClient.class); + + BeanConnectionFactoryConfiguration bean = context.getBean(BeanConnectionFactoryConfiguration.class); + + assertThat(bean.callCounter).isEqualTo(1); + assertThat(context.getBeanNamesForType(ConnectionFactory.class)).hasSize(1); + + context.stop(); + } + + @Configuration(proxyBeanMethods = false) + static class NonBeanConnectionFactoryConfiguration extends AbstractR2dbcConfiguration { + + int callCounter; + + @Override + public ConnectionFactory connectionFactory() { + + callCounter++; + return new H2ConnectionFactory( + H2ConnectionConfiguration.builder().inMemory("foo").username("sa").password("").build()); + } + } + + @Configuration(proxyBeanMethods = false) + static class CustomConnectionFactoryBeanNameConfiguration extends AbstractR2dbcConfiguration { + + int callCounter; + + @Bean + public ConnectionFactory myCustomBean() { + return mock(ConnectionFactory.class); + } + + @Override + public ConnectionFactory connectionFactory() { + + callCounter++; + return new H2ConnectionFactory( + H2ConnectionConfiguration.builder().inMemory("foo").username("sa").password("").build()); + } + + @Bean + ConnectionFactoryWrapper wrapper() { + return new ConnectionFactoryWrapper(lookupConnectionFactory()); + } + } + + static class ConnectionFactoryWrapper { + ConnectionFactory connectionFactory; + + ConnectionFactoryWrapper(ConnectionFactory connectionFactory) { + this.connectionFactory = connectionFactory; + } + } + + @Configuration(proxyBeanMethods = false) + static class BeanConnectionFactoryConfiguration extends NonBeanConnectionFactoryConfiguration { + + @Override + @Bean + public ConnectionFactory connectionFactory() { + return super.connectionFactory(); + } + } + +}