Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: incrementalOutput not working #17

Merged
merged 4 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ public DashScopeChatModel dashscopeChatModel(DashScopeConnectionProperties commo
retryTemplate);
}

@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = DashScopeEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public DashScopeApi dashscopeChatApi(DashScopeConnectionProperties commonProperties,
DashScopeChatProperties chatProperties, RestClient.Builder restClientBuilder,
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class DashScopeChatProperties extends DashScopeParentProperties {
/**
* Default DashScope Chat model.
*/
public static final String DEFAULT_DEPLOYMENT_NAME = Generation.Models.QWEN_TURBO;
public static final String DEFAULT_DEPLOYMENT_NAME = Generation.Models.QWEN_PLUS;

/**
* Default temperature speed.
Expand All @@ -56,7 +56,6 @@ public class DashScopeChatProperties extends DashScopeParentProperties {
private DashScopeChatOptions options = DashScopeChatOptions.builder()
.withModel(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
.withEnableSearch(true)
.build();

public DashScopeChatProperties() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@
* @author Ken
*/
public class DashScopeAiStreamFunctionCallingHelper {
private Boolean incrementalOutput = false;

public DashScopeAiStreamFunctionCallingHelper() {
}

public DashScopeAiStreamFunctionCallingHelper(Boolean incrementalOutput) {
this.incrementalOutput = incrementalOutput;
}

/**
* Merge the previous and current ChatCompletionChunk into a single one.
Expand All @@ -46,7 +54,6 @@ public class DashScopeAiStreamFunctionCallingHelper {
* @return the merged ChatCompletionChunk
*/
public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) {

if (previous == null) {
return current;
}
Expand All @@ -57,9 +64,18 @@ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChu
Choice previousChoice0 = previous.output() == null ? null : previous.output().choices().get(0);
Choice currentChoice0 = current.output() == null ? null : current.output().choices().get(0);

//compatibility of incremental_output false for streaming function call
if (!incrementalOutput && isStreamingToolFunctionCall(current)) {
if (!isStreamingToolFunctionCallFinish(current)) {
return new ChatCompletionChunk(id, new ChatCompletionOutput(null, List.of(new Choice(null, null))), usage);
} else {
return new ChatCompletionChunk(id, new ChatCompletionOutput(null, List.of(currentChoice0)), usage);
}
}

Choice choice = merge(previousChoice0, currentChoice0);
List<Choice> chunkChoices = choice == null ? List.of() : List.of(choice);
return new ChatCompletionChunk(id, new ChatCompletionOutput(null, chunkChoices), usage);
return new ChatCompletionChunk(id, new ChatCompletionOutput(null, chunkChoices), usage);
}

private Choice merge(Choice previous, Choice current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1237,8 +1237,6 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
.toEntity(ChatCompletion.class);
}

private final DashScopeAiStreamFunctionCallingHelper chunkMerger = new DashScopeAiStreamFunctionCallingHelper();

/**
* Creates a streaming chat response for the given chat conversation.
* @param chatRequest The chat completion request. Must have the stream property set
Expand All @@ -1251,6 +1249,8 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true.");

AtomicBoolean isInsideTool = new AtomicBoolean(false);
boolean incrementalOutput = chatRequest.parameters() != null && chatRequest.parameters().incrementalOutput != null && chatRequest.parameters().incrementalOutput;
DashScopeAiStreamFunctionCallingHelper chunkMerger = new DashScopeAiStreamFunctionCallingHelper(incrementalOutput);

return this.webClient.post()
.uri("/api/v1/services/aigc/text-generation/generation")
Expand All @@ -1262,21 +1262,21 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
.filter(SSE_DONE_PREDICATE.negate())
.map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class))
.map(chunk -> {
if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) {
if (chunkMerger.isStreamingToolFunctionCall(chunk)) {
isInsideTool.set(true);
}
return chunk;
})
.windowUntil(chunk -> {
if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) {
if (isInsideTool.get() && chunkMerger.isStreamingToolFunctionCallFinish(chunk)) {
isInsideTool.set(false);
return true;
}
return !isInsideTool.get();
})
.concatMapIterable(window -> {
Mono<ChatCompletionChunk> monoChunk = window.reduce(new ChatCompletionChunk(null, null, null),
this.chunkMerger::merge);
chunkMerger::merge);
return List.of(monoChunk);
})
.flatMap(mono -> mono);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,7 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
options = ModelOptionsUtils.merge(options, this.defaultOptions, DashScopeChatOptions.class);

if (!CollectionUtils.isEmpty(enabledToolsToUse)) {
options = ModelOptionsUtils.merge(
DashScopeChatOptions.builder().withTools(this.getFunctionTools(enabledToolsToUse)).build(), options,
DashScopeChatOptions.class);
options.setTools(this.getFunctionTools(enabledToolsToUse));
}

List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
Expand Down Expand Up @@ -338,7 +336,7 @@ private ChatCompletionRequestParameter toDashScopeRequestParameter(DashScopeChat
return new ChatCompletionRequestParameter();
}

Boolean incrementalOutput = stream || options.getIncrementalOutput();
Boolean incrementalOutput = options.getIncrementalOutput();
return new ChatCompletionRequestParameter("message", options.getSeed(), options.getMaxTokens(),
options.getTopP(), options.getTopK(), options.getRepetitionPenalty(), options.getPresencePenalty(),
options.getTemperature(), options.getStop(), options.getEnableSearch(), incrementalOutput,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public class DashScopeChatOptions implements FunctionCallingOptions, ChatOptions
/**
* 控制在流式输出模式下是否开启增量输出,即后续输出内容是否包含已输出的内容。设置为True时,将开启增量输出模式,后面输出不会包含已经输出的内容,您需要自行拼接整体输出;设置为False则会包含已输出的内容。
*/
private @JsonProperty("incremental_output") Boolean incrementalOutput = false;
private @JsonProperty("incremental_output") Boolean incrementalOutput = true;

/** 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。默认为1.1。 */
private @JsonProperty("repetition_penalty") Float repetitionPenalty;
Expand Down Expand Up @@ -341,6 +341,7 @@ public static DashScopeChatOptions fromOptions(DashScopeChatOptions fromOptions)
.withStop(fromOptions.getStop())
.withStream(fromOptions.getStream())
.withEnableSearch(fromOptions.enableSearch)
.withIncrementalOutput(fromOptions.getIncrementalOutput())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.withRepetitionPenalty(fromOptions.getRepetitionPenalty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionFinishReason;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import com.alibaba.cloud.ai.dashscope.tool.DashScopeFunctionTestConfiguration;
import com.alibaba.cloud.ai.dashscope.chat.tool.MockOrderService;
import com.alibaba.cloud.ai.dashscope.chat.tool.MockWeatherService;
Expand Down Expand Up @@ -75,7 +76,7 @@ public class DashScopeChatClientIT {
private DashScopeChatModel dashscopeChatModel;

@Autowired
private DashScopeApi dashscopeApi;
private DashScopeApi dashscopeChatApi;

@Value("classpath:/prompts/rag/system-qa.st")
private Resource systemResource;
Expand All @@ -85,7 +86,7 @@ public class DashScopeChatClientIT {

@Test
void callTest() throws IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand All @@ -102,14 +103,20 @@ void callTest() throws IOException {

@Test
void streamTest() throws InterruptedException, IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());
ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
.defaultAdvisors(
new DocumentRetrievalAdvisor(retriever, systemResource.getContentAsString(StandardCharsets.UTF_8)))
.build();

Flux<ChatResponse> response = chatClient.prompt().user("如何快速开始百炼?").stream().chatResponse();
Flux<ChatResponse> response = chatClient.prompt()
.user("如何快速开始百炼?")
.options(DashScopeChatOptions.builder()
.withIncrementalOutput(true)
.build())
.stream()
.chatResponse();

CountDownLatch cdl = new CountDownLatch(1);
response.subscribe(data -> {
Expand Down Expand Up @@ -159,7 +166,7 @@ void callWithFunctionBeanTest() {

@Test
void callWithFunctionAndRagTest() throws IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand All @@ -178,7 +185,7 @@ void callWithFunctionAndRagTest() throws IOException {

@Test
void streamCallWithFunctionAndRagTest() throws InterruptedException, IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand All @@ -187,7 +194,13 @@ void streamCallWithFunctionAndRagTest() throws InterruptedException, IOException
.defaultFunctions("weatherFunction")
.build();

Flux<ChatResponse> response = chatClient.prompt().user("上海今天的天气如何?").stream().chatResponse();
Flux<ChatResponse> response = chatClient.prompt()
.user("上海今天的天气如何?")
.options(DashScopeChatOptions.builder()
.withIncrementalOutput(true)
.build())
.stream()
.chatResponse();

CountDownLatch cdl = new CountDownLatch(1);
response.subscribe(data -> {
Expand All @@ -206,7 +219,7 @@ void streamCallWithFunctionAndRagTest() throws InterruptedException, IOException

@Test
void callWithReferencedRagTest() throws IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand All @@ -232,7 +245,7 @@ void callWithReferencedRagTest() throws IOException {

@Test
void streamCallWithReferencedRagTest() throws IOException, InterruptedException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand Down Expand Up @@ -272,7 +285,7 @@ void streamCallWithReferencedRagTest() throws IOException, InterruptedException

@Test
void callWithMemory() throws IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand Down Expand Up @@ -309,24 +322,24 @@ void callWithMemory() throws IOException {
@Test
void reader() {
String filePath = "/Users/nuocheng.lxm/Desktop/新能源产业有哪些-36氪.pdf";
DashScopeDocumentCloudReader reader = new DashScopeDocumentCloudReader(filePath, dashscopeApi, null);
DashScopeDocumentCloudReader reader = new DashScopeDocumentCloudReader(filePath, dashscopeChatApi, null);
List<Document> documentList = reader.get();
DashScopeDocumentTransformer transformer = new DashScopeDocumentTransformer(dashscopeApi);
DashScopeDocumentTransformer transformer = new DashScopeDocumentTransformer(dashscopeChatApi);
List<Document> transformerList = transformer.apply(documentList);
System.out.println(transformerList.size());
}

@Test
void embed() {
DashScopeEmbeddingModel embeddingModel = new DashScopeEmbeddingModel(dashscopeApi);
DashScopeEmbeddingModel embeddingModel = new DashScopeEmbeddingModel(dashscopeChatApi);
Document document = new Document("你好阿里云");
float[] vectorList = embeddingModel.embed(document);
System.out.println(vectorList.length);
}

@Test
void vectorStore() {
DashScopeCloudStore cloudStore = new DashScopeCloudStore(dashscopeApi, new DashScopeStoreOptions("诺成SpringAI"));
DashScopeCloudStore cloudStore = new DashScopeCloudStore(dashscopeChatApi, new DashScopeStoreOptions("诺成SpringAI"));
List<Document> documentList = Arrays.asList(
new Document("file_f0b6b18b14994ed8a0b45648ce5d0da5_10001", "abc", new HashMap<>()),
new Document("file_d3083d64026d4864b4558d18f9ca2a6d_10001", "abc", new HashMap<>()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ import { Outlet } from 'react-router-dom';
(window as any).Vaadin ??= {};
(window as any).Vaadin.copilot ??= {};
(window as any).Vaadin.copilot._ref ??= {};
(window as any).Vaadin.copilot._ref.Outlet = Outlet;
(window as any).Vaadin.copilot._ref.Outlet = Outlet;
Loading