Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 95 additions & 31 deletions Sources/Conduit/ChatSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,17 @@ public final class ChatSession<Provider: AIProvider & TextGenerator>: @unchecked
/// is `.none`, preserving single-attempt behavior.
public var toolCallRetryPolicy: ToolExecutor.RetryPolicy = .none

/// Maximum number of tool-call rounds allowed in a single `send(_:)` request.
/// Maximum number of tool-call rounds allowed in a single `send(_:)` or `stream(_:)` request.
///
/// A "round" is one model response containing at least one tool call followed by
/// executing those calls. This bounds continuation loops and prevents runaway cycles.
///
/// The limit is checked **before** executing each round:
/// - `maxToolCallRounds = 0`: no tool calls are executed; the first tool-call response
/// throws `AIError.invalidInput` immediately without running any tools.
/// - `maxToolCallRounds = N` (N > 0): exactly **N** rounds are permitted. The (N+1)th
/// tool-call response throws `AIError.invalidInput`.
///
/// Values less than zero are treated as zero during execution.
public var maxToolCallRounds: Int = 8

Expand Down Expand Up @@ -658,17 +664,20 @@ public final class ChatSession<Provider: AIProvider & TextGenerator>: @unchecked
let userMessage = Message.user(content)

// Prepare state and capture messages under lock
let currentMessages: [Message] = withLock {
let capturedState: (messages: [Message], toolExecutor: ToolExecutor?, toolCallRetryPolicy: ToolExecutor.RetryPolicy, maxToolCallRounds: Int) = withLock {
lastError = nil
isGenerating = true
cancellationRequested = false
messages.append(userMessage)
return messages
return (messages, toolExecutor, toolCallRetryPolicy, max(0, maxToolCallRounds))
}

// Capture model and config for the async operation
let currentModel = model
let currentConfig = config
let currentToolExecutor = capturedState.toolExecutor
let currentToolCallRetryPolicy = capturedState.toolCallRetryPolicy
let currentMaxToolCallRounds = capturedState.maxToolCallRounds

return AsyncThrowingStream { continuation in
let task = Task { [weak self] in
Expand All @@ -677,27 +686,90 @@ public final class ChatSession<Provider: AIProvider & TextGenerator>: @unchecked
return
}

var fullText = ""
var streamError: Error?

do {
// Get the stream from provider using streamWithMetadata
// which accepts messages array
let providerStream = self.provider.streamWithMetadata(
messages: currentMessages,
model: currentModel,
config: currentConfig
)
var loopMessages = capturedState.messages
var turnMessages: [Message] = []
var toolRoundCount = 0

// Iterate and yield tokens
for try await chunk in providerStream {
// Check for cancellation
while true {
try Task.checkCancellation()
try self.throwIfCancelled()

var roundText = ""
var completedToolCalls: [Transcript.ToolCall] = []

let providerStream = self.provider.streamWithMetadata(
messages: loopMessages,
model: currentModel,
config: currentConfig
)

for try await chunk in providerStream {
try Task.checkCancellation()
try self.throwIfCancelled()

// Yield the token (text is non-optional in GenerationChunk)
let tokenText = chunk.text
continuation.yield(tokenText)
fullText += tokenText
if !chunk.text.isEmpty {
continuation.yield(chunk.text)
roundText += chunk.text
}

if let toolCalls = chunk.completedToolCalls, !toolCalls.isEmpty {
completedToolCalls = toolCalls
}
}

let assistantMessage = Message(
role: .assistant,
content: .text(roundText),
metadata: MessageMetadata(
model: currentModel.rawValue,
toolCalls: completedToolCalls.isEmpty ? nil : completedToolCalls
)
)

turnMessages.append(assistantMessage)
loopMessages.append(assistantMessage)

guard !completedToolCalls.isEmpty else {
break
}

guard toolRoundCount < currentMaxToolCallRounds else {
throw AIError.invalidInput(
"Tool-call loop exceeded maxToolCallRounds (\(currentMaxToolCallRounds))."
)
}

guard let currentToolExecutor else {
throw AIError.invalidInput(
"Tool calls were requested but ChatSession.toolExecutor is nil."
)
}

let toolOutputs = try await currentToolExecutor.execute(
toolCalls: completedToolCalls,
retryPolicy: currentToolCallRetryPolicy
)

try Task.checkCancellation()
try self.throwIfCancelled()

for output in toolOutputs {
let toolMessage = Message.toolOutput(output)
turnMessages.append(toolMessage)
loopMessages.append(toolMessage)
}

toolRoundCount += 1
}

// Finalize state under lock on success
self.withLock {
self.messages.append(contentsOf: turnMessages)
self.isGenerating = false
self.cancellationRequested = false
}

} catch is CancellationError {
Expand All @@ -707,24 +779,16 @@ public final class ChatSession<Provider: AIProvider & TextGenerator>: @unchecked
streamError = error
}

// Finalize state under lock
self.withLock {
if let error = streamError {
// Remove user message on error
// Finalize error state under lock
if let error = streamError {
self.withLock {
if let index = self.messages.lastIndex(where: { $0.id == userMessage.id }) {
self.messages.remove(at: index)
}
self.lastError = error
} else {
// Add assistant message on success
let assistantMessage = Message(
role: .assistant,
content: .text(fullText),
metadata: MessageMetadata(model: currentModel.rawValue)
)
self.messages.append(assistantMessage)
self.isGenerating = false
self.cancellationRequested = false
}
self.isGenerating = false
}

// Finish the stream
Expand Down
36 changes: 34 additions & 2 deletions Sources/Conduit/Core/Streaming/GenerationChunk.swift
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,52 @@ public struct PartialToolCall: Sendable, Hashable {
///
/// - Precondition: `id` must not be empty.
/// - Precondition: `toolName` must not be empty.
/// - Precondition: `index` must be in range `0...maxToolCallIndex` (0...100).
/// - Precondition: `index` must be in range `0...maxToolCallIndex`.
public init(id: String, toolName: String, index: Int, argumentsFragment: String) {
precondition(!id.isEmpty, "PartialToolCall id must not be empty")
precondition(!toolName.isEmpty, "PartialToolCall toolName must not be empty")
precondition(
(0...maxToolCallIndex).contains(index),
"PartialToolCall index must be in range 0...\(maxToolCallIndex), got \(index)"
)

self.id = id
self.toolName = toolName
self.index = index
self.argumentsFragment = argumentsFragment
}

/// Creates a validated partial tool call, returning an error for invalid inputs.
///
/// Use this factory in streaming code where server responses may contain malformed
/// data such as empty IDs, missing tool names, or out-of-range indices. Unlike the
/// standard `init`, this factory throws instead of trapping on invalid input.
///
/// - Parameters:
/// - id: Unique identifier for this tool call. Must not be empty.
/// - toolName: Name of the tool being called. Must not be empty.
/// - index: Index of this tool call in the response. Must be in range `0...maxToolCallIndex`.
/// - argumentsFragment: Current accumulated arguments JSON fragment.
///
/// - Throws: `AIError.invalidInput` when inputs are invalid.
public static func validated(
id: String,
toolName: String,
index: Int,
argumentsFragment: String
) throws -> PartialToolCall {
guard !id.isEmpty else {
throw AIError.invalidInput("PartialToolCall id must not be empty")
}
guard !toolName.isEmpty else {
throw AIError.invalidInput("PartialToolCall toolName must not be empty")
}
guard (0...maxToolCallIndex).contains(index) else {
throw AIError.invalidInput(
"PartialToolCall index must be in range 0...\(maxToolCallIndex), got \(index)"
)
}
return PartialToolCall(id: id, toolName: toolName, index: index, argumentsFragment: argumentsFragment)
}
}

// MARK: - GenerationChunk
Expand Down
52 changes: 43 additions & 9 deletions Sources/Conduit/ModelManagement/ModelCache.swift
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,18 @@ public actor ModelCache {
/// }
/// ```
public func allCachedModels() -> [CachedModelInfo] {
cache.values.sorted { $0.lastAccessedAt > $1.lastAccessedAt }
// Take a snapshot before validation to avoid mutating `cache` while iterating its values.
let infos = Array(cache.values)
return infos
.filter { validateCachedEntry($0) }
.sorted { $0.lastAccessedAt > $1.lastAccessedAt }
}

/// Checks if a model is cached.
///
/// This is a fast lookup that only checks the in-memory cache.
/// It does not verify that the files still exist on disk.
/// It validates that the files still exist on disk and prunes
/// stale entries if the OS has evicted the cache.
///
/// - Parameter model: The model identifier to check.
/// - Returns: `true` if the model is in the cache, `false` otherwise.
Expand All @@ -206,7 +211,10 @@ public actor ModelCache {
/// }
/// ```
public func isCached(_ model: ModelIdentifier) -> Bool {
cache[model] != nil
guard let info = cache[model] else {
return false
}
return validateCachedEntry(info)
}

/// Gets the cached model info for a model.
Expand All @@ -226,7 +234,10 @@ public actor ModelCache {
/// }
/// ```
public func info(for model: ModelIdentifier) -> CachedModelInfo? {
cache[model]
guard let info = cache[model] else {
return nil
}
return validateCachedEntry(info) ? info : nil
}

/// Gets the local file path for a cached model.
Expand All @@ -245,7 +256,23 @@ public actor ModelCache {
/// }
/// ```
public func localPath(for model: ModelIdentifier) -> URL? {
cache[model]?.path
guard let info = cache[model] else {
return nil
}
return validateCachedEntry(info) ? info.path : nil
}

/// Validates a cached entry exists on disk and prunes stale metadata.
///
/// Returns `true` when the cache entry still exists on disk.
private func validateCachedEntry(_ info: CachedModelInfo) -> Bool {
let fileManager = FileManager.default
guard fileManager.fileExists(atPath: info.path.path) else {
cache.removeValue(forKey: info.identifier)
try? saveMetadata()
return false
}
return true
}

/// Returns the total size of all cached models.
Expand All @@ -265,7 +292,10 @@ public actor ModelCache {
/// }
/// ```
public func totalSize() -> ByteCount {
let totalBytes = cache.values.reduce(0) { $0 + $1.size.bytes }
let infos = Array(cache.values)
let totalBytes = infos
.filter { validateCachedEntry($0) }
.reduce(0) { $0 + $1.size.bytes }
return ByteCount(totalBytes)
}

Expand Down Expand Up @@ -534,15 +564,19 @@ extension ModelCache {

/// Returns the number of cached models.
///
/// - Returns: The count of models in the cache.
/// Stale entries (files removed from disk) are pruned and excluded from the count.
///
/// - Returns: The count of valid models in the cache.
public var count: Int {
cache.count
Array(cache.values).filter { validateCachedEntry($0) }.count
}

/// Whether the cache is empty.
///
/// Returns `true` only when no valid (on-disk) entries remain.
///
/// - Returns: `true` if no models are cached, `false` otherwise.
public var isEmpty: Bool {
cache.isEmpty
count == 0
}
}
Loading
Loading