Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply D105 to the Models Module Partly #38277

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,10 +718,12 @@ def _do_render_template_fields(
setattr(parent, attr_name, rendered_content)

def __enter__(self):
"""Enter the context manager for setup or teardown tasks."""
if not self.is_setup and not self.is_teardown:
raise AirflowException("Only setup/teardown tasks can be used as context managers.")
SetupTeardownContext.push_setup_teardown_task(self)
return SetupTeardownContext

def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit the context manager for setup or teardown tasks."""
SetupTeardownContext.set_work_task_roots_and_leaves()
8 changes: 8 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,16 +995,19 @@ def __init__(
SetupTeardownContext.update_context_map(self)

def __eq__(self, other):
"""Determine whether two instances of BaseOperator are equal."""
if type(self) is type(other):
# Use getattr() instead of __dict__ as __dict__ doesn't return
# correct values for properties.
return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps)
return False

def __ne__(self, other):
"""Determine whether two instances of BaseOperator are not equal."""
return not self == other

def __hash__(self):
"""Compute the hash value of the BaseOperator instance."""
hash_components = [type(self)]
for component in self._comps:
val = getattr(self, component, None)
Expand Down Expand Up @@ -1068,6 +1071,7 @@ def __lt__(self, other):
return self

def __setattr__(self, key, value):
"""Set the value of attributes for instances of the BaseOperator class."""
super().__setattr__(key, value)
if self.__from_mapped or self._lock_for_execution:
return # Skip any custom behavior for validation and during execute.
Expand Down Expand Up @@ -1219,6 +1223,7 @@ def on_kill(self) -> None:
"""

def __deepcopy__(self, memo):
"""Deep copy instances of the BaseOperator classes."""
# Hack sorting double chained task lists by task_id to avoid hitting
# max_depth on deepcopy operations.
sys.setrecursionlimit(5000) # TODO fix this in a better way
Expand All @@ -1241,13 +1246,15 @@ def __deepcopy__(self, memo):
return result

def __getstate__(self):
"""Get the state of the object for serialization."""
state = dict(self.__dict__)
if self._log:
del state["_log"]

return state

def __setstate__(self, state):
"""Restore the object's state from a serialized state dictionary."""
self.__dict__ = state

def render_template_fields(
Expand Down Expand Up @@ -1401,6 +1408,7 @@ def get_direct_relatives(self, upstream: bool = False) -> Iterable[Operator]:
return self.downstream_list

def __repr__(self):
"""Return a string representation of the Task object."""
return f"<Task({self.task_type}): {self.task_id}>"

@property
Expand Down
1 change: 1 addition & 0 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def get_hook(self, *, hook_params=None):
return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params)

def __repr__(self):
"""Return a string representation of the Connection object."""
return self.conn_id or ""

def log_info(self):
Expand Down
12 changes: 12 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(self, instance: Any, start_field_name: str, end_field_name: str) ->
self._end_field = (end_field_name, getattr(instance, end_field_name))

def __str__(self) -> str:
"""Return a string representation of the InconsistentDataInterval exception."""
return self._template.format(cls=self._class_name, start=self._start_field, end=self._end_field)


Expand Down Expand Up @@ -781,22 +782,27 @@ def validate_setup_teardown(self):
FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule)

def __repr__(self):
"""Return a string representation of the DAG object."""
return f"<DAG: {self.dag_id}>"

def __eq__(self, other):
"""Check if two DAG objects are equal."""
if type(self) == type(other):
# Use getattr() instead of __dict__ as __dict__ doesn't return
# correct values for properties.
return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps)
return False

def __ne__(self, other):
"""Check if two DAG objects are not equal."""
return not self == other

def __lt__(self, other):
"""Compare two DAG objects based on their DAG ID."""
return self.dag_id < other.dag_id

def __hash__(self):
"""Compute the hash value of the DAG object."""
hash_components = [type(self)]
for c in self._comps:
# task_ids returns a list and lists can't be hashed
Expand All @@ -813,10 +819,12 @@ def __hash__(self):

# Context Manager -----------------------------------------------
def __enter__(self):
"""Enter the context manager scope for the DAG object."""
DagContext.push_context_managed_dag(self)
return self

def __exit__(self, _type, _value, _tb):
"""Exit the context manager scope for the DAG object."""
DagContext.pop_context_managed_dag()

# /Context Manager ----------------------------------------------
Expand Down Expand Up @@ -2438,6 +2446,7 @@ def clear_dags(
return count

def __deepcopy__(self, memo):
"""Create a deep copy of the DAG object."""
# Switcharoo to go around deepcopying objects coming through the
# backdoor
cls = self.__class__
Expand Down Expand Up @@ -3521,6 +3530,7 @@ class DagTag(Base):
)

def __repr__(self):
"""Return a string representation of the DagTag object."""
return self.name


Expand All @@ -3542,6 +3552,7 @@ class DagOwnerAttributes(Base):
link = Column(String(500), nullable=False)

def __repr__(self):
"""Return a string representation of the DagOwnerAttributes object."""
return f"<DagOwnerAttributes: dag_id={self.dag_id}, owner={self.owner}, link={self.link}>"

@classmethod
Expand Down Expand Up @@ -3670,6 +3681,7 @@ def __init__(self, concurrency=None, **kwargs):
self.has_task_concurrency_limits = True

def __repr__(self):
"""Return a string representation of the DagModel object."""
return f"<DAG: {self.dag_id}>"

@property
Expand Down
2 changes: 2 additions & 0 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def __init__(
super().__init__()

def __repr__(self):
"""Return a string representation of the DagRun object."""
return (
f"<DagRun {self.dag_id} @ {self.execution_date}: {self.run_id}, state:{self.state}, "
f"queued_at: {self.queued_at}. externally triggered: {self.external_trigger}>"
Expand Down Expand Up @@ -1631,6 +1632,7 @@ def __init__(self, content, user_id=None):
self.user_id = user_id

def __repr__(self):
"""Return a string representation of the DagRunInfo object."""
prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.dagrun_id} {self.run_id}"
if self.map_index != -1:
prefix += f" map_index={self.map_index}"
Expand Down
2 changes: 2 additions & 0 deletions airflow/models/dagwarning.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def __init__(self, dag_id: str, error_type: str, message: str, **kwargs):
self.message = message

def __eq__(self, other) -> bool:
"""Compare two DagRunWarning objects for equality."""
return self.dag_id == other.dag_id and self.warning_type == other.warning_type

def __hash__(self) -> int:
"""Generate a hash value for the DagRunWarning object."""
return hash((self.dag_id, self.warning_type))

@classmethod
Expand Down
13 changes: 13 additions & 0 deletions airflow/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,18 @@ def __init__(self, uri: str, **kwargs):
super().__init__(uri=uri, **kwargs)

def __eq__(self, other):
"""Compare two DatasetModel objects for equality."""
if isinstance(other, (self.__class__, Dataset)):
return self.uri == other.uri
else:
return NotImplemented

def __hash__(self):
"""Generate a hash value for the DatasetModel object."""
return hash(self.uri)

def __repr__(self):
"""Return a string representation of the DatasetModel object."""
return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})"


Expand Down Expand Up @@ -139,15 +142,18 @@ class DagScheduleDatasetReference(Base):
)

def __eq__(self, other):
"""Compare two DagScheduleDatasetReference objects for equality."""
if isinstance(other, self.__class__):
return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id
else:
return NotImplemented

def __hash__(self):
"""Generate a hash value for the DagScheduleDatasetReference object."""
return hash(self.__mapper__.primary_key)

def __repr__(self):
"""Return a string representation of the DagScheduleDatasetReference object."""
args = []
for attr in [x.name for x in self.__mapper__.primary_key]:
args.append(f"{attr}={getattr(self, attr)!r}")
Expand Down Expand Up @@ -183,6 +189,7 @@ class TaskOutletDatasetReference(Base):
)

def __eq__(self, other):
"""Compare two TaskOutletDatasetReference objects for equality."""
if isinstance(other, self.__class__):
return (
self.dataset_id == other.dataset_id
Expand All @@ -193,9 +200,11 @@ def __eq__(self, other):
return NotImplemented

def __hash__(self):
"""Generate a hash value for the TaskOutletDatasetReference object."""
return hash(self.__mapper__.primary_key)

def __repr__(self):
"""Return a string representation of the TaskOutletDatasetReference object."""
args = []
for attr in [x.name for x in self.__mapper__.primary_key]:
args.append(f"{attr}={getattr(self, attr)!r}")
Expand Down Expand Up @@ -227,15 +236,18 @@ class DatasetDagRunQueue(Base):
)

def __eq__(self, other):
"""Compare two DatasetDagRunQueue objects for equality."""
if isinstance(other, self.__class__):
return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id
else:
return NotImplemented

def __hash__(self):
"""Generate a hash value for the DatasetDagRunQueue object."""
return hash(self.__mapper__.primary_key)

def __repr__(self):
"""Return a string representation of the DatasetDagRunQueue object."""
args = []
for attr in [x.name for x in self.__mapper__.primary_key]:
args.append(f"{attr}={getattr(self, attr)!r}")
Expand Down Expand Up @@ -324,6 +336,7 @@ def uri(self):
return self.dataset.uri

def __repr__(self) -> str:
"""Return a string representation of the DatasetEvent object."""
args = []
for attr in [
"id",
Expand Down
1 change: 1 addition & 0 deletions airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self, missing: set[str]) -> None:
self.missing = missing

def __str__(self) -> str:
"""Return a string representation of the exception."""
keys = ", ".join(repr(k) for k in sorted(self.missing))
return f"Failed to populate all mapping metadata; missing: {keys}"

Expand Down
1 change: 1 addition & 0 deletions airflow/models/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,5 @@ def __init__(self, event, task_instance=None, owner=None, owner_display_name=Non
self.owner_display_name = owner_display_name or None

def __str__(self) -> str:
"""Return a string representation of the Log object."""
return f"Log({self.event}, {self.task_id}, {self.owner}, {self.owner_display_name}, {self.extra})"
6 changes: 6 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,20 @@ class OperatorPartial:
_expand_called: bool = False # Set when expand() is called to ease user debugging.

def __attrs_post_init__(self):
"""Perform post-initialization actions for the OperatorPartial object."""
from airflow.operators.subdag import SubDagOperator

if issubclass(self.operator_class, SubDagOperator):
raise TypeError("Mapping over deprecated SubDagOperator is not supported")
validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)

def __repr__(self) -> str:
"""Return a string representation of the OperatorPartial object."""
args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
return f"{self.operator_class.__name__}.partial({args})"

def __del__(self):
"""Perform cleanup actions before the OperatorPartial object is destroyed."""
if not self._expand_called:
try:
task_id = repr(self.kwargs["task_id"])
Expand Down Expand Up @@ -309,12 +312,15 @@ class MappedOperator(AbstractOperator):
)

def __hash__(self):
"""Generate a hash value for the MappedOperator object."""
return id(self)

def __repr__(self):
"""Return a string representation of the MappedOperator object."""
return f"<Mapped({self._task_type}): {self.task_id}>"

def __attrs_post_init__(self):
"""Perform post-initialization actions for the MappedOperator object."""
from airflow.models.xcom_arg import XComArg

if self.get_closest_mapped_task_group() is not None:
Expand Down
10 changes: 0 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1433,16 +1433,6 @@ combine-as-imports = true
"airflow/kubernetes/pre_7_4_0_compatibility/secret.py" = ["D105"]
"airflow/metrics/protocols.py" = ["D105"]
"airflow/metrics/validators.py" = ["D105"]
"airflow/models/abstractoperator.py" = ["D105"]
"airflow/models/baseoperator.py" = ["D105"]
"airflow/models/connection.py" = ["D105"]
"airflow/models/dag.py" = ["D105"]
"airflow/models/dagrun.py" = ["D105"]
"airflow/models/dagwarning.py" = ["D105"]
"airflow/models/dataset.py" = ["D105"]
"airflow/models/expandinput.py" = ["D105"]
"airflow/models/log.py" = ["D105"]
"airflow/models/mappedoperator.py" = ["D105"]
"airflow/models/param.py" = ["D105"]
"airflow/models/pool.py" = ["D105"]
"airflow/models/renderedtifields.py" = ["D105"]
Expand Down
Loading