diff --git a/util/di/src/main/java/groowt/util/di/AbstractInjectingObjectFactory.java b/util/di/src/main/java/groowt/util/di/AbstractInjectingObjectFactory.java index 679c53f..afae93a 100644 --- a/util/di/src/main/java/groowt/util/di/AbstractInjectingObjectFactory.java +++ b/util/di/src/main/java/groowt/util/di/AbstractInjectingObjectFactory.java @@ -5,6 +5,8 @@ import org.jetbrains.annotations.Nullable; import java.lang.reflect.*; import java.util.*; +import java.util.function.Consumer; +import java.util.stream.Collectors; import static groowt.util.di.ObjectFactoryUtil.toTypes; @@ -19,6 +21,89 @@ public abstract class AbstractInjectingObjectFactory implements ObjectFactory { Class[] paramTypes ) {} + protected record Resolved(Class type, Object object) {} + + private static final class DeferredSetters { + + private final Class forType; + private final List> actions = new ArrayList<>(); + + public DeferredSetters(Class forType) { + this.forType = forType; + } + + public Class getForType() { + return this.forType; + } + + public List> getActions() { + return this.actions; + } + + } + + protected static class CreateContext { + + private final Deque> constructionStack = new LinkedList<>(); + private final Deque deferredSettersStack = new LinkedList<>(); + private final List allResolved = new ArrayList<>(); + + public CreateContext(Class targetType) { + this.pushConstruction(targetType); + } + + public void checkForCircularDependency(Class typeToConstruct) { + if (constructionStack.contains(typeToConstruct)) { + throw new IllegalStateException( + "Detected a circular constructor dependency for " + typeToConstruct.getName() + + " . Please use setter methods instead. Current construction stack: " + + constructionStack.stream() + .map(Class::getName) + .collect(Collectors.joining(", ")) + ); + } + } + + public void pushConstruction(Class type) { + this.constructionStack.push(type); + this.deferredSettersStack.push(new DeferredSetters(type)); + } + + public void popConstruction() { + this.constructionStack.pop(); + this.deferredSettersStack.pop(); + } + + public boolean containsConstruction(Class type) { + return this.constructionStack.contains(type); + } + + public void addDeferredSetterAction(Class forType, Consumer action) { + boolean found = false; + for (final DeferredSetters deferredSetters : this.deferredSettersStack) { + if (deferredSetters.getForType().equals(forType)) { + found = true; + deferredSetters.getActions().add(action); + break; + } + } + if (!found) { + throw new IllegalArgumentException( + "There is no construction for type " + forType.getName() + " which should be deferred." + ); + } + } + + public List> getDeferredSetterActions() { + return this.deferredSettersStack.getFirst().getActions(); + } + + public List getAllResolved() { + return this.allResolved; + } + + } + private final Map, Constructor[]> cachedAllConstructors = new HashMap<>(); private final Collection> cachedInjectConstructors = new ArrayList<>(); private final Collection> cachedNonInjectConstructors = new ArrayList<>(); @@ -82,13 +167,14 @@ public abstract class AbstractInjectingObjectFactory implements ObjectFactory { } /** - * @implNote If overridden, please cache any found non-inject constructors using {@link #putCachedNonInjectConstructor}. + * @implNote If overridden, please cache any found non-inject constructors using + * {@link #putCachedNonInjectConstructor}. * - * @param clazz the {@link Class} in which to search for a constructor which does not have an {@literal @}Inject - * annotation + * @param clazz the {@link Class} in which to search for a constructor which does + * not have an {@literal @}Inject annotation. * @param constructorArgs the given constructor args - * @return the found non-inject constructor appropriate for the given constructor args, or {@code null} if no - * such constructor exists + * @return the found non-inject constructor appropriate for the given constructor args, + * or {@code null} if no such constructor exists * @param the type */ @SuppressWarnings("unchecked") @@ -156,42 +242,104 @@ public abstract class AbstractInjectingObjectFactory implements ObjectFactory { protected Parameter getCachedInjectParameter(Method setter) { return this.cachedSetterParameters.computeIfAbsent(setter, s -> { if (s.getParameterCount() != 1) { - throw new IllegalArgumentException("Setter " + s.getName() + " has a parameter count other than one (1)!"); + throw new IllegalArgumentException( + "Setter " + s.getName() + " has a parameter count other than one (1)!" + ); } return s.getParameters()[0]; }); } - protected void injectSetter(Object target, Method setter) { - try { - setter.invoke(target, this.getSetterInjectArg(target.getClass(), setter, this.getCachedInjectParameter(setter))); - } catch (InvocationTargetException | IllegalAccessException e) { - throw new RuntimeException(e); + private @Nullable Object findInContext(CreateContext context, Class type) { + for (final Resolved resolved : context.getAllResolved()) { + if (type.isAssignableFrom(resolved.type)) { + return resolved.object; + } + } + return null; + } + + protected void injectSetter(CreateContext context, Object target, Method setter) { + final Parameter injectParam = this.getCachedInjectParameter(setter); + final Class typeToInject = injectParam.getType(); + final @Nullable Object fromContext = this.findInContext(context, typeToInject); + + final Consumer setterAction = arg -> { + try { + setter.invoke(target, arg); + } catch (InvocationTargetException | IllegalAccessException e) { + throw new RuntimeException(e); + } + }; + + if (context.containsConstruction(typeToInject)) { + context.addDeferredSetterAction(typeToInject, setterAction); + } else { + final Object arg; + if (fromContext != null) { + arg = fromContext; + } else { + arg = this.getSetterInjectArg(context, target.getClass(), setter, injectParam); + context.getAllResolved().add(new Resolved(typeToInject, arg)); + } + setterAction.accept(arg); } } - protected void injectSetters(Object target) { - this.getCachedSettersFor(target).forEach(setter -> this.injectSetter(target, setter)); + protected void injectSetters(CreateContext context, Object target) { + this.getCachedSettersFor(target).forEach(setter -> this.injectSetter(context, target, setter)); } /** * {@inheritDoc} */ @Override - public T createInstance(Class clazz, Object... constructorArgs) { - final Constructor constructor = this.findConstructor(clazz, constructorArgs); - final Object[] allArgs = this.createArgs(constructor, constructorArgs); + public T createInstance(Class type, Object... constructorArgs) { + final Constructor constructor = this.findConstructor(type, constructorArgs); + final CreateContext context = new CreateContext(type); + final Object[] allArgs = this.createArgs(context, constructor, constructorArgs); try { final T instance = constructor.newInstance(allArgs); - this.injectSetters(instance); + context.getAllResolved().add(new Resolved(type, instance)); + this.injectSetters(context, instance); + final List> deferredSetterActions = context.getDeferredSetterActions(); + context.popConstruction(); + deferredSetterActions.forEach(setterAction -> setterAction.accept(instance)); return instance; } catch (InvocationTargetException | IllegalAccessException | InstantiationException e) { throw new RuntimeException(e); // In the future, we might have an option to ignore exceptions } } - protected abstract Object[] createArgs(Constructor constructor, Object[] constructorArgs); + protected T createInstance(CreateContext context, Class type, Object... givenArgs) { + final Constructor constructor = this.findConstructor(type, givenArgs); + context.checkForCircularDependency(type); + context.pushConstruction(type); + final Object[] allArgs = this.createArgs(context, constructor, givenArgs); + try { + final T instance = constructor.newInstance(allArgs); + context.getAllResolved().add(new Resolved(type, instance)); + this.injectSetters(context, instance); + final List> deferredSetterActions = context.getDeferredSetterActions(); + context.popConstruction(); + deferredSetterActions.forEach(setterAction -> setterAction.accept(instance)); + return instance; + } catch (InvocationTargetException | IllegalAccessException | InstantiationException e) { + throw new RuntimeException("e"); + } + } - protected abstract Object getSetterInjectArg(Class targetType, Method setter, Parameter toInject); + protected abstract Object[] createArgs( + CreateContext context, + Constructor constructor, + Object[] constructorArgs + ); + + protected abstract Object getSetterInjectArg( + CreateContext context, + Class targetType, + Method setter, + Parameter toInject + ); } diff --git a/util/di/src/main/java/groowt/util/di/AbstractRegistryObjectFactory.java b/util/di/src/main/java/groowt/util/di/AbstractRegistryObjectFactory.java index f6e530e..cd947fb 100644 --- a/util/di/src/main/java/groowt/util/di/AbstractRegistryObjectFactory.java +++ b/util/di/src/main/java/groowt/util/di/AbstractRegistryObjectFactory.java @@ -14,7 +14,8 @@ import java.util.function.Supplier; import static groowt.util.di.RegistryObjectFactoryUtil.orElseSupply; -public abstract class AbstractRegistryObjectFactory extends AbstractInjectingObjectFactory implements RegistryObjectFactory { +public abstract class AbstractRegistryObjectFactory extends AbstractInjectingObjectFactory + implements RegistryObjectFactory { public static abstract class AbstractBuilder implements Builder { diff --git a/util/di/src/main/java/groowt/util/di/DefaultRegistryObjectFactory.java b/util/di/src/main/java/groowt/util/di/DefaultRegistryObjectFactory.java index 422c146..f8e8784 100644 --- a/util/di/src/main/java/groowt/util/di/DefaultRegistryObjectFactory.java +++ b/util/di/src/main/java/groowt/util/di/DefaultRegistryObjectFactory.java @@ -3,8 +3,6 @@ package groowt.util.di; import groowt.util.di.filters.FilterHandler; import groowt.util.di.filters.IterableFilterHandler; import org.jetbrains.annotations.Nullable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.lang.annotation.Annotation; import java.lang.reflect.Constructor; @@ -20,7 +18,8 @@ import static groowt.util.di.RegistryObjectFactoryUtil.*; public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory { - public static final class Builder extends AbstractRegistryObjectFactory.AbstractBuilder { + public static final class Builder + extends AbstractRegistryObjectFactory.AbstractBuilder { /** * Creates a {@code Builder} initialized with a {@link DefaultRegistry}, which is in-turn configured with a @@ -74,7 +73,6 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory } private static final Object[] EMPTY_OBJECT_ARRAY = new Object[0]; - private static final Logger logger = LoggerFactory.getLogger(DefaultRegistryObjectFactory.class); // leave it for the future! private final Collection> filterHandlers; private final Collection> iterableFilterHandlers; @@ -107,7 +105,9 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory @SuppressWarnings("unchecked") protected final @Nullable Object tryQualifiers(Parameter parameter) { final Class paramType = parameter.getType(); - final List qualifiers = RegistryObjectFactoryUtil.getQualifierAnnotations(parameter.getAnnotations()); + final List qualifiers = RegistryObjectFactoryUtil.getQualifierAnnotations( + parameter.getAnnotations() + ); if (qualifiers.size() > 1) { throw new RuntimeException("Parameter " + parameter + " cannot have more than one Qualifier annotation."); } else if (qualifiers.size() == 1) { @@ -115,7 +115,9 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory @SuppressWarnings("rawtypes") final QualifierHandler handler = this.getInSelfOrParent( f -> f.findQualifierHandler(qualifier.annotationType()), - () -> new RuntimeException("There is no configured QualifierHandler for " + qualifier.annotationType().getName()) + () -> new RuntimeException("There is no configured QualifierHandler for " + + qualifier.annotationType().getName() + ) ); final Binding binding = handler.handle(qualifier, paramType); if (binding != null) { @@ -165,21 +167,23 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory } } - protected final Object resolveInjectedArg(Parameter parameter) { + protected final Object resolveInjectedArg(CreateContext context, Parameter parameter) { final Object qualifierProvidedArg = this.tryQualifiers(parameter); if (qualifierProvidedArg != null) { this.checkFilters(parameter, qualifierProvidedArg); + context.getAllResolved().add(new Resolved(parameter.getType(), qualifierProvidedArg)); return qualifierProvidedArg; } else { - final Object created = this.get(parameter.getType()); + final Object created = this.get(context, parameter.getType()); this.checkFilters(parameter, created); + context.getAllResolved().add(new Resolved(parameter.getType(), created)); return created; } } - protected final void resolveInjectedArgs(Object[] dest, Parameter[] params) { + protected final void resolveInjectedArgs(CreateContext context, Object[] dest, Parameter[] params) { for (int i = 0; i < params.length; i++) { - dest[i] = this.resolveInjectedArg(params[i]); + dest[i] = this.resolveInjectedArg(context, params[i]); } } @@ -194,7 +198,7 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory // TODO: when there is a null arg, we lose the type. Therefore this algorithm breaks. Fix this. @Override - protected Object[] createArgs(Constructor constructor, Object[] givenArgs) { + protected Object[] createArgs(CreateContext context, Constructor constructor, Object[] givenArgs) { final Class[] paramTypes = constructor.getParameterTypes(); // check no arg @@ -218,7 +222,7 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory if (givenArgs.length == 0) { // if no given args, then they are all injected - this.resolveInjectedArgs(resolvedArgs, allParams); + this.resolveInjectedArgs(context, resolvedArgs, allParams); } else if (givenArgs.length == paramTypes.length) { // all are given this.resolveGivenArgs(resolvedArgs, allParams, givenArgs, 0); @@ -234,17 +238,23 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory final Parameter[] givenParams = new Parameter[allParams.length - firstGivenIndex]; System.arraycopy(allParams, 0, injectedParams, 0, injectedParams.length); - System.arraycopy(allParams, firstGivenIndex, givenParams, 0, allParams.length - firstGivenIndex); + System.arraycopy( + allParams, firstGivenIndex, givenParams, 0, allParams.length - firstGivenIndex + ); - this.resolveInjectedArgs(resolvedArgs, injectedParams); + this.resolveInjectedArgs(context, resolvedArgs, injectedParams); this.resolveGivenArgs(resolvedArgs, givenParams, givenArgs, firstGivenIndex); } return resolvedArgs; } - @SuppressWarnings("unchecked") private T handleBinding(Binding binding, Object[] constructorArgs) { + return this.handleBinding(binding, null, constructorArgs); + } + + @SuppressWarnings("unchecked") + private T handleBinding(Binding binding, @Nullable CreateContext context, Object[] constructorArgs) { return switch (binding) { case ClassBinding(Class ignored, Class to) -> { final Annotation scopeAnnotation = getScopeAnnotation(to); @@ -253,12 +263,20 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory @SuppressWarnings("rawtypes") final ScopeHandler scopeHandler = this.getInSelfOrParent( f -> f.findScopeHandler(scopeClass), - () -> new RuntimeException("There is no configured ScopeHandler for " + scopeClass.getName()) + () -> new RuntimeException( + "There is no configured ScopeHandler for " + scopeClass.getName() + ) + ); + final Binding scopedBinding = scopeHandler.onScopedDependencyRequest( + scopeAnnotation, to, this ); - final Binding scopedBinding = scopeHandler.onScopedDependencyRequest(scopeAnnotation, to, this); yield this.handleBinding(scopedBinding, constructorArgs); } else { - yield this.createInstance(to, constructorArgs); + if (context != null) { + yield this.createInstance(context, to, constructorArgs); + } else { + yield this.createInstance(to, constructorArgs); + } } } case ProviderBinding providerBinding -> providerBinding.provider().get(); @@ -276,8 +294,8 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory } @Override - protected Object getSetterInjectArg(Class targetType, Method setter, Parameter toInject) { - return this.resolveInjectedArg(toInject); + protected Object getSetterInjectArg(CreateContext context, Class targetType, Method setter, Parameter toInject) { + return this.resolveInjectedArg(context, toInject); } /** @@ -293,7 +311,24 @@ public class DefaultRegistryObjectFactory extends AbstractRegistryObjectFactory if (parentResult != null) { return parentResult; } else { - throw new RuntimeException("No bindings for " + clazz + " with args " + Arrays.toString(constructorArgs) + "."); + throw new RuntimeException( + "No bindings for " + clazz + " with args " + Arrays.toString(constructorArgs) + "." + ); + } + } + + protected T get(CreateContext context, Class type) { + final Binding binding = this.searchRegistry(type); + if (binding != null) { + return this.handleBinding(binding, context, EMPTY_OBJECT_ARRAY); + } + final T parentResult = this.tryParent(type, EMPTY_OBJECT_ARRAY); + if (parentResult != null) { + return parentResult; + } else { + throw new RuntimeException( + "No bindings for " + type + " with args " + Arrays.toString(EMPTY_OBJECT_ARRAY) + "." + ); } } diff --git a/util/di/src/test/java/groowt/util/di/DefaultRegistryObjectFactoryTests.java b/util/di/src/test/java/groowt/util/di/DefaultRegistryObjectFactoryTests.java index 1ee62bf..1ecfc9e 100644 --- a/util/di/src/test/java/groowt/util/di/DefaultRegistryObjectFactoryTests.java +++ b/util/di/src/test/java/groowt/util/di/DefaultRegistryObjectFactoryTests.java @@ -166,4 +166,51 @@ public class DefaultRegistryObjectFactoryTests { assertEquals("Hello, World!", greeter.greet()); } + public static final class GreeterDependency { + + private GreeterDependencyUser greeter; + + @Inject + public void setGreeter(GreeterDependencyUser greeter) { + this.greeter = greeter; + } + + public String filterGreeting() { + return this.greeter.getGreeting().toUpperCase(); + } + + } + + public static final class GreeterDependencyUser implements Greeter { + + private final GreeterDependency greeterDependency; + + @Inject + public GreeterDependencyUser(GreeterDependency greeterDependency) { + this.greeterDependency = greeterDependency; + } + + @Override + public String greet() { + return this.greeterDependency.filterGreeting(); + } + + public String getGreeting() { + return "hello, world!"; + } + + } + + @Test + public void injectedDeferred() { + final var b = DefaultRegistryObjectFactory.Builder.withDefaults(); + b.configureRegistry(r -> { + r.bind(GreeterDependencyUser.class, toSelf()); + r.bind(GreeterDependency.class, toSelf()); + }); + final var f = b.build(); + final var g = f.get(GreeterDependencyUser.class); + assertEquals("HELLO, WORLD!", g.greet()); + } + }