Skip to content

Commit

Permalink
Merge pull request #1162 from sberyozkin/azure_openai_embedded_with_m…
Browse files Browse the repository at this point in the history
…odel_provider

Update Azure OpenAI to check ModelAuthProvider during Embedding and Image model creation
  • Loading branch information
geoand authored Dec 16, 2024
2 parents 6ec8ea5 + 616d02e commit 22fdfce
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,17 @@ void generateBeans(AzureOpenAiRecorder recorder,
for (var selected : selectedEmbedding) {
if (PROVIDER.equals(selected.getProvider())) {
String configName = selected.getConfigName();

var embeddingModel = recorder.embeddingModel(config, configName);
var builder = SyntheticBeanBuildItem
.configure(EMBEDDING_MODEL)
.setRuntimeInit()
.unremovable()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.embeddingModel(config, configName));
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(DotNames.MODEL_AUTH_PROVIDER) }, null))
.createWith(embeddingModel);
addQualifierIfNecessary(builder, configName);
beanProducer.produce(builder.done());
}
Expand All @@ -136,12 +140,16 @@ void generateBeans(AzureOpenAiRecorder recorder,
for (var selected : selectedImage) {
if (PROVIDER.equals(selected.getProvider())) {
String configName = selected.getConfigName();

var imageModel = recorder.imageModel(config, configName);
var builder = SyntheticBeanBuildItem
.configure(IMAGE_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.imageModel(config, configName));
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(DotNames.MODEL_AUTH_PROVIDER) }, null))
.createWith(imageModel);
addQualifierIfNecessary(builder, configName);
beanProducer.produce(builder.done());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,14 @@ public StreamingChatLanguageModel apply(SyntheticCreationalContext<StreamingChat
}
}

public Supplier<EmbeddingModel> embeddingModel(LangChain4jAzureOpenAiConfig runtimeConfig, String configName) {
public Function<SyntheticCreationalContext<EmbeddingModel>, EmbeddingModel> embeddingModel(
LangChain4jAzureOpenAiConfig runtimeConfig, String configName) {
LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName);

if (azureAiConfig.enableIntegration()) {
EmbeddingModelConfig embeddingModelConfig = azureAiConfig.embeddingModel();
String apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);
if (apiKey == null && adToken == null) {
throw new ConfigValidationException(createKeyMisconfigurationProblem(configName));
}
var builder = AzureOpenAiEmbeddingModel.builder()
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.EMBEDDING))
.apiKey(apiKey)
Expand All @@ -174,29 +172,31 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jAzureOpenAiConfig runt
.logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), azureAiConfig.logRequests()))
.logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), azureAiConfig.logResponses()));

return new Supplier<>() {
return new Function<>() {
@Override
public EmbeddingModel get() {
public EmbeddingModel apply(SyntheticCreationalContext<EmbeddingModel> context) {
throwIfApiKeysNotConfigured(apiKey, adToken, isAuthProviderAvailable(context, configName),
configName);
return builder.build();
}
};
} else {
return new Supplier<>() {
return new Function<>() {
@Override
public EmbeddingModel get() {
public EmbeddingModel apply(SyntheticCreationalContext<EmbeddingModel> context) {
return new DisabledEmbeddingModel();
}
};
}
}

public Supplier<ImageModel> imageModel(LangChain4jAzureOpenAiConfig runtimeConfig, String configName) {
public Function<SyntheticCreationalContext<ImageModel>, ImageModel> imageModel(LangChain4jAzureOpenAiConfig runtimeConfig,
String configName) {
LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName);

if (azureAiConfig.enableIntegration()) {
var apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);
throwIfApiKeysNotConfigured(apiKey, adToken, false, configName);

var imageModelConfig = azureAiConfig.imageModel();
var builder = AzureOpenAiImageModel.builder()
Expand Down Expand Up @@ -236,16 +236,18 @@ public Optional<? extends Path> get() {

builder.persistDirectory(persistDirectory);

return new Supplier<>() {
return new Function<>() {
@Override
public ImageModel get() {
public ImageModel apply(SyntheticCreationalContext<ImageModel> context) {
throwIfApiKeysNotConfigured(apiKey, adToken, isAuthProviderAvailable(context, configName),
configName);
return builder.build();
}
};
} else {
return new Supplier<>() {
return new Function<>() {
@Override
public ImageModel get() {
public ImageModel apply(SyntheticCreationalContext<ImageModel> context) {
return new DisabledImageModel();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ void disabledStreamingChatModel() {

@Test
void disabledEmbeddingModel() {
assertThat(recorder.embeddingModel(config, NamedConfigUtil.DEFAULT_NAME).get())
assertThat(recorder.embeddingModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null))
.isNotNull()
.isExactlyInstanceOf(DisabledEmbeddingModel.class);
}

@Test
void disabledImageModel() {
assertThat(recorder.imageModel(config, NamedConfigUtil.DEFAULT_NAME).get())
assertThat(recorder.imageModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null))
.isNotNull()
.isExactlyInstanceOf(DisabledImageModel.class);
}
Expand Down

0 comments on commit 22fdfce

Please sign in to comment.