Skip to content

Commit ae706f3

Browse files
committed
Allow MethodReference to define a more flexible signature
This commit moves MethodReference to an interface with a default implementation that relies on a MethodSpec. Such an arrangement avoid the need of specifying attributes of the method such as whether it is static or not. The resolution of the invocation block now takes an ArgumentCodeGenerator rather than the raw arguments. Doing so gives the opportunity to create more flexible signatures. See gh-29005
1 parent 8a4a89b commit ae706f3

File tree

16 files changed

+469
-428
lines changed

16 files changed

+469
-428
lines changed

spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import org.springframework.aop.framework.AopInfrastructureBean;
2828
import org.springframework.aot.generate.MethodReference;
29+
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
2930
import org.springframework.aot.test.generate.TestGenerationContext;
3031
import org.springframework.aot.test.generate.compile.Compiled;
3132
import org.springframework.aot.test.generate.compile.TestCompiler;
@@ -139,11 +140,14 @@ private void compile(BiConsumer<DefaultListableBeanFactory, Compiled> result) {
139140
MethodReference methodReference = this.beanFactoryInitializationCode
140141
.getInitializers().get(0);
141142
this.beanFactoryInitializationCode.getTypeBuilder().set(type -> {
143+
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
144+
ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"),
145+
this.beanFactoryInitializationCode.getClassName());
142146
type.addModifiers(Modifier.PUBLIC);
143147
type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class));
144148
type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC)
145149
.addParameter(DefaultListableBeanFactory.class, "beanFactory")
146-
.addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory")))
150+
.addStatement(methodInvocation)
147151
.build());
148152
});
149153
this.generationContext.writeGeneratedContent();

spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.springframework.aot.generate.GeneratedMethods;
2626
import org.springframework.aot.generate.GenerationContext;
2727
import org.springframework.aot.generate.MethodReference;
28+
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
2829
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
2930
import org.springframework.javapoet.ClassName;
3031
import org.springframework.javapoet.CodeBlock;
@@ -81,9 +82,11 @@ private void generateRegisterMethod(MethodSpec.Builder method,
8182
MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator
8283
.generateBeanDefinitionMethod(generationContext,
8384
beanRegistrationsCode);
85+
CodeBlock methodInvocation = beanDefinitionMethod.toInvokeCodeBlock(
86+
ArgumentCodeGenerator.none(), beanRegistrationsCode.getClassName());
8487
code.addStatement("$L.registerBeanDefinition($S, $L)",
8588
BEAN_FACTORY_PARAMETER_NAME, beanName,
86-
beanDefinitionMethod.toInvokeCodeBlock());
89+
methodInvocation);
8790
});
8891
method.addCode(code.build());
8992
}

spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.springframework.aot.generate.AccessVisibility;
2525
import org.springframework.aot.generate.GenerationContext;
2626
import org.springframework.aot.generate.MethodReference;
27+
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
2728
import org.springframework.beans.factory.FactoryBean;
2829
import org.springframework.beans.factory.config.BeanDefinition;
2930
import org.springframework.beans.factory.config.BeanDefinitionHolder;
@@ -156,7 +157,7 @@ protected CodeBlock generateValueCode(GenerationContext generationContext,
156157
MethodReference generatedMethod = methodGenerator
157158
.generateBeanDefinitionMethod(generationContext,
158159
this.beanRegistrationsCode);
159-
return generatedMethod.toInvokeCodeBlock();
160+
return generatedMethod.toInvokeCodeBlock(ArgumentCodeGenerator.none());
160161
}
161162
return null;
162163
}

spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.aot.generate.GeneratedMethod;
2929
import org.springframework.aot.generate.GeneratedMethods;
3030
import org.springframework.aot.generate.GenerationContext;
31+
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
3132
import org.springframework.aot.hint.ExecutableMode;
3233
import org.springframework.beans.factory.support.InstanceSupplier;
3334
import org.springframework.beans.factory.support.RegisteredBean;
@@ -297,7 +298,8 @@ private CodeBlock generateNewInstanceCodeForMethod(boolean dependsOnBean,
297298
}
298299

299300
private CodeBlock generateReturnStatement(GeneratedMethod generatedMethod) {
300-
return generatedMethod.toMethodReference().toInvokeCodeBlock();
301+
return generatedMethod.toMethodReference().toInvokeCodeBlock(
302+
ArgumentCodeGenerator.none(), this.className);
301303
}
302304

303305
private CodeBlock generateWithGeneratorCode(boolean hasArguments, CodeBlock newInstance) {

spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.junit.jupiter.api.Test;
2525

2626
import org.springframework.aot.generate.MethodReference;
27+
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
2728
import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
2829
import org.springframework.aot.test.generate.TestGenerationContext;
2930
import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess;
@@ -161,13 +162,16 @@ private void compile(RegisteredBean registeredBean,
161162
Class<?> target = registeredBean.getBeanClass();
162163
MethodReference methodReference = this.beanRegistrationCode.getInstancePostProcessors().get(0);
163164
this.beanRegistrationCode.getTypeBuilder().set(type -> {
165+
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
166+
ArgumentCodeGenerator.of(RegisteredBean.class, "registeredBean").and(target, "instance"),
167+
this.beanRegistrationCode.getClassName());
164168
type.addModifiers(Modifier.PUBLIC);
165169
type.addSuperinterface(ParameterizedTypeName.get(BiFunction.class, RegisteredBean.class, target, target));
166170
type.addMethod(MethodSpec.methodBuilder("apply")
167171
.addModifiers(Modifier.PUBLIC)
168172
.addParameter(RegisteredBean.class, "registeredBean")
169173
.addParameter(target, "instance").returns(target)
170-
.addStatement("return $L", methodReference.toInvokeCodeBlock(CodeBlock.of("registeredBean"), CodeBlock.of("instance")))
174+
.addStatement("return $L", methodInvocation)
171175
.build());
172176

173177
});

spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.springframework.aot.generate.GeneratedMethod;
3131
import org.springframework.aot.generate.GenerationContext;
3232
import org.springframework.aot.generate.MethodReference;
33+
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
3334
import org.springframework.aot.test.generate.TestGenerationContext;
3435
import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess;
3536
import org.springframework.aot.test.generate.compile.Compiled;
@@ -414,12 +415,14 @@ private RegisteredBean registerBean(RootBeanDefinition beanDefinition) {
414415
private void compile(MethodReference method,
415416
BiConsumer<RootBeanDefinition, Compiled> result) {
416417
this.beanRegistrationsCode.getTypeBuilder().set(type -> {
418+
CodeBlock methodInvocation = method.toInvokeCodeBlock(ArgumentCodeGenerator.none(),
419+
this.beanRegistrationsCode.getClassName());
417420
type.addModifiers(Modifier.PUBLIC);
418421
type.addSuperinterface(ParameterizedTypeName.get(Supplier.class, BeanDefinition.class));
419422
type.addMethod(MethodSpec.methodBuilder("get")
420423
.addModifiers(Modifier.PUBLIC)
421424
.returns(BeanDefinition.class)
422-
.addCode("return $L;", method.toInvokeCodeBlock()).build());
425+
.addCode("return $L;", methodInvocation).build());
423426
});
424427
this.generationContext.writeGeneratedContent();
425428
TestCompiler.forSystem().withFiles(this.generationContext.getGeneratedFiles()).compile(compiled ->

spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.aot.generate.ClassNameGenerator;
3232
import org.springframework.aot.generate.GenerationContext;
3333
import org.springframework.aot.generate.MethodReference;
34+
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
3435
import org.springframework.aot.test.generate.TestGenerationContext;
3536
import org.springframework.aot.test.generate.TestTarget;
3637
import org.springframework.aot.test.generate.compile.Compiled;
@@ -155,11 +156,14 @@ private void compile(
155156
MethodReference methodReference = this.beanFactoryInitializationCode
156157
.getInitializers().get(0);
157158
this.beanFactoryInitializationCode.getTypeBuilder().set(type -> {
159+
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
160+
ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"),
161+
this.beanFactoryInitializationCode.getClassName());
158162
type.addModifiers(Modifier.PUBLIC);
159163
type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class));
160164
type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC)
161165
.addParameter(DefaultListableBeanFactory.class, "beanFactory")
162-
.addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory")))
166+
.addStatement(methodInvocation)
163167
.build());
164168
});
165169
this.generationContext.writeGeneratedContent();

spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.springframework.aot.generate.GenerationContext;
2626
import org.springframework.aot.generate.MethodReference;
2727
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
28+
import org.springframework.javapoet.ClassName;
2829

2930
/**
3031
* Mock {@link BeanFactoryInitializationCode} implementation.
@@ -46,6 +47,9 @@ public MockBeanFactoryInitializationCode(GenerationContext generationContext) {
4647
.addForFeature("TestCode", this.typeBuilder);
4748
}
4849

50+
public ClassName getClassName() {
51+
return this.generatedClass.getName();
52+
}
4953

5054
public DeferredTypeBuilder getTypeBuilder() {
5155
return this.typeBuilder;

spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.springframework.aot.generate.GeneratedMethods;
2626
import org.springframework.aot.generate.GenerationContext;
2727
import org.springframework.aot.generate.MethodReference;
28+
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
2829
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
2930
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
3031
import org.springframework.context.ApplicationContextInitializer;
@@ -88,12 +89,17 @@ private CodeBlock generateInitializeCode() {
8889
BEAN_FACTORY_VARIABLE, ContextAnnotationAutowireCandidateResolver.class);
8990
code.addStatement("$L.setDependencyComparator($T.INSTANCE)",
9091
BEAN_FACTORY_VARIABLE, AnnotationAwareOrderComparator.class);
92+
ArgumentCodeGenerator argCodeGenerator = createInitializerMethodsArgumentCodeGenerator();
9193
for (MethodReference initializer : this.initializers) {
92-
code.addStatement(initializer.toInvokeCodeBlock(CodeBlock.of(BEAN_FACTORY_VARIABLE)));
94+
code.addStatement(initializer.toInvokeCodeBlock(argCodeGenerator, this.generatedClass.getName()));
9395
}
9496
return code.build();
9597
}
9698

99+
private ArgumentCodeGenerator createInitializerMethodsArgumentCodeGenerator() {
100+
return ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, BEAN_FACTORY_VARIABLE);
101+
}
102+
97103
GeneratedClass getGeneratedClass() {
98104
return this.generatedClass;
99105
}

spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.junit.jupiter.api.Test;
2828

2929
import org.springframework.aot.generate.MethodReference;
30+
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
3031
import org.springframework.aot.hint.ResourcePatternHint;
3132
import org.springframework.aot.test.generate.TestGenerationContext;
3233
import org.springframework.aot.test.generate.compile.Compiled;
@@ -162,11 +163,14 @@ private void assertPostProcessorEntry(BeanPostProcessor postProcessor, Class<?>
162163
private void compile(BiConsumer<Consumer<DefaultListableBeanFactory>, Compiled> result) {
163164
MethodReference methodReference = this.beanFactoryInitializationCode.getInitializers().get(0);
164165
this.beanFactoryInitializationCode.getTypeBuilder().set(type -> {
166+
CodeBlock methodInvocation = methodReference.toInvokeCodeBlock(
167+
ArgumentCodeGenerator.of(DefaultListableBeanFactory.class, "beanFactory"),
168+
this.beanFactoryInitializationCode.getClassName());
165169
type.addModifiers(Modifier.PUBLIC);
166170
type.addSuperinterface(ParameterizedTypeName.get(Consumer.class, DefaultListableBeanFactory.class));
167171
type.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC)
168172
.addParameter(DefaultListableBeanFactory.class, "beanFactory")
169-
.addStatement(methodReference.toInvokeCodeBlock(CodeBlock.of("beanFactory")))
173+
.addStatement(methodInvocation)
170174
.build());
171175
});
172176
this.generationContext.writeGeneratedContent();
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.aot.generate;
18+
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
22+
import javax.lang.model.element.Modifier;
23+
24+
import org.springframework.javapoet.ClassName;
25+
import org.springframework.javapoet.CodeBlock;
26+
import org.springframework.javapoet.MethodSpec;
27+
import org.springframework.javapoet.TypeName;
28+
import org.springframework.lang.Nullable;
29+
import org.springframework.util.Assert;
30+
31+
/**
32+
* Default {@link MethodReference} implementation based on a {@link MethodSpec}.
33+
*
34+
* @author Stephane Nicoll
35+
* @author Phillip Webb
36+
* @since 6.0
37+
*/
38+
public class DefaultMethodReference implements MethodReference {
39+
40+
private final MethodSpec method;
41+
42+
@Nullable
43+
private final ClassName declaringClass;
44+
45+
public DefaultMethodReference(MethodSpec method, @Nullable ClassName declaringClass) {
46+
this.method = method;
47+
this.declaringClass = declaringClass;
48+
}
49+
50+
@Override
51+
public CodeBlock toCodeBlock() {
52+
String methodName = this.method.name;
53+
if (isStatic()) {
54+
Assert.notNull(this.declaringClass, "static method reference must define a declaring class");
55+
return CodeBlock.of("$T::$L", this.declaringClass, methodName);
56+
}
57+
else {
58+
return CodeBlock.of("this::$L", methodName);
59+
}
60+
}
61+
62+
public CodeBlock toInvokeCodeBlock(ArgumentCodeGenerator argumentCodeGenerator,
63+
@Nullable ClassName targetClassName) {
64+
String methodName = this.method.name;
65+
CodeBlock.Builder code = CodeBlock.builder();
66+
if (isStatic()) {
67+
Assert.notNull(this.declaringClass, "static method reference must define a declaring class");
68+
if (isSameDeclaringClass(targetClassName)) {
69+
code.add("$L", methodName);
70+
}
71+
else {
72+
code.add("$T.$L", this.declaringClass, methodName);
73+
}
74+
}
75+
else {
76+
if (!isSameDeclaringClass(targetClassName)) {
77+
code.add(instantiateDeclaringClass(this.declaringClass));
78+
}
79+
code.add("$L", methodName);
80+
}
81+
code.add("(");
82+
addArguments(code, argumentCodeGenerator);
83+
code.add(")");
84+
return code.build();
85+
}
86+
87+
/**
88+
* Add the code for the method arguments using the specified
89+
* {@link ArgumentCodeGenerator} if necessary.
90+
* @param code the code builder to use to add method arguments
91+
* @param argumentCodeGenerator the code generator to use
92+
*/
93+
protected void addArguments(CodeBlock.Builder code, ArgumentCodeGenerator argumentCodeGenerator) {
94+
List<CodeBlock> arguments = new ArrayList<>();
95+
TypeName[] argumentTypes = this.method.parameters.stream()
96+
.map(parameter -> parameter.type).toArray(TypeName[]::new);
97+
for (int i = 0; i < argumentTypes.length; i++) {
98+
TypeName argumentType = argumentTypes[i];
99+
CodeBlock argumentCode = argumentCodeGenerator.generateCode(argumentType);
100+
if (argumentCode == null) {
101+
throw new IllegalArgumentException("Could not generate code for " + this
102+
+ ": parameter " + i + " of type " + argumentType + " is not supported");
103+
}
104+
arguments.add(argumentCode);
105+
}
106+
code.add(CodeBlock.join(arguments, ", "));
107+
}
108+
109+
protected CodeBlock instantiateDeclaringClass(ClassName declaringClass) {
110+
return CodeBlock.of("new $T().", declaringClass);
111+
}
112+
113+
private boolean isStatic() {
114+
return this.method.modifiers.contains(Modifier.STATIC);
115+
}
116+
117+
private boolean isSameDeclaringClass(ClassName declaringClass) {
118+
return this.declaringClass == null || this.declaringClass.equals(declaringClass);
119+
}
120+
121+
@Override
122+
public String toString() {
123+
String methodName = this.method.name;
124+
if (isStatic()) {
125+
return this.declaringClass + "::" + methodName;
126+
}
127+
else {
128+
return ((this.declaringClass != null)
129+
? "<" + this.declaringClass + ">" : "<instance>")
130+
+ "::" + methodName;
131+
}
132+
}
133+
134+
}

spring-core/src/main/java/org/springframework/aot/generate/GeneratedMethod.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
import java.util.function.Consumer;
2020

21-
import javax.lang.model.element.Modifier;
22-
2321
import org.springframework.javapoet.ClassName;
2422
import org.springframework.javapoet.MethodSpec;
2523
import org.springframework.util.Assert;
@@ -73,9 +71,7 @@ public String getName() {
7371
* @return a method reference
7472
*/
7573
public MethodReference toMethodReference() {
76-
return (this.methodSpec.modifiers.contains(Modifier.STATIC)
77-
? MethodReference.ofStatic(this.className, this.name)
78-
: MethodReference.of(this.className, this.name));
74+
return new DefaultMethodReference(this.methodSpec, this.className);
7975
}
8076

8177
/**

0 commit comments

Comments
 (0)