Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

destination-async-framework: move the state emission logic into GlobalAsyncStateManager #35240

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
import io.airbyte.cdk.integrations.destination_async.buffers.StreamAwareQueue.MessageWithMeta;
import io.airbyte.cdk.integrations.destination_async.state.FlushFailure;
import io.airbyte.cdk.integrations.destination_async.state.GlobalAsyncStateManager;
import io.airbyte.cdk.integrations.destination_async.state.PartialStateWithDestinationStats;
import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
Expand Down Expand Up @@ -67,8 +64,6 @@ public class FlushWorkers implements AutoCloseable {
private final AtomicBoolean isClosing;
private final GlobalAsyncStateManager stateManager;

private final Object LOCK = new Object();

public FlushWorkers(final BufferDequeue bufferDequeue,
final DestinationFlushFunction flushFunction,
final Consumer<AirbyteMessage> outputRecordCollector,
Expand Down Expand Up @@ -172,7 +167,7 @@ private void flush(final StreamDescriptor desc, final UUID flushWorkerId) {
AirbyteFileUtils.byteCountToDisplaySize(batch.getSizeInBytes()));

flusher.flush(desc, batch.getData().stream().map(MessageWithMeta::message));
emitStateMessages(batch.flushStates(stateIdToCount));
batch.flushStates(stateIdToCount);
}

log.info("Flush Worker ({}) -- Worker finished flushing. Current queue size: {}",
Expand Down Expand Up @@ -222,7 +217,7 @@ public void close() throws Exception {
log.info("Closing flush workers -- all buffers flushed");

// before shutting down the supervisor, flush all state.
emitStateMessages(stateManager.flushStates());
stateManager.flushStates();
supervisorThread.shutdown();
while (!supervisorThread.awaitTermination(5L, TimeUnit.MINUTES)) {
log.info("Waiting for flush worker supervisor to shut down");
Expand All @@ -239,17 +234,6 @@ public void close() throws Exception {
debugLoop.shutdownNow();
}

private void emitStateMessages(final List<PartialStateWithDestinationStats> partials) {
synchronized (LOCK) {
for (final PartialStateWithDestinationStats partial : partials) {
final AirbyteMessage message = Jsons.deserialize(partial.stateMessage().getSerialized(), AirbyteMessage.class);
message.getState().setDestinationStats(partial.stats());
log.info("State with arrival number {} emitted from thread {}", partial.stateArrivalNumber(), Thread.currentThread().getName());
outputRecordCollector.accept(message);
}
}
}

private static String humanReadableFlushWorkerId(final UUID flushWorkerId) {
return flushWorkerId.toString().substring(0, 5);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import io.airbyte.cdk.integrations.destination_async.FlushWorkers;
import io.airbyte.cdk.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.cdk.integrations.destination_async.state.GlobalAsyncStateManager;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.ArrayList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
Expand All @@ -38,21 +40,22 @@ public class BufferManager {

public static final double MEMORY_LIMIT_RATIO = 0.7;

public BufferManager() {
this((long) (Runtime.getRuntime().maxMemory() * MEMORY_LIMIT_RATIO));
public BufferManager(final Consumer<AirbyteMessage> outputRecordCollector) {
this((long) (Runtime.getRuntime().maxMemory() * MEMORY_LIMIT_RATIO), outputRecordCollector);
}

/**
* @param memoryLimit the amount of estimated memory we allow for all buffers. The
* GlobalMemoryManager will apply back pressure once this quota is filled. "Memory" can be
* released back once flushing finishes. This number should be large enough we don't block
* reading unnecessarily, but small enough we apply back pressure before OOMing.
* @param outputRecordCollector
*/
public BufferManager(final long memoryLimit) {
public BufferManager(final long memoryLimit, final Consumer<AirbyteMessage> outputRecordCollector) {
maxMemory = memoryLimit;
LOGGER.info("Max 'memory' available for buffer allocation {}", FileUtils.byteCountToDisplaySize(maxMemory));
memoryManager = new GlobalMemoryManager(maxMemory);
this.stateManager = new GlobalAsyncStateManager(memoryManager);
this.stateManager = new GlobalAsyncStateManager(memoryManager, outputRecordCollector);
buffers = new ConcurrentHashMap<>();
bufferEnqueue = new BufferEnqueue(memoryManager, buffers, stateManager);
bufferDequeue = new BufferDequeue(memoryManager, buffers, stateManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import io.airbyte.cdk.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.cdk.integrations.destination_async.buffers.StreamAwareQueue.MessageWithMeta;
import io.airbyte.cdk.integrations.destination_async.state.GlobalAsyncStateManager;
import io.airbyte.cdk.integrations.destination_async.state.PartialStateWithDestinationStats;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
Expand Down Expand Up @@ -64,9 +63,9 @@ public void close() throws Exception {
*
* @return list of states that can be flushed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change java doc.

*/
public List<PartialStateWithDestinationStats> flushStates(final Map<Long, Long> stateIdToCount) {
public void flushStates(final Map<Long, Long> stateIdToCount) {
stateIdToCount.forEach(stateManager::decrement);
return stateManager.flushStates();
stateManager.flushStates();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, does the locking break if we return the ids to flush and flush them in the FlushWorkers instead?

That way, all the actual action happens in the flush workers, and not in objects scattered around. Mainly a cleanliness suggestion.

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import com.google.common.base.Strings;
import io.airbyte.cdk.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.cdk.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteStateMessage;
import io.airbyte.protocol.models.v0.AirbyteStateStats;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.ArrayList;
import java.time.Instant;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand All @@ -25,6 +25,7 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
Expand Down Expand Up @@ -95,11 +96,13 @@ public class GlobalAsyncStateManager {
private long retroactiveGlobalStateId = 0;
// All access to this field MUST be guarded by a synchronized(lock) block
private long arrivalNumber = 0;
private final Consumer<AirbyteMessage> outputRecordCollector;

private final Object LOCK = new Object();
private static final Object LOCK = new Object();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static seems unnecessary ?


public GlobalAsyncStateManager(final GlobalMemoryManager memoryManager) {
public GlobalAsyncStateManager(final GlobalMemoryManager memoryManager, final Consumer<AirbyteMessage> outputRecordCollector) {
this.memoryManager = memoryManager;
this.outputRecordCollector = outputRecordCollector;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid holding a copy of outputRecordCollector at instance level and use only in flushStates ? This reduces the possibility of future abuse of this instance variable.

this.memoryAllocated = new AtomicLong(memoryManager.requestMemory());
this.memoryUsed = new AtomicLong();
}
Expand Down Expand Up @@ -161,8 +164,7 @@ public void decrement(final long stateId, final long count) {
*
* @return list of state messages with no more inflight records.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update javadoc.

*/
public List<PartialStateWithDestinationStats> flushStates() {
final List<PartialStateWithDestinationStats> output = new ArrayList<>();
public void flushStates() {
Copy link
Contributor

@davinchia davinchia Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: what if flushStates took in an output record collector? This helps us prevent passing the record collector everywhere and helps readability.

Looks like it might minimise this PR's change set too.

Long bytesFlushed = 0L;
synchronized (LOCK) {
for (final Map.Entry<StreamDescriptor, LinkedBlockingDeque<Long>> entry : descToStateIdQ.entrySet()) {
Expand Down Expand Up @@ -195,8 +197,13 @@ public List<PartialStateWithDestinationStats> flushStates() {
if (allRecordsCommitted) {
final StateMessageWithArrivalNumber stateMessage = oldestState.getLeft();
final double flushedRecordsAssociatedWithState = stateIdToCounterForPopulatingDestinationStats.get(oldestStateId).doubleValue();
output.add(new PartialStateWithDestinationStats(stateMessage.partialAirbyteStateMessage(),
new AirbyteStateStats().withRecordCount(flushedRecordsAssociatedWithState), stateMessage.arrivalNumber()));

log.info("State with arrival number {} emitted from thread {} at {}", stateMessage.arrivalNumber(), Thread.currentThread().getName(),
Instant.now().toString());
final AirbyteMessage message = Jsons.deserialize(stateMessage.partialAirbyteStateMessage.getSerialized(), AirbyteMessage.class);
message.getState().setDestinationStats(new AirbyteStateStats().withRecordCount(flushedRecordsAssociatedWithState));
outputRecordCollector.accept(message);

bytesFlushed += oldestState.getRight();

// cleanup
Expand All @@ -212,7 +219,6 @@ public List<PartialStateWithDestinationStats> flushStates() {
}

freeBytes(bytesFlushed);
return output;
}

private Long getStateIdAndIncrement(final StreamDescriptor streamDescriptor, final long increment) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ void setup() {
onClose,
flushFunction,
CATALOG,
new BufferManager(),
new BufferManager(outputRecordCollector),
flushFailure,
"default_ns");

Expand Down Expand Up @@ -204,7 +204,7 @@ void testBackPressure() throws Exception {
(hasFailed, recordCounts) -> {},
flushFunction,
CATALOG,
new BufferManager(1024 * 10),
new BufferManager(1024 * 10, outputRecordCollector),
flushFailure,
"default_ns");
when(flushFunction.getOptimalBatchSizeBytes()).thenReturn(0L);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,33 @@
import io.airbyte.cdk.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.cdk.integrations.destination_async.partial_messages.PartialAirbyteRecordMessage;
import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.function.Consumer;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;

public class BufferDequeueTest {

private static final int RECORD_SIZE_20_BYTES = 20;
private static final String DEFAULT_NAMESPACE = "foo_namespace";
public static final String RECORD_20_BYTES = "abc";
private static final String STREAM_NAME = "stream1";
private static final StreamDescriptor STREAM_DESC = new StreamDescriptor().withName(STREAM_NAME);
private static final PartialAirbyteMessage RECORD_MSG_20_BYTES = new PartialAirbyteMessage()
.withType(Type.RECORD)
.withRecord(new PartialAirbyteRecordMessage()
.withStream(STREAM_NAME));
private final Consumer<AirbyteMessage> outputRecordCollector = c -> {};

@Nested
class Take {

@Test
void testTakeShouldBestEffortRead() {
final BufferManager bufferManager = new BufferManager();
final BufferManager bufferManager = new BufferManager(outputRecordCollector);
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
final BufferDequeue dequeue = bufferManager.getBufferDequeue();

Expand All @@ -57,7 +59,7 @@ void testTakeShouldBestEffortRead() {

@Test
void testTakeShouldReturnAllIfPossible() {
final BufferManager bufferManager = new BufferManager();
final BufferManager bufferManager = new BufferManager(outputRecordCollector);
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
final BufferDequeue dequeue = bufferManager.getBufferDequeue();

Expand All @@ -74,7 +76,7 @@ void testTakeShouldReturnAllIfPossible() {

@Test
void testTakeFewerRecordsThanSizeLimitShouldNotError() {
final BufferManager bufferManager = new BufferManager();
final BufferManager bufferManager = new BufferManager(outputRecordCollector);
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
final BufferDequeue dequeue = bufferManager.getBufferDequeue();

Expand All @@ -92,7 +94,7 @@ void testTakeFewerRecordsThanSizeLimitShouldNotError() {

@Test
void testMetadataOperationsCorrect() {
final BufferManager bufferManager = new BufferManager();
final BufferManager bufferManager = new BufferManager(outputRecordCollector);
final BufferEnqueue enqueue = bufferManager.getBufferEnqueue();
final BufferDequeue dequeue = bufferManager.getBufferDequeue();

Expand Down Expand Up @@ -120,7 +122,7 @@ void testMetadataOperationsCorrect() {

@Test
void testMetadataOperationsError() {
final BufferManager bufferManager = new BufferManager();
final BufferManager bufferManager = new BufferManager(outputRecordCollector);
final BufferDequeue dequeue = bufferManager.getBufferDequeue();

final var ghostStream = new StreamDescriptor().withName("ghost stream");
Expand All @@ -136,7 +138,7 @@ void testMetadataOperationsError() {

@Test
void cleansUpMemoryForEmptyQueues() throws Exception {
final var bufferManager = new BufferManager();
final var bufferManager = new BufferManager(outputRecordCollector);
final var enqueue = bufferManager.getBufferEnqueue();
final var dequeue = bufferManager.getBufferDequeue();
final var memoryManager = bufferManager.getMemoryManager();
Expand Down
Loading
Loading