Skip to content

Commit

Permalink
Updating property to enable AI features
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaeldelio authored and bsbodden committed Dec 30, 2024
1 parent e6b1c53 commit 0448262
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 59 deletions.
2 changes: 1 addition & 1 deletion demos/roms-vss/src/main/resources/application.properties
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
spring.mvc.hiddenmethod.filter.enabled=true
com.redis.om.vss.useLocalImages=false
com.redis.om.vss.maxLines=300
redis.om.spring.ai.djl.enabled=true
redis.om.spring.ai.enabled=true
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
import java.time.*;
import java.util.Map;

@ConditionalOnProperty(name = "redis.om.spring.ai.djl.enabled")
@ConditionalOnProperty(name = "redis.om.spring.ai.enabled")
@Configuration
@EnableConfigurationProperties({ RedisOMAiProperties.class })
public class RedisAiConfiguration {
Expand All @@ -71,10 +71,10 @@ public ImageFactory imageFactory() {

@Bean(name = "djlImageEmbeddingModelCriteria")
public Criteria<Image, byte[]> imageEmbeddingModelCriteria(RedisOMAiProperties properties) {
return properties.getDjl().isEnabled() ? Criteria.builder().setTypes(Image.class, byte[].class) //
return Criteria.builder().setTypes(Image.class, byte[].class) //
.optEngine(properties.getDjl().getImageEmbeddingModelEngine()) //
.optModelUrls(properties.getDjl().getImageEmbeddingModelModelUrls()) //
.build() : null;
.build();
}

@Bean(name = "djlFaceDetectionTranslator")
Expand All @@ -93,20 +93,19 @@ public Criteria<Image, DetectedObjects> faceDetectionModelCriteria( //
@Qualifier("djlFaceDetectionTranslator") Translator<Image, DetectedObjects> translator, //
RedisOMAiProperties properties) {

return properties.getDjl().isEnabled() ? Criteria.builder().setTypes(Image.class, DetectedObjects.class) //
return Criteria.builder().setTypes(Image.class, DetectedObjects.class) //
.optModelUrls(properties.getDjl().getFaceDetectionModelModelUrls()) //
.optModelName(properties.getDjl().getFaceDetectionModelName()) //
.optTranslator(translator) //
.optEngine(properties.getDjl().getFaceDetectionModelEngine()) //
.build() : null;
.build();
}

@Bean(name = "djlFaceDetectionModel")
public ZooModel<Image, DetectedObjects> faceDetectionModel(
@Nullable @Qualifier("djlFaceDetectionModelCriteria") Criteria<Image, DetectedObjects> criteria,
RedisOMAiProperties properties) {
@Nullable @Qualifier("djlFaceDetectionModelCriteria") Criteria<Image, DetectedObjects> criteria) {
try {
return properties.getDjl().isEnabled() && (criteria != null) ? ModelZoo.loadModel(criteria) : null;
return criteria != null ? ModelZoo.loadModel(criteria) : null;
} catch (IOException | ModelNotFoundException | MalformedModelException ex) {
logger.warn("Error retrieving default DJL face detection model", ex);
return null;
Expand All @@ -123,20 +122,19 @@ public Criteria<Image, float[]> faceEmbeddingModelCriteria( //
@Qualifier("djlFaceEmbeddingTranslator") Translator<Image, float[]> translator, //
RedisOMAiProperties properties) {

return properties.getDjl().isEnabled() ? Criteria.builder() //
return Criteria.builder() //
.setTypes(Image.class, float[].class).optModelUrls(properties.getDjl().getFaceEmbeddingModelModelUrls()) //
.optModelName(properties.getDjl().getFaceEmbeddingModelName()) //
.optTranslator(translator) //
.optEngine(properties.getDjl().getFaceEmbeddingModelEngine()) //
.build() : null;
.build();
}

@Bean(name = "djlFaceEmbeddingModel")
public ZooModel<Image, float[]> faceEmbeddingModel(
@Nullable @Qualifier("djlFaceEmbeddingModelCriteria") Criteria<Image, float[]> criteria, //
RedisOMAiProperties properties) {
@Nullable @Qualifier("djlFaceEmbeddingModelCriteria") Criteria<Image, float[]> criteria) {
try {
return properties.getDjl().isEnabled() && (criteria != null) ? ModelZoo.loadModel(criteria) : null;
return criteria != null ? ModelZoo.loadModel(criteria) : null;
} catch (Exception e) {
logger.warn("Error retrieving default DJL face embeddings model", e);
return null;
Expand All @@ -145,46 +143,39 @@ public ZooModel<Image, float[]> faceEmbeddingModel(

@Bean(name = "djlImageEmbeddingModel")
public ZooModel<Image, byte[]> imageModel(
@Nullable @Qualifier("djlImageEmbeddingModelCriteria") Criteria<Image, byte[]> criteria,
RedisOMAiProperties properties) throws MalformedModelException, ModelNotFoundException, IOException {
return properties.getDjl().isEnabled() && (criteria != null) ? ModelZoo.loadModel(criteria) : null;
@Nullable @Qualifier("djlImageEmbeddingModelCriteria") Criteria<Image, byte[]> criteria) throws MalformedModelException, ModelNotFoundException, IOException {
return criteria != null ? ModelZoo.loadModel(criteria) : null;
}

@Bean(name = "djlDefaultImagePipeline")
public Pipeline defaultImagePipeline(RedisOMAiProperties properties) {
if (properties.getDjl().isEnabled()) {
Pipeline pipeline = new Pipeline();
if (properties.getDjl().isDefaultImagePipelineCenterCrop()) {
pipeline.add(new CenterCrop());
}
return pipeline //
.add(new Resize( //
properties.getDjl().getDefaultImagePipelineResizeWidth(), //
properties.getDjl().getDefaultImagePipelineResizeHeight() //
)) //
.add(new ToTensor());
} else
return null;
Pipeline pipeline = new Pipeline();
if (properties.getDjl().isDefaultImagePipelineCenterCrop()) {
pipeline.add(new CenterCrop());
}
return pipeline //
.add(new Resize( //
properties.getDjl().getDefaultImagePipelineResizeWidth(), //
properties.getDjl().getDefaultImagePipelineResizeHeight() //
)) //
.add(new ToTensor());
}

@Bean(name = "djlSentenceTokenizer")
public HuggingFaceTokenizer sentenceTokenizer(RedisOMAiProperties properties) {
if (properties.getDjl().isEnabled()) {
Map<String, String> options = Map.of( //
"maxLength", properties.getDjl().getSentenceTokenizerMaxLength(), //
"modelMaxLength", properties.getDjl().getSentenceTokenizerModelMaxLength() //
);

try {
//noinspection ResultOfMethodCallIgnored
InetAddress.getByName("www.huggingface.co").isReachable(5000);
return HuggingFaceTokenizer.newInstance(properties.getDjl().getSentenceTokenizerModel(), options);
} catch (IOException ioe) {
logger.warn("Error retrieving default DJL sentence tokenizer");
return null;
}
} else
Map<String, String> options = Map.of( //
"maxLength", properties.getDjl().getSentenceTokenizerMaxLength(), //
"modelMaxLength", properties.getDjl().getSentenceTokenizerModelMaxLength() //
);

try {
//noinspection ResultOfMethodCallIgnored
InetAddress.getByName("www.huggingface.co").isReachable(5000);
return HuggingFaceTokenizer.newInstance(properties.getDjl().getSentenceTokenizerModel(), options);
} catch (IOException ioe) {
logger.warn("Error retrieving default DJL sentence tokenizer");
return null;
}
}

@ConditionalOnMissingBean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ public void registerReferenceSerializer(ContextRefreshedEvent cre) {
registrar.registerReferencesFor(RedisHash.class);
}

@ConditionalOnProperty(name = "redis.om.spring.ai.djl.enabled", havingValue = "false", matchIfMissing = true)
@ConditionalOnProperty(name = "redis.om.spring.ai.enabled", havingValue = "false", matchIfMissing = true)
@Bean(name = "featureExtractor")
public Embedder featureExtractor() {
return new NoopEmbedder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.ConfigurationProperties;

@ConditionalOnProperty(name = "redis.om.spring.ai.djl.enabled")
@ConditionalOnProperty(name = "redis.om.spring.ai.enabled")
@ConfigurationProperties(
prefix = "redis.om.spring.ai", ignoreInvalidFields = true
)
public class RedisOMAiProperties {
private boolean enabled = false;
private final Djl djl = new Djl();
private final OpenAi openAi = new OpenAi();
private final AzureOpenAi azureOpenAi = new AzureOpenAi();
Expand All @@ -18,6 +19,14 @@ public class RedisOMAiProperties {
private final BedrockTitan bedrockTitan = new BedrockTitan();
private final Ollama ollama = new Ollama();

public boolean isEnabled() {
return this.enabled;
}

public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

public Djl getDjl() {
return djl;
}
Expand Down Expand Up @@ -49,7 +58,6 @@ public Ollama getOllama() {
// DJL properties
public static class Djl {
private static final String DEFAULT_ENGINE = "PyTorch";
private boolean enabled = false;
// image embedding settings
@NotNull
private String imageEmbeddingModelEngine = DEFAULT_ENGINE;
Expand Down Expand Up @@ -86,14 +94,6 @@ public static class Djl {
public Djl() {
}

public boolean isEnabled() {
return this.enabled;
}

public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

public @NotNull String getImageEmbeddingModelEngine() {
return this.imageEmbeddingModelEngine;
}
Expand Down Expand Up @@ -207,7 +207,7 @@ public void setFaceEmbeddingModelModelUrls(@NotNull String faceEmbeddingModelMod
}

public String toString() {
return "RedisOMSpringProperties.Djl(enabled=" + this.isEnabled() + ", imageEmbeddingModelEngine=" + this.getImageEmbeddingModelEngine() + ", imageEmbeddingModelModelUrls=" + this.getImageEmbeddingModelModelUrls() + ", defaultImagePipelineResizeWidth=" + this.getDefaultImagePipelineResizeWidth() + ", defaultImagePipelineResizeHeight=" + this.getDefaultImagePipelineResizeHeight() + ", defaultImagePipelineCenterCrop=" + this.isDefaultImagePipelineCenterCrop() + ", sentenceTokenizerMaxLength=" + this.getSentenceTokenizerMaxLength() + ", sentenceTokenizerModelMaxLength=" + this.getSentenceTokenizerModelMaxLength() + ", sentenceTokenizerModel=" + this.getSentenceTokenizerModel() + ", faceDetectionModelEngine=" + this.getFaceDetectionModelEngine() + ", faceDetectionModelName=" + this.getFaceDetectionModelName() + ", faceDetectionModelModelUrls=" + this.getFaceDetectionModelModelUrls() + ", faceEmbeddingModelEngine=" + this.getFaceEmbeddingModelEngine() + ", faceEmbeddingModelName=" + this.getFaceEmbeddingModelName() + ", faceEmbeddingModelModelUrls=" + this.getFaceEmbeddingModelModelUrls() + ")";
return "RedisOMSpringProperties.Ai.Djl(imageEmbeddingModelEngine=" + this.getImageEmbeddingModelEngine() + ", imageEmbeddingModelModelUrls=" + this.getImageEmbeddingModelModelUrls() + ", defaultImagePipelineResizeWidth=" + this.getDefaultImagePipelineResizeWidth() + ", defaultImagePipelineResizeHeight=" + this.getDefaultImagePipelineResizeHeight() + ", defaultImagePipelineCenterCrop=" + this.isDefaultImagePipelineCenterCrop() + ", sentenceTokenizerMaxLength=" + this.getSentenceTokenizerMaxLength() + ", sentenceTokenizerModelMaxLength=" + this.getSentenceTokenizerModelMaxLength() + ", sentenceTokenizerModel=" + this.getSentenceTokenizerModel() + ", faceDetectionModelEngine=" + this.getFaceDetectionModelEngine() + ", faceDetectionModelName=" + this.getFaceDetectionModelName() + ", faceDetectionModelModelUrls=" + this.getFaceDetectionModelModelUrls() + ", faceEmbeddingModelEngine=" + this.getFaceEmbeddingModelEngine() + ", faceEmbeddingModelName=" + this.getFaceEmbeddingModelName() + ", faceEmbeddingModelModelUrls=" + this.getFaceEmbeddingModelModelUrls() + ")";
}
}

Expand Down
3 changes: 1 addition & 2 deletions redis-om-spring/src/test/resources/vss_on.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@ redis:
om:
spring:
ai:
djl:
enabled: true
\enabled: true

0 comments on commit 0448262

Please sign in to comment.