diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index e2e8dd2..a63892b 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -45,8 +45,25 @@ jobs: - name: Build run: swift build -v - - name: Test - run: swift test -v + - name: Test with Coverage + run: swift test --enable-code-coverage -v + + - name: Generate Coverage Report + run: | + BIN_PATH=$(swift build --show-bin-path) + XCTEST_PATH=$(find "$BIN_PATH" -name "*.xctest" -type f | head -1) + PROFDATA=$(find .build -name "default.profdata" | head -1) + if [ -n "$XCTEST_PATH" ] && [ -n "$PROFDATA" ]; then + llvm-cov export -format=lcov \ + -instr-profile="$PROFDATA" "$XCTEST_PATH" \ + -ignore-filename-regex='.build|Tests' > coverage.lcov + fi + + - name: Upload Coverage + uses: codecov/codecov-action@v4 + with: + files: coverage.lcov + fail_ci_if_error: false test-macros: name: Verify Macros on Linux @@ -77,3 +94,30 @@ jobs: - name: Test Macros run: swift test --filter ConduitMacrosTests + + test-with-providers: + name: Test with Provider Traits + runs-on: ubuntu-latest + timeout-minutes: 20 + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Swift + uses: vapor/swiftly-action@v0.2 + with: + toolchain: "6.2" + + - name: Cache SPM + uses: actions/cache@v4 + with: + path: | + ~/.cache/org.swift.swiftpm + .build + key: linux-swift-6.2-spm-providers-${{ hashFiles('**/Package.resolved') }} + restore-keys: | + linux-swift-6.2-spm-providers- + + - name: Test with traits + run: swift test --traits OpenAI,Anthropic,Kimi,MiniMax -v diff --git a/Package.swift b/Package.swift index bca2e44..c8fe9b2 100644 --- a/Package.swift +++ b/Package.swift @@ -73,6 +73,18 @@ let package = Package( .package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")), ], targets: [ + .target( + name: "ConduitCore", + dependencies: [], + path: "Sources/ConduitCore", + publicHeadersPath: "include", + cSettings: [ + .define("CONDUIT_HAS_ACCELERATE", .when(platforms: [.macOS, .iOS, .visionOS, .tvOS, .watchOS])), + ], + linkerSettings: [ + .linkedFramework("Accelerate", .when(platforms: [.macOS, .iOS, .visionOS, .tvOS, .watchOS])), + ] + ), .macro( name: "ConduitMacros", dependencies: [ @@ -86,6 +98,7 @@ let package = Package( .target( name: "Conduit", dependencies: [ + "ConduitCore", "ConduitMacros", .product(name: "OrderedCollections", package: "swift-collections"), .product(name: "Logging", package: "swift-log"), @@ -127,6 +140,13 @@ let package = Package( .enableExperimentalFeature("StrictConcurrency") ] ), + .testTarget( + name: "ConduitCoreTests", + dependencies: [ + "ConduitCore", + ], + path: "Tests/ConduitCoreTests" + ), .testTarget( name: "ConduitMacrosTests", dependencies: [ diff --git a/Sources/Conduit/ChatSession+History.swift b/Sources/Conduit/ChatSession+History.swift new file mode 100644 index 0000000..0db6859 --- /dev/null +++ b/Sources/Conduit/ChatSession+History.swift @@ -0,0 +1,141 @@ +// ChatSession+History.swift +// Conduit + +import Foundation + +// MARK: - History Management + +extension ChatSession { + + /// Clears all messages except the system prompt. + /// + /// If a system message exists at the beginning of the history, + /// it is preserved. All other messages are removed. + /// + /// ## Usage + /// + /// ```swift + /// session.clearHistory() + /// // System prompt is preserved, conversation is reset + /// ``` + public func clearHistory() { + withLock { + if let systemMessage = messages.first, systemMessage.role == .system { + messages = [systemMessage] + } else { + messages = [] + } + } + } + + /// Removes the last user-assistant exchange from history. + /// + /// This removes the most recent pair of user and assistant messages, + /// allowing you to "undo" the last conversation turn. + /// + /// If the last message is a user message without a response, only + /// that user message is removed. + /// + /// ## Usage + /// + /// ```swift + /// // After an unsatisfactory response + /// session.undoLastExchange() + /// // Try again with different phrasing + /// let response = try await session.send("Let me rephrase...") + /// ``` + public func undoLastExchange() { + withLock { + guard !messages.isEmpty else { return } + + // Remove assistant message if it's the last one + if messages.last?.role == .assistant { + messages.removeLast() + } + + // Remove user message if it's now the last one + if messages.last?.role == .user { + messages.removeLast() + } + } + } + + /// Injects a conversation history, preserving the current system prompt. + /// + /// If the current session has a system prompt, it is preserved and + /// the injected history (minus any system messages) is appended. + /// + /// If the current session has no system prompt but the injected + /// history has one, that system prompt is used. + /// + /// ## Usage + /// + /// ```swift + /// // Load saved conversation + /// let savedMessages = loadMessagesFromDisk() + /// session.injectHistory(savedMessages) + /// ``` + /// + /// - Parameter history: The messages to inject. + public func injectHistory(_ history: [Message]) { + withLock { + // Check for existing system prompt + let existingSystemPrompt: Message? = messages.first?.role == .system + ? messages.first + : nil + + // Filter out system messages from injected history + let nonSystemMessages = history.filter { $0.role != .system } + + // Check for system prompt in injected history + let injectedSystemPrompt = history.first { $0.role == .system } + + // Build new message list + if let existingPrompt = existingSystemPrompt { + // Keep existing system prompt + messages = [existingPrompt] + nonSystemMessages + } else if let injectedPrompt = injectedSystemPrompt { + // Use injected system prompt + messages = [injectedPrompt] + nonSystemMessages + } else { + // No system prompt + messages = nonSystemMessages + } + } + } + + // MARK: - Computed Properties + + /// The total number of messages in the conversation. + /// + /// Includes system, user, and assistant messages. + public var messageCount: Int { + withLock { messages.count } + } + + /// The number of user messages in the conversation. + /// + /// Useful for tracking the number of conversation turns. + public var userMessageCount: Int { + withLock { + messages.filter { $0.role == .user }.count + } + } + + /// Whether the session has an active system prompt. + public var hasSystemPrompt: Bool { + withLock { + messages.first?.role == .system + } + } + + /// The current system prompt, if any. + public var systemPrompt: String? { + withLock { + guard let first = messages.first, first.role == .system else { + return nil + } + return first.content.textValue + } + } +} diff --git a/Sources/Conduit/ChatSession.swift b/Sources/Conduit/ChatSession.swift index 31c2ee2..e778793 100644 --- a/Sources/Conduit/ChatSession.swift +++ b/Sources/Conduit/ChatSession.swift @@ -7,140 +7,6 @@ import Foundation import Observation #endif -// MARK: - WarmupConfig - -/// Configuration for model warmup behavior in ChatSession. -/// -/// Model warmup performs a minimal generation pass to pre-compile Metal shaders -/// and initialize the model's attention cache. This trades startup time for -/// improved first-message latency. -/// -/// ## Performance Impact -/// -/// - **Without warmup**: First message has ~2-4 second overhead (shader compilation) -/// - **With warmup**: First message latency is ~100-300ms (normal generation speed) -/// - **Warmup duration**: Typically 1-2 seconds during initialization -/// -/// ## When to Use -/// -/// **Use `.eager` warmup when:** -/// - The model is known at initialization time -/// - First-message latency is critical for user experience -/// - You're willing to pay the cost upfront during session creation -/// - Example: Chat interface where the user expects immediate responses -/// -/// **Use `.default` (no warmup) when:** -/// - The model might change before first use -/// - Initialization speed is more important than first-message speed -/// - The session might be created but not immediately used -/// - Example: Pre-creating sessions for potential future conversations -/// -/// ## Usage -/// -/// ### Eager Warmup (Recommended for Active Chats) -/// -/// ```swift -/// // Warmup automatically on init -/// let session = try await ChatSession( -/// provider: provider, -/// model: .llama3_2_1b, -/// warmup: .eager -/// ) -/// // First message will be fast (~100-300ms) -/// ``` -/// -/// ### Default (No Warmup) -/// -/// ```swift -/// // No warmup overhead during init -/// let session = ChatSession( -/// provider: provider, -/// model: .llama3_2_1b, -/// warmup: .default -/// ) -/// // First message will include warmup time (~2-4s) -/// ``` -/// -/// ### Custom Warmup -/// -/// ```swift -/// let customWarmup = WarmupConfig( -/// warmupOnInit: true, -/// prefillChars: 100, // Larger cache warmup -/// warmupTokens: 10 // More tokens generated -/// ) -/// let session = try await ChatSession( -/// provider: provider, -/// model: .llama3_2_1b, -/// warmup: customWarmup -/// ) -/// ``` -/// -/// ## Properties -/// -/// - `warmupOnInit`: If `true`, performs warmup during session initialization. -/// - `prefillChars`: Number of characters in the warmup prompt. Controls the -/// size of the attention cache that gets warmed up. Default: 50. -/// - `warmupTokens`: Number of tokens to generate during warmup. Higher values -/// warm up longer generation sequences but take longer. Default: 5. -/// -/// ## Static Presets -/// -/// - `.default`: No automatic warmup (`warmupOnInit: false`) -/// - `.eager`: Automatic warmup with default parameters (`warmupOnInit: true`) -public struct WarmupConfig: Sendable { - /// Whether to perform warmup during session initialization. - /// - /// If `true`, the session initializer will call the provider's `warmUp()` - /// method automatically. This trades initialization time for improved - /// first-message latency. - public var warmupOnInit: Bool - - /// Number of characters in the warmup prompt. - /// - /// Controls the size of the attention cache that gets warmed up. Larger - /// values warm up the cache for longer prompts but take slightly longer. - /// - /// Default: 50 characters - public var prefillChars: Int - - /// Number of tokens to generate during warmup. - /// - /// Higher values provide better warmup for longer generation sequences - /// but increase warmup duration. - /// - /// Default: 5 tokens - public var warmupTokens: Int - - /// Creates a custom warmup configuration. - /// - /// - Parameters: - /// - warmupOnInit: Whether to warmup on session init. Default: `false`. - /// - prefillChars: Number of warmup prompt characters. Default: `50`. - /// - warmupTokens: Number of tokens to generate. Default: `5`. - public init( - warmupOnInit: Bool = false, - prefillChars: Int = 50, - warmupTokens: Int = 5 - ) { - self.warmupOnInit = warmupOnInit - self.prefillChars = prefillChars - self.warmupTokens = warmupTokens - } - - /// Default configuration with no automatic warmup. - /// - /// First message will include warmup overhead (~2-4s), but session - /// initialization is fast. - public static let `default` = WarmupConfig(warmupOnInit: false) - - /// Eager warmup configuration. - /// - /// Performs warmup during session initialization. First message will be - /// fast (~100-300ms), but session creation takes longer (~1-2s). - public static let eager = WarmupConfig(warmupOnInit: true) -} - // MARK: - ChatSession /// A stateful session manager for multi-turn chat conversations. @@ -248,7 +114,11 @@ public final class ChatSession: @unchecked /// /// Messages are stored in chronological order. Use factory methods /// like `send(_:)` to add messages rather than modifying directly. - public private(set) var messages: [Message] = [] + /// + /// - Warning: Internal callers (e.g. extensions in separate files) that mutate + /// `messages` directly MUST do so inside a `withLock { }` block. Direct mutation + /// without the lock is unsafe and will cause data races. + public internal(set) var messages: [Message] = [] /// Whether a generation is currently in progress. /// @@ -296,7 +166,10 @@ public final class ChatSession: @unchecked private var cancellationRequested: Bool = false /// Lock for thread-safe access to mutable state. - private let lock = NSLock() + /// + /// Internal visibility to support extensions in separate files + /// (e.g., `ChatSession+History.swift`). + let lock = NSLock() // MARK: - Initialization @@ -415,7 +288,7 @@ public final class ChatSession: @unchecked /// /// - Parameter body: The closure to execute while holding the lock. /// - Returns: The value returned by the closure. - private func withLock(_ body: () throws -> T) rethrows -> T { + func withLock(_ body: () throws -> T) rethrows -> T { lock.lock() defer { lock.unlock() } return try body() @@ -657,18 +530,19 @@ public final class ChatSession: @unchecked public func stream(_ content: String) -> AsyncThrowingStream { let userMessage = Message.user(content) - // Prepare state and capture messages under lock - let currentMessages: [Message] = withLock { + // Prepare state and capture messages + config atomically under lock. + // config is a public var that could change concurrently, so it must be + // captured inside the same critical section as messages. + let (currentMessages, currentConfig): ([Message], GenerateConfig) = withLock { lastError = nil isGenerating = true cancellationRequested = false messages.append(userMessage) - return messages + return (messages, config) } - // Capture model and config for the async operation + // model is a let constant — no lock needed let currentModel = model - let currentConfig = config return AsyncThrowingStream { continuation in let task = Task { [weak self] in @@ -753,105 +627,6 @@ public final class ChatSession: @unchecked } } - // MARK: - History Management - - /// Clears all messages except the system prompt. - /// - /// If a system message exists at the beginning of the history, - /// it is preserved. All other messages are removed. - /// - /// ## Usage - /// - /// ```swift - /// session.clearHistory() - /// // System prompt is preserved, conversation is reset - /// ``` - public func clearHistory() { - withLock { - if let systemMessage = messages.first, systemMessage.role == .system { - messages = [systemMessage] - } else { - messages = [] - } - } - } - - /// Removes the last user-assistant exchange from history. - /// - /// This removes the most recent pair of user and assistant messages, - /// allowing you to "undo" the last conversation turn. - /// - /// If the last message is a user message without a response, only - /// that user message is removed. - /// - /// ## Usage - /// - /// ```swift - /// // After an unsatisfactory response - /// session.undoLastExchange() - /// // Try again with different phrasing - /// let response = try await session.send("Let me rephrase...") - /// ``` - public func undoLastExchange() { - withLock { - guard !messages.isEmpty else { return } - - // Remove assistant message if it's the last one - if messages.last?.role == .assistant { - messages.removeLast() - } - - // Remove user message if it's now the last one - if messages.last?.role == .user { - messages.removeLast() - } - } - } - - /// Injects a conversation history, preserving the current system prompt. - /// - /// If the current session has a system prompt, it is preserved and - /// the injected history (minus any system messages) is appended. - /// - /// If the current session has no system prompt but the injected - /// history has one, that system prompt is used. - /// - /// ## Usage - /// - /// ```swift - /// // Load saved conversation - /// let savedMessages = loadMessagesFromDisk() - /// session.injectHistory(savedMessages) - /// ``` - /// - /// - Parameter history: The messages to inject. - public func injectHistory(_ history: [Message]) { - withLock { - // Check for existing system prompt - let existingSystemPrompt: Message? = messages.first?.role == .system - ? messages.first - : nil - - // Filter out system messages from injected history - let nonSystemMessages = history.filter { $0.role != .system } - - // Check for system prompt in injected history - let injectedSystemPrompt = history.first { $0.role == .system } - - // Build new message list - if let existingPrompt = existingSystemPrompt { - // Keep existing system prompt - messages = [existingPrompt] + nonSystemMessages - } else if let injectedPrompt = injectedSystemPrompt { - // Use injected system prompt - messages = [injectedPrompt] + nonSystemMessages - } else { - // No system prompt - messages = nonSystemMessages - } - } - } - // MARK: - Cancellation /// Cancels any in-progress generation. @@ -900,38 +675,4 @@ public final class ChatSession: @unchecked } } - // MARK: - Computed Properties - - /// The total number of messages in the conversation. - /// - /// Includes system, user, and assistant messages. - public var messageCount: Int { - withLock { messages.count } - } - - /// The number of user messages in the conversation. - /// - /// Useful for tracking the number of conversation turns. - public var userMessageCount: Int { - withLock { - messages.filter { $0.role == .user }.count - } - } - - /// Whether the session has an active system prompt. - public var hasSystemPrompt: Bool { - withLock { - messages.first?.role == .system - } - } - - /// The current system prompt, if any. - public var systemPrompt: String? { - withLock { - guard let first = messages.first, first.role == .system else { - return nil - } - return first.content.textValue - } - } } diff --git a/Sources/Conduit/Conduit.swift b/Sources/Conduit/Conduit.swift index 58f3dcd..51abc08 100644 --- a/Sources/Conduit/Conduit.swift +++ b/Sources/Conduit/Conduit.swift @@ -1,84 +1,13 @@ // Conduit.swift // Conduit // -// A unified Swift SDK for LLM inference across multiple providers: -// - MLX: Local inference on Apple Silicon (offline, privacy-preserving) -// - llama.cpp: Native local GGUF inference via LlamaSwift (offline, portable) -// - HuggingFace: Cloud inference via HF Inference API (online, model variety) -// - Anthropic: Claude API for advanced reasoning and tool use -// - OpenAI: GPT models and DALL-E image generation -// -// Note: Apple Foundation Models (iOS 26+) are supported via FoundationModelsProvider. +// A unified Swift SDK for LLM inference across multiple providers. +// All public types are available via `import Conduit`. // // Copyright 2025. MIT License. import Foundation -// MARK: - Module Re-exports - -// Core Protocols -// TODO: @_exported import when implemented -// - AIProvider -// - TextGenerator -// - EmbeddingGenerator -// - Transcriber -// - TokenCounter -// - ModelManaging - -// Core Types -// TODO: @_exported import when implemented -// - ModelIdentifier -// - Message -// - GenerateConfig -// - EmbeddingResult -// - TranscriptionResult -// - TokenCount - -// Image Generation Types -// - GeneratedImage: Image result with SwiftUI support and save methods -// - ImageGenerationConfig: Configuration for text-to-image (dimensions, steps, guidance) -// - ImageFormat: Supported image formats (PNG, JPEG, WebP) -// - GeneratedImageError: Errors for image operations -// - ImageGenerator: Protocol for text-to-image providers (v1.2.0) -// - ImageGenerationProgress: Progress tracking for local diffusion models (v1.2.0) -// - MLXImageProvider: Local on-device image generation using MLX StableDiffusion (v1.2.0) -// - DiffusionVariant: Supported diffusion model variants (SDXL Turbo, SD 1.5, Flux) (v1.2.0) -// - DiffusionModelRegistry: Registry for managing diffusion model downloads (v1.2.0) -// - DiffusionModelDownloader: Downloads diffusion models from HuggingFace (v1.2.0) - -// Streaming -// TODO: @_exported import when implemented -// - GenerationStream -// - GenerationChunk - -// Errors -// TODO: @_exported import when implemented -// - AIError -// - UnavailabilityReason - -// Providers -// TODO: @_exported import when implemented -// - MLXProvider -// - LlamaProvider -// - HuggingFaceProvider - -// MARK: - Anthropic Provider -// - AnthropicProvider: Anthropic Claude API support -// - AnthropicModelID: Model identifiers for Claude models -// - AnthropicConfiguration: Configuration for Anthropic provider -// - AnthropicAuthentication: API key authentication - -// Model Management -// TODO: @_exported import when implemented -// - ModelManager -// - ModelRegistry -// - ModelCache - -// Builders -// TODO: @_exported import when implemented -// - PromptBuilder -// - MessageBuilder - // MARK: - Version /// The current version of the Conduit framework. diff --git a/Sources/Conduit/Core/Types/EmbeddingResult.swift b/Sources/Conduit/Core/Types/EmbeddingResult.swift index 420c144..6d0a1e2 100644 --- a/Sources/Conduit/Core/Types/EmbeddingResult.swift +++ b/Sources/Conduit/Core/Types/EmbeddingResult.swift @@ -2,6 +2,7 @@ // Conduit import Foundation +import ConduitCore /// The result of an embedding operation. /// @@ -72,19 +73,11 @@ public struct EmbeddingResult: Sendable, Hashable { /// Returns 0 if vectors have different dimensions. public func cosineSimilarity(with other: EmbeddingResult) -> Float { guard vector.count == other.vector.count else { return 0 } - - var dotProduct: Float = 0 - var normA: Float = 0 - var normB: Float = 0 - - for i in vector.indices { - dotProduct += vector[i] * other.vector[i] - normA += vector[i] * vector[i] - normB += other.vector[i] * other.vector[i] + return vector.withUnsafeBufferPointer { a in + other.vector.withUnsafeBufferPointer { b in + conduit_cosine_similarity(a.baseAddress, b.baseAddress, a.count) + } } - - let denominator = sqrt(normA) * sqrt(normB) - return denominator > 0 ? dotProduct / denominator : 0 } /// Computes Euclidean distance to another embedding. @@ -99,13 +92,11 @@ public struct EmbeddingResult: Sendable, Hashable { /// Returns `.infinity` if vectors have different dimensions. public func euclideanDistance(to other: EmbeddingResult) -> Float { guard vector.count == other.vector.count else { return .infinity } - - var sum: Float = 0 - for i in vector.indices { - let diff = vector[i] - other.vector[i] - sum += diff * diff + return vector.withUnsafeBufferPointer { a in + other.vector.withUnsafeBufferPointer { b in + conduit_euclidean_distance(a.baseAddress, b.baseAddress, a.count) + } } - return sqrt(sum) } /// Computes dot product with another embedding. @@ -120,6 +111,10 @@ public struct EmbeddingResult: Sendable, Hashable { /// Returns 0 if vectors have different dimensions. public func dotProduct(with other: EmbeddingResult) -> Float { guard vector.count == other.vector.count else { return 0 } - return zip(vector, other.vector).reduce(0) { $0 + $1.0 * $1.1 } + return vector.withUnsafeBufferPointer { a in + other.vector.withUnsafeBufferPointer { b in + conduit_dot_product(a.baseAddress, b.baseAddress, a.count) + } + } } } diff --git a/Sources/Conduit/Core/Types/WarmupConfig.swift b/Sources/Conduit/Core/Types/WarmupConfig.swift new file mode 100644 index 0000000..ee7ed8e --- /dev/null +++ b/Sources/Conduit/Core/Types/WarmupConfig.swift @@ -0,0 +1,138 @@ +// WarmupConfig.swift +// Conduit + +import Foundation + +// MARK: - WarmupConfig + +/// Configuration for model warmup behavior in ChatSession. +/// +/// Model warmup performs a minimal generation pass to pre-compile Metal shaders +/// and initialize the model's attention cache. This trades startup time for +/// improved first-message latency. +/// +/// ## Performance Impact +/// +/// - **Without warmup**: First message has ~2-4 second overhead (shader compilation) +/// - **With warmup**: First message latency is ~100-300ms (normal generation speed) +/// - **Warmup duration**: Typically 1-2 seconds during initialization +/// +/// ## When to Use +/// +/// **Use `.eager` warmup when:** +/// - The model is known at initialization time +/// - First-message latency is critical for user experience +/// - You're willing to pay the cost upfront during session creation +/// - Example: Chat interface where the user expects immediate responses +/// +/// **Use `.default` (no warmup) when:** +/// - The model might change before first use +/// - Initialization speed is more important than first-message speed +/// - The session might be created but not immediately used +/// - Example: Pre-creating sessions for potential future conversations +/// +/// ## Usage +/// +/// ### Eager Warmup (Recommended for Active Chats) +/// +/// ```swift +/// // Warmup automatically on init +/// let session = try await ChatSession( +/// provider: provider, +/// model: .llama3_2_1b, +/// warmup: .eager +/// ) +/// // First message will be fast (~100-300ms) +/// ``` +/// +/// ### Default (No Warmup) +/// +/// ```swift +/// // No warmup overhead during init +/// let session = ChatSession( +/// provider: provider, +/// model: .llama3_2_1b, +/// warmup: .default +/// ) +/// // First message will include warmup time (~2-4s) +/// ``` +/// +/// ### Custom Warmup +/// +/// ```swift +/// let customWarmup = WarmupConfig( +/// warmupOnInit: true, +/// prefillChars: 100, // Larger cache warmup +/// warmupTokens: 10 // More tokens generated +/// ) +/// let session = try await ChatSession( +/// provider: provider, +/// model: .llama3_2_1b, +/// warmup: customWarmup +/// ) +/// ``` +/// +/// ## Properties +/// +/// - `warmupOnInit`: If `true`, performs warmup during session initialization. +/// - `prefillChars`: Number of characters in the warmup prompt. Controls the +/// size of the attention cache that gets warmed up. Default: 50. +/// - `warmupTokens`: Number of tokens to generate during warmup. Higher values +/// warm up longer generation sequences but take longer. Default: 5. +/// +/// ## Static Presets +/// +/// - `.default`: No automatic warmup (`warmupOnInit: false`) +/// - `.eager`: Automatic warmup with default parameters (`warmupOnInit: true`) +public struct WarmupConfig: Sendable { + /// Whether to perform warmup during session initialization. + /// + /// If `true`, the session initializer will call the provider's `warmUp()` + /// method automatically. This trades initialization time for improved + /// first-message latency. + public var warmupOnInit: Bool + + /// Number of characters in the warmup prompt. + /// + /// Controls the size of the attention cache that gets warmed up. Larger + /// values warm up the cache for longer prompts but take slightly longer. + /// + /// Default: 50 characters + public var prefillChars: Int + + /// Number of tokens to generate during warmup. + /// + /// Higher values provide better warmup for longer generation sequences + /// but increase warmup duration. + /// + /// Default: 5 tokens + public var warmupTokens: Int + + /// Creates a custom warmup configuration. + /// + /// - Parameters: + /// - warmupOnInit: Whether to warmup on session init. Default: `false`. + /// - prefillChars: Number of warmup prompt characters. Default: `50`. + /// - warmupTokens: Number of tokens to generate. Default: `5`. + public init( + warmupOnInit: Bool = false, + prefillChars: Int = 50, + warmupTokens: Int = 5 + ) { + self.warmupOnInit = warmupOnInit + self.prefillChars = prefillChars + self.warmupTokens = warmupTokens + } + + /// Default configuration with no automatic warmup. + /// + /// First message will include warmup overhead (~2-4s), but session + /// initialization is fast. + public static let `default` = WarmupConfig(warmupOnInit: false) + + /// Eager warmup configuration. + /// + /// Performs warmup during session initialization. First message will be + /// fast (~100-300ms), but session creation takes longer (~1-2s). + public static let eager = WarmupConfig(warmupOnInit: true) +} diff --git a/Sources/Conduit/Providers/MiniMax/MiniMaxConfiguration.swift b/Sources/Conduit/Providers/MiniMax/MiniMaxConfiguration.swift index 1044ec2..f135c33 100644 --- a/Sources/Conduit/Providers/MiniMax/MiniMaxConfiguration.swift +++ b/Sources/Conduit/Providers/MiniMax/MiniMaxConfiguration.swift @@ -55,3 +55,58 @@ extension MiniMaxConfiguration { } #endif // CONDUIT_TRAIT_MINIMAX + +// MARK: - Anthropic-Compatible Messages API + +#if CONDUIT_TRAIT_MINIMAX && CONDUIT_TRAIT_OPENAI && CONDUIT_TRAIT_ANTHROPIC + +extension MiniMaxConfiguration { + + /// Base URL for MiniMax's Anthropic-compatible Messages API. + /// + /// MiniMax hosts both an OpenAI-compatible Chat Completions endpoint + /// and an Anthropic-compatible Messages endpoint. They use different + /// authentication headers and request formats. + /// + /// Use this with ``AnthropicConfiguration`` to access the Messages endpoint. + /// + /// - SeeAlso: ``anthropicCompatible(apiKey:)`` + public static let messagesBaseURL = URL(string: "https://minimax-m2.com/api")! + + /// Creates an ``AnthropicConfiguration`` targeting MiniMax's Messages API. + /// + /// MiniMax's Messages API mirrors the Anthropic Claude API format and uses: + /// - `x-api-key` header for authentication (not `Authorization: Bearer`) + /// - `Anthropic-Version: 2023-06-01` header (required) + /// - `POST /api/v1/messages` endpoint + /// + /// This is distinct from ``MiniMaxProvider``, which uses the OpenAI-compatible + /// Chat Completions endpoint (`POST /api/v1/chat/completions`) with Bearer auth. + /// + /// - Note: The Messages API does **not** support streaming. Use non-streaming + /// `generate()` calls with an ``AnthropicProvider`` configured via this method. + /// For streaming, use ``MiniMaxProvider`` instead. + /// + /// ## Usage + /// ```swift + /// let config = try MiniMaxConfiguration.anthropicCompatible(apiKey: "your-key") + /// let provider = AnthropicProvider(configuration: config) + /// let result = try await provider.generate( + /// messages: [.user("Hello")], + /// model: AnthropicModelID("MiniMax-M2"), + /// config: .default + /// ) + /// ``` + /// + /// - Parameter apiKey: Your MiniMax API key. + /// - Returns: An ``AnthropicConfiguration`` pointed at MiniMax's Messages API. + /// - Throws: `AIError.invalidInput` if the base URL is invalid. + public static func anthropicCompatible(apiKey: String) throws -> AnthropicConfiguration { + try AnthropicConfiguration( + authentication: .apiKey(apiKey), + baseURL: messagesBaseURL + ) + } +} + +#endif // CONDUIT_TRAIT_MINIMAX && CONDUIT_TRAIT_OPENAI && CONDUIT_TRAIT_ANTHROPIC diff --git a/Sources/Conduit/Utilities/JsonRepair.swift b/Sources/Conduit/Utilities/JsonRepair.swift index d4d85ea..b9ca87e 100644 --- a/Sources/Conduit/Utilities/JsonRepair.swift +++ b/Sources/Conduit/Utilities/JsonRepair.swift @@ -4,6 +4,7 @@ // Utility for repairing incomplete JSON from streaming responses. import Foundation +import ConduitCore // MARK: - JsonRepair @@ -77,33 +78,57 @@ public enum JsonRepair { /// - Parameter json: The potentially incomplete JSON string /// - Returns: A repaired JSON string that should be valid JSON public static func repair(_ json: String, maximumDepth: Int = 64) -> String { + // Use [CChar] directly to avoid reinterpreting [UInt8] as CChar via assumingMemoryBound, + // which is technically undefined behaviour in Swift's strict memory model. + let utf8: [CChar] = json.utf8.map { CChar(bitPattern: $0) } + let capacity = utf8.count + maximumDepth + 128 // Room for closing brackets + suffix + var output = [CChar](repeating: 0, count: capacity) + + let result = utf8.withUnsafeBufferPointer { inputBuf in + output.withUnsafeMutableBufferPointer { outputBuf in + conduit_json_repair( + inputBuf.baseAddress, + inputBuf.count, + outputBuf.baseAddress, + outputBuf.count, + Int32(maximumDepth) + ) + } + } + + if result >= 0 { + return output.withUnsafeBufferPointer { buf in + String(cString: buf.baseAddress!) + } + } + + // Fallback to Swift implementation if C buffer was too small + return repairSwift(json, maximumDepth: maximumDepth) + } + + /// Original Swift repair implementation, kept as fallback. + private static func repairSwift(_ json: String, maximumDepth: Int = 64) -> String { let trimmed = json.trimmingCharacters(in: .whitespaces) guard !trimmed.isEmpty else { return "{}" } - // Pre-allocate result string with margin for closing brackets var resultBuilder = "" resultBuilder.reserveCapacity(json.count + 100) var state = ParserState(maximumDepth: maximumDepth) - // Single pass: analyze AND build simultaneously for char in json { state.process(char) resultBuilder.append(char) } - // If we're in a string, close it if state.inString { - // Check for partial unicode escape sequence and remove it removePartialUnicodeEscape(&resultBuilder) - // Also handle incomplete escape at the very end if state.escapeNext, let last = resultBuilder.last, last == "\\" { resultBuilder.removeLast() } resultBuilder.append("\"") } - // Remove trailing whitespace and comma in-place while let last = resultBuilder.last, last.isWhitespace { resultBuilder.removeLast() } @@ -111,12 +136,9 @@ public enum JsonRepair { resultBuilder.removeLast() } - // Remove incomplete key-value pairs (key without value, key without colon) resultBuilder = removeIncompleteKeyValuePairs(resultBuilder) - // Close any open brackets/braces for bracket in state.bracketStack.reversed() { - // Before adding a closing bracket, remove any trailing comma while let last = resultBuilder.last, last.isWhitespace { resultBuilder.removeLast() } @@ -126,7 +148,6 @@ public enum JsonRepair { resultBuilder.append(bracket.closing) } - // Final pass: remove trailing commas before existing closing brackets resultBuilder = removeTrailingCommasBeforeClosingBrackets(resultBuilder) return resultBuilder @@ -324,7 +345,10 @@ public enum JsonRepair { private enum JsonContext { case object, array, unknown } private static func findContext(_ chars: [Character], upTo idx: Int) -> JsonContext { - var depth = 0 + // Forward scan with string-awareness to find the innermost unmatched bracket. + // A backward scan without string tracking would miscount brackets inside string + // literals (e.g. {"key": "[value"} — the '[' inside the string is not an opener). + var bracketStack: [Character] = [] var inString = false var escapeNext = false @@ -349,46 +373,20 @@ public enum JsonRepair { case "\"": inString = true case "{": - depth += 1 + bracketStack.append("{") case "}": - depth -= 1 + if bracketStack.last == "{" { bracketStack.removeLast() } case "[": - depth += 1 + bracketStack.append("[") case "]": - depth -= 1 + if bracketStack.last == "[" { bracketStack.removeLast() } default: break } } - // Now scan backwards from idx to find the most recent unmatched opener - var bracketStack: [Character] = [] - inString = false - escapeNext = false - - for i in (0...idx).reversed() { - let char = chars[i] - - // Handle string detection (simplified - scan forward to know if in string) - // Actually, for simplicity, let's just look for the nearest unmatched [ or { - if char == "]" || char == "}" { - bracketStack.append(char) - } else if char == "[" { - if let last = bracketStack.last, last == "]" { - bracketStack.removeLast() - } else { - return .array - } - } else if char == "{" { - if let last = bracketStack.last, last == "}" { - bracketStack.removeLast() - } else { - return .object - } - } - } - - return .unknown + guard let last = bracketStack.last else { return .unknown } + return last == "{" ? .object : .array } /// Removes trailing commas before closing brackets/braces in already-closed JSON. diff --git a/Sources/Conduit/Utilities/PartialJSONDecoder.swift b/Sources/Conduit/Utilities/PartialJSONDecoder.swift index 310f60a..691f699 100644 --- a/Sources/Conduit/Utilities/PartialJSONDecoder.swift +++ b/Sources/Conduit/Utilities/PartialJSONDecoder.swift @@ -5,6 +5,7 @@ // for completing and decoding partial JSON during streaming. import Foundation +import ConduitCore // MARK: - Errors @@ -52,6 +53,32 @@ public final class JSONCompleter { public func complete(_ json: String) throws -> String { guard !json.isEmpty else { return "" } + // Try C implementation first for performance + // C outputs the FULL completed string (input truncated at completion point + suffix) + // Use [CChar] directly to avoid reinterpreting [UInt8] as CChar via assumingMemoryBound, + // which is technically undefined behaviour in Swift's strict memory model. + let utf8: [CChar] = json.utf8.map { CChar(bitPattern: $0) } + let capacity = utf8.count + 256 + var output = [CChar](repeating: 0, count: capacity) + + let result = utf8.withUnsafeBufferPointer { inputBuf in + output.withUnsafeMutableBufferPointer { outputBuf in + conduit_json_complete( + inputBuf.baseAddress, + inputBuf.count, + outputBuf.baseAddress, + outputBuf.count, + Int32(maximumDepth) + ) + } + } + + if result >= 0 { + if result == 0 { return json } // Already complete + return String(cString: output) + } + + // Fallback to Swift implementation for edge cases if let completion = try completion(for: json, from: json.startIndex) { return String(json[.. +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================ +// MARK: - Vector Operations +// ============================================================================ + +/// Computes the dot product of two float vectors. +/// Returns 0 if count is 0. +float conduit_dot_product(const float *a, const float *b, size_t count); + +/// Computes cosine similarity between two float vectors. +/// Returns 0 if either vector has zero magnitude or count is 0. +float conduit_cosine_similarity(const float *a, const float *b, size_t count); + +/// Computes Euclidean distance between two float vectors. +/// Returns 0 if count is 0. +/// +/// HEAP ALLOCATION NOTE (Apple platforms only): On builds with Accelerate enabled, +/// this function allocates a temporary float buffer of `count * sizeof(float)` bytes +/// via malloc to hold the element-wise difference vector before computing the norm +/// with vDSP_dotpr. On allocation failure it falls back to the scalar path. +/// All other functions in this header perform no heap allocation. +float conduit_euclidean_distance(const float *a, const float *b, size_t count); + +/// Computes cosine similarity of `query` against each of `count` vectors in `vectors`. +/// Each vector has `dimensions` floats. Results written to `results` (must hold `count` floats). +void conduit_cosine_similarity_batch( + const float *query, + const float *vectors, + size_t dimensions, + size_t count, + float *results +); + +// ============================================================================ +// MARK: - SSE Parser +// ============================================================================ + +/// Opaque SSE parser handle. +typedef struct conduit_sse_parser conduit_sse_parser_t; + +/// A parsed Server-Sent Event. +typedef struct { + const char *id; // NULL if not present + const char *event; // NULL if not present (implies "message") + const char *data; // Event data payload (never NULL after dispatch, may be "") + int retry; // Retry interval in ms, or -1 if not set +} conduit_sse_event_t; + +/// Callback invoked for each dispatched SSE event. +/// The event fields are valid only for the duration of the callback. +typedef void (*conduit_sse_callback_t)(const conduit_sse_event_t *event, void *context); + +/// Creates a new SSE parser. Returns NULL on allocation failure. +conduit_sse_parser_t *conduit_sse_parser_create(void); + +/// Destroys an SSE parser and frees all associated memory. +void conduit_sse_parser_destroy(conduit_sse_parser_t *parser); + +/// Ingests a single line (without trailing newline). May invoke `callback` if +/// a complete event is dispatched (e.g., on empty line). +void conduit_sse_ingest_line( + conduit_sse_parser_t *parser, + const char *line, + size_t length, + conduit_sse_callback_t callback, + void *context +); + +/// Flushes any pending event at end-of-stream. May invoke `callback`. +void conduit_sse_finish( + conduit_sse_parser_t *parser, + conduit_sse_callback_t callback, + void *context +); + +// ============================================================================ +// MARK: - JSON Repair +// ============================================================================ + +/// Repairs incomplete JSON by closing unclosed strings, arrays, and objects. +/// Removes trailing commas and incomplete key-value pairs. +/// +/// `input`/`input_len`: the potentially incomplete JSON (UTF-8). +/// `output`: caller-allocated buffer to receive the repaired JSON. +/// `output_capacity`: size of `output` in bytes. +/// `max_depth`: maximum bracket nesting depth to track (e.g. 64). +/// +/// Returns the number of bytes written to `output` (excluding NUL terminator), +/// or -1 if `output_capacity` is too small. Output is always NUL-terminated +/// when return value >= 0. +int64_t conduit_json_repair( + const char *input, + size_t input_len, + char *output, + size_t output_capacity, + int max_depth +); + +// ============================================================================ +// MARK: - JSON Completer +// ============================================================================ + +/// Completes partial JSON by appending missing closing characters. +/// +/// `input`/`input_len`: the potentially incomplete JSON (UTF-8). +/// `output`: caller-allocated buffer for the FULL completed JSON string +/// (input truncated at the completion point + suffix appended). +/// `output_capacity`: size of `output` in bytes. +/// `max_depth`: maximum nesting depth (e.g. 64). +/// +/// Returns the number of bytes written to `output` (the full completed string), +/// or -1 if output_capacity is too small. Output is always NUL-terminated +/// when return value >= 0. Returns 0 if the JSON is already complete +/// (output is empty string; caller should use original input as-is). +int64_t conduit_json_complete( + const char *input, + size_t input_len, + char *output, + size_t output_capacity, + int max_depth +); + +// ============================================================================ +// MARK: - Line Buffer +// ============================================================================ + +/// Opaque line buffer handle. +typedef struct conduit_line_buffer conduit_line_buffer_t; + +/// Creates a line buffer with the given initial capacity. +/// Returns NULL on allocation failure. +conduit_line_buffer_t *conduit_line_buffer_create(size_t initial_capacity); + +/// Destroys a line buffer and frees all associated memory. +void conduit_line_buffer_destroy(conduit_line_buffer_t *buf); + +/// Appends `length` bytes from `data` to the buffer. +/// Returns 0 on success, -1 on allocation failure. +int conduit_line_buffer_append(conduit_line_buffer_t *buf, const uint8_t *data, size_t length); + +/// Extracts the next complete line (delimited by \n, \r, or \r\n). +/// On success: writes the line (without delimiter) to `line_out`, sets `line_len` +/// to the number of bytes, and returns 1. +/// On no complete line available: returns 0 and does not modify outputs. +/// On line too large for `line_out_capacity`: returns -1 and does not modify outputs. +/// The oversized line remains unconsumed so callers can grow their buffer and retry. +/// `line_out` must point to a buffer of at least `conduit_line_buffer_pending(buf)` bytes. +/// +/// The delimiter bytes are consumed from the buffer only on success (return 1). +int conduit_line_buffer_next_line( + conduit_line_buffer_t *buf, + char *line_out, + size_t line_out_capacity, + size_t *line_len +); + +/// Returns the number of bytes currently buffered. +size_t conduit_line_buffer_pending(const conduit_line_buffer_t *buf); + +/// Drains all remaining bytes into `out` (for end-of-stream). +/// Returns the number of bytes written. +size_t conduit_line_buffer_drain(conduit_line_buffer_t *buf, char *out, size_t out_capacity); + +#ifdef __cplusplus +} +#endif + +#endif // CONDUIT_CORE_H diff --git a/Sources/ConduitCore/src/conduit_json_completer.c b/Sources/ConduitCore/src/conduit_json_completer.c new file mode 100644 index 0000000..8c31e59 --- /dev/null +++ b/Sources/ConduitCore/src/conduit_json_completer.c @@ -0,0 +1,478 @@ +// conduit_json_completer.c +// ConduitCore +// +// Completes partial JSON by computing the minimal suffix and outputting the +// full completed string (input truncated at the completion point + suffix). +// Operates directly on UTF-8 bytes with O(1) pointer arithmetic. + +#include "conduit_core.h" +#include +#include + +// Internal result: completion suffix + where to apply it +typedef struct { + const char *suffix; // Static or stack-allocated suffix string + size_t suffix_len; + size_t end_offset; // Offset in input where completion applies + // Stack buffer for dynamically composed suffixes (e.g. inner completion + "]"). + // 512 bytes accommodates suffixes up to ~127 levels of nesting before falling back + // to the safe-but-minimal "]" or "}" closer (see composite suffix construction below). + // If the combined suffix ever exceeds this size the result is still valid JSON — + // only the innermost element completion is dropped, producing a shorter output. + char suffix_buf[512]; + bool found; +} completion_t; + +static size_t skip_ws(const char *json, size_t len, size_t pos) { + while (pos < len && (json[pos] == ' ' || json[pos] == '\t' || + json[pos] == '\n' || json[pos] == '\r')) { + pos++; + } + return pos; +} + +// Forward declarations +static completion_t complete_value(const char *json, size_t len, size_t pos, int depth, int max_depth); +static size_t find_end_of_complete_value(const char *json, size_t len, size_t pos, int max_depth); + +// Complete a string starting at pos (which should be '"') +static completion_t complete_string(const char *json, size_t len, size_t pos) { + completion_t r = {0}; + if (pos >= len || json[pos] != '"') return r; + + size_t cur = pos + 1; + bool escaped = false; + + while (cur < len) { + char c = json[cur]; + if (c == '\\') { + escaped = !escaped; + } else if (c == '"' && !escaped) { + return r; // String is complete, no completion needed + } else { + escaped = false; + } + cur++; + } + + // String is incomplete — close it + r.found = true; + r.suffix = "\""; + r.suffix_len = 1; + r.end_offset = cur; + return r; +} + +// Complete a number starting at pos +static completion_t complete_number(const char *json, size_t len, size_t pos) { + completion_t r = {0}; + size_t cur = pos; + + if (cur < len && json[cur] == '-') cur++; + + size_t after_sign = cur; + + // Bare minus at end + if (cur >= len) { + r.found = true; + r.suffix = "0"; + r.suffix_len = 1; + r.end_offset = cur; + return r; + } + + // "-." prefix + if (json[cur] == '.') { + r.found = true; + r.suffix = "0.0"; + r.suffix_len = 3; + r.end_offset = cur; + return r; + } + + // Integer digits + while (cur < len && json[cur] >= '0' && json[cur] <= '9') cur++; + + // Decimal part + if (cur < len && json[cur] == '.') { + cur++; + size_t frac_start = cur; + while (cur < len && json[cur] >= '0' && json[cur] <= '9') cur++; + if (cur == frac_start) { + // Decimal point with no fraction digits + r.found = true; + r.suffix = "0"; + r.suffix_len = 1; + r.end_offset = cur; + return r; + } + } + + // Exponent part + if (cur < len && (json[cur] == 'e' || json[cur] == 'E')) { + cur++; + if (cur < len && (json[cur] == '+' || json[cur] == '-')) cur++; + if (cur >= len || json[cur] < '0' || json[cur] > '9') { + r.found = true; + r.suffix = "0"; + r.suffix_len = 1; + r.end_offset = cur; + return r; + } + while (cur < len && json[cur] >= '0' && json[cur] <= '9') cur++; + } + + // Complete number — no completion needed + return r; +} + +// Complete a special value (true, false, null) +static completion_t complete_special(const char *json, size_t len, size_t pos, + const char *value, size_t value_len) { + completion_t r = {0}; + size_t cur = pos; + size_t matched = 0; + + while (cur < len && matched < value_len) { + if (json[cur] != value[matched]) return r; // Mismatch + cur++; + matched++; + } + + if (matched == value_len) return r; // Fully matched, no completion + + // Partially matched — complete it + r.found = true; + r.suffix = value + matched; + r.suffix_len = value_len - matched; + r.end_offset = cur; + return r; +} + +// Complete an array starting at pos (which should be '[') +static completion_t complete_array(const char *json, size_t len, size_t pos, int depth, int max_depth) { + completion_t r = {0}; + if (pos >= len || json[pos] != '[') return r; + + size_t cur = pos + 1; + bool requires_comma = false; + size_t last_valid = cur; + + cur = skip_ws(json, len, cur); + + if (cur >= len || json[cur] == ']') { + r.found = true; + r.suffix = "]"; + r.suffix_len = 1; + r.end_offset = cur; + return r; + } + + while (cur < len) { + if (json[cur] == ']') return r; // Array is complete + + if (requires_comma) { + if (json[cur] == ',') { + requires_comma = false; + cur++; + cur = skip_ws(json, len, cur); + if (cur >= len) break; + last_valid = cur; + } else { + r.found = true; + r.suffix = "]"; + r.suffix_len = 1; + r.end_offset = last_valid; + return r; + } + } + + if (cur >= len) break; + if (json[cur] == ']') return r; + + completion_t elem = complete_value(json, len, cur, depth + 1, max_depth); + if (elem.found) { + // Build composite suffix: elem completion + "]" + r.found = true; + size_t total = elem.suffix_len + 1; + if (total < sizeof(r.suffix_buf)) { + memcpy(r.suffix_buf, elem.suffix, elem.suffix_len); + r.suffix_buf[elem.suffix_len] = ']'; + r.suffix_buf[total] = '\0'; + r.suffix = r.suffix_buf; + r.suffix_len = total; + } else { + r.suffix = "]"; + r.suffix_len = 1; + } + r.end_offset = elem.end_offset; + return r; + } + + size_t end = find_end_of_complete_value(json, len, cur, max_depth); + cur = end; + last_valid = cur; + requires_comma = true; + } + + r.found = true; + r.suffix = "]"; + r.suffix_len = 1; + r.end_offset = last_valid; + return r; +} + +// Complete an object starting at pos (which should be '{') +static completion_t complete_object(const char *json, size_t len, size_t pos, int depth, int max_depth) { + completion_t r = {0}; + if (pos >= len || json[pos] != '{') return r; + + size_t cur = pos + 1; + bool requires_comma = false; + size_t last_valid = cur; + + cur = skip_ws(json, len, cur); + + if (cur >= len || json[cur] == '}') { + r.found = true; + r.suffix = "}"; + r.suffix_len = 1; + r.end_offset = cur; + return r; + } + + while (cur < len) { + if (json[cur] == '}') return r; + + if (requires_comma) { + if (json[cur] == ',') { + requires_comma = false; + cur++; + cur = skip_ws(json, len, cur); + if (cur >= len) break; + last_valid = cur; + } else { + r.found = true; + r.suffix = "}"; + r.suffix_len = 1; + r.end_offset = last_valid; + return r; + } + } + + if (cur >= len) break; + if (json[cur] == '}') return r; + + // Key + completion_t key_comp = complete_string(json, len, cur); + if (key_comp.found) { + r.found = true; + const char *suffix = ": null}"; + size_t slen = 7; + size_t total = key_comp.suffix_len + slen; + if (total < sizeof(r.suffix_buf)) { + memcpy(r.suffix_buf, key_comp.suffix, key_comp.suffix_len); + memcpy(r.suffix_buf + key_comp.suffix_len, suffix, slen); + r.suffix_buf[total] = '\0'; + r.suffix = r.suffix_buf; + r.suffix_len = total; + } else { + r.suffix = "}"; + r.suffix_len = 1; + } + r.end_offset = key_comp.end_offset; + return r; + } + + size_t key_end = find_end_of_complete_value(json, len, cur, max_depth); + if (key_end <= cur) { + r.found = true; + r.suffix = "}"; + r.suffix_len = 1; + r.end_offset = last_valid; + return r; + } + + cur = key_end; + last_valid = cur; + + // Colon + cur = skip_ws(json, len, cur); + if (cur >= len || json[cur] != ':') { + r.found = true; + // Need to provide ": null}" + const char *suffix = ": null}"; + r.suffix = suffix; + r.suffix_len = 7; + r.end_offset = last_valid; + return r; + } + cur++; + last_valid = cur; + + // Value + cur = skip_ws(json, len, cur); + if (cur >= len) { + r.found = true; + r.suffix = "null}"; + r.suffix_len = 5; + r.end_offset = last_valid; + return r; + } + + completion_t val_comp = complete_value(json, len, cur, depth + 1, max_depth); + if (val_comp.found) { + r.found = true; + size_t total = val_comp.suffix_len + 1; + if (total < sizeof(r.suffix_buf)) { + memcpy(r.suffix_buf, val_comp.suffix, val_comp.suffix_len); + r.suffix_buf[val_comp.suffix_len] = '}'; + r.suffix_buf[total] = '\0'; + r.suffix = r.suffix_buf; + r.suffix_len = total; + } else { + r.suffix = "}"; + r.suffix_len = 1; + } + r.end_offset = val_comp.end_offset; + return r; + } + + size_t val_end = find_end_of_complete_value(json, len, cur, max_depth); + cur = val_end; + last_valid = cur; + requires_comma = true; + } + + r.found = true; + r.suffix = "}"; + r.suffix_len = 1; + r.end_offset = last_valid; + return r; +} + +static completion_t complete_value(const char *json, size_t len, size_t pos, int depth, int max_depth) { + completion_t r = {0}; + if (depth >= max_depth) return r; + + pos = skip_ws(json, len, pos); + if (pos >= len) return r; + + switch (json[pos]) { + case '{': return complete_object(json, len, pos, depth, max_depth); + case '[': return complete_array(json, len, pos, depth, max_depth); + case '"': return complete_string(json, len, pos); + case 't': return complete_special(json, len, pos, "true", 4); + case 'f': return complete_special(json, len, pos, "false", 5); + case 'n': return complete_special(json, len, pos, "null", 4); + case '-': + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + return complete_number(json, len, pos); + default: + return r; + } +} + +// Find where a complete value ends (returns offset past the value) +static size_t find_end_of_complete_value(const char *json, size_t len, size_t pos, int max_depth) { + pos = skip_ws(json, len, pos); + if (pos >= len) return pos; + + // If value is incomplete, return its end + completion_t c = complete_value(json, len, pos, 0, max_depth); + if (c.found) return c.end_offset; + + switch (json[pos]) { + case '"': { + size_t cur = pos + 1; + bool escaped = false; + while (cur < len) { + if (json[cur] == '\\') escaped = !escaped; + else if (json[cur] == '"' && !escaped) return cur + 1; + else escaped = false; + cur++; + } + return cur; + } + case '{': case '[': { + char open = json[pos]; + char close = (open == '{') ? '}' : ']'; + int level = 0; + size_t cur = pos; + bool in_str = false; + bool esc = false; + while (cur < len) { + char ch = json[cur]; + if (in_str) { + if (ch == '\\') esc = !esc; + else if (ch == '"' && !esc) in_str = false; + else esc = false; + } else { + if (ch == '"') { in_str = true; esc = false; } + else if (ch == open) level++; + else if (ch == close) { level--; if (level == 0) return cur + 1; } + } + cur++; + } + return cur; + } + case 't': + if (pos + 4 <= len && memcmp(json + pos, "true", 4) == 0) return pos + 4; + break; + case 'f': + if (pos + 5 <= len && memcmp(json + pos, "false", 5) == 0) return pos + 5; + break; + case 'n': + if (pos + 4 <= len && memcmp(json + pos, "null", 4) == 0) return pos + 4; + break; + case '-': case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': { + size_t cur = pos; + while (cur < len && (json[cur] == '-' || json[cur] == '+' || + json[cur] == '.' || json[cur] == 'e' || + json[cur] == 'E' || + (json[cur] >= '0' && json[cur] <= '9'))) { + cur++; + } + return cur; + } + default: + break; + } + return pos; +} + +int64_t conduit_json_complete( + const char *input, + size_t input_len, + char *output, + size_t output_capacity, + int max_depth +) { + if (output_capacity < 1) return -1; + + if (input_len == 0) { + output[0] = '\0'; + return 0; + } + + if (max_depth < 1) max_depth = 64; + + completion_t c = complete_value(input, input_len, 0, 0, max_depth); + + if (!c.found) { + output[0] = '\0'; + return 0; + } + + // Output the FULL completed string: input[0..end_offset] + suffix + size_t total = c.end_offset + c.suffix_len; + if (total + 1 > output_capacity) return -1; + + memcpy(output, input, c.end_offset); + memcpy(output + c.end_offset, c.suffix, c.suffix_len); + output[total] = '\0'; + + return (int64_t)total; +} diff --git a/Sources/ConduitCore/src/conduit_json_repair.c b/Sources/ConduitCore/src/conduit_json_repair.c new file mode 100644 index 0000000..52eb57c --- /dev/null +++ b/Sources/ConduitCore/src/conduit_json_repair.c @@ -0,0 +1,358 @@ +// conduit_json_repair.c +// ConduitCore +// +// Single-pass JSON repair on raw UTF-8 bytes. Closes unclosed strings, +// arrays, objects. Removes trailing commas and incomplete key-value pairs. +// No heap allocation beyond the caller-provided output buffer. + +#include "conduit_core.h" +#include +#include + +// Bracket types for the stack +typedef enum { BRACKET_BRACE = 0, BRACKET_SQUARE = 1 } bracket_type_t; + +// JSON context for determining if a trailing string is a key or array element +typedef enum { CTX_UNKNOWN = 0, CTX_OBJECT = 1, CTX_ARRAY = 2 } json_context_t; + +// Helper: skip trailing whitespace from the end of the output +static size_t trim_trailing_whitespace(char *buf, size_t len) { + while (len > 0 && (buf[len - 1] == ' ' || buf[len - 1] == '\t' || + buf[len - 1] == '\n' || buf[len - 1] == '\r')) { + len--; + } + return len; +} + +// Helper: check if a byte is a hex digit +static bool is_hex_digit(char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); +} + +// Helper: remove partial unicode escape at end of output +// Looks for \uX, \uXX, \uXXX patterns +static size_t remove_partial_unicode_escape(char *buf, size_t len) { + if (len < 2) return len; + + // Search the last 6 chars for a backslash + size_t search_start = len > 6 ? len - 6 : 0; + size_t backslash_pos = len; // sentinel + + for (size_t i = search_start; i < len; i++) { + if (buf[i] == '\\') backslash_pos = i; + } + + if (backslash_pos >= len) return len; + if (backslash_pos + 1 >= len) return len; + + if (buf[backslash_pos + 1] == 'u') { + // Count hex digits after \u + size_t hex_count = 0; + for (size_t i = backslash_pos + 2; i < len && is_hex_digit(buf[i]); i++) { + hex_count++; + } + if (hex_count < 4) { + return backslash_pos; // Remove the entire \uXX... sequence + } + } + + return len; +} + +// Helper: find the innermost unmatched opener by scanning forward with string-awareness. +// A backward scan without string tracking would miscount brackets inside string literals +// (e.g. {"key": "[value"} — the '[' inside the string is not an array opener). +static json_context_t find_context(const char *buf, size_t len) { + bracket_type_t stack[256]; + int depth = 0; + bool in_string = false; + bool escape_next = false; + + for (size_t i = 0; i < len; i++) { + char c = buf[i]; + if (escape_next) { + escape_next = false; + continue; + } + if (in_string) { + if (c == '\\') escape_next = true; + else if (c == '"') in_string = false; + continue; + } + switch (c) { + case '"': in_string = true; break; + case '{': + if (depth < 256) stack[depth++] = BRACKET_BRACE; + break; + case '}': + if (depth > 0) depth--; + break; + case '[': + if (depth < 256) stack[depth++] = BRACKET_SQUARE; + break; + case ']': + if (depth > 0) depth--; + break; + default: + break; + } + } + + if (depth == 0) return CTX_UNKNOWN; + return (stack[depth - 1] == BRACKET_BRACE) ? CTX_OBJECT : CTX_ARRAY; +} + +// Helper: remove incomplete key-value pairs from end of output +static size_t remove_incomplete_kvp(char *buf, size_t len) { + len = trim_trailing_whitespace(buf, len); + + // Pattern: trailing comma — remove it + if (len > 0 && buf[len - 1] == ',') { + len--; + len = trim_trailing_whitespace(buf, len); + } + + // Pattern: ends with colon (key without value) — remove key: + if (len > 0 && buf[len - 1] == ':') { + len--; // remove colon + len = trim_trailing_whitespace(buf, len); + + // Now remove the quoted key + if (len > 0 && buf[len - 1] == '"') { + len--; // remove closing quote + // Find opening quote + while (len > 0 && buf[len - 1] != '"') { + len--; + } + if (len > 0) len--; // remove opening quote + + // Remove preceding comma and whitespace + len = trim_trailing_whitespace(buf, len); + if (len > 0 && buf[len - 1] == ',') { + len--; + } + } + } + + // Pattern: ends with a quoted string that might be an incomplete key in object context + if (len > 0 && buf[len - 1] == '"') { + // Find the start of this string + size_t close_quote = len - 1; + size_t idx = close_quote; + if (idx > 0) idx--; // skip past the closing quote + + while (idx > 0) { + if (buf[idx] == '"') { + // Check if escaped + size_t backslash_count = 0; + size_t check = idx; + while (check > 0 && buf[check - 1] == '\\') { + backslash_count++; + check--; + } + if (backslash_count % 2 == 0) { + break; // Found unescaped opening quote + } + } + idx--; + } + + // Check what precedes this string + size_t prev = idx; + if (prev > 0) prev--; + while (prev > 0 && (buf[prev] == ' ' || buf[prev] == '\t' || + buf[prev] == '\n' || buf[prev] == '\r')) { + prev--; + } + + if (prev < len && buf[prev] == '{') { + // Object start — this is definitely an incomplete key + len = idx; + len = trim_trailing_whitespace(buf, len); + } else if (prev < len && buf[prev] == ',') { + json_context_t ctx = find_context(buf, prev); + if (ctx == CTX_OBJECT) { + len = idx; + len = trim_trailing_whitespace(buf, len); + if (len > 0 && buf[len - 1] == ',') { + len--; + } + } + } + } + + return len; +} + +// Helper: remove trailing commas before closing brackets in a completed JSON string. +// Safe for in-place use (output == input) because out_idx <= read_idx at every step. +// input is non-const to permit aliased in-place calls without UB. +static size_t remove_trailing_commas(char *input, size_t input_len, + char *output, size_t output_capacity) { + size_t out = 0; + bool in_string = false; + bool escape_next = false; + + for (size_t i = 0; i < input_len && out < output_capacity - 1; i++) { + char c = input[i]; + + if (escape_next) { + escape_next = false; + output[out++] = c; + continue; + } + + if (in_string) { + if (c == '\\') escape_next = true; + else if (c == '"') in_string = false; + output[out++] = c; + continue; + } + + if (c == '"') { + in_string = true; + output[out++] = c; + continue; + } + + if (c == ',') { + // Look ahead for whitespace + closing bracket + size_t j = i + 1; + while (j < input_len && (input[j] == ' ' || input[j] == '\t' || + input[j] == '\n' || input[j] == '\r')) { + j++; + } + if (j < input_len && (input[j] == '}' || input[j] == ']')) { + continue; // Skip this comma + } + } + + output[out++] = c; + } + + return out; +} + +int64_t conduit_json_repair( + const char *input, + size_t input_len, + char *output, + size_t output_capacity, + int max_depth +) { + if (max_depth < 1) max_depth = 1; + if (output_capacity < 3) return -1; // Need at least "{}\0" + + // Skip leading/trailing whitespace from input + size_t start = 0; + while (start < input_len && (input[start] == ' ' || input[start] == '\t' || + input[start] == '\n' || input[start] == '\r')) { + start++; + } + + // Empty input → "{}" + if (start >= input_len) { + output[0] = '{'; + output[1] = '}'; + output[2] = '\0'; + return 2; + } + + // Parser state + bool in_string = false; + bool escape_next = false; + bracket_type_t bracket_stack[256]; // Use fixed stack (capped at 256 depth) + int stack_depth = 0; + int effective_max = max_depth < 256 ? max_depth : 256; + + // First pass: copy input to output while tracking state + // Guard against underflow: we need at least (effective_max + 2) bytes for closers + NUL. + if (output_capacity <= (size_t)(effective_max + 2)) return -1; + size_t out = 0; + size_t capacity_for_content = output_capacity - (size_t)(effective_max + 2); // Reserve space for closers + NUL + + for (size_t i = start; i < input_len && out < capacity_for_content; i++) { + char c = input[i]; + + if (escape_next) { + escape_next = false; + output[out++] = c; + continue; + } + + if (in_string) { + if (c == '\\') { + escape_next = true; + } else if (c == '"') { + in_string = false; + } + output[out++] = c; + continue; + } + + // Not in string + switch (c) { + case '"': + in_string = true; + break; + case '{': + if (stack_depth < effective_max) { + bracket_stack[stack_depth++] = BRACKET_BRACE; + } + break; + case '}': + if (stack_depth > 0) stack_depth--; + break; + case '[': + if (stack_depth < effective_max) { + bracket_stack[stack_depth++] = BRACKET_SQUARE; + } + break; + case ']': + if (stack_depth > 0) stack_depth--; + break; + default: + break; + } + + output[out++] = c; + } + + // If in string: handle partial unicode escape, remove trailing backslash, close quote + if (in_string) { + out = remove_partial_unicode_escape(output, out); + if (escape_next && out > 0 && output[out - 1] == '\\') { + out--; + } + if (out < output_capacity - 1) { + output[out++] = '"'; + } + } + + // Remove trailing whitespace and comma + out = trim_trailing_whitespace(output, out); + if (out > 0 && output[out - 1] == ',') { + out--; + } + + // Remove incomplete key-value pairs + out = remove_incomplete_kvp(output, out); + + // Close open brackets + for (int i = stack_depth - 1; i >= 0 && out < output_capacity - 1; i--) { + // Before adding closer, remove trailing comma + out = trim_trailing_whitespace(output, out); + if (out > 0 && output[out - 1] == ',') { + out--; + } + output[out++] = (bracket_stack[i] == BRACKET_BRACE) ? '}' : ']'; + } + + output[out] = '\0'; + + // Final pass: remove trailing commas before existing closing brackets (in-place) + size_t final_len = remove_trailing_commas(output, out, output, output_capacity); + output[final_len] = '\0'; + + return (int64_t)final_len; +} diff --git a/Sources/ConduitCore/src/conduit_line_buffer.c b/Sources/ConduitCore/src/conduit_line_buffer.c new file mode 100644 index 0000000..b75c0e4 --- /dev/null +++ b/Sources/ConduitCore/src/conduit_line_buffer.c @@ -0,0 +1,160 @@ +// conduit_line_buffer.c +// ConduitCore +// +// High-performance line buffer using a growable byte array with memchr()-based +// newline scanning. O(1) amortized line extraction via read pointer advancement +// (no memmove on every line like the Swift version's removeFirst). + +#include "conduit_core.h" +#include +#include + +struct conduit_line_buffer { + uint8_t *data; + size_t capacity; + size_t read_pos; // Start of unread data + size_t write_pos; // End of written data +}; + +conduit_line_buffer_t *conduit_line_buffer_create(size_t initial_capacity) { + conduit_line_buffer_t *buf = (conduit_line_buffer_t *)calloc(1, sizeof(conduit_line_buffer_t)); + if (!buf) return NULL; + + if (initial_capacity < 256) initial_capacity = 256; + + buf->data = (uint8_t *)malloc(initial_capacity); + if (!buf->data) { + free(buf); + return NULL; + } + + buf->capacity = initial_capacity; + buf->read_pos = 0; + buf->write_pos = 0; + + return buf; +} + +void conduit_line_buffer_destroy(conduit_line_buffer_t *buf) { + if (!buf) return; + free(buf->data); + free(buf); +} + +// Compact the buffer if read_pos has advanced past half the capacity +static void maybe_compact(conduit_line_buffer_t *buf) { + if (buf->read_pos > buf->capacity / 2 && buf->read_pos > 0) { + size_t pending = buf->write_pos - buf->read_pos; + if (pending > 0) { + memmove(buf->data, buf->data + buf->read_pos, pending); + } + buf->read_pos = 0; + buf->write_pos = pending; + } +} + +int conduit_line_buffer_append(conduit_line_buffer_t *buf, const uint8_t *data, size_t length) { + if (!buf || length == 0) return 0; + + size_t needed = buf->write_pos + length; + if (needed > buf->capacity) { + // Try compacting first + maybe_compact(buf); + needed = buf->write_pos + length; + + if (needed > buf->capacity) { + size_t new_cap = buf->capacity * 2; + if (new_cap < needed) new_cap = needed; + uint8_t *new_data = (uint8_t *)realloc(buf->data, new_cap); + if (!new_data) return -1; + buf->data = new_data; + buf->capacity = new_cap; + } + } + + memcpy(buf->data + buf->write_pos, data, length); + buf->write_pos += length; + return 0; +} + +int conduit_line_buffer_next_line( + conduit_line_buffer_t *buf, + char *line_out, + size_t line_out_capacity, + size_t *line_len +) { + if (!buf) return 0; + + size_t pending = buf->write_pos - buf->read_pos; + if (pending == 0) return 0; + + const uint8_t *start = buf->data + buf->read_pos; + + // Use memchr for fast newline scanning — this is the key optimization + // over Swift's firstIndex(where:) which checks two conditions per byte + const uint8_t *newline_lf = (const uint8_t *)memchr(start, '\n', pending); + const uint8_t *newline_cr = (const uint8_t *)memchr(start, '\r', pending); + + // Find the earliest newline + const uint8_t *newline = NULL; + if (newline_lf && newline_cr) { + newline = (newline_lf < newline_cr) ? newline_lf : newline_cr; + } else { + newline = newline_lf ? newline_lf : newline_cr; + } + + if (!newline) return 0; // No complete line yet + + size_t line_bytes = (size_t)(newline - start); + // Return -1 (not 0) when the line exists but is too large for the caller's buffer. + // Returning 0 would be indistinguishable from "no complete line yet" and cause + // callers polling in a loop to spin indefinitely on the unconsumed oversized line. + if (line_bytes >= line_out_capacity) return -1; + + // Copy the line (without delimiter) + memcpy(line_out, start, line_bytes); + line_out[line_bytes] = '\0'; + *line_len = line_bytes; + + // Consume the line + delimiter + size_t consume = line_bytes + 1; + + // Handle \r\n: if we consumed \r and next byte is \n, consume it too + if (*newline == '\r') { + size_t next_pos = buf->read_pos + consume; + if (next_pos < buf->write_pos && buf->data[next_pos] == '\n') { + consume++; + } + } + + buf->read_pos += consume; + + // Compact periodically + maybe_compact(buf); + + return 1; +} + +size_t conduit_line_buffer_pending(const conduit_line_buffer_t *buf) { + if (!buf) return 0; + return buf->write_pos - buf->read_pos; +} + +size_t conduit_line_buffer_drain(conduit_line_buffer_t *buf, char *out, size_t out_capacity) { + if (!buf) return 0; + + size_t pending = buf->write_pos - buf->read_pos; + if (pending == 0) return 0; + + size_t to_copy = pending < out_capacity ? pending : out_capacity; + memcpy(out, buf->data + buf->read_pos, to_copy); + // NUL-terminate when capacity permits (i.e. data did not fill the entire buffer). + // When out_capacity == to_copy the buffer is full of raw bytes with no room for NUL; + // callers must treat the output as raw bytes in that case. + if (to_copy < out_capacity) { + out[to_copy] = '\0'; + } + buf->read_pos += to_copy; + + return to_copy; +} diff --git a/Sources/ConduitCore/src/conduit_sse_parser.c b/Sources/ConduitCore/src/conduit_sse_parser.c new file mode 100644 index 0000000..c60b761 --- /dev/null +++ b/Sources/ConduitCore/src/conduit_sse_parser.c @@ -0,0 +1,251 @@ +// conduit_sse_parser.c +// ConduitCore +// +// Incremental SSE (Server-Sent Events) parser operating on UTF-8 byte buffers. +// No global state. Thread-safe when each parser instance is accessed by one thread. + +#include "conduit_core.h" +#include +#include +#include + +// Internal dynamic string buffer +typedef struct { + char *data; + size_t len; + size_t capacity; +} dyn_str_t; + +static void dyn_str_init(dyn_str_t *s) { + s->data = NULL; + s->len = 0; + s->capacity = 0; +} + +static void dyn_str_free(dyn_str_t *s) { + free(s->data); + s->data = NULL; + s->len = 0; + s->capacity = 0; +} + +static void dyn_str_clear(dyn_str_t *s) { + s->len = 0; +} + +static bool dyn_str_append(dyn_str_t *s, const char *data, size_t len) { + if (len == 0) return true; + size_t needed = s->len + len + 1; + if (needed > s->capacity) { + size_t new_cap = s->capacity ? s->capacity * 2 : 128; + if (new_cap < needed) new_cap = needed; + char *new_data = (char *)realloc(s->data, new_cap); + if (!new_data) return false; + s->data = new_data; + s->capacity = new_cap; + } + memcpy(s->data + s->len, data, len); + s->len += len; + s->data[s->len] = '\0'; + return true; +} + +static bool dyn_str_append_char(dyn_str_t *s, char c) { + return dyn_str_append(s, &c, 1); +} + +static const char *dyn_str_cstr(const dyn_str_t *s) { + return s->data ? s->data : ""; +} + +// Parser internals +struct conduit_sse_parser { + dyn_str_t current_id; + dyn_str_t current_event; + dyn_str_t current_data; + int current_retry; // -1 if not set + bool has_id; + bool has_event; + bool has_data; + + // Persistent state + dyn_str_t last_event_id; + int reconnection_time; +}; + +conduit_sse_parser_t *conduit_sse_parser_create(void) { + conduit_sse_parser_t *p = (conduit_sse_parser_t *)calloc(1, sizeof(conduit_sse_parser_t)); + if (!p) return NULL; + + dyn_str_init(&p->current_id); + dyn_str_init(&p->current_event); + dyn_str_init(&p->current_data); + dyn_str_init(&p->last_event_id); + p->current_retry = -1; + p->has_id = false; + p->has_event = false; + p->has_data = false; + p->reconnection_time = 3000; + + return p; +} + +void conduit_sse_parser_destroy(conduit_sse_parser_t *parser) { + if (!parser) return; + dyn_str_free(&parser->current_id); + dyn_str_free(&parser->current_event); + dyn_str_free(&parser->current_data); + dyn_str_free(&parser->last_event_id); + free(parser); +} + +static void reset_current_event(conduit_sse_parser_t *p) { + dyn_str_clear(&p->current_id); + dyn_str_clear(&p->current_event); + dyn_str_clear(&p->current_data); + p->current_retry = -1; + p->has_id = false; + p->has_event = false; + p->has_data = false; +} + +static void dispatch_if_needed(conduit_sse_parser_t *p, + conduit_sse_callback_t callback, + void *context) { + // If we have no data and no explicit id/event, nothing to dispatch + bool is_data_empty = (p->current_data.len == 0); + bool is_retry_only = is_data_empty && !p->has_id && !p->has_event && !p->has_data; + + if (is_retry_only) { + reset_current_event(p); + return; + } + + if (callback) { + conduit_sse_event_t event; + event.id = p->has_id ? dyn_str_cstr(&p->current_id) : NULL; + event.event = p->has_event ? dyn_str_cstr(&p->current_event) : NULL; + event.data = dyn_str_cstr(&p->current_data); + event.retry = p->current_retry; + + callback(&event, context); + } + + reset_current_event(p); +} + +void conduit_sse_ingest_line( + conduit_sse_parser_t *parser, + const char *line, + size_t length, + conduit_sse_callback_t callback, + void *context +) { + if (!parser) return; + + // Normalize: strip trailing \r (from CRLF) + while (length > 0 && line[length - 1] == '\r') { + length--; + } + + // Strip leading BOM + if (length >= 3 && (unsigned char)line[0] == 0xEF && + (unsigned char)line[1] == 0xBB && (unsigned char)line[2] == 0xBF) { + line += 3; + length -= 3; + } + + // Empty line → dispatch event + if (length == 0) { + dispatch_if_needed(parser, callback, context); + return; + } + + // Comment: starts with ':' + if (line[0] == ':') { + return; + } + + // Parse field:value + const char *colon = (const char *)memchr(line, ':', length); + const char *field = line; + size_t field_len; + const char *value; + size_t value_len; + + if (colon) { + field_len = (size_t)(colon - line); + value = colon + 1; + value_len = length - field_len - 1; + // Skip single leading space in value (per SSE spec) + if (value_len > 0 && value[0] == ' ') { + value++; + value_len--; + } + } else { + field_len = length; + value = ""; + value_len = 0; + } + + if (field_len == 5 && memcmp(field, "event", 5) == 0) { + dyn_str_clear(&parser->current_event); + dyn_str_append(&parser->current_event, value, value_len); + parser->has_event = true; + } + else if (field_len == 4 && memcmp(field, "data", 4) == 0) { + if (parser->current_data.len > 0) { + dyn_str_append_char(&parser->current_data, '\n'); + } + dyn_str_append(&parser->current_data, value, value_len); + parser->has_data = true; + } + else if (field_len == 2 && memcmp(field, "id", 2) == 0) { + // ID must not contain null byte + bool has_null = false; + for (size_t i = 0; i < value_len; i++) { + if (value[i] == '\0') { has_null = true; break; } + } + if (!has_null) { + dyn_str_clear(&parser->current_id); + dyn_str_append(&parser->current_id, value, value_len); + parser->has_id = true; + // Update last event id + dyn_str_clear(&parser->last_event_id); + dyn_str_append(&parser->last_event_id, value, value_len); + } + } + else if (field_len == 5 && memcmp(field, "retry", 5) == 0) { + // Parse as positive integer with overflow guard (max ~24 days in ms) + int ms = 0; + bool valid = (value_len > 0); + for (size_t i = 0; i < value_len && valid; i++) { + if (value[i] >= '0' && value[i] <= '9') { + if (ms > 214748364) { valid = false; break; } // prevent int overflow + ms = ms * 10 + (value[i] - '0'); + } else { + valid = false; + } + } + // Per WHATWG SSE spec §9.2.6, retry:0 is valid and should set the + // reconnection time to 0 ms. The previous guard `ms > 0` violated the spec. + if (valid) { + parser->reconnection_time = ms; + parser->current_retry = ms; + } + } + // Unknown fields are ignored +} + +void conduit_sse_finish( + conduit_sse_parser_t *parser, + conduit_sse_callback_t callback, + void *context +) { + if (!parser) return; + + // Only dispatch if we have non-empty data or explicit id/event + if (parser->current_data.len > 0 || parser->has_id || parser->has_event) { + dispatch_if_needed(parser, callback, context); + } +} diff --git a/Sources/ConduitCore/src/conduit_vector_ops.c b/Sources/ConduitCore/src/conduit_vector_ops.c new file mode 100644 index 0000000..24bff29 --- /dev/null +++ b/Sources/ConduitCore/src/conduit_vector_ops.c @@ -0,0 +1,127 @@ +// conduit_vector_ops.c +// ConduitCore +// +// High-performance vector operations for embedding similarity computation. +// Uses Accelerate/vDSP on Apple platforms, scalar fallback elsewhere. + +#include "conduit_core.h" +#include + +#ifdef CONDUIT_HAS_ACCELERATE +#include +#endif + +float conduit_dot_product(const float *a, const float *b, size_t count) { + if (count == 0) return 0.0f; + +#ifdef CONDUIT_HAS_ACCELERATE + float result = 0.0f; + vDSP_dotpr(a, 1, b, 1, &result, (vDSP_Length)count); + return result; +#else + float result = 0.0f; + for (size_t i = 0; i < count; i++) { + result += a[i] * b[i]; + } + return result; +#endif +} + +float conduit_cosine_similarity(const float *a, const float *b, size_t count) { + if (count == 0) return 0.0f; + +#ifdef CONDUIT_HAS_ACCELERATE + float dot = 0.0f, normA = 0.0f, normB = 0.0f; + vDSP_dotpr(a, 1, b, 1, &dot, (vDSP_Length)count); + vDSP_dotpr(a, 1, a, 1, &normA, (vDSP_Length)count); + vDSP_dotpr(b, 1, b, 1, &normB, (vDSP_Length)count); +#else + float dot = 0.0f, normA = 0.0f, normB = 0.0f; + for (size_t i = 0; i < count; i++) { + dot += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } +#endif + + float denom = sqrtf(normA) * sqrtf(normB); + return denom > 0.0f ? dot / denom : 0.0f; +} + +float conduit_euclidean_distance(const float *a, const float *b, size_t count) { + if (count == 0) return 0.0f; + +#ifdef CONDUIT_HAS_ACCELERATE + // diff = a - b, then compute sqrt(dot(diff, diff)) + float *diff = (float *)malloc(count * sizeof(float)); + if (!diff) { + // Fallback to scalar on allocation failure + float sum = 0.0f; + for (size_t i = 0; i < count; i++) { + float d = a[i] - b[i]; + sum += d * d; + } + return sqrtf(sum); + } + vDSP_vsub(b, 1, a, 1, diff, 1, (vDSP_Length)count); + float sum_sq = 0.0f; + vDSP_dotpr(diff, 1, diff, 1, &sum_sq, (vDSP_Length)count); + free(diff); + return sqrtf(sum_sq); +#else + float sum = 0.0f; + for (size_t i = 0; i < count; i++) { + float d = a[i] - b[i]; + sum += d * d; + } + return sqrtf(sum); +#endif +} + +void conduit_cosine_similarity_batch( + const float *query, + const float *vectors, + size_t dimensions, + size_t count, + float *results +) { + if (dimensions == 0 || count == 0) return; + + // Pre-compute query norm +#ifdef CONDUIT_HAS_ACCELERATE + float query_norm_sq = 0.0f; + vDSP_dotpr(query, 1, query, 1, &query_norm_sq, (vDSP_Length)dimensions); +#else + float query_norm_sq = 0.0f; + for (size_t i = 0; i < dimensions; i++) { + query_norm_sq += query[i] * query[i]; + } +#endif + float query_norm = sqrtf(query_norm_sq); + + if (query_norm == 0.0f) { + for (size_t i = 0; i < count; i++) { + results[i] = 0.0f; + } + return; + } + + for (size_t v = 0; v < count; v++) { + const float *vec = vectors + v * dimensions; + +#ifdef CONDUIT_HAS_ACCELERATE + float dot = 0.0f, vec_norm_sq = 0.0f; + vDSP_dotpr(query, 1, vec, 1, &dot, (vDSP_Length)dimensions); + vDSP_dotpr(vec, 1, vec, 1, &vec_norm_sq, (vDSP_Length)dimensions); +#else + float dot = 0.0f, vec_norm_sq = 0.0f; + for (size_t i = 0; i < dimensions; i++) { + dot += query[i] * vec[i]; + vec_norm_sq += vec[i] * vec[i]; + } +#endif + + float vec_norm = sqrtf(vec_norm_sq); + results[v] = vec_norm > 0.0f ? dot / (query_norm * vec_norm) : 0.0f; + } +} diff --git a/Tests/ConduitCoreTests/JsonCompleterCTests.swift b/Tests/ConduitCoreTests/JsonCompleterCTests.swift new file mode 100644 index 0000000..b6feb01 --- /dev/null +++ b/Tests/ConduitCoreTests/JsonCompleterCTests.swift @@ -0,0 +1,162 @@ +// JsonCompleterCTests.swift +// ConduitCoreTests + +import Foundation +import Testing +import ConduitCore + +@Suite("JSON Completer C Tests") +struct JsonCompleterCTests { + + // Helper: calls conduit_json_complete which now returns the FULL completed string + // Returns empty string if JSON is already complete, or the full completed string otherwise + func completeRaw(_ json: String, maxDepth: Int = 64) -> (output: String, resultCode: Int64) { + let input = Array(json.utf8) + let capacity = input.count + 256 + var output = [CChar](repeating: 0, count: capacity) + + let result = input.withUnsafeBufferPointer { inputBuf in + output.withUnsafeMutableBufferPointer { outputBuf in + conduit_json_complete( + inputBuf.baseAddress, + inputBuf.count, + outputBuf.baseAddress, + outputBuf.count, + Int32(maxDepth) + ) + } + } + + return (String(cString: output), result) + } + + // Helper: returns full completed JSON (the original if already complete) + func fullComplete(_ json: String) -> String { + let (output, code) = completeRaw(json) + if code == 0 { return json } // Already complete + if code < 0 { return json } // Error fallback + return output + } + + // MARK: - Complete JSON returns code 0 (already complete) + + @Test("Complete object needs no completion") + func completeObjectNoSuffix() { + let (_, code) = completeRaw(#"{"a": 1}"#) + #expect(code == 0) + } + + @Test("Complete array needs no completion") + func completeArrayNoSuffix() { + let (_, code) = completeRaw("[1, 2, 3]") + #expect(code == 0) + } + + @Test("Complete string needs no completion") + func completeStringNoSuffix() { + let (_, code) = completeRaw(#""hello""#) + #expect(code == 0) + } + + @Test("Empty input returns code 0") + func emptyInput() { + let (_, code) = completeRaw("") + #expect(code == 0) + } + + // MARK: - String Completion + + @Test("Unclosed string gets closing quote") + func unclosedString() { + #expect(fullComplete(#""hello"#) == #""hello""#) + } + + // MARK: - Object Completion + + @Test("Unclosed object with value") + func unclosedObjectWithValue() { + let result = fullComplete(#"{"a": 1"#) + #expect(result == #"{"a": 1}"#) + } + + @Test("Object with incomplete key gets null value") + func objectIncompleteKey() { + let result = fullComplete(#"{"name"#) + // Should complete the string and add : null} + #expect(result == #"{"name": null}"#) + } + + @Test("Object with key and colon but no value") + func objectKeyColonNoValue() { + let result = fullComplete(#"{"name": "#) + // Trailing space after colon is before the completion point, + // so the output truncates at the colon position: {"name":null} + #expect(result == #"{"name":null}"#) + } + + @Test("Object with incomplete string value") + func objectIncompleteStringValue() { + let result = fullComplete(#"{"name": "Alice"#) + #expect(result == #"{"name": "Alice"}"#) + } + + @Test("Nested objects get closed") + func nestedObjectsClosed() { + let result = fullComplete(#"{"a": {"b": 1"#) + #expect(result == #"{"a": {"b": 1}}"#) + } + + // MARK: - Array Completion + + @Test("Unclosed array") + func unclosedArray() { + let result = fullComplete("[1, 2") + #expect(result == "[1, 2]") + } + + @Test("Array with incomplete element") + func arrayIncompleteElement() { + let result = fullComplete(#"[1, "hello"#) + #expect(result == #"[1, "hello"]"#) + } + + // MARK: - Special Values + + @Test("Partial true gets completed") + func partialTrue() { + #expect(fullComplete("tr") == "true") + } + + @Test("Partial false gets completed") + func partialFalse() { + #expect(fullComplete("fal") == "false") + } + + @Test("Partial null gets completed") + func partialNull() { + #expect(fullComplete("nu") == "null") + } + + // MARK: - Number Completion + + @Test("Complete number needs no completion") + func completeNumber() { + let (_, code) = completeRaw("42") + #expect(code == 0) + } + + @Test("Bare minus gets -0") + func bareMinus() { + #expect(fullComplete("-") == "-0") + } + + @Test("Decimal point without fraction") + func decimalNoFraction() { + #expect(fullComplete("3.") == "3.0") + } + + @Test("Minus-dot gets -0.0") + func minusDot() { + #expect(fullComplete("-.") == "-0.0") + } +} diff --git a/Tests/ConduitCoreTests/JsonRepairCTests.swift b/Tests/ConduitCoreTests/JsonRepairCTests.swift new file mode 100644 index 0000000..36b7f44 --- /dev/null +++ b/Tests/ConduitCoreTests/JsonRepairCTests.swift @@ -0,0 +1,176 @@ +// JsonRepairCTests.swift +// ConduitCoreTests + +import Foundation +import Testing +import ConduitCore + +@Suite("JSON Repair C Tests") +struct JsonRepairCTests { + + // Helper to call the C repair function + func repair(_ json: String, maxDepth: Int = 64) -> String { + let input = Array(json.utf8) + let capacity = input.count + 256 + var output = [CChar](repeating: 0, count: capacity) + + let result = input.withUnsafeBufferPointer { inputBuf in + output.withUnsafeMutableBufferPointer { outputBuf in + conduit_json_repair( + inputBuf.baseAddress, + inputBuf.count, + outputBuf.baseAddress, + outputBuf.count, + Int32(maxDepth) + ) + } + } + + guard result >= 0 else { return "" } + return String(cString: output) + } + + // MARK: - String Repairs + + @Test("Unclosed string at end gets closing quote and brace") + func unclosedStringAtEnd() { + let repaired = repair(#"{"name": "Alice"#) + #expect(repaired == #"{"name": "Alice"}"#) + } + + @Test("Incomplete escape sequence at end is handled") + func incompleteEscapeSequence() { + let repaired = repair(#"{"text": "hello\"#) + #expect(repaired == #"{"text": "hello"}"#) + } + + @Test("Partial unicode escape is removed") + func partialUnicodeEscape() { + let repaired = repair(#"{"text": "\u00"#) + #expect(repaired == #"{"text": ""}"#) + } + + @Test("Normal strings pass through unchanged") + func normalStringsUnchanged() { + let complete = #"{"name": "Alice", "age": 30}"# + #expect(repair(complete) == complete) + } + + // MARK: - Object Repairs + + @Test("Unclosed single object") + func unclosedSingleObject() { + #expect(repair(#"{"a": 1"#) == #"{"a": 1}"#) + } + + @Test("Nested unclosed objects") + func nestedUnclosedObjects() { + #expect(repair(#"{"user": {"name": "Bob""#) == #"{"user": {"name": "Bob"}}"#) + } + + @Test("Trailing comma before close removed") + func trailingCommaRemoved() { + #expect(repair(#"{"a": 1,}"#) == #"{"a": 1}"#) + } + + @Test("Multiple trailing commas and whitespace") + func multipleTrailingCommasAndWhitespace() { + #expect(repair(#"{"a": 1, "#) == #"{"a": 1}"#) + } + + @Test("Valid object passes through unchanged") + func validObjectUnchanged() { + let json = #"{"name": "Alice", "age": 30, "active": true}"# + #expect(repair(json) == json) + } + + @Test("Empty object is valid") + func emptyObjectValid() { + #expect(repair("{}") == "{}") + } + + // MARK: - Array Repairs + + @Test("Unclosed array") + func unclosedArray() { + #expect(repair("[1, 2, 3") == "[1, 2, 3]") + } + + @Test("Nested unclosed arrays") + func nestedUnclosedArrays() { + #expect(repair("[[1, 2, [3, 4") == "[[1, 2, [3, 4]]]") + } + + @Test("Mixed array and object closures") + func mixedArrayObjectClosures() { + #expect(repair(#"{"arr": [1, 2"#) == #"{"arr": [1, 2]}"#) + } + + @Test("Array trailing comma fixed") + func arrayTrailingComma() { + #expect(repair("[1, 2, 3,]") == "[1, 2, 3]") + } + + @Test("Empty array is valid") + func emptyArrayValid() { + #expect(repair("[]") == "[]") + } + + // MARK: - Edge Cases + + @Test("Empty input returns empty object") + func emptyInput() { + #expect(repair("") == "{}") + } + + @Test("Deeply nested structures (5+ levels)") + func deeplyNested() { + let repaired = repair(#"{"a": {"b": {"c": {"d": {"e": "value""#) + #expect(repaired == #"{"a": {"b": {"c": {"d": {"e": "value"}}}}}"#) + } + + @Test("Mixed nesting deep") + func mixedNestingDeep() { + let repaired = repair(#"{"data": [{"items": [1, 2, {"nested": [3, 4"#) + #expect(repaired == #"{"data": [{"items": [1, 2, {"nested": [3, 4]}]}]}"#) + } + + @Test("Single opening brace") + func singleBrace() { + #expect(repair("{") == "{}") + } + + @Test("Single opening bracket") + func singleBracket() { + #expect(repair("[") == "[]") + } + + @Test("Whitespace only input") + func whitespaceOnly() { + #expect(repair(" ") == "{}") + } + + @Test("Bracket inside string value is not mistaken for array opener") + func bracketInsideStringValue() { + // The '[' inside the string literal must NOT be treated as an array opener + // by find_context when deciding whether a trailing incomplete key is in + // an object or array context. + let repaired = repair(#"{"key": "[value", "#) + // Should not produce an extra ']' closer from the bracket inside the string + #expect(!repaired.hasSuffix("]}")) + // Must still be valid-ish JSON ending with '}' + #expect(repaired.hasSuffix("}")) + } + + @Test("Boolean values") + func booleanValues() { + #expect(repair(#"{"active": true, "verified": false"#) == + #"{"active": true, "verified": false}"#) + } + + @Test("Null values") + func nullValues() { + #expect(repair(#"{"value": null, "other": 42"#) == + #"{"value": null, "other": 42}"#) + } +} diff --git a/Tests/ConduitCoreTests/LineBufferCTests.swift b/Tests/ConduitCoreTests/LineBufferCTests.swift new file mode 100644 index 0000000..c4cfc5d --- /dev/null +++ b/Tests/ConduitCoreTests/LineBufferCTests.swift @@ -0,0 +1,228 @@ +// LineBufferCTests.swift +// ConduitCoreTests + +import Foundation +import Testing +import ConduitCore + +@Suite("Line Buffer C Tests") +struct LineBufferCTests { + + // MARK: - Basic Line Extraction + + @Test("Extract single line ending with LF") + func singleLineLF() { + let buf = conduit_line_buffer_create(256)! + defer { conduit_line_buffer_destroy(buf) } + + let data: [UInt8] = Array("hello\n".utf8) + conduit_line_buffer_append(buf, data, data.count) + + var line = [CChar](repeating: 0, count: 256) + var lineLen: Int = 0 + let result = conduit_line_buffer_next_line(buf, &line, 256, &lineLen) + + #expect(result == 1) + #expect(lineLen == 5) + #expect(String(cString: line) == "hello") + } + + @Test("Extract single line ending with CR") + func singleLineCR() { + let buf = conduit_line_buffer_create(256)! + defer { conduit_line_buffer_destroy(buf) } + + let data: [UInt8] = Array("hello\r".utf8) + conduit_line_buffer_append(buf, data, data.count) + + var line = [CChar](repeating: 0, count: 256) + var lineLen: Int = 0 + let result = conduit_line_buffer_next_line(buf, &line, 256, &lineLen) + + #expect(result == 1) + #expect(String(cString: line) == "hello") + } + + @Test("Extract single line ending with CRLF") + func singleLineCRLF() { + let buf = conduit_line_buffer_create(256)! + defer { conduit_line_buffer_destroy(buf) } + + let data: [UInt8] = Array("hello\r\n".utf8) + conduit_line_buffer_append(buf, data, data.count) + + var line = [CChar](repeating: 0, count: 256) + var lineLen: Int = 0 + let result = conduit_line_buffer_next_line(buf, &line, 256, &lineLen) + + #expect(result == 1) + #expect(String(cString: line) == "hello") + } + + // MARK: - Multiple Lines + + @Test("Extract multiple lines") + func multipleLines() { + let buf = conduit_line_buffer_create(256)! + defer { conduit_line_buffer_destroy(buf) } + + let data: [UInt8] = Array("line1\nline2\nline3\n".utf8) + conduit_line_buffer_append(buf, data, data.count) + + var line = [CChar](repeating: 0, count: 256) + var lineLen: Int = 0 + + #expect(conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 1) + #expect(String(cString: line) == "line1") + + #expect(conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 1) + #expect(String(cString: line) == "line2") + + #expect(conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 1) + #expect(String(cString: line) == "line3") + + // No more lines + #expect(conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 0) + } + + // MARK: - Partial Data + + @Test("No line returned for incomplete data") + func incompleteData() { + let buf = conduit_line_buffer_create(256)! + defer { conduit_line_buffer_destroy(buf) } + + let data: [UInt8] = Array("partial".utf8) + conduit_line_buffer_append(buf, data, data.count) + + var line = [CChar](repeating: 0, count: 256) + var lineLen: Int = 0 + + #expect(conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 0) + } + + @Test("Incremental append completes a line") + func incrementalAppend() { + let buf = conduit_line_buffer_create(256)! + defer { conduit_line_buffer_destroy(buf) } + + var line = [CChar](repeating: 0, count: 256) + var lineLen: Int = 0 + + let part1: [UInt8] = Array("hel".utf8) + conduit_line_buffer_append(buf, part1, part1.count) + #expect(conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 0) + + let part2: [UInt8] = Array("lo\n".utf8) + conduit_line_buffer_append(buf, part2, part2.count) + #expect(conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 1) + #expect(String(cString: line) == "hello") + } + + // MARK: - Empty Lines + + @Test("Empty line between two lines") + func emptyLineBetween() { + let buf = conduit_line_buffer_create(256)! + defer { conduit_line_buffer_destroy(buf) } + + let data: [UInt8] = Array("a\n\nb\n".utf8) + conduit_line_buffer_append(buf, data, data.count) + + var line = [CChar](repeating: 0, count: 256) + var lineLen: Int = 0 + + #expect(conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 1) + #expect(String(cString: line) == "a") + + #expect(conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 1) + #expect(lineLen == 0) // Empty line + #expect(String(cString: line) == "") + + #expect(conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 1) + #expect(String(cString: line) == "b") + } + + // MARK: - Drain + + @Test("Drain returns remaining bytes") + func drainRemainder() { + let buf = conduit_line_buffer_create(256)! + defer { conduit_line_buffer_destroy(buf) } + + let data: [UInt8] = Array("line1\nremainder".utf8) + conduit_line_buffer_append(buf, data, data.count) + + var line = [CChar](repeating: 0, count: 256) + var lineLen: Int = 0 + + #expect(conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 1) + #expect(String(cString: line) == "line1") + + // Drain the remainder + var remainder = [CChar](repeating: 0, count: 256) + let drainLen = conduit_line_buffer_drain(buf, &remainder, 256) + #expect(drainLen == 9) // "remainder" = 9 bytes + } + + // MARK: - Pending Count + + @Test("Pending count tracks buffered bytes") + func pendingCount() { + let buf = conduit_line_buffer_create(256)! + defer { conduit_line_buffer_destroy(buf) } + + #expect(conduit_line_buffer_pending(buf) == 0) + + let data: [UInt8] = Array("hello\n".utf8) + conduit_line_buffer_append(buf, data, data.count) + #expect(conduit_line_buffer_pending(buf) == 6) + + var line = [CChar](repeating: 0, count: 256) + var lineLen: Int = 0 + conduit_line_buffer_next_line(buf, &line, 256, &lineLen) + #expect(conduit_line_buffer_pending(buf) == 0) + } + + @Test("next_line returns -1 (not 0) when line exceeds output buffer capacity") + func nextLineOversizedLine() { + let buf = conduit_line_buffer_create(1024)! + defer { conduit_line_buffer_destroy(buf) } + + // Append a line that is longer than the output buffer we'll provide + let longLine: [UInt8] = Array("ABCDEFGHIJ\n".utf8) // 10 chars + newline + conduit_line_buffer_append(buf, longLine, longLine.count) + + // Provide a 4-byte output buffer — too small for the 10-char line + var line = [CChar](repeating: 0, count: 4) + var lineLen: Int = 0 + let result = conduit_line_buffer_next_line(buf, &line, 4, &lineLen) + + // Must return -1 (not 0) so callers can distinguish "no line" from "line too large" + #expect(result == -1) + // The oversized line must remain unconsumed so callers can retry with a larger buffer + #expect(conduit_line_buffer_pending(buf) == longLine.count) + } + + // MARK: - SSE Simulation + + @Test("Simulated SSE stream with mixed delimiters") + func sseSimulation() { + let buf = conduit_line_buffer_create(1024)! + defer { conduit_line_buffer_destroy(buf) } + + // Simulate an SSE stream chunk + let chunk: [UInt8] = Array("data: hello\r\n\r\ndata: world\r\n\r\n".utf8) + conduit_line_buffer_append(buf, chunk, chunk.count) + + var line = [CChar](repeating: 0, count: 256) + var lineLen: Int = 0 + var lines: [String] = [] + + while conduit_line_buffer_next_line(buf, &line, 256, &lineLen) == 1 { + lines.append(String(cString: line)) + } + + #expect(lines == ["data: hello", "", "data: world", ""]) + } +} diff --git a/Tests/ConduitCoreTests/SSEParserCTests.swift b/Tests/ConduitCoreTests/SSEParserCTests.swift new file mode 100644 index 0000000..6c96c10 --- /dev/null +++ b/Tests/ConduitCoreTests/SSEParserCTests.swift @@ -0,0 +1,247 @@ +// SSEParserCTests.swift +// ConduitCoreTests + +import Foundation +import Testing +import ConduitCore + +@Suite("SSE Parser C Tests") +struct SSEParserCTests { + + // Collected events from callbacks + final class EventCollector: @unchecked Sendable { + struct Event { + let id: String? + let event: String? + let data: String + let retry: Int + } + + var events: [Event] = [] + + static let callback: conduit_sse_callback_t = { eventPtr, contextPtr in + guard let event = eventPtr, let ctx = contextPtr else { return } + let collector = Unmanaged.fromOpaque(ctx).takeUnretainedValue() + + let id = event.pointee.id.map { String(cString: $0) } + let eventType = event.pointee.event.map { String(cString: $0) } + let data = String(cString: event.pointee.data) + let retry = Int(event.pointee.retry) + + collector.events.append(Event(id: id, event: eventType, data: data, retry: retry)) + } + } + + func makeParser() -> OpaquePointer { + conduit_sse_parser_create()! + } + + func ingest(_ parser: OpaquePointer, line: String, collector: EventCollector) { + let ctx = Unmanaged.passUnretained(collector).toOpaque() + line.withCString { cstr in + conduit_sse_ingest_line(parser, cstr, strlen(cstr), EventCollector.callback, ctx) + } + } + + func finish(_ parser: OpaquePointer, collector: EventCollector) { + let ctx = Unmanaged.passUnretained(collector).toOpaque() + conduit_sse_finish(parser, EventCollector.callback, ctx) + } + + // MARK: - Basic Events + + @Test("Simple data event dispatched on empty line") + func simpleDataEvent() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + ingest(parser, line: "data: hello world", collector: collector) + #expect(collector.events.isEmpty) // Not dispatched yet + + ingest(parser, line: "", collector: collector) // Empty line dispatches + #expect(collector.events.count == 1) + #expect(collector.events[0].data == "hello world") + #expect(collector.events[0].event == nil) + #expect(collector.events[0].id == nil) + } + + @Test("Event with type") + func eventWithType() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + ingest(parser, line: "event: update", collector: collector) + ingest(parser, line: "data: some data", collector: collector) + ingest(parser, line: "", collector: collector) + + #expect(collector.events.count == 1) + #expect(collector.events[0].event == "update") + #expect(collector.events[0].data == "some data") + } + + @Test("Event with id") + func eventWithId() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + ingest(parser, line: "id: 42", collector: collector) + ingest(parser, line: "data: test", collector: collector) + ingest(parser, line: "", collector: collector) + + #expect(collector.events.count == 1) + #expect(collector.events[0].id == "42") + #expect(collector.events[0].data == "test") + } + + @Test("Multi-line data joined with newlines") + func multiLineData() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + ingest(parser, line: "data: line1", collector: collector) + ingest(parser, line: "data: line2", collector: collector) + ingest(parser, line: "data: line3", collector: collector) + ingest(parser, line: "", collector: collector) + + #expect(collector.events.count == 1) + #expect(collector.events[0].data == "line1\nline2\nline3") + } + + // MARK: - Comments and Unknown Fields + + @Test("Comments are ignored") + func commentsIgnored() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + ingest(parser, line: ": this is a comment", collector: collector) + ingest(parser, line: "data: actual data", collector: collector) + ingest(parser, line: "", collector: collector) + + #expect(collector.events.count == 1) + #expect(collector.events[0].data == "actual data") + } + + @Test("Unknown fields are ignored") + func unknownFieldsIgnored() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + ingest(parser, line: "custom: field", collector: collector) + ingest(parser, line: "data: test", collector: collector) + ingest(parser, line: "", collector: collector) + + #expect(collector.events.count == 1) + #expect(collector.events[0].data == "test") + } + + // MARK: - Retry + + @Test("Retry field parsed correctly") + func retryField() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + ingest(parser, line: "retry: 5000", collector: collector) + ingest(parser, line: "data: test", collector: collector) + ingest(parser, line: "", collector: collector) + + #expect(collector.events.count == 1) + #expect(collector.events[0].retry == 5000) + } + + @Test("Retry field with value 0 is accepted per WHATWG SSE spec") + func retryZero() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + // WHATWG SSE spec §9.2.6: retry:0 must set reconnection time to 0 ms + ingest(parser, line: "retry: 0", collector: collector) + ingest(parser, line: "data: test", collector: collector) + ingest(parser, line: "", collector: collector) + + #expect(collector.events.count == 1) + #expect(collector.events[0].retry == 0) + } + + // MARK: - Finish + + @Test("Finish flushes pending event") + func finishFlushes() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + ingest(parser, line: "data: pending", collector: collector) + #expect(collector.events.isEmpty) + + finish(parser, collector: collector) + #expect(collector.events.count == 1) + #expect(collector.events[0].data == "pending") + } + + @Test("Finish with nothing pending produces no event") + func finishNoPending() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + finish(parser, collector: collector) + #expect(collector.events.isEmpty) + } + + // MARK: - Multiple Events + + @Test("Multiple events in sequence") + func multipleEvents() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + ingest(parser, line: "data: first", collector: collector) + ingest(parser, line: "", collector: collector) + ingest(parser, line: "data: second", collector: collector) + ingest(parser, line: "", collector: collector) + + #expect(collector.events.count == 2) + #expect(collector.events[0].data == "first") + #expect(collector.events[1].data == "second") + } + + // MARK: - Edge Cases + + @Test("Leading space in data value stripped per spec") + func leadingSpaceStripped() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + ingest(parser, line: "data: two spaces", collector: collector) + ingest(parser, line: "", collector: collector) + + // Only first space stripped, second kept + #expect(collector.events[0].data == " two spaces") + } + + @Test("Field with no colon uses whole line as field name") + func noColonFieldName() { + let parser = makeParser() + defer { conduit_sse_parser_destroy(parser) } + let collector = EventCollector() + + // "data" with no colon is treated as field "data" with empty value + ingest(parser, line: "data", collector: collector) + ingest(parser, line: "", collector: collector) + + #expect(collector.events.count == 1) + #expect(collector.events[0].data == "") + } +} diff --git a/Tests/ConduitCoreTests/VectorOpsCTests.swift b/Tests/ConduitCoreTests/VectorOpsCTests.swift new file mode 100644 index 0000000..9f9fe81 --- /dev/null +++ b/Tests/ConduitCoreTests/VectorOpsCTests.swift @@ -0,0 +1,168 @@ +// VectorOpsCTests.swift +// ConduitCoreTests + +import Foundation +import Testing +import ConduitCore + +@Suite("Vector Operations C Tests") +struct VectorOpsCTests { + + // MARK: - Dot Product + + @Test("Dot product of known vectors") + func dotProductKnown() { + let a: [Float] = [1.0, 2.0, 3.0] + let b: [Float] = [4.0, 5.0, 6.0] + // 1*4 + 2*5 + 3*6 = 32 + let result = conduit_dot_product(a, b, 3) + #expect(abs(result - 32.0) < 0.0001) + } + + @Test("Dot product of orthogonal vectors is 0") + func dotProductOrthogonal() { + let a: [Float] = [1.0, 0.0] + let b: [Float] = [0.0, 1.0] + let result = conduit_dot_product(a, b, 2) + #expect(abs(result) < 0.0001) + } + + @Test("Dot product with zero count returns 0") + func dotProductEmpty() { + let a: [Float] = [1.0] + let b: [Float] = [1.0] + let result = conduit_dot_product(a, b, 0) + #expect(result == 0.0) + } + + // MARK: - Cosine Similarity + + @Test("Cosine similarity of identical vectors is 1") + func cosineSimilaritySame() { + let a: [Float] = [1.0, 0.0, 0.0] + let b: [Float] = [1.0, 0.0, 0.0] + let result = conduit_cosine_similarity(a, b, 3) + #expect(abs(result - 1.0) < 0.0001) + } + + @Test("Cosine similarity of orthogonal vectors is 0") + func cosineSimilarityOrthogonal() { + let a: [Float] = [1.0, 0.0] + let b: [Float] = [0.0, 1.0] + let result = conduit_cosine_similarity(a, b, 2) + #expect(abs(result) < 0.0001) + } + + @Test("Cosine similarity of opposite vectors is -1") + func cosineSimilarityOpposite() { + let a: [Float] = [1.0, 0.0] + let b: [Float] = [-1.0, 0.0] + let result = conduit_cosine_similarity(a, b, 2) + #expect(abs(result - (-1.0)) < 0.0001) + } + + @Test("Cosine similarity returns 0 for zero vector") + func cosineSimilarityZeroVector() { + let a: [Float] = [0.0, 0.0] + let b: [Float] = [1.0, 2.0] + let result = conduit_cosine_similarity(a, b, 2) + #expect(result == 0.0) + } + + @Test("Cosine similarity with zero count returns 0") + func cosineSimilarityEmpty() { + let a: [Float] = [1.0] + let b: [Float] = [1.0] + let result = conduit_cosine_similarity(a, b, 0) + #expect(result == 0.0) + } + + // MARK: - Euclidean Distance + + @Test("Euclidean distance of identical vectors is 0") + func euclideanDistanceSame() { + let a: [Float] = [1.0, 2.0, 3.0] + let b: [Float] = [1.0, 2.0, 3.0] + let result = conduit_euclidean_distance(a, b, 3) + #expect(abs(result) < 0.0001) + } + + @Test("Euclidean distance of known vectors (3-4-5 triangle)") + func euclideanDistanceKnown() { + let a: [Float] = [0.0, 0.0] + let b: [Float] = [3.0, 4.0] + let result = conduit_euclidean_distance(a, b, 2) + #expect(abs(result - 5.0) < 0.0001) + } + + @Test("Euclidean distance with zero count returns 0") + func euclideanDistanceEmpty() { + let a: [Float] = [1.0] + let b: [Float] = [1.0] + let result = conduit_euclidean_distance(a, b, 0) + #expect(result == 0.0) + } + + // MARK: - Batch Cosine Similarity + + @Test("Batch cosine similarity matches individual calls") + func batchCosineSimilarity() { + let query: [Float] = [1.0, 0.0, 0.0] + let vectors: [Float] = [ + 1.0, 0.0, 0.0, // identical → 1.0 + 0.0, 1.0, 0.0, // orthogonal → 0.0 + -1.0, 0.0, 0.0, // opposite → -1.0 + ] + var results: [Float] = [0.0, 0.0, 0.0] + + conduit_cosine_similarity_batch(query, vectors, 3, 3, &results) + + #expect(abs(results[0] - 1.0) < 0.0001) + #expect(abs(results[1]) < 0.0001) + #expect(abs(results[2] - (-1.0)) < 0.0001) + } + + @Test("Batch cosine similarity with zero query returns all zeros") + func batchCosineSimilarityZeroQuery() { + let query: [Float] = [0.0, 0.0] + let vectors: [Float] = [1.0, 2.0, 3.0, 4.0] + var results: [Float] = [999.0, 999.0] + + conduit_cosine_similarity_batch(query, vectors, 2, 2, &results) + + #expect(results[0] == 0.0) + #expect(results[1] == 0.0) + } + + // MARK: - Parity with Swift EmbeddingResult + + @Test("C dot product matches Swift EmbeddingResult.dotProduct") + func parityDotProduct() { + let vecA: [Float] = [0.5, 1.5, -2.0, 0.3] + let vecB: [Float] = [1.0, -0.5, 0.8, 2.1] + + // Swift reference + let swiftResult = zip(vecA, vecB).reduce(Float(0)) { $0 + $1.0 * $1.1 } + let cResult = conduit_dot_product(vecA, vecB, 4) + + #expect(abs(cResult - swiftResult) < 0.0001) + } + + @Test("C cosine similarity matches Swift implementation for non-trivial vectors") + func parityCosineSimilarity() { + let vecA: [Float] = [0.5, 1.5, -2.0, 0.3] + let vecB: [Float] = [1.0, -0.5, 0.8, 2.1] + + // Swift reference (from EmbeddingResult.cosineSimilarity) + var dot: Float = 0, normA: Float = 0, normB: Float = 0 + for i in vecA.indices { + dot += vecA[i] * vecB[i] + normA += vecA[i] * vecA[i] + normB += vecB[i] * vecB[i] + } + let swiftResult = dot / (sqrt(normA) * sqrt(normB)) + + let cResult = conduit_cosine_similarity(vecA, vecB, 4) + #expect(abs(cResult - swiftResult) < 0.0001) + } +} diff --git a/Tests/ConduitTests/Core/AnyCodableTests.swift b/Tests/ConduitTests/Core/AnyCodableTests.swift new file mode 100644 index 0000000..4b090cb --- /dev/null +++ b/Tests/ConduitTests/Core/AnyCodableTests.swift @@ -0,0 +1,204 @@ +// AnyCodableTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("AnyCodable Tests") +struct AnyCodableTests { + + // MARK: - Value Cases + + @Test("Null value") + func nullValue() { + let value = AnyCodable(NSNull()) + #expect(value.value == .null) + #expect(value.anyValue is NSNull) + } + + @Test("Bool value") + func boolValue() { + let value = AnyCodable(true) + #expect(value.value == .bool(true)) + #expect(value.anyValue as? Bool == true) + } + + @Test("Int value") + func intValue() { + let value = AnyCodable(42) + #expect(value.value == .int(42)) + #expect(value.anyValue as? Int == 42) + } + + @Test("Double value") + func doubleValue() { + let value = AnyCodable(3.14) + #expect(value.value == .double(3.14)) + #expect(value.anyValue as? Double == 3.14) + } + + @Test("String value") + func stringValue() { + let value = AnyCodable("hello") + #expect(value.value == .string("hello")) + #expect(value.anyValue as? String == "hello") + } + + @Test("Array value") + func arrayValue() { + let value = AnyCodable([1, 2, 3] as [Any]) + if case .array(let arr) = value.value { + #expect(arr.count == 3) + } else { + Issue.record("Expected array value") + } + } + + @Test("Dictionary value") + func dictionaryValue() { + let value = AnyCodable(["key": "value"] as [String: Any]) + if case .object(let dict) = value.value { + #expect(dict["key"]?.value == .string("value")) + } else { + Issue.record("Expected object value") + } + } + + @Test("Unsupported type falls back to null") + func unsupportedType() { + let value = AnyCodable(Date()) + #expect(value.value == .null) + } + + // MARK: - Value Enum Init + + @Test("Init from Value enum") + func initFromValueEnum() { + let value = AnyCodable(value: .string("test")) + #expect(value.value == .string("test")) + } + + // MARK: - Codable Round-Trip + + @Test("Null round-trip") + func nullRoundTrip() throws { + let original = AnyCodable(value: .null) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(AnyCodable.self, from: data) + #expect(decoded.value == .null) + } + + @Test("Bool round-trip") + func boolRoundTrip() throws { + let original = AnyCodable(value: .bool(true)) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(AnyCodable.self, from: data) + #expect(decoded.value == .bool(true)) + } + + @Test("Int round-trip") + func intRoundTrip() throws { + let original = AnyCodable(value: .int(42)) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(AnyCodable.self, from: data) + #expect(decoded.value == .int(42)) + } + + @Test("Double round-trip") + func doubleRoundTrip() throws { + let original = AnyCodable(value: .double(3.14)) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(AnyCodable.self, from: data) + #expect(decoded.value == .double(3.14)) + } + + @Test("String round-trip") + func stringRoundTrip() throws { + let original = AnyCodable(value: .string("test")) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(AnyCodable.self, from: data) + #expect(decoded.value == .string("test")) + } + + @Test("Array round-trip") + func arrayRoundTrip() throws { + let original = AnyCodable(value: .array([ + AnyCodable(value: .int(1)), + AnyCodable(value: .string("two")) + ])) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(AnyCodable.self, from: data) + + if case .array(let arr) = decoded.value { + #expect(arr.count == 2) + #expect(arr[0].value == .int(1)) + #expect(arr[1].value == .string("two")) + } else { + Issue.record("Expected array value") + } + } + + @Test("Object round-trip") + func objectRoundTrip() throws { + let original = AnyCodable(value: .object([ + "name": AnyCodable(value: .string("test")), + "count": AnyCodable(value: .int(5)) + ])) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(AnyCodable.self, from: data) + + if case .object(let dict) = decoded.value { + #expect(dict["name"]?.value == .string("test")) + #expect(dict["count"]?.value == .int(5)) + } else { + Issue.record("Expected object value") + } + } + + // MARK: - Nested Structures + + @Test("Nested arrays round-trip") + func nestedArrays() throws { + let inner = AnyCodable(value: .array([ + AnyCodable(value: .int(1)), + AnyCodable(value: .int(2)) + ])) + let original = AnyCodable(value: .array([inner])) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(AnyCodable.self, from: data) + + if case .array(let outer) = decoded.value, + case .array(let innerArr) = outer[0].value { + #expect(innerArr.count == 2) + } else { + Issue.record("Expected nested array") + } + } + + // MARK: - Hashable + + @Test("Equal values have same hash") + func hashEquality() { + let a = AnyCodable(value: .string("test")) + let b = AnyCodable(value: .string("test")) + #expect(a == b) + #expect(a.hashValue == b.hashValue) + } + + @Test("Different values are unequal") + func hashInequality() { + let a = AnyCodable(value: .int(1)) + let b = AnyCodable(value: .int(2)) + #expect(a != b) + } + + @Test("Can be used in a Set") + func setUsage() { + var set: Set = [] + set.insert(AnyCodable(value: .string("a"))) + set.insert(AnyCodable(value: .string("b"))) + set.insert(AnyCodable(value: .string("a"))) + #expect(set.count == 2) + } +} diff --git a/Tests/ConduitTests/Core/EmbeddingResultTests.swift b/Tests/ConduitTests/Core/EmbeddingResultTests.swift new file mode 100644 index 0000000..194f7f5 --- /dev/null +++ b/Tests/ConduitTests/Core/EmbeddingResultTests.swift @@ -0,0 +1,183 @@ +// EmbeddingResultTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("EmbeddingResult Tests") +struct EmbeddingResultTests { + + // MARK: - Test Data + + static let sampleEmbedding = EmbeddingResult( + vector: [0.1, 0.2, 0.3], + text: "Hello", + model: "test-model", + tokenCount: 1 + ) + + // MARK: - Initialization + + @Test("Init stores all properties") + func initProperties() { + let result = EmbeddingResult( + vector: [1.0, 2.0, 3.0], + text: "test text", + model: "bge-small", + tokenCount: 5 + ) + #expect(result.vector == [1.0, 2.0, 3.0]) + #expect(result.text == "test text") + #expect(result.model == "bge-small") + #expect(result.tokenCount == 5) + } + + @Test("Init defaults tokenCount to nil") + func initDefaultTokenCount() { + let result = EmbeddingResult( + vector: [1.0], + text: "hi", + model: "model" + ) + #expect(result.tokenCount == nil) + } + + // MARK: - Dimensions + + @Test("dimensions returns vector count") + func dimensions() { + let result = EmbeddingResult( + vector: [0.1, 0.2, 0.3, 0.4, 0.5], + text: "test", + model: "model" + ) + #expect(result.dimensions == 5) + } + + @Test("dimensions is zero for empty vector") + func dimensionsEmpty() { + let result = EmbeddingResult( + vector: [], + text: "empty", + model: "model" + ) + #expect(result.dimensions == 0) + } + + // MARK: - Cosine Similarity + + @Test("cosineSimilarity of identical vectors is 1") + func cosineSimilaritySame() { + let a = EmbeddingResult(vector: [1.0, 0.0, 0.0], text: "a", model: "m") + let b = EmbeddingResult(vector: [1.0, 0.0, 0.0], text: "b", model: "m") + let similarity = a.cosineSimilarity(with: b) + #expect(abs(similarity - 1.0) < 0.0001) + } + + @Test("cosineSimilarity of orthogonal vectors is 0") + func cosineSimilarityOrthogonal() { + let a = EmbeddingResult(vector: [1.0, 0.0], text: "a", model: "m") + let b = EmbeddingResult(vector: [0.0, 1.0], text: "b", model: "m") + let similarity = a.cosineSimilarity(with: b) + #expect(abs(similarity) < 0.0001) + } + + @Test("cosineSimilarity of opposite vectors is -1") + func cosineSimilarityOpposite() { + let a = EmbeddingResult(vector: [1.0, 0.0], text: "a", model: "m") + let b = EmbeddingResult(vector: [-1.0, 0.0], text: "b", model: "m") + let similarity = a.cosineSimilarity(with: b) + #expect(abs(similarity - (-1.0)) < 0.0001) + } + + @Test("cosineSimilarity returns 0 for different dimensions") + func cosineSimilarityDifferentDimensions() { + let a = EmbeddingResult(vector: [1.0, 2.0], text: "a", model: "m") + let b = EmbeddingResult(vector: [1.0, 2.0, 3.0], text: "b", model: "m") + #expect(a.cosineSimilarity(with: b) == 0) + } + + @Test("cosineSimilarity returns 0 for zero vector") + func cosineSimilarityZeroVector() { + let a = EmbeddingResult(vector: [0.0, 0.0], text: "a", model: "m") + let b = EmbeddingResult(vector: [1.0, 2.0], text: "b", model: "m") + #expect(a.cosineSimilarity(with: b) == 0) + } + + // MARK: - Euclidean Distance + + @Test("euclideanDistance of identical vectors is 0") + func euclideanDistanceSame() { + let a = EmbeddingResult(vector: [1.0, 2.0, 3.0], text: "a", model: "m") + let b = EmbeddingResult(vector: [1.0, 2.0, 3.0], text: "b", model: "m") + #expect(abs(a.euclideanDistance(to: b)) < 0.0001) + } + + @Test("euclideanDistance of known vectors is correct") + func euclideanDistanceKnown() { + let a = EmbeddingResult(vector: [0.0, 0.0], text: "a", model: "m") + let b = EmbeddingResult(vector: [3.0, 4.0], text: "b", model: "m") + let distance = a.euclideanDistance(to: b) + #expect(abs(distance - 5.0) < 0.0001) + } + + @Test("euclideanDistance returns infinity for different dimensions") + func euclideanDistanceDifferentDimensions() { + let a = EmbeddingResult(vector: [1.0, 2.0], text: "a", model: "m") + let b = EmbeddingResult(vector: [1.0, 2.0, 3.0], text: "b", model: "m") + #expect(a.euclideanDistance(to: b) == .infinity) + } + + // MARK: - Dot Product + + @Test("dotProduct of known vectors is correct") + func dotProductKnown() { + let a = EmbeddingResult(vector: [1.0, 2.0, 3.0], text: "a", model: "m") + let b = EmbeddingResult(vector: [4.0, 5.0, 6.0], text: "b", model: "m") + // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32 + let result = a.dotProduct(with: b) + #expect(abs(result - 32.0) < 0.0001) + } + + @Test("dotProduct of orthogonal vectors is 0") + func dotProductOrthogonal() { + let a = EmbeddingResult(vector: [1.0, 0.0], text: "a", model: "m") + let b = EmbeddingResult(vector: [0.0, 1.0], text: "b", model: "m") + #expect(abs(a.dotProduct(with: b)) < 0.0001) + } + + @Test("dotProduct returns 0 for different dimensions") + func dotProductDifferentDimensions() { + let a = EmbeddingResult(vector: [1.0, 2.0], text: "a", model: "m") + let b = EmbeddingResult(vector: [1.0, 2.0, 3.0], text: "b", model: "m") + #expect(a.dotProduct(with: b) == 0) + } + + // MARK: - Hashable + + @Test("Equal embeddings have same hash") + func hashableEqual() { + let a = EmbeddingResult(vector: [1.0, 2.0], text: "hello", model: "m") + let b = EmbeddingResult(vector: [1.0, 2.0], text: "hello", model: "m") + #expect(a.hashValue == b.hashValue) + } + + @Test("Embeddings work in Set") + func hashableSet() { + let a = EmbeddingResult(vector: [1.0, 2.0], text: "hello", model: "m") + let b = EmbeddingResult(vector: [1.0, 2.0], text: "hello", model: "m") + let c = EmbeddingResult(vector: [3.0, 4.0], text: "world", model: "m") + let set: Set = [a, b, c] + #expect(set.count == 2) + } + + // MARK: - Sendable + + @Test("EmbeddingResult is Sendable") + func sendable() async { + let result = Self.sampleEmbedding + let text = await Task { result.text }.value + #expect(text == "Hello") + } +} diff --git a/Tests/ConduitTests/Core/FinishReasonTests.swift b/Tests/ConduitTests/Core/FinishReasonTests.swift new file mode 100644 index 0000000..3c5ba34 --- /dev/null +++ b/Tests/ConduitTests/Core/FinishReasonTests.swift @@ -0,0 +1,100 @@ +// FinishReasonTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("FinishReason Tests") +struct FinishReasonTests { + + // MARK: - Raw Values + + @Test("All raw values match expected wire format") + func rawValues() { + #expect(FinishReason.stop.rawValue == "stop") + #expect(FinishReason.maxTokens.rawValue == "max_tokens") + #expect(FinishReason.stopSequence.rawValue == "stop_sequence") + #expect(FinishReason.cancelled.rawValue == "cancelled") + #expect(FinishReason.contentFilter.rawValue == "content_filter") + #expect(FinishReason.toolCall.rawValue == "tool_call") + #expect(FinishReason.toolCalls.rawValue == "tool_calls") + #expect(FinishReason.pauseTurn.rawValue == "pause_turn") + #expect(FinishReason.modelContextWindowExceeded.rawValue == "model_context_window_exceeded") + } + + // MARK: - isToolCallRequest + + @Test("isToolCallRequest returns true for toolCall") + func isToolCallRequestSingular() { + #expect(FinishReason.toolCall.isToolCallRequest) + } + + @Test("isToolCallRequest returns true for toolCalls") + func isToolCallRequestPlural() { + #expect(FinishReason.toolCalls.isToolCallRequest) + } + + @Test("isToolCallRequest returns false for non-tool-call reasons", + arguments: [ + FinishReason.stop, + .maxTokens, + .stopSequence, + .cancelled, + .contentFilter, + .pauseTurn, + .modelContextWindowExceeded + ]) + func isToolCallRequestFalse(reason: FinishReason) { + #expect(!reason.isToolCallRequest) + } + + // MARK: - Codable + + @Test("Codable round-trip for all cases", + arguments: [ + FinishReason.stop, + .maxTokens, + .stopSequence, + .cancelled, + .contentFilter, + .toolCall, + .toolCalls, + .pauseTurn, + .modelContextWindowExceeded + ]) + func codableRoundTrip(reason: FinishReason) throws { + let data = try JSONEncoder().encode(reason) + let decoded = try JSONDecoder().decode(FinishReason.self, from: data) + #expect(reason == decoded) + } + + @Test("Decodes from raw value string") + func decodesFromString() throws { + let json = Data("\"max_tokens\"".utf8) + let decoded = try JSONDecoder().decode(FinishReason.self, from: json) + #expect(decoded == .maxTokens) + } + + @Test("Decodes tool_call from wire format") + func decodesToolCall() throws { + let json = Data("\"tool_call\"".utf8) + let decoded = try JSONDecoder().decode(FinishReason.self, from: json) + #expect(decoded == .toolCall) + } + + @Test("Decodes tool_calls from wire format") + func decodesToolCalls() throws { + let json = Data("\"tool_calls\"".utf8) + let decoded = try JSONDecoder().decode(FinishReason.self, from: json) + #expect(decoded == .toolCalls) + } + + // MARK: - Hashable + + @Test("Can be used in a Set") + func hashable() { + let reasons: Set = [.stop, .maxTokens, .stop] + #expect(reasons.count == 2) + } +} diff --git a/Tests/ConduitTests/Core/GenerationChunkTests.swift b/Tests/ConduitTests/Core/GenerationChunkTests.swift new file mode 100644 index 0000000..2be3b1a --- /dev/null +++ b/Tests/ConduitTests/Core/GenerationChunkTests.swift @@ -0,0 +1,166 @@ +// GenerationChunkTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("GenerationChunk Tests") +struct GenerationChunkTests { + + // MARK: - Initialization + + @Test("Default init stores text with default values") + func defaultInit() { + let chunk = GenerationChunk(text: "Hello") + + #expect(chunk.text == "Hello") + #expect(chunk.tokenCount == 1) + #expect(chunk.tokenId == nil) + #expect(chunk.logprob == nil) + #expect(chunk.topLogprobs == nil) + #expect(chunk.tokensPerSecond == nil) + #expect(!chunk.isComplete) + #expect(chunk.finishReason == nil) + #expect(chunk.usage == nil) + #expect(chunk.partialToolCall == nil) + #expect(chunk.completedToolCalls == nil) + #expect(chunk.reasoningDetails == nil) + } + + @Test("Full init stores all properties") + func fullInit() { + let timestamp = Date() + let usage = UsageStats(promptTokens: 10, completionTokens: 20) + let logprob = TokenLogprob(token: "hello", logprob: -0.5) + + let chunk = GenerationChunk( + text: "test", + tokenCount: 2, + tokenId: 42, + logprob: -0.3, + topLogprobs: [logprob], + tokensPerSecond: 50.0, + isComplete: true, + finishReason: .stop, + timestamp: timestamp, + usage: usage + ) + + #expect(chunk.text == "test") + #expect(chunk.tokenCount == 2) + #expect(chunk.tokenId == 42) + #expect(chunk.logprob == -0.3) + #expect(chunk.topLogprobs?.count == 1) + #expect(chunk.tokensPerSecond == 50.0) + #expect(chunk.isComplete) + #expect(chunk.finishReason == .stop) + #expect(chunk.timestamp == timestamp) + #expect(chunk.usage == usage) + } + + // MARK: - Factory Methods + + @Test(".completion() creates final chunk with empty text") + func completionFactory() { + let chunk = GenerationChunk.completion(finishReason: .stop) + + #expect(chunk.text == "") + #expect(chunk.tokenCount == 0) + #expect(chunk.isComplete) + #expect(chunk.finishReason == .stop) + } + + @Test(".completion() with maxTokens reason") + func completionMaxTokens() { + let chunk = GenerationChunk.completion(finishReason: .maxTokens) + + #expect(chunk.isComplete) + #expect(chunk.finishReason == .maxTokens) + } + + // MARK: - Computed Properties + + @Test("hasToolCallUpdates is false with no tool data") + func hasToolCallUpdatesFalse() { + let chunk = GenerationChunk(text: "test") + #expect(!chunk.hasToolCallUpdates) + } + + @Test("hasToolCallUpdates is true with partial tool call") + func hasToolCallUpdatesPartial() { + let partial = PartialToolCall( + id: "call-1", + toolName: "weather", + index: 0, + argumentsFragment: "{\"loc\":" + ) + let chunk = GenerationChunk(text: "", partialToolCall: partial) + #expect(chunk.hasToolCallUpdates) + } + + @Test("hasToolCallUpdates is true with completed tool calls") + func hasToolCallUpdatesCompleted() { + let toolCall = Transcript.ToolCall( + id: "call-1", + toolName: "weather", + arguments: GeneratedContent(kind: .null) + ) + let chunk = GenerationChunk(text: "", completedToolCalls: [toolCall]) + #expect(chunk.hasToolCallUpdates) + } + + @Test("hasToolCallUpdates is false with empty completed tool calls array") + func hasToolCallUpdatesEmptyCompleted() { + let chunk = GenerationChunk(text: "", completedToolCalls: []) + #expect(!chunk.hasToolCallUpdates) + } + + @Test("hasReasoningDetails is false with nil") + func hasReasoningDetailsFalse() { + let chunk = GenerationChunk(text: "test") + #expect(!chunk.hasReasoningDetails) + } + + @Test("hasReasoningDetails is false with empty array") + func hasReasoningDetailsEmpty() { + let chunk = GenerationChunk(text: "test", reasoningDetails: []) + #expect(!chunk.hasReasoningDetails) + } + + // MARK: - Equatable + + @Test("Equal chunks are equal") + func equality() { + let timestamp = Date() + let a = GenerationChunk(text: "hello", timestamp: timestamp) + let b = GenerationChunk(text: "hello", timestamp: timestamp) + #expect(a == b) + } + + @Test("Different text makes chunks unequal") + func inequalityText() { + let timestamp = Date() + let a = GenerationChunk(text: "hello", timestamp: timestamp) + let b = GenerationChunk(text: "world", timestamp: timestamp) + #expect(a != b) + } + + @Test("Different timestamps make chunks unequal") + func inequalityTimestamp() { + let a = GenerationChunk(text: "hello", timestamp: Date(timeIntervalSince1970: 0)) + let b = GenerationChunk(text: "hello", timestamp: Date(timeIntervalSince1970: 1)) + #expect(a != b) + } +} + +// MARK: - maxToolCallIndex Tests + +@Suite("maxToolCallIndex Tests") +struct MaxToolCallIndexTests { + + @Test("maxToolCallIndex is 100") + func value() { + #expect(maxToolCallIndex == 100) + } +} diff --git a/Tests/ConduitTests/Core/GenerationResultTests.swift b/Tests/ConduitTests/Core/GenerationResultTests.swift new file mode 100644 index 0000000..69e97c6 --- /dev/null +++ b/Tests/ConduitTests/Core/GenerationResultTests.swift @@ -0,0 +1,146 @@ +// GenerationResultTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("GenerationResult Tests") +struct GenerationResultTests { + + // MARK: - Initialization + + @Test("Full initialization stores all properties") + func fullInit() { + let usage = UsageStats(promptTokens: 10, completionTokens: 20) + let rateLimit = RateLimitInfo(requestId: "req-1") + let logprob = TokenLogprob(token: "hello", logprob: -0.5) + + let result = GenerationResult( + text: "Hello world", + tokenCount: 5, + generationTime: 1.2, + tokensPerSecond: 4.17, + finishReason: .stop, + logprobs: [logprob], + usage: usage, + rateLimitInfo: rateLimit, + toolCalls: [], + reasoningDetails: [] + ) + + #expect(result.text == "Hello world") + #expect(result.tokenCount == 5) + #expect(result.generationTime == 1.2) + #expect(result.tokensPerSecond == 4.17) + #expect(result.finishReason == .stop) + #expect(result.logprobs?.count == 1) + #expect(result.usage == usage) + #expect(result.rateLimitInfo == rateLimit) + #expect(result.toolCalls.isEmpty) + #expect(result.reasoningDetails.isEmpty) + } + + @Test("Default parameters produce empty collections") + func defaultParams() { + let result = GenerationResult( + text: "test", + tokenCount: 1, + generationTime: 0.1, + tokensPerSecond: 10, + finishReason: .stop + ) + + #expect(result.logprobs == nil) + #expect(result.usage == nil) + #expect(result.rateLimitInfo == nil) + #expect(result.toolCalls.isEmpty) + #expect(result.reasoningDetails.isEmpty) + } + + // MARK: - Factory Methods + + @Test(".text() factory creates result with default metadata") + func textFactory() { + let result = GenerationResult.text("Hello") + + #expect(result.text == "Hello") + #expect(result.tokenCount == 0) + #expect(result.generationTime == 0) + #expect(result.tokensPerSecond == 0) + #expect(result.finishReason == .stop) + #expect(result.toolCalls.isEmpty) + #expect(result.reasoningDetails.isEmpty) + } + + // MARK: - Computed Properties + + @Test("hasToolCalls returns false when no tool calls") + func hasToolCallsFalse() { + let result = GenerationResult.text("Hello") + #expect(!result.hasToolCalls) + } + + @Test("hasReasoningDetails returns false when empty") + func hasReasoningDetailsFalse() { + let result = GenerationResult.text("Hello") + #expect(!result.hasReasoningDetails) + } + + // MARK: - Equatable + + @Test("Equal results compare equal") + func equality() { + let a = GenerationResult.text("Hello") + let b = GenerationResult.text("Hello") + #expect(a == b) + } + + @Test("Different text makes results unequal") + func inequality() { + let a = GenerationResult.text("Hello") + let b = GenerationResult.text("World") + #expect(a != b) + } + + @Test("Different finish reasons make results unequal") + func inequalityFinishReason() { + let a = GenerationResult( + text: "Hello", tokenCount: 0, generationTime: 0, + tokensPerSecond: 0, finishReason: .stop + ) + let b = GenerationResult( + text: "Hello", tokenCount: 0, generationTime: 0, + tokensPerSecond: 0, finishReason: .maxTokens + ) + #expect(a != b) + } + + // MARK: - Message Bridge + + @Test("assistantMessage creates message with correct role and content") + func assistantMessage() { + let result = GenerationResult( + text: "Response text", + tokenCount: 10, + generationTime: 0.5, + tokensPerSecond: 20, + finishReason: .stop + ) + + let message = result.assistantMessage() + + #expect(message.role == .assistant) + #expect(message.content.textValue == "Response text") + #expect(message.metadata?.tokenCount == 10) + #expect(message.metadata?.generationTime == 0.5) + #expect(message.metadata?.tokensPerSecond == 20) + } + + @Test("assistantMessage with no tool calls has nil toolCalls in metadata") + func assistantMessageNoToolCalls() { + let result = GenerationResult.text("Hello") + let message = result.assistantMessage() + #expect(message.metadata?.toolCalls == nil) + } +} diff --git a/Tests/ConduitTests/Core/GenerationSchemaTests.swift b/Tests/ConduitTests/Core/GenerationSchemaTests.swift new file mode 100644 index 0000000..6f6483e --- /dev/null +++ b/Tests/ConduitTests/Core/GenerationSchemaTests.swift @@ -0,0 +1,324 @@ +// GenerationSchemaTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("GenerationSchema Node Tests") +struct GenerationSchemaNodeTests { + + // MARK: - Boolean Node + + @Test("Boolean node encodes as type boolean") + func booleanEncode() throws { + let node = GenerationSchema.Node.boolean + let data = try JSONEncoder().encode(node) + let json = String(data: data, encoding: .utf8)! + #expect(json.contains("\"type\":\"boolean\"") || json.contains("\"type\" : \"boolean\"")) + } + + @Test("Boolean node round-trips through Codable") + func booleanRoundTrip() throws { + let original = GenerationSchema.Node.boolean + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(GenerationSchema.Node.self, from: data) + + if case .boolean = decoded { + // success + } else { + Issue.record("Expected boolean node") + } + } + + // MARK: - String Node + + @Test("String node encodes with type string") + func stringEncode() throws { + let node = GenerationSchema.Node.string( + GenerationSchema.StringNode(description: "A name", pattern: nil, enumChoices: nil) + ) + let data = try JSONEncoder().encode(node) + let json = String(data: data, encoding: .utf8)! + #expect(json.contains("string")) + #expect(json.contains("A name")) + } + + @Test("String node with enum choices round-trips") + func stringEnumRoundTrip() throws { + let original = GenerationSchema.Node.string( + GenerationSchema.StringNode( + description: "Color", + pattern: nil, + enumChoices: ["red", "green", "blue"] + ) + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(GenerationSchema.Node.self, from: data) + + if case .string(let str) = decoded { + #expect(str.enumChoices == ["red", "green", "blue"]) + #expect(str.description == "Color") + } else { + Issue.record("Expected string node") + } + } + + @Test("String node with pattern round-trips") + func stringPatternRoundTrip() throws { + let original = GenerationSchema.Node.string( + GenerationSchema.StringNode( + description: nil, + pattern: "^[a-z]+$", + enumChoices: nil + ) + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(GenerationSchema.Node.self, from: data) + + if case .string(let str) = decoded { + #expect(str.pattern == "^[a-z]+$") + } else { + Issue.record("Expected string node") + } + } + + // MARK: - Number Node + + @Test("Number node encodes as type number") + func numberEncode() throws { + let node = GenerationSchema.Node.number( + GenerationSchema.NumberNode( + description: "Temperature", + minimum: 0, + maximum: 100, + integerOnly: false + ) + ) + let data = try JSONEncoder().encode(node) + let json = String(data: data, encoding: .utf8)! + #expect(json.contains("number")) + } + + @Test("Integer-only number node encodes as type integer") + func integerEncode() throws { + let node = GenerationSchema.Node.number( + GenerationSchema.NumberNode( + description: "Count", + minimum: nil, + maximum: nil, + integerOnly: true + ) + ) + let data = try JSONEncoder().encode(node) + let json = String(data: data, encoding: .utf8)! + #expect(json.contains("integer")) + } + + @Test("Number node with range round-trips") + func numberRangeRoundTrip() throws { + let original = GenerationSchema.Node.number( + GenerationSchema.NumberNode( + description: "Score", + minimum: 1.0, + maximum: 10.0, + integerOnly: false + ) + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(GenerationSchema.Node.self, from: data) + + if case .number(let num) = decoded { + #expect(num.minimum == 1.0) + #expect(num.maximum == 10.0) + #expect(!num.integerOnly) + #expect(num.description == "Score") + } else { + Issue.record("Expected number node") + } + } + + // MARK: - Array Node + + @Test("Array node encodes with items") + func arrayEncode() throws { + let node = GenerationSchema.Node.array( + GenerationSchema.ArrayNode( + description: "Tags", + items: .string(GenerationSchema.StringNode(description: nil, pattern: nil, enumChoices: nil)), + minItems: 1, + maxItems: 10 + ) + ) + let data = try JSONEncoder().encode(node) + let json = String(data: data, encoding: .utf8)! + #expect(json.contains("array")) + } + + @Test("Array node round-trips with constraints") + func arrayRoundTrip() throws { + let original = GenerationSchema.Node.array( + GenerationSchema.ArrayNode( + description: nil, + items: .boolean, + minItems: 2, + maxItems: 5 + ) + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(GenerationSchema.Node.self, from: data) + + if case .array(let arr) = decoded { + #expect(arr.minItems == 2) + #expect(arr.maxItems == 5) + if case .boolean = arr.items { + // correct + } else { + Issue.record("Expected boolean items") + } + } else { + Issue.record("Expected array node") + } + } + + // MARK: - Object Node + + @Test("Object node encodes with properties") + func objectEncode() throws { + let node = GenerationSchema.Node.object( + GenerationSchema.ObjectNode( + description: "A person", + properties: [ + "name": .string(GenerationSchema.StringNode(description: nil, pattern: nil, enumChoices: nil)), + "age": .number(GenerationSchema.NumberNode(description: nil, minimum: 0, maximum: nil, integerOnly: true)) + ], + required: ["name", "age"] + ) + ) + let data = try JSONEncoder().encode(node) + let json = String(data: data, encoding: .utf8)! + #expect(json.contains("object")) + #expect(json.contains("name")) + #expect(json.contains("age")) + } + + @Test("Object node round-trips") + func objectRoundTrip() throws { + let original = GenerationSchema.Node.object( + GenerationSchema.ObjectNode( + description: "Test", + properties: [ + "flag": .boolean + ], + required: ["flag"] + ) + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(GenerationSchema.Node.self, from: data) + + if case .object(let obj) = decoded { + #expect(obj.description == "Test") + #expect(obj.properties.count == 1) + #expect(obj.required.contains("flag")) + } else { + Issue.record("Expected object node") + } + } + + // MARK: - Ref Node + + @Test("Ref node encodes with $ref prefix") + func refEncode() throws { + let node = GenerationSchema.Node.ref("MyType") + let data = try JSONEncoder().encode(node) + let json = String(data: data, encoding: .utf8)! + #expect(json.contains("#/$defs/MyType")) + } + + @Test("Ref node round-trips") + func refRoundTrip() throws { + let original = GenerationSchema.Node.ref("MyType") + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(GenerationSchema.Node.self, from: data) + + if case .ref(let name) = decoded { + #expect(name == "MyType") + } else { + Issue.record("Expected ref node") + } + } + + // MARK: - AnyOf Node + + @Test("AnyOf node encodes with choices") + func anyOfEncode() throws { + let node = GenerationSchema.Node.anyOf([.boolean, .ref("Option")]) + let data = try JSONEncoder().encode(node) + let json = String(data: data, encoding: .utf8)! + #expect(json.contains("anyOf")) + } + + @Test("AnyOf node round-trips") + func anyOfRoundTrip() throws { + let original = GenerationSchema.Node.anyOf([ + .boolean, + .string(GenerationSchema.StringNode(description: nil, pattern: nil, enumChoices: nil)) + ]) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(GenerationSchema.Node.self, from: data) + + if case .anyOf(let nodes) = decoded { + #expect(nodes.count == 2) + } else { + Issue.record("Expected anyOf node") + } + } +} + +// MARK: - SchemaError Tests + +@Suite("GenerationSchema.SchemaError Tests") +struct GenerationSchemaErrorTests { + + @Test("duplicateType has descriptive message") + func duplicateTypeError() { + let error = GenerationSchema.SchemaError.duplicateType( + schema: "root", + type: "MyType", + context: .init(debugDescription: "test") + ) + #expect(error.errorDescription?.contains("MyType") == true) + } + + @Test("emptyTypeChoices has descriptive message") + func emptyTypeChoicesError() { + let error = GenerationSchema.SchemaError.emptyTypeChoices( + schema: "TestEnum", + context: .init(debugDescription: "test") + ) + #expect(error.errorDescription?.contains("TestEnum") == true) + } + + @Test("undefinedReferences lists missing refs") + func undefinedReferencesError() { + let error = GenerationSchema.SchemaError.undefinedReferences( + schema: "root", + references: ["Foo", "Bar"], + context: .init(debugDescription: "test") + ) + #expect(error.errorDescription?.contains("Foo") == true) + } + + @Test("All errors have recovery suggestions") + func recoverySuggestions() { + let errors: [GenerationSchema.SchemaError] = [ + .duplicateType(schema: nil, type: "T", context: .init(debugDescription: "")), + .duplicateProperty(schema: "S", property: "p", context: .init(debugDescription: "")), + .emptyTypeChoices(schema: "S", context: .init(debugDescription: "")), + .undefinedReferences(schema: nil, references: [], context: .init(debugDescription: "")) + ] + for error in errors { + #expect(error.recoverySuggestion != nil) + } + } +} diff --git a/Tests/ConduitTests/Core/ModelCapabilitiesTests.swift b/Tests/ConduitTests/Core/ModelCapabilitiesTests.swift new file mode 100644 index 0000000..8db644d --- /dev/null +++ b/Tests/ConduitTests/Core/ModelCapabilitiesTests.swift @@ -0,0 +1,163 @@ +// ModelCapabilitiesTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("ModelCapabilities Tests") +struct ModelCapabilitiesTests { + + // MARK: - Static Presets + + @Test("textOnly preset has correct capabilities") + func textOnlyPreset() { + let caps = ModelCapabilities.textOnly + + #expect(!caps.supportsVision) + #expect(caps.supportsTextGeneration) + #expect(!caps.supportsEmbeddings) + #expect(caps.architectureType == nil) + #expect(caps.contextWindowSize == nil) + } + + @Test("vlm preset has correct capabilities") + func vlmPreset() { + let caps = ModelCapabilities.vlm + + #expect(caps.supportsVision) + #expect(caps.supportsTextGeneration) + #expect(!caps.supportsEmbeddings) + #expect(caps.architectureType == .vlm) + } + + @Test("embedding preset has correct capabilities") + func embeddingPreset() { + let caps = ModelCapabilities.embedding + + #expect(!caps.supportsVision) + #expect(!caps.supportsTextGeneration) + #expect(caps.supportsEmbeddings) + #expect(caps.architectureType == nil) + } + + // MARK: - Custom Init + + @Test("Custom init stores all properties") + func customInit() { + let caps = ModelCapabilities( + supportsVision: true, + supportsTextGeneration: true, + supportsEmbeddings: false, + architectureType: .qwen2VL, + contextWindowSize: 32768 + ) + + #expect(caps.supportsVision) + #expect(caps.supportsTextGeneration) + #expect(!caps.supportsEmbeddings) + #expect(caps.architectureType == .qwen2VL) + #expect(caps.contextWindowSize == 32768) + } + + // MARK: - Hashable + + @Test("Equal capabilities have same hash") + func hashEquality() { + let a = ModelCapabilities.textOnly + let b = ModelCapabilities.textOnly + #expect(a == b) + #expect(a.hashValue == b.hashValue) + } + + @Test("Different capabilities are unequal") + func inequality() { + #expect(ModelCapabilities.textOnly != ModelCapabilities.vlm) + #expect(ModelCapabilities.textOnly != ModelCapabilities.embedding) + #expect(ModelCapabilities.vlm != ModelCapabilities.embedding) + } +} + +// MARK: - ArchitectureType Tests + +@Suite("ArchitectureType Tests") +struct ArchitectureTypeTests { + + // MARK: - supportsVision + + @Test("Vision architectures return true for supportsVision", + arguments: [ + ArchitectureType.vlm, + .llava, + .qwen2VL, + .pixtral, + .paligemma, + .idefics, + .mllama, + .phi3Vision, + .cogvlm, + .internvl, + .minicpmV, + .florence, + .blip + ]) + func visionArchitectures(arch: ArchitectureType) { + #expect(arch.supportsVision) + } + + @Test("Non-vision architectures return false for supportsVision", + arguments: [ + ArchitectureType.llama, + .mistral, + .qwen, + .phi, + .gemma, + .bert, + .bge, + .nomic + ]) + func nonVisionArchitectures(arch: ArchitectureType) { + #expect(!arch.supportsVision) + } + + // MARK: - CaseIterable + + @Test("All architecture cases are enumerated") + func allCases() { + // 25 total: 5 text + 13 vision + 3 embedding + 4 others + #expect(ArchitectureType.allCases.count > 20) + } + + // MARK: - Raw Values + + @Test("Custom raw values are correct") + func customRawValues() { + #expect(ArchitectureType.qwen2VL.rawValue == "qwen2_vl") + #expect(ArchitectureType.phi3Vision.rawValue == "phi3_v") + #expect(ArchitectureType.minicpmV.rawValue == "minicpm_v") + } + + @Test("Default raw values match case name") + func defaultRawValues() { + #expect(ArchitectureType.llama.rawValue == "llama") + #expect(ArchitectureType.mistral.rawValue == "mistral") + #expect(ArchitectureType.bert.rawValue == "bert") + } + + // MARK: - Codable + + @Test("Codable round-trip for all cases", + arguments: ArchitectureType.allCases) + func codableRoundTrip(arch: ArchitectureType) throws { + let data = try JSONEncoder().encode(arch) + let decoded = try JSONDecoder().decode(ArchitectureType.self, from: data) + #expect(arch == decoded) + } + + @Test("Decodes from raw value string") + func decodesFromRawValue() throws { + let json = Data("\"qwen2_vl\"".utf8) + let decoded = try JSONDecoder().decode(ArchitectureType.self, from: json) + #expect(decoded == .qwen2VL) + } +} diff --git a/Tests/ConduitTests/Core/RateLimitInfoTests.swift b/Tests/ConduitTests/Core/RateLimitInfoTests.swift new file mode 100644 index 0000000..d77b9ff --- /dev/null +++ b/Tests/ConduitTests/Core/RateLimitInfoTests.swift @@ -0,0 +1,174 @@ +// RateLimitInfoTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("RateLimitInfo Tests") +struct RateLimitInfoTests { + + // MARK: - Header Parsing + + @Test("Parses all Anthropic headers") + func parsesAllHeaders() { + let headers: [String: String] = [ + "request-id": "req-abc123", + "anthropic-organization-id": "org-xyz", + "anthropic-ratelimit-requests-limit": "100", + "anthropic-ratelimit-tokens-limit": "50000", + "anthropic-ratelimit-requests-remaining": "95", + "anthropic-ratelimit-tokens-remaining": "48000", + "retry-after": "30" + ] + + let info = RateLimitInfo(headers: headers) + + #expect(info.requestId == "req-abc123") + #expect(info.organizationId == "org-xyz") + #expect(info.limitRequests == 100) + #expect(info.limitTokens == 50000) + #expect(info.remainingRequests == 95) + #expect(info.remainingTokens == 48000) + #expect(info.retryAfter == 30) + } + + @Test("Case-insensitive header matching") + func caseInsensitive() { + let headers: [String: String] = [ + "Request-Id": "req-upper", + "Anthropic-Organization-Id": "org-upper", + "Anthropic-Ratelimit-Requests-Limit": "200" + ] + + let info = RateLimitInfo(headers: headers) + + #expect(info.requestId == "req-upper") + #expect(info.organizationId == "org-upper") + #expect(info.limitRequests == 200) + } + + @Test("Missing headers produce nil values") + func missingHeaders() { + let info = RateLimitInfo(headers: [:]) + + #expect(info.requestId == nil) + #expect(info.organizationId == nil) + #expect(info.limitRequests == nil) + #expect(info.limitTokens == nil) + #expect(info.remainingRequests == nil) + #expect(info.remainingTokens == nil) + #expect(info.resetRequests == nil) + #expect(info.resetTokens == nil) + #expect(info.retryAfter == nil) + } + + @Test("Invalid numeric values produce nil") + func invalidNumericValues() { + let headers: [String: String] = [ + "anthropic-ratelimit-requests-limit": "not-a-number", + "retry-after": "invalid" + ] + + let info = RateLimitInfo(headers: headers) + + #expect(info.limitRequests == nil) + #expect(info.retryAfter == nil) + } + + // MARK: - Date Parsing + + @Test("Parses ISO8601 date with fractional seconds") + func parsesDateWithFractional() { + let headers: [String: String] = [ + "anthropic-ratelimit-requests-reset": "2025-01-15T10:30:00.500Z" + ] + + let info = RateLimitInfo(headers: headers) + #expect(info.resetRequests != nil) + } + + @Test("Parses ISO8601 date without fractional seconds") + func parsesDateWithoutFractional() { + let headers: [String: String] = [ + "anthropic-ratelimit-tokens-reset": "2025-01-15T10:30:00Z" + ] + + let info = RateLimitInfo(headers: headers) + #expect(info.resetTokens != nil) + } + + @Test("Invalid date string produces nil") + func invalidDate() { + let headers: [String: String] = [ + "anthropic-ratelimit-requests-reset": "not-a-date" + ] + + let info = RateLimitInfo(headers: headers) + #expect(info.resetRequests == nil) + } + + // MARK: - Explicit Init + + @Test("Explicit init stores all values") + func explicitInit() { + let date = Date() + let info = RateLimitInfo( + requestId: "req-1", + organizationId: "org-1", + limitRequests: 100, + limitTokens: 50000, + remainingRequests: 95, + remainingTokens: 48000, + resetRequests: date, + resetTokens: date, + retryAfter: 60 + ) + + #expect(info.requestId == "req-1") + #expect(info.organizationId == "org-1") + #expect(info.limitRequests == 100) + #expect(info.limitTokens == 50000) + #expect(info.remainingRequests == 95) + #expect(info.remainingTokens == 48000) + #expect(info.resetRequests == date) + #expect(info.resetTokens == date) + #expect(info.retryAfter == 60) + } + + // MARK: - Hashable + + @Test("Equal RateLimitInfo values are equal") + func equality() { + let a = RateLimitInfo(requestId: "req-1", limitRequests: 100) + let b = RateLimitInfo(requestId: "req-1", limitRequests: 100) + #expect(a == b) + } + + @Test("Different RateLimitInfo values are unequal") + func inequality() { + let a = RateLimitInfo(requestId: "req-1") + let b = RateLimitInfo(requestId: "req-2") + #expect(a != b) + } + + // MARK: - Codable + + @Test("Codable round-trip preserves values") + func codableRoundTrip() throws { + let original = RateLimitInfo( + requestId: "req-1", + organizationId: "org-1", + limitRequests: 100, + limitTokens: 50000, + remainingRequests: 95, + remainingTokens: 48000, + retryAfter: 30 + ) + + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(RateLimitInfo.self, from: data) + + #expect(original == decoded) + } +} diff --git a/Tests/ConduitTests/Core/TokenLogprobTests.swift b/Tests/ConduitTests/Core/TokenLogprobTests.swift new file mode 100644 index 0000000..5ca6131 --- /dev/null +++ b/Tests/ConduitTests/Core/TokenLogprobTests.swift @@ -0,0 +1,106 @@ +// TokenLogprobTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("TokenLogprob Tests") +struct TokenLogprobTests { + + // MARK: - Initialization + + @Test("Init stores all properties") + func initProperties() { + let logprob = TokenLogprob(token: "hello", logprob: -0.5, tokenId: 42) + #expect(logprob.token == "hello") + #expect(logprob.logprob == -0.5) + #expect(logprob.tokenId == 42) + } + + @Test("Init defaults tokenId to nil") + func initDefaultTokenId() { + let logprob = TokenLogprob(token: "world", logprob: -1.0) + #expect(logprob.tokenId == nil) + } + + // MARK: - Probability + + @Test("probability computes exp of logprob") + func probabilityComputation() { + let logprob = TokenLogprob(token: "a", logprob: 0.0) + // exp(0) = 1.0 + #expect(abs(logprob.probability - 1.0) < 0.0001) + } + + @Test("probability of -1 is approximately 0.368") + func probabilityNegativeOne() { + let logprob = TokenLogprob(token: "b", logprob: -1.0) + // exp(-1) ≈ 0.3679 + #expect(abs(logprob.probability - exp(-1.0)) < 0.0001) + } + + @Test("probability of very negative logprob is near 0") + func probabilityVeryNegative() { + let logprob = TokenLogprob(token: "c", logprob: -100.0) + #expect(logprob.probability >= 0) + #expect(logprob.probability < 0.0001) + } + + // MARK: - Codable + + @Test("Codable round-trip preserves all fields") + func codableRoundTrip() throws { + let original = TokenLogprob(token: "hello", logprob: -0.5, tokenId: 42) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(TokenLogprob.self, from: data) + #expect(decoded.token == original.token) + #expect(decoded.logprob == original.logprob) + #expect(decoded.tokenId == original.tokenId) + } + + @Test("Codable round-trip with nil tokenId") + func codableRoundTripNilTokenId() throws { + let original = TokenLogprob(token: "world", logprob: -1.0) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(TokenLogprob.self, from: data) + #expect(decoded.token == original.token) + #expect(decoded.logprob == original.logprob) + #expect(decoded.tokenId == nil) + } + + // MARK: - Hashable + + @Test("Equal TokenLogprobs have same hash") + func hashableEqual() { + let a = TokenLogprob(token: "x", logprob: -0.5, tokenId: 1) + let b = TokenLogprob(token: "x", logprob: -0.5, tokenId: 1) + #expect(a == b) + #expect(a.hashValue == b.hashValue) + } + + @Test("Different TokenLogprobs are not equal") + func hashableNotEqual() { + let a = TokenLogprob(token: "x", logprob: -0.5, tokenId: 1) + let b = TokenLogprob(token: "y", logprob: -0.5, tokenId: 1) + #expect(a != b) + } + + @Test("TokenLogprobs work in Set") + func hashableSet() { + let a = TokenLogprob(token: "x", logprob: -0.5) + let b = TokenLogprob(token: "x", logprob: -0.5) + let c = TokenLogprob(token: "y", logprob: -1.0) + let set: Set = [a, b, c] + #expect(set.count == 2) + } + + // MARK: - Sendable + + @Test("TokenLogprob is Sendable") + func sendable() async { + let logprob = TokenLogprob(token: "test", logprob: -0.5, tokenId: 10) + let token = await Task { logprob.token }.value + #expect(token == "test") + } +} diff --git a/Tests/ConduitTests/Core/ToolMessageTests.swift b/Tests/ConduitTests/Core/ToolMessageTests.swift new file mode 100644 index 0000000..cd39504 --- /dev/null +++ b/Tests/ConduitTests/Core/ToolMessageTests.swift @@ -0,0 +1,158 @@ +// ToolMessageTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("ToolMessage Tests") +struct ToolMessageTests { + + // MARK: - Transcript.ToolCall Helpers + + @Test("ToolCall argumentsString returns JSON") + func toolCallArgumentsString() { + let content = GeneratedContent(kind: .structure( + properties: ["location": GeneratedContent(kind: .string("San Francisco"))], + orderedKeys: ["location"] + )) + let call = Transcript.ToolCall( + id: "call-1", + toolName: "get_weather", + arguments: content + ) + + let argsString = call.argumentsString + #expect(argsString.contains("San Francisco")) + } + + @Test("ToolCall argumentsData returns valid data") + func toolCallArgumentsData() throws { + let content = GeneratedContent(kind: .structure( + properties: ["query": GeneratedContent(kind: .string("test"))], + orderedKeys: ["query"] + )) + let call = Transcript.ToolCall( + id: "call-1", + toolName: "search", + arguments: content + ) + + let data = try call.argumentsData() + #expect(!data.isEmpty) + } + + // MARK: - Transcript.ToolOutput Helpers + + @Test("ToolOutput text extracts text segments") + func toolOutputText() { + let output = Transcript.ToolOutput( + id: "call-1", + toolName: "search", + segments: [ + .text(Transcript.TextSegment(content: "Result 1")), + .text(Transcript.TextSegment(content: "Result 2")) + ] + ) + + let text = output.text + #expect(text.contains("Result 1")) + #expect(text.contains("Result 2")) + } + + @Test("ToolOutput init from call preserves id and name") + func toolOutputFromCall() { + let call = Transcript.ToolCall( + id: "call-42", + toolName: "calculator", + arguments: GeneratedContent(kind: .null) + ) + + let output = Transcript.ToolOutput( + call: call, + segments: [.text(Transcript.TextSegment(content: "42"))] + ) + + #expect(output.id == "call-42") + #expect(output.toolName == "calculator") + #expect(output.text == "42") + } + + // MARK: - Message.toolOutput + + @Test("Message.toolOutput creates tool role message") + func messageToolOutput() { + let output = Transcript.ToolOutput( + id: "call-1", + toolName: "weather", + segments: [.text(Transcript.TextSegment(content: "Sunny, 72°F"))] + ) + + let message = Message.toolOutput(output) + + #expect(message.role == .tool) + #expect(message.content.textValue == "Sunny, 72°F") + #expect(message.metadata?.custom?["tool_call_id"] == "call-1") + #expect(message.metadata?.custom?["tool_name"] == "weather") + } + + @Test("Message.toolOutput from call and content") + func messageToolOutputFromCallAndContent() { + let call = Transcript.ToolCall( + id: "call-1", + toolName: "search", + arguments: GeneratedContent(kind: .null) + ) + + let message = Message.toolOutput(call: call, content: "Found 5 results") + + #expect(message.role == .tool) + #expect(message.content.textValue == "Found 5 results") + } + + // MARK: - Collection Extension + + @Test("call(named:) finds tool call by name") + func callNamed() { + let calls = [ + Transcript.ToolCall(id: "1", toolName: "weather", arguments: GeneratedContent(kind: .null)), + Transcript.ToolCall(id: "2", toolName: "search", arguments: GeneratedContent(kind: .null)), + Transcript.ToolCall(id: "3", toolName: "calculator", arguments: GeneratedContent(kind: .null)) + ] + + let found = calls.call(named: "search") + #expect(found?.id == "2") + } + + @Test("call(named:) returns nil when not found") + func callNamedNotFound() { + let calls = [ + Transcript.ToolCall(id: "1", toolName: "weather", arguments: GeneratedContent(kind: .null)) + ] + + let found = calls.call(named: "nonexistent") + #expect(found == nil) + } + + @Test("calls(named:) filters multiple matches") + func callsNamed() { + let calls = [ + Transcript.ToolCall(id: "1", toolName: "search", arguments: GeneratedContent(kind: .null)), + Transcript.ToolCall(id: "2", toolName: "weather", arguments: GeneratedContent(kind: .null)), + Transcript.ToolCall(id: "3", toolName: "search", arguments: GeneratedContent(kind: .null)) + ] + + let found = calls.calls(named: "search") + #expect(found.count == 2) + } + + @Test("calls(named:) returns empty when no matches") + func callsNamedEmpty() { + let calls = [ + Transcript.ToolCall(id: "1", toolName: "weather", arguments: GeneratedContent(kind: .null)) + ] + + let found = calls.calls(named: "nonexistent") + #expect(found.isEmpty) + } +} diff --git a/Tests/ConduitTests/Core/UsageStatsTests.swift b/Tests/ConduitTests/Core/UsageStatsTests.swift new file mode 100644 index 0000000..061eda6 --- /dev/null +++ b/Tests/ConduitTests/Core/UsageStatsTests.swift @@ -0,0 +1,95 @@ +// UsageStatsTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("UsageStats Tests") +struct UsageStatsTests { + + // MARK: - Initialization + + @Test("Init stores prompt and completion tokens") + func initialization() { + let usage = UsageStats(promptTokens: 100, completionTokens: 50) + + #expect(usage.promptTokens == 100) + #expect(usage.completionTokens == 50) + } + + // MARK: - Computed Properties + + @Test("totalTokens sums prompt and completion tokens") + func totalTokens() { + let usage = UsageStats(promptTokens: 100, completionTokens: 50) + #expect(usage.totalTokens == 150) + } + + @Test("totalTokens with zero values") + func totalTokensZero() { + let usage = UsageStats(promptTokens: 0, completionTokens: 0) + #expect(usage.totalTokens == 0) + } + + @Test("totalTokens with large values") + func totalTokensLarge() { + let usage = UsageStats(promptTokens: 100_000, completionTokens: 50_000) + #expect(usage.totalTokens == 150_000) + } + + // MARK: - Codable + + @Test("Codable round-trip preserves values") + func codableRoundTrip() throws { + let original = UsageStats(promptTokens: 42, completionTokens: 17) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(UsageStats.self, from: data) + + #expect(decoded.promptTokens == 42) + #expect(decoded.completionTokens == 17) + } + + @Test("JSON encoding produces expected keys") + func jsonKeys() throws { + let usage = UsageStats(promptTokens: 10, completionTokens: 20) + let data = try JSONEncoder().encode(usage) + let json = try JSONDecoder().decode([String: Int].self, from: data) + + #expect(json["promptTokens"] == 10) + #expect(json["completionTokens"] == 20) + // totalTokens is computed, should not appear in JSON + #expect(json["totalTokens"] == nil) + } + + // MARK: - Hashable + + @Test("Equal UsageStats have same hash") + func hashEquality() { + let a = UsageStats(promptTokens: 10, completionTokens: 20) + let b = UsageStats(promptTokens: 10, completionTokens: 20) + + #expect(a == b) + #expect(a.hashValue == b.hashValue) + } + + @Test("Different UsageStats are unequal") + func hashInequality() { + let a = UsageStats(promptTokens: 10, completionTokens: 20) + let b = UsageStats(promptTokens: 10, completionTokens: 30) + + #expect(a != b) + } + + @Test("UsageStats can be used in a Set") + func setUsage() { + let a = UsageStats(promptTokens: 10, completionTokens: 20) + let b = UsageStats(promptTokens: 30, completionTokens: 40) + + var set: Set = [] + set.insert(a) + set.insert(b) + + #expect(set.count == 2) + } +} diff --git a/Tests/ConduitTests/Core/WarmupConfigTests.swift b/Tests/ConduitTests/Core/WarmupConfigTests.swift new file mode 100644 index 0000000..af2ca4b --- /dev/null +++ b/Tests/ConduitTests/Core/WarmupConfigTests.swift @@ -0,0 +1,80 @@ +// WarmupConfigTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("WarmupConfig Tests") +struct WarmupConfigTests { + + // MARK: - Static Presets + + @Test("default preset has warmupOnInit false") + func defaultPreset() { + let config = WarmupConfig.default + + #expect(!config.warmupOnInit) + #expect(config.prefillChars == 50) + #expect(config.warmupTokens == 5) + } + + @Test("eager preset has warmupOnInit true") + func eagerPreset() { + let config = WarmupConfig.eager + + #expect(config.warmupOnInit) + #expect(config.prefillChars == 50) + #expect(config.warmupTokens == 5) + } + + // MARK: - Custom Initialization + + @Test("Custom init stores all properties") + func customInit() { + let config = WarmupConfig( + warmupOnInit: true, + prefillChars: 100, + warmupTokens: 10 + ) + + #expect(config.warmupOnInit) + #expect(config.prefillChars == 100) + #expect(config.warmupTokens == 10) + } + + @Test("Default parameter values in init") + func defaultParameters() { + let config = WarmupConfig() + + #expect(!config.warmupOnInit) + #expect(config.prefillChars == 50) + #expect(config.warmupTokens == 5) + } + + // MARK: - Mutability + + @Test("Properties are mutable") + func mutability() { + var config = WarmupConfig.default + + config.warmupOnInit = true + config.prefillChars = 200 + config.warmupTokens = 20 + + #expect(config.warmupOnInit) + #expect(config.prefillChars == 200) + #expect(config.warmupTokens == 20) + } + + // MARK: - Sendable + + @Test("WarmupConfig is Sendable") + func sendable() { + let config = WarmupConfig.eager + Task { + // This compiles only if WarmupConfig is Sendable + _ = config.warmupOnInit + } + } +} diff --git a/Tests/ConduitTests/Extensions/ArrayExtensionTests.swift b/Tests/ConduitTests/Extensions/ArrayExtensionTests.swift new file mode 100644 index 0000000..913e515 --- /dev/null +++ b/Tests/ConduitTests/Extensions/ArrayExtensionTests.swift @@ -0,0 +1,108 @@ +// ArrayExtensionTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("Array Extension Tests") +struct ArrayMessageExtensionTests { + + // MARK: - Test Data + + static let sampleMessages: [Message] = [ + .system("You are helpful."), + .user("Hello!"), + .assistant("Hi there!"), + .user("How are you?"), + .assistant("I'm doing well!") + ] + + // MARK: - userMessages + + @Test("userMessages filters only user role messages") + func userMessages() { + let users = Self.sampleMessages.userMessages + #expect(users.count == 2) + #expect(users.allSatisfy { $0.role == .user }) + } + + @Test("userMessages returns empty for no user messages") + func userMessagesEmpty() { + let messages: [Message] = [.system("System"), .assistant("Hi")] + #expect(messages.userMessages.isEmpty) + } + + // MARK: - assistantMessages + + @Test("assistantMessages filters only assistant role messages") + func assistantMessages() { + let assistants = Self.sampleMessages.assistantMessages + #expect(assistants.count == 2) + #expect(assistants.allSatisfy { $0.role == .assistant }) + } + + @Test("assistantMessages returns empty for no assistant messages") + func assistantMessagesEmpty() { + let messages: [Message] = [.system("System"), .user("Hi")] + #expect(messages.assistantMessages.isEmpty) + } + + // MARK: - systemMessage + + @Test("systemMessage returns first system message") + func systemMessage() { + let system = Self.sampleMessages.systemMessage + #expect(system != nil) + #expect(system?.role == .system) + #expect(system?.content.textValue == "You are helpful.") + } + + @Test("systemMessage returns nil when no system message exists") + func systemMessageNil() { + let messages: [Message] = [.user("Hello"), .assistant("Hi")] + #expect(messages.systemMessage == nil) + } + + // MARK: - withoutSystem + + @Test("withoutSystem removes system messages") + func withoutSystem() { + let filtered = Self.sampleMessages.withoutSystem + #expect(filtered.count == 4) + #expect(filtered.allSatisfy { $0.role != .system }) + } + + @Test("withoutSystem on messages without system returns all") + func withoutSystemNoChange() { + let messages: [Message] = [.user("Hello"), .assistant("Hi")] + let filtered = messages.withoutSystem + #expect(filtered.count == 2) + } + + // MARK: - totalTextLength + + @Test("totalTextLength sums character counts") + func totalTextLength() { + let messages: [Message] = [ + .user("Hello"), // 5 chars + .assistant("World") // 5 chars + ] + #expect(messages.totalTextLength == 10) + } + + @Test("totalTextLength is zero for empty array") + func totalTextLengthEmpty() { + let messages: [Message] = [] + #expect(messages.totalTextLength == 0) + } + + @Test("totalTextLength includes system message") + func totalTextLengthWithSystem() { + let messages: [Message] = [ + .system("Be brief."), // 9 chars + .user("Hi") // 2 chars + ] + #expect(messages.totalTextLength == 11) + } +} diff --git a/Tests/ConduitTests/Extensions/IntContextExtensionTests.swift b/Tests/ConduitTests/Extensions/IntContextExtensionTests.swift new file mode 100644 index 0000000..44ff94f --- /dev/null +++ b/Tests/ConduitTests/Extensions/IntContextExtensionTests.swift @@ -0,0 +1,110 @@ +// IntContextExtensionTests.swift +// ConduitTests + +import Foundation +import Testing +@testable import Conduit + +@Suite("Int Context Extension Tests") +struct IntContextExtensionTests { + + // MARK: - Context Window Constants + + @Test("context4K is 4096") + func context4K() { + #expect(Int.context4K == 4_096) + } + + @Test("context8K is 8192") + func context8K() { + #expect(Int.context8K == 8_192) + } + + @Test("context16K is 16384") + func context16K() { + #expect(Int.context16K == 16_384) + } + + @Test("context32K is 32768") + func context32K() { + #expect(Int.context32K == 32_768) + } + + @Test("context64K is 65536") + func context64K() { + #expect(Int.context64K == 65_536) + } + + @Test("context128K is 131072") + func context128K() { + #expect(Int.context128K == 131_072) + } + + @Test("context200K is 200000") + func context200K() { + #expect(Int.context200K == 200_000) + } + + @Test("context1M is 1000000") + func context1M() { + #expect(Int.context1M == 1_000_000) + } + + // MARK: - contextDescription + + @Test("Standard sizes have colloquial descriptions", + arguments: [ + (Int.context4K, "4K"), + (Int.context8K, "8K"), + (Int.context16K, "16K"), + (Int.context32K, "32K"), + (Int.context64K, "64K"), + (Int.context128K, "128K"), + (Int.context200K, "200K"), + (Int.context1M, "1M") + ]) + func standardContextDescriptions(size: Int, expected: String) { + #expect(size.contextDescription == expected) + } + + @Test("Non-standard sizes use binary K") + func nonStandardDescription() { + let size = 2048 + #expect(size.contextDescription == "2K") + } + + @Test("Large non-standard sizes use M") + func largeSizeDescription() { + let size = 2_000_000 + #expect(size.contextDescription == "2M") + } + + @Test("Small values shown as-is") + func smallSizeDescription() { + let size = 512 + #expect(size.contextDescription == "512") + } + + // MARK: - isStandardContextSize + + @Test("Standard sizes return true for isStandardContextSize", + arguments: [ + Int.context4K, + .context8K, + .context16K, + .context32K, + .context64K, + .context128K, + .context200K, + .context1M + ]) + func isStandard(size: Int) { + #expect(size.isStandardContextSize) + } + + @Test("Non-standard size returns false") + func isNotStandard() { + #expect(!2048.isStandardContextSize) + #expect(!10000.isStandardContextSize) + } +} diff --git a/Tests/ConduitTests/ImageGeneration/DiffusionModelDownloaderTests.swift b/Tests/ConduitTests/ImageGeneration/DiffusionModelDownloaderTests.swift index 0b580a2..a4f01d9 100644 --- a/Tests/ConduitTests/ImageGeneration/DiffusionModelDownloaderTests.swift +++ b/Tests/ConduitTests/ImageGeneration/DiffusionModelDownloaderTests.swift @@ -3,7 +3,7 @@ // // This file requires the MLX trait (Hub) to be enabled. -#if canImport(Hub) +#if CONDUIT_TRAIT_MLX import Foundation import Testing @@ -773,4 +773,4 @@ struct DiffusionModelDownloaderTests { } } -#endif // canImport(Hub) +#endif // CONDUIT_TRAIT_MLX diff --git a/Tests/ConduitTests/Providers/Kimi/KimiAuthenticationTests.swift b/Tests/ConduitTests/Providers/Kimi/KimiAuthenticationTests.swift new file mode 100644 index 0000000..0211d67 --- /dev/null +++ b/Tests/ConduitTests/Providers/Kimi/KimiAuthenticationTests.swift @@ -0,0 +1,236 @@ +// KimiAuthenticationTests.swift +// ConduitTests + +#if CONDUIT_TRAIT_KIMI && CONDUIT_TRAIT_OPENAI + +import Foundation +import Testing +@testable import Conduit + +@Suite("KimiAuthentication Tests") +struct KimiAuthenticationTests { + + // MARK: - AuthType Cases + + @Test("apiKey case stores the key") + func apiKeyCaseStoresKey() { + let auth = KimiAuthentication(type: .apiKey("sk-test-123")) + if case .apiKey(let key) = auth.type { + #expect(key == "sk-test-123") + } else { + Issue.record("Expected .apiKey case") + } + } + + @Test("auto case exists") + func autoCaseExists() { + let auth = KimiAuthentication(type: .auto) + if case .auto = auth.type { + // success + } else { + Issue.record("Expected .auto case") + } + } + + // MARK: - Static Factory: apiKey + + @Test("Static apiKey factory creates correct type") + func staticApiKeyFactory() { + let auth = KimiAuthentication.apiKey("sk-moonshot-abc") + + if case .apiKey(let key) = auth.type { + #expect(key == "sk-moonshot-abc") + } else { + Issue.record("Expected .apiKey case from static factory") + } + } + + @Test("Static apiKey factory resolves apiKey property") + func staticApiKeyFactoryResolvesApiKey() { + let auth = KimiAuthentication.apiKey("sk-resolve-me") + #expect(auth.apiKey == "sk-resolve-me") + } + + // MARK: - Static Factory: auto + + @Test("Static auto creates auto type") + func staticAutoFactory() { + let auth = KimiAuthentication.auto + if case .auto = auth.type { + // success + } else { + Issue.record("Expected .auto case from static property") + } + } + + // MARK: - apiKey Property + + @Test("apiKey property returns key for apiKey type") + func apiKeyPropertyWithApiKeyType() { + let auth = KimiAuthentication.apiKey("sk-my-key") + #expect(auth.apiKey == "sk-my-key") + } + + @Test("apiKey property checks environment for auto type") + func apiKeyPropertyWithAutoType() { + // auto type reads from ProcessInfo.processInfo.environment["MOONSHOT_API_KEY"] + // In test environment, this is likely nil unless explicitly set + let auth = KimiAuthentication.auto + let envKey = ProcessInfo.processInfo.environment["MOONSHOT_API_KEY"] + #expect(auth.apiKey == envKey) + } + + // MARK: - isValid + + @Test("isValid returns true for non-empty API key") + func isValidWithNonEmptyKey() { + let auth = KimiAuthentication.apiKey("sk-valid-key") + #expect(auth.isValid == true) + } + + @Test("isValid returns false for empty API key") + func isValidWithEmptyKey() { + let auth = KimiAuthentication.apiKey("") + #expect(auth.isValid == false) + } + + @Test("isValid for auto depends on environment") + func isValidAutoType() { + let auth = KimiAuthentication.auto + let envKey = ProcessInfo.processInfo.environment["MOONSHOT_API_KEY"] + let expectedValid = envKey?.isEmpty == false + #expect(auth.isValid == expectedValid) + } + + // MARK: - Hashable + + @Test("Equal authentications have same hash") + func hashableEquality() { + let auth1 = KimiAuthentication.apiKey("sk-same-key") + let auth2 = KimiAuthentication.apiKey("sk-same-key") + + #expect(auth1 == auth2) + #expect(auth1.hashValue == auth2.hashValue) + } + + @Test("Different API keys are not equal") + func hashableInequalityDifferentKeys() { + let auth1 = KimiAuthentication.apiKey("sk-key-1") + let auth2 = KimiAuthentication.apiKey("sk-key-2") + + #expect(auth1 != auth2) + } + + @Test("apiKey and auto are not equal") + func hashableInequalityDifferentTypes() { + let auth1 = KimiAuthentication.apiKey("sk-some-key") + let auth2 = KimiAuthentication.auto + + #expect(auth1 != auth2) + } + + @Test("auto instances are equal") + func hashableAutoEquality() { + let auth1 = KimiAuthentication.auto + let auth2 = KimiAuthentication(type: .auto) + + #expect(auth1 == auth2) + } + + @Test("Authentications work in a Set") + func authenticationSet() { + var set: Set = [] + set.insert(.apiKey("sk-a")) + set.insert(.apiKey("sk-b")) + set.insert(.auto) + set.insert(.apiKey("sk-a")) // duplicate + + #expect(set.count == 3) + } + + // MARK: - Codable Round-Trip + + @Test("Codable round-trip for apiKey type") + func codableRoundTripApiKey() throws { + let original = KimiAuthentication.apiKey("sk-codable-test") + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(KimiAuthentication.self, from: data) + + #expect(original == decoded) + #expect(decoded.apiKey == "sk-codable-test") + } + + @Test("Codable round-trip for auto type") + func codableRoundTripAuto() throws { + let original = KimiAuthentication.auto + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(KimiAuthentication.self, from: data) + + #expect(original == decoded) + } + + @Test("Codable round-trip preserves AuthType apiKey case") + func codableRoundTripAuthTypeApiKey() throws { + let original = KimiAuthentication.AuthType.apiKey("sk-inner") + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(KimiAuthentication.AuthType.self, from: data) + + if case .apiKey(let key) = decoded { + #expect(key == "sk-inner") + } else { + Issue.record("Expected decoded .apiKey case") + } + } + + @Test("Codable round-trip preserves AuthType auto case") + func codableRoundTripAuthTypeAuto() throws { + let original = KimiAuthentication.AuthType.auto + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(KimiAuthentication.AuthType.self, from: data) + + if case .auto = decoded { + // success + } else { + Issue.record("Expected decoded .auto case") + } + } + + // MARK: - CustomDebugStringConvertible + + @Test("Debug description for apiKey masks the key") + func debugDescriptionApiKey() { + let auth = KimiAuthentication.apiKey("sk-secret-should-not-appear") + #expect(auth.debugDescription == "KimiAuthentication.apiKey(***)") + #expect(!auth.debugDescription.contains("sk-secret-should-not-appear")) + } + + @Test("Debug description for auto shows auto") + func debugDescriptionAuto() { + let auth = KimiAuthentication.auto + #expect(auth.debugDescription == "KimiAuthentication.auto") + } + + // MARK: - Sendable + + @Test("KimiAuthentication is Sendable") + func sendableConformance() async { + let auth = KimiAuthentication.apiKey("sk-sendable-test") + let task = Task { auth.apiKey } + let result = await task.value + #expect(result == "sk-sendable-test") + } + + @Test("KimiAuthentication.AuthType is Sendable") + func authTypeSendableConformance() async { + let authType = KimiAuthentication.AuthType.apiKey("sk-sendable") + let task = Task { authType } + let result = await task.value + if case .apiKey(let key) = result { + #expect(key == "sk-sendable") + } else { + Issue.record("Expected .apiKey case from Task result") + } + } +} + +#endif // CONDUIT_TRAIT_KIMI && CONDUIT_TRAIT_OPENAI diff --git a/Tests/ConduitTests/Providers/Kimi/KimiConfigurationTests.swift b/Tests/ConduitTests/Providers/Kimi/KimiConfigurationTests.swift new file mode 100644 index 0000000..803a977 --- /dev/null +++ b/Tests/ConduitTests/Providers/Kimi/KimiConfigurationTests.swift @@ -0,0 +1,294 @@ +// KimiConfigurationTests.swift +// ConduitTests + +#if CONDUIT_TRAIT_KIMI && CONDUIT_TRAIT_OPENAI + +import Foundation +import Testing +@testable import Conduit + +@Suite("KimiConfiguration Tests") +struct KimiConfigurationTests { + + // MARK: - Default Initialization + + @Test("Default init uses auto authentication") + func defaultInitAuthentication() { + let config = KimiConfiguration() + #expect(config.authentication.type == .auto) + } + + @Test("Default init uses correct base URL") + func defaultInitBaseURL() { + let config = KimiConfiguration() + #expect(config.baseURL == URL(string: "https://api.moonshot.cn/v1")!) + } + + @Test("Default init uses 120 second timeout") + func defaultInitTimeout() { + let config = KimiConfiguration() + #expect(config.timeout == 120.0) + } + + @Test("Default init uses 3 max retries") + func defaultInitMaxRetries() { + let config = KimiConfiguration() + #expect(config.maxRetries == 3) + } + + // MARK: - Custom Initialization + + @Test("Init with all custom parameters") + func customInit() { + let auth = KimiAuthentication.apiKey("sk-test-key") + let url = URL(string: "https://custom.api.com/v2")! + let config = KimiConfiguration( + authentication: auth, + baseURL: url, + timeout: 60.0, + maxRetries: 5 + ) + + #expect(config.authentication.apiKey == "sk-test-key") + #expect(config.baseURL == url) + #expect(config.timeout == 60.0) + #expect(config.maxRetries == 5) + } + + @Test("Init with only authentication parameter") + func initWithAuthOnly() { + let config = KimiConfiguration(authentication: .apiKey("sk-key")) + #expect(config.authentication.apiKey == "sk-key") + #expect(config.baseURL == URL(string: "https://api.moonshot.cn/v1")!) + #expect(config.timeout == 120.0) + #expect(config.maxRetries == 3) + } + + // MARK: - Static Factory + + @Test("Standard factory creates config with API key") + func standardFactory() { + let config = KimiConfiguration.standard(apiKey: "sk-moonshot-abc123") + + #expect(config.authentication.apiKey == "sk-moonshot-abc123") + #expect(config.baseURL == URL(string: "https://api.moonshot.cn/v1")!) + #expect(config.timeout == 120.0) + #expect(config.maxRetries == 3) + } + + // MARK: - hasValidAuthentication + + @Test("hasValidAuthentication with API key returns true") + func hasValidAuthenticationWithKey() { + let config = KimiConfiguration.standard(apiKey: "sk-valid") + #expect(config.hasValidAuthentication == true) + } + + @Test("hasValidAuthentication with empty API key returns false") + func hasValidAuthenticationEmptyKey() { + let config = KimiConfiguration(authentication: .apiKey("")) + #expect(config.hasValidAuthentication == false) + } + + // MARK: - Fluent API: apiKey + + @Test("Fluent apiKey sets authentication") + func fluentApiKey() { + let config = KimiConfiguration().apiKey("sk-new-key") + + #expect(config.authentication.apiKey == "sk-new-key") + } + + @Test("Fluent apiKey preserves other fields") + func fluentApiKeyPreservesOther() { + let original = KimiConfiguration( + baseURL: URL(string: "https://custom.com")!, + timeout: 60.0, + maxRetries: 5 + ) + let updated = original.apiKey("sk-new-key") + + #expect(updated.baseURL == URL(string: "https://custom.com")!) + #expect(updated.timeout == 60.0) + #expect(updated.maxRetries == 5) + #expect(updated.authentication.apiKey == "sk-new-key") + } + + @Test("Fluent apiKey does not mutate original") + func fluentApiKeyImmutability() { + let original = KimiConfiguration() + let _ = original.apiKey("sk-new-key") + + #expect(original.authentication.type == .auto) + } + + // MARK: - Fluent API: timeout + + @Test("Fluent timeout sets timeout value") + func fluentTimeout() { + let config = KimiConfiguration().timeout(300.0) + #expect(config.timeout == 300.0) + } + + @Test("Fluent timeout clamps negative to zero") + func fluentTimeoutClampsNegative() { + let config = KimiConfiguration().timeout(-10.0) + #expect(config.timeout == 0.0) + } + + @Test("Fluent timeout preserves other fields") + func fluentTimeoutPreservesOther() { + let original = KimiConfiguration.standard(apiKey: "sk-key") + let updated = original.timeout(45.0) + + #expect(updated.authentication.apiKey == "sk-key") + #expect(updated.baseURL == URL(string: "https://api.moonshot.cn/v1")!) + #expect(updated.maxRetries == 3) + #expect(updated.timeout == 45.0) + } + + @Test("Fluent timeout does not mutate original") + func fluentTimeoutImmutability() { + let original = KimiConfiguration() + let _ = original.timeout(999.0) + + #expect(original.timeout == 120.0) + } + + // MARK: - Fluent API: maxRetries + + @Test("Fluent maxRetries sets retry count") + func fluentMaxRetries() { + let config = KimiConfiguration().maxRetries(10) + #expect(config.maxRetries == 10) + } + + @Test("Fluent maxRetries clamps negative to zero") + func fluentMaxRetriesClampsNegative() { + let config = KimiConfiguration().maxRetries(-5) + #expect(config.maxRetries == 0) + } + + @Test("Fluent maxRetries preserves other fields") + func fluentMaxRetriesPreservesOther() { + let original = KimiConfiguration.standard(apiKey: "sk-key").timeout(60.0) + let updated = original.maxRetries(7) + + #expect(updated.authentication.apiKey == "sk-key") + #expect(updated.timeout == 60.0) + #expect(updated.baseURL == URL(string: "https://api.moonshot.cn/v1")!) + #expect(updated.maxRetries == 7) + } + + @Test("Fluent maxRetries does not mutate original") + func fluentMaxRetriesImmutability() { + let original = KimiConfiguration() + let _ = original.maxRetries(99) + + #expect(original.maxRetries == 3) + } + + // MARK: - Fluent API: Chaining + + @Test("Fluent methods can be chained") + func fluentChaining() { + let config = KimiConfiguration() + .apiKey("sk-chained") + .timeout(90.0) + .maxRetries(2) + + #expect(config.authentication.apiKey == "sk-chained") + #expect(config.timeout == 90.0) + #expect(config.maxRetries == 2) + } + + // MARK: - Hashable + + @Test("Equal configurations have same hash") + func hashableEquality() { + let config1 = KimiConfiguration.standard(apiKey: "sk-test") + let config2 = KimiConfiguration.standard(apiKey: "sk-test") + + #expect(config1 == config2) + #expect(config1.hashValue == config2.hashValue) + } + + @Test("Different configurations are not equal") + func hashableInequality() { + let config1 = KimiConfiguration.standard(apiKey: "sk-one") + let config2 = KimiConfiguration.standard(apiKey: "sk-two") + + #expect(config1 != config2) + } + + @Test("Configurations work in a Set") + func configSet() { + let config1 = KimiConfiguration.standard(apiKey: "sk-a") + let config2 = KimiConfiguration.standard(apiKey: "sk-b") + let config3 = KimiConfiguration.standard(apiKey: "sk-a") // duplicate + + var set: Set = [] + set.insert(config1) + set.insert(config2) + set.insert(config3) + + #expect(set.count == 2) + } + + // MARK: - Codable Round-Trip + + @Test("Codable round-trip preserves all fields") + func codableRoundTrip() throws { + let original = KimiConfiguration( + authentication: .apiKey("sk-round-trip-test"), + baseURL: URL(string: "https://custom.moonshot.cn/v2")!, + timeout: 180.0, + maxRetries: 7 + ) + + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(KimiConfiguration.self, from: data) + + #expect(decoded.authentication.apiKey == "sk-round-trip-test") + #expect(decoded.baseURL == URL(string: "https://custom.moonshot.cn/v2")!) + #expect(decoded.timeout == 180.0) + #expect(decoded.maxRetries == 7) + #expect(original == decoded) + } + + @Test("Codable round-trip with default config") + func codableRoundTripDefault() throws { + let original = KimiConfiguration() + + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(KimiConfiguration.self, from: data) + + #expect(decoded.baseURL == original.baseURL) + #expect(decoded.timeout == original.timeout) + #expect(decoded.maxRetries == original.maxRetries) + #expect(original == decoded) + } + + @Test("Codable round-trip with standard factory") + func codableRoundTripStandardFactory() throws { + let original = KimiConfiguration.standard(apiKey: "sk-encode-me") + + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(KimiConfiguration.self, from: data) + + #expect(decoded.authentication.apiKey == "sk-encode-me") + #expect(original == decoded) + } + + // MARK: - Sendable + + @Test("KimiConfiguration is Sendable") + func sendableConformance() async { + let config = KimiConfiguration.standard(apiKey: "sk-sendable") + let task = Task { config.timeout } + let result = await task.value + #expect(result == 120.0) + } +} + +#endif // CONDUIT_TRAIT_KIMI && CONDUIT_TRAIT_OPENAI diff --git a/Tests/ConduitTests/Providers/Kimi/KimiModelIDTests.swift b/Tests/ConduitTests/Providers/Kimi/KimiModelIDTests.swift new file mode 100644 index 0000000..eb80bd3 --- /dev/null +++ b/Tests/ConduitTests/Providers/Kimi/KimiModelIDTests.swift @@ -0,0 +1,195 @@ +// KimiModelIDTests.swift +// ConduitTests + +#if CONDUIT_TRAIT_KIMI && CONDUIT_TRAIT_OPENAI + +import Foundation +import Testing +@testable import Conduit + +@Suite("KimiModelID Tests") +struct KimiModelIDTests { + + // MARK: - Predefined Model Cases + + @Test("Predefined model kimiK2_5 has correct raw value") + func kimiK2_5RawValue() { + #expect(KimiModelID.kimiK2_5.rawValue == "kimi-k2-5") + } + + @Test("Predefined model kimiK2 has correct raw value") + func kimiK2RawValue() { + #expect(KimiModelID.kimiK2.rawValue == "kimi-k2") + } + + @Test("Predefined model kimiK1_5 has correct raw value") + func kimiK1_5RawValue() { + #expect(KimiModelID.kimiK1_5.rawValue == "kimi-k1-5") + } + + // MARK: - Initialization + + @Test("Init with raw string") + func initWithRawString() { + let model = KimiModelID("custom-model") + #expect(model.rawValue == "custom-model") + } + + @Test("Init with rawValue parameter") + func initWithRawValueParameter() { + let model = KimiModelID(rawValue: "custom-model-2") + #expect(model.rawValue == "custom-model-2") + } + + @Test("String literal initialization") + func stringLiteralInit() { + let model: KimiModelID = "my-custom-kimi" + #expect(model.rawValue == "my-custom-kimi") + } + + // MARK: - ModelIdentifying Conformance + + @Test("Provider is kimi") + func providerIsKimi() { + #expect(KimiModelID.kimiK2_5.provider == .kimi) + #expect(KimiModelID.kimiK2.provider == .kimi) + #expect(KimiModelID.kimiK1_5.provider == .kimi) + } + + @Test("Custom model provider is kimi") + func customModelProviderIsKimi() { + let model = KimiModelID("anything") + #expect(model.provider == .kimi) + } + + // MARK: - Display Name + + @Test("Display name strips kimi- prefix and replaces dashes with dots") + func displayNameFormatting() { + #expect(KimiModelID.kimiK2_5.displayName == "Kimi k2.5") + #expect(KimiModelID.kimiK2.displayName == "Kimi k2") + #expect(KimiModelID.kimiK1_5.displayName == "Kimi k1.5") + } + + @Test("Display name for custom model without kimi prefix") + func displayNameCustomModel() { + let model = KimiModelID("some-other-model") + #expect(model.displayName == "some.other.model") + } + + // MARK: - Description + + @Test("Description includes provider tag and raw value") + func descriptionFormat() { + #expect(KimiModelID.kimiK2_5.description == "[Kimi] kimi-k2-5") + #expect(KimiModelID.kimiK2.description == "[Kimi] kimi-k2") + #expect(KimiModelID.kimiK1_5.description == "[Kimi] kimi-k1-5") + } + + @Test("Description for custom model") + func descriptionCustomModel() { + let model = KimiModelID("my-model") + #expect(model.description == "[Kimi] my-model") + } + + // MARK: - Hashable + + @Test("Equal models have same hash") + func hashableEquality() { + let model1 = KimiModelID("kimi-k2-5") + let model2 = KimiModelID.kimiK2_5 + #expect(model1 == model2) + #expect(model1.hashValue == model2.hashValue) + } + + @Test("Different models are not equal") + func hashableInequality() { + #expect(KimiModelID.kimiK2_5 != KimiModelID.kimiK2) + #expect(KimiModelID.kimiK2 != KimiModelID.kimiK1_5) + } + + @Test("Models work in a Set") + func modelSet() { + var set: Set = [] + set.insert(.kimiK2_5) + set.insert(.kimiK2) + set.insert(.kimiK1_5) + set.insert(.kimiK2_5) // duplicate + + #expect(set.count == 3) + #expect(set.contains(.kimiK2_5)) + #expect(set.contains(.kimiK2)) + #expect(set.contains(.kimiK1_5)) + } + + // MARK: - Codable Round-Trip + + @Test("Codable round-trip for predefined models") + func codableRoundTripPredefined() throws { + let models: [KimiModelID] = [.kimiK2_5, .kimiK2, .kimiK1_5] + let encoder = JSONEncoder() + let decoder = JSONDecoder() + + for original in models { + let data = try encoder.encode(original) + let decoded = try decoder.decode(KimiModelID.self, from: data) + #expect(original == decoded) + #expect(original.rawValue == decoded.rawValue) + } + } + + @Test("Codable round-trip for custom model") + func codableRoundTripCustom() throws { + let original = KimiModelID("my-custom-kimi-model") + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(KimiModelID.self, from: data) + + #expect(original == decoded) + #expect(decoded.rawValue == "my-custom-kimi-model") + } + + @Test("Encodes as single string value") + func encodesAsSingleValue() throws { + let model = KimiModelID.kimiK2_5 + let data = try JSONEncoder().encode(model) + let jsonString = String(data: data, encoding: .utf8)! + + #expect(jsonString == "\"kimi-k2-5\"") + } + + @Test("Decodes from single string value") + func decodesFromSingleValue() throws { + let json = "\"kimi-k2\"".data(using: .utf8)! + let decoded = try JSONDecoder().decode(KimiModelID.self, from: json) + + #expect(decoded == KimiModelID.kimiK2) + #expect(decoded.rawValue == "kimi-k2") + } + + // MARK: - ExpressibleByStringLiteral + + @Test("String literal creates correct model") + func stringLiteralCreation() { + let model: KimiModelID = "kimi-k2-5" + #expect(model == KimiModelID.kimiK2_5) + } + + @Test("String literal preserves arbitrary strings") + func stringLiteralArbitrary() { + let model: KimiModelID = "totally-custom-id" + #expect(model.rawValue == "totally-custom-id") + #expect(model.provider == .kimi) + } + + // MARK: - Sendable + + @Test("KimiModelID is Sendable") + func sendableConformance() async { + let model = KimiModelID.kimiK2_5 + let task = Task { model.rawValue } + let result = await task.value + #expect(result == "kimi-k2-5") + } +} + +#endif // CONDUIT_TRAIT_KIMI && CONDUIT_TRAIT_OPENAI diff --git a/Tests/ConduitTests/Providers/MiniMax/MiniMaxAuthenticationTests.swift b/Tests/ConduitTests/Providers/MiniMax/MiniMaxAuthenticationTests.swift new file mode 100644 index 0000000..3ea25e4 --- /dev/null +++ b/Tests/ConduitTests/Providers/MiniMax/MiniMaxAuthenticationTests.swift @@ -0,0 +1,245 @@ +// MiniMaxAuthenticationTests.swift +// ConduitTests +// +// Unit tests for MiniMax authentication. + +#if CONDUIT_TRAIT_MINIMAX && CONDUIT_TRAIT_OPENAI +import Testing +import Foundation +@testable import Conduit + +@Suite("MiniMax Authentication Tests") +struct MiniMaxAuthenticationTests { + + // MARK: - AuthType Cases + + @Test("AuthType apiKey case holds value") + func authTypeApiKey() { + let authType = MiniMaxAuthentication.AuthType.apiKey("my-secret-key") + if case .apiKey(let key) = authType { + #expect(key == "my-secret-key") + } else { + Issue.record("Expected .apiKey case") + } + } + + @Test("AuthType auto case") + func authTypeAuto() { + let authType = MiniMaxAuthentication.AuthType.auto + if case .auto = authType { + // pass + } else { + Issue.record("Expected .auto case") + } + } + + // MARK: - Static Factory Methods + + @Test("apiKey factory creates correct authentication") + func apiKeyFactory() { + let auth = MiniMaxAuthentication.apiKey("factory-key") + #expect(auth.apiKey == "factory-key") + } + + @Test("auto static property creates auto authentication") + func autoFactory() { + let auth = MiniMaxAuthentication.auto + if case .auto = auth.type { + // pass + } else { + Issue.record("Expected .auto type") + } + } + + // MARK: - Init + + @Test("Init with apiKey type") + func initWithApiKeyType() { + let auth = MiniMaxAuthentication(type: .apiKey("init-key")) + #expect(auth.apiKey == "init-key") + } + + @Test("Init with auto type") + func initWithAutoType() { + let auth = MiniMaxAuthentication(type: .auto) + if case .auto = auth.type { + // pass + } else { + Issue.record("Expected .auto type") + } + } + + // MARK: - apiKey Computed Property + + @Test("apiKey returns key for apiKey type") + func apiKeyComputedPropertyWithKey() { + let auth = MiniMaxAuthentication.apiKey("computed-key") + #expect(auth.apiKey == "computed-key") + } + + @Test("apiKey returns environment variable for auto type") + func apiKeyComputedPropertyWithAuto() { + // When the env var is not set, apiKey returns nil + // We can't reliably test the env-var-set case without modifying the environment, + // but we can verify the property exists and returns something (possibly nil). + let auth = MiniMaxAuthentication.auto + // The actual value depends on the environment; just verify no crash. + _ = auth.apiKey + } + + @Test("apiKey with empty string returns empty string") + func apiKeyEmptyString() { + let auth = MiniMaxAuthentication.apiKey("") + #expect(auth.apiKey == "") + } + + // MARK: - isValid + + @Test("isValid is true for non-empty API key") + func isValidWithKey() { + let auth = MiniMaxAuthentication.apiKey("valid-key") + #expect(auth.isValid == true) + } + + @Test("isValid is false for empty API key") + func isValidWithEmptyKey() { + let auth = MiniMaxAuthentication.apiKey("") + #expect(auth.isValid == false) + } + + // MARK: - Equatable + + @Test("Same apiKey authentications are equal") + func equatableSameKey() { + let auth1 = MiniMaxAuthentication.apiKey("same-key") + let auth2 = MiniMaxAuthentication.apiKey("same-key") + #expect(auth1 == auth2) + } + + @Test("Different apiKey authentications are not equal") + func equatableDifferentKey() { + let auth1 = MiniMaxAuthentication.apiKey("key-1") + let auth2 = MiniMaxAuthentication.apiKey("key-2") + #expect(auth1 != auth2) + } + + @Test("Auto authentications are equal") + func equatableAuto() { + let auth1 = MiniMaxAuthentication.auto + let auth2 = MiniMaxAuthentication.auto + #expect(auth1 == auth2) + } + + @Test("apiKey and auto are not equal") + func equatableApiKeyVsAuto() { + let apiKeyAuth = MiniMaxAuthentication.apiKey("key") + let autoAuth = MiniMaxAuthentication.auto + #expect(apiKeyAuth != autoAuth) + } + + // MARK: - Hashable + + @Test("Same authentications have same hash") + func hashableSame() { + let auth1 = MiniMaxAuthentication.apiKey("hash-key") + let auth2 = MiniMaxAuthentication.apiKey("hash-key") + #expect(auth1.hashValue == auth2.hashValue) + } + + @Test("Can be used in a Set") + func hashableInSet() { + var authSet: Set = [] + authSet.insert(.apiKey("key1")) + authSet.insert(.apiKey("key2")) + authSet.insert(.auto) + authSet.insert(.auto) // duplicate + + #expect(authSet.count == 3) + } + + // MARK: - CustomDebugStringConvertible + + @Test("Debug description for apiKey hides the key") + func debugDescriptionApiKey() { + let auth = MiniMaxAuthentication.apiKey("super-secret") + #expect(auth.debugDescription == "MiniMaxAuthentication.apiKey(***)") + #expect(!auth.debugDescription.contains("super-secret")) + } + + @Test("Debug description for auto shows auto") + func debugDescriptionAuto() { + let auth = MiniMaxAuthentication.auto + #expect(auth.debugDescription == "MiniMaxAuthentication.auto") + } + + // MARK: - Codable Round-Trip + + @Test("Codable round-trip for apiKey authentication") + func codableRoundTripApiKey() throws { + let original = MiniMaxAuthentication.apiKey("codable-key") + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(MiniMaxAuthentication.self, from: data) + #expect(original == decoded) + #expect(decoded.apiKey == "codable-key") + } + + @Test("Codable round-trip for auto authentication") + func codableRoundTripAuto() throws { + let original = MiniMaxAuthentication.auto + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(MiniMaxAuthentication.self, from: data) + #expect(original == decoded) + } + + @Test("Codable round-trip preserves AuthType apiKey") + func codableRoundTripAuthType() throws { + let original = MiniMaxAuthentication.AuthType.apiKey("type-key") + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(MiniMaxAuthentication.AuthType.self, from: data) + if case .apiKey(let key) = decoded { + #expect(key == "type-key") + } else { + Issue.record("Expected .apiKey case after decoding") + } + } + + @Test("Codable round-trip preserves AuthType auto") + func codableRoundTripAuthTypeAuto() throws { + let original = MiniMaxAuthentication.AuthType.auto + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(MiniMaxAuthentication.AuthType.self, from: data) + if case .auto = decoded { + // pass + } else { + Issue.record("Expected .auto case after decoding") + } + } + + // MARK: - Sendable + + @Test("Authentication is Sendable across tasks") + func sendableConformance() async { + let auth = MiniMaxAuthentication.apiKey("sendable-key") + let result = await Task { auth.apiKey }.value + #expect(result == "sendable-key") + } + + // MARK: - Edge Cases + + @Test("API key with whitespace is preserved") + func apiKeyWithWhitespace() { + let auth = MiniMaxAuthentication.apiKey(" key with spaces ") + #expect(auth.apiKey == " key with spaces ") + #expect(auth.isValid == true) + } + + @Test("Very long API key is preserved") + func veryLongApiKey() { + let longKey = String(repeating: "a", count: 1000) + let auth = MiniMaxAuthentication.apiKey(longKey) + #expect(auth.apiKey == longKey) + #expect(auth.isValid == true) + } +} + +#endif // CONDUIT_TRAIT_MINIMAX && CONDUIT_TRAIT_OPENAI diff --git a/Tests/ConduitTests/Providers/MiniMax/MiniMaxConfigurationTests.swift b/Tests/ConduitTests/Providers/MiniMax/MiniMaxConfigurationTests.swift new file mode 100644 index 0000000..6407fdc --- /dev/null +++ b/Tests/ConduitTests/Providers/MiniMax/MiniMaxConfigurationTests.swift @@ -0,0 +1,262 @@ +// MiniMaxConfigurationTests.swift +// ConduitTests +// +// Unit tests for MiniMax configuration. + +#if CONDUIT_TRAIT_MINIMAX && CONDUIT_TRAIT_OPENAI +import Testing +import Foundation +@testable import Conduit + +@Suite("MiniMax Configuration Tests") +struct MiniMaxConfigurationTests { + + // MARK: - Default Initialization + + @Test("Default init uses expected values") + func defaultInit() { + let config = MiniMaxConfiguration() + #expect(config.authentication.type == .auto) + #expect(config.baseURL == URL(string: "https://minimax-m2.com/api/v1")!) + #expect(config.timeout == 120.0) + #expect(config.maxRetries == 3) + } + + @Test("Default authentication is auto") + func defaultAuthenticationIsAuto() { + let config = MiniMaxConfiguration() + #expect(config.authentication == .auto) + } + + // MARK: - Custom Initialization + + @Test("Init with custom authentication") + func customAuthentication() { + let config = MiniMaxConfiguration(authentication: .apiKey("test-key")) + #expect(config.authentication.apiKey == "test-key") + } + + @Test("Init with custom base URL") + func customBaseURL() { + let url = URL(string: "https://custom.example.com/api")! + let config = MiniMaxConfiguration(baseURL: url) + #expect(config.baseURL == url) + } + + @Test("Init with custom timeout") + func customTimeout() { + let config = MiniMaxConfiguration(timeout: 30.0) + #expect(config.timeout == 30.0) + } + + @Test("Init with custom max retries") + func customMaxRetries() { + let config = MiniMaxConfiguration(maxRetries: 5) + #expect(config.maxRetries == 5) + } + + @Test("Init with all custom values") + func allCustomValues() { + let url = URL(string: "https://custom.example.com")! + let config = MiniMaxConfiguration( + authentication: .apiKey("my-key"), + baseURL: url, + timeout: 60.0, + maxRetries: 10 + ) + #expect(config.authentication.apiKey == "my-key") + #expect(config.baseURL == url) + #expect(config.timeout == 60.0) + #expect(config.maxRetries == 10) + } + + // MARK: - Static Factory Methods + + @Test("standard(apiKey:) sets API key authentication") + func standardFactoryMethod() { + let config = MiniMaxConfiguration.standard(apiKey: "sk-test-123") + #expect(config.authentication.apiKey == "sk-test-123") + } + + @Test("standard(apiKey:) uses default base URL") + func standardFactoryMethodBaseURL() { + let config = MiniMaxConfiguration.standard(apiKey: "key") + #expect(config.baseURL == URL(string: "https://minimax-m2.com/api/v1")!) + } + + @Test("standard(apiKey:) uses default timeout and retries") + func standardFactoryMethodDefaults() { + let config = MiniMaxConfiguration.standard(apiKey: "key") + #expect(config.timeout == 120.0) + #expect(config.maxRetries == 3) + } + + // MARK: - hasValidAuthentication + + @Test("hasValidAuthentication is true with API key") + func hasValidAuthWithKey() { + let config = MiniMaxConfiguration(authentication: .apiKey("valid-key")) + #expect(config.hasValidAuthentication == true) + } + + @Test("hasValidAuthentication is false with empty API key") + func hasValidAuthEmptyKey() { + let config = MiniMaxConfiguration(authentication: .apiKey("")) + #expect(config.hasValidAuthentication == false) + } + + // MARK: - Fluent API + + @Test("apiKey fluent method sets authentication") + func fluentApiKey() { + let config = MiniMaxConfiguration().apiKey("fluent-key") + #expect(config.authentication.apiKey == "fluent-key") + } + + @Test("apiKey fluent method returns new instance") + func fluentApiKeyNewInstance() { + let original = MiniMaxConfiguration() + let modified = original.apiKey("new-key") + // Original unchanged + #expect(original.authentication == .auto) + #expect(modified.authentication.apiKey == "new-key") + } + + @Test("timeout fluent method sets timeout") + func fluentTimeout() { + let config = MiniMaxConfiguration().timeout(45.0) + #expect(config.timeout == 45.0) + } + + @Test("timeout fluent method clamps negative to zero") + func fluentTimeoutClampsNegative() { + let config = MiniMaxConfiguration().timeout(-10.0) + #expect(config.timeout == 0.0) + } + + @Test("timeout fluent method allows zero") + func fluentTimeoutAllowsZero() { + let config = MiniMaxConfiguration().timeout(0.0) + #expect(config.timeout == 0.0) + } + + @Test("timeout fluent method returns new instance") + func fluentTimeoutNewInstance() { + let original = MiniMaxConfiguration() + let modified = original.timeout(30.0) + #expect(original.timeout == 120.0) + #expect(modified.timeout == 30.0) + } + + @Test("maxRetries fluent method sets max retries") + func fluentMaxRetries() { + let config = MiniMaxConfiguration().maxRetries(7) + #expect(config.maxRetries == 7) + } + + @Test("maxRetries fluent method clamps negative to zero") + func fluentMaxRetriesClampsNegative() { + let config = MiniMaxConfiguration().maxRetries(-3) + #expect(config.maxRetries == 0) + } + + @Test("maxRetries fluent method allows zero") + func fluentMaxRetriesAllowsZero() { + let config = MiniMaxConfiguration().maxRetries(0) + #expect(config.maxRetries == 0) + } + + @Test("maxRetries fluent method returns new instance") + func fluentMaxRetriesNewInstance() { + let original = MiniMaxConfiguration() + let modified = original.maxRetries(1) + #expect(original.maxRetries == 3) + #expect(modified.maxRetries == 1) + } + + @Test("Fluent methods can be chained") + func fluentChaining() { + let config = MiniMaxConfiguration() + .apiKey("chain-key") + .timeout(90.0) + .maxRetries(2) + + #expect(config.authentication.apiKey == "chain-key") + #expect(config.timeout == 90.0) + #expect(config.maxRetries == 2) + } + + // MARK: - Hashable + + @Test("Equal configurations have same hash") + func hashableEquality() { + let config1 = MiniMaxConfiguration.standard(apiKey: "key") + let config2 = MiniMaxConfiguration.standard(apiKey: "key") + #expect(config1.hashValue == config2.hashValue) + } + + @Test("Can be used in a Set") + func hashableInSet() { + let config1 = MiniMaxConfiguration.standard(apiKey: "key1") + let config2 = MiniMaxConfiguration.standard(apiKey: "key2") + var configSet: Set = [] + configSet.insert(config1) + configSet.insert(config2) + #expect(configSet.count == 2) + } + + // MARK: - Codable Round-Trip + + @Test("Codable round-trip with default config") + func codableRoundTripDefault() throws { + let original = MiniMaxConfiguration() + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(MiniMaxConfiguration.self, from: data) + #expect(original == decoded) + } + + @Test("Codable round-trip with API key") + func codableRoundTripWithApiKey() throws { + let original = MiniMaxConfiguration.standard(apiKey: "test-key-123") + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(MiniMaxConfiguration.self, from: data) + #expect(original == decoded) + #expect(decoded.authentication.apiKey == "test-key-123") + } + + @Test("Codable round-trip with custom values") + func codableRoundTripCustom() throws { + let original = MiniMaxConfiguration( + authentication: .apiKey("custom"), + baseURL: URL(string: "https://example.com")!, + timeout: 42.0, + maxRetries: 7 + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(MiniMaxConfiguration.self, from: data) + #expect(original == decoded) + #expect(decoded.baseURL == URL(string: "https://example.com")!) + #expect(decoded.timeout == 42.0) + #expect(decoded.maxRetries == 7) + } + + @Test("Codable round-trip preserves base URL") + func codableRoundTripBaseURL() throws { + let url = URL(string: "https://custom-minimax.example.com/v2")! + let original = MiniMaxConfiguration(baseURL: url) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(MiniMaxConfiguration.self, from: data) + #expect(decoded.baseURL == url) + } + + // MARK: - Sendable + + @Test("Configuration is Sendable across tasks") + func sendableConformance() async { + let config = MiniMaxConfiguration.standard(apiKey: "sendable-key") + let result = await Task { config.authentication.apiKey }.value + #expect(result == "sendable-key") + } +} + +#endif // CONDUIT_TRAIT_MINIMAX && CONDUIT_TRAIT_OPENAI diff --git a/Tests/ConduitTests/Providers/MiniMax/MiniMaxModelIDTests.swift b/Tests/ConduitTests/Providers/MiniMax/MiniMaxModelIDTests.swift new file mode 100644 index 0000000..64eea4e --- /dev/null +++ b/Tests/ConduitTests/Providers/MiniMax/MiniMaxModelIDTests.swift @@ -0,0 +1,233 @@ +// MiniMaxModelIDTests.swift +// ConduitTests +// +// Unit tests for MiniMax model identifiers. + +#if CONDUIT_TRAIT_MINIMAX && CONDUIT_TRAIT_OPENAI +import Testing +import Foundation +@testable import Conduit + +@Suite("MiniMax Model ID Tests") +struct MiniMaxModelIDTests { + + // MARK: - Static Model Constants + + @Test("minimaxM2 has correct raw value") + func minimaxM2RawValue() { + #expect(MiniMaxModelID.minimaxM2.rawValue == "MiniMax-M2") + } + + @Test("minimaxM2_1 has correct raw value") + func minimaxM2_1RawValue() { + #expect(MiniMaxModelID.minimaxM2_1.rawValue == "MiniMax-M2.1") + } + + @Test("minimaxM2_5 has correct raw value") + func minimaxM2_5RawValue() { + #expect(MiniMaxModelID.minimaxM2_5.rawValue == "MiniMax-M2.5") + } + + // MARK: - Initialization + + @Test("Init with raw value string") + func initWithRawValue() { + let model = MiniMaxModelID(rawValue: "custom-model") + #expect(model.rawValue == "custom-model") + } + + @Test("Init with positional string") + func initWithPositionalString() { + let model = MiniMaxModelID("my-model") + #expect(model.rawValue == "my-model") + } + + // MARK: - ModelIdentifying Conformance + + @Test("Provider is minimax for all static models") + func providerIsMiniMax() { + #expect(MiniMaxModelID.minimaxM2.provider == .minimax) + #expect(MiniMaxModelID.minimaxM2_1.provider == .minimax) + #expect(MiniMaxModelID.minimaxM2_5.provider == .minimax) + } + + @Test("Provider is minimax for custom model") + func providerIsMiniMaxCustom() { + let model = MiniMaxModelID("anything") + #expect(model.provider == .minimax) + } + + @Test("Display name equals raw value") + func displayNameEqualsRawValue() { + let model = MiniMaxModelID("test-model") + #expect(model.displayName == "test-model") + #expect(model.displayName == model.rawValue) + } + + @Test("Display name for static models") + func displayNameStaticModels() { + #expect(MiniMaxModelID.minimaxM2.displayName == "MiniMax-M2") + #expect(MiniMaxModelID.minimaxM2_1.displayName == "MiniMax-M2.1") + #expect(MiniMaxModelID.minimaxM2_5.displayName == "MiniMax-M2.5") + } + + // MARK: - CustomStringConvertible + + @Test("Description includes provider prefix and raw value") + func descriptionFormat() { + let model = MiniMaxModelID("MiniMax-M2") + #expect(model.description == "[MiniMax] MiniMax-M2") + } + + @Test("Description for custom model") + func descriptionCustomModel() { + let model = MiniMaxModelID("my-fine-tuned-model") + #expect(model.description == "[MiniMax] my-fine-tuned-model") + } + + // MARK: - ExpressibleByStringLiteral + + @Test("String literal initialization") + func stringLiteralInit() { + let model: MiniMaxModelID = "literal-model" + #expect(model.rawValue == "literal-model") + } + + @Test("String literal produces same result as explicit init") + func stringLiteralEquivalence() { + let fromLiteral: MiniMaxModelID = "MiniMax-M2" + let fromInit = MiniMaxModelID("MiniMax-M2") + #expect(fromLiteral == fromInit) + } + + // MARK: - Hashable + + @Test("Equal models have same hash") + func hashableEquality() { + let model1 = MiniMaxModelID("MiniMax-M2") + let model2 = MiniMaxModelID("MiniMax-M2") + #expect(model1.hashValue == model2.hashValue) + } + + @Test("Can be used in a Set") + func hashableInSet() { + var modelSet: Set = [] + modelSet.insert(.minimaxM2) + modelSet.insert(.minimaxM2_1) + modelSet.insert(.minimaxM2_5) + modelSet.insert(.minimaxM2) // duplicate + + #expect(modelSet.count == 3) + #expect(modelSet.contains(.minimaxM2)) + #expect(modelSet.contains(.minimaxM2_1)) + #expect(modelSet.contains(.minimaxM2_5)) + } + + @Test("Can be used as Dictionary key") + func hashableAsDictionaryKey() { + var dict: [MiniMaxModelID: String] = [:] + dict[.minimaxM2] = "M2" + dict[.minimaxM2_1] = "M2.1" + + #expect(dict[.minimaxM2] == "M2") + #expect(dict[.minimaxM2_1] == "M2.1") + #expect(dict[.minimaxM2_5] == nil) + } + + // MARK: - Equatable + + @Test("Same raw values are equal") + func equatable() { + let a = MiniMaxModelID("MiniMax-M2") + let b = MiniMaxModelID("MiniMax-M2") + #expect(a == b) + } + + @Test("Different raw values are not equal") + func notEquatable() { + let a = MiniMaxModelID("MiniMax-M2") + let b = MiniMaxModelID("MiniMax-M2.1") + #expect(a != b) + } + + @Test("Static constant equals manually created instance with same value") + func staticEqualsManual() { + let manual = MiniMaxModelID("MiniMax-M2") + #expect(MiniMaxModelID.minimaxM2 == manual) + } + + // MARK: - Codable Round-Trip + + @Test("Codable round-trip for static model") + func codableRoundTripStatic() throws { + let original = MiniMaxModelID.minimaxM2 + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(MiniMaxModelID.self, from: data) + #expect(original == decoded) + #expect(decoded.rawValue == "MiniMax-M2") + } + + @Test("Codable round-trip for custom model") + func codableRoundTripCustom() throws { + let original = MiniMaxModelID("my-custom-model") + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(MiniMaxModelID.self, from: data) + #expect(original == decoded) + #expect(decoded.rawValue == "my-custom-model") + } + + @Test("Codable round-trip preserves all static models") + func codableRoundTripAllStatic() throws { + let models: [MiniMaxModelID] = [.minimaxM2, .minimaxM2_1, .minimaxM2_5] + let encoder = JSONEncoder() + let decoder = JSONDecoder() + + for model in models { + let data = try encoder.encode(model) + let decoded = try decoder.decode(MiniMaxModelID.self, from: data) + #expect(model == decoded, "Round-trip failed for \(model.rawValue)") + } + } + + @Test("Encodes as single string value") + func encodesAsSingleValue() throws { + let model = MiniMaxModelID.minimaxM2 + let data = try JSONEncoder().encode(model) + let jsonString = String(data: data, encoding: .utf8)! + #expect(jsonString == "\"MiniMax-M2\"") + } + + @Test("Decodes from bare string") + func decodesFromBareString() throws { + let json = "\"MiniMax-M2.5\"".data(using: .utf8)! + let decoded = try JSONDecoder().decode(MiniMaxModelID.self, from: json) + #expect(decoded.rawValue == "MiniMax-M2.5") + #expect(decoded == .minimaxM2_5) + } + + @Test("Codable round-trip in array context") + func codableRoundTripArray() throws { + let models: [MiniMaxModelID] = [.minimaxM2, .minimaxM2_1, .minimaxM2_5] + let data = try JSONEncoder().encode(models) + let decoded = try JSONDecoder().decode([MiniMaxModelID].self, from: data) + #expect(decoded == models) + } + + // MARK: - Edge Cases + + @Test("Empty string model ID") + func emptyStringModelID() { + let model = MiniMaxModelID("") + #expect(model.rawValue == "") + #expect(model.displayName == "") + #expect(model.description == "[MiniMax] ") + } + + @Test("Model ID with special characters") + func specialCharactersModelID() { + let model = MiniMaxModelID("model/v2.0-beta+rc1") + #expect(model.rawValue == "model/v2.0-beta+rc1") + } +} + +#endif // CONDUIT_TRAIT_MINIMAX && CONDUIT_TRAIT_OPENAI diff --git a/Tests/ConduitTests/Providers/OpenAI/AzureConfigurationTests.swift b/Tests/ConduitTests/Providers/OpenAI/AzureConfigurationTests.swift new file mode 100644 index 0000000..277463d --- /dev/null +++ b/Tests/ConduitTests/Providers/OpenAI/AzureConfigurationTests.swift @@ -0,0 +1,222 @@ +// AzureConfigurationTests.swift +// Conduit Tests +// +// Tests for AzureConfiguration, ContentFilteringMode, and related types. + +#if CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER +import Foundation +import Testing +@testable import Conduit + +@Suite("AzureConfiguration Tests") +struct AzureConfigurationTests { + + // MARK: - Initialization + + @Test("Default init has expected values") + func defaultInit() { + let config = AzureConfiguration(resource: "my-resource", deployment: "gpt4-deploy") + #expect(config.resource == "my-resource") + #expect(config.deployment == "gpt4-deploy") + #expect(config.apiVersion == "2024-02-15-preview") + #expect(config.contentFiltering == .default) + #expect(config.enableStreaming == true) + #expect(config.region == nil) + } + + @Test("Custom init preserves all values") + func customInit() { + let config = AzureConfiguration( + resource: "res", + deployment: "dep", + apiVersion: "2024-02-01", + contentFiltering: .strict, + enableStreaming: false, + region: "eastus" + ) + #expect(config.resource == "res") + #expect(config.deployment == "dep") + #expect(config.apiVersion == "2024-02-01") + #expect(config.contentFiltering == .strict) + #expect(config.enableStreaming == false) + #expect(config.region == "eastus") + } + + // MARK: - URL Generation + + @Test("baseURL includes resource name") + func baseURLIncludesResource() { + let config = AzureConfiguration(resource: "my-company-openai", deployment: "dep") + #expect(config.baseURL.absoluteString == "https://my-company-openai.openai.azure.com/openai") + } + + @Test("chatCompletionsURL includes deployment and api-version") + func chatCompletionsURL() { + let config = AzureConfiguration( + resource: "res", + deployment: "gpt4-deploy", + apiVersion: "2024-02-15-preview" + ) + let url = config.chatCompletionsURL + #expect(url.absoluteString.contains("deployments/gpt4-deploy/chat/completions")) + #expect(url.absoluteString.contains("api-version=2024-02-15-preview")) + } + + @Test("embeddingsURL includes deployment and api-version") + func embeddingsURL() { + let config = AzureConfiguration( + resource: "res", + deployment: "embed-deploy", + apiVersion: "2024-02-01" + ) + let url = config.embeddingsURL + #expect(url.absoluteString.contains("deployments/embed-deploy/embeddings")) + #expect(url.absoluteString.contains("api-version=2024-02-01")) + } + + @Test("imagesGenerationsURL includes deployment and api-version") + func imagesGenerationsURL() { + let config = AzureConfiguration( + resource: "res", + deployment: "dalle-deploy", + apiVersion: "2024-02-15-preview" + ) + let url = config.imagesGenerationsURL + #expect(url.absoluteString.contains("deployments/dalle-deploy/images/generations")) + #expect(url.absoluteString.contains("api-version=2024-02-15-preview")) + } + + // MARK: - Fluent API + + @Test("Fluent apiVersion returns updated copy") + func fluentApiVersion() { + let config = AzureConfiguration(resource: "res", deployment: "dep") + .apiVersion("2024-02-01") + #expect(config.apiVersion == "2024-02-01") + } + + @Test("Fluent contentFiltering returns updated copy") + func fluentContentFiltering() { + let config = AzureConfiguration(resource: "res", deployment: "dep") + .contentFiltering(.strict) + #expect(config.contentFiltering == .strict) + } + + @Test("Fluent withStrictFiltering sets strict mode") + func fluentWithStrictFiltering() { + let config = AzureConfiguration(resource: "res", deployment: "dep") + .withStrictFiltering() + #expect(config.contentFiltering == .strict) + } + + @Test("Fluent streaming returns updated copy") + func fluentStreaming() { + let config = AzureConfiguration(resource: "res", deployment: "dep") + .streaming(false) + #expect(config.enableStreaming == false) + } + + @Test("Fluent region returns updated copy") + func fluentRegion() { + let config = AzureConfiguration(resource: "res", deployment: "dep") + .region("westus2") + #expect(config.region == "westus2") + } + + // MARK: - API Versions + + @Test("Known API versions have expected values") + func knownApiVersions() { + #expect(AzureConfiguration.APIVersion.latestStable == "2024-02-15-preview") + #expect(AzureConfiguration.APIVersion.ga2024 == "2024-02-01") + #expect(AzureConfiguration.APIVersion.legacy == "2023-05-15") + } + + // MARK: - Codable + + @Test("AzureConfiguration round-trips through JSON") + func codableRoundTrip() throws { + let original = AzureConfiguration( + resource: "my-resource", + deployment: "gpt4", + apiVersion: "2024-02-15-preview", + contentFiltering: .strict, + enableStreaming: false, + region: "eastus" + ) + + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(AzureConfiguration.self, from: data) + + #expect(decoded.resource == original.resource) + #expect(decoded.deployment == original.deployment) + #expect(decoded.apiVersion == original.apiVersion) + #expect(decoded.contentFiltering == original.contentFiltering) + #expect(decoded.enableStreaming == original.enableStreaming) + #expect(decoded.region == original.region) + } + + @Test("AzureConfiguration with nil region round-trips") + func codableRoundTripNilRegion() throws { + let original = AzureConfiguration(resource: "res", deployment: "dep") + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(AzureConfiguration.self, from: data) + #expect(decoded.region == nil) + } + + // MARK: - Hashable / Equatable + + @Test("Equal configurations are equal") + func equalConfigurationsEqual() { + let a = AzureConfiguration(resource: "res", deployment: "dep") + let b = AzureConfiguration(resource: "res", deployment: "dep") + #expect(a == b) + } + + @Test("Different configurations are not equal") + func differentConfigurationsNotEqual() { + let a = AzureConfiguration(resource: "res1", deployment: "dep") + let b = AzureConfiguration(resource: "res2", deployment: "dep") + #expect(a != b) + } + + @Test("Sendable conformance compiles") + func sendableConformance() { + let config: Sendable = AzureConfiguration(resource: "r", deployment: "d") + #expect(config is AzureConfiguration) + } +} + +// MARK: - ContentFilteringMode Tests + +@Suite("ContentFilteringMode Tests") +struct ContentFilteringModeTests { + + @Test("All cases have correct raw values") + func rawValues() { + #expect(ContentFilteringMode.default.rawValue == "default") + #expect(ContentFilteringMode.strict.rawValue == "strict") + #expect(ContentFilteringMode.reduced.rawValue == "reduced") + #expect(ContentFilteringMode.none.rawValue == "none") + } + + @Test("Description returns human-readable string") + func descriptions() { + #expect(ContentFilteringMode.default.description == "Default filtering") + #expect(ContentFilteringMode.strict.description == "Strict filtering") + #expect(ContentFilteringMode.reduced.description == "Reduced filtering") + #expect(ContentFilteringMode.none.description == "No filtering") + } + + @Test("Codable round-trip for all cases") + func codableRoundTrip() throws { + let cases: [ContentFilteringMode] = [.default, .strict, .reduced, .none] + for mode in cases { + let data = try JSONEncoder().encode(mode) + let decoded = try JSONDecoder().decode(ContentFilteringMode.self, from: data) + #expect(decoded == mode) + } + } +} + +#endif // CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER diff --git a/Tests/ConduitTests/Providers/OpenAI/OllamaConfigurationTests.swift b/Tests/ConduitTests/Providers/OpenAI/OllamaConfigurationTests.swift new file mode 100644 index 0000000..96b2b1c --- /dev/null +++ b/Tests/ConduitTests/Providers/OpenAI/OllamaConfigurationTests.swift @@ -0,0 +1,412 @@ +// OllamaConfigurationTests.swift +// Conduit Tests +// +// Tests for OllamaConfiguration, OllamaModelStatus, and OllamaServerStatus. + +#if CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER +import Foundation +import Testing +@testable import Conduit + +@Suite("OllamaConfiguration Tests") +struct OllamaConfigurationTests { + + // MARK: - Default Initialization + + @Test("Default init has expected values") + func defaultInit() { + let config = OllamaConfiguration() + #expect(config.keepAlive == nil) + #expect(config.pullOnMissing == false) + #expect(config.numParallel == nil) + #expect(config.numGPU == nil) + #expect(config.mainGPU == nil) + #expect(config.lowVRAM == false) + #expect(config.numCtx == nil) + #expect(config.healthCheck == true) + #expect(config.healthCheckTimeout == 5.0) + } + + // MARK: - Custom Initialization + + @Test("Custom init preserves all values") + func customInit() { + let config = OllamaConfiguration( + keepAlive: "10m", + pullOnMissing: true, + numParallel: 4, + numGPU: 32, + mainGPU: 0, + lowVRAM: true, + numCtx: 4096, + healthCheck: false, + healthCheckTimeout: 10.0 + ) + #expect(config.keepAlive == "10m") + #expect(config.pullOnMissing == true) + #expect(config.numParallel == 4) + #expect(config.numGPU == 32) + #expect(config.mainGPU == 0) + #expect(config.lowVRAM == true) + #expect(config.numCtx == 4096) + #expect(config.healthCheck == false) + #expect(config.healthCheckTimeout == 10.0) + } + + // MARK: - Static Presets + + @Test("Default preset matches default init") + func defaultPreset() { + let config = OllamaConfiguration.default + #expect(config.keepAlive == nil) + #expect(config.pullOnMissing == false) + #expect(config.healthCheck == true) + } + + @Test("lowMemory preset has correct values") + func lowMemoryPreset() { + let config = OllamaConfiguration.lowMemory + #expect(config.keepAlive == "1m") + #expect(config.lowVRAM == true) + } + + @Test("interactive preset has longer keep-alive") + func interactivePreset() { + let config = OllamaConfiguration.interactive + #expect(config.keepAlive == "30m") + #expect(config.healthCheck == true) + } + + @Test("batch preset unloads immediately and skips health check") + func batchPreset() { + let config = OllamaConfiguration.batch + #expect(config.keepAlive == "0") + #expect(config.healthCheck == false) + } + + @Test("alwaysOn preset keeps models loaded indefinitely") + func alwaysOnPreset() { + let config = OllamaConfiguration.alwaysOn + #expect(config.keepAlive == "-1") + #expect(config.healthCheck == true) + } + + // MARK: - Options Generation + + @Test("options returns empty dict when no GPU settings") + func optionsEmpty() { + let config = OllamaConfiguration() + let opts = config.options() + #expect(opts.isEmpty) + } + + @Test("options includes numGPU when set") + func optionsIncludesNumGPU() { + let config = OllamaConfiguration(numGPU: 16) + let opts = config.options() + #expect(opts["num_gpu"] as? Int == 16) + } + + @Test("options includes mainGPU when set") + func optionsIncludesMainGPU() { + let config = OllamaConfiguration(mainGPU: 1) + let opts = config.options() + #expect(opts["main_gpu"] as? Int == 1) + } + + @Test("options includes lowVRAM when true") + func optionsIncludesLowVRAM() { + let config = OllamaConfiguration(lowVRAM: true) + let opts = config.options() + #expect(opts["low_vram"] as? Bool == true) + } + + @Test("options does not include lowVRAM when false") + func optionsOmitsLowVRAMWhenFalse() { + let config = OllamaConfiguration(lowVRAM: false) + let opts = config.options() + #expect(opts["low_vram"] == nil) + } + + @Test("options includes numCtx when set") + func optionsIncludesNumCtx() { + let config = OllamaConfiguration(numCtx: 8192) + let opts = config.options() + #expect(opts["num_ctx"] as? Int == 8192) + } + + @Test("options includes all GPU settings when all are set") + func optionsAllGPUSettings() { + let config = OllamaConfiguration( + numGPU: 32, + mainGPU: 0, + lowVRAM: true, + numCtx: 4096 + ) + let opts = config.options() + #expect(opts["num_gpu"] as? Int == 32) + #expect(opts["main_gpu"] as? Int == 0) + #expect(opts["low_vram"] as? Bool == true) + #expect(opts["num_ctx"] as? Int == 4096) + } + + // MARK: - Fluent API + + @Test("Fluent keepAlive returns updated copy") + func fluentKeepAlive() { + let config = OllamaConfiguration.default.keepAlive("15m") + #expect(config.keepAlive == "15m") + } + + @Test("Fluent pullOnMissing returns updated copy") + func fluentPullOnMissing() { + let config = OllamaConfiguration.default.pullOnMissing(true) + #expect(config.pullOnMissing == true) + } + + @Test("Fluent numParallel returns updated copy") + func fluentNumParallel() { + let config = OllamaConfiguration.default.numParallel(8) + #expect(config.numParallel == 8) + } + + @Test("Fluent numGPU returns updated copy") + func fluentNumGPU() { + let config = OllamaConfiguration.default.numGPU(24) + #expect(config.numGPU == 24) + } + + @Test("Fluent cpuOnly sets numGPU to 0") + func fluentCpuOnly() { + let config = OllamaConfiguration.default.cpuOnly() + #expect(config.numGPU == 0) + } + + @Test("Fluent lowVRAM returns updated copy") + func fluentLowVRAM() { + let config = OllamaConfiguration.default.lowVRAM(true) + #expect(config.lowVRAM == true) + } + + @Test("Fluent contextSize returns updated copy") + func fluentContextSize() { + let config = OllamaConfiguration.default.contextSize(16384) + #expect(config.numCtx == 16384) + } + + @Test("Fluent healthCheck returns updated copy") + func fluentHealthCheck() { + let config = OllamaConfiguration.default.healthCheck(false) + #expect(config.healthCheck == false) + } + + @Test("Fluent API chaining works correctly") + func fluentChaining() { + let config = OllamaConfiguration.default + .keepAlive("10m") + .pullOnMissing(true) + .numGPU(16) + .contextSize(4096) + .healthCheck(false) + + #expect(config.keepAlive == "10m") + #expect(config.pullOnMissing == true) + #expect(config.numGPU == 16) + #expect(config.numCtx == 4096) + #expect(config.healthCheck == false) + } + + // MARK: - Codable + + @Test("OllamaConfiguration round-trips through JSON") + func codableRoundTrip() throws { + let original = OllamaConfiguration( + keepAlive: "5m", + pullOnMissing: true, + numParallel: 2, + numGPU: 16, + mainGPU: 0, + lowVRAM: true, + numCtx: 2048, + healthCheck: false, + healthCheckTimeout: 10.0 + ) + + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OllamaConfiguration.self, from: data) + + #expect(decoded.keepAlive == original.keepAlive) + #expect(decoded.pullOnMissing == original.pullOnMissing) + #expect(decoded.numParallel == original.numParallel) + #expect(decoded.numGPU == original.numGPU) + #expect(decoded.mainGPU == original.mainGPU) + #expect(decoded.lowVRAM == original.lowVRAM) + #expect(decoded.numCtx == original.numCtx) + #expect(decoded.healthCheck == original.healthCheck) + #expect(decoded.healthCheckTimeout == original.healthCheckTimeout) + } + + @Test("Default OllamaConfiguration round-trips through JSON") + func codableDefaultRoundTrip() throws { + let original = OllamaConfiguration.default + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OllamaConfiguration.self, from: data) + #expect(decoded == original) + } + + @Test("All presets round-trip through JSON") + func codablePresetsRoundTrip() throws { + let presets: [OllamaConfiguration] = [ + .default, .lowMemory, .interactive, .batch, .alwaysOn + ] + for original in presets { + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OllamaConfiguration.self, from: data) + #expect(decoded == original) + } + } + + // MARK: - Hashable / Equatable + + @Test("Equal configurations are equal") + func equalConfigurations() { + let a = OllamaConfiguration.default + let b = OllamaConfiguration.default + #expect(a == b) + } + + @Test("Different configurations are not equal") + func differentConfigurations() { + let a = OllamaConfiguration.default + let b = OllamaConfiguration.lowMemory + #expect(a != b) + } + + @Test("Sendable conformance compiles") + func sendableConformance() { + let config: Sendable = OllamaConfiguration.default + #expect(config is OllamaConfiguration) + } +} + +// MARK: - OllamaModelStatus Tests + +@Suite("OllamaModelStatus Tests") +struct OllamaModelStatusTests { + + @Test("available status exists") + func availableStatus() { + let status = OllamaModelStatus.available + if case .available = status { + // pass + } else { + Issue.record("Expected .available") + } + } + + @Test("pulling status stores progress") + func pullingStatus() { + let status = OllamaModelStatus.pulling(progress: 0.75) + if case .pulling(let progress) = status { + #expect(progress == 0.75) + } else { + Issue.record("Expected .pulling") + } + } + + @Test("notAvailable status exists") + func notAvailableStatus() { + let status = OllamaModelStatus.notAvailable + if case .notAvailable = status { + // pass + } else { + Issue.record("Expected .notAvailable") + } + } + + @Test("unknown status exists") + func unknownStatus() { + let status = OllamaModelStatus.unknown + if case .unknown = status { + // pass + } else { + Issue.record("Expected .unknown") + } + } + + @Test("Same statuses are equal") + func sameStatusesEqual() { + #expect(OllamaModelStatus.available == .available) + #expect(OllamaModelStatus.notAvailable == .notAvailable) + #expect(OllamaModelStatus.unknown == .unknown) + #expect(OllamaModelStatus.pulling(progress: 0.5) == .pulling(progress: 0.5)) + } + + @Test("Different statuses are not equal") + func differentStatusesNotEqual() { + #expect(OllamaModelStatus.available != .notAvailable) + #expect(OllamaModelStatus.pulling(progress: 0.5) != .pulling(progress: 0.7)) + } +} + +// MARK: - OllamaServerStatus Tests + +@Suite("OllamaServerStatus Tests") +struct OllamaServerStatusTests { + + @Test("running status exists") + func runningStatus() { + let status = OllamaServerStatus.running + if case .running = status { + // pass + } else { + Issue.record("Expected .running") + } + } + + @Test("notResponding status exists") + func notRespondingStatus() { + let status = OllamaServerStatus.notResponding + if case .notResponding = status { + // pass + } else { + Issue.record("Expected .notResponding") + } + } + + @Test("error status stores message") + func errorStatus() { + let status = OllamaServerStatus.error("connection refused") + if case .error(let message) = status { + #expect(message == "connection refused") + } else { + Issue.record("Expected .error") + } + } + + @Test("unknown status exists") + func unknownStatus() { + let status = OllamaServerStatus.unknown + if case .unknown = status { + // pass + } else { + Issue.record("Expected .unknown") + } + } + + @Test("Same statuses are equal") + func sameStatusesEqual() { + #expect(OllamaServerStatus.running == .running) + #expect(OllamaServerStatus.notResponding == .notResponding) + #expect(OllamaServerStatus.unknown == .unknown) + #expect(OllamaServerStatus.error("msg") == .error("msg")) + } + + @Test("Different statuses are not equal") + func differentStatusesNotEqual() { + #expect(OllamaServerStatus.running != .notResponding) + #expect(OllamaServerStatus.error("a") != .error("b")) + } +} + +#endif // CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER diff --git a/Tests/ConduitTests/Providers/OpenAI/OpenAIAuthenticationTests.swift b/Tests/ConduitTests/Providers/OpenAI/OpenAIAuthenticationTests.swift new file mode 100644 index 0000000..da7e2b8 --- /dev/null +++ b/Tests/ConduitTests/Providers/OpenAI/OpenAIAuthenticationTests.swift @@ -0,0 +1,335 @@ +// OpenAIAuthenticationTests.swift +// Conduit Tests +// +// Tests for OpenAIAuthentication enum cases, resolution, headers, and security. + +#if CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER +import Foundation +import Testing +@testable import Conduit + +@Suite("OpenAIAuthentication Tests") +struct OpenAIAuthenticationTests { + + // MARK: - Enum Cases + + @Test("none case exists and resolves to nil") + func noneResolves() { + let auth = OpenAIAuthentication.none + #expect(auth.resolve() == nil) + } + + @Test("bearer case stores and resolves the token") + func bearerResolves() { + let auth = OpenAIAuthentication.bearer("sk-test123") + #expect(auth.resolve() == "sk-test123") + } + + @Test("apiKey case stores key and resolves it") + func apiKeyResolves() { + let auth = OpenAIAuthentication.apiKey("azure-key", headerName: "api-key") + #expect(auth.resolve() == "azure-key") + } + + @Test("apiKey case uses default headerName of api-key") + func apiKeyDefaultHeader() { + let auth = OpenAIAuthentication.apiKey("key123") + #expect(auth.headerName == "api-key") + } + + @Test("environment case resolves from environment variable") + func environmentResolves() { + // This test checks the mechanism; actual env value depends on environment + let auth = OpenAIAuthentication.environment("LIKELY_UNSET_VAR_FOR_TESTING_12345") + #expect(auth.resolve() == nil) + } + + @Test("auto case checks known env variables") + func autoChecksEnvVars() { + let auth = OpenAIAuthentication.auto + // Cannot guarantee env vars are set, but the mechanism should not crash + _ = auth.resolve() + } + + // MARK: - isConfigured + + @Test("none is considered configured") + func noneIsConfigured() { + #expect(OpenAIAuthentication.none.isConfigured) + } + + @Test("bearer with non-empty token is configured") + func bearerNonEmptyIsConfigured() { + #expect(OpenAIAuthentication.bearer("sk-test").isConfigured) + } + + @Test("bearer with empty token is not configured") + func bearerEmptyNotConfigured() { + #expect(!OpenAIAuthentication.bearer("").isConfigured) + } + + @Test("apiKey with non-empty key is configured") + func apiKeyNonEmptyIsConfigured() { + #expect(OpenAIAuthentication.apiKey("key123", headerName: "api-key").isConfigured) + } + + @Test("apiKey with empty key is not configured") + func apiKeyEmptyNotConfigured() { + #expect(!OpenAIAuthentication.apiKey("", headerName: "api-key").isConfigured) + } + + // MARK: - Header Name + + @Test("none has nil header name") + func noneHeaderName() { + #expect(OpenAIAuthentication.none.headerName == nil) + } + + @Test("bearer uses Authorization header") + func bearerHeaderName() { + #expect(OpenAIAuthentication.bearer("token").headerName == "Authorization") + } + + @Test("apiKey uses custom header name") + func apiKeyHeaderName() { + #expect(OpenAIAuthentication.apiKey("key", headerName: "x-api-key").headerName == "x-api-key") + } + + @Test("environment uses Authorization header") + func environmentHeaderName() { + #expect(OpenAIAuthentication.environment("VAR").headerName == "Authorization") + } + + @Test("auto uses Authorization header") + func autoHeaderName() { + #expect(OpenAIAuthentication.auto.headerName == "Authorization") + } + + // MARK: - Header Value + + @Test("none has nil header value") + func noneHeaderValue() { + #expect(OpenAIAuthentication.none.headerValue == nil) + } + + @Test("bearer header value has Bearer prefix") + func bearerHeaderValue() { + #expect(OpenAIAuthentication.bearer("sk-test").headerValue == "Bearer sk-test") + } + + @Test("apiKey header value is the raw key") + func apiKeyHeaderValue() { + #expect(OpenAIAuthentication.apiKey("my-key", headerName: "api-key").headerValue == "my-key") + } + + // MARK: - apply(to:) + + @Test("apply adds auth header to URLRequest") + func applyAddsHeader() { + var request = URLRequest(url: URL(string: "https://api.openai.com/v1/chat/completions")!) + let auth = OpenAIAuthentication.bearer("sk-test") + auth.apply(to: &request) + #expect(request.value(forHTTPHeaderField: "Authorization") == "Bearer sk-test") + } + + @Test("apply with none does not modify request") + func applyNoneDoesNotModify() { + var request = URLRequest(url: URL(string: "https://localhost/v1")!) + let auth = OpenAIAuthentication.none + auth.apply(to: &request) + #expect(request.value(forHTTPHeaderField: "Authorization") == nil) + } + + @Test("apply with apiKey sets custom header") + func applyApiKey() { + var request = URLRequest(url: URL(string: "https://azure.com/openai")!) + let auth = OpenAIAuthentication.apiKey("azure-key", headerName: "api-key") + auth.apply(to: &request) + #expect(request.value(forHTTPHeaderField: "api-key") == "azure-key") + } + + // MARK: - Convenience Initializers + + @Test("from(apiKey:) creates bearer authentication") + func fromApiKey() { + let auth = OpenAIAuthentication.from(apiKey: "sk-test123") + #expect(auth == .bearer("sk-test123")) + } + + @Test("for(endpoint:apiKey:) returns .none for Ollama") + func forEndpointOllama() { + let auth = OpenAIAuthentication.for(endpoint: .ollama()) + #expect(auth == .none) + } + + @Test("for(endpoint:apiKey:) returns bearer for OpenAI with key") + func forEndpointOpenAIWithKey() { + let auth = OpenAIAuthentication.for(endpoint: .openAI, apiKey: "sk-test") + #expect(auth == .bearer("sk-test")) + } + + @Test("for(endpoint:apiKey:) returns environment for OpenAI without key") + func forEndpointOpenAIWithoutKey() { + let auth = OpenAIAuthentication.for(endpoint: .openAI) + #expect(auth == .environment("OPENAI_API_KEY")) + } + + @Test("for(endpoint:apiKey:) returns bearer for OpenRouter with key") + func forEndpointOpenRouterWithKey() { + let auth = OpenAIAuthentication.for(endpoint: .openRouter, apiKey: "or-key") + #expect(auth == .bearer("or-key")) + } + + @Test("for(endpoint:apiKey:) returns environment for OpenRouter without key") + func forEndpointOpenRouterWithoutKey() { + let auth = OpenAIAuthentication.for(endpoint: .openRouter) + #expect(auth == .environment("OPENROUTER_API_KEY")) + } + + @Test("for(endpoint:apiKey:) returns apiKey for Azure with key") + func forEndpointAzureWithKey() { + let auth = OpenAIAuthentication.for( + endpoint: .azure(resource: "res", deployment: "dep", apiVersion: "v1"), + apiKey: "azure-key" + ) + #expect(auth == .apiKey("azure-key", headerName: "api-key")) + } + + @Test("for(endpoint:apiKey:) returns environment for Azure without key") + func forEndpointAzureWithoutKey() { + let auth = OpenAIAuthentication.for( + endpoint: .azure(resource: "res", deployment: "dep", apiVersion: "v1") + ) + #expect(auth == .environment("AZURE_OPENAI_API_KEY")) + } + + @Test("for(endpoint:apiKey:) returns bearer for custom with key") + func forEndpointCustomWithKey() { + let url = URL(string: "https://custom.com/v1")! + let auth = OpenAIAuthentication.for(endpoint: .custom(url), apiKey: "my-key") + #expect(auth == .bearer("my-key")) + } + + @Test("for(endpoint:apiKey:) returns auto for custom without key") + func forEndpointCustomWithoutKey() { + let url = URL(string: "https://custom.com/v1")! + let auth = OpenAIAuthentication.for(endpoint: .custom(url)) + #expect(auth == .auto) + } + + // MARK: - Equatable + + @Test("Same bearer tokens are equal") + func sameBearerTokensEqual() { + #expect(OpenAIAuthentication.bearer("token") == .bearer("token")) + } + + @Test("Different bearer tokens are not equal") + func differentBearerTokensNotEqual() { + #expect(OpenAIAuthentication.bearer("token1") != .bearer("token2")) + } + + @Test("Same apiKey with same header are equal") + func sameApiKeysEqual() { + let a = OpenAIAuthentication.apiKey("key", headerName: "api-key") + let b = OpenAIAuthentication.apiKey("key", headerName: "api-key") + #expect(a == b) + } + + @Test("Same apiKey with different headers are not equal") + func sameKeyDifferentHeaderNotEqual() { + let a = OpenAIAuthentication.apiKey("key", headerName: "api-key") + let b = OpenAIAuthentication.apiKey("key", headerName: "x-api-key") + #expect(a != b) + } + + @Test("Different enum cases are not equal") + func differentCasesNotEqual() { + #expect(OpenAIAuthentication.none != .auto) + #expect(OpenAIAuthentication.bearer("key") != .apiKey("key")) + #expect(OpenAIAuthentication.bearer("key") != .none) + } + + @Test("none equals none") + func noneEqualsNone() { + #expect(OpenAIAuthentication.none == .none) + } + + @Test("auto equals auto") + func autoEqualsAuto() { + #expect(OpenAIAuthentication.auto == .auto) + } + + @Test("environment with same variable name are equal") + func sameEnvironmentEqual() { + #expect(OpenAIAuthentication.environment("VAR") == .environment("VAR")) + } + + @Test("environment with different variable names are not equal") + func differentEnvironmentNotEqual() { + #expect(OpenAIAuthentication.environment("VAR1") != .environment("VAR2")) + } + + // MARK: - Hashable + + @Test("Hashing works for set usage") + func hashableForSet() { + var set: Set = [] + set.insert(.none) + set.insert(.bearer("key")) + set.insert(.auto) + // Due to security-aware hashing, different bearers hash the same + // but the set still works correctly via Equatable + set.insert(.bearer("other-key")) + #expect(set.count >= 3) + } + + // MARK: - Debug Description + + @Test("debugDescription redacts bearer token") + func debugDescriptionRedactsBearer() { + let auth = OpenAIAuthentication.bearer("sk-secret-key") + #expect(auth.debugDescription == "OpenAIAuthentication.bearer(***)") + #expect(!auth.debugDescription.contains("sk-secret-key")) + } + + @Test("debugDescription redacts apiKey") + func debugDescriptionRedactsApiKey() { + let auth = OpenAIAuthentication.apiKey("my-secret", headerName: "api-key") + #expect(auth.debugDescription.contains("***")) + #expect(!auth.debugDescription.contains("my-secret")) + #expect(auth.debugDescription.contains("api-key")) + } + + @Test("debugDescription shows environment variable name") + func debugDescriptionShowsEnvVar() { + let auth = OpenAIAuthentication.environment("OPENAI_API_KEY") + #expect(auth.debugDescription.contains("OPENAI_API_KEY")) + } + + @Test("debugDescription for none") + func debugDescriptionNone() { + #expect(OpenAIAuthentication.none.debugDescription == "OpenAIAuthentication.none") + } + + @Test("debugDescription for auto") + func debugDescriptionAuto() { + #expect(OpenAIAuthentication.auto.debugDescription == "OpenAIAuthentication.auto") + } + + @Test("description matches debugDescription") + func descriptionMatchesDebug() { + let auth = OpenAIAuthentication.bearer("test") + #expect(auth.description == auth.debugDescription) + } + + // MARK: - Sendable + + @Test("Sendable conformance compiles") + func sendableConformance() { + let auth: Sendable = OpenAIAuthentication.bearer("test") + #expect(auth is OpenAIAuthentication) + } +} + +#endif // CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER diff --git a/Tests/ConduitTests/Providers/OpenAI/OpenAICapabilitiesTests.swift b/Tests/ConduitTests/Providers/OpenAI/OpenAICapabilitiesTests.swift new file mode 100644 index 0000000..ef8727e --- /dev/null +++ b/Tests/ConduitTests/Providers/OpenAI/OpenAICapabilitiesTests.swift @@ -0,0 +1,297 @@ +// OpenAICapabilitiesTests.swift +// Conduit Tests +// +// Tests for OpenAICapabilities option set flags, presets, and methods. + +#if CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER +import Foundation +import Testing +@testable import Conduit + +@Suite("OpenAICapabilities Tests") +struct OpenAICapabilitiesTests { + + // MARK: - Individual Capability Flags + + @Test("Each capability flag has a unique raw value") + func uniqueRawValues() { + let flags: [OpenAICapabilities] = [ + .textGeneration, .streaming, .embeddings, + .imageGeneration, .transcription, .functionCalling, + .jsonMode, .vision, .textToSpeech, + .parallelFunctionCalling, .structuredOutputs + ] + let rawValues = flags.map(\.rawValue) + let uniqueValues = Set(rawValues) + #expect(rawValues.count == uniqueValues.count) + } + + @Test("Capability raw values are powers of 2") + func rawValuesArePowersOf2() { + #expect(OpenAICapabilities.textGeneration.rawValue == 1) + #expect(OpenAICapabilities.streaming.rawValue == 2) + #expect(OpenAICapabilities.embeddings.rawValue == 4) + #expect(OpenAICapabilities.imageGeneration.rawValue == 8) + #expect(OpenAICapabilities.transcription.rawValue == 16) + #expect(OpenAICapabilities.functionCalling.rawValue == 32) + #expect(OpenAICapabilities.jsonMode.rawValue == 64) + #expect(OpenAICapabilities.vision.rawValue == 128) + #expect(OpenAICapabilities.textToSpeech.rawValue == 256) + #expect(OpenAICapabilities.parallelFunctionCalling.rawValue == 512) + #expect(OpenAICapabilities.structuredOutputs.rawValue == 1024) + } + + // MARK: - Preset Capability Sets + + @Test("OpenAI preset includes all expected capabilities") + func openAIPreset() { + let caps = OpenAICapabilities.openAI + #expect(caps.contains(.textGeneration)) + #expect(caps.contains(.streaming)) + #expect(caps.contains(.embeddings)) + #expect(caps.contains(.imageGeneration)) + #expect(caps.contains(.transcription)) + #expect(caps.contains(.functionCalling)) + #expect(caps.contains(.jsonMode)) + #expect(caps.contains(.vision)) + #expect(caps.contains(.textToSpeech)) + #expect(caps.contains(.parallelFunctionCalling)) + #expect(caps.contains(.structuredOutputs)) + } + + @Test("OpenRouter preset includes expected capabilities") + func openRouterPreset() { + let caps = OpenAICapabilities.openRouter + #expect(caps.contains(.textGeneration)) + #expect(caps.contains(.streaming)) + #expect(caps.contains(.embeddings)) + #expect(caps.contains(.functionCalling)) + #expect(caps.contains(.jsonMode)) + #expect(caps.contains(.vision)) + #expect(!caps.contains(.imageGeneration)) + #expect(!caps.contains(.transcription)) + #expect(!caps.contains(.textToSpeech)) + } + + @Test("Ollama preset includes expected capabilities") + func ollamaPreset() { + let caps = OpenAICapabilities.ollama + #expect(caps.contains(.textGeneration)) + #expect(caps.contains(.streaming)) + #expect(caps.contains(.embeddings)) + #expect(caps.contains(.vision)) + #expect(!caps.contains(.imageGeneration)) + #expect(!caps.contains(.transcription)) + #expect(!caps.contains(.functionCalling)) + #expect(!caps.contains(.textToSpeech)) + } + + @Test("textOnly preset contains only textGeneration and streaming") + func textOnlyPreset() { + let caps = OpenAICapabilities.textOnly + #expect(caps.contains(.textGeneration)) + #expect(caps.contains(.streaming)) + #expect(!caps.contains(.embeddings)) + #expect(!caps.contains(.imageGeneration)) + #expect(!caps.contains(.functionCalling)) + } + + @Test("all preset includes every capability") + func allPreset() { + let caps = OpenAICapabilities.all + #expect(caps == .openAI) + } + + // MARK: - supports / supportsAll / supportsAny + + @Test("supports returns true for contained capability") + func supportsContained() { + let caps = OpenAICapabilities.openAI + #expect(caps.supports(.textGeneration)) + #expect(caps.supports(.vision)) + } + + @Test("supports returns false for non-contained capability") + func supportsNotContained() { + let caps = OpenAICapabilities.textOnly + #expect(!caps.supports(.embeddings)) + #expect(!caps.supports(.imageGeneration)) + } + + @Test("supportsAll returns true when all capabilities present") + func supportsAllPresent() { + let caps = OpenAICapabilities.openAI + let required: OpenAICapabilities = [.textGeneration, .streaming, .vision] + #expect(caps.supportsAll(required)) + } + + @Test("supportsAll returns false when some capabilities missing") + func supportsAllMissing() { + let caps = OpenAICapabilities.textOnly + let required: OpenAICapabilities = [.textGeneration, .embeddings] + #expect(!caps.supportsAll(required)) + } + + @Test("supportsAny returns true when at least one capability present") + func supportsAnyPresent() { + let caps = OpenAICapabilities.textOnly + let check: OpenAICapabilities = [.textGeneration, .embeddings] + #expect(caps.supportsAny(check)) + } + + @Test("supportsAny returns false when no capabilities present") + func supportsAnyNonePresent() { + let caps = OpenAICapabilities.textOnly + let check: OpenAICapabilities = [.imageGeneration, .transcription] + #expect(!caps.supportsAny(check)) + } + + // MARK: - missing(from:) + + @Test("missing returns capabilities that are not present") + func missingReturnsAbsent() { + let caps = OpenAICapabilities.textOnly + let required: OpenAICapabilities = [.textGeneration, .embeddings, .vision] + let missing = caps.missing(from: required) + #expect(missing.contains(.embeddings)) + #expect(missing.contains(.vision)) + #expect(!missing.contains(.textGeneration)) + } + + @Test("missing returns empty when all capabilities present") + func missingReturnsEmptyWhenAllPresent() { + let caps = OpenAICapabilities.openAI + let required: OpenAICapabilities = [.textGeneration, .streaming] + let missing = caps.missing(from: required) + #expect(missing.isEmpty) + } + + // MARK: - Descriptions + + @Test("descriptions returns human-readable names") + func descriptionsHumanReadable() { + let caps: OpenAICapabilities = [.textGeneration, .streaming] + let descs = caps.descriptions + #expect(descs.contains("Text Generation")) + #expect(descs.contains("Streaming")) + #expect(descs.count == 2) + } + + @Test("descriptions returns empty array for no capabilities") + func descriptionsEmpty() { + let caps = OpenAICapabilities(rawValue: 0) + #expect(caps.descriptions.isEmpty) + } + + @Test("All individual capabilities have descriptions") + func allCapabilitiesHaveDescriptions() { + let allIndividual: [OpenAICapabilities] = [ + .textGeneration, .streaming, .embeddings, + .imageGeneration, .transcription, .functionCalling, + .jsonMode, .vision, .textToSpeech, + .parallelFunctionCalling, .structuredOutputs + ] + for cap in allIndividual { + #expect(!cap.descriptions.isEmpty) + #expect(cap.descriptions.count == 1) + } + } + + // MARK: - CustomStringConvertible + + @Test("Description for capabilities with flags lists them") + func descriptionWithFlags() { + let caps: OpenAICapabilities = [.textGeneration] + #expect(caps.description.contains("Text Generation")) + #expect(caps.description.hasPrefix("OpenAICapabilities(")) + } + + @Test("Description for empty capabilities says none") + func descriptionEmpty() { + let caps = OpenAICapabilities(rawValue: 0) + #expect(caps.description == "OpenAICapabilities(none)") + } + + // MARK: - OptionSet Operations + + @Test("Union of capability sets works") + func unionWorks() { + let a: OpenAICapabilities = [.textGeneration] + let b: OpenAICapabilities = [.streaming] + let combined = a.union(b) + #expect(combined.contains(.textGeneration)) + #expect(combined.contains(.streaming)) + } + + @Test("Intersection of capability sets works") + func intersectionWorks() { + let a: OpenAICapabilities = [.textGeneration, .streaming, .vision] + let b: OpenAICapabilities = [.streaming, .vision, .embeddings] + let common = a.intersection(b) + #expect(common.contains(.streaming)) + #expect(common.contains(.vision)) + #expect(!common.contains(.textGeneration)) + #expect(!common.contains(.embeddings)) + } + + @Test("Subtracting capability sets works") + func subtractingWorks() { + let a: OpenAICapabilities = [.textGeneration, .streaming, .vision] + let b: OpenAICapabilities = [.streaming] + let result = a.subtracting(b) + #expect(result.contains(.textGeneration)) + #expect(result.contains(.vision)) + #expect(!result.contains(.streaming)) + } + + // MARK: - Endpoint Default Capabilities + + @Test("OpenAI endpoint returns openAI capabilities") + func openAIEndpointCapabilities() { + #expect(OpenAIEndpoint.openAI.defaultCapabilities == .openAI) + } + + @Test("OpenRouter endpoint returns openRouter capabilities") + func openRouterEndpointCapabilities() { + #expect(OpenAIEndpoint.openRouter.defaultCapabilities == .openRouter) + } + + @Test("Ollama endpoint returns ollama capabilities") + func ollamaEndpointCapabilities() { + #expect(OpenAIEndpoint.ollama().defaultCapabilities == .ollama) + } + + @Test("Azure endpoint returns expected capabilities") + func azureEndpointCapabilities() { + let caps = OpenAIEndpoint.azure(resource: "r", deployment: "d", apiVersion: "v").defaultCapabilities + #expect(caps.contains(.textGeneration)) + #expect(caps.contains(.streaming)) + #expect(caps.contains(.functionCalling)) + #expect(caps.contains(.jsonMode)) + } + + @Test("Custom endpoint returns textOnly capabilities") + func customEndpointCapabilities() { + let url = URL(string: "https://custom.com/v1")! + #expect(OpenAIEndpoint.custom(url).defaultCapabilities == .textOnly) + } + + // MARK: - Hashable / Sendable + + @Test("Capabilities are hashable for set usage") + func hashable() { + var set: Set = [] + set.insert(.textOnly) + set.insert(.openAI) + set.insert(.textOnly) + #expect(set.count == 2) + } + + @Test("Sendable conformance compiles") + func sendableConformance() { + let caps: Sendable = OpenAICapabilities.openAI + #expect(caps is OpenAICapabilities) + } +} + +#endif // CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER diff --git a/Tests/ConduitTests/Providers/OpenAI/OpenAIConfigurationTests.swift b/Tests/ConduitTests/Providers/OpenAI/OpenAIConfigurationTests.swift new file mode 100644 index 0000000..4d74c15 --- /dev/null +++ b/Tests/ConduitTests/Providers/OpenAI/OpenAIConfigurationTests.swift @@ -0,0 +1,476 @@ +// OpenAIConfigurationTests.swift +// Conduit Tests +// +// Tests for OpenAIConfiguration initialization, defaults, fluent API, and Codable. + +#if CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER +import Foundation +import Testing +@testable import Conduit + +@Suite("OpenAIConfiguration Tests") +struct OpenAIConfigurationTests { + + // MARK: - Default Initialization + + @Test("Default init has expected values") + func defaultInit() { + let config = OpenAIConfiguration() + #expect(config.endpoint == .openAI) + #expect(config.authentication == .auto) + #expect(config.apiVariant == .chatCompletions) + #expect(config.timeout == 60.0) + #expect(config.maxRetries == 3) + #expect(config.retryConfig == .default) + #expect(config.defaultHeaders.isEmpty) + #expect(config.userAgent == nil) + #expect(config.organizationID == nil) + #expect(config.openRouterConfig == nil) + #expect(config.azureConfig == nil) + #expect(config.ollamaConfig == nil) + } + + @Test("Default static property matches default init") + func defaultStaticProperty() { + let config = OpenAIConfiguration.default + #expect(config.endpoint == .openAI) + #expect(config.authentication == .auto) + #expect(config.timeout == 60.0) + #expect(config.maxRetries == 3) + } + + // MARK: - Clamping + + @Test("Negative timeout is clamped to zero") + func negativeTimeoutClamped() { + let config = OpenAIConfiguration(timeout: -10) + #expect(config.timeout == 0) + } + + @Test("Negative maxRetries is clamped to zero") + func negativeMaxRetriesClamped() { + let config = OpenAIConfiguration(maxRetries: -3) + #expect(config.maxRetries == 0) + } + + // MARK: - Static Presets + + @Test("openRouter preset has correct endpoint and authentication") + func openRouterPreset() { + let config = OpenAIConfiguration.openRouter + #expect(config.endpoint == .openRouter) + #expect(config.authentication == .environment("OPENROUTER_API_KEY")) + #expect(config.openRouterConfig != nil) + } + + @Test("ollama preset has correct endpoint and no auth") + func ollamaPreset() { + let config = OpenAIConfiguration.ollama + #expect(config.authentication == .none) + #expect(config.ollamaConfig != nil) + } + + @Test("longRunning preset has extended timeout and aggressive retry") + func longRunningPreset() { + let config = OpenAIConfiguration.longRunning + #expect(config.timeout == 120.0) + #expect(config.maxRetries == 5) + #expect(config.retryConfig == .aggressive) + } + + @Test("noRetry preset has zero retries") + func noRetryPreset() { + let config = OpenAIConfiguration.noRetry + #expect(config.maxRetries == 0) + #expect(config.retryConfig == .none) + } + + // MARK: - Convenience Initializers + + @Test("openAI(apiKey:) creates bearer auth with openAI endpoint") + func openAIConvenience() { + let config = OpenAIConfiguration.openAI(apiKey: "sk-test123") + #expect(config.endpoint == .openAI) + #expect(config.authentication == .bearer("sk-test123")) + } + + @Test("openRouter(apiKey:) creates bearer auth with openRouter endpoint") + func openRouterConvenience() { + let config = OpenAIConfiguration.openRouter(apiKey: "or-test456") + #expect(config.endpoint == .openRouter) + #expect(config.authentication == .bearer("or-test456")) + #expect(config.openRouterConfig != nil) + } + + @Test("ollama(host:port:) creates config with no auth") + func ollamaConvenience() { + let config = OpenAIConfiguration.ollama(host: "192.168.1.5", port: 8080) + #expect(config.authentication == .none) + #expect(config.ollamaConfig != nil) + } + + @Test("azure convenience creates apiKey auth with correct endpoint") + func azureConvenience() { + let config = OpenAIConfiguration.azure( + resource: "my-resource", + deployment: "gpt4-deploy", + apiKey: "azure-key-123" + ) + #expect(config.authentication == .apiKey("azure-key-123", headerName: "api-key")) + #expect(config.azureConfig != nil) + #expect(config.azureConfig?.resource == "my-resource") + #expect(config.azureConfig?.deployment == "gpt4-deploy") + } + + @Test("custom(url:) creates config with auto auth when no key") + func customConvenienceNoKey() { + let url = URL(string: "https://my-proxy.com/v1")! + let config = OpenAIConfiguration.custom(url: url) + #expect(config.endpoint == .custom(url)) + #expect(config.authentication == .auto) + } + + @Test("custom(url:apiKey:) creates config with bearer auth") + func customConvenienceWithKey() { + let url = URL(string: "https://my-proxy.com/v1")! + let config = OpenAIConfiguration.custom(url: url, apiKey: "my-key") + #expect(config.authentication == .bearer("my-key")) + } + + // MARK: - Computed Properties + + @Test("hasValidAuthentication returns true for ollama (no auth required)") + func hasValidAuthOllama() { + let config = OpenAIConfiguration.ollama + #expect(config.hasValidAuthentication) + } + + @Test("hasValidAuthentication returns true for bearer with non-empty token") + func hasValidAuthBearer() { + let config = OpenAIConfiguration.openAI(apiKey: "sk-test") + #expect(config.hasValidAuthentication) + } + + @Test("hasValidAuthentication returns false for bearer with empty token") + func hasInvalidAuthEmptyBearer() { + let config = OpenAIConfiguration( + endpoint: .openAI, + authentication: .bearer("") + ) + #expect(!config.hasValidAuthentication) + } + + @Test("capabilities returns endpoint default capabilities") + func capabilitiesReturnsEndpointDefaults() { + let openAI = OpenAIConfiguration.default + #expect(openAI.capabilities == OpenAICapabilities.openAI) + + let openRouter = OpenAIConfiguration.openRouter + #expect(openRouter.capabilities == OpenAICapabilities.openRouter) + } + + // MARK: - Fluent API + + @Test("Fluent endpoint returns updated copy") + func fluentEndpoint() { + let config = OpenAIConfiguration.default.endpoint(.openRouter) + #expect(config.endpoint == .openRouter) + } + + @Test("Fluent authentication returns updated copy") + func fluentAuthentication() { + let config = OpenAIConfiguration.default.authentication(.bearer("sk-test")) + #expect(config.authentication == .bearer("sk-test")) + } + + @Test("Fluent apiVariant returns updated copy") + func fluentApiVariant() { + let config = OpenAIConfiguration.default.apiVariant(.responses) + #expect(config.apiVariant == .responses) + } + + @Test("Fluent apiKey sets bearer authentication") + func fluentApiKey() { + let config = OpenAIConfiguration.default.apiKey("sk-new-key") + #expect(config.authentication == .bearer("sk-new-key")) + } + + @Test("Fluent timeout returns updated copy and clamps negative") + func fluentTimeout() { + let config = OpenAIConfiguration.default.timeout(120.0) + #expect(config.timeout == 120.0) + + let clamped = OpenAIConfiguration.default.timeout(-5.0) + #expect(clamped.timeout == 0.0) + } + + @Test("Fluent maxRetries returns updated copy and clamps negative") + func fluentMaxRetries() { + let config = OpenAIConfiguration.default.maxRetries(10) + #expect(config.maxRetries == 10) + + let clamped = OpenAIConfiguration.default.maxRetries(-1) + #expect(clamped.maxRetries == 0) + } + + @Test("Fluent retryConfig returns updated copy") + func fluentRetryConfig() { + let config = OpenAIConfiguration.default.retryConfig(.aggressive) + #expect(config.retryConfig == .aggressive) + } + + @Test("Fluent noRetries sets maxRetries to zero") + func fluentNoRetries() { + let config = OpenAIConfiguration.default.noRetries() + #expect(config.maxRetries == 0) + } + + @Test("Fluent headers returns updated copy") + func fluentHeaders() { + let config = OpenAIConfiguration.default.headers(["X-Custom": "value"]) + #expect(config.defaultHeaders == ["X-Custom": "value"]) + } + + @Test("Fluent header adds a single header") + func fluentHeader() { + let config = OpenAIConfiguration.default.header("X-Custom", value: "test") + #expect(config.defaultHeaders["X-Custom"] == "test") + } + + @Test("Fluent userAgent returns updated copy") + func fluentUserAgent() { + let config = OpenAIConfiguration.default.userAgent("MyApp/1.0") + #expect(config.userAgent == "MyApp/1.0") + } + + @Test("Fluent organization returns updated copy") + func fluentOrganization() { + let config = OpenAIConfiguration.default.organization("org-123") + #expect(config.organizationID == "org-123") + } + + @Test("Fluent openRouter routing returns updated copy") + func fluentOpenRouterRouting() { + let routing = OpenRouterRoutingConfig(providers: [.anthropic], fallbacks: true) + let config = OpenAIConfiguration.default.openRouter(routing) + #expect(config.openRouterConfig?.providers == [.anthropic]) + } + + @Test("Fluent routing alias works") + func fluentRoutingAlias() { + let config = OpenAIConfiguration.openRouter(apiKey: "test") + .routing(.preferAnthropic) + #expect(config.openRouterConfig?.providers == [.anthropic]) + } + + @Test("Fluent preferring sets provider routing") + func fluentPreferring() { + let config = OpenAIConfiguration.openRouter(apiKey: "test") + .preferring(.openai, .anthropic) + #expect(config.openRouterConfig?.providers == [.openai, .anthropic]) + #expect(config.openRouterConfig?.fallbacks == true) + } + + @Test("Fluent routeByLatency enables latency routing") + func fluentRouteByLatency() { + let config = OpenAIConfiguration.openRouter(apiKey: "test") + .routeByLatency() + #expect(config.openRouterConfig?.routeByLatency == true) + } + + @Test("Fluent routeByLatency creates default config if nil") + func fluentRouteByLatencyCreatesDefault() { + let config = OpenAIConfiguration.default.routeByLatency() + #expect(config.openRouterConfig != nil) + #expect(config.openRouterConfig?.routeByLatency == true) + } + + @Test("Fluent ollama returns updated copy") + func fluentOllama() { + let ollamaConfig = OllamaConfiguration(keepAlive: "10m") + let config = OpenAIConfiguration.default.ollama(ollamaConfig) + #expect(config.ollamaConfig?.keepAlive == "10m") + } + + @Test("Fluent azure returns updated copy") + func fluentAzure() { + let azureConfig = AzureConfiguration(resource: "res", deployment: "dep") + let config = OpenAIConfiguration.default.azure(azureConfig) + #expect(config.azureConfig?.resource == "res") + } + + // MARK: - Build Headers + + @Test("buildHeaders includes Content-Type") + func buildHeadersIncludesContentType() { + let config = OpenAIConfiguration.openAI(apiKey: "sk-test") + let headers = config.buildHeaders() + #expect(headers["Content-Type"] == "application/json") + } + + @Test("buildHeaders includes bearer auth") + func buildHeadersIncludesAuth() { + let config = OpenAIConfiguration.openAI(apiKey: "sk-test") + let headers = config.buildHeaders() + #expect(headers["Authorization"] == "Bearer sk-test") + } + + @Test("buildHeaders includes default User-Agent when custom not set") + func buildHeadersDefaultUserAgent() { + let config = OpenAIConfiguration.openAI(apiKey: "sk-test") + let headers = config.buildHeaders() + #expect(headers["User-Agent"]?.hasPrefix("Conduit/") == true) + } + + @Test("buildHeaders uses custom User-Agent when set") + func buildHeadersCustomUserAgent() { + let config = OpenAIConfiguration.openAI(apiKey: "sk-test") + .userAgent("MyApp/2.0") + let headers = config.buildHeaders() + #expect(headers["User-Agent"] == "MyApp/2.0") + } + + @Test("buildHeaders includes organization ID for OpenAI endpoint") + func buildHeadersIncludesOrgId() { + let config = OpenAIConfiguration.openAI(apiKey: "sk-test") + .organization("org-abc") + let headers = config.buildHeaders() + #expect(headers["OpenAI-Organization"] == "org-abc") + } + + @Test("buildHeaders does not include organization ID for non-OpenAI endpoint") + func buildHeadersOmitsOrgIdForNonOpenAI() { + let config = OpenAIConfiguration.openRouter(apiKey: "or-test") + .organization("org-abc") + let headers = config.buildHeaders() + #expect(headers["OpenAI-Organization"] == nil) + } + + @Test("buildHeaders includes default headers") + func buildHeadersIncludesDefaultHeaders() { + let config = OpenAIConfiguration.openAI(apiKey: "sk-test") + .headers(["X-Custom": "value"]) + let headers = config.buildHeaders() + #expect(headers["X-Custom"] == "value") + } + + @Test("buildHeaders includes OpenRouter site headers") + func buildHeadersIncludesOpenRouterHeaders() { + let routing = OpenRouterRoutingConfig( + siteURL: URL(string: "https://myapp.com"), + appName: "MyApp" + ) + let config = OpenAIConfiguration.openRouter(apiKey: "or-test") + .openRouter(routing) + let headers = config.buildHeaders() + #expect(headers["HTTP-Referer"] == "https://myapp.com") + #expect(headers["X-Title"] == "MyApp") + } + + // MARK: - Codable + + @Test("OpenAIConfiguration round-trips through JSON with default values") + func codableRoundTripDefaults() throws { + let original = OpenAIConfiguration.default + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenAIConfiguration.self, from: data) + + #expect(decoded.endpoint == original.endpoint) + #expect(decoded.timeout == original.timeout) + #expect(decoded.maxRetries == original.maxRetries) + #expect(decoded.defaultHeaders == original.defaultHeaders) + // Authentication is not encoded (security) - decoded should be .auto + #expect(decoded.authentication == .auto) + } + + @Test("OpenAIConfiguration authentication is not persisted in JSON for security") + func codableDoesNotPersistAuth() throws { + let original = OpenAIConfiguration.openAI(apiKey: "sk-secret-key") + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenAIConfiguration.self, from: data) + + // Authentication should be .auto after decoding, not bearer + #expect(decoded.authentication == .auto) + } + + @Test("OpenAIConfiguration azureConfig is not persisted in JSON for security") + func codableDoesNotPersistAzureConfig() throws { + let original = OpenAIConfiguration.azure( + resource: "res", + deployment: "dep", + apiKey: "key" + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenAIConfiguration.self, from: data) + + #expect(decoded.azureConfig == nil) + } + + @Test("OpenAIConfiguration preserves optional fields through encoding") + func codablePreservesOptionalFields() throws { + let original = OpenAIConfiguration( + userAgent: "TestAgent", + organizationID: "org-test", + openRouterConfig: .default, + ollamaConfig: .default + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenAIConfiguration.self, from: data) + + #expect(decoded.userAgent == "TestAgent") + #expect(decoded.organizationID == "org-test") + #expect(decoded.openRouterConfig != nil) + #expect(decoded.ollamaConfig != nil) + } + + @Test("OpenAIConfiguration apiVariant round-trips") + func codableRoundTripApiVariant() throws { + let original = OpenAIConfiguration.default.apiVariant(.responses) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenAIConfiguration.self, from: data) + #expect(decoded.apiVariant == .responses) + } + + // MARK: - Hashable / Equatable + + @Test("Equal configurations are equal") + func equalConfigurationsAreEqual() { + let a = OpenAIConfiguration.default + let b = OpenAIConfiguration.default + #expect(a == b) + } + + @Test("Different configurations are not equal") + func differentConfigurationsAreNotEqual() { + let a = OpenAIConfiguration.default + let b = OpenAIConfiguration.longRunning + #expect(a != b) + } +} + +// MARK: - OpenAIAPIVariant Tests + +@Suite("OpenAIAPIVariant Type Tests") +struct OpenAIAPIVariantTypeTests { + + @Test("chatCompletions raw value is correct") + func chatCompletionsRawValue() { + #expect(OpenAIAPIVariant.chatCompletions.rawValue == "chatCompletions") + } + + @Test("responses raw value is correct") + func responsesRawValue() { + #expect(OpenAIAPIVariant.responses.rawValue == "responses") + } + + @Test("Codable round-trip for both variants") + func codableRoundTrip() throws { + for variant in [OpenAIAPIVariant.chatCompletions, .responses] { + let data = try JSONEncoder().encode(variant) + let decoded = try JSONDecoder().decode(OpenAIAPIVariant.self, from: data) + #expect(decoded == variant) + } + } +} + +#endif // CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER diff --git a/Tests/ConduitTests/Providers/OpenAI/OpenAIEndpointTests.swift b/Tests/ConduitTests/Providers/OpenAI/OpenAIEndpointTests.swift new file mode 100644 index 0000000..e3f6102 --- /dev/null +++ b/Tests/ConduitTests/Providers/OpenAI/OpenAIEndpointTests.swift @@ -0,0 +1,475 @@ +// OpenAIEndpointTests.swift +// Conduit Tests +// +// Tests for OpenAIEndpoint URL construction, properties, validation, and Codable. + +#if CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER +import Foundation +import Testing +@testable import Conduit + +@Suite("OpenAIEndpoint Tests") +struct OpenAIEndpointTests { + + // MARK: - Base URLs + + @Test("OpenAI base URL is correct") + func openAIBaseURL() { + #expect(OpenAIEndpoint.openAI.baseURL.absoluteString == "https://api.openai.com/v1") + } + + @Test("OpenRouter base URL is correct") + func openRouterBaseURL() { + #expect(OpenAIEndpoint.openRouter.baseURL.absoluteString == "https://openrouter.ai/api/v1") + } + + @Test("Ollama default base URL is correct") + func ollamaDefaultBaseURL() { + let endpoint = OpenAIEndpoint.ollama() + #expect(endpoint.baseURL.absoluteString == "http://localhost:11434/v1") + } + + @Test("Ollama custom host and port base URL is correct") + func ollamaCustomBaseURL() { + let endpoint = OpenAIEndpoint.ollama(host: "192.168.1.10", port: 8080) + #expect(endpoint.baseURL.absoluteString == "http://192.168.1.10:8080/v1") + } + + @Test("Azure base URL includes resource name") + func azureBaseURL() { + let endpoint = OpenAIEndpoint.azure( + resource: "my-resource", + deployment: "gpt4", + apiVersion: "2024-02-15-preview" + ) + #expect(endpoint.baseURL.absoluteString == "https://my-resource.openai.azure.com/openai") + } + + @Test("Custom endpoint base URL is the provided URL") + func customBaseURL() { + let url = URL(string: "https://my-proxy.com/v1")! + let endpoint = OpenAIEndpoint.custom(url) + #expect(endpoint.baseURL == url) + } + + // MARK: - Chat Completions URL + + @Test("OpenAI chat completions URL") + func openAIChatCompletionsURL() { + let url = OpenAIEndpoint.openAI.chatCompletionsURL + #expect(url.absoluteString == "https://api.openai.com/v1/chat/completions") + } + + @Test("OpenRouter chat completions URL") + func openRouterChatCompletionsURL() { + let url = OpenAIEndpoint.openRouter.chatCompletionsURL + #expect(url.absoluteString == "https://openrouter.ai/api/v1/chat/completions") + } + + @Test("Ollama chat completions URL") + func ollamaChatCompletionsURL() { + let url = OpenAIEndpoint.ollama().chatCompletionsURL + #expect(url.absoluteString == "http://localhost:11434/v1/chat/completions") + } + + @Test("Azure chat completions URL includes deployment and api-version") + func azureChatCompletionsURL() { + let endpoint = OpenAIEndpoint.azure( + resource: "my-resource", + deployment: "gpt4-deploy", + apiVersion: "2024-02-15-preview" + ) + let url = endpoint.chatCompletionsURL + #expect(url.absoluteString.contains("deployments/gpt4-deploy/chat/completions")) + #expect(url.absoluteString.contains("api-version=2024-02-15-preview")) + } + + // MARK: - Responses URL + + @Test("OpenAI responses URL") + func openAIResponsesURL() { + let url = OpenAIEndpoint.openAI.responsesURL + #expect(url.absoluteString == "https://api.openai.com/v1/responses") + } + + @Test("Azure responses URL includes deployment and api-version") + func azureResponsesURL() { + let endpoint = OpenAIEndpoint.azure( + resource: "res", + deployment: "dep", + apiVersion: "2024-02-15-preview" + ) + let url = endpoint.responsesURL + #expect(url.absoluteString.contains("deployments/dep/responses")) + #expect(url.absoluteString.contains("api-version=2024-02-15-preview")) + } + + // MARK: - Text Generation URL + + @Test("textGenerationURL returns chat completions for .chatCompletions variant") + func textGenerationURLChatCompletions() { + let url = OpenAIEndpoint.openAI.textGenerationURL(for: .chatCompletions) + #expect(url == OpenAIEndpoint.openAI.chatCompletionsURL) + } + + @Test("textGenerationURL returns responses for .responses variant") + func textGenerationURLResponses() { + let url = OpenAIEndpoint.openAI.textGenerationURL(for: .responses) + #expect(url == OpenAIEndpoint.openAI.responsesURL) + } + + // MARK: - Embeddings URL + + @Test("OpenAI embeddings URL") + func openAIEmbeddingsURL() { + let url = OpenAIEndpoint.openAI.embeddingsURL + #expect(url.absoluteString == "https://api.openai.com/v1/embeddings") + } + + @Test("Azure embeddings URL includes deployment and api-version") + func azureEmbeddingsURL() { + let endpoint = OpenAIEndpoint.azure( + resource: "res", + deployment: "embed-dep", + apiVersion: "2024-02-15-preview" + ) + let url = endpoint.embeddingsURL + #expect(url.absoluteString.contains("deployments/embed-dep/embeddings")) + #expect(url.absoluteString.contains("api-version=2024-02-15-preview")) + } + + // MARK: - Images Generations URL + + @Test("Images generations URL appends correct path") + func imagesGenerationsURL() { + let url = OpenAIEndpoint.openAI.imagesGenerationsURL + #expect(url.absoluteString == "https://api.openai.com/v1/images/generations") + } + + // MARK: - Audio Transcriptions URL + + @Test("Audio transcriptions URL appends correct path") + func audioTranscriptionsURL() { + let url = OpenAIEndpoint.openAI.audioTranscriptionsURL + #expect(url.absoluteString == "https://api.openai.com/v1/audio/transcriptions") + } + + // MARK: - isLocal + + @Test("Ollama is local") + func ollamaIsLocal() { + #expect(OpenAIEndpoint.ollama().isLocal) + } + + @Test("OpenAI is not local") + func openAIIsNotLocal() { + #expect(!OpenAIEndpoint.openAI.isLocal) + } + + @Test("OpenRouter is not local") + func openRouterIsNotLocal() { + #expect(!OpenAIEndpoint.openRouter.isLocal) + } + + @Test("Azure is not local") + func azureIsNotLocal() { + #expect(!OpenAIEndpoint.azure(resource: "r", deployment: "d", apiVersion: "v").isLocal) + } + + @Test("Custom is not local") + func customIsNotLocal() { + let url = URL(string: "https://custom.com")! + #expect(!OpenAIEndpoint.custom(url).isLocal) + } + + // MARK: - requiresAuthentication + + @Test("Ollama does not require authentication") + func ollamaNoAuthRequired() { + #expect(!OpenAIEndpoint.ollama().requiresAuthentication) + } + + @Test("OpenAI requires authentication") + func openAIRequiresAuth() { + #expect(OpenAIEndpoint.openAI.requiresAuthentication) + } + + @Test("OpenRouter requires authentication") + func openRouterRequiresAuth() { + #expect(OpenAIEndpoint.openRouter.requiresAuthentication) + } + + @Test("Azure requires authentication") + func azureRequiresAuth() { + #expect(OpenAIEndpoint.azure(resource: "r", deployment: "d", apiVersion: "v").requiresAuthentication) + } + + @Test("Custom requires authentication") + func customRequiresAuth() { + let url = URL(string: "https://custom.com")! + #expect(OpenAIEndpoint.custom(url).requiresAuthentication) + } + + // MARK: - Display Name + + @Test("OpenAI display name") + func openAIDisplayName() { + #expect(OpenAIEndpoint.openAI.displayName == "OpenAI") + } + + @Test("OpenRouter display name") + func openRouterDisplayName() { + #expect(OpenAIEndpoint.openRouter.displayName == "OpenRouter") + } + + @Test("Ollama default display name") + func ollamaDefaultDisplayName() { + #expect(OpenAIEndpoint.ollama().displayName == "Ollama (Local)") + } + + @Test("Ollama custom host display name") + func ollamaCustomDisplayName() { + let endpoint = OpenAIEndpoint.ollama(host: "192.168.1.5", port: 8080) + #expect(endpoint.displayName == "Ollama (192.168.1.5:8080)") + } + + @Test("Azure display name includes resource") + func azureDisplayName() { + let endpoint = OpenAIEndpoint.azure(resource: "my-resource", deployment: "d", apiVersion: "v") + #expect(endpoint.displayName == "Azure OpenAI (my-resource)") + } + + @Test("Custom display name includes host") + func customDisplayName() { + let url = URL(string: "https://my-proxy.com/v1")! + let endpoint = OpenAIEndpoint.custom(url) + #expect(endpoint.displayName.contains("my-proxy.com")) + } + + // MARK: - Description + + @Test("description matches displayName") + func descriptionMatchesDisplayName() { + #expect(OpenAIEndpoint.openAI.description == OpenAIEndpoint.openAI.displayName) + #expect(OpenAIEndpoint.openRouter.description == OpenAIEndpoint.openRouter.displayName) + } + + // MARK: - Convenience Initializers + + @Test("ollama(url:) parses host and port from URL") + func ollamaFromURL() { + let url = URL(string: "http://192.168.1.10:8080/v1")! + let endpoint = OpenAIEndpoint.ollama(url: url) + #expect(endpoint.baseURL.absoluteString.contains("192.168.1.10")) + #expect(endpoint.baseURL.absoluteString.contains("8080")) + } + + @Test("azure convenience initializer uses default API version") + func azureConvenienceDefaultVersion() { + let endpoint = OpenAIEndpoint.azure(resource: "res", deployment: "dep") + if case .azure(_, _, let version) = endpoint { + #expect(version == "2024-02-15-preview") + } else { + Issue.record("Expected azure endpoint") + } + } + + // MARK: - Validated Constructors + + @Test("ollamaValidated with valid values works") + func ollamaValidatedValid() { + let endpoint = OpenAIEndpoint.ollamaValidated(host: "myhost", port: 9090) + if case .ollama(let host, let port) = endpoint { + #expect(host == "myhost") + #expect(port == 9090) + } else { + Issue.record("Expected ollama endpoint") + } + } + + @Test("ollamaValidated with empty host falls back to localhost") + func ollamaValidatedEmptyHost() { + let endpoint = OpenAIEndpoint.ollamaValidated(host: "", port: 11434) + if case .ollama(let host, _) = endpoint { + #expect(host == "localhost") + } else { + Issue.record("Expected ollama endpoint") + } + } + + @Test("ollamaValidated with invalid port falls back to 11434") + func ollamaValidatedInvalidPort() { + let endpoint = OpenAIEndpoint.ollamaValidated(host: "localhost", port: 99999) + if case .ollama(_, let port) = endpoint { + #expect(port == 11434) + } else { + Issue.record("Expected ollama endpoint") + } + } + + @Test("ollamaValidated with negative port falls back to 11434") + func ollamaValidatedNegativePort() { + let endpoint = OpenAIEndpoint.ollamaValidated(host: "localhost", port: -1) + if case .ollama(_, let port) = endpoint { + #expect(port == 11434) + } else { + Issue.record("Expected ollama endpoint") + } + } + + // MARK: - Validation + + @Test("validateOllamaConfig succeeds with valid values") + func validateOllamaConfigValid() throws { + try OpenAIEndpoint.validateOllamaConfig(host: "localhost", port: 11434) + } + + @Test("validateOllamaConfig throws emptyHost for empty host") + func validateOllamaConfigEmptyHost() { + #expect(throws: OpenAIEndpoint.ValidationError.self) { + try OpenAIEndpoint.validateOllamaConfig(host: "", port: 11434) + } + } + + @Test("validateOllamaConfig throws invalidPort for port 0") + func validateOllamaConfigInvalidPortZero() { + #expect(throws: OpenAIEndpoint.ValidationError.self) { + try OpenAIEndpoint.validateOllamaConfig(host: "localhost", port: 0) + } + } + + @Test("validateOllamaConfig throws invalidPort for port above 65535") + func validateOllamaConfigInvalidPortAbove() { + #expect(throws: OpenAIEndpoint.ValidationError.self) { + try OpenAIEndpoint.validateOllamaConfig(host: "localhost", port: 70000) + } + } + + @Test("ValidationError has localized descriptions") + func validationErrorDescriptions() { + let emptyHost = OpenAIEndpoint.ValidationError.emptyHost + #expect(emptyHost.errorDescription?.contains("empty") == true) + + let invalidPort = OpenAIEndpoint.ValidationError.invalidPort(99999) + #expect(invalidPort.errorDescription?.contains("99999") == true) + #expect(invalidPort.errorDescription?.contains("65535") == true) + } + + // MARK: - Ollama Host Sanitization + + @Test("Ollama sanitizes host by removing scheme prefixes") + func ollamaSanitizesScheme() { + let endpoint = OpenAIEndpoint.ollama(host: "http://myhost", port: 11434) + #expect(endpoint.baseURL.host == "myhost") + } + + @Test("Ollama sanitizes empty host to localhost") + func ollamaSanitizesEmptyHost() { + let endpoint = OpenAIEndpoint.ollama(host: "", port: 11434) + #expect(endpoint.baseURL.host == "localhost") + } + + @Test("Ollama clamps invalid port to default 11434") + func ollamaClampsInvalidPort() { + let endpoint = OpenAIEndpoint.ollama(host: "localhost", port: 99999) + #expect(endpoint.baseURL.port == 11434) + } + + // MARK: - Equatable / Hashable + + @Test("Same endpoints are equal") + func sameEndpointsEqual() { + #expect(OpenAIEndpoint.openAI == .openAI) + #expect(OpenAIEndpoint.openRouter == .openRouter) + #expect(OpenAIEndpoint.ollama() == .ollama()) + } + + @Test("Different endpoints are not equal") + func differentEndpointsNotEqual() { + #expect(OpenAIEndpoint.openAI != .openRouter) + #expect(OpenAIEndpoint.openAI != .ollama()) + } + + @Test("Ollama endpoints with different hosts are not equal") + func ollamaDifferentHosts() { + #expect(OpenAIEndpoint.ollama(host: "a", port: 11434) != .ollama(host: "b", port: 11434)) + } + + @Test("Can be used in a Set") + func hashableForSet() { + var set: Set = [] + set.insert(.openAI) + set.insert(.openRouter) + set.insert(.openAI) + #expect(set.count == 2) + } + + // MARK: - Codable + + @Test("OpenAI endpoint round-trips through JSON") + func codableOpenAI() throws { + let original = OpenAIEndpoint.openAI + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenAIEndpoint.self, from: data) + #expect(decoded == original) + } + + @Test("OpenRouter endpoint round-trips through JSON") + func codableOpenRouter() throws { + let original = OpenAIEndpoint.openRouter + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenAIEndpoint.self, from: data) + #expect(decoded == original) + } + + @Test("Ollama endpoint round-trips through JSON") + func codableOllama() throws { + let original = OpenAIEndpoint.ollama(host: "192.168.1.5", port: 8080) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenAIEndpoint.self, from: data) + #expect(decoded == original) + } + + @Test("Azure endpoint round-trips through JSON") + func codableAzure() throws { + let original = OpenAIEndpoint.azure( + resource: "my-resource", + deployment: "gpt4", + apiVersion: "2024-02-15-preview" + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenAIEndpoint.self, from: data) + #expect(decoded == original) + } + + @Test("Custom endpoint round-trips through JSON") + func codableCustom() throws { + let url = URL(string: "https://custom-api.example.com/v1")! + let original = OpenAIEndpoint.custom(url) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenAIEndpoint.self, from: data) + #expect(decoded == original) + } + + @Test("Ollama endpoint decodes with defaults when host/port missing") + func codableOllamaDefaults() throws { + let json = #"{"type":"ollama"}"#.data(using: .utf8)! + let decoded = try JSONDecoder().decode(OpenAIEndpoint.self, from: json) + if case .ollama(let host, let port) = decoded { + #expect(host == "localhost") + #expect(port == 11434) + } else { + Issue.record("Expected ollama endpoint") + } + } + + // MARK: - Sendable + + @Test("Sendable conformance compiles") + func sendableConformance() { + let endpoint: Sendable = OpenAIEndpoint.openAI + #expect(endpoint is OpenAIEndpoint) + } +} + +#endif // CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER diff --git a/Tests/ConduitTests/Providers/OpenAI/OpenAIModelIDTests.swift b/Tests/ConduitTests/Providers/OpenAI/OpenAIModelIDTests.swift new file mode 100644 index 0000000..940f300 --- /dev/null +++ b/Tests/ConduitTests/Providers/OpenAI/OpenAIModelIDTests.swift @@ -0,0 +1,234 @@ +// OpenAIModelIDTests.swift +// Conduit Tests +// +// Tests for OpenAIModelID including static models, helpers, and protocol conformances. + +#if CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER +import Foundation +import Testing +@testable import Conduit + +@Suite("OpenAIModelID Tests") +struct OpenAIModelIDTests { + + // MARK: - Initialization + + @Test("Init from string stores raw value") + func initFromString() { + let model = OpenAIModelID("gpt-4o") + #expect(model.rawValue == "gpt-4o") + } + + @Test("Init from rawValue stores raw value") + func initFromRawValue() { + let model = OpenAIModelID(rawValue: "gpt-4-turbo") + #expect(model.rawValue == "gpt-4-turbo") + } + + @Test("String literal initialization") + func stringLiteralInit() { + let model: OpenAIModelID = "gpt-3.5-turbo" + #expect(model.rawValue == "gpt-3.5-turbo") + } + + // MARK: - Provider Type + + @Test("All OpenAI model IDs have .openAI provider type") + func providerTypeIsOpenAI() { + #expect(OpenAIModelID.gpt4o.provider == .openAI) + #expect(OpenAIModelID.gpt4oMini.provider == .openAI) + #expect(OpenAIModelID("custom-model").provider == .openAI) + } + + // MARK: - Display Name + + @Test("Display name for simple model returns raw value") + func displayNameSimple() { + #expect(OpenAIModelID.gpt4o.displayName == "gpt-4o") + } + + @Test("Display name extracts model from provider-prefixed format") + func displayNameProviderPrefixed() { + let model = OpenAIModelID("openai/gpt-4-turbo") + #expect(model.displayName == "gpt-4-turbo") + } + + @Test("Display name for Ollama format preserves tag") + func displayNameOllamaFormat() { + let model = OpenAIModelID("llama3.2:3b") + #expect(model.displayName == "llama3.2:3b") + } + + // MARK: - Description + + @Test("Description includes OpenAI-Compatible prefix") + func descriptionFormat() { + let model = OpenAIModelID.gpt4o + #expect(model.description == "[OpenAI-Compatible] gpt-4o") + } + + // MARK: - Static OpenAI Models + + @Test("GPT-4 series static models have correct raw values") + func gpt4SeriesRawValues() { + #expect(OpenAIModelID.gpt4o.rawValue == "gpt-4o") + #expect(OpenAIModelID.gpt4oMini.rawValue == "gpt-4o-mini") + #expect(OpenAIModelID.gpt4Turbo.rawValue == "gpt-4-turbo") + #expect(OpenAIModelID.gpt4.rawValue == "gpt-4") + } + + @Test("GPT-3.5 static model has correct raw value") + func gpt35RawValue() { + #expect(OpenAIModelID.gpt35Turbo.rawValue == "gpt-3.5-turbo") + } + + @Test("Reasoning models have correct raw values") + func reasoningModelsRawValues() { + #expect(OpenAIModelID.o1.rawValue == "o1") + #expect(OpenAIModelID.o1Mini.rawValue == "o1-mini") + #expect(OpenAIModelID.o3Mini.rawValue == "o3-mini") + } + + @Test("Embedding models have correct raw values") + func embeddingModelsRawValues() { + #expect(OpenAIModelID.textEmbedding3Small.rawValue == "text-embedding-3-small") + #expect(OpenAIModelID.textEmbedding3Large.rawValue == "text-embedding-3-large") + #expect(OpenAIModelID.textEmbeddingAda002.rawValue == "text-embedding-ada-002") + } + + @Test("Image models have correct raw values") + func imageModelsRawValues() { + #expect(OpenAIModelID.dallE3.rawValue == "dall-e-3") + #expect(OpenAIModelID.dallE2.rawValue == "dall-e-2") + } + + @Test("Audio models have correct raw values") + func audioModelsRawValues() { + #expect(OpenAIModelID.whisper1.rawValue == "whisper-1") + #expect(OpenAIModelID.tts1.rawValue == "tts-1") + #expect(OpenAIModelID.tts1HD.rawValue == "tts-1-hd") + } + + // MARK: - OpenRouter Helpers + + @Test("openRouter helper creates model ID from string") + func openRouterHelper() { + let model = OpenAIModelID.openRouter("anthropic/claude-3-opus") + #expect(model.rawValue == "anthropic/claude-3-opus") + } + + @Test("OpenRouter static models have correct raw values") + func openRouterStaticModels() { + #expect(OpenAIModelID.claudeOpus.rawValue == "anthropic/claude-3-opus") + #expect(OpenAIModelID.claudeSonnet.rawValue == "anthropic/claude-3-sonnet") + #expect(OpenAIModelID.claudeHaiku.rawValue == "anthropic/claude-3-haiku") + #expect(OpenAIModelID.geminiPro.rawValue == "google/gemini-pro") + #expect(OpenAIModelID.geminiPro15.rawValue == "google/gemini-pro-1.5") + #expect(OpenAIModelID.mixtral8x7B.rawValue == "mistralai/mixtral-8x7b-instruct") + #expect(OpenAIModelID.llama31B70B.rawValue == "meta-llama/llama-3.1-70b-instruct") + #expect(OpenAIModelID.llama31B8B.rawValue == "meta-llama/llama-3.1-8b-instruct") + } + + @Test("OpenRouter model display names extract model name from prefix") + func openRouterDisplayNames() { + #expect(OpenAIModelID.claudeOpus.displayName == "claude-3-opus") + #expect(OpenAIModelID.geminiPro.displayName == "gemini-pro") + } + + // MARK: - Ollama Helpers + + @Test("ollama helper creates model ID from string") + func ollamaHelper() { + let model = OpenAIModelID.ollama("llama3.2:3b") + #expect(model.rawValue == "llama3.2:3b") + } + + @Test("Ollama static models have correct raw values") + func ollamaStaticModels() { + #expect(OpenAIModelID.ollamaLlama32.rawValue == "llama3.2") + #expect(OpenAIModelID.ollamaLlama32B3B.rawValue == "llama3.2:3b") + #expect(OpenAIModelID.ollamaLlama32B1B.rawValue == "llama3.2:1b") + #expect(OpenAIModelID.ollamaMistral.rawValue == "mistral") + #expect(OpenAIModelID.ollamaCodeLlama.rawValue == "codellama") + #expect(OpenAIModelID.ollamaPhi3.rawValue == "phi3") + #expect(OpenAIModelID.ollamaGemma2.rawValue == "gemma2") + #expect(OpenAIModelID.ollamaQwen25.rawValue == "qwen2.5") + #expect(OpenAIModelID.ollamaDeepseekCoder.rawValue == "deepseek-coder") + #expect(OpenAIModelID.ollamaNomicEmbed.rawValue == "nomic-embed-text") + } + + // MARK: - Azure Helpers + + @Test("azure deployment helper creates model ID from string") + func azureHelper() { + let model = OpenAIModelID.azure(deployment: "my-gpt4-deployment") + #expect(model.rawValue == "my-gpt4-deployment") + } + + // MARK: - Equatable / Hashable + + @Test("Same raw values are equal") + func sameRawValuesEqual() { + let a = OpenAIModelID("gpt-4o") + let b = OpenAIModelID("gpt-4o") + #expect(a == b) + } + + @Test("Different raw values are not equal") + func differentRawValuesNotEqual() { + let a = OpenAIModelID("gpt-4o") + let b = OpenAIModelID("gpt-4") + #expect(a != b) + } + + @Test("Can be used in a Set") + func usableInSet() { + let set: Set = [.gpt4o, .gpt4oMini, .gpt4o] + #expect(set.count == 2) + } + + @Test("Can be used as dictionary key") + func usableAsDictKey() { + var dict: [OpenAIModelID: String] = [:] + dict[.gpt4o] = "latest" + dict[.gpt4] = "legacy" + #expect(dict[.gpt4o] == "latest") + #expect(dict[.gpt4] == "legacy") + } + + // MARK: - Codable + + @Test("Codable round-trip preserves raw value") + func codableRoundTrip() throws { + let original = OpenAIModelID.gpt4o + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenAIModelID.self, from: data) + #expect(decoded == original) + #expect(decoded.rawValue == "gpt-4o") + } + + @Test("Decodes from plain string JSON") + func decodesFromPlainString() throws { + let json = Data(#""gpt-4-turbo""#.utf8) + let decoded = try JSONDecoder().decode(OpenAIModelID.self, from: json) + #expect(decoded.rawValue == "gpt-4-turbo") + } + + @Test("Encodes to plain string JSON") + func encodesToPlainString() throws { + let model = OpenAIModelID.gpt4oMini + let data = try JSONEncoder().encode(model) + let jsonString = String(data: data, encoding: .utf8) + #expect(jsonString == #""gpt-4o-mini""#) + } + + // MARK: - Sendable + + @Test("Sendable conformance compiles") + func sendableConformance() { + let model: Sendable = OpenAIModelID.gpt4o + #expect(model is OpenAIModelID) + } +} + +#endif // CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER diff --git a/Tests/ConduitTests/Providers/OpenAI/OpenRouterConfigTests.swift b/Tests/ConduitTests/Providers/OpenAI/OpenRouterConfigTests.swift new file mode 100644 index 0000000..5a04a44 --- /dev/null +++ b/Tests/ConduitTests/Providers/OpenAI/OpenRouterConfigTests.swift @@ -0,0 +1,431 @@ +// OpenRouterConfigTests.swift +// Conduit Tests +// +// Tests for OpenRouterRoutingConfig, OpenRouterProvider, and OpenRouterDataCollection. + +#if CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER +import Foundation +import Testing +@testable import Conduit + +@Suite("OpenRouterRoutingConfig Tests") +struct OpenRouterRoutingConfigTests { + + // MARK: - Default Initialization + + @Test("Default init has expected values") + func defaultInit() { + let config = OpenRouterRoutingConfig() + #expect(config.providers == nil) + #expect(config.fallbacks == true) + #expect(config.routeByLatency == false) + #expect(config.requireProvidersForJSON == false) + #expect(config.siteURL == nil) + #expect(config.appName == nil) + #expect(config.routeTag == nil) + #expect(config.dataCollection == nil) + } + + // MARK: - Custom Initialization + + @Test("Custom init preserves all values") + func customInit() { + let siteURL = URL(string: "https://myapp.com")! + let config = OpenRouterRoutingConfig( + providers: [.anthropic, .openai], + fallbacks: false, + routeByLatency: true, + requireProvidersForJSON: true, + siteURL: siteURL, + appName: "MyApp", + routeTag: "production", + dataCollection: .deny + ) + + #expect(config.providers == [.anthropic, .openai]) + #expect(config.fallbacks == false) + #expect(config.routeByLatency == true) + #expect(config.requireProvidersForJSON == true) + #expect(config.siteURL == siteURL) + #expect(config.appName == "MyApp") + #expect(config.routeTag == "production") + #expect(config.dataCollection == .deny) + } + + // MARK: - Static Presets + + @Test("Default preset matches default init") + func defaultPreset() { + let config = OpenRouterRoutingConfig.default + #expect(config.providers == nil) + #expect(config.fallbacks == true) + #expect(config.routeByLatency == false) + } + + @Test("preferOpenAI preset routes to OpenAI") + func preferOpenAIPreset() { + let config = OpenRouterRoutingConfig.preferOpenAI + #expect(config.providers == [.openai]) + #expect(config.fallbacks == true) + } + + @Test("preferAnthropic preset routes to Anthropic") + func preferAnthropicPreset() { + let config = OpenRouterRoutingConfig.preferAnthropic + #expect(config.providers == [.anthropic]) + #expect(config.fallbacks == true) + } + + @Test("fastestProvider preset enables latency routing") + func fastestProviderPreset() { + let config = OpenRouterRoutingConfig.fastestProvider + #expect(config.routeByLatency == true) + #expect(config.fallbacks == true) + } + + // MARK: - Header Generation + + @Test("headers returns empty dict when no site info") + func headersEmptyNoSiteInfo() { + let config = OpenRouterRoutingConfig() + let headers = config.headers() + #expect(headers.isEmpty) + } + + @Test("headers includes HTTP-Referer when siteURL set") + func headersIncludesReferer() { + let config = OpenRouterRoutingConfig(siteURL: URL(string: "https://myapp.com")!) + let headers = config.headers() + #expect(headers["HTTP-Referer"] == "https://myapp.com") + } + + @Test("headers includes X-Title when appName set") + func headersIncludesTitle() { + let config = OpenRouterRoutingConfig(appName: "TestApp") + let headers = config.headers() + #expect(headers["X-Title"] == "TestApp") + } + + @Test("headers includes both site URL and app name") + func headersIncludesBoth() { + let config = OpenRouterRoutingConfig( + siteURL: URL(string: "https://app.com")!, + appName: "MyApp" + ) + let headers = config.headers() + #expect(headers["HTTP-Referer"] == "https://app.com") + #expect(headers["X-Title"] == "MyApp") + } + + // MARK: - Provider Routing + + @Test("providerRouting returns nil for default config") + func providerRoutingNilForDefault() { + let config = OpenRouterRoutingConfig.default + #expect(config.providerRouting() == nil) + } + + @Test("providerRouting includes order when providers set") + func providerRoutingIncludesOrder() { + let config = OpenRouterRoutingConfig(providers: [.anthropic, .openai]) + let routing = config.providerRouting() + let order = routing?["order"] as? [String] + #expect(order == ["anthropic", "openai"]) + } + + @Test("providerRouting includes allow_fallbacks false when disabled") + func providerRoutingDisabledFallbacks() { + let config = OpenRouterRoutingConfig(fallbacks: false) + let routing = config.providerRouting() + #expect(routing?["allow_fallbacks"] as? Bool == false) + } + + @Test("providerRouting includes sort latency when enabled") + func providerRoutingSortLatency() { + let config = OpenRouterRoutingConfig(routeByLatency: true) + let routing = config.providerRouting() + #expect(routing?["sort"] as? String == "latency") + } + + @Test("providerRouting includes require_parameters when JSON required") + func providerRoutingRequireJSON() { + let config = OpenRouterRoutingConfig(requireProvidersForJSON: true) + let routing = config.providerRouting() + #expect(routing?["require_parameters"] as? Bool == true) + } + + @Test("providerRouting includes data_collection when set") + func providerRoutingDataCollection() { + let config = OpenRouterRoutingConfig(dataCollection: .deny) + let routing = config.providerRouting() + #expect(routing?["data_collection"] as? String == "deny") + } + + @Test("providerRouting uses legacy routeTag as data_collection when valid") + func providerRoutingLegacyRouteTag() { + let config = OpenRouterRoutingConfig(routeTag: "allow") + let routing = config.providerRouting() + #expect(routing?["data_collection"] as? String == "allow") + } + + @Test("providerRouting ignores invalid routeTag") + func providerRoutingIgnoresInvalidRouteTag() { + let config = OpenRouterRoutingConfig(routeTag: "custom-tag") + let routing = config.providerRouting() + // "custom-tag" is not a valid OpenRouterDataCollection, so it should not appear + #expect(routing?["data_collection"] == nil) + } + + @Test("providerRouting prefers dataCollection over routeTag") + func providerRoutingPrefersDataCollection() { + let config = OpenRouterRoutingConfig(routeTag: "allow", dataCollection: .deny) + let routing = config.providerRouting() + #expect(routing?["data_collection"] as? String == "deny") + } + + // MARK: - Fluent API + + @Test("Fluent providers returns updated copy") + func fluentProviders() { + let config = OpenRouterRoutingConfig.default + .providers([.google, .mistral]) + #expect(config.providers == [.google, .mistral]) + } + + @Test("Fluent fallbacks returns updated copy") + func fluentFallbacks() { + let config = OpenRouterRoutingConfig.default.fallbacks(false) + #expect(config.fallbacks == false) + } + + @Test("Fluent routeByLatency returns updated copy") + func fluentRouteByLatency() { + let config = OpenRouterRoutingConfig.default.routeByLatency(true) + #expect(config.routeByLatency == true) + } + + @Test("Fluent siteURL returns updated copy") + func fluentSiteURL() { + let url = URL(string: "https://test.com")! + let config = OpenRouterRoutingConfig.default.siteURL(url) + #expect(config.siteURL == url) + } + + @Test("Fluent appName returns updated copy") + func fluentAppName() { + let config = OpenRouterRoutingConfig.default.appName("TestApp") + #expect(config.appName == "TestApp") + } + + @Test("Fluent routeTag returns updated copy") + func fluentRouteTag() { + let config = OpenRouterRoutingConfig.default.routeTag("prod") + #expect(config.routeTag == "prod") + } + + @Test("Fluent dataCollection returns updated copy") + func fluentDataCollection() { + let config = OpenRouterRoutingConfig.default.dataCollection(.deny) + #expect(config.dataCollection == .deny) + } + + @Test("Fluent chaining works correctly") + func fluentChaining() { + let config = OpenRouterRoutingConfig.default + .providers([.anthropic]) + .fallbacks(false) + .routeByLatency(true) + .appName("ChainedApp") + .dataCollection(.allow) + + #expect(config.providers == [.anthropic]) + #expect(config.fallbacks == false) + #expect(config.routeByLatency == true) + #expect(config.appName == "ChainedApp") + #expect(config.dataCollection == .allow) + } + + // MARK: - Codable + + @Test("OpenRouterRoutingConfig round-trips through JSON") + func codableRoundTrip() throws { + let original = OpenRouterRoutingConfig( + providers: [.openai, .anthropic], + fallbacks: false, + routeByLatency: true, + requireProvidersForJSON: true, + siteURL: URL(string: "https://example.com"), + appName: "TestApp", + routeTag: "allow", + dataCollection: .deny + ) + + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenRouterRoutingConfig.self, from: data) + + #expect(decoded.providers == original.providers) + #expect(decoded.fallbacks == original.fallbacks) + #expect(decoded.routeByLatency == original.routeByLatency) + #expect(decoded.requireProvidersForJSON == original.requireProvidersForJSON) + #expect(decoded.siteURL == original.siteURL) + #expect(decoded.appName == original.appName) + #expect(decoded.routeTag == original.routeTag) + #expect(decoded.dataCollection == original.dataCollection) + } + + @Test("Default preset round-trips through JSON") + func codableDefaultRoundTrip() throws { + let original = OpenRouterRoutingConfig.default + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(OpenRouterRoutingConfig.self, from: data) + #expect(decoded == original) + } + + // MARK: - Hashable / Equatable + + @Test("Equal configs are equal") + func equalConfigsEqual() { + let a = OpenRouterRoutingConfig.default + let b = OpenRouterRoutingConfig.default + #expect(a == b) + } + + @Test("Different configs are not equal") + func differentConfigsNotEqual() { + let a = OpenRouterRoutingConfig.preferOpenAI + let b = OpenRouterRoutingConfig.preferAnthropic + #expect(a != b) + } + + @Test("Sendable conformance compiles") + func sendableConformance() { + let config: Sendable = OpenRouterRoutingConfig.default + #expect(config is OpenRouterRoutingConfig) + } +} + +// MARK: - OpenRouterProvider Tests + +@Suite("OpenRouterProvider Tests") +struct OpenRouterProviderTests { + + @Test("All CaseIterable providers exist") + func allCasesExist() { + let allCases = OpenRouterProvider.allCases + #expect(allCases.contains(.openai)) + #expect(allCases.contains(.anthropic)) + #expect(allCases.contains(.google)) + #expect(allCases.contains(.googleAIStudio)) + #expect(allCases.contains(.together)) + #expect(allCases.contains(.fireworks)) + #expect(allCases.contains(.perplexity)) + #expect(allCases.contains(.mistral)) + #expect(allCases.contains(.groq)) + #expect(allCases.contains(.deepseek)) + #expect(allCases.contains(.cohere)) + #expect(allCases.contains(.ai21)) + #expect(allCases.contains(.bedrock)) + #expect(allCases.contains(.azure)) + #expect(allCases.count == 14) + } + + @Test("Raw values are display names") + func rawValuesAreDisplayNames() { + #expect(OpenRouterProvider.openai.rawValue == "OpenAI") + #expect(OpenRouterProvider.anthropic.rawValue == "Anthropic") + #expect(OpenRouterProvider.google.rawValue == "Google") + #expect(OpenRouterProvider.googleAIStudio.rawValue == "Google AI Studio") + #expect(OpenRouterProvider.together.rawValue == "Together") + #expect(OpenRouterProvider.fireworks.rawValue == "Fireworks") + #expect(OpenRouterProvider.perplexity.rawValue == "Perplexity") + #expect(OpenRouterProvider.mistral.rawValue == "Mistral") + #expect(OpenRouterProvider.groq.rawValue == "Groq") + #expect(OpenRouterProvider.deepseek.rawValue == "DeepSeek") + #expect(OpenRouterProvider.cohere.rawValue == "Cohere") + #expect(OpenRouterProvider.ai21.rawValue == "AI21") + #expect(OpenRouterProvider.bedrock.rawValue == "Amazon Bedrock") + #expect(OpenRouterProvider.azure.rawValue == "Azure") + } + + @Test("Slugs are lowercase API identifiers") + func slugsAreLowercase() { + #expect(OpenRouterProvider.openai.slug == "openai") + #expect(OpenRouterProvider.anthropic.slug == "anthropic") + #expect(OpenRouterProvider.google.slug == "google") + #expect(OpenRouterProvider.googleAIStudio.slug == "google-ai-studio") + #expect(OpenRouterProvider.together.slug == "together") + #expect(OpenRouterProvider.fireworks.slug == "fireworks") + #expect(OpenRouterProvider.perplexity.slug == "perplexity") + #expect(OpenRouterProvider.mistral.slug == "mistral") + #expect(OpenRouterProvider.groq.slug == "groq") + #expect(OpenRouterProvider.deepseek.slug == "deepseek") + #expect(OpenRouterProvider.cohere.slug == "cohere") + #expect(OpenRouterProvider.ai21.slug == "ai21") + #expect(OpenRouterProvider.bedrock.slug == "bedrock") + #expect(OpenRouterProvider.azure.slug == "azure") + } + + @Test("displayName matches rawValue") + func displayNameMatchesRawValue() { + for provider in OpenRouterProvider.allCases { + #expect(provider.displayName == provider.rawValue) + } + } + + @Test("Codable round-trip for all providers") + func codableRoundTrip() throws { + for provider in OpenRouterProvider.allCases { + let data = try JSONEncoder().encode(provider) + let decoded = try JSONDecoder().decode(OpenRouterProvider.self, from: data) + #expect(decoded == provider) + } + } + + @Test("Can be used in a Set") + func usableInSet() { + let set: Set = [.openai, .anthropic, .openai] + #expect(set.count == 2) + } +} + +// MARK: - OpenRouterDataCollection Tests + +@Suite("OpenRouterDataCollection Tests") +struct OpenRouterDataCollectionTests { + + @Test("allow has correct raw value") + func allowRawValue() { + #expect(OpenRouterDataCollection.allow.rawValue == "allow") + } + + @Test("deny has correct raw value") + func denyRawValue() { + #expect(OpenRouterDataCollection.deny.rawValue == "deny") + } + + @Test("CaseIterable includes both cases") + func caseIterable() { + let allCases = OpenRouterDataCollection.allCases + #expect(allCases.count == 2) + #expect(allCases.contains(.allow)) + #expect(allCases.contains(.deny)) + } + + @Test("Codable round-trip for both cases") + func codableRoundTrip() throws { + for policy in OpenRouterDataCollection.allCases { + let data = try JSONEncoder().encode(policy) + let decoded = try JSONDecoder().decode(OpenRouterDataCollection.self, from: data) + #expect(decoded == policy) + } + } + + @Test("Can initialize from valid raw values") + func initFromRawValue() { + #expect(OpenRouterDataCollection(rawValue: "allow") == .allow) + #expect(OpenRouterDataCollection(rawValue: "deny") == .deny) + #expect(OpenRouterDataCollection(rawValue: "invalid") == nil) + } +} + +#endif // CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER diff --git a/Tests/ConduitTests/Providers/OpenAI/RetryConfigurationTests.swift b/Tests/ConduitTests/Providers/OpenAI/RetryConfigurationTests.swift new file mode 100644 index 0000000..f3f4bbc --- /dev/null +++ b/Tests/ConduitTests/Providers/OpenAI/RetryConfigurationTests.swift @@ -0,0 +1,534 @@ +// RetryConfigurationTests.swift +// Conduit Tests +// +// Tests for RetryConfiguration, RetryStrategy, and RetryableErrorType. + +#if CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER +import Foundation +import Testing +@testable import Conduit + +// MARK: - RetryConfiguration Tests + +@Suite("RetryConfiguration Tests") +struct RetryConfigurationTests { + + // MARK: - Static Presets + + @Test("Default preset has expected values") + func defaultPreset() { + let config = RetryConfiguration.default + #expect(config.maxRetries == 3) + #expect(config.baseDelay == 1.0) + #expect(config.maxDelay == 30.0) + #expect(config.retryableStatusCodes == [408, 429, 500, 502, 503, 504]) + #expect(config.retryableErrors == [.timeout, .connectionLost, .serverError, .rateLimited]) + } + + @Test("Aggressive preset has expected values") + func aggressivePreset() { + let config = RetryConfiguration.aggressive + #expect(config.maxRetries == 5) + #expect(config.baseDelay == 0.5) + #expect(config.maxDelay == 15.0) + #expect(config.strategy == .exponentialWithJitter()) + } + + @Test("Conservative preset has expected values") + func conservativePreset() { + let config = RetryConfiguration.conservative + #expect(config.maxRetries == 2) + #expect(config.baseDelay == 2.0) + #expect(config.maxDelay == 60.0) + #expect(config.strategy == .exponentialBackoff()) + } + + @Test("None preset disables retries") + func nonePreset() { + let config = RetryConfiguration.none + #expect(config.maxRetries == 0) + } + + // MARK: - Init Clamping + + @Test("Negative maxRetries is clamped to zero") + func negativeMaxRetriesClamped() { + let config = RetryConfiguration(maxRetries: -5) + #expect(config.maxRetries == 0) + } + + @Test("Negative baseDelay is clamped to zero") + func negativeBaseDelayClamped() { + let config = RetryConfiguration(baseDelay: -2.0) + #expect(config.baseDelay == 0.0) + } + + @Test("maxDelay is clamped to at least baseDelay") + func maxDelayClampedToBaseDelay() { + let config = RetryConfiguration(baseDelay: 10.0, maxDelay: 5.0) + #expect(config.maxDelay == 10.0) + } + + // MARK: - Delay Calculation + + @Test("Delay for attempt 0 is always zero") + func delayForAttemptZero() { + let config = RetryConfiguration.default + #expect(config.delay(forAttempt: 0) == 0) + } + + @Test("Delay is capped at maxDelay") + func delayCappedAtMaxDelay() { + let config = RetryConfiguration( + maxRetries: 10, + baseDelay: 1.0, + maxDelay: 5.0, + strategy: .exponentialBackoff(multiplier: 10.0) + ) + let delay = config.delay(forAttempt: 5) + #expect(delay <= 5.0) + } + + @Test("Immediate strategy returns zero delay for all attempts") + func immediateStrategyZeroDelay() { + let config = RetryConfiguration(strategy: .immediate) + for attempt in 0...5 { + #expect(config.delay(forAttempt: attempt) == 0) + } + } + + @Test("Fixed strategy returns constant delay") + func fixedStrategyConstantDelay() { + let config = RetryConfiguration(strategy: .fixed(delay: 2.5), maxDelay: 100.0) + #expect(config.delay(forAttempt: 1) == 2.5) + #expect(config.delay(forAttempt: 2) == 2.5) + #expect(config.delay(forAttempt: 5) == 2.5) + } + + @Test("Exponential backoff doubles delay each attempt") + func exponentialBackoffDoublesDelay() { + let config = RetryConfiguration( + baseDelay: 1.0, + maxDelay: 1000.0, + strategy: .exponentialBackoff(multiplier: 2.0) + ) + #expect(config.delay(forAttempt: 1) == 1.0) // 1.0 * 2^0 = 1.0 + #expect(config.delay(forAttempt: 2) == 2.0) // 1.0 * 2^1 = 2.0 + #expect(config.delay(forAttempt: 3) == 4.0) // 1.0 * 2^2 = 4.0 + #expect(config.delay(forAttempt: 4) == 8.0) // 1.0 * 2^3 = 8.0 + } + + @Test("Exponential with jitter returns non-negative delay close to base exponential") + func exponentialWithJitterReturnsNonNegative() { + let config = RetryConfiguration( + baseDelay: 1.0, + maxDelay: 1000.0, + strategy: .exponentialWithJitter(multiplier: 2.0, jitterFactor: 0.1) + ) + for _ in 0..<20 { + let delay = config.delay(forAttempt: 1) + #expect(delay >= 0) + // With jitterFactor 0.1, delay for attempt 1 should be baseDelay +/- 10% + #expect(delay >= 0.9) + #expect(delay <= 1.1) + } + } + + // MARK: - shouldRetry(statusCode:) + + @Test("Default config retries 429 status code") + func shouldRetry429() { + let config = RetryConfiguration.default + #expect(config.shouldRetry(statusCode: 429)) + } + + @Test("Default config retries 500 status code") + func shouldRetry500() { + let config = RetryConfiguration.default + #expect(config.shouldRetry(statusCode: 500)) + } + + @Test("Default config retries 502 status code") + func shouldRetry502() { + let config = RetryConfiguration.default + #expect(config.shouldRetry(statusCode: 502)) + } + + @Test("Default config retries 503 status code") + func shouldRetry503() { + let config = RetryConfiguration.default + #expect(config.shouldRetry(statusCode: 503)) + } + + @Test("Default config retries 504 status code") + func shouldRetry504() { + let config = RetryConfiguration.default + #expect(config.shouldRetry(statusCode: 504)) + } + + @Test("Default config retries 408 status code") + func shouldRetry408() { + let config = RetryConfiguration.default + #expect(config.shouldRetry(statusCode: 408)) + } + + @Test("Default config does not retry 200") + func shouldNotRetry200() { + let config = RetryConfiguration.default + #expect(!config.shouldRetry(statusCode: 200)) + } + + @Test("Default config does not retry 401") + func shouldNotRetry401() { + let config = RetryConfiguration.default + #expect(!config.shouldRetry(statusCode: 401)) + } + + @Test("Default config does not retry 403") + func shouldNotRetry403() { + let config = RetryConfiguration.default + #expect(!config.shouldRetry(statusCode: 403)) + } + + @Test("Default config does not retry 404") + func shouldNotRetry404() { + let config = RetryConfiguration.default + #expect(!config.shouldRetry(statusCode: 404)) + } + + // MARK: - shouldRetry(errorType:) + + @Test("Default config retries timeout errors") + func shouldRetryTimeout() { + let config = RetryConfiguration.default + #expect(config.shouldRetry(errorType: .timeout)) + } + + @Test("Default config retries connectionLost errors") + func shouldRetryConnectionLost() { + let config = RetryConfiguration.default + #expect(config.shouldRetry(errorType: .connectionLost)) + } + + @Test("Default config retries serverError errors") + func shouldRetryServerError() { + let config = RetryConfiguration.default + #expect(config.shouldRetry(errorType: .serverError)) + } + + @Test("Default config retries rateLimited errors") + func shouldRetryRateLimited() { + let config = RetryConfiguration.default + #expect(config.shouldRetry(errorType: .rateLimited)) + } + + @Test("Default config does not retry dnsFailure errors") + func shouldNotRetryDnsFailure() { + let config = RetryConfiguration.default + #expect(!config.shouldRetry(errorType: .dnsFailure)) + } + + @Test("Default config does not retry sslError errors") + func shouldNotRetrySslError() { + let config = RetryConfiguration.default + #expect(!config.shouldRetry(errorType: .sslError)) + } + + // MARK: - Fluent API + + @Test("Fluent maxRetries returns updated copy") + func fluentMaxRetries() { + let config = RetryConfiguration.default.maxRetries(7) + #expect(config.maxRetries == 7) + #expect(RetryConfiguration.default.maxRetries == 3) + } + + @Test("Fluent maxRetries clamps negative to zero") + func fluentMaxRetriesClamps() { + let config = RetryConfiguration.default.maxRetries(-1) + #expect(config.maxRetries == 0) + } + + @Test("Fluent baseDelay returns updated copy") + func fluentBaseDelay() { + let config = RetryConfiguration.default.baseDelay(5.0) + #expect(config.baseDelay == 5.0) + } + + @Test("Fluent baseDelay clamps negative to zero") + func fluentBaseDelayClampsNegative() { + let config = RetryConfiguration.default.baseDelay(-1.0) + #expect(config.baseDelay == 0.0) + } + + @Test("Fluent maxDelay returns updated copy and clamps to baseDelay") + func fluentMaxDelay() { + let config = RetryConfiguration.default.baseDelay(10.0).maxDelay(5.0) + #expect(config.maxDelay == 10.0) + } + + @Test("Fluent strategy returns updated copy") + func fluentStrategy() { + let config = RetryConfiguration.default.strategy(.immediate) + #expect(config.strategy == .immediate) + } + + @Test("Fluent retryableStatusCodes returns updated copy") + func fluentRetryableStatusCodes() { + let codes: Set = [500] + let config = RetryConfiguration.default.retryableStatusCodes(codes) + #expect(config.retryableStatusCodes == [500]) + } + + @Test("Fluent retryableErrors returns updated copy") + func fluentRetryableErrors() { + let errors: Set = [.timeout] + let config = RetryConfiguration.default.retryableErrors(errors) + #expect(config.retryableErrors == [.timeout]) + } + + @Test("Fluent disabled sets maxRetries to zero") + func fluentDisabled() { + let config = RetryConfiguration.aggressive.disabled() + #expect(config.maxRetries == 0) + } + + // MARK: - Codable + + @Test("RetryConfiguration round-trips through JSON encoding and decoding") + func retryConfigurationCodableRoundTrip() throws { + let original = RetryConfiguration( + maxRetries: 4, + baseDelay: 2.0, + maxDelay: 20.0, + strategy: .fixed(delay: 3.0), + retryableStatusCodes: [429, 503], + retryableErrors: [.timeout, .rateLimited] + ) + + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(RetryConfiguration.self, from: data) + + #expect(decoded.maxRetries == original.maxRetries) + #expect(decoded.baseDelay == original.baseDelay) + #expect(decoded.maxDelay == original.maxDelay) + #expect(decoded.strategy == original.strategy) + #expect(decoded.retryableStatusCodes == original.retryableStatusCodes) + #expect(decoded.retryableErrors == original.retryableErrors) + } + + @Test("RetryConfiguration default preset round-trips through JSON") + func defaultPresetCodableRoundTrip() throws { + let original = RetryConfiguration.default + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(RetryConfiguration.self, from: data) + #expect(decoded == original) + } + + @Test("RetryConfiguration aggressive preset round-trips through JSON") + func aggressivePresetCodableRoundTrip() throws { + let original = RetryConfiguration.aggressive + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(RetryConfiguration.self, from: data) + #expect(decoded == original) + } + + @Test("RetryConfiguration none preset round-trips through JSON") + func nonePresetCodableRoundTrip() throws { + let original = RetryConfiguration.none + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(RetryConfiguration.self, from: data) + #expect(decoded == original) + } + + // MARK: - Hashable / Equatable + + @Test("Equal configurations are equal") + func equalConfigurationsAreEqual() { + let a = RetryConfiguration.default + let b = RetryConfiguration.default + #expect(a == b) + } + + @Test("Different configurations are not equal") + func differentConfigurationsAreNotEqual() { + let a = RetryConfiguration.default + let b = RetryConfiguration.aggressive + #expect(a != b) + } + + @Test("Sendable conformance compiles") + func sendableConformance() { + let config: Sendable = RetryConfiguration.default + #expect(config is RetryConfiguration) + } +} + +// MARK: - RetryStrategy Tests + +@Suite("RetryStrategy Tests") +struct RetryStrategyTests { + + @Test("Immediate strategy always returns 0") + func immediateAlwaysZero() { + let strategy = RetryStrategy.immediate + #expect(strategy.delay(forAttempt: 1, baseDelay: 5.0) == 0) + #expect(strategy.delay(forAttempt: 10, baseDelay: 5.0) == 0) + } + + @Test("Fixed strategy returns constant value regardless of attempt") + func fixedReturnsConstant() { + let strategy = RetryStrategy.fixed(delay: 3.5) + #expect(strategy.delay(forAttempt: 1, baseDelay: 1.0) == 3.5) + #expect(strategy.delay(forAttempt: 5, baseDelay: 1.0) == 3.5) + } + + @Test("Exponential backoff with multiplier 2 doubles per attempt") + func exponentialBackoffMultiplier2() { + let strategy = RetryStrategy.exponentialBackoff(multiplier: 2.0) + #expect(strategy.delay(forAttempt: 1, baseDelay: 1.0) == 1.0) + #expect(strategy.delay(forAttempt: 2, baseDelay: 1.0) == 2.0) + #expect(strategy.delay(forAttempt: 3, baseDelay: 1.0) == 4.0) + } + + @Test("Exponential backoff with multiplier 3 triples per attempt") + func exponentialBackoffMultiplier3() { + let strategy = RetryStrategy.exponentialBackoff(multiplier: 3.0) + #expect(strategy.delay(forAttempt: 1, baseDelay: 1.0) == 1.0) + #expect(strategy.delay(forAttempt: 2, baseDelay: 1.0) == 3.0) + #expect(strategy.delay(forAttempt: 3, baseDelay: 1.0) == 9.0) + } + + @Test("Exponential with jitter returns non-negative values") + func exponentialWithJitterNonNegative() { + let strategy = RetryStrategy.exponentialWithJitter(multiplier: 2.0, jitterFactor: 0.5) + for _ in 0..<50 { + let delay = strategy.delay(forAttempt: 1, baseDelay: 1.0) + #expect(delay >= 0) + } + } + + @Test("Strategies with different cases are not equal") + func differentStrategiesNotEqual() { + #expect(RetryStrategy.immediate != RetryStrategy.fixed(delay: 0)) + #expect(RetryStrategy.exponentialBackoff() != RetryStrategy.exponentialWithJitter()) + } + + @Test("RetryStrategy Codable round-trip for each case") + func strategyCodableRoundTrip() throws { + let strategies: [RetryStrategy] = [ + .immediate, + .fixed(delay: 2.5), + .exponentialBackoff(multiplier: 3.0), + .exponentialWithJitter(multiplier: 2.0, jitterFactor: 0.2) + ] + + for original in strategies { + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(RetryStrategy.self, from: data) + #expect(decoded == original) + } + } +} + +// MARK: - RetryableErrorType Tests + +@Suite("RetryableErrorType Tests") +struct RetryableErrorTypeTests { + + @Test("CaseIterable includes all cases") + func caseIterableIncludesAllCases() { + let allCases = RetryableErrorType.allCases + #expect(allCases.contains(.timeout)) + #expect(allCases.contains(.connectionLost)) + #expect(allCases.contains(.serverError)) + #expect(allCases.contains(.rateLimited)) + #expect(allCases.contains(.dnsFailure)) + #expect(allCases.contains(.sslError)) + #expect(allCases.count == 6) + } + + @Test("URLError.timedOut maps to .timeout") + func timedOutMapsToTimeout() { + let urlError = URLError(.timedOut) + #expect(RetryableErrorType.from(urlError) == .timeout) + } + + @Test("URLError.networkConnectionLost maps to .connectionLost") + func networkConnectionLostMaps() { + let urlError = URLError(.networkConnectionLost) + #expect(RetryableErrorType.from(urlError) == .connectionLost) + } + + @Test("URLError.notConnectedToInternet maps to .connectionLost") + func notConnectedToInternetMaps() { + let urlError = URLError(.notConnectedToInternet) + #expect(RetryableErrorType.from(urlError) == .connectionLost) + } + + @Test("URLError.dnsLookupFailed maps to .dnsFailure") + func dnsLookupFailedMaps() { + let urlError = URLError(.dnsLookupFailed) + #expect(RetryableErrorType.from(urlError) == .dnsFailure) + } + + @Test("URLError.cannotFindHost maps to .dnsFailure") + func cannotFindHostMaps() { + let urlError = URLError(.cannotFindHost) + #expect(RetryableErrorType.from(urlError) == .dnsFailure) + } + + @Test("URLError.secureConnectionFailed maps to .sslError") + func secureConnectionFailedMaps() { + let urlError = URLError(.secureConnectionFailed) + #expect(RetryableErrorType.from(urlError) == .sslError) + } + + @Test("SSL certificate errors are not retryable") + func sslCertificateErrorsNotRetryable() { + let untrusted = URLError(.serverCertificateUntrusted) + #expect(RetryableErrorType.from(untrusted) == nil) + + let badDate = URLError(.serverCertificateHasBadDate) + #expect(RetryableErrorType.from(badDate) == nil) + + let notYetValid = URLError(.serverCertificateNotYetValid) + #expect(RetryableErrorType.from(notYetValid) == nil) + + let unknownRoot = URLError(.serverCertificateHasUnknownRoot) + #expect(RetryableErrorType.from(unknownRoot) == nil) + + let clientRejected = URLError(.clientCertificateRejected) + #expect(RetryableErrorType.from(clientRejected) == nil) + + let clientRequired = URLError(.clientCertificateRequired) + #expect(RetryableErrorType.from(clientRequired) == nil) + } + + @Test("Unrecognized URLError returns nil") + func unrecognizedURLErrorReturnsNil() { + let urlError = URLError(.cancelled) + #expect(RetryableErrorType.from(urlError) == nil) + } + + @Test("RetryableErrorType raw values are stable strings") + func rawValuesAreStableStrings() { + #expect(RetryableErrorType.timeout.rawValue == "timeout") + #expect(RetryableErrorType.connectionLost.rawValue == "connectionLost") + #expect(RetryableErrorType.serverError.rawValue == "serverError") + #expect(RetryableErrorType.rateLimited.rawValue == "rateLimited") + #expect(RetryableErrorType.dnsFailure.rawValue == "dnsFailure") + #expect(RetryableErrorType.sslError.rawValue == "sslError") + } + + @Test("RetryableErrorType Codable round-trip") + func codableRoundTrip() throws { + for errorType in RetryableErrorType.allCases { + let data = try JSONEncoder().encode(errorType) + let decoded = try JSONDecoder().decode(RetryableErrorType.self, from: data) + #expect(decoded == errorType) + } + } +} + +#endif // CONDUIT_TRAIT_OPENAI || CONDUIT_TRAIT_OPENROUTER