|
5 | 5 | import com.devoxx.genie.model.LanguageModel; |
6 | 6 | import com.devoxx.genie.model.enumarations.ModelProvider; |
7 | 7 | import com.devoxx.genie.ui.settings.DevoxxGenieStateService; |
8 | | -import dev.langchain4j.model.bedrock.*; |
| 8 | +import dev.langchain4j.model.bedrock.BedrockChatModel; |
| 9 | +import dev.langchain4j.model.chat.request.ChatRequestParameters; |
| 10 | +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; |
9 | 11 | import dev.langchain4j.model.chat.ChatLanguageModel; |
10 | 12 | import org.apache.commons.lang3.NotImplementedException; |
11 | 13 | import org.jetbrains.annotations.NotNull; |
@@ -65,68 +67,86 @@ public ChatLanguageModel createChatModel(ChatModel chatModel) { |
65 | 67 | * @return An instance of {@link ChatLanguageModel} configured for Anthropic models. |
66 | 68 | */ |
67 | 69 | private ChatLanguageModel createAnthropicChatModel(@NotNull ChatModel chatModel) { |
68 | | - // TODO Refactor the deprecated class and use the new one |
69 | | - return BedrockAnthropicMessageChatModel.builder() |
70 | | - .model(chatModel.getModelName()) |
71 | | - .temperature(chatModel.getTemperature()) |
72 | | - .maxTokens(chatModel.getMaxTokens()) |
73 | | - .credentialsProvider(getCredentialsProvider()) |
74 | | - .region(getRegion()) |
| 70 | + return BedrockChatModel.builder() |
| 71 | + .modelId(chatModel.getModelName()) |
| 72 | + .client(BedrockRuntimeClient.builder() |
| 73 | + .region(getRegion()) |
| 74 | + .credentialsProvider(getCredentialsProvider()) |
| 75 | + .build()) |
| 76 | + .defaultRequestParameters(ChatRequestParameters.builder() |
| 77 | + .temperature(chatModel.getTemperature()) |
| 78 | + .maxOutputTokens(chatModel.getMaxTokens()) |
| 79 | + .build()) |
75 | 80 | .build(); |
76 | 81 | } |
77 | 82 |
|
78 | 83 | private ChatLanguageModel createMistralChatModel(@NotNull ChatModel chatModel) { |
79 | | - // TODO Refactor the deprecated class and use the new one |
80 | | - return BedrockMistralAiChatModel.builder() |
81 | | - .model(chatModel.getModelName()) |
82 | | - .temperature(chatModel.getTemperature()) |
83 | | - .maxTokens(chatModel.getMaxTokens()) |
84 | | - .credentialsProvider(getCredentialsProvider()) |
85 | | - .region(getRegion()) |
| 84 | + return BedrockChatModel.builder() |
| 85 | + .modelId(chatModel.getModelName()) |
| 86 | + .client(BedrockRuntimeClient.builder() |
| 87 | + .region(getRegion()) |
| 88 | + .credentialsProvider(getCredentialsProvider()) |
| 89 | + .build()) |
| 90 | + .defaultRequestParameters(ChatRequestParameters.builder() |
| 91 | + .temperature(chatModel.getTemperature()) |
| 92 | + .maxOutputTokens(chatModel.getMaxTokens()) |
| 93 | + .build()) |
86 | 94 | .build(); |
87 | 95 | } |
88 | 96 |
|
89 | 97 | private ChatLanguageModel createCohereChatModel(@NotNull ChatModel chatModel) { |
90 | | - // TODO Refactor the deprecated class and use the new one |
91 | | - return BedrockCohereChatModel.builder() |
92 | | - .model(chatModel.getModelName()) |
93 | | - .temperature(chatModel.getTemperature()) |
94 | | - .maxTokens(chatModel.getMaxTokens()) |
95 | | - .credentialsProvider(getCredentialsProvider()) |
96 | | - .region(getRegion()) |
| 98 | + return BedrockChatModel.builder() |
| 99 | + .modelId(chatModel.getModelName()) |
| 100 | + .client(BedrockRuntimeClient.builder() |
| 101 | + .region(getRegion()) |
| 102 | + .credentialsProvider(getCredentialsProvider()) |
| 103 | + .build()) |
| 104 | + .defaultRequestParameters(ChatRequestParameters.builder() |
| 105 | + .temperature(chatModel.getTemperature()) |
| 106 | + .maxOutputTokens(chatModel.getMaxTokens()) |
| 107 | + .build()) |
97 | 108 | .build(); |
98 | 109 | } |
99 | 110 |
|
100 | 111 | private ChatLanguageModel createLamaChatModel(@NotNull ChatModel chatModel) { |
101 | | - // TODO Refactor the deprecated class and use the new one |
102 | | - return BedrockLlamaChatModel.builder() |
103 | | - .model(chatModel.getModelName()) |
104 | | - .temperature(chatModel.getTemperature()) |
105 | | - .maxTokens(chatModel.getMaxTokens()) |
106 | | - .credentialsProvider(getCredentialsProvider()) |
107 | | - .region(getRegion()) |
| 112 | + return BedrockChatModel.builder() |
| 113 | + .modelId(chatModel.getModelName()) |
| 114 | + .client(BedrockRuntimeClient.builder() |
| 115 | + .region(getRegion()) |
| 116 | + .credentialsProvider(getCredentialsProvider()) |
| 117 | + .build()) |
| 118 | + .defaultRequestParameters(ChatRequestParameters.builder() |
| 119 | + .temperature(chatModel.getTemperature()) |
| 120 | + .maxOutputTokens(chatModel.getMaxTokens()) |
| 121 | + .build()) |
108 | 122 | .build(); |
109 | 123 | } |
110 | 124 |
|
111 | 125 | private ChatLanguageModel createAI21ChatModel(@NotNull ChatModel chatModel) { |
112 | | - // TODO Refactor the deprecated class and use the new one |
113 | | - return BedrockAI21LabsChatModel.builder() |
114 | | - .model(chatModel.getModelName()) |
115 | | - .temperature(chatModel.getTemperature()) |
116 | | - .maxTokens(chatModel.getMaxTokens()) |
117 | | - .credentialsProvider(getCredentialsProvider()) |
118 | | - .region(getRegion()) |
| 126 | + return BedrockChatModel.builder() |
| 127 | + .modelId(chatModel.getModelName()) |
| 128 | + .client(BedrockRuntimeClient.builder() |
| 129 | + .region(getRegion()) |
| 130 | + .credentialsProvider(getCredentialsProvider()) |
| 131 | + .build()) |
| 132 | + .defaultRequestParameters(ChatRequestParameters.builder() |
| 133 | + .temperature(chatModel.getTemperature()) |
| 134 | + .maxOutputTokens(chatModel.getMaxTokens()) |
| 135 | + .build()) |
119 | 136 | .build(); |
120 | 137 | } |
121 | 138 |
|
122 | 139 | private ChatLanguageModel createStabilityChatModel(@NotNull ChatModel chatModel) { |
123 | | - // TODO Refactor the deprecated class and use the new one |
124 | | - return BedrockStabilityAIChatModel.builder() |
125 | | - .model(chatModel.getModelName()) |
126 | | - .temperature(chatModel.getTemperature()) |
127 | | - .maxTokens(chatModel.getMaxTokens()) |
128 | | - .credentialsProvider(getCredentialsProvider()) |
129 | | - .region(getRegion()) |
| 140 | + return BedrockChatModel.builder() |
| 141 | + .modelId(chatModel.getModelName()) |
| 142 | + .client(BedrockRuntimeClient.builder() |
| 143 | + .region(getRegion()) |
| 144 | + .credentialsProvider(getCredentialsProvider()) |
| 145 | + .build()) |
| 146 | + .defaultRequestParameters(ChatRequestParameters.builder() |
| 147 | + .temperature(chatModel.getTemperature()) |
| 148 | + .maxOutputTokens(chatModel.getMaxTokens()) |
| 149 | + .build()) |
130 | 150 | .build(); |
131 | 151 | } |
132 | 152 |
|
|
0 commit comments