diff --git a/Sources/Conduit/ChatSession.swift b/Sources/Conduit/ChatSession.swift index 31c2ee2..5436137 100644 --- a/Sources/Conduit/ChatSession.swift +++ b/Sources/Conduit/ChatSession.swift @@ -273,11 +273,17 @@ public final class ChatSession: @unchecked /// is `.none`, preserving single-attempt behavior. public var toolCallRetryPolicy: ToolExecutor.RetryPolicy = .none - /// Maximum number of tool-call rounds allowed in a single `send(_:)` request. + /// Maximum number of tool-call rounds allowed in a single `send(_:)` or `stream(_:)` request. /// /// A "round" is one model response containing at least one tool call followed by /// executing those calls. This bounds continuation loops and prevents runaway cycles. /// + /// The limit is checked **before** executing each round: + /// - `maxToolCallRounds = 0`: no tool calls are executed; the first tool-call response + /// throws `AIError.invalidInput` immediately without running any tools. + /// - `maxToolCallRounds = N` (N > 0): exactly **N** rounds are permitted. The (N+1)th + /// tool-call response throws `AIError.invalidInput`. + /// /// Values less than zero are treated as zero during execution. public var maxToolCallRounds: Int = 8 @@ -658,17 +664,20 @@ public final class ChatSession: @unchecked let userMessage = Message.user(content) // Prepare state and capture messages under lock - let currentMessages: [Message] = withLock { + let capturedState: (messages: [Message], toolExecutor: ToolExecutor?, toolCallRetryPolicy: ToolExecutor.RetryPolicy, maxToolCallRounds: Int) = withLock { lastError = nil isGenerating = true cancellationRequested = false messages.append(userMessage) - return messages + return (messages, toolExecutor, toolCallRetryPolicy, max(0, maxToolCallRounds)) } // Capture model and config for the async operation let currentModel = model let currentConfig = config + let currentToolExecutor = capturedState.toolExecutor + let currentToolCallRetryPolicy = capturedState.toolCallRetryPolicy + let currentMaxToolCallRounds = capturedState.maxToolCallRounds return AsyncThrowingStream { continuation in let task = Task { [weak self] in @@ -677,27 +686,90 @@ public final class ChatSession: @unchecked return } - var fullText = "" var streamError: Error? do { - // Get the stream from provider using streamWithMetadata - // which accepts messages array - let providerStream = self.provider.streamWithMetadata( - messages: currentMessages, - model: currentModel, - config: currentConfig - ) + var loopMessages = capturedState.messages + var turnMessages: [Message] = [] + var toolRoundCount = 0 - // Iterate and yield tokens - for try await chunk in providerStream { - // Check for cancellation + while true { try Task.checkCancellation() + try self.throwIfCancelled() + + var roundText = "" + var completedToolCalls: [Transcript.ToolCall] = [] + + let providerStream = self.provider.streamWithMetadata( + messages: loopMessages, + model: currentModel, + config: currentConfig + ) + + for try await chunk in providerStream { + try Task.checkCancellation() + try self.throwIfCancelled() - // Yield the token (text is non-optional in GenerationChunk) - let tokenText = chunk.text - continuation.yield(tokenText) - fullText += tokenText + if !chunk.text.isEmpty { + continuation.yield(chunk.text) + roundText += chunk.text + } + + if let toolCalls = chunk.completedToolCalls, !toolCalls.isEmpty { + completedToolCalls = toolCalls + } + } + + let assistantMessage = Message( + role: .assistant, + content: .text(roundText), + metadata: MessageMetadata( + model: currentModel.rawValue, + toolCalls: completedToolCalls.isEmpty ? nil : completedToolCalls + ) + ) + + turnMessages.append(assistantMessage) + loopMessages.append(assistantMessage) + + guard !completedToolCalls.isEmpty else { + break + } + + guard toolRoundCount < currentMaxToolCallRounds else { + throw AIError.invalidInput( + "Tool-call loop exceeded maxToolCallRounds (\(currentMaxToolCallRounds))." + ) + } + + guard let currentToolExecutor else { + throw AIError.invalidInput( + "Tool calls were requested but ChatSession.toolExecutor is nil." + ) + } + + let toolOutputs = try await currentToolExecutor.execute( + toolCalls: completedToolCalls, + retryPolicy: currentToolCallRetryPolicy + ) + + try Task.checkCancellation() + try self.throwIfCancelled() + + for output in toolOutputs { + let toolMessage = Message.toolOutput(output) + turnMessages.append(toolMessage) + loopMessages.append(toolMessage) + } + + toolRoundCount += 1 + } + + // Finalize state under lock on success + self.withLock { + self.messages.append(contentsOf: turnMessages) + self.isGenerating = false + self.cancellationRequested = false } } catch is CancellationError { @@ -707,24 +779,16 @@ public final class ChatSession: @unchecked streamError = error } - // Finalize state under lock - self.withLock { - if let error = streamError { - // Remove user message on error + // Finalize error state under lock + if let error = streamError { + self.withLock { if let index = self.messages.lastIndex(where: { $0.id == userMessage.id }) { self.messages.remove(at: index) } self.lastError = error - } else { - // Add assistant message on success - let assistantMessage = Message( - role: .assistant, - content: .text(fullText), - metadata: MessageMetadata(model: currentModel.rawValue) - ) - self.messages.append(assistantMessage) + self.isGenerating = false + self.cancellationRequested = false } - self.isGenerating = false } // Finish the stream diff --git a/Sources/Conduit/Core/Streaming/GenerationChunk.swift b/Sources/Conduit/Core/Streaming/GenerationChunk.swift index ea0dc2e..ba7edde 100644 --- a/Sources/Conduit/Core/Streaming/GenerationChunk.swift +++ b/Sources/Conduit/Core/Streaming/GenerationChunk.swift @@ -106,7 +106,7 @@ public struct PartialToolCall: Sendable, Hashable { /// /// - Precondition: `id` must not be empty. /// - Precondition: `toolName` must not be empty. - /// - Precondition: `index` must be in range `0...maxToolCallIndex` (0...100). + /// - Precondition: `index` must be in range `0...maxToolCallIndex`. public init(id: String, toolName: String, index: Int, argumentsFragment: String) { precondition(!id.isEmpty, "PartialToolCall id must not be empty") precondition(!toolName.isEmpty, "PartialToolCall toolName must not be empty") @@ -114,12 +114,44 @@ public struct PartialToolCall: Sendable, Hashable { (0...maxToolCallIndex).contains(index), "PartialToolCall index must be in range 0...\(maxToolCallIndex), got \(index)" ) - self.id = id self.toolName = toolName self.index = index self.argumentsFragment = argumentsFragment } + + /// Creates a validated partial tool call, returning an error for invalid inputs. + /// + /// Use this factory in streaming code where server responses may contain malformed + /// data such as empty IDs, missing tool names, or out-of-range indices. Unlike the + /// standard `init`, this factory throws instead of trapping on invalid input. + /// + /// - Parameters: + /// - id: Unique identifier for this tool call. Must not be empty. + /// - toolName: Name of the tool being called. Must not be empty. + /// - index: Index of this tool call in the response. Must be in range `0...maxToolCallIndex`. + /// - argumentsFragment: Current accumulated arguments JSON fragment. + /// + /// - Throws: `AIError.invalidInput` when inputs are invalid. + public static func validated( + id: String, + toolName: String, + index: Int, + argumentsFragment: String + ) throws -> PartialToolCall { + guard !id.isEmpty else { + throw AIError.invalidInput("PartialToolCall id must not be empty") + } + guard !toolName.isEmpty else { + throw AIError.invalidInput("PartialToolCall toolName must not be empty") + } + guard (0...maxToolCallIndex).contains(index) else { + throw AIError.invalidInput( + "PartialToolCall index must be in range 0...\(maxToolCallIndex), got \(index)" + ) + } + return PartialToolCall(id: id, toolName: toolName, index: index, argumentsFragment: argumentsFragment) + } } // MARK: - GenerationChunk diff --git a/Sources/Conduit/ModelManagement/ModelCache.swift b/Sources/Conduit/ModelManagement/ModelCache.swift index 78d5551..224fc76 100644 --- a/Sources/Conduit/ModelManagement/ModelCache.swift +++ b/Sources/Conduit/ModelManagement/ModelCache.swift @@ -186,13 +186,18 @@ public actor ModelCache { /// } /// ``` public func allCachedModels() -> [CachedModelInfo] { - cache.values.sorted { $0.lastAccessedAt > $1.lastAccessedAt } + // Take a snapshot before validation to avoid mutating `cache` while iterating its values. + let infos = Array(cache.values) + return infos + .filter { validateCachedEntry($0) } + .sorted { $0.lastAccessedAt > $1.lastAccessedAt } } /// Checks if a model is cached. /// /// This is a fast lookup that only checks the in-memory cache. - /// It does not verify that the files still exist on disk. + /// It validates that the files still exist on disk and prunes + /// stale entries if the OS has evicted the cache. /// /// - Parameter model: The model identifier to check. /// - Returns: `true` if the model is in the cache, `false` otherwise. @@ -206,7 +211,10 @@ public actor ModelCache { /// } /// ``` public func isCached(_ model: ModelIdentifier) -> Bool { - cache[model] != nil + guard let info = cache[model] else { + return false + } + return validateCachedEntry(info) } /// Gets the cached model info for a model. @@ -226,7 +234,10 @@ public actor ModelCache { /// } /// ``` public func info(for model: ModelIdentifier) -> CachedModelInfo? { - cache[model] + guard let info = cache[model] else { + return nil + } + return validateCachedEntry(info) ? info : nil } /// Gets the local file path for a cached model. @@ -245,7 +256,23 @@ public actor ModelCache { /// } /// ``` public func localPath(for model: ModelIdentifier) -> URL? { - cache[model]?.path + guard let info = cache[model] else { + return nil + } + return validateCachedEntry(info) ? info.path : nil + } + + /// Validates a cached entry exists on disk and prunes stale metadata. + /// + /// Returns `true` when the cache entry still exists on disk. + private func validateCachedEntry(_ info: CachedModelInfo) -> Bool { + let fileManager = FileManager.default + guard fileManager.fileExists(atPath: info.path.path) else { + cache.removeValue(forKey: info.identifier) + try? saveMetadata() + return false + } + return true } /// Returns the total size of all cached models. @@ -265,7 +292,10 @@ public actor ModelCache { /// } /// ``` public func totalSize() -> ByteCount { - let totalBytes = cache.values.reduce(0) { $0 + $1.size.bytes } + let infos = Array(cache.values) + let totalBytes = infos + .filter { validateCachedEntry($0) } + .reduce(0) { $0 + $1.size.bytes } return ByteCount(totalBytes) } @@ -534,15 +564,19 @@ extension ModelCache { /// Returns the number of cached models. /// - /// - Returns: The count of models in the cache. + /// Stale entries (files removed from disk) are pruned and excluded from the count. + /// + /// - Returns: The count of valid models in the cache. public var count: Int { - cache.count + Array(cache.values).filter { validateCachedEntry($0) }.count } /// Whether the cache is empty. /// + /// Returns `true` only when no valid (on-disk) entries remain. + /// /// - Returns: `true` if no models are cached, `false` otherwise. public var isEmpty: Bool { - cache.isEmpty + count == 0 } } diff --git a/Sources/Conduit/ModelManagement/ModelManager.swift b/Sources/Conduit/ModelManagement/ModelManager.swift index 7018d62..c4fa8d5 100644 --- a/Sources/Conduit/ModelManagement/ModelManager.swift +++ b/Sources/Conduit/ModelManagement/ModelManager.swift @@ -692,40 +692,53 @@ extension ModelManager { // Create speed calculator for this download let speedCalculator = SpeedCalculator() + var speedTask: Task? + var speedContinuation: AsyncStream.Continuation? - // Wrap progress callback with size and speed enrichment - let enrichedProgress: (@Sendable (DownloadProgress) -> Void)? = progress.map { callback in - { @Sendable (downloadProgress: DownloadProgress) in - var enriched = downloadProgress - - // Set total bytes from estimation if not provided - if enriched.totalBytes == nil { - enriched.totalBytes = estimatedSize?.bytes - } - - // Call callback immediately with basic progress - callback(enriched) - - // Asynchronously update speed in background (non-blocking) - Task { - await speedCalculator.addSample(bytes: downloadProgress.bytesDownloaded) + if let progress { + let speedStream = AsyncStream { continuation in + speedContinuation = continuation + } + speedTask = Task { + for await update in speedStream { + await speedCalculator.addSample(bytes: update.bytesDownloaded) + var updated = update if let speed = await speedCalculator.averageSpeed() { - var updated = enriched updated.bytesPerSecond = speed - - // Calculate ETA if let total = updated.totalBytes, speed > 0 { let remaining = total - updated.bytesDownloaded updated.estimatedTimeRemaining = TimeInterval(remaining) / speed } - - // Send updated progress with speed info - callback(updated) } + // Always fire the callback once per tick (with or without speed info) + progress(updated) } } } + // Capture continuation for use in the @Sendable enrichedProgress closure below + let capturedContinuation = speedContinuation + + defer { + capturedContinuation?.finish() + speedTask?.cancel() + } + + // Wrap progress callback with estimated size enrichment before feeding the speed stream + let enrichedProgress: (@Sendable (DownloadProgress) -> Void)? = progress.map { _ in + { @Sendable (downloadProgress: DownloadProgress) in + var enriched = downloadProgress + + // Set total bytes from estimation if not provided + if enriched.totalBytes == nil { + enriched.totalBytes = estimatedSize?.bytes + } + + // Feed the speed stream, which calls the user callback exactly once per event + capturedContinuation?.yield(enriched) + } + } + return try await download(model, progress: enrichedProgress) } diff --git a/Sources/Conduit/Providers/OpenAI/OpenAIProvider+Streaming.swift b/Sources/Conduit/Providers/OpenAI/OpenAIProvider+Streaming.swift index dfd8d50..f2ec442 100644 --- a/Sources/Conduit/Providers/OpenAI/OpenAIProvider+Streaming.swift +++ b/Sources/Conduit/Providers/OpenAI/OpenAIProvider+Streaming.swift @@ -347,12 +347,18 @@ extension OpenAIProvider { // Create partial tool call for streaming updates if let acc = toolCallAccumulators[index] { - partialToolCall = PartialToolCall( - id: acc.id, - toolName: acc.name, - index: index, - argumentsFragment: acc.argumentsBuffer - ) + do { + partialToolCall = try PartialToolCall.validated( + id: acc.id, + toolName: acc.name, + index: index, + argumentsFragment: acc.argumentsBuffer + ) + } catch { + logger.error( + "Skipping partial tool call '\(acc.name)' at index \(index): \(error.localizedDescription)" + ) + } } } } @@ -732,18 +738,25 @@ extension OpenAIProvider { toolAccumulatorsByID[callID] = accumulator - continuation.yield(GenerationChunk( - text: "", - tokenCount: 0, - isComplete: false, - partialToolCall: PartialToolCall( + do { + let partialToolCall = try PartialToolCall.validated( id: accumulator.id, toolName: accumulator.name, index: accumulator.index, argumentsFragment: accumulator.argumentsBuffer - ), - reasoningDetails: currentReasoningDetails() - )) + ) + continuation.yield(GenerationChunk( + text: "", + tokenCount: 0, + isComplete: false, + partialToolCall: partialToolCall, + reasoningDetails: currentReasoningDetails() + )) + } catch { + logger.error( + "Skipping partial tool call '\(accumulator.name)' at index \(accumulator.index): \(error.localizedDescription)" + ) + } case .completed: let completedToolCalls = finalizeToolCalls() diff --git a/Tests/ConduitTests/ChatSessionTests.swift b/Tests/ConduitTests/ChatSessionTests.swift index fe22b51..e14b169 100644 --- a/Tests/ConduitTests/ChatSessionTests.swift +++ b/Tests/ConduitTests/ChatSessionTests.swift @@ -38,6 +38,12 @@ actor MockTextProvider: AIProvider, @preconcurrency TextGenerator { /// Optional artificial delay per generate call for cancellation tests. private var _generationDelayNanos: UInt64 = 0 + /// Queue of pre-built chunk arrays to return from successive stream calls. + /// + /// Each inner array is one complete stream response. When non-empty, the + /// next `stream(messages:model:config:)` call dequeues from the front. + private var _queuedStreamChunkSets: [[GenerationChunk]] = [] + // MARK: - Accessors for Test Assertions var responseToReturn: String { @@ -73,6 +79,10 @@ actor MockTextProvider: AIProvider, @preconcurrency TextGenerator { _generationDelayNanos = nanoseconds } + func setQueuedStreamChunkSets(_ chunkSets: [[GenerationChunk]]) { + _queuedStreamChunkSets = chunkSets + } + // MARK: - AIProvider var isAvailable: Bool { true } @@ -116,16 +126,30 @@ actor MockTextProvider: AIProvider, @preconcurrency TextGenerator { model: ModelID, config: GenerateConfig ) -> AsyncThrowingStream { + _generateCallCount += 1 _lastReceivedMessages = messages - let responseText = _responseToReturn - let throwError = _shouldThrowError + _receivedMessagesByGenerateCall.append(messages) - return AsyncThrowingStream { continuation in - if throwError { + if _shouldThrowError { + return AsyncThrowingStream { continuation in continuation.finish(throwing: MockError.simulatedFailure) - return } + } + + if !_queuedStreamChunkSets.isEmpty { + let chunks = _queuedStreamChunkSets.removeFirst() + return AsyncThrowingStream { continuation in + Task { + for chunk in chunks { + continuation.yield(chunk) + } + continuation.finish() + } + } + } + let responseText = _responseToReturn + return AsyncThrowingStream { continuation in let words = responseText.split(separator: " ") Task { for (index, word) in words.enumerated() { @@ -208,6 +232,7 @@ actor MockTextProvider: AIProvider, @preconcurrency TextGenerator { _generateCallCount = 0 _generationDelayNanos = 0 _queuedGenerationResults = [] + _queuedStreamChunkSets = [] _receivedMessagesByGenerateCall = [] } } @@ -1043,4 +1068,186 @@ struct ChatSessionTests { let callCount = await provider.generateCallCount #expect(callCount == 0) } + + // MARK: - Streaming Tool-Call Loop Tests + + @Test("stream executes tool calls then yields final answer") + func streamExecutesToolCallsThenYieldsFinalAnswer() async throws { + let provider = MockTextProvider() + let session = try await ChatSession(provider: provider, model: .llama3_2_1b) + + let toolCall = try Transcript.ToolCall( + id: "stream_tool_1", + toolName: "session_echo_tool", + argumentsJSON: #"{"input":"Paris"}"# + ) + + // Round 1: assistant requests a tool call + // Round 2: assistant responds with the final answer + await provider.setQueuedStreamChunkSets([ + [ + GenerationChunk(text: "Checking weather", isComplete: false), + GenerationChunk( + text: "", + tokenCount: 0, + isComplete: true, + finishReason: .toolCalls, + completedToolCalls: [toolCall] + ) + ], + [ + GenerationChunk(text: "Weather is Echo: Paris", isComplete: false), + GenerationChunk(text: "", tokenCount: 0, isComplete: true, finishReason: .stop) + ] + ]) + + session.toolExecutor = ToolExecutor(tools: [SessionEchoTool()]) + + var tokens: [String] = [] + for try await token in session.stream("What's the weather?") { + tokens.append(token) + } + + // Tokens from both rounds should have been yielded + #expect(tokens.contains("Checking weather")) + #expect(tokens.contains("Weather is Echo: Paris")) + + // Final message history: user, assistant (with tool call), tool output, final assistant + #expect(session.messages.count == 4) + #expect(session.messages[0].role == .user) + #expect(session.messages[1].role == .assistant) + #expect(session.messages[1].metadata?.toolCalls?.count == 1) + #expect(session.messages[2].role == .tool) + #expect(session.messages[2].content.textValue == "Echo: Paris") + #expect(session.messages[3].role == .assistant) + #expect(session.messages[3].content.textValue == "Weather is Echo: Paris") + + let callCount = await provider.generateCallCount + #expect(callCount == 2) + + let receivedByCall = await provider.receivedMessagesByGenerateCall + #expect(receivedByCall.count == 2) + // Second stream call must have received the tool output message + #expect( + receivedByCall[1].contains(where: { $0.role == .tool && $0.content.textValue == "Echo: Paris" }) + ) + } + + @Test("stream throws when tool loop exceeds maxToolCallRounds") + func streamThrowsWhenToolLoopExceedsMaxRounds() async throws { + let provider = MockTextProvider() + let session = try await ChatSession(provider: provider, model: .llama3_2_1b) + + let toolCall1 = try Transcript.ToolCall( + id: "stream_loop_1", + toolName: "session_echo_tool", + argumentsJSON: #"{"input":"one"}"# + ) + let toolCall2 = try Transcript.ToolCall( + id: "stream_loop_2", + toolName: "session_echo_tool", + argumentsJSON: #"{"input":"two"}"# + ) + + // Both stream results request tool calls; with maxToolCallRounds = 1, + // the second tool-call response should trigger the overflow error. + await provider.setQueuedStreamChunkSets([ + [GenerationChunk( + text: "", tokenCount: 0, isComplete: true, + finishReason: .toolCalls, completedToolCalls: [toolCall1] + )], + [GenerationChunk( + text: "", tokenCount: 0, isComplete: true, + finishReason: .toolCalls, completedToolCalls: [toolCall2] + )] + ]) + + session.toolExecutor = ToolExecutor(tools: [SessionEchoTool()]) + session.maxToolCallRounds = 1 + + await #expect(throws: AIError.self) { + for try await _ in session.stream("Trigger loop") {} + } + + // User message must be rolled back on error + #expect(session.messages.isEmpty) + #expect(session.isGenerating == false) + #expect(session.lastError != nil) + + guard let aiError = session.lastError as? AIError else { + Issue.record("Expected AIError for loop limit failure") + return + } + guard case .invalidInput(let message) = aiError else { + Issue.record("Expected AIError.invalidInput for loop limit failure") + return + } + #expect(message.contains("maxToolCallRounds")) + } + + @Test("stream throws when toolExecutor is nil and tool calls are returned") + func streamThrowsWhenToolExecutorNilAndToolCallsReturned() async throws { + let provider = MockTextProvider() + let session = try await ChatSession(provider: provider, model: .llama3_2_1b) + + let toolCall = try Transcript.ToolCall( + id: "no_executor_tool", + toolName: "session_echo_tool", + argumentsJSON: #"{"input":"test"}"# + ) + + await provider.setQueuedStreamChunkSets([ + [GenerationChunk( + text: "", tokenCount: 0, isComplete: true, + finishReason: .toolCalls, completedToolCalls: [toolCall] + )] + ]) + + // Intentionally no toolExecutor set + + await #expect(throws: AIError.self) { + for try await _ in session.stream("Request that returns tool calls") {} + } + + // User message must be rolled back on error + #expect(session.messages.isEmpty) + #expect(session.isGenerating == false) + #expect(session.lastError != nil) + } + + @Test("stream with maxToolCallRounds = 0 throws without executing any tool calls") + func streamMaxToolCallRoundsZeroThrowsImmediately() async throws { + let provider = MockTextProvider() + let session = try await ChatSession(provider: provider, model: .llama3_2_1b) + + let toolCall = try Transcript.ToolCall( + id: "zero_rounds_tool", + toolName: "session_echo_tool", + argumentsJSON: #"{"input":"test"}"# + ) + + await provider.setQueuedStreamChunkSets([ + [GenerationChunk( + text: "", tokenCount: 0, isComplete: true, + finishReason: .toolCalls, completedToolCalls: [toolCall] + )] + ]) + + session.toolExecutor = ToolExecutor(tools: [SessionEchoTool()]) + session.maxToolCallRounds = 0 + + await #expect(throws: AIError.self) { + for try await _ in session.stream("Trigger zero-rounds") {} + } + + // No tool execution should have happened; messages rolled back + #expect(session.messages.isEmpty) + + guard let aiError = session.lastError as? AIError, + case .invalidInput(let message) = aiError else { + Issue.record("Expected AIError.invalidInput for zero-rounds failure") + return + } + #expect(message.contains("maxToolCallRounds")) + } }