Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SCD Jinja templates for normalization #9278

Merged
merged 1 commit into from
Jan 4, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,144 @@ def safe_cast_to_string(definition: Dict, column_name: str, destination_type: De
return col

def generate_scd_type_2_model(self, from_table: str, column_names: Dict[str, Tuple[str, str]]) -> str:
scd_sql_template = """
order_null = "is null asc"
if self.destination_type.value == DestinationType.ORACLE.value:
order_null = "asc nulls last"
if self.destination_type.value == DestinationType.MSSQL.value:
# SQL Server treats NULL values as the lowest values, then sorted in ascending order, NULLs come first.
order_null = "desc"

lag_begin = "lag"
lag_end = ""
input_data_table = "input_data"
if self.destination_type == DestinationType.CLICKHOUSE:
# ClickHouse doesn't support lag() yet, this is a workaround solution
# Ref: https://clickhouse.com/docs/en/sql-reference/window-functions/
lag_begin = "anyOrNull"
lag_end = "ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING"
input_data_table = "input_data_with_active_row_num"

enable_left_join_null = ""
cast_begin = "cast("
cast_as = " as "
cast_end = ")"
if self.destination_type == DestinationType.CLICKHOUSE:
enable_left_join_null = "--"
cast_begin = "accurateCastOrNull("
cast_as = ", '"
cast_end = "')"

# TODO move all cdc columns out of scd models
cdc_active_row_pattern = ""
cdc_updated_order_pattern = ""
cdc_cols = ""
quoted_cdc_cols = ""
if "_ab_cdc_deleted_at" in column_names.keys():
col_cdc_deleted_at = self.name_transformer.normalize_column_name("_ab_cdc_deleted_at")
col_cdc_updated_at = self.name_transformer.normalize_column_name("_ab_cdc_updated_at")
quoted_col_cdc_deleted_at = self.name_transformer.normalize_column_name("_ab_cdc_deleted_at", in_jinja=True)
quoted_col_cdc_updated_at = self.name_transformer.normalize_column_name("_ab_cdc_updated_at", in_jinja=True)
cdc_active_row_pattern = f" and {col_cdc_deleted_at} is null"
cdc_updated_order_pattern = f", {col_cdc_updated_at} desc"
cdc_cols = (
f", {cast_begin}{col_cdc_deleted_at}{cast_as}"
+ "{{ dbt_utils.type_string() }}"
+ f"{cast_end}"
+ f", {cast_begin}{col_cdc_updated_at}{cast_as}"
+ "{{ dbt_utils.type_string() }}"
+ f"{cast_end}"
)
quoted_cdc_cols = f", {quoted_col_cdc_deleted_at}, {quoted_col_cdc_updated_at}"

if "_ab_cdc_log_pos" in column_names.keys():
col_cdc_log_pos = self.name_transformer.normalize_column_name("_ab_cdc_log_pos")
quoted_col_cdc_log_pos = self.name_transformer.normalize_column_name("_ab_cdc_log_pos", in_jinja=True)
cdc_updated_order_pattern += f", {col_cdc_log_pos} desc"
cdc_cols += f", {cast_begin}{col_cdc_log_pos}{cast_as}" + "{{ dbt_utils.type_string() }}" + f"{cast_end}"
quoted_cdc_cols += f", {quoted_col_cdc_log_pos}"

jinja_variables = {
"active_row": self.name_transformer.normalize_column_name("_airbyte_active_row"),
"airbyte_end_at": self.name_transformer.normalize_column_name("_airbyte_end_at"),
"airbyte_row_num": self.name_transformer.normalize_column_name("_airbyte_row_num"),
"airbyte_start_at": self.name_transformer.normalize_column_name("_airbyte_start_at"),
"airbyte_unique_key_scd": self.name_transformer.normalize_column_name(f"{self.airbyte_unique_key}_scd"),
"cdc_active_row": cdc_active_row_pattern,
"cdc_cols": cdc_cols,
"cdc_updated_at_order": cdc_updated_order_pattern,
"col_ab_id": self.get_ab_id(),
"col_emitted_at": self.get_emitted_at(),
"col_normalized_at": self.get_normalized_at(),
"cursor_field": self.get_cursor_field(column_names),
"enable_left_join_null": enable_left_join_null,
"fields": self.list_fields(column_names),
"from_table": from_table,
"hash_id": self.hash_id(),
"input_data_table": input_data_table,
"lag_begin": lag_begin,
"lag_end": lag_end,
"order_null": order_null,
"parent_hash_id": self.parent_hash_id(),
"primary_key_partition": self.get_primary_key_partition(column_names),
"primary_keys": self.list_primary_keys(column_names),
"quoted_airbyte_row_num": self.name_transformer.normalize_column_name("_airbyte_row_num", in_jinja=True),
"quoted_airbyte_start_at": self.name_transformer.normalize_column_name("_airbyte_start_at", in_jinja=True),
"quoted_cdc_cols": quoted_cdc_cols,
"quoted_col_emitted_at": self.get_emitted_at(in_jinja=True),
"quoted_unique_key": self.get_unique_key(in_jinja=True),
"sql_table_comment": self.sql_table_comment(include_from_table=True),
"unique_key": self.get_unique_key(),
}
if self.destination_type == DestinationType.CLICKHOUSE:
clickhouse_active_row_sql = Template(
"""
input_data_with_active_row_num as (
select *,
row_number() over (
partition by {{ primary_key_partition | join(", ") }}
order by
{{ cursor_field }} {{ order_null }},
{{ cursor_field }} desc,
{{ col_emitted_at }} desc{{ cdc_updated_at_order }}
) as _airbyte_active_row_num
from input_data
),"""
).render(jinja_variables)
jinja_variables["clickhouse_active_row_sql"] = clickhouse_active_row_sql
scd_columns_sql = Template(
"""
case when _airbyte_active_row_num = 1{{ cdc_active_row }} then 1 else 0 end as {{ active_row }},
{{ lag_begin }}({{ cursor_field }}) over (
partition by {{ primary_key_partition | join(", ") }}
order by
{{ cursor_field }} {{ order_null }},
{{ cursor_field }} desc,
{{ col_emitted_at }} desc{{ cdc_updated_at_order }}
{{ lag_end }}
) as {{ airbyte_end_at }}"""
).render(jinja_variables)
jinja_variables["scd_columns_sql"] = scd_columns_sql
else:
scd_columns_sql = Template(
"""
lag({{ cursor_field }}) over (
partition by {{ primary_key_partition | join(", ") }}
order by
{{ cursor_field }} {{ order_null }},
{{ cursor_field }} desc,
{{ col_emitted_at }} desc{{ cdc_updated_at_order }}
) as {{ airbyte_end_at }},
case when row_number() over (
partition by {{ primary_key_partition | join(", ") }}
order by
{{ cursor_field }} {{ order_null }},
{{ cursor_field }} desc,
{{ col_emitted_at }} desc{{ cdc_updated_at_order }}
) = 1{{ cdc_active_row }} then 1 else 0 end as {{ active_row }}"""
).render(jinja_variables)
jinja_variables["scd_columns_sql"] = scd_columns_sql
sql = Template(
"""
-- depends_on: {{ from_table }}
with
{{ '{% if is_incremental() %}' }}
Expand Down Expand Up @@ -699,19 +836,7 @@ def generate_scd_type_2_model(self, from_table: str, column_names: Dict[str, Tup
{{ sql_table_comment }}
),
{{ '{% endif %}' }}
{{ '{%- if var("destination") == "clickhouse" %}' }}
input_data_with_active_row_num as (
select *,
row_number() over (
partition by {{ primary_key_partition | join(", ") }}
order by
{{ cursor_field }} {{ order_null }},
{{ cursor_field }} desc,
{{ col_emitted_at }} desc{{ cdc_updated_at_order }}
) as _airbyte_active_row_num
from input_data
),
{{ '{%- endif %}' }}
{{ clickhouse_active_row_sql }}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does jinja insert an empty string if this variable is unset?

scd_data as (
-- SQL model to build a Type 2 Slowly Changing Dimension (SCD) table for each record identified by their primary key
select
Expand All @@ -727,32 +852,7 @@ def generate_scd_type_2_model(self, from_table: str, column_names: Dict[str, Tup
{{ field }},
{%- endfor %}
{{ cursor_field }} as {{ airbyte_start_at }},
{{ '{%- if var("destination") == "clickhouse" %}' }}
case when _airbyte_active_row_num = 1{{ cdc_active_row }} then 1 else 0 end as {{ active_row }},
{{ lag_begin }}({{ cursor_field }}) over (
partition by {{ primary_key_partition | join(", ") }}
order by
{{ cursor_field }} {{ order_null }},
{{ cursor_field }} desc,
{{ col_emitted_at }} desc{{ cdc_updated_at_order }}
{{ lag_end }}
) as {{ airbyte_end_at }},
{{ '{%- else %}' }}
lag({{ cursor_field }}) over (
partition by {{ primary_key_partition | join(", ") }}
order by
{{ cursor_field }} {{ order_null }},
{{ cursor_field }} desc,
{{ col_emitted_at }} desc{{ cdc_updated_at_order }}
) as {{ airbyte_end_at }},
case when row_number() over (
partition by {{ primary_key_partition | join(", ") }}
order by
{{ cursor_field }} {{ order_null }},
{{ cursor_field }} desc,
{{ col_emitted_at }} desc{{ cdc_updated_at_order }}
) = 1{{ cdc_active_row }} then 1 else 0 end as {{ active_row }},
{{ '{%- endif %}' }}
{{ scd_columns_sql }},
{{ col_ab_id }},
{{ col_emitted_at }},
{{ hash_id }}
Expand Down Expand Up @@ -791,97 +891,8 @@ def generate_scd_type_2_model(self, from_table: str, column_names: Dict[str, Tup
{{ '{{ current_timestamp() }}' }} as {{ col_normalized_at }},
{{ hash_id }}
from dedup_data where {{ airbyte_row_num }} = 1
"""
template = Template(scd_sql_template)

order_null = "is null asc"
if self.destination_type.value == DestinationType.ORACLE.value:
order_null = "asc nulls last"
if self.destination_type.value == DestinationType.MSSQL.value:
# SQL Server treats NULL values as the lowest values, then sorted in ascending order, NULLs come first.
order_null = "desc"

lag_begin = "lag"
lag_end = ""
input_data_table = "input_data"
if self.destination_type == DestinationType.CLICKHOUSE:
# ClickHouse doesn't support lag() yet, this is a workaround solution
# Ref: https://clickhouse.com/docs/en/sql-reference/window-functions/
lag_begin = "anyOrNull"
lag_end = "ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING"
input_data_table = "input_data_with_active_row_num"

enable_left_join_null = ""
cast_begin = "cast("
cast_as = " as "
cast_end = ")"
if self.destination_type == DestinationType.CLICKHOUSE:
enable_left_join_null = "--"
cast_begin = "accurateCastOrNull("
cast_as = ", '"
cast_end = "')"

# TODO move all cdc columns out of scd models
cdc_active_row_pattern = ""
cdc_updated_order_pattern = ""
cdc_cols = ""
quoted_cdc_cols = ""
if "_ab_cdc_deleted_at" in column_names.keys():
col_cdc_deleted_at = self.name_transformer.normalize_column_name("_ab_cdc_deleted_at")
col_cdc_updated_at = self.name_transformer.normalize_column_name("_ab_cdc_updated_at")
quoted_col_cdc_deleted_at = self.name_transformer.normalize_column_name("_ab_cdc_deleted_at", in_jinja=True)
quoted_col_cdc_updated_at = self.name_transformer.normalize_column_name("_ab_cdc_updated_at", in_jinja=True)
cdc_active_row_pattern = f" and {col_cdc_deleted_at} is null"
cdc_updated_order_pattern = f", {col_cdc_updated_at} desc"
cdc_cols = (
f", {cast_begin}{col_cdc_deleted_at}{cast_as}"
+ "{{ dbt_utils.type_string() }}"
+ f"{cast_end}"
+ f", {cast_begin}{col_cdc_updated_at}{cast_as}"
+ "{{ dbt_utils.type_string() }}"
+ f"{cast_end}"
)
quoted_cdc_cols = f", {quoted_col_cdc_deleted_at}, {quoted_col_cdc_updated_at}"

if "_ab_cdc_log_pos" in column_names.keys():
col_cdc_log_pos = self.name_transformer.normalize_column_name("_ab_cdc_log_pos")
quoted_col_cdc_log_pos = self.name_transformer.normalize_column_name("_ab_cdc_log_pos", in_jinja=True)
cdc_updated_order_pattern += f", {col_cdc_log_pos} desc"
cdc_cols += f", {cast_begin}{col_cdc_log_pos}{cast_as}" + "{{ dbt_utils.type_string() }}" + f"{cast_end}"
quoted_cdc_cols += f", {quoted_col_cdc_log_pos}"

sql = template.render(
order_null=order_null,
airbyte_start_at=self.name_transformer.normalize_column_name("_airbyte_start_at"),
quoted_airbyte_start_at=self.name_transformer.normalize_column_name("_airbyte_start_at", in_jinja=True),
airbyte_end_at=self.name_transformer.normalize_column_name("_airbyte_end_at"),
active_row=self.name_transformer.normalize_column_name("_airbyte_active_row"),
airbyte_row_num=self.name_transformer.normalize_column_name("_airbyte_row_num"),
quoted_airbyte_row_num=self.name_transformer.normalize_column_name("_airbyte_row_num", in_jinja=True),
airbyte_unique_key_scd=self.name_transformer.normalize_column_name(f"{self.airbyte_unique_key}_scd"),
unique_key=self.get_unique_key(),
quoted_unique_key=self.get_unique_key(in_jinja=True),
col_ab_id=self.get_ab_id(),
col_emitted_at=self.get_emitted_at(),
quoted_col_emitted_at=self.get_emitted_at(in_jinja=True),
col_normalized_at=self.get_normalized_at(),
parent_hash_id=self.parent_hash_id(),
fields=self.list_fields(column_names),
cursor_field=self.get_cursor_field(column_names),
primary_keys=self.list_primary_keys(column_names),
primary_key_partition=self.get_primary_key_partition(column_names),
hash_id=self.hash_id(),
from_table=from_table,
sql_table_comment=self.sql_table_comment(include_from_table=True),
cdc_active_row=cdc_active_row_pattern,
cdc_updated_at_order=cdc_updated_order_pattern,
cdc_cols=cdc_cols,
quoted_cdc_cols=quoted_cdc_cols,
lag_begin=lag_begin,
lag_end=lag_end,
enable_left_join_null=enable_left_join_null,
input_data_table=input_data_table,
)
"""
).render(jinja_variables)
return sql

def get_cursor_field(self, column_names: Dict[str, Tuple[str, str]], in_jinja: bool = False) -> str:
Expand Down