Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class TraceRecord implements Serializable {
transient private String executorName
transient private CloudMachineInfo machineInfo
transient private ContainerMeta containerMeta
transient private Integer numSpotInterruptions

/**
* Convert the given value to a string
Expand Down Expand Up @@ -611,6 +612,14 @@ class TraceRecord implements Serializable {
this.machineInfo = value
}

Integer getNumSpotInterruptions() {
return numSpotInterruptions
}

void setNumSpotInterruptions(Integer numSpotInterruptions) {
this.numSpotInterruptions = numSpotInterruptions
}

ContainerMeta getContainerMeta() {
return containerMeta
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,4 +344,29 @@ class TraceRecordTest extends Specification {
then:
thrown(NoSuchFileException)
}

def 'should manage numSpotInterruptions and not persist it across serialization'() {
given:
def rec = new TraceRecord()

expect:
rec.getNumSpotInterruptions() == null
and:
rec.numSpotInterruptions == null

when:
rec.setNumSpotInterruptions(3)

then:
rec.getNumSpotInterruptions() == 3
rec.numSpotInterruptions == 3

when:
def buf = rec.serialize()
def rec2 = TraceRecord.deserialize(buf)

then:
rec2.getNumSpotInterruptions() == null
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -916,10 +916,48 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
return machineInfo
}

/**
* Count the number of spot instance reclamations for this job by examining
* the job attempts and checking for EC2 spot interruption status reasons
*
* @param jobId The AWS Batch Job Id
* @return The number of times this job was retried due to spot instance reclamation
*/
protected Integer getNumSpotInterruptions(String jobId) {
if (!jobId || !isCompleted())
return null

try {
def job = describeJob(jobId)
if (!job)
return null
if (!job.attempts())
return 0

int count = 0
for (def attempt : job.attempts()) {
// Check attempt-level statusReason
def attemptReason = attempt.statusReason()
// AWS Batch uses "Host EC2 (instance i-xxx) terminated." pattern for spot interruptions
// Using startsWith to match the pattern regardless of instance ID
if (attemptReason && attemptReason.startsWith('Host EC2')) {
count++
}
}
log.trace "Job $jobId had $count spot interruptions"
return count
}
catch (Exception e) {
log.debug "[AWS BATCH] Unable to count spot interruptions for job=$jobId - ${e.message}"
return null
}
}

TraceRecord getTraceRecord() {
def result = super.getTraceRecord()
result.put('native_id', jobId)
result.machineInfo = getMachineInfo()
result.numSpotInterruptions = getNumSpotInterruptions(jobId)
return result
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import nextflow.script.ProcessConfig
import nextflow.util.CacheHelper
import nextflow.util.MemoryUnit
import software.amazon.awssdk.services.batch.BatchClient
import software.amazon.awssdk.services.batch.model.AttemptDetail
import software.amazon.awssdk.services.batch.model.DescribeJobDefinitionsRequest
import software.amazon.awssdk.services.batch.model.DescribeJobDefinitionsResponse
import software.amazon.awssdk.services.batch.model.DescribeJobsRequest
Expand Down Expand Up @@ -905,7 +906,7 @@ class AwsBatchTaskHandlerTest extends Specification {
when:
def trace = handler.getTraceRecord()
then:
1 * handler.isCompleted() >> false
2 * handler.isCompleted() >> false
1 * handler.getMachineInfo() >> new CloudMachineInfo('x1.large', 'us-east-1b', PriceModel.spot)

and:
Expand All @@ -916,6 +917,48 @@ class AwsBatchTaskHandlerTest extends Specification {
trace.machineInfo.priceModel == PriceModel.spot
}

def 'should create the trace record when job is completed with spot interruptions' () {
given:
def exec = Mock(Executor) { getName() >> 'awsbatch' }
def processor = Mock(TaskProcessor)
processor.getExecutor() >> exec
processor.getName() >> 'foo'
processor.getConfig() >> new ProcessConfig(Mock(BaseScript))
def task = Mock(TaskRun)
task.getProcessor() >> processor
task.getConfig() >> GroovyMock(TaskConfig)
def proxy = Mock(AwsBatchProxy)
def handler = Spy(AwsBatchTaskHandler)
handler.@client = proxy
handler.task = task
handler.@jobId = 'xyz-123'
handler.setStatus(TaskStatus.COMPLETED)

def attempt1 = GroovyMock(AttemptDetail)
def attempt2 = GroovyMock(AttemptDetail)
attempt1.statusReason() >> 'Host EC2 (instance i-123) terminated.'
attempt1.container() >> null
attempt2.statusReason() >> 'Essential container in task exited'
attempt2.container() >> null
def job = JobDetail.builder().attempts([attempt1, attempt2]).build()

// Stub BEFORE calling the method
handler.isCompleted() >> true
handler.getMachineInfo() >> new CloudMachineInfo('x1.large', 'us-east-1b', PriceModel.spot)
handler.describeJob('xyz-123') >> job

when:
def trace = handler.getTraceRecord()

then:
trace.native_id == 'xyz-123'
trace.executorName == 'awsbatch'
trace.machineInfo.type == 'x1.large'
trace.machineInfo.zone == 'us-east-1b'
trace.machineInfo.priceModel == PriceModel.spot
trace.numSpotInterruptions == 1
}

def 'should render submit command' () {
given:
def executor = Spy(AwsBatchExecutor)
Expand Down Expand Up @@ -1138,4 +1181,56 @@ class AwsBatchTaskHandlerTest extends Specification {
1 | true | true | 1
2 | true | true | 2
}

def 'should return zero spot interruptions when no attempts or non-spot terminations exist'() {
given:
def handler = Spy(AwsBatchTaskHandler)
def attempt1 = GroovyMock(AttemptDetail) {
statusReason() >> 'Essential container in task exited'
}
def attempt2 = GroovyMock(AttemptDetail) {
statusReason() >> 'Some other reason'
}

when:
def resultNoAttempts = handler.getNumSpotInterruptions('job-123')
then:
1 * handler.isCompleted() >> true
1 * handler.describeJob('job-123') >> JobDetail.builder().attempts([]).build()
resultNoAttempts == 0

when:
def resultNonSpot = handler.getNumSpotInterruptions('job-456')
then:
1 * handler.isCompleted() >> true
1 * handler.describeJob('job-456') >> JobDetail.builder().attempts([attempt1, attempt2]).build()
resultNonSpot == 0
}

def 'should return null when job cannot be processed'() {
given:
def handler = Spy(AwsBatchTaskHandler)

when:
def resultNotCompleted = handler.getNumSpotInterruptions('job-123')
then:
1 * handler.isCompleted() >> false
0 * handler.describeJob(_)
resultNotCompleted == null

when:
def resultNullJobId = handler.getNumSpotInterruptions(null)
then:
0 * handler.isCompleted()
0 * handler.describeJob(_)
resultNullJobId == null

when:
def resultException = handler.getNumSpotInterruptions('job-789')
then:
1 * handler.isCompleted() >> true
1 * handler.describeJob('job-789') >> { throw new RuntimeException("Error") }
resultException == null
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -635,12 +635,52 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
return machineInfo
}

/**
* Count the number of spot instance reclamations for this task by examining
* the task status events and checking for preemption exit codes
*
* @param jobId The Google Batch Job Id
* @return The number of times this task was retried due to spot instance reclamation
*/

protected Integer getNumSpotInterruptions(String jobId) {
if (!jobId || !taskId || !isCompleted()) {
return null
}

try {
final status = client.getTaskStatus(jobId, taskId)

if (!status)
return null

// valid status but no events present means no interruptions occurred
if (!status?.statusEventsList)
return 0

int count = 0
for (def event : status.statusEventsList) {
// Google Batch uses exit code 50001 for spot preemption
// Check if the event has a task execution with exit code 50001
if (event.hasTaskExecution() && event.taskExecution.exitCode == 50001) {
count++
}
}
return count

} catch (Exception e) {
log.debug "[GOOGLE BATCH] Unable to count spot interruptions for job=$jobId task=$taskId - ${e.message}"
return null
}
}

@Override
TraceRecord getTraceRecord() {
def result = super.getTraceRecord()
if( jobId && uid )
result.put('native_id', "$jobId/$taskId/$uid")
result.machineInfo = getMachineInfo()
result.numSpotInterruptions = getNumSpotInterruptions(jobId)
return result
}

Expand Down
Loading
Loading