Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(low code): Fix missing cursor for ClientSideIncrementalRecordFilterDecorator #334

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions airbyte_cdk/sources/declarative/extractors/record_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class RecordSelector(HttpSelector):
_name: Union[InterpolatedString, str] = field(init=False, repr=False, default="")
record_filter: Optional[RecordFilter] = None
transformations: List[RecordTransformation] = field(default_factory=lambda: [])
transform_before_filtering: bool = False

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._parameters = parameters
Expand Down Expand Up @@ -104,9 +105,17 @@ def filter_and_transform(
Until we decide to move this logic away from the selector, we made this method public so that users like AsyncJobRetriever could
share the logic of doing transformations on a set of records.
"""
filtered_data = self._filter(all_data, stream_state, stream_slice, next_page_token)
transformed_data = self._transform(filtered_data, stream_state, stream_slice)
normalized_data = self._normalize_by_schema(transformed_data, schema=records_schema)
if self.transform_before_filtering:
transformed_data = self._transform(all_data, stream_state, stream_slice)
transformed_filtered_data = self._filter(
transformed_data, stream_state, stream_slice, next_page_token
)
else:
filtered_data = self._filter(all_data, stream_state, stream_slice, next_page_token)
transformed_filtered_data = self._transform(filtered_data, stream_state, stream_slice)
normalized_data = self._normalize_by_schema(
transformed_filtered_data, schema=records_schema
)
for data in normalized_data:
yield Record(data=data, stream_name=self.name, associated_slice=stream_slice)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2415,6 +2415,8 @@ def create_record_selector(
if model.record_filter
else None
)

transform_before_filtering = False
if client_side_incremental_sync:
record_filter = ClientSideIncrementalRecordFilterDecorator(
config=config,
Expand All @@ -2424,6 +2426,8 @@ def create_record_selector(
else None,
**client_side_incremental_sync,
)
transform_before_filtering = True

schema_normalization = (
TypeTransformer(SCHEMA_TRANSFORMER_TYPE_MAPPING[model.schema_normalization])
if isinstance(model.schema_normalization, SchemaNormalizationModel)
Expand All @@ -2438,6 +2442,7 @@ def create_record_selector(
transformations=transformations or [],
schema_normalization=schema_normalization,
parameters=model.parameters or {},
transform_before_filtering=transform_before_filtering,
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

import json
from unittest.mock import Mock, call
from unittest.mock import MagicMock, Mock, call

import pytest
import requests
Expand Down Expand Up @@ -220,3 +220,91 @@ def create_schema():
"field_float": {"type": "number"},
},
}


@pytest.mark.parametrize("transform_before_filtering", [True, False])
def test_transform_before_filtering(transform_before_filtering):
"""
Verify that when transform_before_filtering=True, records are modified before
filtering. When False, the filter sees the original record data first.
"""

# 1) Our response body with 'myfield' set differently
# The first record has myfield=0 (needs transformation to pass)
# The second record has myfield=999 (already passes the filter)
body = {"data": [{"id": 1, "myfield": 0}, {"id": 2, "myfield": 999}]}

# 2) A response object
response = requests.Response()
response._content = json.dumps(body).encode("utf-8")

# 3) A simple extractor pulling records from 'data'
extractor = DpathExtractor(
field_path=["data"], decoder=JsonDecoder(parameters={}), config={}, parameters={}
)

# 4) A filter that keeps only records whose 'myfield' == 999
# i.e.: "{{ record['myfield'] == 999 }}"
record_filter = RecordFilter(
config={},
condition="{{ record['myfield'] == 999 }}",
parameters={},
)

# 5) A transformation that sets 'myfield' to 999
# We'll attach it to a mock so we can confirm how many times it was called
transformation_mock = MagicMock(spec=RecordTransformation)

def transformation_side_effect(record, config, stream_state, stream_slice):
record["myfield"] = 999

transformation_mock.transform.side_effect = transformation_side_effect

# 6) Create a RecordSelector with transform_before_filtering set from our param
record_selector = RecordSelector(
extractor=extractor,
config={},
name="test_stream",
record_filter=record_filter,
transformations=[transformation_mock],
schema_normalization=TypeTransformer(TransformConfig.NoTransform),
transform_before_filtering=transform_before_filtering,
parameters={},
)

# 7) Collect records
stream_slice = StreamSlice(partition={}, cursor_slice={})
actual_records = list(
record_selector.select_records(
response=response,
records_schema={}, # not using schema in this test
stream_state={},
stream_slice=stream_slice,
next_page_token=None,
)
)

# 8) Assert how many records survive
if transform_before_filtering:
# Both records become myfield=999 BEFORE the filter => both pass
assert len(actual_records) == 2
# The transformation should be called 2x (once per record)
assert transformation_mock.transform.call_count == 2
else:
# The first record is myfield=0 when the filter sees it => filter excludes it
# The second record is myfield=999 => filter includes it
assert len(actual_records) == 1
# The transformation occurs only on that single surviving record
# (the filter is done first, so the first record is already dropped)
assert transformation_mock.transform.call_count == 1

# 9) Check final record data
# If transform_before_filtering=True => we have records [1,2] both with myfield=999
# If transform_before_filtering=False => we have record [2] with myfield=999
final_record_data = [r.data for r in actual_records]
if transform_before_filtering:
assert all(record["myfield"] == 999 for record in final_record_data)
assert sorted([r["id"] for r in final_record_data]) == [1, 2]
else:
assert final_record_data[0]["id"] == 2
assert final_record_data[0]["myfield"] == 999
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,8 @@ def test_client_side_incremental():
stream.retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator
)

assert stream.retriever.record_selector.transform_before_filtering == True


def test_client_side_incremental_with_partition_router():
content = """
Expand Down Expand Up @@ -1274,6 +1276,7 @@ def test_client_side_incremental_with_partition_router():
assert isinstance(
stream.retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator
)
assert stream.retriever.record_selector.transform_before_filtering == True
assert isinstance(
stream.retriever.record_selector.record_filter._cursor,
PerPartitionWithGlobalCursor,
Expand Down
Loading