diff --git a/basic_demo/glm_server.py b/basic_demo/glm_server.py index 2ae8b22..fec07c3 100644 --- a/basic_demo/glm_server.py +++ b/basic_demo/glm_server.py @@ -19,7 +19,7 @@ EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 -MAX_MODEL_LENGTH = 8192 +MAX_MODEL_LENGTH = 8192 @asynccontextmanager async def lifespan(app: FastAPI): @@ -444,23 +444,35 @@ async def predict_stream(model_id, gen_params): system_fingerprint = generate_id('fp_', 9) tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {} delta_text = "" + delta_confirming_texts = [] + confirm_tool_state = 'un_confirm' if tools else 'none' + # 带有tools时可以确认是否调用工具的最大字符长度 = 工具名最大长度 + 可能的前面有“\n”、后面“\n{”共3个字符。 + max_confirm_tool_length = len(max(tools, len)) + 3 if tools else 0 async for new_response in generate_stream_glm4(gen_params): decoded_unicode = new_response["text"] delta_text += decoded_unicode[len(output):] + if confirm_tool_state == 'un_confirm': + delta_confirming_texts.append(decoded_unicode[len(output):]) + output = decoded_unicode lines = output.strip().split("\n") # 检查是否为工具 # 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。 ##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。 - - if not is_function_call and len(lines) >= 2: + if confirm_tool_state == 'un_confirm' and len(lines) >= 2 and lines[1].startswith("{"): first_line = lines[0].strip() if first_line in tools: is_function_call = True function_name = first_line delta_text = lines[1] + confirm_tool_state == 'confirmed' + else: + confirm_tool_state == 'none' + # 当传入tools时,经过大模型输出几轮后,已经可以确认不需要调用工具了 + if confirm_tool_state == 'un_confirm' and max_confirm_tool_length < len(delta_text): + confirm_tool_state == 'none' # 工具调用返回 if is_function_call: if not has_send_first_chunk: @@ -524,7 +536,7 @@ async def predict_stream(model_id, gen_params): yield chunk.model_dump_json(exclude_unset=True) # 用户请求了 Function Call 但是框架还没确定是否为Function Call - elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call: + elif confirm_tool_state == 'un_confirm': continue # 常规返回 @@ -552,6 +564,29 @@ async def predict_stream(model_id, gen_params): yield chunk.model_dump_json(exclude_unset=True) has_send_first_chunk = True + for text in delta_confirming_texts: + message = DeltaMessage( + content=text, + role="assistant", + function_call=None, + ) + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=message, + finish_reason=finish_reason + ) + chunk = ChatCompletionResponse( + model=model_id, + id=response_id, + choices=[choice_data], + created=created_time, + system_fingerprint=system_fingerprint, + object="chat.completion.chunk" + ) + yield chunk.model_dump_json(exclude_unset=True) + delta_confirming_texts = [] + delta_text = "" + message = DeltaMessage( content=delta_text, role="assistant", @@ -613,7 +648,7 @@ async def predict_stream(model_id, gen_params): object="chat.completion.chunk" ) yield chunk.model_dump_json(exclude_unset=True) - + finish_reason = 'stop' message = DeltaMessage( content=delta_text,