From 613481d53aa841f64b1b469663d8e1a41b005cac Mon Sep 17 00:00:00 2001 From: Anatolii Yatsuk Date: Thu, 13 Feb 2025 10:55:10 +0200 Subject: [PATCH] Fix missing cursor for ClientSideIncrementalRecordFilterDecorator --- .../declarative/extractors/record_selector.py | 15 +++- .../parsers/model_to_component_factory.py | 5 ++ .../extractors/test_record_selector.py | 90 ++++++++++++++++++- .../test_model_to_component_factory.py | 3 + 4 files changed, 109 insertions(+), 4 deletions(-) diff --git a/airbyte_cdk/sources/declarative/extractors/record_selector.py b/airbyte_cdk/sources/declarative/extractors/record_selector.py index f29b8a75a..c37b8035b 100644 --- a/airbyte_cdk/sources/declarative/extractors/record_selector.py +++ b/airbyte_cdk/sources/declarative/extractors/record_selector.py @@ -41,6 +41,7 @@ class RecordSelector(HttpSelector): _name: Union[InterpolatedString, str] = field(init=False, repr=False, default="") record_filter: Optional[RecordFilter] = None transformations: List[RecordTransformation] = field(default_factory=lambda: []) + transform_before_filtering: bool = False def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._parameters = parameters @@ -104,9 +105,17 @@ def filter_and_transform( Until we decide to move this logic away from the selector, we made this method public so that users like AsyncJobRetriever could share the logic of doing transformations on a set of records. """ - filtered_data = self._filter(all_data, stream_state, stream_slice, next_page_token) - transformed_data = self._transform(filtered_data, stream_state, stream_slice) - normalized_data = self._normalize_by_schema(transformed_data, schema=records_schema) + if self.transform_before_filtering: + transformed_data = self._transform(all_data, stream_state, stream_slice) + transformed_filtered_data = self._filter( + transformed_data, stream_state, stream_slice, next_page_token + ) + else: + filtered_data = self._filter(all_data, stream_state, stream_slice, next_page_token) + transformed_filtered_data = self._transform(filtered_data, stream_state, stream_slice) + normalized_data = self._normalize_by_schema( + transformed_filtered_data, schema=records_schema + ) for data in normalized_data: yield Record(data=data, stream_name=self.name, associated_slice=stream_slice) diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 12464b40a..739d15825 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -2415,6 +2415,8 @@ def create_record_selector( if model.record_filter else None ) + + transform_before_filtering = False if client_side_incremental_sync: record_filter = ClientSideIncrementalRecordFilterDecorator( config=config, @@ -2424,6 +2426,8 @@ def create_record_selector( else None, **client_side_incremental_sync, ) + transform_before_filtering = True + schema_normalization = ( TypeTransformer(SCHEMA_TRANSFORMER_TYPE_MAPPING[model.schema_normalization]) if isinstance(model.schema_normalization, SchemaNormalizationModel) @@ -2438,6 +2442,7 @@ def create_record_selector( transformations=transformations or [], schema_normalization=schema_normalization, parameters=model.parameters or {}, + transform_before_filtering=transform_before_filtering, ) @staticmethod diff --git a/unit_tests/sources/declarative/extractors/test_record_selector.py b/unit_tests/sources/declarative/extractors/test_record_selector.py index ee0f2f94d..5ec883ad2 100644 --- a/unit_tests/sources/declarative/extractors/test_record_selector.py +++ b/unit_tests/sources/declarative/extractors/test_record_selector.py @@ -3,7 +3,7 @@ # import json -from unittest.mock import Mock, call +from unittest.mock import MagicMock, Mock, call import pytest import requests @@ -220,3 +220,91 @@ def create_schema(): "field_float": {"type": "number"}, }, } + + +@pytest.mark.parametrize("transform_before_filtering", [True, False]) +def test_transform_before_filtering(transform_before_filtering): + """ + Verify that when transform_before_filtering=True, records are modified before + filtering. When False, the filter sees the original record data first. + """ + + # 1) Our response body with 'myfield' set differently + # The first record has myfield=0 (needs transformation to pass) + # The second record has myfield=999 (already passes the filter) + body = {"data": [{"id": 1, "myfield": 0}, {"id": 2, "myfield": 999}]} + + # 2) A response object + response = requests.Response() + response._content = json.dumps(body).encode("utf-8") + + # 3) A simple extractor pulling records from 'data' + extractor = DpathExtractor( + field_path=["data"], decoder=JsonDecoder(parameters={}), config={}, parameters={} + ) + + # 4) A filter that keeps only records whose 'myfield' == 999 + # i.e.: "{{ record['myfield'] == 999 }}" + record_filter = RecordFilter( + config={}, + condition="{{ record['myfield'] == 999 }}", + parameters={}, + ) + + # 5) A transformation that sets 'myfield' to 999 + # We'll attach it to a mock so we can confirm how many times it was called + transformation_mock = MagicMock(spec=RecordTransformation) + + def transformation_side_effect(record, config, stream_state, stream_slice): + record["myfield"] = 999 + + transformation_mock.transform.side_effect = transformation_side_effect + + # 6) Create a RecordSelector with transform_before_filtering set from our param + record_selector = RecordSelector( + extractor=extractor, + config={}, + name="test_stream", + record_filter=record_filter, + transformations=[transformation_mock], + schema_normalization=TypeTransformer(TransformConfig.NoTransform), + transform_before_filtering=transform_before_filtering, + parameters={}, + ) + + # 7) Collect records + stream_slice = StreamSlice(partition={}, cursor_slice={}) + actual_records = list( + record_selector.select_records( + response=response, + records_schema={}, # not using schema in this test + stream_state={}, + stream_slice=stream_slice, + next_page_token=None, + ) + ) + + # 8) Assert how many records survive + if transform_before_filtering: + # Both records become myfield=999 BEFORE the filter => both pass + assert len(actual_records) == 2 + # The transformation should be called 2x (once per record) + assert transformation_mock.transform.call_count == 2 + else: + # The first record is myfield=0 when the filter sees it => filter excludes it + # The second record is myfield=999 => filter includes it + assert len(actual_records) == 1 + # The transformation occurs only on that single surviving record + # (the filter is done first, so the first record is already dropped) + assert transformation_mock.transform.call_count == 1 + + # 9) Check final record data + # If transform_before_filtering=True => we have records [1,2] both with myfield=999 + # If transform_before_filtering=False => we have record [2] with myfield=999 + final_record_data = [r.data for r in actual_records] + if transform_before_filtering: + assert all(record["myfield"] == 999 for record in final_record_data) + assert sorted([r["id"] for r in final_record_data]) == [1, 2] + else: + assert final_record_data[0]["id"] == 2 + assert final_record_data[0]["myfield"] == 999 diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index a062cdfc7..faab999cb 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -1194,6 +1194,8 @@ def test_client_side_incremental(): stream.retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator ) + assert stream.retriever.record_selector.transform_before_filtering == True + def test_client_side_incremental_with_partition_router(): content = """ @@ -1274,6 +1276,7 @@ def test_client_side_incremental_with_partition_router(): assert isinstance( stream.retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator ) + assert stream.retriever.record_selector.transform_before_filtering == True assert isinstance( stream.retriever.record_selector.record_filter._cursor, PerPartitionWithGlobalCursor,