From a76ba27755a448a68db295520345a71886b0e0a0 Mon Sep 17 00:00:00 2001 From: Christophe Duong Date: Tue, 4 Jan 2022 13:12:57 +0100 Subject: [PATCH] Refactor jinja template for scd --- .../transform_catalog/stream_processor.py | 273 +++++++++--------- 1 file changed, 142 insertions(+), 131 deletions(-) diff --git a/airbyte-integrations/bases/base-normalization/normalization/transform_catalog/stream_processor.py b/airbyte-integrations/bases/base-normalization/normalization/transform_catalog/stream_processor.py index 84f75ff3e8b41..1a8993ddf8cc3 100644 --- a/airbyte-integrations/bases/base-normalization/normalization/transform_catalog/stream_processor.py +++ b/airbyte-integrations/bases/base-normalization/normalization/transform_catalog/stream_processor.py @@ -649,7 +649,144 @@ def safe_cast_to_string(definition: Dict, column_name: str, destination_type: De return col def generate_scd_type_2_model(self, from_table: str, column_names: Dict[str, Tuple[str, str]]) -> str: - scd_sql_template = """ + order_null = "is null asc" + if self.destination_type.value == DestinationType.ORACLE.value: + order_null = "asc nulls last" + if self.destination_type.value == DestinationType.MSSQL.value: + # SQL Server treats NULL values as the lowest values, then sorted in ascending order, NULLs come first. + order_null = "desc" + + lag_begin = "lag" + lag_end = "" + input_data_table = "input_data" + if self.destination_type == DestinationType.CLICKHOUSE: + # ClickHouse doesn't support lag() yet, this is a workaround solution + # Ref: https://clickhouse.com/docs/en/sql-reference/window-functions/ + lag_begin = "anyOrNull" + lag_end = "ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING" + input_data_table = "input_data_with_active_row_num" + + enable_left_join_null = "" + cast_begin = "cast(" + cast_as = " as " + cast_end = ")" + if self.destination_type == DestinationType.CLICKHOUSE: + enable_left_join_null = "--" + cast_begin = "accurateCastOrNull(" + cast_as = ", '" + cast_end = "')" + + # TODO move all cdc columns out of scd models + cdc_active_row_pattern = "" + cdc_updated_order_pattern = "" + cdc_cols = "" + quoted_cdc_cols = "" + if "_ab_cdc_deleted_at" in column_names.keys(): + col_cdc_deleted_at = self.name_transformer.normalize_column_name("_ab_cdc_deleted_at") + col_cdc_updated_at = self.name_transformer.normalize_column_name("_ab_cdc_updated_at") + quoted_col_cdc_deleted_at = self.name_transformer.normalize_column_name("_ab_cdc_deleted_at", in_jinja=True) + quoted_col_cdc_updated_at = self.name_transformer.normalize_column_name("_ab_cdc_updated_at", in_jinja=True) + cdc_active_row_pattern = f" and {col_cdc_deleted_at} is null" + cdc_updated_order_pattern = f", {col_cdc_updated_at} desc" + cdc_cols = ( + f", {cast_begin}{col_cdc_deleted_at}{cast_as}" + + "{{ dbt_utils.type_string() }}" + + f"{cast_end}" + + f", {cast_begin}{col_cdc_updated_at}{cast_as}" + + "{{ dbt_utils.type_string() }}" + + f"{cast_end}" + ) + quoted_cdc_cols = f", {quoted_col_cdc_deleted_at}, {quoted_col_cdc_updated_at}" + + if "_ab_cdc_log_pos" in column_names.keys(): + col_cdc_log_pos = self.name_transformer.normalize_column_name("_ab_cdc_log_pos") + quoted_col_cdc_log_pos = self.name_transformer.normalize_column_name("_ab_cdc_log_pos", in_jinja=True) + cdc_updated_order_pattern += f", {col_cdc_log_pos} desc" + cdc_cols += f", {cast_begin}{col_cdc_log_pos}{cast_as}" + "{{ dbt_utils.type_string() }}" + f"{cast_end}" + quoted_cdc_cols += f", {quoted_col_cdc_log_pos}" + + jinja_variables = { + "active_row": self.name_transformer.normalize_column_name("_airbyte_active_row"), + "airbyte_end_at": self.name_transformer.normalize_column_name("_airbyte_end_at"), + "airbyte_row_num": self.name_transformer.normalize_column_name("_airbyte_row_num"), + "airbyte_start_at": self.name_transformer.normalize_column_name("_airbyte_start_at"), + "airbyte_unique_key_scd": self.name_transformer.normalize_column_name(f"{self.airbyte_unique_key}_scd"), + "cdc_active_row": cdc_active_row_pattern, + "cdc_cols": cdc_cols, + "cdc_updated_at_order": cdc_updated_order_pattern, + "col_ab_id": self.get_ab_id(), + "col_emitted_at": self.get_emitted_at(), + "col_normalized_at": self.get_normalized_at(), + "cursor_field": self.get_cursor_field(column_names), + "enable_left_join_null": enable_left_join_null, + "fields": self.list_fields(column_names), + "from_table": from_table, + "hash_id": self.hash_id(), + "input_data_table": input_data_table, + "lag_begin": lag_begin, + "lag_end": lag_end, + "order_null": order_null, + "parent_hash_id": self.parent_hash_id(), + "primary_key_partition": self.get_primary_key_partition(column_names), + "primary_keys": self.list_primary_keys(column_names), + "quoted_airbyte_row_num": self.name_transformer.normalize_column_name("_airbyte_row_num", in_jinja=True), + "quoted_airbyte_start_at": self.name_transformer.normalize_column_name("_airbyte_start_at", in_jinja=True), + "quoted_cdc_cols": quoted_cdc_cols, + "quoted_col_emitted_at": self.get_emitted_at(in_jinja=True), + "quoted_unique_key": self.get_unique_key(in_jinja=True), + "sql_table_comment": self.sql_table_comment(include_from_table=True), + "unique_key": self.get_unique_key(), + } + if self.destination_type == DestinationType.CLICKHOUSE: + clickhouse_active_row_sql = Template( + """ +input_data_with_active_row_num as ( + select *, + row_number() over ( + partition by {{ primary_key_partition | join(", ") }} + order by + {{ cursor_field }} {{ order_null }}, + {{ cursor_field }} desc, + {{ col_emitted_at }} desc{{ cdc_updated_at_order }} + ) as _airbyte_active_row_num + from input_data +),""" + ).render(jinja_variables) + jinja_variables["clickhouse_active_row_sql"] = clickhouse_active_row_sql + scd_columns_sql = Template( + """ + case when _airbyte_active_row_num = 1{{ cdc_active_row }} then 1 else 0 end as {{ active_row }}, + {{ lag_begin }}({{ cursor_field }}) over ( + partition by {{ primary_key_partition | join(", ") }} + order by + {{ cursor_field }} {{ order_null }}, + {{ cursor_field }} desc, + {{ col_emitted_at }} desc{{ cdc_updated_at_order }} + {{ lag_end }} + ) as {{ airbyte_end_at }}""" + ).render(jinja_variables) + jinja_variables["scd_columns_sql"] = scd_columns_sql + else: + scd_columns_sql = Template( + """ + lag({{ cursor_field }}) over ( + partition by {{ primary_key_partition | join(", ") }} + order by + {{ cursor_field }} {{ order_null }}, + {{ cursor_field }} desc, + {{ col_emitted_at }} desc{{ cdc_updated_at_order }} + ) as {{ airbyte_end_at }}, + case when row_number() over ( + partition by {{ primary_key_partition | join(", ") }} + order by + {{ cursor_field }} {{ order_null }}, + {{ cursor_field }} desc, + {{ col_emitted_at }} desc{{ cdc_updated_at_order }} + ) = 1{{ cdc_active_row }} then 1 else 0 end as {{ active_row }}""" + ).render(jinja_variables) + jinja_variables["scd_columns_sql"] = scd_columns_sql + sql = Template( + """ -- depends_on: {{ from_table }} with {{ '{% if is_incremental() %}' }} @@ -699,19 +836,7 @@ def generate_scd_type_2_model(self, from_table: str, column_names: Dict[str, Tup {{ sql_table_comment }} ), {{ '{% endif %}' }} -{{ '{%- if var("destination") == "clickhouse" %}' }} -input_data_with_active_row_num as ( - select *, - row_number() over ( - partition by {{ primary_key_partition | join(", ") }} - order by - {{ cursor_field }} {{ order_null }}, - {{ cursor_field }} desc, - {{ col_emitted_at }} desc{{ cdc_updated_at_order }} - ) as _airbyte_active_row_num - from input_data -), -{{ '{%- endif %}' }} +{{ clickhouse_active_row_sql }} scd_data as ( -- SQL model to build a Type 2 Slowly Changing Dimension (SCD) table for each record identified by their primary key select @@ -727,32 +852,7 @@ def generate_scd_type_2_model(self, from_table: str, column_names: Dict[str, Tup {{ field }}, {%- endfor %} {{ cursor_field }} as {{ airbyte_start_at }}, - {{ '{%- if var("destination") == "clickhouse" %}' }} - case when _airbyte_active_row_num = 1{{ cdc_active_row }} then 1 else 0 end as {{ active_row }}, - {{ lag_begin }}({{ cursor_field }}) over ( - partition by {{ primary_key_partition | join(", ") }} - order by - {{ cursor_field }} {{ order_null }}, - {{ cursor_field }} desc, - {{ col_emitted_at }} desc{{ cdc_updated_at_order }} - {{ lag_end }} - ) as {{ airbyte_end_at }}, - {{ '{%- else %}' }} - lag({{ cursor_field }}) over ( - partition by {{ primary_key_partition | join(", ") }} - order by - {{ cursor_field }} {{ order_null }}, - {{ cursor_field }} desc, - {{ col_emitted_at }} desc{{ cdc_updated_at_order }} - ) as {{ airbyte_end_at }}, - case when row_number() over ( - partition by {{ primary_key_partition | join(", ") }} - order by - {{ cursor_field }} {{ order_null }}, - {{ cursor_field }} desc, - {{ col_emitted_at }} desc{{ cdc_updated_at_order }} - ) = 1{{ cdc_active_row }} then 1 else 0 end as {{ active_row }}, - {{ '{%- endif %}' }} + {{ scd_columns_sql }}, {{ col_ab_id }}, {{ col_emitted_at }}, {{ hash_id }} @@ -791,97 +891,8 @@ def generate_scd_type_2_model(self, from_table: str, column_names: Dict[str, Tup {{ '{{ current_timestamp() }}' }} as {{ col_normalized_at }}, {{ hash_id }} from dedup_data where {{ airbyte_row_num }} = 1 - """ - template = Template(scd_sql_template) - - order_null = "is null asc" - if self.destination_type.value == DestinationType.ORACLE.value: - order_null = "asc nulls last" - if self.destination_type.value == DestinationType.MSSQL.value: - # SQL Server treats NULL values as the lowest values, then sorted in ascending order, NULLs come first. - order_null = "desc" - - lag_begin = "lag" - lag_end = "" - input_data_table = "input_data" - if self.destination_type == DestinationType.CLICKHOUSE: - # ClickHouse doesn't support lag() yet, this is a workaround solution - # Ref: https://clickhouse.com/docs/en/sql-reference/window-functions/ - lag_begin = "anyOrNull" - lag_end = "ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING" - input_data_table = "input_data_with_active_row_num" - - enable_left_join_null = "" - cast_begin = "cast(" - cast_as = " as " - cast_end = ")" - if self.destination_type == DestinationType.CLICKHOUSE: - enable_left_join_null = "--" - cast_begin = "accurateCastOrNull(" - cast_as = ", '" - cast_end = "')" - - # TODO move all cdc columns out of scd models - cdc_active_row_pattern = "" - cdc_updated_order_pattern = "" - cdc_cols = "" - quoted_cdc_cols = "" - if "_ab_cdc_deleted_at" in column_names.keys(): - col_cdc_deleted_at = self.name_transformer.normalize_column_name("_ab_cdc_deleted_at") - col_cdc_updated_at = self.name_transformer.normalize_column_name("_ab_cdc_updated_at") - quoted_col_cdc_deleted_at = self.name_transformer.normalize_column_name("_ab_cdc_deleted_at", in_jinja=True) - quoted_col_cdc_updated_at = self.name_transformer.normalize_column_name("_ab_cdc_updated_at", in_jinja=True) - cdc_active_row_pattern = f" and {col_cdc_deleted_at} is null" - cdc_updated_order_pattern = f", {col_cdc_updated_at} desc" - cdc_cols = ( - f", {cast_begin}{col_cdc_deleted_at}{cast_as}" - + "{{ dbt_utils.type_string() }}" - + f"{cast_end}" - + f", {cast_begin}{col_cdc_updated_at}{cast_as}" - + "{{ dbt_utils.type_string() }}" - + f"{cast_end}" - ) - quoted_cdc_cols = f", {quoted_col_cdc_deleted_at}, {quoted_col_cdc_updated_at}" - - if "_ab_cdc_log_pos" in column_names.keys(): - col_cdc_log_pos = self.name_transformer.normalize_column_name("_ab_cdc_log_pos") - quoted_col_cdc_log_pos = self.name_transformer.normalize_column_name("_ab_cdc_log_pos", in_jinja=True) - cdc_updated_order_pattern += f", {col_cdc_log_pos} desc" - cdc_cols += f", {cast_begin}{col_cdc_log_pos}{cast_as}" + "{{ dbt_utils.type_string() }}" + f"{cast_end}" - quoted_cdc_cols += f", {quoted_col_cdc_log_pos}" - - sql = template.render( - order_null=order_null, - airbyte_start_at=self.name_transformer.normalize_column_name("_airbyte_start_at"), - quoted_airbyte_start_at=self.name_transformer.normalize_column_name("_airbyte_start_at", in_jinja=True), - airbyte_end_at=self.name_transformer.normalize_column_name("_airbyte_end_at"), - active_row=self.name_transformer.normalize_column_name("_airbyte_active_row"), - airbyte_row_num=self.name_transformer.normalize_column_name("_airbyte_row_num"), - quoted_airbyte_row_num=self.name_transformer.normalize_column_name("_airbyte_row_num", in_jinja=True), - airbyte_unique_key_scd=self.name_transformer.normalize_column_name(f"{self.airbyte_unique_key}_scd"), - unique_key=self.get_unique_key(), - quoted_unique_key=self.get_unique_key(in_jinja=True), - col_ab_id=self.get_ab_id(), - col_emitted_at=self.get_emitted_at(), - quoted_col_emitted_at=self.get_emitted_at(in_jinja=True), - col_normalized_at=self.get_normalized_at(), - parent_hash_id=self.parent_hash_id(), - fields=self.list_fields(column_names), - cursor_field=self.get_cursor_field(column_names), - primary_keys=self.list_primary_keys(column_names), - primary_key_partition=self.get_primary_key_partition(column_names), - hash_id=self.hash_id(), - from_table=from_table, - sql_table_comment=self.sql_table_comment(include_from_table=True), - cdc_active_row=cdc_active_row_pattern, - cdc_updated_at_order=cdc_updated_order_pattern, - cdc_cols=cdc_cols, - quoted_cdc_cols=quoted_cdc_cols, - lag_begin=lag_begin, - lag_end=lag_end, - enable_left_join_null=enable_left_join_null, - input_data_table=input_data_table, - ) +""" + ).render(jinja_variables) return sql def get_cursor_field(self, column_names: Dict[str, Tuple[str, str]], in_jinja: bool = False) -> str: