Dependency Injection impl automatically handles "circular" constructors/setters.

This commit is contained in:
JesseBrault0709 2024-05-27 07:53:46 +02:00
parent d116b2c555
commit 2bac0f55dc
4 changed files with 272 additions and 41 deletions

View File

@ -5,6 +5,8 @@ import org.jetbrains.annotations.Nullable;
import java.lang.reflect.*; import java.lang.reflect.*;
import java.util.*; import java.util.*;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import static groowt.util.di.ObjectFactoryUtil.toTypes; import static groowt.util.di.ObjectFactoryUtil.toTypes;
@ -19,6 +21,89 @@ public abstract class AbstractInjectingObjectFactory implements ObjectFactory {
Class<?>[] paramTypes Class<?>[] paramTypes
) {} ) {}
protected record Resolved(Class<?> type, Object object) {}
private static final class DeferredSetters {
private final Class<?> forType;
private final List<Consumer<Object>> actions = new ArrayList<>();
public DeferredSetters(Class<?> forType) {
this.forType = forType;
}
public Class<?> getForType() {
return this.forType;
}
public List<Consumer<Object>> getActions() {
return this.actions;
}
}
protected static class CreateContext {
private final Deque<Class<?>> constructionStack = new LinkedList<>();
private final Deque<DeferredSetters> deferredSettersStack = new LinkedList<>();
private final List<Resolved> allResolved = new ArrayList<>();
public CreateContext(Class<?> targetType) {
this.pushConstruction(targetType);
}
public void checkForCircularDependency(Class<?> typeToConstruct) {
if (constructionStack.contains(typeToConstruct)) {
throw new IllegalStateException(
"Detected a circular constructor dependency for " + typeToConstruct.getName()
+ " . Please use setter methods instead. Current construction stack: "
+ constructionStack.stream()
.map(Class::getName)
.collect(Collectors.joining(", "))
);
}
}
public void pushConstruction(Class<?> type) {
this.constructionStack.push(type);
this.deferredSettersStack.push(new DeferredSetters(type));
}
public void popConstruction() {
this.constructionStack.pop();
this.deferredSettersStack.pop();
}
public boolean containsConstruction(Class<?> type) {
return this.constructionStack.contains(type);
}
public void addDeferredSetterAction(Class<?> forType, Consumer<Object> action) {
boolean found = false;
for (final DeferredSetters deferredSetters : this.deferredSettersStack) {
if (deferredSetters.getForType().equals(forType)) {
found = true;
deferredSetters.getActions().add(action);
break;
}
}
if (!found) {
throw new IllegalArgumentException(
"There is no construction for type " + forType.getName() + " which should be deferred."
);
}
}
public List<Consumer<Object>> getDeferredSetterActions() {
return this.deferredSettersStack.getFirst().getActions();
}
public List<Resolved> getAllResolved() {
return this.allResolved;
}
}
private final Map<Class<?>, Constructor<?>[]> cachedAllConstructors = new HashMap<>(); private final Map<Class<?>, Constructor<?>[]> cachedAllConstructors = new HashMap<>();
private final Collection<CachedInjectConstructor<?>> cachedInjectConstructors = new ArrayList<>(); private final Collection<CachedInjectConstructor<?>> cachedInjectConstructors = new ArrayList<>();
private final Collection<CachedNonInjectConstructor<?>> cachedNonInjectConstructors = new ArrayList<>(); private final Collection<CachedNonInjectConstructor<?>> cachedNonInjectConstructors = new ArrayList<>();
@ -82,13 +167,14 @@ public abstract class AbstractInjectingObjectFactory implements ObjectFactory {
} }
/** /**
* @implNote If overridden, please cache any found non-inject constructors using {@link #putCachedNonInjectConstructor}. * @implNote If overridden, please cache any found non-inject constructors using
* {@link #putCachedNonInjectConstructor}.
* *
* @param clazz the {@link Class} in which to search for a constructor which does not have an <code>{@literal @}Inject</code> * @param clazz the {@link Class} in which to search for a constructor which does
* annotation * not have an <code>{@literal @}Inject</code> annotation.
* @param constructorArgs the given constructor args * @param constructorArgs the given constructor args
* @return the found non-inject constructor appropriate for the given constructor args, or {@code null} if no * @return the found non-inject constructor appropriate for the given constructor args,
* such constructor exists * or {@code null} if no such constructor exists
* @param <T> the type * @param <T> the type
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -156,42 +242,104 @@ public abstract class AbstractInjectingObjectFactory implements ObjectFactory {
protected Parameter getCachedInjectParameter(Method setter) { protected Parameter getCachedInjectParameter(Method setter) {
return this.cachedSetterParameters.computeIfAbsent(setter, s -> { return this.cachedSetterParameters.computeIfAbsent(setter, s -> {
if (s.getParameterCount() != 1) { if (s.getParameterCount() != 1) {
throw new IllegalArgumentException("Setter " + s.getName() + " has a parameter count other than one (1)!"); throw new IllegalArgumentException(
"Setter " + s.getName() + " has a parameter count other than one (1)!"
);
} }
return s.getParameters()[0]; return s.getParameters()[0];
}); });
} }
protected void injectSetter(Object target, Method setter) { private @Nullable Object findInContext(CreateContext context, Class<?> type) {
try { for (final Resolved resolved : context.getAllResolved()) {
setter.invoke(target, this.getSetterInjectArg(target.getClass(), setter, this.getCachedInjectParameter(setter))); if (type.isAssignableFrom(resolved.type)) {
} catch (InvocationTargetException | IllegalAccessException e) { return resolved.object;
throw new RuntimeException(e); }
}
return null;
}
protected void injectSetter(CreateContext context, Object target, Method setter) {
final Parameter injectParam = this.getCachedInjectParameter(setter);
final Class<?> typeToInject = injectParam.getType();
final @Nullable Object fromContext = this.findInContext(context, typeToInject);
final Consumer<Object> setterAction = arg -> {
try {
setter.invoke(target, arg);
} catch (InvocationTargetException | IllegalAccessException e) {
throw new RuntimeException(e);
}
};
if (context.containsConstruction(typeToInject)) {
context.addDeferredSetterAction(typeToInject, setterAction);
} else {
final Object arg;
if (fromContext != null) {
arg = fromContext;
} else {
arg = this.getSetterInjectArg(context, target.getClass(), setter, injectParam);
context.getAllResolved().add(new Resolved(typeToInject, arg));
}
setterAction.accept(arg);
} }
} }
protected void injectSetters(Object target) { protected void injectSetters(CreateContext context, Object target) {
this.getCachedSettersFor(target).forEach(setter -> this.injectSetter(target, setter)); this.getCachedSettersFor(target).forEach(setter -> this.injectSetter(context, target, setter));
} }
/** /**
* {@inheritDoc} * {@inheritDoc}
*/ */
@Override @Override
public <T> T createInstance(Class<T> clazz, Object... constructorArgs) { public <T> T createInstance(Class<T> type, Object... constructorArgs) {
final Constructor<T> constructor = this.findConstructor(clazz, constructorArgs); final Constructor<T> constructor = this.findConstructor(type, constructorArgs);
final Object[] allArgs = this.createArgs(constructor, constructorArgs); final CreateContext context = new CreateContext(type);
final Object[] allArgs = this.createArgs(context, constructor, constructorArgs);
try { try {
final T instance = constructor.newInstance(allArgs); final T instance = constructor.newInstance(allArgs);
this.injectSetters(instance); context.getAllResolved().add(new Resolved(type, instance));
this.injectSetters(context, instance);
final List<Consumer<Object>> deferredSetterActions = context.getDeferredSetterActions();
context.popConstruction();
deferredSetterActions.forEach(setterAction -> setterAction.accept(instance));
return instance; return instance;
} catch (InvocationTargetException | IllegalAccessException | InstantiationException e) { } catch (InvocationTargetException | IllegalAccessException | InstantiationException e) {
throw new RuntimeException(e); // In the future, we might have an option to ignore exceptions throw new RuntimeException(e); // In the future, we might have an option to ignore exceptions
} }
} }
protected abstract Object[] createArgs(Constructor<?> constructor, Object[] constructorArgs); protected <T> T createInstance(CreateContext context, Class<T> type, Object... givenArgs) {
final Constructor<T> constructor = this.findConstructor(type, givenArgs);
context.checkForCircularDependency(type);
context.pushConstruction(type);
final Object[] allArgs = this.createArgs(context, constructor, givenArgs);
try {
final T instance = constructor.newInstance(allArgs);
context.getAllResolved().add(new Resolved(type, instance));
this.injectSetters(context, instance);
final List<Consumer<Object>> deferredSetterActions = context.getDeferredSetterActions();
context.popConstruction();
deferredSetterActions.forEach(setterAction -> setterAction.accept(instance));
return instance;
} catch (InvocationTargetException | IllegalAccessException | InstantiationException e) {
throw new RuntimeException("e");
}
}
protected abstract Object getSetterInjectArg(Class<?> targetType, Method setter, Parameter toInject); protected abstract Object[] createArgs(
CreateContext context,
Constructor<?> constructor,
Object[] constructorArgs
);
protected abstract Object getSetterInjectArg(
CreateContext context,
Class<?> targetType,
Method setter,
Parameter toInject
);
} }

View File

@ -14,7 +14,8 @@ import java.util.function.Supplier;
import static groowt.util.di.RegistryObjectFactoryUtil.orElseSupply; import static groowt.util.di.RegistryObjectFactoryUtil.orElseSupply;
public abstract class AbstractRegistryObjectFactory extends AbstractInjectingObjectFactory implements RegistryObjectFactory { public abstract class AbstractRegistryObjectFactory extends AbstractInjectingObjectFactory
implements RegistryObjectFactory {
public static abstract class AbstractBuilder<T extends DefaultRegistryObjectFactory> implements Builder<T> { public static abstract class AbstractBuilder<T extends DefaultRegistryObjectFactory> implements Builder<T> {

View File

@ -3,8 +3,6 @@ package groowt.util.di;
import groowt.util.di.filters.FilterHandler; import groowt.util.di.filters.FilterHandler;
import groowt.util.di.filters.IterableFilterHandler; import groowt.util.di.filters.IterableFilterHandler;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
@ -20,7 +18,8 @@ import static groowt.util.di.RegistryObjectFactoryUtil.*;
public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory { public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory {
public static final class Builder extends AbstractRegistryObjectFactory.AbstractBuilder<DefaultRegistryObjectFactory> { public static final class Builder
extends AbstractRegistryObjectFactory.AbstractBuilder<DefaultRegistryObjectFactory> {
/** /**
* Creates a {@code Builder} initialized with a {@link DefaultRegistry}, which is in-turn configured with a * Creates a {@code Builder} initialized with a {@link DefaultRegistry}, which is in-turn configured with a
@ -74,7 +73,6 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
} }
private static final Object[] EMPTY_OBJECT_ARRAY = new Object[0]; private static final Object[] EMPTY_OBJECT_ARRAY = new Object[0];
private static final Logger logger = LoggerFactory.getLogger(DefaultRegistryObjectFactory.class); // leave it for the future!
private final Collection<FilterHandler<?, ?>> filterHandlers; private final Collection<FilterHandler<?, ?>> filterHandlers;
private final Collection<IterableFilterHandler<?, ?>> iterableFilterHandlers; private final Collection<IterableFilterHandler<?, ?>> iterableFilterHandlers;
@ -107,7 +105,9 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
protected final @Nullable Object tryQualifiers(Parameter parameter) { protected final @Nullable Object tryQualifiers(Parameter parameter) {
final Class<?> paramType = parameter.getType(); final Class<?> paramType = parameter.getType();
final List<Annotation> qualifiers = RegistryObjectFactoryUtil.getQualifierAnnotations(parameter.getAnnotations()); final List<Annotation> qualifiers = RegistryObjectFactoryUtil.getQualifierAnnotations(
parameter.getAnnotations()
);
if (qualifiers.size() > 1) { if (qualifiers.size() > 1) {
throw new RuntimeException("Parameter " + parameter + " cannot have more than one Qualifier annotation."); throw new RuntimeException("Parameter " + parameter + " cannot have more than one Qualifier annotation.");
} else if (qualifiers.size() == 1) { } else if (qualifiers.size() == 1) {
@ -115,7 +115,9 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
@SuppressWarnings("rawtypes") @SuppressWarnings("rawtypes")
final QualifierHandler handler = this.getInSelfOrParent( final QualifierHandler handler = this.getInSelfOrParent(
f -> f.findQualifierHandler(qualifier.annotationType()), f -> f.findQualifierHandler(qualifier.annotationType()),
() -> new RuntimeException("There is no configured QualifierHandler for " + qualifier.annotationType().getName()) () -> new RuntimeException("There is no configured QualifierHandler for "
+ qualifier.annotationType().getName()
)
); );
final Binding<?> binding = handler.handle(qualifier, paramType); final Binding<?> binding = handler.handle(qualifier, paramType);
if (binding != null) { if (binding != null) {
@ -165,21 +167,23 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
} }
} }
protected final Object resolveInjectedArg(Parameter parameter) { protected final Object resolveInjectedArg(CreateContext context, Parameter parameter) {
final Object qualifierProvidedArg = this.tryQualifiers(parameter); final Object qualifierProvidedArg = this.tryQualifiers(parameter);
if (qualifierProvidedArg != null) { if (qualifierProvidedArg != null) {
this.checkFilters(parameter, qualifierProvidedArg); this.checkFilters(parameter, qualifierProvidedArg);
context.getAllResolved().add(new Resolved(parameter.getType(), qualifierProvidedArg));
return qualifierProvidedArg; return qualifierProvidedArg;
} else { } else {
final Object created = this.get(parameter.getType()); final Object created = this.get(context, parameter.getType());
this.checkFilters(parameter, created); this.checkFilters(parameter, created);
context.getAllResolved().add(new Resolved(parameter.getType(), created));
return created; return created;
} }
} }
protected final void resolveInjectedArgs(Object[] dest, Parameter[] params) { protected final void resolveInjectedArgs(CreateContext context, Object[] dest, Parameter[] params) {
for (int i = 0; i < params.length; i++) { for (int i = 0; i < params.length; i++) {
dest[i] = this.resolveInjectedArg(params[i]); dest[i] = this.resolveInjectedArg(context, params[i]);
} }
} }
@ -194,7 +198,7 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
// TODO: when there is a null arg, we lose the type. Therefore this algorithm breaks. Fix this. // TODO: when there is a null arg, we lose the type. Therefore this algorithm breaks. Fix this.
@Override @Override
protected Object[] createArgs(Constructor<?> constructor, Object[] givenArgs) { protected Object[] createArgs(CreateContext context, Constructor<?> constructor, Object[] givenArgs) {
final Class<?>[] paramTypes = constructor.getParameterTypes(); final Class<?>[] paramTypes = constructor.getParameterTypes();
// check no arg // check no arg
@ -218,7 +222,7 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
if (givenArgs.length == 0) { if (givenArgs.length == 0) {
// if no given args, then they are all injected // if no given args, then they are all injected
this.resolveInjectedArgs(resolvedArgs, allParams); this.resolveInjectedArgs(context, resolvedArgs, allParams);
} else if (givenArgs.length == paramTypes.length) { } else if (givenArgs.length == paramTypes.length) {
// all are given // all are given
this.resolveGivenArgs(resolvedArgs, allParams, givenArgs, 0); this.resolveGivenArgs(resolvedArgs, allParams, givenArgs, 0);
@ -234,17 +238,23 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
final Parameter[] givenParams = new Parameter[allParams.length - firstGivenIndex]; final Parameter[] givenParams = new Parameter[allParams.length - firstGivenIndex];
System.arraycopy(allParams, 0, injectedParams, 0, injectedParams.length); System.arraycopy(allParams, 0, injectedParams, 0, injectedParams.length);
System.arraycopy(allParams, firstGivenIndex, givenParams, 0, allParams.length - firstGivenIndex); System.arraycopy(
allParams, firstGivenIndex, givenParams, 0, allParams.length - firstGivenIndex
);
this.resolveInjectedArgs(resolvedArgs, injectedParams); this.resolveInjectedArgs(context, resolvedArgs, injectedParams);
this.resolveGivenArgs(resolvedArgs, givenParams, givenArgs, firstGivenIndex); this.resolveGivenArgs(resolvedArgs, givenParams, givenArgs, firstGivenIndex);
} }
return resolvedArgs; return resolvedArgs;
} }
@SuppressWarnings("unchecked")
private <T> T handleBinding(Binding<T> binding, Object[] constructorArgs) { private <T> T handleBinding(Binding<T> binding, Object[] constructorArgs) {
return this.handleBinding(binding, null, constructorArgs);
}
@SuppressWarnings("unchecked")
private <T> T handleBinding(Binding<T> binding, @Nullable CreateContext context, Object[] constructorArgs) {
return switch (binding) { return switch (binding) {
case ClassBinding<T>(Class<T> ignored, Class<? extends T> to) -> { case ClassBinding<T>(Class<T> ignored, Class<? extends T> to) -> {
final Annotation scopeAnnotation = getScopeAnnotation(to); final Annotation scopeAnnotation = getScopeAnnotation(to);
@ -253,12 +263,20 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
@SuppressWarnings("rawtypes") @SuppressWarnings("rawtypes")
final ScopeHandler scopeHandler = this.getInSelfOrParent( final ScopeHandler scopeHandler = this.getInSelfOrParent(
f -> f.findScopeHandler(scopeClass), f -> f.findScopeHandler(scopeClass),
() -> new RuntimeException("There is no configured ScopeHandler for " + scopeClass.getName()) () -> new RuntimeException(
"There is no configured ScopeHandler for " + scopeClass.getName()
)
);
final Binding<T> scopedBinding = scopeHandler.onScopedDependencyRequest(
scopeAnnotation, to, this
); );
final Binding<T> scopedBinding = scopeHandler.onScopedDependencyRequest(scopeAnnotation, to, this);
yield this.handleBinding(scopedBinding, constructorArgs); yield this.handleBinding(scopedBinding, constructorArgs);
} else { } else {
yield this.createInstance(to, constructorArgs); if (context != null) {
yield this.createInstance(context, to, constructorArgs);
} else {
yield this.createInstance(to, constructorArgs);
}
} }
} }
case ProviderBinding<T> providerBinding -> providerBinding.provider().get(); case ProviderBinding<T> providerBinding -> providerBinding.provider().get();
@ -276,8 +294,8 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
} }
@Override @Override
protected Object getSetterInjectArg(Class<?> targetType, Method setter, Parameter toInject) { protected Object getSetterInjectArg(CreateContext context, Class<?> targetType, Method setter, Parameter toInject) {
return this.resolveInjectedArg(toInject); return this.resolveInjectedArg(context, toInject);
} }
/** /**
@ -293,7 +311,24 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory
if (parentResult != null) { if (parentResult != null) {
return parentResult; return parentResult;
} else { } else {
throw new RuntimeException("No bindings for " + clazz + " with args " + Arrays.toString(constructorArgs) + "."); throw new RuntimeException(
"No bindings for " + clazz + " with args " + Arrays.toString(constructorArgs) + "."
);
}
}
protected <T> T get(CreateContext context, Class<T> type) {
final Binding<T> binding = this.searchRegistry(type);
if (binding != null) {
return this.handleBinding(binding, context, EMPTY_OBJECT_ARRAY);
}
final T parentResult = this.tryParent(type, EMPTY_OBJECT_ARRAY);
if (parentResult != null) {
return parentResult;
} else {
throw new RuntimeException(
"No bindings for " + type + " with args " + Arrays.toString(EMPTY_OBJECT_ARRAY) + "."
);
} }
} }

View File

@ -166,4 +166,51 @@ public class DefaultRegistryObjectFactoryTests {
assertEquals("Hello, World!", greeter.greet()); assertEquals("Hello, World!", greeter.greet());
} }
public static final class GreeterDependency {
private GreeterDependencyUser greeter;
@Inject
public void setGreeter(GreeterDependencyUser greeter) {
this.greeter = greeter;
}
public String filterGreeting() {
return this.greeter.getGreeting().toUpperCase();
}
}
public static final class GreeterDependencyUser implements Greeter {
private final GreeterDependency greeterDependency;
@Inject
public GreeterDependencyUser(GreeterDependency greeterDependency) {
this.greeterDependency = greeterDependency;
}
@Override
public String greet() {
return this.greeterDependency.filterGreeting();
}
public String getGreeting() {
return "hello, world!";
}
}
@Test
public void injectedDeferred() {
final var b = DefaultRegistryObjectFactory.Builder.withDefaults();
b.configureRegistry(r -> {
r.bind(GreeterDependencyUser.class, toSelf());
r.bind(GreeterDependency.class, toSelf());
});
final var f = b.build();
final var g = f.get(GreeterDependencyUser.class);
assertEquals("HELLO, WORLD!", g.greet());
}
} }