diff --git a/spring-aot/src/main/java/org/springframework/data/RepositoryFactoryBeanPostProcessor.java b/spring-aot/src/main/java/org/springframework/data/RepositoryFactoryBeanPostProcessor.java index 011de903e..ed8c12430 100644 --- a/spring-aot/src/main/java/org/springframework/data/RepositoryFactoryBeanPostProcessor.java +++ b/spring-aot/src/main/java/org/springframework/data/RepositoryFactoryBeanPostProcessor.java @@ -16,6 +16,10 @@ package org.springframework.data; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + import org.springframework.aot.context.bootstrap.generator.bean.BeanRegistrationWriter; import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; @@ -34,6 +38,7 @@ * repository with the resolved generics of the repository. * * @author Stephane Nicoll + * @author Christoph Strobl */ class RepositoryFactoryBeanPostProcessor implements BeanDefinitionPostProcessor, BeanFactoryAware { @@ -61,10 +66,10 @@ private void resolveRepositoryFactoryBeanTypeIfNecessary(RootBeanDefinition bean Class repositoryType = loadRepositoryType(valueHolder); if (repositoryType != null) { ResolvableType resolvableType = ResolvableType.forClass(repositoryType).as(Repository.class); - ResolvableType entityType = resolvableType.getGenerics()[0]; - ResolvableType idType = resolvableType.getGenerics()[1]; + List typeArgs = new ArrayList<>(Arrays.asList(resolvableType.getGenerics())); + typeArgs.add(0, ResolvableType.forClass(repositoryType)); ResolvableType resolvedRepositoryType = ResolvableType.forClassWithGenerics( - beanDefinition.getBeanClass(), ResolvableType.forClass(repositoryType), entityType, idType); + beanDefinition.getBeanClass(), typeArgs.toArray(new ResolvableType[0])); beanDefinition.setTargetType(resolvedRepositoryType); beanDefinition.setAttribute(BeanRegistrationWriter.PRESERVE_TARGET_TYPE, true); } diff --git a/spring-aot/src/test/java/org/springframework/data/RepositoryFactoryBeanPostProcessorTests.java b/spring-aot/src/test/java/org/springframework/data/RepositoryFactoryBeanPostProcessorTests.java index dc4fcdb9e..a2aef4071 100644 --- a/spring-aot/src/test/java/org/springframework/data/RepositoryFactoryBeanPostProcessorTests.java +++ b/spring-aot/src/test/java/org/springframework/data/RepositoryFactoryBeanPostProcessorTests.java @@ -46,6 +46,16 @@ void resolveRepositoryTypeWithTypeAsString() { assertThat(beanDefinition.getAttribute(BeanRegistrationWriter.PRESERVE_TARGET_TYPE)).isEqualTo(true); } + @Test + void resolveRepositoryTypeIfNotDirectSubOfRepository() { + RootBeanDefinition beanDefinition = (RootBeanDefinition) BeanDefinitionBuilder.rootBeanDefinition(JpaRepositoryFactoryBean.class) + .addConstructorArgValue("org.springframework.data.RepositoryFactoryBeanPostProcessorTests.RockstarRepository").getBeanDefinition(); + assertThat(beanDefinition.getResolvableType().hasUnresolvableGenerics()).isTrue(); + postProcess(beanDefinition); + assertFactoryBeanForRockstarRepository(beanDefinition.getResolvableType()); + assertThat(beanDefinition.getAttribute(BeanRegistrationWriter.PRESERVE_TARGET_TYPE)).isEqualTo(true); + } + @Test void resolveRepositoryTypeWithTypeAsStringThatDoesNotExist() { RootBeanDefinition beanDefinition = (RootBeanDefinition) BeanDefinitionBuilder.rootBeanDefinition(JpaRepositoryFactoryBean.class) @@ -109,6 +119,13 @@ private void assertFactoryBeanForSpeakerRepository(ResolvableType resolvedType) assertThat(resolvedType.getGenerics()[2].resolve()).isEqualTo(Integer.class); } + private void assertFactoryBeanForRockstarRepository(ResolvableType resolvedType) { + assertThat(resolvedType.hasUnresolvableGenerics()).isFalse(); + assertThat(resolvedType.getGenerics()[0].resolve()).isEqualTo(RockstarRepository.class); + assertThat(resolvedType.getGenerics()[1].resolve()).isEqualTo(Speaker.class); + assertThat(resolvedType.getGenerics()[2].resolve()).isEqualTo(Integer.class); + } + private void postProcess(RootBeanDefinition beanDefinition) { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); RepositoryFactoryBeanPostProcessor processor = new RepositoryFactoryBeanPostProcessor(); @@ -120,6 +137,10 @@ interface SpeakerRepository extends CrudRepository { } + interface RockstarRepository extends SpeakerRepository { + + } + static class Speaker { }