Skip to content

Commit

Permalink
fix: Retry handling for streamed request bodies (#852)
Browse files Browse the repository at this point in the history
* fix: Rewind a stream before retrying it, don't retry nonseekable streams

* Fix tests so they run without delay

* Fix request count logic to include 1st request
  • Loading branch information
jbelkins authored Nov 8, 2024
1 parent 2519d38 commit abd012b
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 5 deletions.
26 changes: 26 additions & 0 deletions Sources/ClientRuntime/Orchestrator/Orchestrator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,13 @@ public struct Orchestrator<
// If we can't get errorInfo, we definitely can't retry
guard let errorInfo = retryErrorInfoProvider(error) else { return }

// If the body is a nonseekable stream, we also can't retry
do {
guard try readyBodyForRetry(request: copiedRequest) else { return }
} catch {
return
}

// When refreshing fails it throws, indicating we're done retrying
do {
try await strategy.refreshRetryTokenForRetry(tokenToRenew: token, errorInfo: errorInfo)
Expand All @@ -277,6 +284,25 @@ public struct Orchestrator<
}
}

/// Readies the body for retry, and indicates whether the request body may be safely used in a retry.
/// - Parameter request: The request to be retried.
/// - Returns: `true` if the body of the request is safe to retry, `false` otherwise. In general, a request body is retriable if it is not a stream, or
/// if the stream is seekable and successfully seeks to the start position / offset zero.
private func readyBodyForRetry(request: RequestType) throws -> Bool {
switch request.body {
case .stream(let stream):
guard stream.isSeekable else { return false }
do {
try stream.seek(toOffset: 0)
return true
} catch {
return false
}
case .data, .noStream:
return true
}
}

private func attempt(context: InterceptorContextType, attemptCount: Int) async {
// If anything in here fails, the attempt short-circuits and we go to modifyBeforeAttemptCompletion,
// with the thrown error in context.result
Expand Down
70 changes: 65 additions & 5 deletions Tests/ClientRuntimeTests/OrchestratorTests/OrchestratorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import SmithyRetriesAPI
import SmithyRetries
@_spi(SmithyReadWrite) import SmithyJSON
@_spi(SmithyReadWrite) import SmithyReadWrite
import SmithyStreams

class OrchestratorTests: XCTestCase {
struct TestInput {
Expand Down Expand Up @@ -167,20 +168,23 @@ class OrchestratorTests: XCTestCase {
}

class TraceExecuteRequest: ExecuteRequest {
var succeedAfter: Int
let succeedAfter: Int
var trace: Trace

private(set) var requestCount = 0

init(succeedAfter: Int = 0, trace: Trace) {
self.succeedAfter = succeedAfter
self.trace = trace
}

public func execute(request: HTTPRequest, attributes: Context) async throws -> HTTPResponse {
trace.append("executeRequest")
if succeedAfter <= 0 {
if succeedAfter - requestCount <= 0 {
requestCount += 1
return HTTPResponse(body: request.body, statusCode: .ok)
} else {
succeedAfter -= 1
requestCount += 1
return HTTPResponse(body: request.body, statusCode: .internalServerError)
}
}
Expand Down Expand Up @@ -233,7 +237,7 @@ class OrchestratorTests: XCTestCase {
throw try UnknownHTTPServiceError.makeError(baseError: baseError)
}
})
.retryStrategy(DefaultRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ExponentialBackoffStrategy())))
.retryStrategy(DefaultRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ImmediateBackoffStrategy())))
.retryErrorInfoProvider({ e in
trace.append("errorInfo")
return DefaultRetryErrorInfoProvider.errorInfo(for: e)
Expand Down Expand Up @@ -530,7 +534,7 @@ class OrchestratorTests: XCTestCase {
let initialTokenTrace = Trace()
let initialToken = await asyncResult {
return try await self.traceOrchestrator(trace: initialTokenTrace)
.retryStrategy(ThrowingRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ExponentialBackoffStrategy())))
.retryStrategy(ThrowingRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ImmediateBackoffStrategy())))
.build()
.execute(input: TestInput(foo: ""))
}
Expand Down Expand Up @@ -1315,4 +1319,60 @@ class OrchestratorTests: XCTestCase {
}
}
}

/// Used in retry tests to perform the next retry without waiting, so that tests complete without delay.
private struct ImmediateBackoffStrategy: RetryBackoffStrategy {
func computeNextBackoffDelay(attempt: Int) -> TimeInterval { 0.0 }
}

func test_retry_retriesDataBody() async throws {
let input = TestInput(foo: "bar")
let trace = Trace()
let executeRequest = TraceExecuteRequest(succeedAfter: 2, trace: trace)
let orchestrator = traceOrchestrator(trace: trace)
.retryStrategy(DefaultRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ImmediateBackoffStrategy())))
.serialize({ (input: TestInput, builder: HTTPRequestBuilder, context) in
builder.withBody(.data(Data("\"\(input.foo)\"".utf8)))
})
.executeRequest(executeRequest)
let result = await asyncResult {
return try await orchestrator.build().execute(input: input)
}
XCTAssertNoThrow(try result.get())
XCTAssertEqual(executeRequest.requestCount, 3)
}

func test_retry_doesntRetryNonSeekableStreamBody() async throws {
let input = TestInput(foo: "bar")
let trace = Trace()
let executeRequest = TraceExecuteRequest(succeedAfter: 2, trace: trace)
let orchestrator = traceOrchestrator(trace: trace)
.retryStrategy(DefaultRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ImmediateBackoffStrategy())))
.serialize({ (input: TestInput, builder: HTTPRequestBuilder, context) in
builder.withBody(.stream(BufferedStream(data: Data("\"\(input.foo)\"".utf8), isClosed: true)))
})
.executeRequest(executeRequest)
let result = await asyncResult {
return try await orchestrator.build().execute(input: input)
}
XCTAssertThrowsError(try result.get())
XCTAssertEqual(executeRequest.requestCount, 1)
}

func test_retry_nonSeekableStreamBodySucceeds() async throws {
let input = TestInput(foo: "bar")
let trace = Trace()
let executeRequest = TraceExecuteRequest(succeedAfter: 0, trace: trace)
let orchestrator = traceOrchestrator(trace: trace)
.retryStrategy(DefaultRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ImmediateBackoffStrategy())))
.serialize({ (input: TestInput, builder: HTTPRequestBuilder, context) in
builder.withBody(.stream(BufferedStream(data: Data("\"\(input.foo)\"".utf8), isClosed: true)))
})
.executeRequest(executeRequest)
let result = await asyncResult {
return try await orchestrator.build().execute(input: input)
}
XCTAssertNoThrow(try result.get())
XCTAssertEqual(executeRequest.requestCount, 1)
}
}

0 comments on commit abd012b

Please sign in to comment.