Fix bug where @Singleton classes bound toSelf() caused stack overflow.

This commit is contained in:
Jesse Brault 2025-01-25 00:39:37 -06:00
parent 04866b4d3a
commit 525932668f
4 changed files with 84 additions and 30 deletions

View File

@ -24,9 +24,9 @@ For example:
- [ ] Get rid of wvc compiler dependency on di
## 0.1.3
- [ ] refactor tools/gradle start scripts to use dist instead of custom bin script
- [ ] have custom bin/* scripts which point to dist(s) for convenience
- [ ] di bug: @Singleton toSelf() causes stack overflow
- [ ] ~~refactor tools/gradle start scripts to use dist instead of custom bin script~~
- [ ] ~~have custom bin/* scripts which point to dist(s) for convenience~~
- [x] di bug: @Singleton toSelf() causes stack overflow
- [ ] `OutletContainer` trait or interface for components which can contain an `<Outlet />` child.
- [ ] `Context` should have methods for simply finding an ancestor of a certain type without the need for a predicate.

View File

@ -3,31 +3,58 @@ package groowt.util.di;
import org.jetbrains.annotations.Nullable;
import java.lang.annotation.Annotation;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
public class DefaultRegistry implements Registry {
protected record ClassKeyBinding<T>(Class<T> key, Binding<T> binding) {}
protected static class BindingContainer {
protected final Collection<ClassKeyBinding<?>> classBindings = new ArrayList<>();
private final Map<Class<?>, Binding<?>> bindings = new HashMap<>();
@SuppressWarnings("unchecked")
public <T> @Nullable Binding<T> get(Class<T> key) {
for (final var entry : bindings.entrySet()) {
if (entry.getKey().isAssignableFrom(key)) {
return (Binding<T>) entry.getValue();
}
}
return null;
}
public <T> void put(Class<T> key, Binding<T> binding) {
this.bindings.put(key, binding);
}
public void remove(Class<?> key) {
this.bindings.remove(key);
}
public <T> void removeIf(Class<T> key, Predicate<? super Binding<T>> filter) {
if (filter.test(this.get(key))) {
this.bindings.remove(key);
}
}
public void clear() {
this.bindings.clear();
}
}
protected final BindingContainer bindingContainer = new BindingContainer();
protected final Collection<RegistryExtension> extensions = new ArrayList<>();
@Override
public void removeBinding(Class<?> key) {
this.classBindings.removeIf(classKeyBinding -> classKeyBinding.key().equals(key));
this.bindingContainer.remove(key);
}
@SuppressWarnings("unchecked")
@Override
public <T> void removeBindingIf(Class<T> key, Predicate<Binding<T>> filter) {
this.classBindings.removeIf(classKeyBinding ->
classKeyBinding.key().equals(key) && filter.test((Binding<T>) classKeyBinding.binding())
);
this.bindingContainer.removeIf(key, filter);
}
private <E extends RegistryExtension> List<E> getAllRegistryExtensions(Class<E> extensionType) {
@ -117,18 +144,12 @@ public class DefaultRegistry implements Registry {
public <T> void bind(Class<T> key, Consumer<? super BindingConfigurator<T>> configure) {
final var configurator = new SimpleBindingConfigurator<>(key);
configure.accept(configurator);
this.classBindings.add(new ClassKeyBinding<>(key, configurator.getBinding()));
this.bindingContainer.put(key, configurator.getBinding());
}
@SuppressWarnings("unchecked")
@Override
public @Nullable <T> Binding<T> getBinding(Class<T> key) {
for (final var classKeyBinding : this.classBindings) {
if (key.isAssignableFrom(classKeyBinding.key())) {
return (Binding<T>) classKeyBinding.binding();
}
}
return null;
public <T> @Nullable Binding<T> getBinding(Class<T> key) {
return this.bindingContainer.get(key);
}
private KeyBinder<?> findKeyBinder(Class<?> keyClass) {
@ -189,7 +210,7 @@ public class DefaultRegistry implements Registry {
@Override
public void clearAllBindings() {
this.classBindings.clear();
this.bindingContainer.clear();
for (final var extension : this.extensions) {
if (extension instanceof KeyBinder<?> keyBinder) {
keyBinder.clearAllBindings();

View File

@ -1,8 +1,9 @@
package groowt.util.di;
import jakarta.inject.Provider;
import jakarta.inject.Singleton;
import static groowt.util.di.BindingUtil.toSingleton;
import static groowt.util.di.BindingUtil.toLazySingleton;
public final class SingletonScopeHandler implements ScopeHandler<Singleton> {
@ -19,12 +20,22 @@ public final class SingletonScopeHandler implements ScopeHandler<Singleton> {
RegistryObjectFactory objectFactory
) {
final Binding<T> potentialBinding = this.owner.getBinding(dependencyClass);
if (potentialBinding != null) {
return potentialBinding;
} else {
this.owner.bind(dependencyClass, toSingleton(objectFactory.createInstance(dependencyClass)));
return this.owner.getBinding(dependencyClass);
return switch (potentialBinding) {
case ClassBinding<T>(Class<T> from, Class<? extends T> to) -> {
this.owner.bind(from, toLazySingleton(() -> objectFactory.createInstance(to)));
yield this.owner.getBinding(from);
}
case ProviderBinding<T>(Class<T> from, Provider<? extends T> provider) -> {
this.owner.bind(from, toLazySingleton(provider::get));
yield this.owner.getBinding(from);
}
case SingletonBinding<T> singletonBinding -> singletonBinding;
case LazySingletonBinding<T> lazySingletonBinding -> lazySingletonBinding;
case null -> {
this.owner.bind(dependencyClass, toLazySingleton(() -> objectFactory.createInstance(dependencyClass)));
yield this.owner.getBinding(dependencyClass);
}
};
}
@Override

View File

@ -2,6 +2,7 @@ package groowt.util.di;
import jakarta.inject.Inject;
import jakarta.inject.Named;
import jakarta.inject.Singleton;
import org.junit.jupiter.api.Test;
import static groowt.util.di.BindingUtil.*;
@ -250,4 +251,25 @@ public class DefaultRegistryObjectFactoryTests {
assertEquals("Given Greeting", g.greet());
}
@Singleton
public static final class SingletonGreeter implements Greeter {
@Override
public String greet() {
return "Hello, World!";
}
}
@Test
public void singletonDoesNotOverflow() {
final var b = DefaultRegistryObjectFactory.Builder.withDefaults();
b.configureRegistry(r -> {
r.bind(SingletonGreeter.class, toSelf());
});
final var f = b.build();
final var g = f.get(SingletonGreeter.class);
assertEquals("Hello, World!", g.greet());
}
}