Skip to content

Commit fbeaa59

Browse files
Load CDK: Bugfix: ObjectLoader eos can arrive out-of-order
1 parent f5a25f4 commit fbeaa59

File tree

12 files changed

+99
-63
lines changed

12 files changed

+99
-63
lines changed

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/message/Batch.kt

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ import io.airbyte.cdk.load.command.DestinationStream
4949
* // etc...
5050
* ```
5151
*/
52+
@Deprecated("This serves the old-style StreamLoader interface. See BatchState")
5253
interface Batch {
5354
val groupId: String?
5455

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/message/PartitionedQueue.kt

-26
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
package io.airbyte.cdk.load.message
66

77
import io.airbyte.cdk.load.util.CloseableCoroutine
8-
import kotlinx.coroutines.channels.Channel
98
import kotlinx.coroutines.flow.Flow
10-
import kotlinx.coroutines.flow.merge
119

1210
interface PartitionedQueue<T> : CloseableCoroutine {
1311
val partitions: Int
@@ -39,27 +37,3 @@ class StrictPartitionedQueue<T>(private val queues: Array<MessageQueue<T>>) : Pa
3937
queues.forEach { it.close() }
4038
}
4139
}
42-
43-
/**
44-
* This is for the use case where you want workers to grab work as it becomes available but still be
45-
* able to receive notifications that are guaranteed to be consumed by every partition.
46-
*/
47-
class SinglePartitionQueueWithMultiPartitionBroadcast<T>(
48-
private val sharedQueue: MessageQueue<T>,
49-
override val partitions: Int
50-
) : PartitionedQueue<T> {
51-
private val broadcastChannels =
52-
StrictPartitionedQueue(
53-
(0 until partitions).map { ChannelMessageQueue<T>(Channel(1)) }.toTypedArray()
54-
)
55-
56-
override fun consume(partition: Int): Flow<T> =
57-
merge(sharedQueue.consume(), broadcastChannels.consume(partition))
58-
override suspend fun publish(value: T, partition: Int) = sharedQueue.publish(value)
59-
override suspend fun broadcast(value: T) = broadcastChannels.broadcast(value)
60-
61-
override suspend fun close() {
62-
sharedQueue.close()
63-
broadcastChannels.close()
64-
}
65-
}

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/state/StreamManager.kt

+4-7
Original file line numberDiff line numberDiff line change
@@ -372,13 +372,10 @@ class DefaultStreamManager(
372372
inputCount: Long
373373
) {
374374
val taskKey = TaskKey(taskName, part)
375-
if (taskCompletionCounts.containsKey(taskKey)) {
376-
// TODO: Promote this to a hard failure as part of the subsequent bugfix.
377-
log.warn {
378-
""""$taskKey received input after seeing end-of-stream
379-
(checkpointCounts=$checkpointCounts, inputCount=$inputCount, sawEosAt=${taskCompletionCounts[taskKey]})
380-
This indicates data was processed out of order and future bookkeeping might be corrupt. Failing hard."""
381-
}
375+
check(!taskCompletionCounts.containsKey(taskKey)) {
376+
""""$taskKey received input after seeing end-of-stream
377+
(checkpointCounts=$checkpointCounts, inputCount=$inputCount, sawEosAt=${taskCompletionCounts[taskKey]})
378+
This indicates data was processed out of order and future bookkeeping might be corrupt. Failing hard."""
382379
}
383380
val idToValue =
384381
namedCheckpointCounts.getOrPut(TaskKey(taskName, part) to state) { ConcurrentHashMap() }

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/LoadPipelineStepTask.kt

+5
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ class LoadPipelineStepTask<S : AutoCloseable, K1 : WithStream, T, K2 : WithStrea
8989
try {
9090
when (input) {
9191
is PipelineMessage -> {
92+
if (stateStore.streamsEnded.contains(input.key.stream)) {
93+
throw IllegalStateException(
94+
"$taskName[$part] received input for complete stream ${input.key.stream}. This indicates data was processed out of order and future bookkeeping might be corrupt. Failing hard."
95+
)
96+
}
9297
// Get or create the accumulator state associated w/ the input key.
9398
val stateWithCounts =
9499
stateStore.stateWithCounts

airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/internal/LoadPipelineStepTaskUTest.kt

+43
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import io.airbyte.cdk.load.pipeline.FinalOutput
2121
import io.airbyte.cdk.load.pipeline.NoOutput
2222
import io.airbyte.cdk.load.pipeline.OutputPartitioner
2323
import io.airbyte.cdk.load.state.CheckpointId
24+
import io.airbyte.cdk.load.test.util.CoroutineTestUtils.Companion.assertDoesNotThrow
25+
import io.airbyte.cdk.load.test.util.CoroutineTestUtils.Companion.assertThrows
2426
import io.airbyte.cdk.load.util.setOnce
2527
import io.mockk.coEvery
2628
import io.mockk.coVerify
@@ -465,4 +467,45 @@ class LoadPipelineStepTaskUTest {
465467
fun `end-of-stream not forwarded if all tasks do not receive it`() = runEndOfStreamTest(false)
466468

467469
@Test fun `end-of-stream forwarded to all tasks if all receive it`() = runEndOfStreamTest(true)
470+
471+
@Test
472+
fun `records received after end-of-stream throws`() = runTest {
473+
val key1 = StreamKey(DestinationStream.Descriptor("namespace", "stream1"))
474+
val part = 66666
475+
476+
val task = createTask(part, batchAccumulatorWithUpdate)
477+
478+
coEvery { batchUpdateQueue.publish(any()) } returns Unit
479+
coEvery { inputFlow.collect(any()) } coAnswers
480+
{
481+
val collector = firstArg<FlowCollector<PipelineEvent<StreamKey, String>>>()
482+
483+
// Emit end-of-stream for stream1, end-of-stream for stream2
484+
collector.emit(endOfStreamEvent(key1))
485+
collector.emit(messageEvent(key1, "value", emptyMap()))
486+
}
487+
488+
assertThrows(IllegalStateException::class) { task.execute() }
489+
}
490+
491+
@Test
492+
fun `records received for stream A after end-of-stream B do not throw`() = runTest {
493+
val key1 = StreamKey(DestinationStream.Descriptor("namespace", "stream1"))
494+
val key2 = StreamKey(DestinationStream.Descriptor("namespace", "stream2"))
495+
val part = 66666
496+
497+
val task = createTask(part, batchAccumulatorWithUpdate)
498+
499+
coEvery { batchUpdateQueue.publish(any()) } returns Unit
500+
coEvery { batchAccumulatorWithUpdate.start(any(), any()) } returns Closeable()
501+
coEvery { batchAccumulatorWithUpdate.accept(any(), any()) } returns NoOutput(Closeable())
502+
coEvery { inputFlow.collect(any()) } coAnswers
503+
{
504+
val collector = firstArg<FlowCollector<PipelineEvent<StreamKey, String>>>()
505+
collector.emit(endOfStreamEvent(key2))
506+
collector.emit(messageEvent(key1, "value", emptyMap()))
507+
}
508+
509+
assertDoesNotThrow { task.execute() }
510+
}
468511
}

airbyte-cdk/bulk/toolkits/load-object-storage/src/main/kotlin/io/airbyte/cdk/load/pipline/object_storage/ObjectLoaderFormattedPartPartitioner.kt

+5-4
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ package io.airbyte.cdk.load.pipline.object_storage
66

77
import io.airbyte.cdk.load.message.WithStream
88
import io.airbyte.cdk.load.pipeline.OutputPartitioner
9+
import kotlin.random.Random
910

1011
/**
11-
* The technically correct partitioning is round-robin, but since we use
12-
* [io.airbyte.cdk.load.message.SinglePartitionQueueWithMultiPartitionBroadcast], the partition is
13-
* immaterial, so it's simpler just to return 0 here.
12+
* Distribute the parts randomly across loaders. (Testing shows this is the most efficient pattern.)
1413
*/
1514
class ObjectLoaderFormattedPartPartitioner<K : WithStream, T> :
1615
OutputPartitioner<K, T, ObjectKey, ObjectLoaderPartFormatter.FormattedPart> {
16+
private val prng = Random(System.currentTimeMillis())
17+
1718
override fun getOutputKey(
1819
inputKey: K,
1920
output: ObjectLoaderPartFormatter.FormattedPart
@@ -22,6 +23,6 @@ class ObjectLoaderFormattedPartPartitioner<K : WithStream, T> :
2223
}
2324

2425
override fun getPart(outputKey: ObjectKey, numParts: Int): Int {
25-
return 0
26+
return prng.nextInt(numParts)
2627
}
2728
}

airbyte-cdk/bulk/toolkits/load-object-storage/src/main/kotlin/io/airbyte/cdk/load/pipline/object_storage/ObjectLoaderLoadedPartPartitioner.kt

+6-4
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,19 @@ import io.airbyte.cdk.load.message.WithStream
88
import io.airbyte.cdk.load.pipeline.OutputPartitioner
99

1010
/**
11-
* The technically correct partitioning is round-robin, but since we use
12-
* [io.airbyte.cdk.load.message.SinglePartitionQueueWithMultiPartitionBroadcast], the partition is
13-
* immaterial, so it's simpler just to return 0 here.
11+
* Distribute the loaded parts to the upload completers by key. (Distributing the completes
12+
* efficiently is not as important as not forcing the uploaders to coordinate with each other, so
13+
* instead we focus on operational simplicity: all fact-of-loaded part signals for the same key go
14+
* to the same upload completer.)
1415
*/
1516
class ObjectLoaderLoadedPartPartitioner<K : WithStream, T> :
1617
OutputPartitioner<K, T, ObjectKey, ObjectLoaderPartLoader.PartResult> {
18+
1719
override fun getOutputKey(inputKey: K, output: ObjectLoaderPartLoader.PartResult): ObjectKey {
1820
return ObjectKey(inputKey.stream, output.objectKey)
1921
}
2022

2123
override fun getPart(outputKey: ObjectKey, numParts: Int): Int {
22-
return 0
24+
return Math.floorMod(outputKey.objectKey.hashCode(), numParts)
2325
}
2426
}

airbyte-cdk/bulk/toolkits/load-object-storage/src/main/kotlin/io/airbyte/cdk/load/pipline/object_storage/ObjectLoaderPartLoader.kt

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class ObjectLoaderPartLoader(
6060

6161
sealed interface PartResult : WithBatchState {
6262
val objectKey: String
63+
override val state: BatchState
64+
get() = BatchState.STAGED
6365
}
6466
data class LoadedPart(
6567
val upload: Deferred<StreamingUpload<*>>,

airbyte-cdk/bulk/toolkits/load-object-storage/src/main/kotlin/io/airbyte/cdk/load/pipline/object_storage/ObjectLoaderPartQueueFactory.kt

+27-14
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ package io.airbyte.cdk.load.pipline.object_storage
66

77
import io.airbyte.cdk.load.command.DestinationStream
88
import io.airbyte.cdk.load.message.ChannelMessageQueue
9+
import io.airbyte.cdk.load.message.PartitionedQueue
910
import io.airbyte.cdk.load.message.PipelineEvent
10-
import io.airbyte.cdk.load.message.SinglePartitionQueueWithMultiPartitionBroadcast
11+
import io.airbyte.cdk.load.message.StrictPartitionedQueue
1112
import io.airbyte.cdk.load.message.WithStream
1213
import io.airbyte.cdk.load.state.ReservationManager
1314
import io.airbyte.cdk.load.state.Reserved
@@ -63,8 +64,8 @@ class ObjectLoaderPartQueueFactory(
6364
fun objectLoaderClampedPartSizeBytes(
6465
@Named("objectLoaderMemoryReservation") reservation: Reserved<ObjectLoader>
6566
): Long {
66-
// 1 per worker, plus at least one in the queue.
67-
val maxNumPartsInMemory = loader.numPartWorkers + loader.numUploadWorkers + 1
67+
// 1 per worker, plus at least one per partition leading to the upload workers.
68+
val maxNumPartsInMemory = loader.numPartWorkers + (loader.numUploadWorkers * 2)
6869
val maxPartSizeBytes = reservation.bytesReserved / maxNumPartsInMemory
6970

7071
if (loader.partSizeBytes > maxPartSizeBytes) {
@@ -86,7 +87,7 @@ class ObjectLoaderPartQueueFactory(
8687
): Int {
8788
val maxNumParts = reservation.bytesReserved / clampedPartSizeBytes
8889
val numWorkersHoldingParts = loader.numPartWorkers + loader.numUploadWorkers
89-
val maxQueueCapacity = maxNumParts - numWorkersHoldingParts
90+
val maxQueueCapacity = (maxNumParts - numWorkersHoldingParts) / loader.numUploadWorkers
9091
// Our earlier calculations should ensure this is always at least 1, but
9192
// we'll clamp it to be safe.
9293
val capacity = maxQueueCapacity.toInt().coerceAtLeast(1)
@@ -106,11 +107,17 @@ class ObjectLoaderPartQueueFactory(
106107
@Requires(bean = ObjectLoader::class)
107108
fun objectLoaderPartQueue(
108109
@Named("objectLoaderPartQueueCapacity") capacity: Int
109-
): SinglePartitionQueueWithMultiPartitionBroadcast<
110-
PipelineEvent<ObjectKey, ObjectLoaderPartFormatter.FormattedPart>> {
111-
return SinglePartitionQueueWithMultiPartitionBroadcast(
112-
ChannelMessageQueue(Channel(capacity)),
113-
loader.numUploadWorkers
110+
): PartitionedQueue<PipelineEvent<ObjectKey, ObjectLoaderPartFormatter.FormattedPart>> {
111+
return StrictPartitionedQueue(
112+
(0 until loader.numUploadWorkers)
113+
.map {
114+
ChannelMessageQueue(
115+
Channel<PipelineEvent<ObjectKey, ObjectLoaderPartFormatter.FormattedPart>>(
116+
capacity
117+
)
118+
)
119+
}
120+
.toTypedArray()
114121
)
115122
}
116123

@@ -123,11 +130,17 @@ class ObjectLoaderPartQueueFactory(
123130
@Named("objectLoaderLoadedPartQueue")
124131
@Requires(bean = ObjectLoader::class)
125132
fun objectLoaderLoadedPartQueue():
126-
SinglePartitionQueueWithMultiPartitionBroadcast<
127-
PipelineEvent<ObjectKey, ObjectLoaderPartLoader.PartResult>> {
128-
return SinglePartitionQueueWithMultiPartitionBroadcast(
129-
ChannelMessageQueue(Channel(OBJECT_LOADER_MAX_ENQUEUED_COMPLETIONS)),
130-
1
133+
PartitionedQueue<PipelineEvent<ObjectKey, ObjectLoaderPartLoader.PartResult>> {
134+
return StrictPartitionedQueue(
135+
(0 until loader.numUploadCompleters)
136+
.map {
137+
ChannelMessageQueue(
138+
Channel<PipelineEvent<ObjectKey, ObjectLoaderPartLoader.PartResult>>(
139+
OBJECT_LOADER_MAX_ENQUEUED_COMPLETIONS / loader.numUploadWorkers
140+
)
141+
)
142+
}
143+
.toTypedArray()
131144
)
132145
}
133146
}

airbyte-cdk/bulk/toolkits/load-object-storage/src/test/kotlin/io/airbyte/cdk/load/pipeline/object_storage/ObjectLoaderPartQueueTest.kt

+4-4
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,19 @@ class ObjectLoaderPartQueueTest {
3434
fun `part queue clamps part size if too many workers`() {
3535
val beanFactory = ObjectLoaderPartQueueFactory(objectLoader)
3636
every { objectLoader.numPartWorkers } returns 5
37-
every { objectLoader.numUploadWorkers } returns 3
37+
every { objectLoader.numUploadWorkers } returns 3 // this will be doubled for calcs
3838
every { objectLoader.partSizeBytes } returns 100
3939
val memoryReservation = mockk<Reserved<ObjectLoader>>(relaxed = true)
4040
every { memoryReservation.bytesReserved } returns 800
4141
val clampedSize = beanFactory.objectLoaderClampedPartSizeBytes(memoryReservation)
42-
Assertions.assertEquals(800 / 9, clampedSize)
42+
Assertions.assertEquals(800 / 11, clampedSize)
4343
}
4444

4545
@Test
4646
fun `part queue does not clamp part size if not too many workers`() {
4747
val beanFactory = ObjectLoaderPartQueueFactory(objectLoader)
4848
every { objectLoader.numPartWorkers } returns 5
49-
every { objectLoader.numUploadWorkers } returns 1
49+
every { objectLoader.numUploadWorkers } returns 1 // this will be doubled for calcs
5050
every { objectLoader.partSizeBytes } returns 100
5151
val memoryReservation = mockk<Reserved<ObjectLoader>>(relaxed = true)
5252
every { memoryReservation.bytesReserved } returns 800
@@ -58,7 +58,7 @@ class ObjectLoaderPartQueueTest {
5858
fun `queue capacity is derived from clamped size and available memory`() {
5959
val beanFactory = ObjectLoaderPartQueueFactory(objectLoader)
6060
every { objectLoader.numPartWorkers } returns 3
61-
every { objectLoader.numUploadWorkers } returns 1
61+
every { objectLoader.numUploadWorkers } returns 1 // this will be doubled for calcs
6262
val clampedPartSize = 150L
6363
val memoryReservation = mockk<Reserved<ObjectLoader>>(relaxed = true)
6464
every { memoryReservation.bytesReserved } returns 910

airbyte-integrations/connectors/destination-s3-data-lake/metadata.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ data:
1717
- name: SECRET_DESTINATION-ICEBERG_V2_S3_GLUE_ASSUME_ROLE_CONFIG
1818
fileName: glue_assume_role.json
1919
secretStore:
20-
type: GSM
20+
type: GSMs
2121
alias: airbyte-connector-testing-secret-store
2222
- name: SECRET_DESTINATION-ICEBERG_V2_S3_GLUE_ASSUME_ROLE_SYSTEM_AWS_CONFIG
2323
fileName: glue_aws_assume_role.json

airbyte-integrations/connectors/destination-s3/src/main/kotlin/S3V2ObjectLoader.kt

+1-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
package io.airbyte.integrations.destination.s3_v2
66

7-
import io.airbyte.cdk.load.pipeline.RoundRobinInputPartitioner
87
import io.airbyte.cdk.load.write.object_storage.ObjectLoader
98

109
/**
@@ -20,5 +19,4 @@ class S3V2ObjectLoader(config: S3V2Configuration<*>) : ObjectLoader {
2019
override val partSizeBytes: Long = config.partSizeBytes
2120
}
2221

23-
// @Singleton
24-
class S3V2RoundRobinInputPartitioner : RoundRobinInputPartitioner()
22+
// @Singleton class S3V2RoundRobinInputPartitioner : RoundRobinInputPartitioner()

0 commit comments

Comments
 (0)