Skip to content

Commit

Permalink
change signature of orphanedThreadFilter (#45469)
Browse files Browse the repository at this point in the history
- Introduced a new `OrphanedThreadInfo` data class to encapsulate thread information to avoid data races when getting thread-specific information (like stacktrace, or threadLocal values)
- Updated thread filtering and logging to use `OrphanedThreadInfo`
- Enhanced logging for timeout tasks in `LoggingInvocationInterceptor`
- fixed race condition in `AirbyteTraceMessageUtilityTest`
- improved test logging to add the running tests for the destination containers
- Modified `TestContext` to use a default value for `CURRENT_TEST_NAME`
  • Loading branch information
stephane-airbyte authored Sep 17, 2024
1 parent 6469111 commit 6d74db7
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 58 deletions.
1 change: 1 addition & 0 deletions airbyte-cdk/java/airbyte-cdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ corresponds to that version.

| Version | Date | Pull Request | Subject |
|:-----------|:-----------| :----------------------------------------------------------- |:---------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 0.45.0 | 2024-09-16 | [\#45469](https://github.com/airbytehq/airbyte/pull/45469) | Fix some race conditions, improve thread filtering, improve test logging |
| 0.44.22 | 2024-09-10 | [\#45368](https://github.com/airbytehq/airbyte/pull/45368) | Remove excessive debezium logging |
| 0.44.21 | 2024-09-04 | [\#45143](https://github.com/airbytehq/airbyte/pull/45143) | S3-destination: don't overwrite existing files, skip those file indexes instead |
| 0.44.20 | 2024-08-30 | [\#44933](https://github.com/airbytehq/airbyte/pull/44933) | Avro/Parquet destinations: handle `{}` schemas inside objects/arrays |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,28 @@ constructor(
recordTransform: CheckedFunction<ResultSet, T, SQLException>
): Stream<T> {
val connection = dataSource.connection
return JdbcDatabase.Companion.toUnsafeStream<T>(
statementCreator.apply(connection).executeQuery(),
recordTransform
)
.onClose(
Runnable {
try {
LOGGER.info { "closing connection" }
connection.close()
} catch (e: SQLException) {
throw RuntimeException(e)
try {
return JdbcDatabase.Companion.toUnsafeStream<T>(
statementCreator.apply(connection).executeQuery(),
recordTransform
)
.onClose(
Runnable {
try {
LOGGER.info { "closing connection" }
connection.close()
} catch (e: SQLException) {
throw RuntimeException(e)
}
}
}
)
)
} catch (e: Throwable) {
// this is ugly because we usually don't close the connection here.
// We expect the calleer to close the returned stream, which will call the onClose
// but if the executeQuery threw an exception, we still need to close the connection
LOGGER.warn(e) { "closing connection because of an Exception" }
connection.close()
throw e
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,45 @@ internal constructor(
}
}

data class OrphanedThreadInfo
private constructor(
val thread: Thread,
val threadCreationInfo: ThreadCreationInfo,
val lastStackTrace: List<StackTraceElement>
) {
fun getLogString(): String {
return String.format(
"%s (%s)\n Thread stacktrace: %s",
thread.name,
thread.state,
lastStackTrace.joinToString("\n at ")
)
}

companion object {
fun getAll(): List<OrphanedThreadInfo> {
return ThreadUtils.getAllThreads().mapNotNull { getForThread(it) }
}

fun getForThread(thread: Thread): OrphanedThreadInfo? {
val threadCreationInfo =
getMethod.invoke(threadCreationInfo, thread) as ThreadCreationInfo?
val stack = thread.stackTrace.asList()
if (threadCreationInfo == null) {
return null
}
return OrphanedThreadInfo(thread, threadCreationInfo, stack)
}

// ThreadLocal.get(Thread) is private. So we open it and keep a reference to the
// opened method
private val getMethod: Method =
ThreadLocal::class.java.getDeclaredMethod("get", Thread::class.java).also {
it.isAccessible = true
}
}
}

class ThreadCreationInfo {
val stack: List<StackTraceElement> = Thread.currentThread().stackTrace.asList()
val time: Instant = Instant.now()
Expand All @@ -327,25 +366,13 @@ internal constructor(
companion object {
private val threadCreationInfo: InheritableThreadLocal<ThreadCreationInfo> =
object : InheritableThreadLocal<ThreadCreationInfo>() {
override fun childValue(parentValue: ThreadCreationInfo): ThreadCreationInfo {
override fun childValue(parentValue: ThreadCreationInfo?): ThreadCreationInfo {
return ThreadCreationInfo()
}
}

const val TYPE_AND_DEDUPE_THREAD_NAME: String = "type-and-dedupe"

// ThreadLocal.get(Thread) is private. So we open it and keep a reference to the
// opened method
private val getMethod: Method =
ThreadLocal::class.java.getDeclaredMethod("get", Thread::class.java).also {
it.isAccessible = true
}

@JvmStatic
fun getThreadCreationInfo(thread: Thread): ThreadCreationInfo? {
return getMethod.invoke(threadCreationInfo, thread) as ThreadCreationInfo?
}

/**
* Filters threads that should not be considered when looking for orphaned threads at
* shutdown of the integration runner.
Expand All @@ -355,11 +382,11 @@ internal constructor(
* active so long as the database connection pool is open.
*/
@VisibleForTesting
private val orphanedThreadPredicates: MutableList<(Thread) -> Boolean> =
mutableListOf({ runningThread: Thread ->
(runningThread.name != Thread.currentThread().name &&
!runningThread.isDaemon &&
TYPE_AND_DEDUPE_THREAD_NAME != runningThread.name)
private val orphanedThreadPredicates: MutableList<(OrphanedThreadInfo) -> Boolean> =
mutableListOf({ runningThreadInfo: OrphanedThreadInfo ->
(runningThreadInfo.thread.name != Thread.currentThread().name &&
!runningThreadInfo.thread.isDaemon &&
TYPE_AND_DEDUPE_THREAD_NAME != runningThreadInfo.thread.name)
})

const val INTERRUPT_THREAD_DELAY_MINUTES: Int = 1
Expand Down Expand Up @@ -402,12 +429,12 @@ internal constructor(
}

@JvmStatic
fun addOrphanedThreadFilter(predicate: (Thread) -> (Boolean)) {
fun addOrphanedThreadFilter(predicate: (OrphanedThreadInfo) -> (Boolean)) {
orphanedThreadPredicates.add(predicate)
}

fun filterOrphanedThread(thread: Thread): Boolean {
return orphanedThreadPredicates.all { it(thread) }
fun filterOrphanedThread(threadInfo: OrphanedThreadInfo): Boolean {
return orphanedThreadPredicates.all { it(threadInfo) }
}

/**
Expand Down Expand Up @@ -437,8 +464,8 @@ internal constructor(
) {
val currentThread = Thread.currentThread()

val runningThreads = ThreadUtils.getAllThreads().filter(::filterOrphanedThread)
if (runningThreads.isNotEmpty()) {
val runningThreadInfos = OrphanedThreadInfo.getAll().filter(::filterOrphanedThread)
if (runningThreadInfos.isNotEmpty()) {
LOGGER.warn {
"""
The main thread is exiting while children non-daemon threads from a connector are still active.
Expand All @@ -457,18 +484,15 @@ internal constructor(
.daemon(true)
.build()
)
for (runningThread in runningThreads) {
val str =
"Active non-daemon thread: " +
dumpThread(runningThread) +
"\ncreationStack=${getThreadCreationInfo(runningThread)}"
for (runningThreadInfo in runningThreadInfos) {
val str = "Active non-daemon thread info: ${runningThreadInfo.getLogString()}"
LOGGER.warn { str }
// even though the main thread is already shutting down, we still leave some
// chances to the children
// threads to close properly on their own.
// So, we schedule an interrupt hook after a fixed time delay instead...
scheduledExecutorService.schedule(
{ runningThread.interrupt() },
{ runningThreadInfo.thread.interrupt() },
interruptTimeDelay.toLong(),
interruptTimeUnit
)
Expand All @@ -493,6 +517,7 @@ internal constructor(
}

private fun dumpThread(thread: Thread): String {
OrphanedThreadInfo.getForThread(thread)
return String.format(
"%s (%s)\n Thread stacktrace: %s",
thread.name,
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.44.23
version=0.45.0
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,18 @@ class AirbyteTraceMessageUtilityTest {
Mockito.mock(RuntimeException::class.java),
"this is a config error"
)
val outJson = Jsons.deserialize(outContent.toString(StandardCharsets.UTF_8))
assertJsonNodeIsTraceMessage(outJson)
val outCt = outContent.toString(StandardCharsets.UTF_8)
var outJson: JsonNode? = null
// because we are running tests in parallel, it's possible that another test is writing to
// stdout while we run this test, in which case we'd see their messages.
// we filter through the messages to find an error (hopefully hours)
for (line in outCt.split('\n')) {
if (line.contains("\"error\"")) {
outJson = Jsons.deserialize(line)
break
}
}
assertJsonNodeIsTraceMessage(outJson!!)
Assertions.assertEquals("config_error", outJson["trace"]["error"]["failure_type"].asText())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,12 @@ ${Jsons.serialize(message2)}""".toByteArray(
} catch (e: Exception) {
throw RuntimeException(e)
}
val runningThreads =
ThreadUtils.getAllThreads().filter(IntegrationRunner::filterOrphanedThread)
val runningThreadInfos =
IntegrationRunner.OrphanedThreadInfo.getAll()
.filter(IntegrationRunner::filterOrphanedThread)

// all threads should be interrupted
Assertions.assertEquals(listOf<Any>(), runningThreads)
Assertions.assertEquals(listOf<Any>(), runningThreadInfos)
Assertions.assertEquals(1, caughtExceptions.size)
}

Expand All @@ -468,11 +469,12 @@ ${Jsons.serialize(message2)}""".toByteArray(
throw RuntimeException(e)
}

val runningThreads =
ThreadUtils.getAllThreads().filter(IntegrationRunner::filterOrphanedThread)
val runningThreadInfos =
IntegrationRunner.OrphanedThreadInfo.getAll()
.filter(IntegrationRunner::filterOrphanedThread)

// a thread that refuses to be interrupted should remain
Assertions.assertEquals(1, runningThreads.size)
Assertions.assertEquals(1, runningThreadInfos.size)
Assertions.assertEquals(1, caughtExceptions.size)
Assertions.assertTrue(exitCalled.get())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import java.time.format.DateTimeParseException
import java.util.*
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import java.util.concurrent.atomic.AtomicLong
import java.util.regex.Pattern
import kotlin.concurrent.Volatile
import org.apache.commons.lang3.StringUtils
Expand Down Expand Up @@ -88,7 +89,7 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
logLineSuffix = "execution of unknown intercepted call $methodName"
}
val currentThread = Thread.currentThread()
val timeoutTask = TimeoutInteruptor(currentThread)
val timeoutTask = TimeoutInteruptor(currentThread, logLineSuffix)
val start = Instant.now()
try {
val timeout = reflectiveInvocationContext?.let(::getTimeout)
Expand Down Expand Up @@ -116,6 +117,7 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
val elapsedMs = Duration.between(start, Instant.now()).toMillis()
val t1: Throwable
if (timeoutTask.wasTriggered) {
LOGGER.info { "timeoutTask ${timeoutTask.id} was triggered." }
val timeoutAsString =
DurationFormatUtils.formatDurationWords(elapsedMs, true, true)
t1 =
Expand All @@ -126,6 +128,7 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
)
t1.initCause(throwable)
} else {
LOGGER.info { "timeoutTask ${timeoutTask.id} was not triggered." }
t1 = throwable
}
var belowCurrentCall = false
Expand Down Expand Up @@ -157,25 +160,36 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
throw t1
} finally {
timeoutTask.cancel()
TestContext.CURRENT_TEST_NAME.set(null)
TestContext.CURRENT_TEST_NAME.set(TestContext.NO_RUNNING_TEST)
}
}

private class TimeoutInteruptor(private val parentThread: Thread) : TimerTask() {
private class TimeoutInteruptor(
private val parentThread: Thread,
private val context: String
) : TimerTask() {
@Volatile var wasTriggered: Boolean = false
val id = timerIdentifier.incrementAndGet()

override fun run() {
LOGGER.info(
"interrupting running task on ${parentThread.name}. Current Stacktrace is ${parentThread.stackTrace.asList()}"
"interrupting running task on ${parentThread.name}. " +
"Current Stacktrace is ${parentThread.stackTrace.asList()}" +
"TimeoutIterruptor $id interrupting running task on ${parentThread.name}: $context. " +
"Current Stacktrace is ${parentThread.stackTrace.asList()}"
)
wasTriggered = true
parentThread.interrupt()
}

override fun cancel(): Boolean {
LOGGER.info("cancelling timer task on ${parentThread.name}")
LOGGER.info("cancelling TimeoutIterruptor $id on ${parentThread.name}")
return super.cancel()
}

companion object {
private val timerIdentifier = AtomicLong(1)
}
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,11 @@
package io.airbyte.cdk.extensions

object TestContext {
val CURRENT_TEST_NAME: ThreadLocal<String?> = ThreadLocal()
const val NO_RUNNING_TEST = "NONE"
val CURRENT_TEST_NAME: ThreadLocal<String> =
object : ThreadLocal<String>() {
override fun initialValue(): String {
return NO_RUNNING_TEST
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package io.airbyte.workers.internal

import com.google.common.base.Charsets
import com.google.common.base.Preconditions
import io.airbyte.cdk.extensions.TestContext
import io.airbyte.commons.io.IOs
import io.airbyte.commons.io.LineGobbler
import io.airbyte.commons.json.Jsons
Expand Down Expand Up @@ -182,7 +183,7 @@ constructor(

fun createContainerLogMdcBuilder(): MdcScope.Builder =
MdcScope.Builder()
.setLogPrefix("destination")
.setLogPrefix("destination-${TestContext.CURRENT_TEST_NAME.get()}")
.setPrefixColor(LoggingHelper.Color.YELLOW_BACKGROUND)
val IGNORED_EXIT_CODES: Set<Int> =
setOf(
Expand Down

0 comments on commit 6d74db7

Please sign in to comment.