Dependency Injection impl automatically handles "circular" constructors/setters.

This commit is contained in:
JesseBrault0709 2024-05-27 07:53:46 +02:00
parent d116b2c555
commit 2bac0f55dc
4 changed files with 272 additions and 41 deletions

View File

@ -5,6 +5,8 @@ import org.jetbrains.annotations.Nullable;
import java.lang.reflect.*;
import java.util.*;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import static groowt.util.di.ObjectFactoryUtil.toTypes;
@ -19,6 +21,89 @@ public abstract class AbstractInjectingObjectFactory implements ObjectFactory {
Class<?>[] paramTypes
) {}
protected record Resolved(Class<?> type, Object object) {}
private static final class DeferredSetters {
private final Class<?> forType;
private final List<Consumer<Object>> actions = new ArrayList<>();
public DeferredSetters(Class<?> forType) {
this.forType = forType;
}
public Class<?> getForType() {
return this.forType;
}
public List<Consumer<Object>> getActions() {
return this.actions;
}
}
protected static class CreateContext {
private final Deque<Class<?>> constructionStack = new LinkedList<>();
private final Deque<DeferredSetters> deferredSettersStack = new LinkedList<>();
private final List<Resolved> allResolved = new ArrayList<>();
public CreateContext(Class<?> targetType) {
this.pushConstruction(targetType);
}
public void checkForCircularDependency(Class<?> typeToConstruct) {
if (constructionStack.contains(typeToConstruct)) {
throw new IllegalStateException(
"Detected a circular constructor dependency for " + typeToConstruct.getName()
+ " . Please use setter methods instead. Current construction stack: "
+ constructionStack.stream()
.map(Class::getName)
.collect(Collectors.joining(", "))
);
}
}
public void pushConstruction(Class<?> type) {
this.constructionStack.push(type);
this.deferredSettersStack.push(new DeferredSetters(type));
}
public void popConstruction() {
this.constructionStack.pop();
this.deferredSettersStack.pop();
}
public boolean containsConstruction(Class<?> type) {
return this.constructionStack.contains(type);
}
public void addDeferredSetterAction(Class<?> forType, Consumer<Object> action) {
boolean found = false;
for (final DeferredSetters deferredSetters : this.deferredSettersStack) {
if (deferredSetters.getForType().equals(forType)) {
found = true;
deferredSetters.getActions().add(action);
break;
}
}
if (!found) {
throw new IllegalArgumentException(
"There is no construction for type " + forType.getName() + " which should be deferred."
);
}
}
public List<Consumer<Object>> getDeferredSetterActions() {
return this.deferredSettersStack.getFirst().getActions();
}
public List<Resolved> getAllResolved() {
return this.allResolved;
}
}
private final Map<Class<?>, Constructor<?>[]> cachedAllConstructors = new HashMap<>();
private final Collection<CachedInjectConstructor<?>> cachedInjectConstructors = new ArrayList<>();
private final Collection<CachedNonInjectConstructor<?>> cachedNonInjectConstructors = new ArrayList<>();
@ -82,13 +167,14 @@ public abstract class AbstractInjectingObjectFactory implements ObjectFactory {
}
/**
* @implNote If overridden, please cache any found non-inject constructors using {@link #putCachedNonInjectConstructor}.
* @implNote If overridden, please cache any found non-inject constructors using
* {@link #putCachedNonInjectConstructor}.
*
* @param clazz the {@link Class} in which to search for a constructor which does not have an <code>{@literal @}Inject</code>
* annotation
* @param clazz the {@link Class} in which to search for a constructor which does
* not have an <code>{@literal @}Inject</code> annotation.
* @param constructorArgs the given constructor args
* @return the found non-inject constructor appropriate for the given constructor args, or {@code null} if no
* such constructor exists
* @return the found non-inject constructor appropriate for the given constructor args,
* or {@code null} if no such constructor exists
* @param <T> the type
*/
@SuppressWarnings("unchecked")
@ -156,42 +242,104 @@ public abstract class AbstractInjectingObjectFactory implements ObjectFactory {
protected Parameter getCachedInjectParameter(Method setter) {
return this.cachedSetterParameters.computeIfAbsent(setter, s -> {
if (s.getParameterCount() != 1) {
throw new IllegalArgumentException("Setter " + s.getName() + " has a parameter count other than one (1)!");
throw new IllegalArgumentException(
"Setter " + s.getName() + " has a parameter count other than one (1)!"
);
}
return s.getParameters()[0];
});
}
protected void injectSetter(Object target, Method setter) {
try {
setter.invoke(target, this.getSetterInjectArg(target.getClass(), setter, this.getCachedInjectParameter(setter)));
} catch (InvocationTargetException | IllegalAccessException e) {
throw new RuntimeException(e);
private @Nullable Object findInContext(CreateContext context, Class<?> type) {
for (final Resolved resolved : context.getAllResolved()) {
if (type.isAssignableFrom(resolved.type)) {
return resolved.object;
}
}
return null;
}
protected void injectSetter(CreateContext context, Object target, Method setter) {
final Parameter injectParam = this.getCachedInjectParameter(setter);
final Class<?> typeToInject = injectParam.getType();
final @Nullable Object fromContext = this.findInContext(context, typeToInject);
final Consumer<Object> setterAction = arg -> {
try {
setter.invoke(target, arg);
} catch (InvocationTargetException | IllegalAccessException e) {
throw new RuntimeException(e);
}
};
if (context.containsConstruction(typeToInject)) {
context.addDeferredSetterAction(typeToInject, setterAction);
} else {
final Object arg;
if (fromContext != null) {
arg = fromContext;
} else {
arg = this.getSetterInjectArg(context, target.getClass(), setter, injectParam);
context.getAllResolved().add(new Resolved(typeToInject, arg));
}
setterAction.accept(arg);
}
}
protected void injectSetters(Object target) {
this.getCachedSettersFor(target).forEach(setter -> this.injectSetter(target, setter));
protected void injectSetters(CreateContext context, Object target) {
this.getCachedSettersFor(target).forEach(setter -> this.injectSetter(context, target, setter));
}
/**
* {@inheritDoc}
*/
@Override
public <T> T createInstance(Class<T> clazz, Object... constructorArgs) {
final Constructor<T> constructor = this.findConstructor(clazz, constructorArgs);
final Object[] allArgs = this.createArgs(constructor, constructorArgs);
public <T> T createInstance(Class<T> type, Object... constructorArgs) {
final Constructor<T> constructor = this.findConstructor(type, constructorArgs);
final CreateContext context = new CreateContext(type);
final Object[] allArgs = this.createArgs(context, constructor, constructorArgs);
try {
final T instance = constructor.newInstance(allArgs);
this.injectSetters(instance);
context.getAllResolved().add(new Resolved(type, instance));
this.injectSetters(context, instance);
final List<Consumer<Object>> deferredSetterActions = context.getDeferredSetterActions();
context.popConstruction();
deferredSetterActions.forEach(setterAction -> setterAction.accept(instance));
return instance;
} catch (InvocationTargetException | IllegalAccessException | InstantiationException e) {
throw new RuntimeException(e); // In the future, we might have an option to ignore exceptions
}
}
protected abstract Object[] createArgs(Constructor<?> constructor, Object[] constructorArgs);
protected <T> T createInstance(CreateContext context, Class<T> type, Object... givenArgs) {
final Constructor<T> constructor = this.findConstructor(type, givenArgs);
context.checkForCircularDependency(type);
context.pushConstruction(type);
final Object[] allArgs = this.createArgs(context, constructor, givenArgs);
try {
final T instance = constructor.newInstance(allArgs);
context.getAllResolved().add(new Resolved(type, instance));
this.injectSetters(context, instance);
final List<Consumer<Object>> deferredSetterActions = context.getDeferredSetterActions();
context.popConstruction();
deferredSetterActions.forEach(setterAction -> setterAction.accept(instance));
return instance;
} catch (InvocationTargetException | IllegalAccessException | InstantiationException e) {
throw new RuntimeException("e");
}
}
protected abstract Object getSetterInjectArg(Class<?> targetType, Method setter, Parameter toInject);
protected abstract Object[] createArgs(
CreateContext context,
Constructor<?> constructor,
Object[] constructorArgs
);
protected abstract Object getSetterInjectArg(
CreateContext context,
Class<?> targetType,
Method setter,
Parameter toInject
);
}

View File

@ -14,7 +14,8 @@ import java.util.function.Supplier;
import static groowt.util.di.RegistryObjectFactoryUtil.orElseSupply;
public abstract class AbstractRegistryObjectFactory extends AbstractInjectingObjectFactory implements RegistryObjectFactory {
public abstract class AbstractRegistryObjectFactory extends AbstractInjectingObjectFactory
implements RegistryObjectFactory {
public static abstract class AbstractBuilder<T extends DefaultRegistryObjectFactory> implements Builder<T> {

View File

@ -3,8 +3,6 @@ package groowt.util.di;
import groowt.util.di.filters.FilterHandler;
import groowt.util.di.filters.IterableFilterHandler;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
@ -20,7 +18,8 @@ import static groowt.util.di.RegistryObjectFactoryUtil.*;
public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory {
public static final class Builder extends AbstractRegistryObjectFactory.AbstractBuilder<DefaultRegistryObjectFactory> {
public static final class Builder
extends AbstractRegistryObjectFactory.AbstractBuilder<DefaultRegistryObjectFactory> {
/**
* Creates a {@code Builder} initialized with a {@link DefaultRegistry}, which is in-turn configured with a
@ -74,7 +73,6 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
}
private static final Object[] EMPTY_OBJECT_ARRAY = new Object[0];
private static final Logger logger = LoggerFactory.getLogger(DefaultRegistryObjectFactory.class); // leave it for the future!
private final Collection<FilterHandler<?, ?>> filterHandlers;
private final Collection<IterableFilterHandler<?, ?>> iterableFilterHandlers;
@ -107,7 +105,9 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
@SuppressWarnings("unchecked")
protected final @Nullable Object tryQualifiers(Parameter parameter) {
final Class<?> paramType = parameter.getType();
final List<Annotation> qualifiers = RegistryObjectFactoryUtil.getQualifierAnnotations(parameter.getAnnotations());
final List<Annotation> qualifiers = RegistryObjectFactoryUtil.getQualifierAnnotations(
parameter.getAnnotations()
);
if (qualifiers.size() > 1) {
throw new RuntimeException("Parameter " + parameter + " cannot have more than one Qualifier annotation.");
} else if (qualifiers.size() == 1) {
@ -115,7 +115,9 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
@SuppressWarnings("rawtypes")
final QualifierHandler handler = this.getInSelfOrParent(
f -> f.findQualifierHandler(qualifier.annotationType()),
() -> new RuntimeException("There is no configured QualifierHandler for " + qualifier.annotationType().getName())
() -> new RuntimeException("There is no configured QualifierHandler for "
+ qualifier.annotationType().getName()
)
);
final Binding<?> binding = handler.handle(qualifier, paramType);
if (binding != null) {
@ -165,21 +167,23 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
}
}
protected final Object resolveInjectedArg(Parameter parameter) {
protected final Object resolveInjectedArg(CreateContext context, Parameter parameter) {
final Object qualifierProvidedArg = this.tryQualifiers(parameter);
if (qualifierProvidedArg != null) {
this.checkFilters(parameter, qualifierProvidedArg);
context.getAllResolved().add(new Resolved(parameter.getType(), qualifierProvidedArg));
return qualifierProvidedArg;
} else {
final Object created = this.get(parameter.getType());
final Object created = this.get(context, parameter.getType());
this.checkFilters(parameter, created);
context.getAllResolved().add(new Resolved(parameter.getType(), created));
return created;
}
}
protected final void resolveInjectedArgs(Object[] dest, Parameter[] params) {
protected final void resolveInjectedArgs(CreateContext context, Object[] dest, Parameter[] params) {
for (int i = 0; i < params.length; i++) {
dest[i] = this.resolveInjectedArg(params[i]);
dest[i] = this.resolveInjectedArg(context, params[i]);
}
}
@ -194,7 +198,7 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
// TODO: when there is a null arg, we lose the type. Therefore this algorithm breaks. Fix this.
@Override
protected Object[] createArgs(Constructor<?> constructor, Object[] givenArgs) {
protected Object[] createArgs(CreateContext context, Constructor<?> constructor, Object[] givenArgs) {
final Class<?>[] paramTypes = constructor.getParameterTypes();
// check no arg
@ -218,7 +222,7 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
if (givenArgs.length == 0) {
// if no given args, then they are all injected
this.resolveInjectedArgs(resolvedArgs, allParams);
this.resolveInjectedArgs(context, resolvedArgs, allParams);
} else if (givenArgs.length == paramTypes.length) {
// all are given
this.resolveGivenArgs(resolvedArgs, allParams, givenArgs, 0);
@ -234,17 +238,23 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
final Parameter[] givenParams = new Parameter[allParams.length - firstGivenIndex];
System.arraycopy(allParams, 0, injectedParams, 0, injectedParams.length);
System.arraycopy(allParams, firstGivenIndex, givenParams, 0, allParams.length - firstGivenIndex);
System.arraycopy(
allParams, firstGivenIndex, givenParams, 0, allParams.length - firstGivenIndex
);
this.resolveInjectedArgs(resolvedArgs, injectedParams);
this.resolveInjectedArgs(context, resolvedArgs, injectedParams);
this.resolveGivenArgs(resolvedArgs, givenParams, givenArgs, firstGivenIndex);
}
return resolvedArgs;
}
@SuppressWarnings("unchecked")
private <T> T handleBinding(Binding<T> binding, Object[] constructorArgs) {
return this.handleBinding(binding, null, constructorArgs);
}
@SuppressWarnings("unchecked")
private <T> T handleBinding(Binding<T> binding, @Nullable CreateContext context, Object[] constructorArgs) {
return switch (binding) {
case ClassBinding<T>(Class<T> ignored, Class<? extends T> to) -> {
final Annotation scopeAnnotation = getScopeAnnotation(to);
@ -253,12 +263,20 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
@SuppressWarnings("rawtypes")
final ScopeHandler scopeHandler = this.getInSelfOrParent(
f -> f.findScopeHandler(scopeClass),
() -> new RuntimeException("There is no configured ScopeHandler for " + scopeClass.getName())
() -> new RuntimeException(
"There is no configured ScopeHandler for " + scopeClass.getName()
)
);
final Binding<T> scopedBinding = scopeHandler.onScopedDependencyRequest(
scopeAnnotation, to, this
);
final Binding<T> scopedBinding = scopeHandler.onScopedDependencyRequest(scopeAnnotation, to, this);
yield this.handleBinding(scopedBinding, constructorArgs);
} else {
yield this.createInstance(to, constructorArgs);
if (context != null) {
yield this.createInstance(context, to, constructorArgs);
} else {
yield this.createInstance(to, constructorArgs);
}
}
}
case ProviderBinding<T> providerBinding -> providerBinding.provider().get();
@ -276,8 +294,8 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
}
@Override
protected Object getSetterInjectArg(Class<?> targetType, Method setter, Parameter toInject) {
return this.resolveInjectedArg(toInject);
protected Object getSetterInjectArg(CreateContext context, Class<?> targetType, Method setter, Parameter toInject) {
return this.resolveInjectedArg(context, toInject);
}
/**
@ -293,7 +311,24 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
if (parentResult != null) {
return parentResult;
} else {
throw new RuntimeException("No bindings for " + clazz + " with args " + Arrays.toString(constructorArgs) + ".");
throw new RuntimeException(
"No bindings for " + clazz + " with args " + Arrays.toString(constructorArgs) + "."
);
}
}
protected <T> T get(CreateContext context, Class<T> type) {
final Binding<T> binding = this.searchRegistry(type);
if (binding != null) {
return this.handleBinding(binding, context, EMPTY_OBJECT_ARRAY);
}
final T parentResult = this.tryParent(type, EMPTY_OBJECT_ARRAY);
if (parentResult != null) {
return parentResult;
} else {
throw new RuntimeException(
"No bindings for " + type + " with args " + Arrays.toString(EMPTY_OBJECT_ARRAY) + "."
);
}
}

View File

@ -166,4 +166,51 @@ public class DefaultRegistryObjectFactoryTests {
assertEquals("Hello, World!", greeter.greet());
}
public static final class GreeterDependency {
private GreeterDependencyUser greeter;
@Inject
public void setGreeter(GreeterDependencyUser greeter) {
this.greeter = greeter;
}
public String filterGreeting() {
return this.greeter.getGreeting().toUpperCase();
}
}
public static final class GreeterDependencyUser implements Greeter {
private final GreeterDependency greeterDependency;
@Inject
public GreeterDependencyUser(GreeterDependency greeterDependency) {
this.greeterDependency = greeterDependency;
}
@Override
public String greet() {
return this.greeterDependency.filterGreeting();
}
public String getGreeting() {
return "hello, world!";
}
}
@Test
public void injectedDeferred() {
final var b = DefaultRegistryObjectFactory.Builder.withDefaults();
b.configureRegistry(r -> {
r.bind(GreeterDependencyUser.class, toSelf());
r.bind(GreeterDependency.class, toSelf());
});
final var f = b.build();
final var g = f.get(GreeterDependencyUser.class);
assertEquals("HELLO, WORLD!", g.greet());
}
}