Skip to content

Commit

Permalink
bulk-cdk: fix bugs surfaced by CAT tests (#44543)
Browse files Browse the repository at this point in the history
## What
I tried to get the CAT tests running on airbyte-enterprise today. There were a few failures which surfaced bugs in the Bulk CDK, which doesn't always emit STATE or TRACE ERROR messages when required during a READ.

## How
Emit TRACE ERROR messages if the configured streams are bad.
Emit at least one STATE message for each stream with an input state.

## Review guide
Commit by commit

## User Impact
None

## Can this PR be safely reverted and rolled back?
<!--
* If unsure, leave it blank.
-->
- [x] YES 💚
- [ ] NO ❌
  • Loading branch information
postamar authored Aug 22, 2024
1 parent 5d1b1cd commit ff6b1bb
Show file tree
Hide file tree
Showing 11 changed files with 281 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class FeedReader(
log.info {
"no more partitions to read for '${feed.label}' in round $partitionsCreatorID"
}
// Publish a checkpoint if applicable.
maybeCheckpoint()
// Publish stream completion.
emitStreamStatus(AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE)
break
}
Expand Down Expand Up @@ -279,18 +282,25 @@ class FeedReader(
}
} finally {
// Publish a checkpoint if applicable.
val stateMessages: List<AirbyteStateMessage> = root.stateManager.checkpoint()
if (stateMessages.isNotEmpty()) {
log.info { "checkpoint of ${stateMessages.size} state message(s)" }
stateMessages.forEach(root.outputConsumer::accept)
}
maybeCheckpoint()
}
}
}

private suspend fun ctx(nameSuffix: String): CoroutineContext =
coroutineContext + ThreadRenamingCoroutineName("${feed.label}-$nameSuffix") + Dispatchers.IO

private fun maybeCheckpoint() {
val stateMessages: List<AirbyteStateMessage> = root.stateManager.checkpoint()
if (stateMessages.isEmpty()) {
return
}
log.info { "checkpoint of ${stateMessages.size} state message(s)" }
for (stateMessage in stateMessages) {
root.outputConsumer.accept(stateMessage)
}
}

private fun emitStreamStatus(status: AirbyteStreamStatusTraceMessage.AirbyteStreamStatus) {
if (feed is Stream) {
root.outputConsumer.accept(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,19 @@ class StateManager(
initialState: OpaqueStateValue?,
private val isCheckpointUnique: Boolean = true,
) : StateManagerScopedToFeed {
private var current: OpaqueStateValue? = initialState
private var pending: OpaqueStateValue? = initialState
private var pendingNumRecords: Long = 0L
private var current: OpaqueStateValue?
private var pending: OpaqueStateValue?
private var isPending: Boolean
private var pendingNumRecords: Long

init {
synchronized(this) {
current = initialState
pending = initialState
isPending = initialState != null
pendingNumRecords = 0L
}
}

override fun current(): OpaqueStateValue? = synchronized(this) { current }

Expand All @@ -108,13 +118,14 @@ class StateManager(
) {
synchronized(this) {
pending = state
isPending = true
pendingNumRecords += numRecords
}
}

fun swap(): Pair<OpaqueStateValue?, Long>? {
synchronized(this) {
if (isCheckpointUnique && pendingNumRecords == 0L && pending == current) {
if (isCheckpointUnique && !isPending) {
return null
}
val returnValue: Pair<OpaqueStateValue?, Long> = pending to pendingNumRecords
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ import io.airbyte.cdk.output.FieldTypeMismatch
import io.airbyte.cdk.output.InvalidIncrementalSyncMode
import io.airbyte.cdk.output.InvalidPrimaryKey
import io.airbyte.cdk.output.MultipleStreamsFound
import io.airbyte.cdk.output.OutputConsumer
import io.airbyte.cdk.output.StreamHasNoFields
import io.airbyte.cdk.output.StreamNotFound
import io.airbyte.protocol.models.v0.AirbyteErrorTraceMessage
import io.airbyte.protocol.models.v0.AirbyteStream
import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog
import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream
import io.airbyte.protocol.models.v0.StreamDescriptor
import io.airbyte.protocol.models.v0.SyncMode
import jakarta.inject.Singleton

Expand All @@ -37,6 +41,7 @@ import jakarta.inject.Singleton
@Singleton
class StateManagerFactory(
val metadataQuerierFactory: MetadataQuerier.Factory<SourceConfiguration>,
val outputConsumer: OutputConsumer,
val handler: CatalogValidationFailureHandler,
) {
/** Generates a [StateManager] instance based on the provided inputs. */
Expand Down Expand Up @@ -101,14 +106,28 @@ class StateManagerFactory(
val jsonSchemaProperties: JsonNode = stream.jsonSchema["properties"]
val name: String = stream.name!!
val namespace: String? = stream.namespace
val streamDescriptor = StreamDescriptor().withName(name).withNamespace(namespace)
val streamLabel: String = AirbyteStreamNameNamespacePair(name, namespace).toString()
when (metadataQuerier.streamNames(namespace).filter { it == name }.size) {
0 -> {
handler.accept(StreamNotFound(name, namespace))
outputConsumer.accept(
AirbyteErrorTraceMessage()
.withStreamDescriptor(streamDescriptor)
.withFailureType(AirbyteErrorTraceMessage.FailureType.CONFIG_ERROR)
.withMessage("Stream '$streamLabel' not found or not accessible in source.")
)
return null
}
1 -> Unit
else -> {
handler.accept(MultipleStreamsFound(name, namespace))
outputConsumer.accept(
AirbyteErrorTraceMessage()
.withStreamDescriptor(streamDescriptor)
.withFailureType(AirbyteErrorTraceMessage.FailureType.CONFIG_ERROR)
.withMessage("Multiple streams '$streamLabel' found in source.")
)
return null
}
}
Expand Down Expand Up @@ -153,6 +172,12 @@ class StateManagerFactory(
}
if (streamFields.isEmpty()) {
handler.accept(StreamHasNoFields(name, namespace))
outputConsumer.accept(
AirbyteErrorTraceMessage()
.withStreamDescriptor(streamDescriptor)
.withFailureType(AirbyteErrorTraceMessage.FailureType.CONFIG_ERROR)
.withMessage("Stream '$streamLabel' has no accessible fields.")
)
return null
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,11 @@ data object SyncsTestFixture {
connectionSupplier: Supplier<Connection>,
prelude: (Connection) -> Unit,
configuredCatalog: ConfiguredAirbyteCatalog,
initialState: List<AirbyteStateMessage> = listOf(),
vararg afterRead: AfterRead,
) {
connectionSupplier.get().use(prelude)
var state: List<AirbyteStateMessage> = listOf()
var state: List<AirbyteStateMessage> = initialState
for (step in afterRead) {
val readOutput: BufferingOutputConsumer =
CliRunner.runSource("read", configPojo, configuredCatalog, state)
Expand All @@ -113,13 +114,15 @@ data object SyncsTestFixture {
connectionSupplier: Supplier<Connection>,
prelude: (Connection) -> Unit,
configuredCatalogResource: String,
initialStateResource: String?,
vararg afterRead: AfterRead,
) {
testReads(
configPojo,
connectionSupplier,
prelude,
configuredCatalogFromResource(configuredCatalogResource),
initialStateFromResource(initialStateResource),
*afterRead,
)
}
Expand Down Expand Up @@ -169,6 +172,14 @@ data object SyncsTestFixture {
ConfiguredAirbyteCatalog::class.java,
)

fun initialStateFromResource(initialStateResource: String?): List<AirbyteStateMessage> =
if (initialStateResource == null) {
listOf()
} else {
val initialStateJson: String = ResourceUtils.readResource(initialStateResource)
ValidatedJsonUtils.parseList(AirbyteStateMessage::class.java, initialStateJson)
}

interface AfterRead {
fun validate(actualOutput: BufferingOutputConsumer)

Expand All @@ -182,7 +193,7 @@ data object SyncsTestFixture {
object : AfterRead {
override fun validate(actualOutput: BufferingOutputConsumer) {
// State messages are timing-sensitive and therefore non-deterministic.
// Ignore them.
// Ignore them for now.
val expectedWithoutStates: List<AirbyteMessage> =
expectedMessages
.filterNot { it.type == AirbyteMessage.Type.STATE }
Expand All @@ -193,6 +204,19 @@ data object SyncsTestFixture {
.filterNot { it.type == AirbyteMessage.Type.STATE }
.sortedBy { Jsons.writeValueAsString(it) }
Assertions.assertIterableEquals(expectedWithoutStates, actualWithoutStates)
// Check for state message counts (null if no state messages).
val expectedCount: Double? =
expectedMessages
.filter { it.type == AirbyteMessage.Type.STATE }
.mapNotNull { it.state?.sourceStats?.recordCount }
.reduceRightOrNull { a: Double, b: Double -> a + b }
val actualCount: Double? =
actualOutput
.messages()
.filter { it.type == AirbyteMessage.Type.STATE }
.mapNotNull { it.state?.sourceStats?.recordCount }
.reduceRightOrNull { a: Double, b: Double -> a + b }
Assertions.assertEquals(expectedCount, actualCount)
}

override fun update(connection: Connection) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,50 @@ class H2SourceIntegrationTest {
}
}

@Test
fun testReadStreamStateTooFarAhead() {
H2TestFixture().use { h2: H2TestFixture ->
val configPojo =
H2SourceConfigurationJsonObject().apply {
port = h2.port
database = h2.database
resumablePreferred = true
}
SyncsTestFixture.testReads(
configPojo,
h2::createConnection,
Companion::prelude,
"h2source/incremental-only-catalog.json",
"h2source/state-too-far-ahead.json",
SyncsTestFixture.AfterRead.Companion.fromExpectedMessages(
"h2source/expected-messages-stream-too-far-ahead.json",
),
)
}
}

@Test
fun testReadBadCatalog() {
H2TestFixture().use { h2: H2TestFixture ->
val configPojo =
H2SourceConfigurationJsonObject().apply {
port = h2.port
database = h2.database
resumablePreferred = true
}
SyncsTestFixture.testReads(
configPojo,
h2::createConnection,
Companion::prelude,
"h2source/bad-catalog.json",
initialStateResource = null,
SyncsTestFixture.AfterRead.Companion.fromExpectedMessages(
"h2source/expected-messages-stream-bad-catalog.json",
),
)
}
}

companion object {
@JvmStatic
fun prelude(connection: Connection) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"streams": [
{
"stream": {
"name": "FOO",
"json_schema": {
"type": "object",
"properties": {
"BAR": {
"type": "string"
}
}
},
"supported_sync_modes": ["full_refresh", "incremental"],
"source_defined_cursor": false,
"default_cursor_field": [],
"source_defined_primary_key": [],
"is_resumable": false,
"namespace": "PUBLIC"
},
"sync_mode": "incremental",
"cursor_field": ["BAR"],
"destination_sync_mode": "overwrite",
"primary_key": []
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[
{
"type": "LOG",
"log": {
"level": "WARN",
"message": "StreamNotFound(streamName=FOO, streamNamespace=PUBLIC)"
}
},
{
"type": "TRACE",
"trace": {
"type": "ERROR",
"emitted_at": 3.1336416e12,
"error": {
"stream_descriptor": {
"name": "FOO",
"namespace": "PUBLIC"
},
"message": "Stream 'PUBLIC_FOO' not found or not accessible in source.",
"failure_type": "config_error"
}
}
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
[
{
"type": "TRACE",
"trace": {
"type": "STREAM_STATUS",
"emitted_at": 3.1336416e12,
"stream_status": {
"stream_descriptor": {
"name": "EVENTS",
"namespace": "PUBLIC"
},
"status": "STARTED"
}
}
},
{
"type": "STATE",
"state": {
"type": "STREAM",
"stream": {
"stream_descriptor": {
"name": "EVENTS",
"namespace": "PUBLIC"
},
"stream_state": {
"primary_key": {},
"cursors": {
"TS": "2024-04-30T00:00:00.000000-04:00"
}
}
},
"sourceStats": {
"recordCount": 0.0
}
}
},
{
"type": "TRACE",
"trace": {
"type": "STREAM_STATUS",
"emitted_at": 3.1336416e12,
"stream_status": {
"stream_descriptor": {
"name": "EVENTS",
"namespace": "PUBLIC"
},
"status": "COMPLETE"
}
}
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
}
},
"sourceStats": {
"recordCount": 2.0
"recordCount": 1.0
}
}
},
Expand Down
Loading

0 comments on commit ff6b1bb

Please sign in to comment.