Skip to content

Use customized ConnectionFactory lookup to avoid AbstractR2dbcConfiguration proxying #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-r2dbc</artifactId>
<version>1.0.0.BUILD-SNAPSHOT</version>
<version>1.0.0.gh-95-SNAPSHOT</version>

<name>Spring Data R2DBC</name>
<description>Spring Data module for R2DBC.</description>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import java.util.Collections;
import java.util.Optional;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.convert.converter.Converter;
Expand All @@ -37,6 +40,7 @@
import org.springframework.data.relational.core.conversion.BasicRelationalConverter;
import org.springframework.data.relational.core.mapping.NamingStrategy;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand All @@ -48,8 +52,21 @@
* @see DatabaseClient
* @see org.springframework.data.r2dbc.repository.config.EnableR2dbcRepositories
*/
@Configuration
public abstract class AbstractR2dbcConfiguration {
@Configuration(proxyBeanMethods = false)
public abstract class AbstractR2dbcConfiguration implements ApplicationContextAware {

private static final String CONNECTION_FACTORY_BEAN_NAME = "connectionFactory";

private @Nullable ApplicationContext context;

/*
* (non-Javadoc)
* @see org.springframework.context.ApplicationContextAware#setApplicationContext(org.springframework.context.ApplicationContext)
*/
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.context = applicationContext;
}

/**
* Return a R2DBC {@link ConnectionFactory}. Annotate with {@link Bean} in case you want to expose a
Expand Down Expand Up @@ -91,7 +108,7 @@ public DatabaseClient databaseClient(ReactiveDataAccessStrategy dataAccessStrate
Assert.notNull(exceptionTranslator, "ExceptionTranslator must not be null!");

return DatabaseClient.builder() //
.connectionFactory(connectionFactory()) //
.connectionFactory(lookupConnectionFactory()) //
.dataAccessStrategy(dataAccessStrategy) //
.exceptionTranslator(exceptionTranslator) //
.build();
Expand Down Expand Up @@ -137,7 +154,7 @@ public ReactiveDataAccessStrategy reactiveDataAccessStrategy(RelationalMappingCo

MappingR2dbcConverter converter = new MappingR2dbcConverter(mappingContext, r2dbcCustomConversions);

return new DefaultReactiveDataAccessStrategy(getDialect(connectionFactory()), converter);
return new DefaultReactiveDataAccessStrategy(getDialect(lookupConnectionFactory()), converter);
}

/**
Expand All @@ -160,7 +177,7 @@ public R2dbcCustomConversions r2dbcCustomConversions() {
*/
protected StoreConversions getStoreConversions() {

Dialect dialect = getDialect(connectionFactory());
Dialect dialect = getDialect(lookupConnectionFactory());
return StoreConversions.of(dialect.getSimpleTypeHolder(), R2dbcCustomConversions.STORE_CONVERTERS);
}

Expand All @@ -172,6 +189,23 @@ protected StoreConversions getStoreConversions() {
*/
@Bean
public R2dbcExceptionTranslator exceptionTranslator() {
return new SqlErrorCodeR2dbcExceptionTranslator(connectionFactory());
return new SqlErrorCodeR2dbcExceptionTranslator(lookupConnectionFactory());
}

ConnectionFactory lookupConnectionFactory() {

ApplicationContext context = this.context;
Assert.notNull(context, "ApplicationContext is not yet initialized");

String[] beanNamesForType = context.getBeanNamesForType(ConnectionFactory.class);

for (String beanName : beanNamesForType) {

if (beanName.equals(CONNECTION_FACTORY_BEAN_NAME)) {
return context.getBean(CONNECTION_FACTORY_BEAN_NAME, ConnectionFactory.class);
}
}

return connectionFactory();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright 2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.r2dbc.config;

import static org.assertj.core.api.Assertions.*;
import static org.mockito.Mockito.*;

import io.r2dbc.h2.H2ConnectionConfiguration;
import io.r2dbc.h2.H2ConnectionFactory;
import io.r2dbc.spi.ConnectionFactory;

import org.junit.Test;

import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.r2dbc.function.DatabaseClient;

/**
* Tests for {@link AbstractR2dbcConfiguration}.
*
* @author Mark Paluch
*/
public class R2dbcConfigurationIntegrationTests {

@Test // gh-95
public void shouldLookupConnectionFactoryThroughLocalCall() {

AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(
NonBeanConnectionFactoryConfiguration.class);

context.getBean(DatabaseClient.class);

NonBeanConnectionFactoryConfiguration bean = context.getBean(NonBeanConnectionFactoryConfiguration.class);

assertThat(context.getBeanNamesForType(ConnectionFactory.class)).isEmpty();
assertThat(bean.callCounter).isGreaterThan(2);

context.stop();
}

@Test // gh-95
public void shouldLookupConnectionFactoryThroughLocalCallForExistingCustomBeans() {

AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(
CustomConnectionFactoryBeanNameConfiguration.class);

context.getBean(DatabaseClient.class);

CustomConnectionFactoryBeanNameConfiguration bean = context
.getBean(CustomConnectionFactoryBeanNameConfiguration.class);

assertThat(context.getBeanNamesForType(ConnectionFactory.class)).hasSize(1).contains("myCustomBean");
assertThat(bean.callCounter).isGreaterThan(2);

ConnectionFactoryWrapper wrapper = context.getBean(ConnectionFactoryWrapper.class);
assertThat(wrapper.connectionFactory).isExactlyInstanceOf(H2ConnectionFactory.class);

context.stop();
}

@Test // gh-95
public void shouldRegisterConnectionFactory() {

AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(
BeanConnectionFactoryConfiguration.class);

context.getBean(DatabaseClient.class);

BeanConnectionFactoryConfiguration bean = context.getBean(BeanConnectionFactoryConfiguration.class);

assertThat(bean.callCounter).isEqualTo(1);
assertThat(context.getBeanNamesForType(ConnectionFactory.class)).hasSize(1);

context.stop();
}

@Configuration(proxyBeanMethods = false)
static class NonBeanConnectionFactoryConfiguration extends AbstractR2dbcConfiguration {

int callCounter;

@Override
public ConnectionFactory connectionFactory() {

callCounter++;
return new H2ConnectionFactory(
H2ConnectionConfiguration.builder().inMemory("foo").username("sa").password("").build());
}
}

@Configuration(proxyBeanMethods = false)
static class CustomConnectionFactoryBeanNameConfiguration extends AbstractR2dbcConfiguration {

int callCounter;

@Bean
public ConnectionFactory myCustomBean() {
return mock(ConnectionFactory.class);
}

@Override
public ConnectionFactory connectionFactory() {

callCounter++;
return new H2ConnectionFactory(
H2ConnectionConfiguration.builder().inMemory("foo").username("sa").password("").build());
}

@Bean
ConnectionFactoryWrapper wrapper() {
return new ConnectionFactoryWrapper(lookupConnectionFactory());
}
}

static class ConnectionFactoryWrapper {
ConnectionFactory connectionFactory;

ConnectionFactoryWrapper(ConnectionFactory connectionFactory) {
this.connectionFactory = connectionFactory;
}
}

@Configuration(proxyBeanMethods = false)
static class BeanConnectionFactoryConfiguration extends NonBeanConnectionFactoryConfiguration {

@Override
@Bean
public ConnectionFactory connectionFactory() {
return super.connectionFactory();
}
}

}