Add basic AI search.
This commit is contained in:
parent
120e6d90e1
commit
9f54c63c53
14
build.gradle
14
build.gradle
@ -43,11 +43,19 @@ configurations {
|
||||
}
|
||||
}
|
||||
|
||||
ext {
|
||||
set('springAiVersion', "1.1.2")
|
||||
}
|
||||
|
||||
dependencies {
|
||||
// From Spring Initalizr
|
||||
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
|
||||
implementation 'org.springframework.boot:spring-boot-starter-security'
|
||||
implementation 'org.springframework.boot:spring-boot-starter-web'
|
||||
implementation 'org.springframework.ai:spring-ai-advisors-vector-store'
|
||||
implementation 'org.springframework.ai:spring-ai-starter-model-ollama'
|
||||
implementation 'org.springframework.ai:spring-ai-starter-vector-store-pgvector'
|
||||
implementation 'org.hibernate.orm:hibernate-vector:6.6.39.Final'
|
||||
runtimeOnly 'org.postgresql:postgresql'
|
||||
testImplementation 'org.springframework.boot:spring-boot-starter-test'
|
||||
testImplementation 'org.springframework.security:spring-security-test'
|
||||
@ -88,6 +96,12 @@ dependencies {
|
||||
testFixturesImplementation 'org.hamcrest:hamcrest:3.0'
|
||||
}
|
||||
|
||||
dependencyManagement {
|
||||
imports {
|
||||
mavenBom "org.springframework.ai:spring-ai-bom:${springAiVersion}"
|
||||
}
|
||||
}
|
||||
|
||||
tasks.register('integrationTest', Test) {
|
||||
description = 'Run integration tests.'
|
||||
group = 'verification'
|
||||
|
||||
@ -0,0 +1,58 @@
|
||||
package app.mealsmadeeasy.api;
|
||||
|
||||
import app.mealsmadeeasy.api.recipe.RecipeEmbeddingEntity;
|
||||
import app.mealsmadeeasy.api.recipe.RecipeEntity;
|
||||
import app.mealsmadeeasy.api.recipe.RecipeRepository;
|
||||
import app.mealsmadeeasy.api.recipe.RecipeService;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.boot.ApplicationArguments;
|
||||
import org.springframework.boot.ApplicationRunner;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.time.OffsetDateTime;
|
||||
import java.util.List;
|
||||
|
||||
@Component
|
||||
@ConditionalOnProperty(name = "backfill.recipe-embeddings.enabled", havingValue = "true")
|
||||
public class BackfillRecipeEmbeddings implements ApplicationRunner {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(BackfillRecipeEmbeddings.class);
|
||||
|
||||
private final RecipeRepository recipeRepository;
|
||||
private final RecipeService recipeService;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
public BackfillRecipeEmbeddings(
|
||||
RecipeRepository recipeRepository,
|
||||
RecipeService recipeService,
|
||||
EmbeddingModel embeddingModel
|
||||
) {
|
||||
this.recipeRepository = recipeRepository;
|
||||
this.recipeService = recipeService;
|
||||
this.embeddingModel = embeddingModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run(ApplicationArguments args) {
|
||||
final List<RecipeEntity> recipeEntities = this.recipeRepository.findAllByEmbeddingIsNull();
|
||||
for (final RecipeEntity recipeEntity : recipeEntities) {
|
||||
logger.info("Calculating embedding for {}", recipeEntity);
|
||||
final String renderedMarkdown = this.recipeService.getRenderedMarkdown(recipeEntity);
|
||||
final String toEmbed = "<h1>" + recipeEntity.getTitle() + "</h1>" + renderedMarkdown;
|
||||
final float[] embedding = this.embeddingModel.embed(toEmbed);
|
||||
|
||||
final RecipeEmbeddingEntity recipeEmbedding = new RecipeEmbeddingEntity();
|
||||
recipeEmbedding.setRecipe(recipeEntity);
|
||||
recipeEmbedding.setEmbedding(embedding);
|
||||
recipeEmbedding.setTimestamp(OffsetDateTime.now());
|
||||
recipeEntity.setEmbedding(recipeEmbedding);
|
||||
|
||||
this.recipeRepository.save(recipeEntity);
|
||||
}
|
||||
this.recipeRepository.flush();
|
||||
}
|
||||
|
||||
}
|
||||
@ -1,10 +1,12 @@
|
||||
package app.mealsmadeeasy.api.recipe;
|
||||
|
||||
import app.mealsmadeeasy.api.image.ImageException;
|
||||
import app.mealsmadeeasy.api.recipe.body.RecipeSearchBody;
|
||||
import app.mealsmadeeasy.api.recipe.comment.RecipeComment;
|
||||
import app.mealsmadeeasy.api.recipe.comment.RecipeCommentCreateBody;
|
||||
import app.mealsmadeeasy.api.recipe.comment.RecipeCommentService;
|
||||
import app.mealsmadeeasy.api.recipe.comment.RecipeCommentView;
|
||||
import app.mealsmadeeasy.api.recipe.spec.RecipeAiSearchSpec;
|
||||
import app.mealsmadeeasy.api.recipe.spec.RecipeUpdateSpec;
|
||||
import app.mealsmadeeasy.api.recipe.star.RecipeStar;
|
||||
import app.mealsmadeeasy.api.recipe.star.RecipeStarService;
|
||||
@ -13,6 +15,7 @@ import app.mealsmadeeasy.api.recipe.view.RecipeExceptionView;
|
||||
import app.mealsmadeeasy.api.recipe.view.RecipeInfoView;
|
||||
import app.mealsmadeeasy.api.sliceview.SliceViewService;
|
||||
import app.mealsmadeeasy.api.user.User;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import org.springframework.data.domain.Pageable;
|
||||
import org.springframework.data.domain.Slice;
|
||||
@ -23,6 +26,7 @@ import org.springframework.security.core.annotation.AuthenticationPrincipal;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@RestController
|
||||
@ -33,17 +37,20 @@ public class RecipeController {
|
||||
private final RecipeStarService recipeStarService;
|
||||
private final RecipeCommentService recipeCommentService;
|
||||
private final SliceViewService sliceViewService;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
public RecipeController(
|
||||
RecipeService recipeService,
|
||||
RecipeStarService recipeStarService,
|
||||
RecipeCommentService recipeCommentService,
|
||||
SliceViewService sliceViewService
|
||||
SliceViewService sliceViewService,
|
||||
ObjectMapper objectMapper
|
||||
) {
|
||||
this.recipeService = recipeService;
|
||||
this.recipeStarService = recipeStarService;
|
||||
this.recipeCommentService = recipeCommentService;
|
||||
this.sliceViewService = sliceViewService;
|
||||
this.objectMapper = objectMapper;
|
||||
}
|
||||
|
||||
@ExceptionHandler(RecipeException.class)
|
||||
@ -98,10 +105,23 @@ public class RecipeController {
|
||||
@GetMapping
|
||||
public ResponseEntity<Map<String, Object>> getRecipeInfoViews(
|
||||
Pageable pageable,
|
||||
@RequestBody(required = false) RecipeSearchBody recipeSearchBody,
|
||||
@AuthenticationPrincipal User user
|
||||
) {
|
||||
final Slice<RecipeInfoView> slice = this.recipeService.getInfoViewsViewableBy(pageable, user);
|
||||
return ResponseEntity.ok(this.sliceViewService.getSliceView(slice));
|
||||
if (recipeSearchBody != null) {
|
||||
if (recipeSearchBody.getType() == RecipeSearchBody.Type.AI_PROMPT) {
|
||||
final RecipeAiSearchSpec spec = this.objectMapper.convertValue(
|
||||
recipeSearchBody.getData(), RecipeAiSearchSpec.class
|
||||
);
|
||||
final List<RecipeInfoView> results = this.recipeService.aiSearch(spec, user);
|
||||
return ResponseEntity.ok(Map.of("results", results));
|
||||
} else {
|
||||
throw new IllegalArgumentException("Invalid recipeSearchBody type: " + recipeSearchBody.getType());
|
||||
}
|
||||
} else {
|
||||
final Slice<RecipeInfoView> slice = this.recipeService.getInfoViewsViewableBy(pageable, user);
|
||||
return ResponseEntity.ok(this.sliceViewService.getSliceView(slice));
|
||||
}
|
||||
}
|
||||
|
||||
@PostMapping("/{username}/{slug}/star")
|
||||
|
||||
@ -0,0 +1,63 @@
|
||||
package app.mealsmadeeasy.api.recipe;
|
||||
|
||||
import jakarta.persistence.*;
|
||||
import org.hibernate.annotations.Array;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.time.OffsetDateTime;
|
||||
|
||||
@Entity
|
||||
@Table(name = "recipe_embedding")
|
||||
public class RecipeEmbeddingEntity {
|
||||
|
||||
@Id
|
||||
private Integer id;
|
||||
|
||||
@OneToOne(fetch = FetchType.LAZY, optional = false)
|
||||
@MapsId
|
||||
@JoinColumn(name = "recipe_id")
|
||||
private RecipeEntity recipe;
|
||||
|
||||
@JdbcTypeCode(SqlTypes.VECTOR)
|
||||
@Array(length = 1024)
|
||||
@Nullable
|
||||
private float[] embedding;
|
||||
|
||||
@Column(nullable = false)
|
||||
private OffsetDateTime timestamp;
|
||||
|
||||
public Integer getId() {
|
||||
return this.id;
|
||||
}
|
||||
|
||||
public void setId(Integer id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public RecipeEntity getRecipe() {
|
||||
return this.recipe;
|
||||
}
|
||||
|
||||
public void setRecipe(RecipeEntity recipe) {
|
||||
this.recipe = recipe;
|
||||
}
|
||||
|
||||
public float[] getEmbedding() {
|
||||
return this.embedding;
|
||||
}
|
||||
|
||||
public void setEmbedding(float[] embedding) {
|
||||
this.embedding = embedding;
|
||||
}
|
||||
|
||||
public OffsetDateTime getTimestamp() {
|
||||
return this.timestamp;
|
||||
}
|
||||
|
||||
public void setTimestamp(OffsetDateTime timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
}
|
||||
|
||||
}
|
||||
@ -78,6 +78,9 @@ public final class RecipeEntity implements Recipe {
|
||||
@JoinColumn(name = "main_image_id")
|
||||
private S3ImageEntity mainImage;
|
||||
|
||||
@OneToOne(mappedBy = "recipe", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true)
|
||||
private RecipeEmbeddingEntity embedding;
|
||||
|
||||
@Override
|
||||
public Integer getId() {
|
||||
return this.id;
|
||||
@ -238,4 +241,12 @@ public final class RecipeEntity implements Recipe {
|
||||
this.mainImage = image;
|
||||
}
|
||||
|
||||
public RecipeEmbeddingEntity getEmbedding() {
|
||||
return this.embedding;
|
||||
}
|
||||
|
||||
public void setEmbedding(RecipeEmbeddingEntity embedding) {
|
||||
this.embedding = embedding;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -41,4 +41,34 @@ public interface RecipeRepository extends JpaRepository<RecipeEntity, Long> {
|
||||
@Query("SELECT r FROM Recipe r WHERE r.isPublic OR r.owner = ?1 OR ?1 MEMBER OF r.viewers")
|
||||
Slice<RecipeEntity> findAllViewableBy(UserEntity viewer, Pageable pageable);
|
||||
|
||||
List<RecipeEntity> findAllByEmbeddingIsNull();
|
||||
|
||||
@Query(
|
||||
nativeQuery = true,
|
||||
value = """
|
||||
WITH distances AS (SELECT recipe_id, embedding <=> cast(?1 AS vector) AS distance FROM recipe_embedding)
|
||||
SELECT r.* FROM distances d
|
||||
INNER JOIN recipe r ON r.id = d.recipe_id
|
||||
WHERE d.distance < ?2 AND (
|
||||
r.is_public = TRUE
|
||||
OR r.owner_id = ?3
|
||||
OR exists(SELECT 1 FROM recipe_viewer v WHERE v.recipe_id = r.id AND v.viewer_id = ?3)
|
||||
)
|
||||
ORDER BY d.distance;
|
||||
"""
|
||||
)
|
||||
List<RecipeEntity> searchByEmbeddingAndViewableBy(float[] queryEmbedding, float similarity, Integer viewerId);
|
||||
|
||||
@Query(
|
||||
nativeQuery = true,
|
||||
value = """
|
||||
WITH distances AS (SELECT recipe_id, embedding <=> cast(?1 AS vector) AS distance FROM recipe_embedding)
|
||||
SELECT r.* FROM distances d
|
||||
INNER JOIN recipe r ON r.id = d.recipe_id
|
||||
WHERE d.distance < ?2 AND r.is_public = TRUE
|
||||
ORDER BY d.distance;
|
||||
"""
|
||||
)
|
||||
List<RecipeEntity> searchByEmbeddingAndIsPublic(float[] queryEmbedding, float similarity);
|
||||
|
||||
}
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
package app.mealsmadeeasy.api.recipe;
|
||||
|
||||
import app.mealsmadeeasy.api.image.ImageException;
|
||||
import app.mealsmadeeasy.api.recipe.spec.RecipeAiSearchSpec;
|
||||
import app.mealsmadeeasy.api.recipe.spec.RecipeCreateSpec;
|
||||
import app.mealsmadeeasy.api.recipe.spec.RecipeUpdateSpec;
|
||||
import app.mealsmadeeasy.api.recipe.view.FullRecipeView;
|
||||
import app.mealsmadeeasy.api.recipe.view.RecipeInfoView;
|
||||
import app.mealsmadeeasy.api.user.User;
|
||||
import org.jetbrains.annotations.ApiStatus;
|
||||
import org.jetbrains.annotations.Contract;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import org.springframework.data.domain.Pageable;
|
||||
@ -35,6 +37,8 @@ public interface RecipeService {
|
||||
List<Recipe> getRecipesViewableBy(User viewer);
|
||||
List<Recipe> getRecipesOwnedBy(User owner);
|
||||
|
||||
List<RecipeInfoView> aiSearch(RecipeAiSearchSpec searchSpec, @Nullable User viewer);
|
||||
|
||||
Recipe update(String username, String slug, RecipeUpdateSpec spec, User modifier)
|
||||
throws RecipeException, ImageException;
|
||||
|
||||
@ -53,4 +57,7 @@ public interface RecipeService {
|
||||
@Contract("_, _, null -> null")
|
||||
@Nullable Boolean isOwner(String username, String slug, @Nullable User viewer);
|
||||
|
||||
@ApiStatus.Internal
|
||||
String getRenderedMarkdown(RecipeEntity entity);
|
||||
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ import app.mealsmadeeasy.api.image.ImageService;
|
||||
import app.mealsmadeeasy.api.image.S3ImageEntity;
|
||||
import app.mealsmadeeasy.api.image.view.ImageView;
|
||||
import app.mealsmadeeasy.api.markdown.MarkdownService;
|
||||
import app.mealsmadeeasy.api.recipe.spec.RecipeAiSearchSpec;
|
||||
import app.mealsmadeeasy.api.recipe.spec.RecipeCreateSpec;
|
||||
import app.mealsmadeeasy.api.recipe.spec.RecipeUpdateSpec;
|
||||
import app.mealsmadeeasy.api.recipe.star.RecipeStarRepository;
|
||||
@ -13,8 +14,10 @@ import app.mealsmadeeasy.api.recipe.view.FullRecipeView;
|
||||
import app.mealsmadeeasy.api.recipe.view.RecipeInfoView;
|
||||
import app.mealsmadeeasy.api.user.User;
|
||||
import app.mealsmadeeasy.api.user.UserEntity;
|
||||
import org.jetbrains.annotations.ApiStatus;
|
||||
import org.jetbrains.annotations.Contract;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.data.domain.Pageable;
|
||||
import org.springframework.data.domain.Slice;
|
||||
import org.springframework.security.access.AccessDeniedException;
|
||||
@ -34,17 +37,20 @@ public class RecipeServiceImpl implements RecipeService {
|
||||
private final RecipeStarRepository recipeStarRepository;
|
||||
private final ImageService imageService;
|
||||
private final MarkdownService markdownService;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
public RecipeServiceImpl(
|
||||
RecipeRepository recipeRepository,
|
||||
RecipeStarRepository recipeStarRepository,
|
||||
ImageService imageService,
|
||||
MarkdownService markdownService
|
||||
MarkdownService markdownService,
|
||||
EmbeddingModel embeddingModel
|
||||
) {
|
||||
this.recipeRepository = recipeRepository;
|
||||
this.recipeStarRepository = recipeStarRepository;
|
||||
this.imageService = imageService;
|
||||
this.markdownService = markdownService;
|
||||
this.embeddingModel = embeddingModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -93,7 +99,9 @@ public class RecipeServiceImpl implements RecipeService {
|
||||
));
|
||||
}
|
||||
|
||||
private String getRenderedMarkdown(RecipeEntity entity) {
|
||||
@Override
|
||||
@ApiStatus.Internal
|
||||
public String getRenderedMarkdown(RecipeEntity entity) {
|
||||
if (entity.getCachedRenderedText() == null) {
|
||||
entity.setCachedRenderedText(this.markdownService.renderAndCleanMarkdown(entity.getRawText()));
|
||||
entity = this.recipeRepository.save(entity);
|
||||
@ -191,6 +199,20 @@ public class RecipeServiceImpl implements RecipeService {
|
||||
return List.copyOf(this.recipeRepository.findAllByOwner((UserEntity) owner));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RecipeInfoView> aiSearch(RecipeAiSearchSpec searchSpec, @Nullable User viewer) {
|
||||
final float[] queryEmbedding = this.embeddingModel.embed(searchSpec.getPrompt());
|
||||
final List<RecipeEntity> results;
|
||||
if (viewer == null) {
|
||||
results = this.recipeRepository.searchByEmbeddingAndIsPublic(queryEmbedding, 0.5f);
|
||||
} else {
|
||||
results = this.recipeRepository.searchByEmbeddingAndViewableBy(queryEmbedding, 0.5f, viewer.getId());
|
||||
}
|
||||
return results.stream()
|
||||
.map(recipeEntity -> this.getInfoView(recipeEntity, viewer))
|
||||
.toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
@PreAuthorize("@recipeSecurity.isOwner(#username, #slug, #modifier)")
|
||||
public Recipe update(String username, String slug, RecipeUpdateSpec spec, User modifier)
|
||||
|
||||
@ -0,0 +1,30 @@
|
||||
package app.mealsmadeeasy.api.recipe.body;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class RecipeSearchBody {
|
||||
|
||||
public enum Type {
|
||||
AI_PROMPT
|
||||
}
|
||||
|
||||
private Type type;
|
||||
private Map<String, Object> data;
|
||||
|
||||
public Type getType() {
|
||||
return this.type;
|
||||
}
|
||||
|
||||
public void setType(Type type) {
|
||||
this.type = type;
|
||||
}
|
||||
|
||||
public Map<String, Object> getData() {
|
||||
return this.data;
|
||||
}
|
||||
|
||||
public void setData(Map<String, Object> data) {
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,15 @@
|
||||
package app.mealsmadeeasy.api.recipe.spec;
|
||||
|
||||
public class RecipeAiSearchSpec {
|
||||
|
||||
private String prompt;
|
||||
|
||||
public String getPrompt() {
|
||||
return this.prompt;
|
||||
}
|
||||
|
||||
public void setPrompt(String prompt) {
|
||||
this.prompt = prompt;
|
||||
}
|
||||
|
||||
}
|
||||
@ -16,3 +16,6 @@ app.mealsmadeeasy.api.minio.endpoint=http://${MINIO_HOST:localhost}:${MINIO_PORT
|
||||
app.mealsmadeeasy.api.minio.accessKey=${MINIO_ROOT_USER}
|
||||
app.mealsmadeeasy.api.minio.secretKey=${MINIO_ROOT_PASSWORD}
|
||||
app.mealsmadeeasy.api.images.bucketName=images
|
||||
|
||||
# AI
|
||||
spring.ai.vectorstore.pgvector.dimensions=1024
|
||||
|
||||
@ -0,0 +1 @@
|
||||
CREATE EXTENSION vector;
|
||||
@ -0,0 +1,7 @@
|
||||
CREATE TABLE recipe_embedding (
|
||||
recipe_id INT NOT NULL,
|
||||
timestamp TIMESTAMPTZ(6) NOT NULL,
|
||||
embedding VECTOR(1024),
|
||||
PRIMARY KEY (recipe_id),
|
||||
FOREIGN KEY (recipe_id) REFERENCES recipe ON DELETE CASCADE
|
||||
);
|
||||
Loading…
Reference in New Issue
Block a user