Skip to content

Commit

Permalink
Destination Snowflake add test to avoid duplicated staged data (#9412)
Browse files Browse the repository at this point in the history
* fix for jdk 17

* added unit test

* refactoring

* replace Exception with SQLException

Co-authored-by: vmaltsev <vitalii.maltsev@globallogic.com>
  • Loading branch information
VitaliiMaltsev and vmaltsev authored Jan 12, 2022
1 parent 1581512 commit 5383439
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.ConfiguredAirbyteStream;
import io.airbyte.protocol.models.DestinationSyncMode;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -178,7 +180,7 @@ private OnCloseFunction onCloseFunction(final JdbcDatabase database,
path);
try {
sqlOperations.copyIntoTmpTableFromStage(database, path, srcTableName, schemaName);
} catch (Exception e){
} catch (SQLException e){
sqlOperations.cleanUpStage(database, path);
LOGGER.info("Cleaning stage path {}", path);
throw new RuntimeException("Failed to upload data from stage "+ path, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,36 @@

package io.airbyte.integrations.destination.snowflake;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import io.airbyte.commons.io.IOs;
import io.airbyte.commons.jackson.MoreMappers;
import io.airbyte.commons.json.Jsons;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.integrations.base.AirbyteMessageConsumer;
import io.airbyte.integrations.base.Destination;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.AirbyteRecordMessage;
import io.airbyte.protocol.models.CatalogHelpers;
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.DestinationSyncMode;
import io.airbyte.protocol.models.Field;
import io.airbyte.protocol.models.JsonSchemaPrimitive;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

import java.nio.file.Path;
import java.sql.SQLException;
import java.time.Instant;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class SnowflakeDestinationTest {

private static final ObjectMapper mapper = MoreMappers.initMapper();
Expand Down Expand Up @@ -53,4 +75,49 @@ public void useInsertStrategyTest() {
assertFalse(SnowflakeDestination.isS3Copy(stubConfig));
}

@Test
public void testCleanupStageOnFailure() throws Exception {

JdbcDatabase mockDb = mock(JdbcDatabase.class);
SnowflakeStagingSqlOperations sqlOperations = mock(SnowflakeStagingSqlOperations.class);
final var testMessages = generateTestMessages();
final JsonNode config = Jsons.deserialize(IOs.readFile(Path.of("secrets/insert_config.json")));

AirbyteMessageConsumer airbyteMessageConsumer = new SnowflakeInternalStagingConsumerFactory()
.create(Destination::defaultOutputRecordCollector, mockDb,
sqlOperations, new SnowflakeSQLNameTransformer(), config, getCatalog());
doThrow(SQLException.class).when(sqlOperations).copyIntoTmpTableFromStage(any(),anyString(),anyString(),anyString());

airbyteMessageConsumer.start();
for (AirbyteMessage m : testMessages) {
airbyteMessageConsumer.accept(m);
}
assertThrows(RuntimeException.class, airbyteMessageConsumer::close);

verify(sqlOperations, times(1)).cleanUpStage(any(),anyString());
}

private List<AirbyteMessage> generateTestMessages() {
return IntStream.range(0, 3)
.boxed()
.map(i -> new AirbyteMessage()
.withType(AirbyteMessage.Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream("test")
.withNamespace("test_staging")
.withEmittedAt(Instant.now().toEpochMilli())
.withData(Jsons.jsonNode(ImmutableMap.of("id", i, "name", "human " + i)))))
.collect(Collectors.toList());
}

ConfiguredAirbyteCatalog getCatalog() {
return new ConfiguredAirbyteCatalog().withStreams(List.of(
CatalogHelpers.createConfiguredAirbyteStream(
"test",
"test_staging",
Field.of("id", JsonSchemaPrimitive.NUMBER),
Field.of("name", JsonSchemaPrimitive.STRING))
.withDestinationSyncMode(DestinationSyncMode.OVERWRITE)));
}

}

0 comments on commit 5383439

Please sign in to comment.