Skip to content
This repository was archived by the owner on Feb 23, 2023. It is now read-only.

Commit fe9b0df

Browse files
committed
Introduce AotMergedContextConfiguration for tests in AOT mode
Closes gh-1262
1 parent f6d09a3 commit fe9b0df

File tree

7 files changed

+381
-37
lines changed

7 files changed

+381
-37
lines changed

spring-aot-test/src/main/java/org/springframework/aot/test/context/bootstrap/generator/TestContextAotProcessor.java

+40-6
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import com.squareup.javapoet.MethodSpec;
3232
import com.squareup.javapoet.ParameterizedTypeName;
3333
import com.squareup.javapoet.TypeName;
34+
import com.squareup.javapoet.WildcardTypeName;
3435

3536
import org.springframework.aot.context.bootstrap.generator.ApplicationContextAotProcessor;
3637
import org.springframework.aot.context.bootstrap.generator.infrastructure.BootstrapClass;
@@ -42,9 +43,11 @@
4243
import org.springframework.test.context.SmartContextLoader;
4344

4445
/**
45-
* A decorator of {@link ApplicationContextAotProcessor} that handles test contexts.
46+
* A decorator of {@link ApplicationContextAotProcessor} that handles test
47+
* application contexts.
4648
*
4749
* @author Stephane Nicoll
50+
* @author Sam Brannen
4851
*/
4952
public class TestContextAotProcessor {
5053

@@ -81,22 +84,25 @@ public void generateTestContexts(Iterable<Class<?>> testClasses, BootstrapWriter
8184
.withMethods(MethodSpec.constructorBuilder().build());
8285
entries.put(className, descriptor);
8386
}
84-
generateContextLoadersMapping(writerContext, entries);
87+
generateMappingMethods(writerContext, entries);
8588

8689
this.testNativeConfigurationRegistrar.processTestConfigurations(nativeConfigurationRegistry,
8790
entries.values().stream().map(TestContextConfigurationDescriptor::getContextConfiguration).collect(Collectors.toList()));
8891
}
8992

90-
private void generateContextLoadersMapping(BootstrapWriterContext writerContext, Map<ClassName, TestContextConfigurationDescriptor> entries) {
93+
private void generateMappingMethods(BootstrapWriterContext writerContext, Map<ClassName, TestContextConfigurationDescriptor> entries) {
9194
BootstrapWriterContext mainWriterContext = writerContext.fork(
9295
TEST_BOOTSTRAP_CLASS_NAME, (packageName) -> {
9396
ClassName mainClassName = ClassName.get(packageName, TEST_BOOTSTRAP_CLASS_NAME);
9497
return BootstrapClass.of(mainClassName, (type) -> type.addModifiers(Modifier.PUBLIC));
9598
});
99+
96100
BootstrapClass boostrapClass = mainWriterContext.getMainBootstrapClass();
97-
MethodSpec method = boostrapClass.addMethod(contextLoadersMappingMethod(entries));
101+
MethodSpec getContextLoaders = boostrapClass.addMethod(getContextLoadersBuilder(entries));
102+
MethodSpec getContextInitializers = boostrapClass.addMethod(getContextInitializersBuilder(entries));
98103
writerContext.getNativeConfigurationRegistry().reflection()
99-
.forGeneratedType(boostrapClass.getClassName()).withMethods(method);
104+
.forGeneratedType(boostrapClass.getClassName())
105+
.withMethods(getContextLoaders, getContextInitializers);
100106
}
101107

102108
protected ClassName generateTestContext(BootstrapWriterContext writerContext, Supplier<ClassName> fallbackClassName,
@@ -110,7 +116,7 @@ protected ClassName generateTestContext(BootstrapWriterContext writerContext, Su
110116
return mainBootstrapClass.getClassName();
111117
}
112118

113-
private MethodSpec.Builder contextLoadersMappingMethod(Map<ClassName, TestContextConfigurationDescriptor> entries) {
119+
private MethodSpec.Builder getContextLoadersBuilder(Map<ClassName, TestContextConfigurationDescriptor> entries) {
114120
Builder code = CodeBlock.builder();
115121
TypeName mapType = ParameterizedTypeName.get(ClassName.get(Map.class),
116122
ClassName.get(String.class), ParameterizedTypeName.get(Supplier.class, SmartContextLoader.class));
@@ -126,6 +132,34 @@ private MethodSpec.Builder contextLoadersMappingMethod(Map<ClassName, TestContex
126132
.addModifiers(Modifier.PUBLIC, Modifier.STATIC).addCode(code.build());
127133
}
128134

135+
private MethodSpec.Builder getContextInitializersBuilder(Map<ClassName, TestContextConfigurationDescriptor> entries) {
136+
// We're generating a method that looks like the following.
137+
//
138+
// public static Map<String, Class<? extends ApplicationContextInitializer<?>>> getContextInitializers() {
139+
// Map<String, Class<? extends ApplicationContextInitializer<?>>> map = new HashMap<>();
140+
// map.put("org.example.Sample1Tests", Sample1TestsContextInitializer.class);
141+
// map.put("org.example.Sample2Tests", Sample2TestsContextInitializer.class);
142+
// return map;
143+
// }
144+
145+
ClassName stringTypeName = ClassName.get(String.class);
146+
TypeName initializerWildcard = WildcardTypeName.subtypeOf(Object.class);
147+
TypeName initializerTypeName = ParameterizedTypeName.get(ClassName.get(ApplicationContextInitializer.class), initializerWildcard);
148+
TypeName classWildcard = WildcardTypeName.subtypeOf(initializerTypeName);
149+
TypeName classTypeName = ParameterizedTypeName.get(ClassName.get(Class.class), classWildcard);
150+
TypeName mapTypeName = ParameterizedTypeName.get(ClassName.get(Map.class), stringTypeName, classTypeName);
151+
152+
Builder code = CodeBlock.builder();
153+
code.addStatement("$T map = new $T<>()", mapTypeName, HashMap.class);
154+
entries.forEach((className, descriptor) ->
155+
descriptor.getTestClasses().forEach((testClass) -> {
156+
code.addStatement("map.put($S, $T.class)", testClass.getName(), className);
157+
}));
158+
code.addStatement("return map");
159+
return MethodSpec.methodBuilder("getContextInitializers").returns(mapTypeName)
160+
.addModifiers(Modifier.PUBLIC, Modifier.STATIC).addCode(code.build());
161+
}
162+
129163
private CodeBlock getClassLevelJavadoc(List<Class<?>> testClasses) {
130164
Builder code = CodeBlock.builder();
131165
code.add("AOT generated context for ");

spring-aot-test/src/test/java/org/springframework/aot/test/context/bootstrap/generator/TestContextAotProcessorTests.java

+25-5
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
* Tests for {@link TestContextAotProcessor}.
4040
*
4141
* @author Stephane Nicoll
42+
* @author Sam Brannen
4243
*/
4344
class TestContextAotProcessorTests {
4445

@@ -97,22 +98,41 @@ void processWritesContextLoaderMappingAtStandardLocation() {
9798
"}");
9899
}
99100

101+
@Test
102+
void processWritesContextInitializersMapping() {
103+
ContextBootstrapStructure structure = this.tester.process(
104+
SampleApplicationTests.class, SampleApplicationAnotherTests.class, SimpleSpringTests.class);
105+
assertThat(structure).contextBootstrapInitializer("TestContextBootstrapInitializer")
106+
.removeIndent(1).lines().containsSubsequence(
107+
"public static Map<String, Class<? extends ApplicationContextInitializer<?>>> getContextInitializers(",
108+
" ) {",
109+
" Map<String, Class<? extends ApplicationContextInitializer<?>>> map = new HashMap<>();",
110+
" map.put(\"org.springframework.aot.test.samples.app.SampleApplicationTests\", TestContextBootstrapInitializer0.class);",
111+
" map.put(\"org.springframework.aot.test.samples.app.SampleApplicationAnotherTests\", TestContextBootstrapInitializer0.class);",
112+
" map.put(\"org.springframework.aot.test.samples.simple.SimpleSpringTests\", SimpleSpringTestsContextInitializer.class);",
113+
" return map;",
114+
"}");
115+
}
116+
100117
@Test
101118
void processInvokeTestNativeConfigurationRegistrar() {
102119
ContextBootstrapStructure structure = this.tester.process(SampleApplicationTests.class);
103120
assertThat(structure).hasResourcePattern("org/springframework/aot/test/samples/app/SampleApplication.class");
104121
}
105122

106123
@Test
107-
void processRegisterReflectionForContextLoadersMappingMethod() {
124+
void processReflectionRegistrationForMappingMethods() {
108125
ContextBootstrapStructure structure = this.tester.process(SampleApplicationTests.class);
109126
assertThat(structure).hasClassDescriptor("com.example.TestContextBootstrapInitializer", (descriptor) -> {
110-
assertThat(descriptor.getMethods()).singleElement().satisfies((methodDescriptor) -> {
111-
assertThat(methodDescriptor.getName()).isEqualTo("getContextLoaders");
112-
assertThat(methodDescriptor.getParameterTypes()).isEmpty();
113-
});
114127
assertThat(descriptor.getAccess()).isNull();
115128
assertThat(descriptor.getFields()).isNull();
129+
assertThat(descriptor.getMethods()).hasSize(2);
130+
assertThat(descriptor.getMethods()).allSatisfy((methodDescriptor) -> {
131+
assertThat(methodDescriptor.getName()).satisfies(name -> {
132+
assertThat(name.equals("getContextLoaders") || name.equals("getContextInitializers")).isTrue();
133+
});
134+
assertThat(methodDescriptor.getParameterTypes()).isEmpty();
135+
});
116136
});
117137
}
118138

spring-native/src/main/java/org/springframework/aot/test/AotCacheAwareContextLoaderDelegate.java

+53-4
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@
2020
import org.apache.commons.logging.LogFactory;
2121

2222
import org.springframework.context.ApplicationContext;
23+
import org.springframework.context.ApplicationContextInitializer;
24+
import org.springframework.test.annotation.DirtiesContext.HierarchyMode;
2325
import org.springframework.test.context.CacheAwareContextLoaderDelegate;
2426
import org.springframework.test.context.MergedContextConfiguration;
2527
import org.springframework.test.context.SmartContextLoader;
28+
import org.springframework.test.context.cache.ContextCache;
2629
import org.springframework.test.context.cache.DefaultCacheAwareContextLoaderDelegate;
2730

2831
/**
29-
* A {@link CacheAwareContextLoaderDelegate} that enables the use of generated application
30-
* contexts for supported test classes.
32+
* A {@link CacheAwareContextLoaderDelegate} that enables the use of AOT-generated
33+
* application contexts for supported test classes.
3134
*
3235
* @author Stephane Nicoll
3336
* @author Sam Brannen
@@ -38,12 +41,17 @@ public class AotCacheAwareContextLoaderDelegate extends DefaultCacheAwareContext
3841

3942
private final AotContextLoader aotContextLoader;
4043

44+
public AotCacheAwareContextLoaderDelegate() {
45+
this(new AotContextLoader());
46+
}
47+
4148
AotCacheAwareContextLoaderDelegate(AotContextLoader aotContextLoader) {
4249
this.aotContextLoader = aotContextLoader;
4350
}
4451

45-
public AotCacheAwareContextLoaderDelegate() {
46-
this(new AotContextLoader());
52+
AotCacheAwareContextLoaderDelegate(AotContextLoader aotContextLoader, ContextCache contextCache) {
53+
super(contextCache);
54+
this.aotContextLoader = aotContextLoader;
4755
}
4856

4957
@Override
@@ -56,4 +64,45 @@ protected ApplicationContext loadContextInternal(MergedContextConfiguration conf
5664
return super.loadContextInternal(config);
5765
}
5866

67+
@Override
68+
public boolean isContextLoaded(MergedContextConfiguration mergedContextConfiguration) {
69+
return super.isContextLoaded(replaceIfNecessary(mergedContextConfiguration));
70+
}
71+
72+
@Override
73+
public ApplicationContext loadContext(MergedContextConfiguration mergedContextConfiguration) {
74+
return super.loadContext(replaceIfNecessary(mergedContextConfiguration));
75+
}
76+
77+
@Override
78+
public void closeContext(MergedContextConfiguration mergedContextConfiguration, HierarchyMode hierarchyMode) {
79+
super.closeContext(replaceIfNecessary(mergedContextConfiguration), hierarchyMode);
80+
}
81+
82+
/**
83+
* If the test class associated with the supplied {@link MergedContextConfiguration}
84+
* has an AOT-generated {@link ApplicationContext}, this method will create an
85+
* {@link AotMergedContextConfiguration} to replace the provided {@code MergedContextConfiguration}.
86+
* <p>This allows for transparent {@link org.springframework.test.context.cache.ContextCache ContextCache}
87+
* support for AOT-generated application contexts, including support for context
88+
* hierarchies.
89+
* <p>Otherwise, this method simply returns the supplied {@code MergedContextConfiguration}
90+
* unmodified.
91+
*/
92+
private MergedContextConfiguration replaceIfNecessary(MergedContextConfiguration mergedContextConfiguration) {
93+
if (mergedContextConfiguration == null) {
94+
return null;
95+
}
96+
97+
Class<?> testClass = mergedContextConfiguration.getTestClass();
98+
Class<? extends ApplicationContextInitializer<?>> contextInitializerClass =
99+
this.aotContextLoader.getContextInitializerClass(testClass);
100+
101+
if (contextInitializerClass != null) {
102+
return new AotMergedContextConfiguration(testClass, contextInitializerClass, this,
103+
replaceIfNecessary(mergedContextConfiguration.getParent()));
104+
}
105+
return mergedContextConfiguration;
106+
}
107+
59108
}

spring-native/src/main/java/org/springframework/aot/test/AotContextLoader.java

+31-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.Map;
2121
import java.util.function.Supplier;
2222

23+
import org.springframework.context.ApplicationContextInitializer;
2324
import org.springframework.test.context.SmartContextLoader;
2425
import org.springframework.util.ClassUtils;
2526
import org.springframework.util.ReflectionUtils;
@@ -40,12 +41,18 @@ class AotContextLoader {
4041

4142
private final Map<String, Supplier<SmartContextLoader>> contextLoaders;
4243

43-
AotContextLoader(Map<String, Supplier<SmartContextLoader>> contextLoaders) {
44+
private final Map<String, Class<? extends ApplicationContextInitializer<?>>> contextInitializers;
45+
46+
47+
AotContextLoader(Map<String, Supplier<SmartContextLoader>> contextLoaders,
48+
Map<String, Class<? extends ApplicationContextInitializer<?>>> contextInitializers) {
49+
4450
this.contextLoaders = contextLoaders;
51+
this.contextInitializers = contextInitializers;
4552
}
4653

4754
AotContextLoader(String initializerClassName) {
48-
this(loadContextLoadersMapping(initializerClassName));
55+
this(loadContextLoadersMapping(initializerClassName), loadContextInitializersMapping(initializerClassName));
4956
}
5057

5158
AotContextLoader() {
@@ -70,6 +77,24 @@ private static Map<String, Supplier<SmartContextLoader>> loadContextLoadersMappi
7077
}
7178
}
7279

80+
@SuppressWarnings("unchecked")
81+
private static Map<String, Class<? extends ApplicationContextInitializer<?>>> loadContextInitializersMapping(String initializerClassName) {
82+
try {
83+
Class<?> type = ClassUtils.forName(initializerClassName, null);
84+
Method method = ReflectionUtils.findMethod(type, "getContextInitializers");
85+
if (method == null) {
86+
throw new IllegalStateException("No getContextInitializers() method found on " + type.getName());
87+
}
88+
return (Map<String, Class<? extends ApplicationContextInitializer<?>>>) ReflectionUtils.invokeMethod(method, null);
89+
}
90+
catch (IllegalStateException ex) {
91+
throw ex;
92+
}
93+
catch (Exception ex) {
94+
throw new IllegalStateException("Failed to load context initializers mapping", ex);
95+
}
96+
}
97+
7398
SmartContextLoader getContextLoader(Class<?> testClass) {
7499
Supplier<SmartContextLoader> supplier = this.contextLoaders.get(testClass.getName());
75100
return (supplier != null) ? supplier.get() : null;
@@ -79,4 +104,8 @@ boolean isSupportedTestClass(Class<?> testClass) {
79104
return this.contextLoaders.containsKey(testClass.getName());
80105
}
81106

107+
Class<? extends ApplicationContextInitializer<?>> getContextInitializerClass(Class<?> testClass) {
108+
return this.contextInitializers.get(testClass.getName());
109+
}
110+
82111
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright 2019-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.aot.test;
18+
19+
import java.util.Collections;
20+
21+
import org.springframework.context.ApplicationContextInitializer;
22+
import org.springframework.core.style.ToStringCreator;
23+
import org.springframework.lang.Nullable;
24+
import org.springframework.test.context.MergedContextConfiguration;
25+
26+
/**
27+
* {@link MergedContextConfiguration} implementation based on an AOT-generated
28+
* {@link ApplicationContextInitializer} that is used to load an AOT-generated
29+
* {@link org.springframework.context.ApplicationContext ApplicationContext}.
30+
*
31+
* <p>The {@link #getParent() parent} may optionally be set as well.
32+
*
33+
* @author Sam Brannen
34+
*/
35+
class AotMergedContextConfiguration extends MergedContextConfiguration {
36+
37+
private static final long serialVersionUID = 1963364911008547843L;
38+
39+
private final Class<? extends ApplicationContextInitializer<?>> contextInitializerClass;
40+
41+
AotMergedContextConfiguration(Class<?> testClass,
42+
Class<? extends ApplicationContextInitializer<?>> contextInitializerClass,
43+
AotCacheAwareContextLoaderDelegate cacheAwareContextLoaderDelegate,
44+
@Nullable MergedContextConfiguration parent) {
45+
46+
super(testClass, null, null, Collections.singleton(contextInitializerClass), null, null,
47+
cacheAwareContextLoaderDelegate, parent);
48+
this.contextInitializerClass = contextInitializerClass;
49+
}
50+
51+
@Override
52+
public boolean equals(@Nullable Object other) {
53+
if (this == other) {
54+
return true;
55+
}
56+
if (other == null || other.getClass() != getClass()) {
57+
return false;
58+
}
59+
60+
AotMergedContextConfiguration that = (AotMergedContextConfiguration) other;
61+
if (!this.contextInitializerClass.equals(that.contextInitializerClass)) {
62+
return false;
63+
}
64+
65+
if (getParent() == null) {
66+
if (that.getParent() != null) {
67+
return false;
68+
}
69+
}
70+
else if (!getParent().equals(that.getParent())) {
71+
return false;
72+
}
73+
74+
return true;
75+
}
76+
77+
@Override
78+
public int hashCode() {
79+
int result = this.contextInitializerClass.hashCode();
80+
result = 31 * result + (getParent() != null ? getParent().hashCode() : 0);
81+
return result;
82+
}
83+
84+
@Override
85+
public String toString() {
86+
return new ToStringCreator(this)
87+
.append("testClass", getTestClass().getName())
88+
.append("contextInitializerClass", this.contextInitializerClass.getName())
89+
.append("parent", getParent())
90+
.toString();
91+
}
92+
93+
}

0 commit comments

Comments
 (0)