diff --git a/model-providers/openai/azure-openai/deployment/src/main/java/io/quarkiverse/langchain4j/azure/openai/deployment/AzureOpenAiProcessor.java b/model-providers/openai/azure-openai/deployment/src/main/java/io/quarkiverse/langchain4j/azure/openai/deployment/AzureOpenAiProcessor.java index aaa3d9922..9f1ce303a 100644 --- a/model-providers/openai/azure-openai/deployment/src/main/java/io/quarkiverse/langchain4j/azure/openai/deployment/AzureOpenAiProcessor.java +++ b/model-providers/openai/azure-openai/deployment/src/main/java/io/quarkiverse/langchain4j/azure/openai/deployment/AzureOpenAiProcessor.java @@ -121,13 +121,17 @@ void generateBeans(AzureOpenAiRecorder recorder, for (var selected : selectedEmbedding) { if (PROVIDER.equals(selected.getProvider())) { String configName = selected.getConfigName(); + + var embeddingModel = recorder.embeddingModel(config, configName); var builder = SyntheticBeanBuildItem .configure(EMBEDDING_MODEL) .setRuntimeInit() .unremovable() .defaultBean() .scope(ApplicationScoped.class) - .supplier(recorder.embeddingModel(config, configName)); + .addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, + new Type[] { ClassType.create(DotNames.MODEL_AUTH_PROVIDER) }, null)) + .createWith(embeddingModel); addQualifierIfNecessary(builder, configName); beanProducer.produce(builder.done()); } @@ -136,12 +140,16 @@ void generateBeans(AzureOpenAiRecorder recorder, for (var selected : selectedImage) { if (PROVIDER.equals(selected.getProvider())) { String configName = selected.getConfigName(); + + var imageModel = recorder.imageModel(config, configName); var builder = SyntheticBeanBuildItem .configure(IMAGE_MODEL) .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(recorder.imageModel(config, configName)); + .addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, + new Type[] { ClassType.create(DotNames.MODEL_AUTH_PROVIDER) }, null)) + .createWith(imageModel); addQualifierIfNecessary(builder, configName); beanProducer.produce(builder.done()); } diff --git a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java index 342c2a60c..40cc8e252 100644 --- a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java +++ b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java @@ -153,16 +153,14 @@ public StreamingChatLanguageModel apply(SyntheticCreationalContext embeddingModel(LangChain4jAzureOpenAiConfig runtimeConfig, String configName) { + public Function, EmbeddingModel> embeddingModel( + LangChain4jAzureOpenAiConfig runtimeConfig, String configName) { LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName); if (azureAiConfig.enableIntegration()) { EmbeddingModelConfig embeddingModelConfig = azureAiConfig.embeddingModel(); String apiKey = azureAiConfig.apiKey().orElse(null); String adToken = azureAiConfig.adToken().orElse(null); - if (apiKey == null && adToken == null) { - throw new ConfigValidationException(createKeyMisconfigurationProblem(configName)); - } var builder = AzureOpenAiEmbeddingModel.builder() .endpoint(getEndpoint(azureAiConfig, configName, EndpointType.EMBEDDING)) .apiKey(apiKey) @@ -174,29 +172,31 @@ public Supplier embeddingModel(LangChain4jAzureOpenAiConfig runt .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), azureAiConfig.logRequests())) .logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), azureAiConfig.logResponses())); - return new Supplier<>() { + return new Function<>() { @Override - public EmbeddingModel get() { + public EmbeddingModel apply(SyntheticCreationalContext context) { + throwIfApiKeysNotConfigured(apiKey, adToken, isAuthProviderAvailable(context, configName), + configName); return builder.build(); } }; } else { - return new Supplier<>() { + return new Function<>() { @Override - public EmbeddingModel get() { + public EmbeddingModel apply(SyntheticCreationalContext context) { return new DisabledEmbeddingModel(); } }; } } - public Supplier imageModel(LangChain4jAzureOpenAiConfig runtimeConfig, String configName) { + public Function, ImageModel> imageModel(LangChain4jAzureOpenAiConfig runtimeConfig, + String configName) { LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName); if (azureAiConfig.enableIntegration()) { var apiKey = azureAiConfig.apiKey().orElse(null); String adToken = azureAiConfig.adToken().orElse(null); - throwIfApiKeysNotConfigured(apiKey, adToken, false, configName); var imageModelConfig = azureAiConfig.imageModel(); var builder = AzureOpenAiImageModel.builder() @@ -236,16 +236,18 @@ public Optional get() { builder.persistDirectory(persistDirectory); - return new Supplier<>() { + return new Function<>() { @Override - public ImageModel get() { + public ImageModel apply(SyntheticCreationalContext context) { + throwIfApiKeysNotConfigured(apiKey, adToken, isAuthProviderAvailable(context, configName), + configName); return builder.build(); } }; } else { - return new Supplier<>() { + return new Function<>() { @Override - public ImageModel get() { + public ImageModel apply(SyntheticCreationalContext context) { return new DisabledImageModel(); } }; diff --git a/model-providers/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/DisabledModelsAzureOpenAiRecorderTest.java b/model-providers/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/DisabledModelsAzureOpenAiRecorderTest.java index 5b5e19ef7..6d599877d 100644 --- a/model-providers/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/DisabledModelsAzureOpenAiRecorderTest.java +++ b/model-providers/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/DisabledModelsAzureOpenAiRecorderTest.java @@ -45,14 +45,14 @@ void disabledStreamingChatModel() { @Test void disabledEmbeddingModel() { - assertThat(recorder.embeddingModel(config, NamedConfigUtil.DEFAULT_NAME).get()) + assertThat(recorder.embeddingModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null)) .isNotNull() .isExactlyInstanceOf(DisabledEmbeddingModel.class); } @Test void disabledImageModel() { - assertThat(recorder.imageModel(config, NamedConfigUtil.DEFAULT_NAME).get()) + assertThat(recorder.imageModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null)) .isNotNull() .isExactlyInstanceOf(DisabledImageModel.class); }