From daaa4cc5083ccaf08fe55c0f4c05c94dd2b2cfc9 Mon Sep 17 00:00:00 2001 From: satoshi-sh Date: Mon, 18 Mar 2024 14:49:42 -0500 Subject: [PATCH 1/3] Fix:fixed first 4 files for d105. --- airflow/models/abstractoperator.py | 2 ++ airflow/models/baseoperator.py | 8 ++++++++ airflow/models/connection.py | 1 + airflow/models/dag.py | 12 ++++++++++++ pyproject.toml | 4 ---- 5 files changed, 23 insertions(+), 4 deletions(-) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index f2d179f01ba35..ac3bd73c81bff 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -718,10 +718,12 @@ def _do_render_template_fields( setattr(parent, attr_name, rendered_content) def __enter__(self): + """Enter the context manager for setup or teardown tasks.""" if not self.is_setup and not self.is_teardown: raise AirflowException("Only setup/teardown tasks can be used as context managers.") SetupTeardownContext.push_setup_teardown_task(self) return SetupTeardownContext def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the context manager for setup or teardown tasks.""" SetupTeardownContext.set_work_task_roots_and_leaves() diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 72fa1aff264f6..606191503c1ac 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -995,6 +995,7 @@ def __init__( SetupTeardownContext.update_context_map(self) def __eq__(self, other): + """Determine whether two instances of BaseOperator are equal.""" if type(self) is type(other): # Use getattr() instead of __dict__ as __dict__ doesn't return # correct values for properties. @@ -1002,9 +1003,11 @@ def __eq__(self, other): return False def __ne__(self, other): + """Determine whether two instances of BaseOperator are not equal.""" return not self == other def __hash__(self): + """Compute the hash value of the BaseOperator instance.""" hash_components = [type(self)] for component in self._comps: val = getattr(self, component, None) @@ -1068,6 +1071,7 @@ def __lt__(self, other): return self def __setattr__(self, key, value): + """Set the value of attributes for instances of the BaseOperator class.""" super().__setattr__(key, value) if self.__from_mapped or self._lock_for_execution: return # Skip any custom behavior for validation and during execute. @@ -1219,6 +1223,7 @@ def on_kill(self) -> None: """ def __deepcopy__(self, memo): + """Deep copy instances of the BaseOperator classes.""" # Hack sorting double chained task lists by task_id to avoid hitting # max_depth on deepcopy operations. sys.setrecursionlimit(5000) # TODO fix this in a better way @@ -1241,6 +1246,7 @@ def __deepcopy__(self, memo): return result def __getstate__(self): + """Get the state of the object for serialization.""" state = dict(self.__dict__) if self._log: del state["_log"] @@ -1248,6 +1254,7 @@ def __getstate__(self): return state def __setstate__(self, state): + """Restore the object's state from a serialized state dictionary.""" self.__dict__ = state def render_template_fields( @@ -1401,6 +1408,7 @@ def get_direct_relatives(self, upstream: bool = False) -> Iterable[Operator]: return self.downstream_list def __repr__(self): + """Return a string representation of the Task object.""" return f"" @property diff --git a/airflow/models/connection.py b/airflow/models/connection.py index b9b4975f89835..1812990c19dbc 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -416,6 +416,7 @@ def get_hook(self, *, hook_params=None): return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params) def __repr__(self): + """Return a string representation of the Connection object.""" return self.conn_id or "" def log_info(self): diff --git a/airflow/models/dag.py b/airflow/models/dag.py index a230c94fd7529..a983b452308fe 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -198,6 +198,7 @@ def __init__(self, instance: Any, start_field_name: str, end_field_name: str) -> self._end_field = (end_field_name, getattr(instance, end_field_name)) def __str__(self) -> str: + """Return a string representation of the InconsistentDataInterval exception.""" return self._template.format(cls=self._class_name, start=self._start_field, end=self._end_field) @@ -781,9 +782,11 @@ def validate_setup_teardown(self): FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule) def __repr__(self): + """Return a string representation of the DAG object.""" return f"" def __eq__(self, other): + """Check if two DAG objects are equal.""" if type(self) == type(other): # Use getattr() instead of __dict__ as __dict__ doesn't return # correct values for properties. @@ -791,12 +794,15 @@ def __eq__(self, other): return False def __ne__(self, other): + """Check if two DAG objects are not equal.""" return not self == other def __lt__(self, other): + """Compare two DAG objects based on their DAG ID.""" return self.dag_id < other.dag_id def __hash__(self): + """Compute the hash value of the DAG object.""" hash_components = [type(self)] for c in self._comps: # task_ids returns a list and lists can't be hashed @@ -813,10 +819,12 @@ def __hash__(self): # Context Manager ----------------------------------------------- def __enter__(self): + """Enter the context manager scope for the DAG object.""" DagContext.push_context_managed_dag(self) return self def __exit__(self, _type, _value, _tb): + """Exit the context manager scope for the DAG object.""" DagContext.pop_context_managed_dag() # /Context Manager ---------------------------------------------- @@ -2438,6 +2446,7 @@ def clear_dags( return count def __deepcopy__(self, memo): + """Create a deep copy of the DAG object.""" # Switcharoo to go around deepcopying objects coming through the # backdoor cls = self.__class__ @@ -3521,6 +3530,7 @@ class DagTag(Base): ) def __repr__(self): + """Return a string representation of the DagTag object.""" return self.name @@ -3542,6 +3552,7 @@ class DagOwnerAttributes(Base): link = Column(String(500), nullable=False) def __repr__(self): + """Return a string representation of the DagOwnerAttributes object.""" return f"" @classmethod @@ -3670,6 +3681,7 @@ def __init__(self, concurrency=None, **kwargs): self.has_task_concurrency_limits = True def __repr__(self): + """Return a string representation of the DagModel object.""" return f"" @property diff --git a/pyproject.toml b/pyproject.toml index b913274769daa..9bc67e4d705c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1433,10 +1433,6 @@ combine-as-imports = true "airflow/kubernetes/pre_7_4_0_compatibility/secret.py" = ["D105"] "airflow/metrics/protocols.py" = ["D105"] "airflow/metrics/validators.py" = ["D105"] -"airflow/models/abstractoperator.py" = ["D105"] -"airflow/models/baseoperator.py" = ["D105"] -"airflow/models/connection.py" = ["D105"] -"airflow/models/dag.py" = ["D105"] "airflow/models/dagrun.py" = ["D105"] "airflow/models/dagwarning.py" = ["D105"] "airflow/models/dataset.py" = ["D105"] From 6b2cc437e27786bc23bbe6e92418a27b4fc7c9cc Mon Sep 17 00:00:00 2001 From: satoshi-sh Date: Mon, 18 Mar 2024 18:07:47 -0500 Subject: [PATCH 2/3] fix:fixed 6 files. --- airflow/models/dagrun.py | 10 +++++++--- airflow/models/dagwarning.py | 2 ++ airflow/models/dataset.py | 13 +++++++++++++ airflow/models/expandinput.py | 1 + airflow/models/log.py | 1 + airflow/models/mappedoperator.py | 6 ++++++ pyproject.toml | 6 ------ 7 files changed, 30 insertions(+), 9 deletions(-) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index f9e7022b5fd37..edae65129d129 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -241,6 +241,7 @@ def __init__( super().__init__() def __repr__(self): + """Return a string representation of the DagRun object.""" return ( f"" @@ -896,9 +897,11 @@ def recalculate(self) -> _UnfinishedStates: self.run_id, self.start_date, self.end_date, - (self.end_date - self.start_date).total_seconds() - if self.start_date and self.end_date - else None, + ( + (self.end_date - self.start_date).total_seconds() + if self.start_date and self.end_date + else None + ), self._state, self.external_trigger, self.run_type, @@ -1631,6 +1634,7 @@ def __init__(self, content, user_id=None): self.user_id = user_id def __repr__(self): + """Return a string representation of the DagRunInfo object.""" prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.dagrun_id} {self.run_id}" if self.map_index != -1: prefix += f" map_index={self.map_index}" diff --git a/airflow/models/dagwarning.py b/airflow/models/dagwarning.py index 789fe0172784b..2f330d4927ecf 100644 --- a/airflow/models/dagwarning.py +++ b/airflow/models/dagwarning.py @@ -64,9 +64,11 @@ def __init__(self, dag_id: str, error_type: str, message: str, **kwargs): self.message = message def __eq__(self, other) -> bool: + """Compare two DagRunWarning objects for equality.""" return self.dag_id == other.dag_id and self.warning_type == other.warning_type def __hash__(self) -> int: + """Generate a hash value for the DagRunWarning object.""" return hash((self.dag_id, self.warning_type)) @classmethod diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index aa10eb3809756..b8c67ad584a5c 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -91,15 +91,18 @@ def __init__(self, uri: str, **kwargs): super().__init__(uri=uri, **kwargs) def __eq__(self, other): + """Compare two DatasetModel objects for equality.""" if isinstance(other, (self.__class__, Dataset)): return self.uri == other.uri else: return NotImplemented def __hash__(self): + """Generate a hash value for the DatasetModel object.""" return hash(self.uri) def __repr__(self): + """Return a string representation of the DatasetModel object.""" return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})" @@ -139,15 +142,18 @@ class DagScheduleDatasetReference(Base): ) def __eq__(self, other): + """Compare two DagScheduleDatasetReference objects for equality.""" if isinstance(other, self.__class__): return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id else: return NotImplemented def __hash__(self): + """Generate a hash value for the DagScheduleDatasetReference object.""" return hash(self.__mapper__.primary_key) def __repr__(self): + """Return a string representation of the DagScheduleDatasetReference object.""" args = [] for attr in [x.name for x in self.__mapper__.primary_key]: args.append(f"{attr}={getattr(self, attr)!r}") @@ -183,6 +189,7 @@ class TaskOutletDatasetReference(Base): ) def __eq__(self, other): + """Compare two TaskOutletDatasetReference objects for equality.""" if isinstance(other, self.__class__): return ( self.dataset_id == other.dataset_id @@ -193,9 +200,11 @@ def __eq__(self, other): return NotImplemented def __hash__(self): + """Generate a hash value for the TaskOutletDatasetReference object.""" return hash(self.__mapper__.primary_key) def __repr__(self): + """Return a string representation of the TaskOutletDatasetReference object.""" args = [] for attr in [x.name for x in self.__mapper__.primary_key]: args.append(f"{attr}={getattr(self, attr)!r}") @@ -227,15 +236,18 @@ class DatasetDagRunQueue(Base): ) def __eq__(self, other): + """Compare two DatasetDagRunQueue objects for equality.""" if isinstance(other, self.__class__): return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id else: return NotImplemented def __hash__(self): + """Generate a hash value for the DatasetDagRunQueue object.""" return hash(self.__mapper__.primary_key) def __repr__(self): + """Return a string representation of the DatasetDagRunQueue object.""" args = [] for attr in [x.name for x in self.__mapper__.primary_key]: args.append(f"{attr}={getattr(self, attr)!r}") @@ -324,6 +336,7 @@ def uri(self): return self.dataset.uri def __repr__(self) -> str: + """Return a string representation of the DatasetEvent object.""" args = [] for attr in [ "id", diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index a20280e5a1111..cef92870b5acd 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -104,6 +104,7 @@ def __init__(self, missing: set[str]) -> None: self.missing = missing def __str__(self) -> str: + """Return a string representation of the exception.""" keys = ", ".join(repr(k) for k in sorted(self.missing)) return f"Failed to populate all mapping metadata; missing: {keys}" diff --git a/airflow/models/log.py b/airflow/models/log.py index fbf26b23276d6..99733623f01cc 100644 --- a/airflow/models/log.py +++ b/airflow/models/log.py @@ -78,4 +78,5 @@ def __init__(self, event, task_instance=None, owner=None, owner_display_name=Non self.owner_display_name = owner_display_name or None def __str__(self) -> str: + """Return a string representation of the Log object.""" return f"Log({self.event}, {self.task_id}, {self.owner}, {self.owner_display_name}, {self.extra})" diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index c1d75fc11ea92..01e2c35e63fc2 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -154,6 +154,7 @@ class OperatorPartial: _expand_called: bool = False # Set when expand() is called to ease user debugging. def __attrs_post_init__(self): + """Perform post-initialization actions for the OperatorPartial object.""" from airflow.operators.subdag import SubDagOperator if issubclass(self.operator_class, SubDagOperator): @@ -161,10 +162,12 @@ def __attrs_post_init__(self): validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) def __repr__(self) -> str: + """Return a string representation of the OperatorPartial object.""" args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) return f"{self.operator_class.__name__}.partial({args})" def __del__(self): + """Perform cleanup actions before the OperatorPartial object is destroyed.""" if not self._expand_called: try: task_id = repr(self.kwargs["task_id"]) @@ -309,12 +312,15 @@ class MappedOperator(AbstractOperator): ) def __hash__(self): + """Generate a hash value for the MappedOperator object.""" return id(self) def __repr__(self): + """Return a string representation of the MappedOperator object.""" return f"" def __attrs_post_init__(self): + """Perform post-initialization actions for the MappedOperator object.""" from airflow.models.xcom_arg import XComArg if self.get_closest_mapped_task_group() is not None: diff --git a/pyproject.toml b/pyproject.toml index 9bc67e4d705c3..2de9b8b11ab0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1433,12 +1433,6 @@ combine-as-imports = true "airflow/kubernetes/pre_7_4_0_compatibility/secret.py" = ["D105"] "airflow/metrics/protocols.py" = ["D105"] "airflow/metrics/validators.py" = ["D105"] -"airflow/models/dagrun.py" = ["D105"] -"airflow/models/dagwarning.py" = ["D105"] -"airflow/models/dataset.py" = ["D105"] -"airflow/models/expandinput.py" = ["D105"] -"airflow/models/log.py" = ["D105"] -"airflow/models/mappedoperator.py" = ["D105"] "airflow/models/param.py" = ["D105"] "airflow/models/pool.py" = ["D105"] "airflow/models/renderedtifields.py" = ["D105"] From 96f38f1f6a66e0c73643852842048087fff4e38b Mon Sep 17 00:00:00 2001 From: satoshi-sh Date: Mon, 18 Mar 2024 18:37:23 -0500 Subject: [PATCH 3/3] Reveted auto-formatted code. --- airflow/models/dagrun.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index edae65129d129..05901eddb90b6 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -897,11 +897,9 @@ def recalculate(self) -> _UnfinishedStates: self.run_id, self.start_date, self.end_date, - ( - (self.end_date - self.start_date).total_seconds() - if self.start_date and self.end_date - else None - ), + (self.end_date - self.start_date).total_seconds() + if self.start_date and self.end_date + else None, self._state, self.external_trigger, self.run_type,