Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions Sources/Conduit/ChatSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -742,13 +742,14 @@ public final class ChatSession<Provider: AIProvider & TextGenerator>: @unchecked

// Handle stream cancellation
continuation.onTermination = { @Sendable [weak self] termination in
if case .cancelled = termination {
task.cancel()
guard case .cancelled = termination else {
self?.withLock { self?.generationTask = nil }
return
}
task.cancel()
guard let strongSelf = self else { return }
strongSelf.withLock {
strongSelf.generationTask = nil
}
strongSelf.withLock { strongSelf.generationTask = nil }
Task { await strongSelf.provider.cancelGeneration() }
}
}
}
Expand Down
25 changes: 10 additions & 15 deletions Sources/Conduit/Core/Streaming/GenerationChunk.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ public struct PartialToolCall: Sendable, Hashable {
///
/// ## Valid Range
///
/// The index must be in the range `0...100` (see ``maxToolCallIndex``). This bound exists to:
/// The index is clamped to `0...maxToolCallIndex` (see ``maxToolCallIndex``). This bound exists to:
/// - Prevent unbounded memory allocation in streaming accumulators
/// - Provide defense against malformed server responses
/// - Ensure predictable behavior across all providers
///
/// Most real-world use cases involve indices 0-9, as models rarely invoke more than
/// 10 tools in parallel.
///
/// - Precondition: Must be in range `0...maxToolCallIndex` (0...100).
/// - Note: Values outside `0...maxToolCallIndex` are clamped rather than rejected.
/// - SeeAlso: ``maxToolCallIndex``
public let index: Int

Expand All @@ -104,20 +104,15 @@ public struct PartialToolCall: Sendable, Hashable {
/// - index: Index of this tool call in the response. Must be in range `0...maxToolCallIndex`.
/// - argumentsFragment: Current accumulated arguments JSON fragment.
///
/// - Precondition: `id` must not be empty.
/// - Precondition: `toolName` must not be empty.
/// - Precondition: `index` must be in range `0...maxToolCallIndex` (0...100).
/// - Note: Invalid values are sanitized to preserve non-crashing behavior when parsing
/// untrusted provider streaming data:
/// - Empty `id` becomes `"unknown_tool_call"`
/// - Empty `toolName` becomes `"unknown_tool"`
/// - `index` is clamped to `0...maxToolCallIndex`
public init(id: String, toolName: String, index: Int, argumentsFragment: String) {
precondition(!id.isEmpty, "PartialToolCall id must not be empty")
precondition(!toolName.isEmpty, "PartialToolCall toolName must not be empty")
precondition(
(0...maxToolCallIndex).contains(index),
"PartialToolCall index must be in range 0...\(maxToolCallIndex), got \(index)"
)

self.id = id
self.toolName = toolName
self.index = index
self.id = id.isEmpty ? "unknown_tool_call" : id
self.toolName = toolName.isEmpty ? "unknown_tool" : toolName
self.index = min(max(index, 0), maxToolCallIndex)
self.argumentsFragment = argumentsFragment
}
}
Expand Down
18 changes: 18 additions & 0 deletions Sources/Conduit/Core/Utilities/URLSessionAsyncBytes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ public struct AsyncLineSequence: AsyncSequence, Sendable {
// Check if we have a complete line in the buffer
if let newlineIndex = buffer.firstIndex(where: { $0 == UInt8(ascii: "\n") || $0 == UInt8(ascii: "\r") }) {
let delimiter = buffer[newlineIndex]

// Handle CRLF split across chunk boundaries by peeking one extra byte.
if delimiter == UInt8(ascii: "\r"), newlineIndex == buffer.count - 1 {
if let nextByte = try await bytesIterator.next() {
buffer.append(nextByte)
continue
}
}

let lineBytes = Array(buffer[..<newlineIndex])
var removeCount = newlineIndex + 1
if delimiter == UInt8(ascii: "\r"),
Expand Down Expand Up @@ -345,6 +354,15 @@ public struct AsyncLineSequence: AsyncSequence, Sendable {
// Check if we have a complete line in the buffer
if let newlineIndex = buffer.firstIndex(where: { $0 == UInt8(ascii: "\n") || $0 == UInt8(ascii: "\r") }) {
let delimiter = buffer[newlineIndex]

// Handle CRLF split across chunk boundaries by peeking one extra byte.
if delimiter == UInt8(ascii: "\r"), newlineIndex == buffer.count - 1 {
if let nextByte = try await bytesIterator.next() {
buffer.append(nextByte)
continue
}
}

let lineBytes = Array(buffer[..<newlineIndex])
var removeCount = newlineIndex + 1
if delimiter == UInt8(ascii: "\r"),
Expand Down
17 changes: 14 additions & 3 deletions Sources/Conduit/Providers/OpenAI/OpenAIProvider+Streaming.swift
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,11 @@ extension OpenAIProvider {
for tc in toolCalls {
guard let index = tc["index"] as? Int else { continue }

// Validate index is within reasonable bounds (0...100)
guard (0...100).contains(index) else {
// Validate index is within bounded range.
guard (0...maxToolCallIndex).contains(index) else {
let toolName = (tc["function"] as? [String: Any])?["name"] as? String ?? "unknown"
logger.warning(
"Skipping tool call '\(toolName)' with invalid index \(index) (must be 0...100)"
"Skipping tool call '\(toolName)' with invalid index \(index) (must be 0...\(maxToolCallIndex))"
)
continue
}
Expand Down Expand Up @@ -627,6 +627,7 @@ extension OpenAIProvider {
var sseParser = ServerSentEventParser()
var reasoningBuffer = ""
var toolAccumulatorsByID: [String: ResponsesToolAccumulator] = [:]
var skippedToolAccumulatorIDs: Set<String> = []
var nextToolIndex = 0

func finalizeToolCalls() -> [Transcript.ToolCall] {
Expand Down Expand Up @@ -705,6 +706,16 @@ extension OpenAIProvider {

case .toolCallCreated, .toolCallDelta:
guard let callID = decoded.toolCallID else { continue }
guard !skippedToolAccumulatorIDs.contains(callID) else { continue }

if toolAccumulatorsByID[callID] == nil, nextToolIndex > maxToolCallIndex {
skippedToolAccumulatorIDs.insert(callID)
logger.warning(
"Skipping tool call '\(callID)' because index exceeded maxToolCallIndex (\(maxToolCallIndex))"
)
continue
}

var accumulator = toolAccumulatorsByID[callID] ?? ResponsesToolAccumulator(
id: callID,
name: decoded.toolName ?? "unknown_tool",
Expand Down
45 changes: 44 additions & 1 deletion Tests/ConduitTests/ChatSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ actor MockTextProvider: AIProvider, @preconcurrency TextGenerator {
/// Optional artificial delay per generate call for cancellation tests.
private var _generationDelayNanos: UInt64 = 0

/// Optional artificial delay per streamed chunk for cancellation tests.
private var _streamChunkDelayNanos: UInt64 = 0

/// Number of times cancelGeneration was called.
private var _cancelCallCount: Int = 0

// MARK: - Accessors for Test Assertions

var responseToReturn: String {
Expand All @@ -60,6 +66,11 @@ actor MockTextProvider: AIProvider, @preconcurrency TextGenerator {
set { _generateCallCount = newValue }
}

var cancelCallCount: Int {
get { _cancelCallCount }
set { _cancelCallCount = newValue }
}

var receivedMessagesByGenerateCall: [[Message]] {
get { _receivedMessagesByGenerateCall }
set { _receivedMessagesByGenerateCall = newValue }
Expand All @@ -73,6 +84,10 @@ actor MockTextProvider: AIProvider, @preconcurrency TextGenerator {
_generationDelayNanos = nanoseconds
}

func setStreamChunkDelay(nanoseconds: UInt64) {
_streamChunkDelayNanos = nanoseconds
}

// MARK: - AIProvider

var isAvailable: Bool { true }
Expand Down Expand Up @@ -119,6 +134,7 @@ actor MockTextProvider: AIProvider, @preconcurrency TextGenerator {
_lastReceivedMessages = messages
let responseText = _responseToReturn
let throwError = _shouldThrowError
let streamChunkDelay = _streamChunkDelayNanos

return AsyncThrowingStream { continuation in
if throwError {
Expand All @@ -129,6 +145,9 @@ actor MockTextProvider: AIProvider, @preconcurrency TextGenerator {
let words = responseText.split(separator: " ")
Task {
for (index, word) in words.enumerated() {
if streamChunkDelay > 0 {
try? await Task.sleep(nanoseconds: streamChunkDelay)
}
let isLast = index == words.count - 1
let chunk = GenerationChunk(
text: String(word) + (isLast ? "" : " "),
Expand All @@ -144,7 +163,7 @@ actor MockTextProvider: AIProvider, @preconcurrency TextGenerator {
}

func cancelGeneration() async {
// No-op for tests
_cancelCallCount += 1
}

// MARK: - TextGenerator Protocol Methods
Expand Down Expand Up @@ -207,6 +226,8 @@ actor MockTextProvider: AIProvider, @preconcurrency TextGenerator {
_lastReceivedMessages = []
_generateCallCount = 0
_generationDelayNanos = 0
_streamChunkDelayNanos = 0
_cancelCallCount = 0
_queuedGenerationResults = []
_receivedMessagesByGenerateCall = []
}
Expand Down Expand Up @@ -786,6 +807,28 @@ struct ChatSessionTests {
}
}

@Test("stream cancellation propagates to provider cancelGeneration")
func streamCancellationPropagatesToProvider() async throws {
let provider = MockTextProvider()
await provider.setStreamChunkDelay(nanoseconds: 200_000_000)
let session = try await ChatSession(provider: provider, model: .llama3_2_1b)

let consumer = Task {
for try await _ in session.stream("Stream slowly") {
}
}

try await Task.sleep(nanoseconds: 30_000_000)
consumer.cancel()
_ = await consumer.result
// Yield to the cooperative scheduler so the fire-and-forget
// cancelGeneration() Task has a chance to run before we assert.
await Task.yield()

let cancelCount = await provider.cancelCallCount
#expect(cancelCount >= 1)
}

// MARK: - Clear History Tests

@Test("clearHistory removes all messages except system")
Expand Down
94 changes: 59 additions & 35 deletions Tests/ConduitTests/Core/PartialToolCallTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -176,36 +176,38 @@ struct PartialToolCallTests {
#expect(partial.index == 99)
}

// Note: The following preconditions cannot be runtime-tested because they cause crashes:
// - index < 0 triggers precondition failure
// - index > 100 triggers precondition failure
// These preconditions are documented in PartialToolCall.init and should be enforced
// by callers. The precondition messages are:
// - "PartialToolCall index must be in range 0...100, got \(index)"
@Test("Index below 0 is clamped to lower boundary")
func indexBelowRangeIsClamped() {
let partial = PartialToolCall(
id: "call_negative",
toolName: "tool",
index: -1,
argumentsFragment: "{}"
)

#expect(partial.index == 0)
}

@Test("Index above max is clamped to upper boundary")
func indexAboveRangeIsClamped() {
let partial = PartialToolCall(
id: "call_above",
toolName: "tool",
index: maxToolCallIndex + 1,
argumentsFragment: "{}"
)

#expect(partial.index == maxToolCallIndex)
}
}

// MARK: - Precondition Documentation Tests

@Suite("Precondition Documentation")
struct PreconditionDocumentationTests {

// Note: Preconditions cannot be easily tested at runtime because they cause crashes.
// The following preconditions exist on PartialToolCall.init:
//
// 1. Empty id must trigger precondition failure:
// precondition(!id.isEmpty, "PartialToolCall id must not be empty")
//
// 2. Empty toolName must trigger precondition failure:
// precondition(!toolName.isEmpty, "PartialToolCall toolName must not be empty")
//
// 3. Index outside 0...100 must trigger precondition failure:
// precondition((0...100).contains(index), "PartialToolCall index must be in range 0...100, got \(index)")
//
// These tests document the expected behavior without being able to verify it at runtime.

@Test("Non-empty id satisfies precondition")
func nonEmptyIdSatisfiesPrecondition() {
// This test verifies that a non-empty id does not trigger the precondition
// MARK: - Sanitization Tests

@Suite("Sanitization")
struct SanitizationTests {

@Test("Non-empty id is preserved")
func nonEmptyIdIsPreserved() {
let partial = PartialToolCall(
id: "valid_id",
toolName: "tool",
Expand All @@ -216,9 +218,8 @@ struct PartialToolCallTests {
#expect(!partial.id.isEmpty)
}

@Test("Non-empty tool name satisfies precondition")
func nonEmptyToolNameSatisfiesPrecondition() {
// This test verifies that a non-empty toolName does not trigger the precondition
@Test("Non-empty tool name is preserved")
func nonEmptyToolNameIsPreserved() {
let partial = PartialToolCall(
id: "call_123",
toolName: "valid_tool",
Expand All @@ -229,17 +230,40 @@ struct PartialToolCallTests {
#expect(!partial.toolName.isEmpty)
}

@Test("Index within range satisfies precondition")
func indexWithinRangeSatisfiesPrecondition() {
// This test verifies that an index within 0...100 does not trigger the precondition
@Test("Index within range is preserved")
func indexWithinRangeIsPreserved() {
let partial = PartialToolCall(
id: "call_123",
toolName: "tool",
index: 50,
argumentsFragment: ""
)

#expect((0...100).contains(partial.index))
#expect((0...maxToolCallIndex).contains(partial.index))
}

@Test("Empty id is sanitized")
func emptyIdIsSanitized() {
let partial = PartialToolCall(
id: "",
toolName: "tool",
index: 1,
argumentsFragment: "{}"
)

#expect(partial.id == "unknown_tool_call")
}

@Test("Empty tool name is sanitized")
func emptyToolNameIsSanitized() {
let partial = PartialToolCall(
id: "call_123",
toolName: "",
index: 1,
argumentsFragment: "{}"
)

#expect(partial.toolName == "unknown_tool")
}
}

Expand Down
Loading