From 85c0bb1aa7d1708d3419a44efcb4d5931d7c6ca7 Mon Sep 17 00:00:00 2001 From: Joao Amaral <7281460+joaopamaral@users.noreply.github.com> Date: Tue, 8 Oct 2024 19:50:39 -0300 Subject: [PATCH 001/349] Fix revoke stale on airflow < 2.10 --- .../fab/auth_manager/security_manager/override.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/airflow/providers/fab/auth_manager/security_manager/override.py b/airflow/providers/fab/auth_manager/security_manager/override.py index fad32c9f55ba5..17ba5d3efedb9 100644 --- a/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/airflow/providers/fab/auth_manager/security_manager/override.py @@ -1172,7 +1172,14 @@ def _get_or_create_dag_permission(action_name: str, dag_resource_name: str) -> P for perm in existing_dag_perms: non_admin_roles = [role for role in perm.role if role.name != "Admin"] for role in non_admin_roles: - target_perms_for_role = access_control.get(role.name, {}).get(resource_name, set()) + access_control_role = access_control.get(role.name) + target_perms_for_role = set() + if access_control_role: + if isinstance(access_control_role, set): + target_perms_for_role = access_control_role + elif isinstance(access_control_role, dict): + target_perms_for_role = access_control.get(role.name, {}).get(resource_name, + set()) if perm.action.name not in target_perms_for_role: self.log.info( "Revoking '%s' on DAG '%s' for role '%s'", From 6e353ebda577f4233bbd1e989b053446fa6dd978 Mon Sep 17 00:00:00 2001 From: Joao Amaral <7281460+joaopamaral@users.noreply.github.com> Date: Tue, 8 Oct 2024 20:57:25 -0300 Subject: [PATCH 002/349] fix static check --- .../providers/fab/auth_manager/security_manager/override.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/providers/fab/auth_manager/security_manager/override.py b/airflow/providers/fab/auth_manager/security_manager/override.py index 17ba5d3efedb9..bf996919ab4fa 100644 --- a/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/airflow/providers/fab/auth_manager/security_manager/override.py @@ -1178,8 +1178,9 @@ def _get_or_create_dag_permission(action_name: str, dag_resource_name: str) -> P if isinstance(access_control_role, set): target_perms_for_role = access_control_role elif isinstance(access_control_role, dict): - target_perms_for_role = access_control.get(role.name, {}).get(resource_name, - set()) + target_perms_for_role = access_control.get(role.name, {}).get( + resource_name, set() + ) if perm.action.name not in target_perms_for_role: self.log.info( "Revoking '%s' on DAG '%s' for role '%s'", From d3b1a3fc2ddb04ecfd53ad4a8b6b75d93e355f98 Mon Sep 17 00:00:00 2001 From: arnaubadia Date: Mon, 9 Sep 2024 18:43:33 +0200 Subject: [PATCH 003/349] Add return_immediately as argument to the PubSubPullSensor class (#41842) Co-authored-by: Arnau Badia Sampera --- airflow/providers/google/cloud/sensors/pubsub.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/airflow/providers/google/cloud/sensors/pubsub.py b/airflow/providers/google/cloud/sensors/pubsub.py index 55acee3d7034a..cb224d42979b7 100644 --- a/airflow/providers/google/cloud/sensors/pubsub.py +++ b/airflow/providers/google/cloud/sensors/pubsub.py @@ -69,6 +69,13 @@ class PubSubPullSensor(BaseSensorOperator): full subscription path. :param max_messages: The maximum number of messages to retrieve per PubSub pull request + :param return_immediately: If this field set to true, the system will + respond immediately even if it there are no messages available to + return in the ``Pull`` response. Otherwise, the system may wait + (for a bounded amount of time) until at least one message is available, + rather than returning no messages. Warning: setting this field to + ``true`` is discouraged because it adversely impacts the performance + of ``Pull`` operations. We recommend that users do not set this field. :param ack_messages: If True, each message will be acknowledged immediately rather than by any downstream tasks :param gcp_conn_id: The connection ID to use connecting to @@ -102,6 +109,7 @@ def __init__( project_id: str, subscription: str, max_messages: int = 5, + return_immediately: bool = True, ack_messages: bool = False, gcp_conn_id: str = "google_cloud_default", messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None, @@ -115,6 +123,7 @@ def __init__( self.project_id = project_id self.subscription = subscription self.max_messages = max_messages + self.return_immediately = return_immediately self.ack_messages = ack_messages self.messages_callback = messages_callback self.impersonation_chain = impersonation_chain @@ -132,7 +141,7 @@ def poke(self, context: Context) -> bool: project_id=self.project_id, subscription=self.subscription, max_messages=self.max_messages, - return_immediately=True, + return_immediately=self.return_immediately, ) handle_messages = self.messages_callback or self._default_message_callback From 1da4332844369c86eaa1f0e0e79ec13b73c7839e Mon Sep 17 00:00:00 2001 From: Vincent Kling Date: Mon, 9 Sep 2024 19:43:20 +0200 Subject: [PATCH 004/349] Move DAGs table to a reusable DataTable component (#42095) * Move DAGs table to a reusable DataTable component Signed-off-by: Vincent Kling * Add missing license Signed-off-by: Vincent Kling * Fix typing Signed-off-by: Vincent Kling --------- Signed-off-by: Vincent Kling --- airflow/ui/src/components/DataTable.test.tsx | 84 ++++++++ airflow/ui/src/components/DataTable.tsx | 202 +++++++++++++++++++ airflow/ui/src/dagsList.tsx | 179 +--------------- 3 files changed, 289 insertions(+), 176 deletions(-) create mode 100644 airflow/ui/src/components/DataTable.test.tsx create mode 100644 airflow/ui/src/components/DataTable.tsx diff --git a/airflow/ui/src/components/DataTable.test.tsx b/airflow/ui/src/components/DataTable.test.tsx new file mode 100644 index 0000000000000..f98b43fe5f610 --- /dev/null +++ b/airflow/ui/src/components/DataTable.test.tsx @@ -0,0 +1,84 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { describe, expect, it, vi } from "vitest"; +import { render, screen } from "@testing-library/react"; +import { DataTable } from "./DataTable.tsx"; +import { ColumnDef, PaginationState } from "@tanstack/react-table"; +import "@testing-library/jest-dom"; + +const columns: ColumnDef<{ name: string }>[] = [ + { + accessorKey: "name", + header: "Name", + cell: (info) => info.getValue(), + }, +]; + +const data = [{ name: "John Doe" }, { name: "Jane Doe" }]; + +const pagination: PaginationState = { pageIndex: 0, pageSize: 1 }; +const setPagination = vi.fn(); + +describe("DataTable", () => { + it("renders table with data", () => { + render( + + ); + + expect(screen.getByText("John Doe")).toBeInTheDocument(); + expect(screen.getByText("Jane Doe")).toBeInTheDocument(); + }); + + it("disables previous page button on first page", () => { + render( + + ); + + expect(screen.getByText("<<")).toBeDisabled(); + expect(screen.getByText("<")).toBeDisabled(); + }); + + it("disables next button when on last page", () => { + render( + + ); + + expect(screen.getByText(">>")).toBeDisabled(); + expect(screen.getByText(">")).toBeDisabled(); + }); +}); diff --git a/airflow/ui/src/components/DataTable.tsx b/airflow/ui/src/components/DataTable.tsx new file mode 100644 index 0000000000000..fbdd59a1b90cd --- /dev/null +++ b/airflow/ui/src/components/DataTable.tsx @@ -0,0 +1,202 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +"use client"; + +import { + ColumnDef, + Table as TanStackTable, + flexRender, + getCoreRowModel, + getExpandedRowModel, + getPaginationRowModel, + OnChangeFn, + PaginationState, + Row, + useReactTable, +} from "@tanstack/react-table"; +import { + Box, + Button, + Table as ChakraTable, + TableContainer, + Tbody, + Td, + Th, + Thead, + Tr, +} from "@chakra-ui/react"; +import React, { Fragment } from "react"; + +type DataTableProps = { + data: TData[]; + total?: number; + columns: ColumnDef[]; + renderSubComponent?: (props: { + row: Row; + }) => React.ReactElement | null; + getRowCanExpand?: (row: Row) => boolean; + pagination: PaginationState; + setPagination: OnChangeFn; +}; + +type PaginatorProps = { + table: TanStackTable; +}; + +const TablePaginator = ({ table }: PaginatorProps) => { + const pageInterval = 3; + const currentPageNumber = table.getState().pagination.pageIndex + 1; + const startPageNumber = Math.max(1, currentPageNumber - pageInterval); + const endPageNumber = Math.min( + table.getPageCount(), + startPageNumber + pageInterval * 2 + ); + const pageNumbers = []; + + for (let index = startPageNumber; index <= endPageNumber; index++) { + pageNumbers.push( + + ); + } + + return ( + + + + + {pageNumbers} + + + + ); +}; + +export function DataTable({ + data, + total = 0, + columns, + renderSubComponent = () => null, + getRowCanExpand = () => false, + pagination, + setPagination, +}: DataTableProps) { + const table = useReactTable({ + data, + columns, + getRowCanExpand, + getCoreRowModel: getCoreRowModel(), + getExpandedRowModel: getExpandedRowModel(), + getPaginationRowModel: getPaginationRowModel(), + onPaginationChange: setPagination, + rowCount: total, + manualPagination: true, + state: { + pagination, + }, + }); + + return ( + + + + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => { + return ( + + {header.isPlaceholder ? null : ( +
+ {flexRender( + header.column.columnDef.header, + header.getContext() + )} +
+ )} + + ); + })} + + ))} + + + {table.getRowModel().rows.map((row) => { + return ( + + + {/* first row is a normal row */} + {row.getVisibleCells().map((cell) => { + return ( + + {flexRender( + cell.column.columnDef.cell, + cell.getContext() + )} + + ); + })} + + {row.getIsExpanded() && ( + + {/* 2nd row is a custom 1 cell row */} + + {renderSubComponent({ row })} + + + )} + + ); + })} + +
+ +
+ ); +} diff --git a/airflow/ui/src/dagsList.tsx b/airflow/ui/src/dagsList.tsx index b5ac0f5751720..e8f06545de83d 100644 --- a/airflow/ui/src/dagsList.tsx +++ b/airflow/ui/src/dagsList.tsx @@ -17,35 +17,17 @@ * under the License. */ -import React, { Fragment } from "react"; - import { - useReactTable, - getCoreRowModel, - getExpandedRowModel, - getPaginationRowModel, ColumnDef, - flexRender, Row, OnChangeFn, PaginationState, - Table as TanStackTable, } from "@tanstack/react-table"; import { MdExpandMore } from "react-icons/md"; -import { - Box, - Code, - Table as ChakraTable, - Thead, - Button, - Td, - Th, - Tr, - Tbody, - TableContainer, -} from "@chakra-ui/react"; +import { Box, Code } from "@chakra-ui/react"; import { DAG } from "openapi/requests/types.gen"; +import { DataTable } from "src/components/DataTable.tsx"; const columns: ColumnDef[] = [ { @@ -87,161 +69,6 @@ const columns: ColumnDef[] = [ }, ]; -type TableProps = { - data: TData[]; - total: number | undefined; - columns: ColumnDef[]; - renderSubComponent: (props: { row: Row }) => React.ReactElement; - getRowCanExpand: (row: Row) => boolean; - pagination: PaginationState; - setPagination: OnChangeFn; -}; - -type PaginatorProps = { - table: TanStackTable; -}; - -const TablePaginator = ({ table }: PaginatorProps) => { - const pageInterval = 3; - const currentPageNumber = table.getState().pagination.pageIndex + 1; - const startPageNumber = Math.max(1, currentPageNumber - pageInterval); - const endPageNumber = Math.min( - table.getPageCount(), - startPageNumber + pageInterval * 2 - ); - const pageNumbers = []; - - for (let index = startPageNumber; index <= endPageNumber; index++) { - pageNumbers.push( - - ); - } - - return ( - - - - - {pageNumbers} - - - - ); -}; - -const Table = ({ - data, - total, - columns, - renderSubComponent, - getRowCanExpand, - pagination, - setPagination, -}: TableProps) => { - const table = useReactTable({ - data, - columns, - getRowCanExpand, - getCoreRowModel: getCoreRowModel(), - getExpandedRowModel: getExpandedRowModel(), - getPaginationRowModel: getPaginationRowModel(), - onPaginationChange: setPagination, - rowCount: total ?? 0, - manualPagination: true, - state: { - pagination, - }, - }); - - return ( - - - - {table.getHeaderGroups().map((headerGroup) => ( - - {headerGroup.headers.map((header) => { - return ( - - {header.isPlaceholder ? null : ( -
- {flexRender( - header.column.columnDef.header, - header.getContext() - )} -
- )} - - ); - })} - - ))} - - - {table.getRowModel().rows.map((row) => { - return ( - - - {/* first row is a normal row */} - {row.getVisibleCells().map((cell) => { - return ( - - {flexRender( - cell.column.columnDef.cell, - cell.getContext() - )} - - ); - })} - - {row.getIsExpanded() && ( - - {/* 2nd row is a custom 1 cell row */} - - {renderSubComponent({ row })} - - - )} - - ); - })} - -
- -
- ); -}; - const renderSubComponent = ({ row }: { row: Row }) => { return (
@@ -262,7 +89,7 @@ export const DagsList = ({
   setPagination: OnChangeFn;
 }) => {
   return (
-    
Date: Mon, 9 Sep 2024 13:15:30 -0700
Subject: [PATCH 005/349] Fix failing compatibility test: importing missing
 saml for old airflows (#42113)

The Compatibility tests for AWS are failing after recent changes
as they attempt to import saml library before skipping the tests
when the import is missing.
---
 pyproject.toml                                                | 1 +
 .../providers/amazon/aws/tests/test_aws_auth_manager.py       | 4 ++--
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index 3c62b110608f7..7c466743a935c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -387,6 +387,7 @@ combine-as-imports = true
 "airflow/security/kerberos.py" = ["E402"]
 "airflow/security/utils.py" = ["E402"]
 "tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py" = ["E402"]
+"tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py" = ["E402"]
 "tests/providers/common/io/xcom/test_backend.py" = ["E402"]
 "tests/providers/elasticsearch/log/elasticmock/__init__.py" = ["E402"]
 "tests/providers/elasticsearch/log/elasticmock/utilities/__init__.py" = ["E402"]
diff --git a/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py
index 44c0bcecc3b49..792df7b155d04 100644
--- a/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py
+++ b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py
@@ -22,13 +22,13 @@
 import boto3
 import pytest
 
+pytest.importorskip("onelogin")
+
 from airflow.www import app as application
 from tests.system.providers.amazon.aws.utils import set_env_id
 from tests.test_utils.config import conf_vars
 from tests.test_utils.www import check_content_in_response
 
-pytest.importorskip("onelogin")
-
 SAML_METADATA_URL = "/saml/metadata"
 SAML_METADATA_PARSED = {
     "idp": {

From f28146e31c17a2faf5b3628a94dd0bd3436c98c6 Mon Sep 17 00:00:00 2001
From: GPK 
Date: Mon, 9 Sep 2024 21:27:49 +0100
Subject: [PATCH 006/349] Aws executor docs update (#42092)

---
 docs/apache-airflow-providers-amazon/executors/general.rst | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/docs/apache-airflow-providers-amazon/executors/general.rst b/docs/apache-airflow-providers-amazon/executors/general.rst
index 94d0248008a9f..5e4ba28c6e9c5 100644
--- a/docs/apache-airflow-providers-amazon/executors/general.rst
+++ b/docs/apache-airflow-providers-amazon/executors/general.rst
@@ -326,7 +326,7 @@ Create an ECR Repository
 
 This script should be run on the host(s) running the Airflow Scheduler and Webserver, before those processes are started.
 
-The script sets environment variables that configure Airflow to use the Batch Executor and provide necessary information for task execution. Any other configuration changes made (such as for remote logging) should be added to this example script to keep configuration consistent across the Airflow environment.
+The script sets environment variables that configure Airflow to use the |executorName| Executor and provide necessary information for task execution. Any other configuration changes made (such as for remote logging) should be added to this example script to keep configuration consistent across the Airflow environment.
 
 Initialize the Airflow DB
 ~~~~~~~~~~~~~~~~~~~~~~~~~

From 97f921ed7b2718d92226767482dbbfb71a172241 Mon Sep 17 00:00:00 2001
From: Daniel Standish <15932138+dstandish@users.noreply.github.com>
Date: Mon, 9 Sep 2024 13:51:02 -0700
Subject: [PATCH 007/349] Log info not error in schedule when task has failure
 (#42116)

You have to think about the context where this is emitted.  It is emitted in the scheduler.  It's not a scheduler error.
---
 airflow/models/dagrun.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index db7cb9443a1fe..c932958861f7a 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -823,7 +823,7 @@ def recalculate(self) -> _UnfinishedStates:
 
         # if all tasks finished and at least one failed, the run failed
         if not unfinished.tis and any(x.state in State.failed_states for x in tis_for_dagrun_state):
-            self.log.error("Marking run %s failed", self)
+            self.log.info("Marking run %s failed", self)
             self.set_state(DagRunState.FAILED)
             self.notify_dagrun_state_changed(msg="task_failure")
 

From f94bca4fb80ab80b855b49c1bf57121ce963095a Mon Sep 17 00:00:00 2001
From: Brent Bovenzi 
Date: Mon, 9 Sep 2024 23:03:41 -0400
Subject: [PATCH 008/349] Improve DAGs table UI (#42119)

* Rebase and fix filter pagination

* Add pluralize test
---
 .gitignore                                    |   1 +
 airflow/ui/package.json                       |   2 +
 airflow/ui/pnpm-lock.yaml                     | 130 ++++++++++
 airflow/ui/src/app.tsx                        |  28 +--
 .../{ => DataTable}/DataTable.test.tsx        |   0
 .../components/{ => DataTable}/DataTable.tsx  |  11 +-
 airflow/ui/src/components/DataTable/index.tsx |  20 ++
 airflow/ui/src/dagsList.tsx                   | 230 +++++++++++++-----
 airflow/ui/src/main.tsx                       |   3 +-
 airflow/ui/src/theme.ts                       |  39 +++
 airflow/ui/src/utils/pluralize.test.ts        |  85 +++++++
 airflow/ui/src/utils/pluralize.ts             |  30 +++
 12 files changed, 488 insertions(+), 91 deletions(-)
 rename airflow/ui/src/components/{ => DataTable}/DataTable.test.tsx (100%)
 rename airflow/ui/src/components/{ => DataTable}/DataTable.tsx (95%)
 create mode 100644 airflow/ui/src/components/DataTable/index.tsx
 create mode 100644 airflow/ui/src/utils/pluralize.test.ts
 create mode 100644 airflow/ui/src/utils/pluralize.ts

diff --git a/.gitignore b/.gitignore
index 3505a4ed8abfe..40845794e3cb7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -172,6 +172,7 @@ pnpm-debug.log*
 .vscode/*
 !.vscode/extensions.json
 /.vite/
+/.pnpm-store/
 
 # Airflow log files when airflow is run locally
 airflow-*.err
diff --git a/airflow/ui/package.json b/airflow/ui/package.json
index 257df8c3dcca6..d78e2f1c69317 100644
--- a/airflow/ui/package.json
+++ b/airflow/ui/package.json
@@ -15,12 +15,14 @@
     "test": "vitest run"
   },
   "dependencies": {
+    "@chakra-ui/anatomy": "^2.2.2",
     "@chakra-ui/react": "^2.8.2",
     "@emotion/react": "^11.13.3",
     "@emotion/styled": "^11.13.0",
     "@tanstack/react-query": "^5.52.1",
     "@tanstack/react-table": "^8.20.1",
     "axios": "^1.7.4",
+    "chakra-react-select": "^4.9.2",
     "framer-motion": "^11.3.29",
     "react": "^18.3.1",
     "react-dom": "^18.3.1",
diff --git a/airflow/ui/pnpm-lock.yaml b/airflow/ui/pnpm-lock.yaml
index 0ff90475d4acd..effe48f41ed40 100644
--- a/airflow/ui/pnpm-lock.yaml
+++ b/airflow/ui/pnpm-lock.yaml
@@ -25,6 +25,9 @@ importers:
 
   .:
     dependencies:
+      '@chakra-ui/anatomy':
+        specifier: ^2.2.2
+        version: 2.2.2
       '@chakra-ui/react':
         specifier: ^2.8.2
         version: 2.8.2(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@emotion/styled@11.13.0(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@types/react@18.3.4)(react@18.3.1))(@types/react@18.3.4)(framer-motion@11.3.29(@emotion/is-prop-valid@1.3.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
@@ -43,6 +46,9 @@ importers:
       axios:
         specifier: ^1.7.4
         version: 1.7.4
+      chakra-react-select:
+        specifier: ^4.9.2
+        version: 4.9.2(ygqhzpuo3vwx3we5k6j4i32nqi)
       framer-motion:
         specifier: ^11.3.29
         version: 11.3.29(@emotion/is-prop-valid@1.3.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
@@ -879,6 +885,15 @@ packages:
     resolution: {integrity: sha512-BsWiH1yFGjXXS2yvrf5LyuoSIIbPrGUWob917o+BTKuZ7qJdxX8aJLRxs1fS9n6r7vESrq1OUqb68dANcFXuQQ==}
     engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0}
 
+  '@floating-ui/core@1.6.7':
+    resolution: {integrity: sha512-yDzVT/Lm101nQ5TCVeK65LtdN7Tj4Qpr9RTXJ2vPFLqtLxwOrpoxAHAJI8J3yYWUc40J0BDBheaitK5SJmno2g==}
+
+  '@floating-ui/dom@1.6.10':
+    resolution: {integrity: sha512-fskgCFv8J8OamCmyun8MfjB1Olfn+uZKjOKZ0vhYF3gRmEUXcGOjxWL8bBr7i4kIuPZ2KD2S3EUIOxnjC8kl2A==}
+
+  '@floating-ui/utils@0.2.7':
+    resolution: {integrity: sha512-X8R8Oj771YRl/w+c1HqAC1szL8zWQRwFvgDwT129k9ACdBoud/+/rX9V0qiMl6LWUdP9voC2nDVZYPMQQsb6eA==}
+
   '@hey-api/openapi-ts@0.52.0':
     resolution: {integrity: sha512-DA3Zf5ONxMK1PUkK88lAuYbXMgn5BvU5sjJdTAO2YOn6Eu/9ovilBztMzvu8pyY44PmL3n4ex4+f+XIwvgfhvw==}
     engines: {node: ^18.0.0 || >=20.0.0}
@@ -1171,6 +1186,9 @@ packages:
   '@types/react-dom@18.3.0':
     resolution: {integrity: sha512-EhwApuTmMBmXuFOikhQLIBUn6uFg81SwLMOAUgodJF14SOBOCMdU04gDoYi0WOJJHD144TL32z4yDqCW3dnkQg==}
 
+  '@types/react-transition-group@4.4.11':
+    resolution: {integrity: sha512-RM05tAniPZ5DZPzzNFP+DmrcOdD0efDUxMy3145oljWSl3x9ZV5vhme98gTxFrj2lhXvmGNnUiuDyJgY9IKkNA==}
+
   '@types/react@18.3.4':
     resolution: {integrity: sha512-J7W30FTdfCxDDjmfRM+/JqLHBIyl7xUIp9kwK637FGmY7+mkSFSe6L4jpZzhj5QMfLssSDP4/i75AKkrdC7/Jw==}
 
@@ -1436,6 +1454,20 @@ packages:
     resolution: {integrity: sha512-pT1ZgP8rPNqUgieVaEY+ryQr6Q4HXNg8Ei9UnLUrjN4IA7dvQC5JB+/kxVcPNDHyBcc/26CXPkbNzq3qwrOEKA==}
     engines: {node: '>=12'}
 
+  chakra-react-select@4.9.2:
+    resolution: {integrity: sha512-uhvKAJ1I2lbIwdn+wx0YvxX5rtQVI0gXL0apx0CXm3blIxk7qf6YuCh2TnGuGKst8gj8jUFZyhYZiGlcvgbBRQ==}
+    peerDependencies:
+      '@chakra-ui/form-control': ^2.0.0
+      '@chakra-ui/icon': ^3.0.0
+      '@chakra-ui/layout': ^2.0.0
+      '@chakra-ui/media-query': ^3.0.0
+      '@chakra-ui/menu': ^2.0.0
+      '@chakra-ui/spinner': ^2.0.0
+      '@chakra-ui/system': ^2.0.0
+      '@emotion/react': ^11.8.1
+      react: ^18.0.0
+      react-dom: ^18.0.0
+
   chalk@2.4.2:
     resolution: {integrity: sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==}
     engines: {node: '>=4'}
@@ -1593,6 +1625,9 @@ packages:
   dom-accessibility-api@0.6.3:
     resolution: {integrity: sha512-7ZgogeTnjuHbo+ct10G9Ffp0mif17idi0IyWNVA/wcwcm7NPOD/WEHVP3n7n3MhXqxoIYm8d6MuZohYWIZ4T3w==}
 
+  dom-helpers@5.2.1:
+    resolution: {integrity: sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==}
+
   dotenv@16.4.5:
     resolution: {integrity: sha512-ZmdL2rui+eB2YwhsWzjInR8LldtZHGDoQ1ugH85ppHKwpUHL7j7rN0Ti9NCnGiQbhaZ11FpR+7ao1dNsmduNUg==}
     engines: {node: '>=12'}
@@ -2174,6 +2209,9 @@ packages:
   magic-string@0.30.11:
     resolution: {integrity: sha512-+Wri9p0QHMy+545hKww7YAu5NyzF8iomPL/RQazugQ9+Ez4Ic3mERMd8ZTX5rfK944j+560ZJi8iAwgak1Ac7A==}
 
+  memoize-one@6.0.0:
+    resolution: {integrity: sha512-rkpe71W0N0c0Xz6QD0eJETuWAJGnJ9afsl1srmwPrI+yBCkge5EycXXbYRyvL29zZVUWQCY7InPRCv3GDXuZNw==}
+
   merge-stream@2.0.0:
     resolution: {integrity: sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==}
 
@@ -2468,6 +2506,12 @@ packages:
       '@types/react':
         optional: true
 
+  react-select@5.8.0:
+    resolution: {integrity: sha512-TfjLDo58XrhP6VG5M/Mi56Us0Yt8X7xD6cDybC7yoRMUNm7BGO7qk8J0TLQOua/prb8vUOtsfnXZwfm30HGsAA==}
+    peerDependencies:
+      react: ^16.8.0 || ^17.0.0 || ^18.0.0
+      react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0
+
   react-style-singleton@2.2.1:
     resolution: {integrity: sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g==}
     engines: {node: '>=10'}
@@ -2478,6 +2522,12 @@ packages:
       '@types/react':
         optional: true
 
+  react-transition-group@4.4.5:
+    resolution: {integrity: sha512-pZcd1MCJoiKiBR2NRxeCRg13uCXbydPnmB4EOeRrY7480qNWO8IIgQG6zlDkm6uRMsURXPuKq0GWtiM59a5Q6g==}
+    peerDependencies:
+      react: '>=16.6.0'
+      react-dom: '>=16.6.0'
+
   react@18.3.1:
     resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==}
     engines: {node: '>=0.10.0'}
@@ -2768,6 +2818,15 @@ packages:
       '@types/react':
         optional: true
 
+  use-isomorphic-layout-effect@1.1.2:
+    resolution: {integrity: sha512-49L8yCO3iGT/ZF9QttjwLF/ZD9Iwto5LnH5LmEdk/6cFmXddqi2ulF0edxTwjj+7mqvpVVGQWvbXZdn32wRSHA==}
+    peerDependencies:
+      '@types/react': '*'
+      react: ^16.8.0 || ^17.0.0 || ^18.0.0
+    peerDependenciesMeta:
+      '@types/react':
+        optional: true
+
   use-sidecar@1.1.2:
     resolution: {integrity: sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==}
     engines: {node: '>=10'}
@@ -3871,6 +3930,17 @@ snapshots:
 
   '@eslint/object-schema@2.1.4': {}
 
+  '@floating-ui/core@1.6.7':
+    dependencies:
+      '@floating-ui/utils': 0.2.7
+
+  '@floating-ui/dom@1.6.10':
+    dependencies:
+      '@floating-ui/core': 1.6.7
+      '@floating-ui/utils': 0.2.7
+
+  '@floating-ui/utils@0.2.7': {}
+
   '@hey-api/openapi-ts@0.52.0(typescript@5.5.4)':
     dependencies:
       '@apidevtools/json-schema-ref-parser': 11.6.4
@@ -4115,6 +4185,10 @@ snapshots:
     dependencies:
       '@types/react': 18.3.4
 
+  '@types/react-transition-group@4.4.11':
+    dependencies:
+      '@types/react': 18.3.4
+
   '@types/react@18.3.4':
     dependencies:
       '@types/prop-types': 15.7.12
@@ -4465,6 +4539,23 @@ snapshots:
       loupe: 3.1.1
       pathval: 2.0.0
 
+  chakra-react-select@4.9.2(ygqhzpuo3vwx3we5k6j4i32nqi):
+    dependencies:
+      '@chakra-ui/form-control': 2.2.0(@chakra-ui/system@2.6.2(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@emotion/styled@11.13.0(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@types/react@18.3.4)(react@18.3.1))(react@18.3.1))(react@18.3.1)
+      '@chakra-ui/icon': 3.2.0(@chakra-ui/system@2.6.2(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@emotion/styled@11.13.0(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@types/react@18.3.4)(react@18.3.1))(react@18.3.1))(react@18.3.1)
+      '@chakra-ui/layout': 2.3.1(@chakra-ui/system@2.6.2(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@emotion/styled@11.13.0(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@types/react@18.3.4)(react@18.3.1))(react@18.3.1))(react@18.3.1)
+      '@chakra-ui/media-query': 3.3.0(@chakra-ui/system@2.6.2(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@emotion/styled@11.13.0(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@types/react@18.3.4)(react@18.3.1))(react@18.3.1))(react@18.3.1)
+      '@chakra-ui/menu': 2.2.1(@chakra-ui/system@2.6.2(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@emotion/styled@11.13.0(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@types/react@18.3.4)(react@18.3.1))(react@18.3.1))(framer-motion@11.3.29(@emotion/is-prop-valid@1.3.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
+      '@chakra-ui/spinner': 2.1.0(@chakra-ui/system@2.6.2(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@emotion/styled@11.13.0(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@types/react@18.3.4)(react@18.3.1))(react@18.3.1))(react@18.3.1)
+      '@chakra-ui/system': 2.6.2(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@emotion/styled@11.13.0(@emotion/react@11.13.3(@types/react@18.3.4)(react@18.3.1))(@types/react@18.3.4)(react@18.3.1))(react@18.3.1)
+      '@emotion/react': 11.13.3(@types/react@18.3.4)(react@18.3.1)
+      react: 18.3.1
+      react-dom: 18.3.1(react@18.3.1)
+      react-select: 5.8.0(@types/react@18.3.4)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
+    transitivePeerDependencies:
+      - '@types/react'
+      - supports-color
+
   chalk@2.4.2:
     dependencies:
       ansi-styles: 3.2.1
@@ -4619,6 +4710,11 @@ snapshots:
 
   dom-accessibility-api@0.6.3: {}
 
+  dom-helpers@5.2.1:
+    dependencies:
+      '@babel/runtime': 7.25.4
+      csstype: 3.1.3
+
   dotenv@16.4.5: {}
 
   eastasianwidth@0.2.0: {}
@@ -5295,6 +5391,8 @@ snapshots:
     dependencies:
       '@jridgewell/sourcemap-codec': 1.5.0
 
+  memoize-one@6.0.0: {}
+
   merge-stream@2.0.0: {}
 
   merge2@1.4.1: {}
@@ -5565,6 +5663,23 @@ snapshots:
     optionalDependencies:
       '@types/react': 18.3.4
 
+  react-select@5.8.0(@types/react@18.3.4)(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
+    dependencies:
+      '@babel/runtime': 7.25.4
+      '@emotion/cache': 11.13.1
+      '@emotion/react': 11.13.3(@types/react@18.3.4)(react@18.3.1)
+      '@floating-ui/dom': 1.6.10
+      '@types/react-transition-group': 4.4.11
+      memoize-one: 6.0.0
+      prop-types: 15.8.1
+      react: 18.3.1
+      react-dom: 18.3.1(react@18.3.1)
+      react-transition-group: 4.4.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
+      use-isomorphic-layout-effect: 1.1.2(@types/react@18.3.4)(react@18.3.1)
+    transitivePeerDependencies:
+      - '@types/react'
+      - supports-color
+
   react-style-singleton@2.2.1(@types/react@18.3.4)(react@18.3.1):
     dependencies:
       get-nonce: 1.0.1
@@ -5574,6 +5689,15 @@ snapshots:
     optionalDependencies:
       '@types/react': 18.3.4
 
+  react-transition-group@4.4.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
+    dependencies:
+      '@babel/runtime': 7.25.4
+      dom-helpers: 5.2.1
+      loose-envify: 1.4.0
+      prop-types: 15.8.1
+      react: 18.3.1
+      react-dom: 18.3.1(react@18.3.1)
+
   react@18.3.1:
     dependencies:
       loose-envify: 1.4.0
@@ -5912,6 +6036,12 @@ snapshots:
     optionalDependencies:
       '@types/react': 18.3.4
 
+  use-isomorphic-layout-effect@1.1.2(@types/react@18.3.4)(react@18.3.1):
+    dependencies:
+      react: 18.3.1
+    optionalDependencies:
+      '@types/react': 18.3.4
+
   use-sidecar@1.1.2(@types/react@18.3.4)(react@18.3.1):
     dependencies:
       detect-node-es: 1.1.0
diff --git a/airflow/ui/src/app.tsx b/airflow/ui/src/app.tsx
index 2f1bc556793ff..ab2789cefb111 100644
--- a/airflow/ui/src/app.tsx
+++ b/airflow/ui/src/app.tsx
@@ -17,40 +17,16 @@
  * under the License.
  */
 
-import { useState } from "react";
-import { Box, Spinner } from "@chakra-ui/react";
-import { PaginationState } from "@tanstack/react-table";
-
-import { useDagServiceGetDags } from "openapi/queries";
+import { Box } from "@chakra-ui/react";
 import { DagsList } from "src/dagsList";
 import { Nav } from "src/nav";
 
 export const App = () => {
-  // TODO: Change this to be taken from airflow.cfg
-  const pageSize = 50;
-  const [pagination, setPagination] = useState({
-    pageIndex: 0,
-    pageSize: pageSize,
-  });
-
-  const { data, isLoading } = useDagServiceGetDags({
-    limit: pagination.pageSize,
-    offset: pagination.pageIndex * pagination.pageSize,
-  });
-
   return (
     
); diff --git a/airflow/ui/src/components/DataTable.test.tsx b/airflow/ui/src/components/DataTable/DataTable.test.tsx similarity index 100% rename from airflow/ui/src/components/DataTable.test.tsx rename to airflow/ui/src/components/DataTable/DataTable.test.tsx diff --git a/airflow/ui/src/components/DataTable.tsx b/airflow/ui/src/components/DataTable/DataTable.tsx similarity index 95% rename from airflow/ui/src/components/DataTable.tsx rename to airflow/ui/src/components/DataTable/DataTable.tsx index fbdd59a1b90cd..4b4b1251f8428 100644 --- a/airflow/ui/src/components/DataTable.tsx +++ b/airflow/ui/src/components/DataTable/DataTable.tsx @@ -17,8 +17,6 @@ * under the License. */ -"use client"; - import { ColumnDef, Table as TanStackTable, @@ -41,6 +39,7 @@ import { Th, Thead, Tr, + useColorModeValue, } from "@chakra-ui/react"; import React, { Fragment } from "react"; @@ -143,10 +142,12 @@ export function DataTable({ }, }); + const theadBg = useColorModeValue("white", "gray.800"); + return ( - - -
+ + + {table.getHeaderGroups().map((headerGroup) => ( {headerGroup.headers.map((header) => { diff --git a/airflow/ui/src/components/DataTable/index.tsx b/airflow/ui/src/components/DataTable/index.tsx new file mode 100644 index 0000000000000..495cf4e4f389b --- /dev/null +++ b/airflow/ui/src/components/DataTable/index.tsx @@ -0,0 +1,20 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export * from "./DataTable"; diff --git a/airflow/ui/src/dagsList.tsx b/airflow/ui/src/dagsList.tsx index e8f06545de83d..b6c4e6949b97a 100644 --- a/airflow/ui/src/dagsList.tsx +++ b/airflow/ui/src/dagsList.tsx @@ -17,40 +17,64 @@ * under the License. */ +import { useState } from "react"; +import { ColumnDef, PaginationState } from "@tanstack/react-table"; import { - ColumnDef, - Row, - OnChangeFn, - PaginationState, -} from "@tanstack/react-table"; -import { MdExpandMore } from "react-icons/md"; -import { Box, Code } from "@chakra-ui/react"; + Badge, + Button, + ButtonProps, + Checkbox, + Heading, + HStack, + Input, + InputGroup, + InputGroupProps, + InputLeftElement, + InputProps, + InputRightElement, + Select, + Spinner, + Text, + VStack, +} from "@chakra-ui/react"; +import { Select as ReactSelect } from "chakra-react-select"; +import { FiSearch } from "react-icons/fi"; import { DAG } from "openapi/requests/types.gen"; -import { DataTable } from "src/components/DataTable.tsx"; +import { useDagServiceGetDags } from "openapi/queries"; +import { DataTable } from "./components/DataTable"; +import { pluralize } from "./utils/pluralize"; + +const SearchBar = ({ + groupProps, + inputProps, + buttonProps, +}: { + groupProps?: InputGroupProps; + inputProps?: InputProps; + buttonProps?: ButtonProps; +}) => ( + + + + + + + + + +); const columns: ColumnDef[] = [ - { - id: "expander", - header: () => null, - cell: ({ row }) => { - return row.getCanExpand() ? ( - - ) : null; - }, - }, { accessorKey: "dag_display_name", header: "DAG", @@ -61,42 +85,130 @@ const columns: ColumnDef[] = [ }, { accessorKey: "timetable_description", - header: () => "Timetable", + header: () => "Schedule", + cell: (info) => + info.getValue() !== "Never, external triggers only" + ? info.getValue() + : undefined, + }, + { + accessorKey: "next_dagrun", + header: "Next DAG Run", + }, + { + accessorKey: "owner", + header: () => "Owner", + cell: ({ row }) => ( + + {row.original.owners?.map((owner) => {owner})} + + ), }, { - accessorKey: "description", - header: () => "Description", + accessorKey: "tags", + header: () => "Tags", + cell: ({ row }) => ( + + {row.original.tags?.map((tag) => ( + {tag.name} + ))} + + ), }, ]; -const renderSubComponent = ({ row }: { row: Row }) => { - return ( -
-      {JSON.stringify(row.original, null, 2)}
-    
- ); -}; +const QuickFilterButton = ({ children, ...rest }: ButtonProps) => ( + +); + +export const DagsList = () => { + // TODO: Change this to be taken from airflow.cfg + const pageSize = 50; + const [pagination, setPagination] = useState({ + pageIndex: 0, + pageSize: pageSize, + }); + const [showPaused, setShowPaused] = useState(true); + const [orderBy, setOrderBy] = useState(); + + const { data, isLoading } = useDagServiceGetDags({ + limit: pagination.pageSize, + offset: pagination.pageIndex * pagination.pageSize, + onlyActive: true, + paused: showPaused, + orderBy, + }); -export const DagsList = ({ - data, - total, - pagination, - setPagination, -}: { - data: DAG[]; - total: number | undefined; - pagination: PaginationState; - setPagination: OnChangeFn; -}) => { return ( - true} - renderSubComponent={renderSubComponent} - pagination={pagination} - setPagination={setPagination} - /> + <> + {isLoading && } + {!isLoading && !!data?.dags && ( + <> + + + + + + All + Failed + Running + Successful + + { + setShowPaused(!showPaused); + setPagination({ + ...pagination, + pageIndex: 0, + }); + }} + > + Show Paused DAGs + + + + + + + + + + {pluralize("DAG", data.total_entries)} + + + + + true} + pagination={pagination} + setPagination={setPagination} + /> + + )} + ); }; diff --git a/airflow/ui/src/main.tsx b/airflow/ui/src/main.tsx index fa45680d264d6..be9196defdb18 100644 --- a/airflow/ui/src/main.tsx +++ b/airflow/ui/src/main.tsx @@ -22,6 +22,7 @@ import { ChakraProvider } from "@chakra-ui/react"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { App } from "src/app.tsx"; import axios, { AxiosResponse } from "axios"; +import theme from "./theme"; const queryClient = new QueryClient({ defaultOptions: { @@ -57,7 +58,7 @@ axios.interceptors.response.use( const root = createRoot(document.getElementById("root")!); root.render( - + diff --git a/airflow/ui/src/theme.ts b/airflow/ui/src/theme.ts index 03247b8cc7556..eee148ad09f4c 100644 --- a/airflow/ui/src/theme.ts +++ b/airflow/ui/src/theme.ts @@ -18,6 +18,44 @@ */ import { extendTheme } from "@chakra-ui/react"; +import { tableAnatomy } from "@chakra-ui/anatomy"; +import { createMultiStyleConfigHelpers } from "@chakra-ui/react"; + +const { definePartsStyle, defineMultiStyleConfig } = + createMultiStyleConfigHelpers(tableAnatomy.keys); + +const baseStyle = definePartsStyle((props) => { + const { colorScheme: c, colorMode } = props; + return { + thead: { + tr: { + th: { + borderBottomWidth: 0, + }, + }, + }, + tbody: { + tr: { + "&:nth-of-type(odd)": { + "th, td": { + borderBottomWidth: "0px", + borderColor: colorMode === "light" ? `${c}.50` : `gray.900`, + }, + td: { + background: colorMode === "light" ? `${c}.50` : `gray.900`, + }, + }, + "&:nth-of-type(even)": { + "th, td": { + borderBottomWidth: "0px", + }, + }, + }, + }, + }; +}); + +export const tableTheme = defineMultiStyleConfig({ baseStyle }); const theme = extendTheme({ config: { @@ -36,6 +74,7 @@ const theme = extendTheme({ fontSize: "md", }, }, + Table: tableTheme, }, }); diff --git a/airflow/ui/src/utils/pluralize.test.ts b/airflow/ui/src/utils/pluralize.test.ts new file mode 100644 index 0000000000000..ead9ff32a044e --- /dev/null +++ b/airflow/ui/src/utils/pluralize.test.ts @@ -0,0 +1,85 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { describe, expect, it } from "vitest"; + +import { pluralize } from "./pluralize"; + +type PluralizeTestCase = { + in: [string, number, (string | null)?, boolean?]; + out: string; +}; + +const pluralizeTestCases: PluralizeTestCase[] = [ + { in: ["DAG", 0, undefined, undefined], out: "0 DAGs" }, + { in: ["DAG", 1, undefined, undefined], out: "1 DAG" }, + { in: ["DAG", 12000, undefined, undefined], out: "12,000 DAGs" }, + { in: ["DAG", 12000000, undefined, undefined], out: "12,000,000 DAGs" }, + { in: ["DAG", 0, undefined, undefined], out: "0 DAGs" }, + { in: ["DAG", 1, undefined, undefined], out: "1 DAG" }, + { in: ["DAG", 12000, undefined, undefined], out: "12,000 DAGs" }, + { in: ["DAG", 12000000, undefined, undefined], out: "12,000,000 DAGs" }, + // Omit the count. + { in: ["DAG", 0, null, true], out: "DAGs" }, + { in: ["DAG", 1, null, true], out: "DAG" }, + { in: ["DAG", 12000, null, true], out: "DAGs" }, + { in: ["DAG", 12000000, null, true], out: "DAGs" }, + { in: ["DAG", 0, null, true], out: "DAGs" }, + { in: ["DAG", 1, null, true], out: "DAG" }, + { in: ["DAG", 12000, null, true], out: "DAGs" }, + { in: ["DAG", 12000000, null, true], out: "DAGs" }, + // The casing of the string is preserved. + { in: ["goose", 0, "geese", undefined], out: "0 geese" }, + { in: ["goose", 1, "geese", undefined], out: "1 goose" }, + // The plural form is different from the singular form. + { in: ["Goose", 0, "Geese", undefined], out: "0 Geese" }, + { in: ["Goose", 1, "Geese", undefined], out: "1 Goose" }, + { in: ["Goose", 12000, "Geese", undefined], out: "12,000 Geese" }, + { in: ["Goose", 12000000, "Geese", undefined], out: "12,000,000 Geese" }, + { in: ["Goose", 0, "Geese", undefined], out: "0 Geese" }, + { in: ["Goose", 1, "Geese", undefined], out: "1 Goose" }, + { in: ["Goose", 12000, "Geese", undefined], out: "12,000 Geese" }, + { in: ["Goose", 12000000, "Geese", undefined], out: "12,000,000 Geese" }, + // In the case of "Moose", the plural is the same as the singular and you + // probably wouldn't elect to use this function at all, but there could be + // cases where dynamic data makes it unavoidable. + { in: ["Moose", 0, "Moose", undefined], out: "0 Moose" }, + { in: ["Moose", 1, "Moose", undefined], out: "1 Moose" }, + { in: ["Moose", 12000, "Moose", undefined], out: "12,000 Moose" }, + { in: ["Moose", 12000000, "Moose", undefined], out: "12,000,000 Moose" }, + { in: ["Moose", 0, "Moose", undefined], out: "0 Moose" }, + { in: ["Moose", 1, "Moose", undefined], out: "1 Moose" }, + { in: ["Moose", 12000, "Moose", undefined], out: "12,000 Moose" }, + { in: ["Moose", 12000000, "Moose", undefined], out: "12,000,000 Moose" }, +]; + +describe("pluralize", () => { + it("case", () => { + pluralizeTestCases.forEach((testCase) => + expect( + pluralize( + testCase.in[0], + testCase.in[1], + testCase.in[2], + testCase.in[3] + ) + ).toEqual(testCase.out) + ); + }); +}); diff --git a/airflow/ui/src/utils/pluralize.ts b/airflow/ui/src/utils/pluralize.ts new file mode 100644 index 0000000000000..0fdddb1c69b90 --- /dev/null +++ b/airflow/ui/src/utils/pluralize.ts @@ -0,0 +1,30 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export const pluralize = ( + singularLabel: string, + count: number | undefined = 0, + pluralLabel?: string | null, + omitCount?: boolean +): string => { + const pluralized = + count === 1 ? singularLabel : pluralLabel || `${singularLabel}s`; + // toLocaleString() will add commas for thousands, millions, etc. + return `${omitCount ? "" : `${count.toLocaleString()} `}${pluralized}`; +}; From c521030cc38ddffb7c825947fa656c36bd4dfab2 Mon Sep 17 00:00:00 2001 From: Brent Bovenzi Date: Tue, 10 Sep 2024 09:56:22 -0400 Subject: [PATCH 009/349] Move Dags list state to url params (#42124) * Rebase and fix filter pagination * Add pluralize test * Move table state to url params --- .gitignore | 2 +- airflow/ui/package.json | 3 +- airflow/ui/pnpm-lock.yaml | 34 ++++++ .../components/DataTable/DataTable.test.tsx | 17 +-- .../ui/src/components/DataTable/DataTable.tsx | 110 +++++++++++++----- airflow/ui/src/components/DataTable/index.tsx | 4 +- .../components/DataTable/searchParams.test.ts | 64 ++++++++++ .../src/components/DataTable/searchParams.ts | 97 +++++++++++++++ airflow/ui/src/components/DataTable/types.ts | 25 ++++ .../components/DataTable/useTableUrlState.ts | 59 ++++++++++ airflow/ui/src/dagsList.tsx | 97 ++++++++------- airflow/ui/src/main.tsx | 5 +- airflow/ui/src/utils/test.tsx | 9 +- 13 files changed, 445 insertions(+), 81 deletions(-) create mode 100644 airflow/ui/src/components/DataTable/searchParams.test.ts create mode 100644 airflow/ui/src/components/DataTable/searchParams.ts create mode 100644 airflow/ui/src/components/DataTable/types.ts create mode 100644 airflow/ui/src/components/DataTable/useTableUrlState.ts diff --git a/.gitignore b/.gitignore index 40845794e3cb7..a37af448782a3 100644 --- a/.gitignore +++ b/.gitignore @@ -172,7 +172,7 @@ pnpm-debug.log* .vscode/* !.vscode/extensions.json /.vite/ -/.pnpm-store/ +.pnpm-store # Airflow log files when airflow is run locally airflow-*.err diff --git a/airflow/ui/package.json b/airflow/ui/package.json index d78e2f1c69317..6a436be5ff270 100644 --- a/airflow/ui/package.json +++ b/airflow/ui/package.json @@ -26,7 +26,8 @@ "framer-motion": "^11.3.29", "react": "^18.3.1", "react-dom": "^18.3.1", - "react-icons": "^5.3.0" + "react-icons": "^5.3.0", + "react-router-dom": "^6.26.2" }, "devDependencies": { "@7nohe/openapi-react-query-codegen": "^1.6.0", diff --git a/airflow/ui/pnpm-lock.yaml b/airflow/ui/pnpm-lock.yaml index effe48f41ed40..3d2614e538914 100644 --- a/airflow/ui/pnpm-lock.yaml +++ b/airflow/ui/pnpm-lock.yaml @@ -61,6 +61,9 @@ importers: react-icons: specifier: ^5.3.0 version: 5.3.0(react@18.3.1) + react-router-dom: + specifier: ^6.26.2 + version: 6.26.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1) devDependencies: '@7nohe/openapi-react-query-codegen': specifier: ^1.6.0 @@ -957,6 +960,10 @@ packages: '@popperjs/core@2.11.8': resolution: {integrity: sha512-P1st0aksCrn9sGZhp8GMYwBnQsbvAWsZAX44oXNNvLHGqAOcoVxmjZiohstwQ7SqKnbR47akdNi+uleWD8+g6A==} + '@remix-run/router@1.19.2': + resolution: {integrity: sha512-baiMx18+IMuD1yyvOGaHM9QrVUPGGG0jC+z+IPHnRJWUAUvaKuWKyE8gjDj2rzv3sz9zOGoRSPgeBVHRhZnBlA==} + engines: {node: '>=14.0.0'} + '@rollup/rollup-android-arm-eabi@4.21.0': resolution: {integrity: sha512-WTWD8PfoSAJ+qL87lE7votj3syLavxunWhzCnx3XFxFiI/BA/r3X7MUM8dVrH8rb2r4AiO8jJsr3ZjdaftmnfA==} cpu: [arm] @@ -2506,6 +2513,19 @@ packages: '@types/react': optional: true + react-router-dom@6.26.2: + resolution: {integrity: sha512-z7YkaEW0Dy35T3/QKPYB1LjMK2R1fxnHO8kWpUMTBdfVzZrWOiY9a7CtN8HqdWtDUWd5FY6Dl8HFsqVwH4uOtQ==} + engines: {node: '>=14.0.0'} + peerDependencies: + react: '>=16.8' + react-dom: '>=16.8' + + react-router@6.26.2: + resolution: {integrity: sha512-tvN1iuT03kHgOFnLPfLJ8V95eijteveqdOSk+srqfePtQvqCExB8eHOYnlilbOcyJyKnYkr1vJvf7YqotAJu1A==} + engines: {node: '>=14.0.0'} + peerDependencies: + react: '>=16.8' + react-select@5.8.0: resolution: {integrity: sha512-TfjLDo58XrhP6VG5M/Mi56Us0Yt8X7xD6cDybC7yoRMUNm7BGO7qk8J0TLQOua/prb8vUOtsfnXZwfm30HGsAA==} peerDependencies: @@ -4003,6 +4023,8 @@ snapshots: '@popperjs/core@2.11.8': {} + '@remix-run/router@1.19.2': {} + '@rollup/rollup-android-arm-eabi@4.21.0': optional: true @@ -5663,6 +5685,18 @@ snapshots: optionalDependencies: '@types/react': 18.3.4 + react-router-dom@6.26.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@remix-run/router': 1.19.2 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + react-router: 6.26.2(react@18.3.1) + + react-router@6.26.2(react@18.3.1): + dependencies: + '@remix-run/router': 1.19.2 + react: 18.3.1 + react-select@5.8.0(@types/react@18.3.4)(react-dom@18.3.1(react@18.3.1))(react@18.3.1): dependencies: '@babel/runtime': 7.25.4 diff --git a/airflow/ui/src/components/DataTable/DataTable.test.tsx b/airflow/ui/src/components/DataTable/DataTable.test.tsx index f98b43fe5f610..3072e6a2edcc0 100644 --- a/airflow/ui/src/components/DataTable/DataTable.test.tsx +++ b/airflow/ui/src/components/DataTable/DataTable.test.tsx @@ -34,7 +34,7 @@ const columns: ColumnDef<{ name: string }>[] = [ const data = [{ name: "John Doe" }, { name: "Jane Doe" }]; const pagination: PaginationState = { pageIndex: 0, pageSize: 1 }; -const setPagination = vi.fn(); +const onStateChange = vi.fn(); describe("DataTable", () => { it("renders table with data", () => { @@ -43,8 +43,8 @@ describe("DataTable", () => { data={data} total={2} columns={columns} - pagination={pagination} - setPagination={setPagination} + initialState={{ pagination, sorting: [] }} + onStateChange={onStateChange} /> ); @@ -58,8 +58,8 @@ describe("DataTable", () => { data={data} total={2} columns={columns} - pagination={pagination} - setPagination={setPagination} + initialState={{ pagination, sorting: [] }} + onStateChange={onStateChange} /> ); @@ -73,8 +73,11 @@ describe("DataTable", () => { data={data} total={2} columns={columns} - pagination={{ pageIndex: 1, pageSize: 10 }} - setPagination={setPagination} + initialState={{ + pagination: { pageIndex: 1, pageSize: 10 }, + sorting: [], + }} + onStateChange={onStateChange} /> ); diff --git a/airflow/ui/src/components/DataTable/DataTable.tsx b/airflow/ui/src/components/DataTable/DataTable.tsx index 4b4b1251f8428..55982f3260331 100644 --- a/airflow/ui/src/components/DataTable/DataTable.tsx +++ b/airflow/ui/src/components/DataTable/DataTable.tsx @@ -25,9 +25,10 @@ import { getExpandedRowModel, getPaginationRowModel, OnChangeFn, - PaginationState, Row, useReactTable, + TableState as ReactTableState, + Updater, } from "@tanstack/react-table"; import { Box, @@ -41,7 +42,13 @@ import { Tr, useColorModeValue, } from "@chakra-ui/react"; -import React, { Fragment } from "react"; +import React, { Fragment, useCallback, useRef } from "react"; +import { + TiArrowSortedDown, + TiArrowSortedUp, + TiArrowUnsorted, +} from "react-icons/ti"; +import type { TableState } from "./types"; type DataTableProps = { data: TData[]; @@ -51,8 +58,8 @@ type DataTableProps = { row: Row; }) => React.ReactElement | null; getRowCanExpand?: (row: Row) => boolean; - pagination: PaginationState; - setPagination: OnChangeFn; + initialState?: TableState; + onStateChange?: (state: TableState) => void; }; type PaginatorProps = { @@ -118,15 +125,36 @@ const TablePaginator = ({ table }: PaginatorProps) => { ); }; -export function DataTable({ +export const DataTable = ({ data, total = 0, columns, renderSubComponent = () => null, getRowCanExpand = () => false, - pagination, - setPagination, -}: DataTableProps) { + initialState, + onStateChange, +}: DataTableProps) => { + const ref = useRef<{ tableRef: TanStackTable | undefined }>({ + tableRef: undefined, + }); + const handleStateChange = useCallback>( + (updater: Updater) => { + if (ref.current.tableRef && onStateChange) { + const current = ref.current.tableRef.getState(); + const next = typeof updater === "function" ? updater(current) : updater; + + // Only use the controlled state + const nextState = { + sorting: next.sorting, + pagination: next.pagination, + }; + + onStateChange(nextState); + } + }, + [onStateChange] + ); + const table = useReactTable({ data, columns, @@ -134,14 +162,15 @@ export function DataTable({ getCoreRowModel: getCoreRowModel(), getExpandedRowModel: getExpandedRowModel(), getPaginationRowModel: getPaginationRowModel(), - onPaginationChange: setPagination, + onStateChange: handleStateChange, rowCount: total, manualPagination: true, - state: { - pagination, - }, + manualSorting: true, + state: initialState, }); + ref.current.tableRef = table; + const theadBg = useColorModeValue("white", "gray.800"); return ( @@ -150,20 +179,47 @@ export function DataTable({
{table.getHeaderGroups().map((headerGroup) => ( - {headerGroup.headers.map((header) => { - return ( - - ); - })} + {headerGroup.headers.map( + ({ column, id, colSpan, getContext, isPlaceholder }) => { + const sort = column.getIsSorted(); + const canSort = column.getCanSort(); + return ( + + ); + } + )} ))} @@ -200,4 +256,4 @@ export function DataTable({ ); -} +}; diff --git a/airflow/ui/src/components/DataTable/index.tsx b/airflow/ui/src/components/DataTable/index.tsx index 495cf4e4f389b..1b01f6dfd3695 100644 --- a/airflow/ui/src/components/DataTable/index.tsx +++ b/airflow/ui/src/components/DataTable/index.tsx @@ -17,4 +17,6 @@ * under the License. */ -export * from "./DataTable"; +import { DataTable } from "./DataTable"; + +export { DataTable }; diff --git a/airflow/ui/src/components/DataTable/searchParams.test.ts b/airflow/ui/src/components/DataTable/searchParams.test.ts new file mode 100644 index 0000000000000..5a1d788305171 --- /dev/null +++ b/airflow/ui/src/components/DataTable/searchParams.test.ts @@ -0,0 +1,64 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { describe, expect, it } from "vitest"; +import type { TableState } from "./types"; +import { searchParamsToState, stateToSearchParams } from "./searchParams"; + +describe("searchParams", () => { + describe("stateToSearchParams", () => { + it("can serialize table state to search params", () => { + const state: TableState = { + pagination: { + pageIndex: 1, + pageSize: 20, + }, + sorting: [{ id: "name", desc: false }], + }; + expect(stateToSearchParams(state).toString()).toEqual( + "limit=20&offset=1&sort=name" + ); + }); + }); + describe("searchParamsToState", () => { + it("can parse search params back to table state", () => { + expect( + searchParamsToState( + new URLSearchParams("limit=20&offset=0&sort=name&sort=-age"), + { + pagination: { + pageIndex: 1, + pageSize: 5, + }, + sorting: [], + } + ) + ).toEqual({ + pagination: { + pageIndex: 0, + pageSize: 20, + }, + sorting: [ + { id: "name", desc: false }, + { id: "age", desc: true }, + ], + }); + }); + }); +}); diff --git a/airflow/ui/src/components/DataTable/searchParams.ts b/airflow/ui/src/components/DataTable/searchParams.ts new file mode 100644 index 0000000000000..397c26a8302d1 --- /dev/null +++ b/airflow/ui/src/components/DataTable/searchParams.ts @@ -0,0 +1,97 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import type { SortingState } from "@tanstack/react-table"; + +import type { TableState } from "./types"; + +export const LIMIT_PARAM = "limit"; +export const OFFSET_PARAM = "offset"; +export const SORT_PARAM = "sort"; + +export const stateToSearchParams = ( + state: TableState, + defaultTableState?: TableState +): URLSearchParams => { + const queryParams = new URLSearchParams(window.location.search); + + if (state.pagination.pageSize === defaultTableState?.pagination.pageSize) { + queryParams.delete(LIMIT_PARAM); + } else if (state.pagination) { + queryParams.set(LIMIT_PARAM, `${state.pagination.pageSize}`); + } + + if (state.pagination.pageIndex === defaultTableState?.pagination.pageIndex) { + queryParams.delete(OFFSET_PARAM); + } else if (state.pagination) { + queryParams.set(OFFSET_PARAM, `${state.pagination.pageIndex}`); + } + + if (!state.sorting.length) { + queryParams.delete(SORT_PARAM); + } else { + state.sorting.forEach(({ id, desc }) => { + if ( + defaultTableState?.sorting.find( + (sort) => sort.id === id && sort.desc === desc + ) + ) { + queryParams.delete(SORT_PARAM, `${desc ? "-" : ""}${id}`); + } else { + queryParams.set(SORT_PARAM, `${desc ? "-" : ""}${id}`); + } + }); + } + + return queryParams; +}; + +export const searchParamsToState = ( + searchParams: URLSearchParams, + defaultState: TableState +) => { + let urlState: Partial = {}; + const pageIndex = searchParams.get(OFFSET_PARAM); + const pageSize = searchParams.get(LIMIT_PARAM); + + if (pageIndex) { + urlState = { + ...urlState, + pagination: { + pageIndex: parseInt(pageIndex, 10), + pageSize: pageSize + ? parseInt(pageSize, 10) + : defaultState.pagination.pageSize, + }, + }; + } + const sorts = searchParams.getAll(SORT_PARAM); + const sorting: SortingState = sorts.map((sort) => ({ + id: sort.replace("-", ""), + desc: sort.startsWith("-"), + })); + urlState = { + ...urlState, + sorting, + }; + return { + ...defaultState, + ...urlState, + }; +}; diff --git a/airflow/ui/src/components/DataTable/types.ts b/airflow/ui/src/components/DataTable/types.ts new file mode 100644 index 0000000000000..52e9e742c2f85 --- /dev/null +++ b/airflow/ui/src/components/DataTable/types.ts @@ -0,0 +1,25 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { PaginationState, SortingState } from "@tanstack/react-table"; + +export interface TableState { + pagination: PaginationState; + sorting: SortingState; +} diff --git a/airflow/ui/src/components/DataTable/useTableUrlState.ts b/airflow/ui/src/components/DataTable/useTableUrlState.ts new file mode 100644 index 0000000000000..aaa48d6362953 --- /dev/null +++ b/airflow/ui/src/components/DataTable/useTableUrlState.ts @@ -0,0 +1,59 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { useCallback, useMemo } from "react"; + +import { useSearchParams } from "react-router-dom"; +import type { TableState } from "./types"; +import { searchParamsToState, stateToSearchParams } from "./searchParams"; + +export const defaultTableState: TableState = { + pagination: { + pageIndex: 0, + pageSize: 50, + }, + sorting: [], +}; + +export const useTableURLState = (defaultState?: Partial) => { + const [searchParams, setSearchParams] = useSearchParams(); + + const handleStateChange = useCallback( + (state: TableState) => { + setSearchParams(stateToSearchParams(state, defaultTableState), { + replace: true, + }); + }, + [setSearchParams] + ); + + const tableURLState = useMemo( + () => + searchParamsToState(searchParams, { + ...defaultTableState, + ...defaultState, + }), + [searchParams, defaultState] + ); + + return { + tableURLState, + setTableURLState: handleStateChange, + }; +}; diff --git a/airflow/ui/src/dagsList.tsx b/airflow/ui/src/dagsList.tsx index b6c4e6949b97a..5591570a391f8 100644 --- a/airflow/ui/src/dagsList.tsx +++ b/airflow/ui/src/dagsList.tsx @@ -17,8 +17,8 @@ * under the License. */ -import { useState } from "react"; -import { ColumnDef, PaginationState } from "@tanstack/react-table"; +import { ColumnDef } from "@tanstack/react-table"; +import { useSearchParams } from "react-router-dom"; import { Badge, Button, @@ -34,7 +34,6 @@ import { InputRightElement, Select, Spinner, - Text, VStack, } from "@chakra-ui/react"; import { Select as ReactSelect } from "chakra-react-select"; @@ -44,6 +43,7 @@ import { DAG } from "openapi/requests/types.gen"; import { useDagServiceGetDags } from "openapi/queries"; import { DataTable } from "./components/DataTable"; import { pluralize } from "./utils/pluralize"; +import { useTableURLState } from "./components/DataTable/useTableUrlState"; const SearchBar = ({ groupProps, @@ -76,12 +76,14 @@ const SearchBar = ({ const columns: ColumnDef[] = [ { - accessorKey: "dag_display_name", + accessorKey: "dag_id", header: "DAG", + cell: ({ row }) => row.original.dag_display_name, }, { accessorKey: "is_paused", header: () => "Is Paused", + enableSorting: false, }, { accessorKey: "timetable_description", @@ -90,19 +92,12 @@ const columns: ColumnDef[] = [ info.getValue() !== "Never, external triggers only" ? info.getValue() : undefined, + enableSorting: false, }, { accessorKey: "next_dagrun", header: "Next DAG Run", - }, - { - accessorKey: "owner", - header: () => "Owner", - cell: ({ row }) => ( - - {row.original.owners?.map((owner) => {owner})} - - ), + enableSorting: false, }, { accessorKey: "tags", @@ -114,6 +109,7 @@ const columns: ColumnDef[] = [ ))} ), + enableSorting: false, }, ]; @@ -129,15 +125,20 @@ const QuickFilterButton = ({ children, ...rest }: ButtonProps) => ( ); +const PAUSED_PARAM = "paused"; + export const DagsList = () => { - // TODO: Change this to be taken from airflow.cfg - const pageSize = 50; - const [pagination, setPagination] = useState({ - pageIndex: 0, - pageSize: pageSize, - }); - const [showPaused, setShowPaused] = useState(true); - const [orderBy, setOrderBy] = useState(); + const cardView = false; + const [searchParams, setSearchParams] = useSearchParams(); + + const showPaused = searchParams.get(PAUSED_PARAM) === "true"; + + const { tableURLState, setTableURLState } = useTableURLState(); + const { sorting, pagination } = tableURLState; + + // TODO: update API to accept multiple orderBy params + const sort = sorting[0]; + const orderBy = sort ? `${sort.desc ? "-" : ""}${sort.id}` : undefined; const { data, isLoading } = useDagServiceGetDags({ limit: pagination.pageSize, @@ -168,44 +169,56 @@ export const DagsList = () => { { - setShowPaused(!showPaused); - setPagination({ - ...pagination, - pageIndex: 0, + if (showPaused) searchParams.delete(PAUSED_PARAM); + else searchParams.set(PAUSED_PARAM, "true"); + setSearchParams(searchParams); + setTableURLState({ + sorting, + pagination: { ...pagination, pageIndex: 0 }, }); }} > Show Paused DAGs - - - - + {pluralize("DAG", data.total_entries)} - + {cardView && ( + + )} true} - pagination={pagination} - setPagination={setPagination} + initialState={tableURLState} + onStateChange={setTableURLState} /> )} diff --git a/airflow/ui/src/main.tsx b/airflow/ui/src/main.tsx index be9196defdb18..1dfab4a7ac908 100644 --- a/airflow/ui/src/main.tsx +++ b/airflow/ui/src/main.tsx @@ -22,6 +22,7 @@ import { ChakraProvider } from "@chakra-ui/react"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { App } from "src/app.tsx"; import axios, { AxiosResponse } from "axios"; +import { BrowserRouter } from "react-router-dom"; import theme from "./theme"; const queryClient = new QueryClient({ @@ -60,7 +61,9 @@ const root = createRoot(document.getElementById("root")!); root.render( - + + + ); diff --git a/airflow/ui/src/utils/test.tsx b/airflow/ui/src/utils/test.tsx index f62acc3ec7fc3..0ed5207f67a48 100644 --- a/airflow/ui/src/utils/test.tsx +++ b/airflow/ui/src/utils/test.tsx @@ -20,6 +20,7 @@ import { PropsWithChildren } from "react"; import { ChakraProvider } from "@chakra-ui/react"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { MemoryRouter } from "react-router-dom"; export const Wrapper = ({ children }: PropsWithChildren) => { const queryClient = new QueryClient({ @@ -32,7 +33,9 @@ export const Wrapper = ({ children }: PropsWithChildren) => { return ( - {children} + + {children} + ); }; @@ -40,3 +43,7 @@ export const Wrapper = ({ children }: PropsWithChildren) => { export const ChakraWrapper = ({ children }: PropsWithChildren) => ( {children} ); + +export const RouterWrapper = ({ children }: PropsWithChildren) => ( + {children} +); From 330f6dd428e1c80233648934e341ac5190c7aea2 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Tue, 10 Sep 2024 17:02:12 +0200 Subject: [PATCH 010/349] chore: bump OL provider dependencies versions (#42059) Signed-off-by: Kacper Muda --- airflow/providers/openlineage/provider.yaml | 4 ++-- generated/provider_dependencies.json | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/providers/openlineage/provider.yaml b/airflow/providers/openlineage/provider.yaml index ee416360a1447..c33b3a128ec1f 100644 --- a/airflow/providers/openlineage/provider.yaml +++ b/airflow/providers/openlineage/provider.yaml @@ -50,8 +50,8 @@ dependencies: - apache-airflow-providers-common-sql>=1.6.0 - apache-airflow-providers-common-compat>=1.2.0 - attrs>=22.2 - - openlineage-integration-common>=1.16.0 - - openlineage-python>=1.16.0 + - openlineage-integration-common>=1.22.0 + - openlineage-python>=1.22.0 integrations: - integration-name: OpenLineage diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 7e158f476af4f..18e98f76cd7e7 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -944,8 +944,8 @@ "apache-airflow-providers-common-sql>=1.6.0", "apache-airflow>=2.8.0", "attrs>=22.2", - "openlineage-integration-common>=1.16.0", - "openlineage-python>=1.16.0" + "openlineage-integration-common>=1.22.0", + "openlineage-python>=1.22.0" ], "devel-deps": [], "plugins": [ From 810df312925d3f70c2598c90044be164e49fc458 Mon Sep 17 00:00:00 2001 From: Andrii Yerko Date: Tue, 10 Sep 2024 18:51:35 +0300 Subject: [PATCH 011/349] fix typo in Tableau Connection (#42131) Sever -> Server --- docs/apache-airflow-providers-tableau/connections/tableau.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/apache-airflow-providers-tableau/connections/tableau.rst b/docs/apache-airflow-providers-tableau/connections/tableau.rst index 3a3627dbcf603..f898fe9c2dee0 100644 --- a/docs/apache-airflow-providers-tableau/connections/tableau.rst +++ b/docs/apache-airflow-providers-tableau/connections/tableau.rst @@ -61,7 +61,7 @@ Password (optional) Used with password authentication. Host - Specify the `Sever URL + Specify the `Server URL `_ used for Tableau. Extra (optional) From ca2dbc0ba6c4850a5a701dc3c743b5389d839e3a Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Tue, 10 Sep 2024 09:31:39 -0700 Subject: [PATCH 012/349] Make SAML a required dependency of Amazon provider (#42137) Amazon provider with auth manager requires SAML onelogin import and it starts to be more and more problematic to skip the related tests for compatibility. It seems appropriate to move saml to be a required dependency of Amazon provider in this case. Since saml is only used by Amazon provider, we can also safely remove optional extra for it. --- INSTALL | 4 ++-- airflow/providers/papermill/provider.yaml | 1 + contributing-docs/12_airflow_dependencies_and_extras.rst | 4 ++-- generated/provider_dependencies.json | 1 + hatch_build.py | 4 ---- newsfragments/42137.significant.rst | 1 + pyproject.toml | 5 ++--- .../security_manager/test_aws_security_manager_override.py | 1 - .../amazon/aws/auth_manager/test_aws_auth_manager.py | 1 - tests/providers/amazon/aws/auth_manager/views/test_auth.py | 2 -- .../providers/amazon/aws/tests/test_aws_auth_manager.py | 2 -- 11 files changed, 9 insertions(+), 17 deletions(-) create mode 100644 newsfragments/42137.significant.rst diff --git a/INSTALL b/INSTALL index 9a8b1a84118c0..8d81910f071c7 100644 --- a/INSTALL +++ b/INSTALL @@ -257,8 +257,8 @@ Those extras are available as regular core airflow extras - they install optiona # START CORE EXTRAS HERE aiobotocore, apache-atlas, apache-webhdfs, async, cgroups, cloudpickle, github-enterprise, google- -auth, graphviz, kerberos, ldap, leveldb, otel, pandas, password, rabbitmq, s3fs, saml, sentry, -statsd, uv, virtualenv +auth, graphviz, kerberos, ldap, leveldb, otel, pandas, password, rabbitmq, s3fs, sentry, statsd, uv, +virtualenv # END CORE EXTRAS HERE diff --git a/airflow/providers/papermill/provider.yaml b/airflow/providers/papermill/provider.yaml index afd273a69e1b2..ff738b2ec73d1 100644 --- a/airflow/providers/papermill/provider.yaml +++ b/airflow/providers/papermill/provider.yaml @@ -57,6 +57,7 @@ dependencies: - ipykernel - pandas>=2.1.2,<2.2;python_version>="3.9" - pandas>=1.5.3,<2.2;python_version<"3.9" + - python3-saml>=1.16.0 integrations: - integration-name: Papermill diff --git a/contributing-docs/12_airflow_dependencies_and_extras.rst b/contributing-docs/12_airflow_dependencies_and_extras.rst index 59b5c6b8053d6..70f30fa0b7a7a 100644 --- a/contributing-docs/12_airflow_dependencies_and_extras.rst +++ b/contributing-docs/12_airflow_dependencies_and_extras.rst @@ -165,8 +165,8 @@ Those extras are available as regular core airflow extras - they install optiona .. START CORE EXTRAS HERE aiobotocore, apache-atlas, apache-webhdfs, async, cgroups, cloudpickle, github-enterprise, google- -auth, graphviz, kerberos, ldap, leveldb, otel, pandas, password, rabbitmq, s3fs, saml, sentry, -statsd, uv, virtualenv +auth, graphviz, kerberos, ldap, leveldb, otel, pandas, password, rabbitmq, s3fs, sentry, statsd, uv, +virtualenv .. END CORE EXTRAS HERE diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 18e98f76cd7e7..3ea4df282d780 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1014,6 +1014,7 @@ "pandas>=1.5.3,<2.2;python_version<\"3.9\"", "pandas>=2.1.2,<2.2;python_version>=\"3.9\"", "papermill[all]>=2.6.0", + "python3-saml>=1.16.0", "scrapbook[all]" ], "devel-deps": [], diff --git a/hatch_build.py b/hatch_build.py index 02c761c8d5d7c..6233712ce676e 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -124,10 +124,6 @@ # which can have a conflict with boto3 as mentioned in aiobotocore extra "s3fs>=2023.10.0", ], - "saml": [ - # This is required for support of SAML which might be used by some providers (e.g. Amazon) - "python3-saml>=1.16.0", - ], "sentry": [ "blinker>=1.1", # Sentry SDK 1.33 is broken when greenlets are installed and fails to import diff --git a/newsfragments/42137.significant.rst b/newsfragments/42137.significant.rst new file mode 100644 index 0000000000000..0e1848933a0ae --- /dev/null +++ b/newsfragments/42137.significant.rst @@ -0,0 +1 @@ +Optional ``[saml]`` extra has been removed from Airflow core - instead Amazon Provider gets saml as required dependency. diff --git a/pyproject.toml b/pyproject.toml index 7c466743a935c..cd51540529bcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,8 +79,8 @@ dynamic = ["version", "optional-dependencies", "dependencies"] # START CORE EXTRAS HERE # # aiobotocore, apache-atlas, apache-webhdfs, async, cgroups, cloudpickle, github-enterprise, google- -# auth, graphviz, kerberos, ldap, leveldb, otel, pandas, password, rabbitmq, s3fs, saml, sentry, -# statsd, uv, virtualenv +# auth, graphviz, kerberos, ldap, leveldb, otel, pandas, password, rabbitmq, s3fs, sentry, statsd, uv, +# virtualenv # # END CORE EXTRAS HERE # @@ -387,7 +387,6 @@ combine-as-imports = true "airflow/security/kerberos.py" = ["E402"] "airflow/security/utils.py" = ["E402"] "tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py" = ["E402"] -"tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py" = ["E402"] "tests/providers/common/io/xcom/test_backend.py" = ["E402"] "tests/providers/elasticsearch/log/elasticmock/__init__.py" = ["E402"] "tests/providers/elasticsearch/log/elasticmock/utilities/__init__.py" = ["E402"] diff --git a/tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py b/tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py index d386ec015aee2..ebb452fb1afb5 100644 --- a/tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py +++ b/tests/providers/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py @@ -51,7 +51,6 @@ class TestAwsSecurityManagerOverride: "airflow.providers.amazon.aws.auth_manager.views.auth.conf.get_mandatory_value", return_value="test" ) def test_register_views(self, mock_get_mandatory_value, override, appbuilder): - pytest.importorskip("onelogin") from airflow.providers.amazon.aws.auth_manager.views.auth import AwsAuthManagerAuthenticationViews with patch.object(AwsAuthManagerAuthenticationViews, "idp_data"): diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py index 0ebad2c0fc601..f54a2a3e5fb1f 100644 --- a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -790,7 +790,6 @@ def test_get_cli_commands_return_cli_commands(self, auth_manager): "airflow.providers.amazon.aws.auth_manager.views.auth.conf.get_mandatory_value", return_value="test" ) def test_register_views(self, mock_get_mandatory_value, auth_manager_with_appbuilder): - pytest.importorskip("onelogin") from airflow.providers.amazon.aws.auth_manager.views.auth import AwsAuthManagerAuthenticationViews with patch.object(AwsAuthManagerAuthenticationViews, "idp_data"): diff --git a/tests/providers/amazon/aws/auth_manager/views/test_auth.py b/tests/providers/amazon/aws/auth_manager/views/test_auth.py index 4fad8e42578c6..435dd8d2c32fe 100644 --- a/tests/providers/amazon/aws/auth_manager/views/test_auth.py +++ b/tests/providers/amazon/aws/auth_manager/views/test_auth.py @@ -26,8 +26,6 @@ from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.config import conf_vars -pytest.importorskip("onelogin") - pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Test requires Airflow 2.9+"), pytest.mark.skip_if_database_isolation_mode, diff --git a/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py index 792df7b155d04..dac7398a1ba8f 100644 --- a/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py +++ b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py @@ -22,8 +22,6 @@ import boto3 import pytest -pytest.importorskip("onelogin") - from airflow.www import app as application from tests.system.providers.amazon.aws.utils import set_env_id from tests.test_utils.config import conf_vars From 1d2a267c4d65b23aabbacee3fdbcb14463df353b Mon Sep 17 00:00:00 2001 From: Pierre Jeambrun Date: Wed, 11 Sep 2024 01:12:40 +0800 Subject: [PATCH 013/349] Restructure FastAPI API (#42128) * Restructure FastAPI API * Fix CI --- .pre-commit-config.yaml | 10 +-- airflow/{api_ui => api_fastapi}/__init__.py | 0 airflow/{api_ui => api_fastapi}/app.py | 22 +++--- .../gunicorn_config.py | 2 +- airflow/{api_ui => api_fastapi}/main.py | 2 +- .../openapi/v1-generated.yaml | 7 +- .../{api_ui => api_fastapi}/views/__init__.py | 0 airflow/api_fastapi/views/public/__init__.py | 22 ++++++ airflow/api_fastapi/views/ui/__init__.py | 25 +++++++ .../views/ui}/datasets.py | 2 +- airflow/cli/cli_config.py | 40 +++++------ ..._api_command.py => fastapi_api_command.py} | 24 +++---- airflow/cli/commands/standalone_command.py | 8 +-- .../03_contributors_quick_start.rst | 4 +- contributing-docs/08_static_code_checks.rst | 2 +- .../14_node_environment_setup.rst | 4 +- contributing-docs/testing/unit_tests.rst | 2 +- dev/breeze/doc/03_developer_tasks.rst | 6 +- .../src/airflow_breeze/global_constants.py | 2 +- .../src/airflow_breeze/params/shell_params.py | 4 +- .../src/airflow_breeze/utils/run_tests.py | 2 +- .../airflow_breeze/utils/selective_checks.py | 4 +- .../src/airflow_breeze/utils/visuals.py | 6 +- .../tests/test_pytest_args_for_test_types.py | 6 +- dev/breeze/tests/test_selective_checks.py | 2 +- docs/spelling_wordlist.txt | 1 + scripts/ci/docker-compose/base-ports.yml | 2 +- .../check_tests_in_right_folders.py | 2 +- ...api_spec.py => update_fastapi_api_spec.py} | 2 +- scripts/in_container/bin/run_tmux | 4 +- ...spec.py => run_update_fastapi_api_spec.py} | 12 +++- tests/{api_ui => api_fastapi}/__init__.py | 0 tests/{api_ui => api_fastapi}/conftest.py | 2 +- .../{api_ui => api_fastapi}/views/__init__.py | 0 tests/api_fastapi/views/ui/__init__.py | 16 +++++ .../views/ui}/test_datasets.py | 0 ...command.py => test_fastapi_api_command.py} | 68 ++++++++++--------- 37 files changed, 195 insertions(+), 122 deletions(-) rename airflow/{api_ui => api_fastapi}/__init__.py (100%) rename airflow/{api_ui => api_fastapi}/app.py (73%) rename airflow/{api_ui => api_fastapi}/gunicorn_config.py (91%) rename airflow/{api_ui => api_fastapi}/main.py (94%) rename airflow/{api_ui => api_fastapi}/openapi/v1-generated.yaml (81%) rename airflow/{api_ui => api_fastapi}/views/__init__.py (100%) create mode 100644 airflow/api_fastapi/views/public/__init__.py create mode 100644 airflow/api_fastapi/views/ui/__init__.py rename airflow/{api_ui/views => api_fastapi/views/ui}/datasets.py (97%) rename airflow/cli/commands/{ui_api_command.py => fastapi_api_command.py} (91%) rename scripts/ci/pre_commit/{update_ui_api_spec.py => update_fastapi_api_spec.py} (94%) rename scripts/in_container/{run_update_ui_api_spec.py => run_update_fastapi_api_spec.py} (79%) rename tests/{api_ui => api_fastapi}/__init__.py (100%) rename tests/{api_ui => api_fastapi}/conftest.py (95%) rename tests/{api_ui => api_fastapi}/views/__init__.py (100%) create mode 100644 tests/api_fastapi/views/ui/__init__.py rename tests/{api_ui/views => api_fastapi/views/ui}/test_datasets.py (100%) rename tests/cli/commands/{test_ui_api_command.py => test_fastapi_api_command.py} (68%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 39fd712e3300f..6d6f77b34302b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -135,7 +135,7 @@ repos: - --fuzzy-match-generates-todo - id: insert-license name: Add license for all YAML files except Helm templates - exclude: ^\.github/.*$|^.*/.*_vendor/|^chart/templates/.*|.*/reproducible_build.yaml$|^airflow/api_ui/openapi/v1-generated.yaml$ + exclude: ^\.github/.*$|^.*/.*_vendor/|^chart/templates/.*|.*/reproducible_build.yaml$|^airflow/api_fastapi/openapi/v1-generated.yaml$ types: [yaml] files: \.ya?ml$ args: @@ -589,7 +589,7 @@ repos: ^airflow/api_connexion/openapi/v1.yaml$| ^airflow/ui/openapi-gen/| ^airflow/cli/commands/internal_api_command.py$| - ^airflow/cli/commands/ui_api_command.py$| + ^airflow/cli/commands/fastapi_api_command.py$| ^airflow/cli/commands/webserver_command.py$| ^airflow/config_templates/| ^airflow/models/baseoperator.py$| @@ -1330,11 +1330,11 @@ repos: files: ^airflow/migrations/versions/.*\.py$|^docs/apache-airflow/migrations-ref\.rst$ additional_dependencies: ['rich>=12.4.4'] - id: generate-openapi-spec - name: Generate the UI API OPENAPI spec + name: Generate the FastAPI API spec language: python - entry: ./scripts/ci/pre_commit/update_ui_api_spec.py + entry: ./scripts/ci/pre_commit/update_fastapi_api_spec.py pass_filenames: false - files: ^airflow/api_ui/.*\.py$ + files: ^airflow/api_fastapi/.*\.py$ additional_dependencies: ['rich>=12.4.4'] - id: update-er-diagram name: Update ER diagram diff --git a/airflow/api_ui/__init__.py b/airflow/api_fastapi/__init__.py similarity index 100% rename from airflow/api_ui/__init__.py rename to airflow/api_fastapi/__init__.py diff --git a/airflow/api_ui/app.py b/airflow/api_fastapi/app.py similarity index 73% rename from airflow/api_ui/app.py rename to airflow/api_fastapi/app.py index b279f6c90d25e..f54c844461ad3 100644 --- a/airflow/api_ui/app.py +++ b/airflow/api_fastapi/app.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from fastapi import APIRouter, FastAPI +from fastapi import FastAPI from airflow.www.extensions.init_dagbag import get_dag_bag @@ -34,9 +34,9 @@ def init_dag_bag(app: FastAPI) -> None: def create_app() -> FastAPI: app = FastAPI( - description="Internal Rest API for the UI frontend. It is subject to breaking change " - "depending on the need of the frontend. Users should not rely on this API but use the " - "public API instead." + description="Airflow API. All endpoints located under ``/public`` can be used safely, are stable and backward compatible. " + "Endpoints located under ``/ui`` are dedicated to the UI and are subject to breaking change " + "depending on the need of the frontend. Users should not rely on those but use the public ones instead." ) init_dag_bag(app) @@ -48,16 +48,14 @@ def create_app() -> FastAPI: def init_views(app) -> None: """Init views by registering the different routers.""" - from airflow.api_ui.views.datasets import dataset_router + from airflow.api_fastapi.views.public import public_router + from airflow.api_fastapi.views.ui import ui_router - root_router = APIRouter(prefix="/ui") + app.include_router(ui_router) + app.include_router(public_router) - root_router.include_router(dataset_router) - app.include_router(root_router) - - -def cached_app(config=None, testing=False): +def cached_app(config=None, testing=False) -> FastAPI: """Return cached instance of Airflow UI app.""" global app if not app: @@ -65,7 +63,7 @@ def cached_app(config=None, testing=False): return app -def purge_cached_app(): +def purge_cached_app() -> None: """Remove the cached version of the app in global state.""" global app app = None diff --git a/airflow/api_ui/gunicorn_config.py b/airflow/api_fastapi/gunicorn_config.py similarity index 91% rename from airflow/api_ui/gunicorn_config.py rename to airflow/api_fastapi/gunicorn_config.py index 3ee3ba5e7cc2b..70072ad6bdc23 100644 --- a/airflow/api_ui/gunicorn_config.py +++ b/airflow/api_fastapi/gunicorn_config.py @@ -27,7 +27,7 @@ def post_worker_init(_): """ Set process title. - This is used by airflow.cli.commands.ui_api_command to track the status of the worker. + This is used by airflow.cli.commands.fastapi_api_command to track the status of the worker. """ old_title = setproctitle.getproctitle() setproctitle.setproctitle(settings.GUNICORN_WORKER_READY_PREFIX + old_title) diff --git a/airflow/api_ui/main.py b/airflow/api_fastapi/main.py similarity index 94% rename from airflow/api_ui/main.py rename to airflow/api_fastapi/main.py index 175db6b9271d1..9a204306efd4e 100644 --- a/airflow/api_ui/main.py +++ b/airflow/api_fastapi/main.py @@ -17,6 +17,6 @@ from __future__ import annotations -from airflow.api_ui.app import cached_app +from airflow.api_fastapi.app import cached_app app = cached_app() diff --git a/airflow/api_ui/openapi/v1-generated.yaml b/airflow/api_fastapi/openapi/v1-generated.yaml similarity index 81% rename from airflow/api_ui/openapi/v1-generated.yaml rename to airflow/api_fastapi/openapi/v1-generated.yaml index fef897a0718f9..dcd67b84df845 100644 --- a/airflow/api_ui/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/openapi/v1-generated.yaml @@ -1,9 +1,10 @@ openapi: 3.1.0 info: title: FastAPI - description: Internal Rest API for the UI frontend. It is subject to breaking change - depending on the need of the frontend. Users should not rely on this API but use - the public API instead. + description: Airflow API. All endpoints located under ``/public`` can be used safely, + are stable and backward compatible. Endpoints located under ``/ui`` are dedicated + to the UI and are subject to breaking change depending on the need of the frontend. + Users should not rely on those but use the public ones instead. version: 0.1.0 paths: /ui/next_run_datasets/{dag_id}: diff --git a/airflow/api_ui/views/__init__.py b/airflow/api_fastapi/views/__init__.py similarity index 100% rename from airflow/api_ui/views/__init__.py rename to airflow/api_fastapi/views/__init__.py diff --git a/airflow/api_fastapi/views/public/__init__.py b/airflow/api_fastapi/views/public/__init__.py new file mode 100644 index 0000000000000..230b4a26c37cb --- /dev/null +++ b/airflow/api_fastapi/views/public/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from fastapi import APIRouter + +public_router = APIRouter(prefix="/public") diff --git a/airflow/api_fastapi/views/ui/__init__.py b/airflow/api_fastapi/views/ui/__init__.py new file mode 100644 index 0000000000000..aa539e2845fad --- /dev/null +++ b/airflow/api_fastapi/views/ui/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from fastapi import APIRouter + +from airflow.api_fastapi.views.ui.datasets import dataset_router + +ui_router = APIRouter(prefix="/ui") + +ui_router.include_router(dataset_router) diff --git a/airflow/api_ui/views/datasets.py b/airflow/api_fastapi/views/ui/datasets.py similarity index 97% rename from airflow/api_ui/views/datasets.py rename to airflow/api_fastapi/views/ui/datasets.py index 2ab983082fdc0..d6de8ebca0e02 100644 --- a/airflow/api_ui/views/datasets.py +++ b/airflow/api_fastapi/views/ui/datasets.py @@ -29,7 +29,7 @@ # Ultimately we want async routes, with async sqlalchemy session / context manager. # Additional effort to make airflow utility code async, not handled for now and most likely part of the AIP-70 -@dataset_router.get("/next_run_datasets/{dag_id}") +@dataset_router.get("/next_run_datasets/{dag_id}", include_in_schema=False) async def next_run_datasets(dag_id: str, request: Request) -> dict: dag = request.app.state.dag_bag.get_dag(dag_id) diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py index c882804778cf6..04f78f3dc2a2f 100644 --- a/airflow/cli/cli_config.py +++ b/airflow/cli/cli_config.py @@ -780,39 +780,39 @@ def string_lower_type(val): ) -# ui-api -ARG_UI_API_PORT = Arg( +# fastapi-api +ARG_FASTAPI_API_PORT = Arg( ("-p", "--port"), default=9091, type=int, help="The port on which to run the server", ) -ARG_UI_API_WORKERS = Arg( +ARG_FASTAPI_API_WORKERS = Arg( ("-w", "--workers"), default=4, type=int, - help="Number of workers to run the UI API-on", + help="Number of workers to run the FastAPI API-on", ) -ARG_UI_API_WORKER_TIMEOUT = Arg( +ARG_FASTAPI_API_WORKER_TIMEOUT = Arg( ("-t", "--worker-timeout"), default=120, type=int, - help="The timeout for waiting on UI API workers", + help="The timeout for waiting on FastAPI API workers", ) -ARG_UI_API_HOSTNAME = Arg( +ARG_FASTAPI_API_HOSTNAME = Arg( ("-H", "--hostname"), default="0.0.0.0", # nosec help="Set the hostname on which to run the web server", ) -ARG_UI_API_ACCESS_LOGFILE = Arg( +ARG_FASTAPI_API_ACCESS_LOGFILE = Arg( ("-A", "--access-logfile"), help="The logfile to store the access log. Use '-' to print to stdout", ) -ARG_UI_API_ERROR_LOGFILE = Arg( +ARG_FASTAPI_API_ERROR_LOGFILE = Arg( ("-E", "--error-logfile"), help="The logfile to store the error log. Use '-' to print to stderr", ) -ARG_UI_API_ACCESS_LOGFORMAT = Arg( +ARG_FASTAPI_API_ACCESS_LOGFORMAT = Arg( ("-L", "--access-logformat"), help="The access log format for gunicorn logs", ) @@ -1981,21 +1981,21 @@ class GroupCommand(NamedTuple): ), ), ActionCommand( - name="ui-api", - help="Start an Airflow UI API instance", - func=lazy_load_command("airflow.cli.commands.ui_api_command.ui_api"), + name="fastapi-api", + help="Start an Airflow FastAPI API instance", + func=lazy_load_command("airflow.cli.commands.fastapi_api_command.fastapi_api"), args=( - ARG_UI_API_PORT, - ARG_UI_API_WORKERS, - ARG_UI_API_WORKER_TIMEOUT, - ARG_UI_API_HOSTNAME, + ARG_FASTAPI_API_PORT, + ARG_FASTAPI_API_WORKERS, + ARG_FASTAPI_API_WORKER_TIMEOUT, + ARG_FASTAPI_API_HOSTNAME, ARG_PID, ARG_DAEMON, ARG_STDOUT, ARG_STDERR, - ARG_UI_API_ACCESS_LOGFILE, - ARG_UI_API_ERROR_LOGFILE, - ARG_UI_API_ACCESS_LOGFORMAT, + ARG_FASTAPI_API_ACCESS_LOGFILE, + ARG_FASTAPI_API_ERROR_LOGFILE, + ARG_FASTAPI_API_ACCESS_LOGFORMAT, ARG_LOG_FILE, ARG_SSL_CERT, ARG_SSL_KEY, diff --git a/airflow/cli/commands/ui_api_command.py b/airflow/cli/commands/fastapi_api_command.py similarity index 91% rename from airflow/cli/commands/ui_api_command.py rename to airflow/cli/commands/fastapi_api_command.py index cacc1fd7d487c..d50d454347a73 100644 --- a/airflow/cli/commands/ui_api_command.py +++ b/airflow/cli/commands/fastapi_api_command.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""UI API command.""" +"""FastAPI API command.""" from __future__ import annotations @@ -52,8 +52,8 @@ @cli_utils.action_cli @providers_configuration_loaded -def ui_api(args): - """Start Airflow UI API.""" +def fastapi_api(args): + """Start Airflow FastAPI API.""" print(settings.HEADER) access_logfile = args.access_logfile or "-" @@ -62,18 +62,18 @@ def ui_api(args): num_workers = args.workers worker_timeout = args.worker_timeout - worker_class = "airflow.cli.commands.ui_api_command.AirflowUvicornWorker" + worker_class = "airflow.cli.commands.fastapi_api_command.AirflowUvicornWorker" - from airflow.api_ui.app import create_app + from airflow.api_fastapi.app import create_app if args.debug: - print(f"Starting the UI API server on port {args.port} and host {args.hostname} debug.") + print(f"Starting the FastAPI API server on port {args.port} and host {args.hostname} debug.") log.warning("Running in dev mode, ignoring gunicorn args") run_args = [ "fastapi", "dev", - "airflow/api_ui/main.py", + "airflow/api_fastapi/main.py", "--port", str(args.port), "--host", @@ -99,7 +99,7 @@ def ui_api(args): ) ) - pid_file, _, _, _ = setup_locations("ui-api", pid=args.pid) + pid_file, _, _, _ = setup_locations("fastapi-api", pid=args.pid) run_args = [ sys.executable, "-m", @@ -113,7 +113,7 @@ def ui_api(args): "--bind", args.hostname + ":" + str(args.port), "--name", - "airflow-ui-api", + "airflow-fastapi-api", "--pid", pid_file, "--access-logfile", @@ -121,7 +121,7 @@ def ui_api(args): "--error-logfile", str(error_logfile), "--config", - "python:airflow.api_ui.gunicorn_config", + "python:airflow.api_fastapi.gunicorn_config", ] if args.access_logformat and args.access_logformat.strip(): @@ -130,7 +130,7 @@ def ui_api(args): if args.daemon: run_args += ["--daemon"] - run_args += ["airflow.api_ui.app:cached_app()"] + run_args += ["airflow.api_fastapi.app:cached_app()"] # To prevent different workers creating the web app and # all writing to the database at the same time, we use the --preload option. @@ -194,7 +194,7 @@ def start_and_monitor_gunicorn(args): monitor_pid_file = str(pid_file_path.with_name(f"{pid_file_path.stem}-monitor{pid_file_path.suffix}")) run_command_with_daemon_option( args=args, - process_name="ui-api", + process_name="fastapi-api", callback=lambda: start_and_monitor_gunicorn(args), should_setup_logging=True, pid_file=monitor_pid_file, diff --git a/airflow/cli/commands/standalone_command.py b/airflow/cli/commands/standalone_command.py index 1f2bf5c9e9b2c..0f8d45eb5f1b4 100644 --- a/airflow/cli/commands/standalone_command.py +++ b/airflow/cli/commands/standalone_command.py @@ -80,10 +80,10 @@ def run(self): command=["webserver"], env=env, ) - self.subcommands["ui-api"] = SubCommand( + self.subcommands["fastapi-api"] = SubCommand( self, - name="ui-api", - command=["ui-api"], + name="fastapi-api", + command=["fastapi-api"], env=env, ) self.subcommands["triggerer"] = SubCommand( @@ -142,7 +142,7 @@ def print_output(self, name: str, output): You can pass multiple lines to output if you wish; it will be split for you. """ color = { - "ui-api": "magenta", + "fastapi-api": "magenta", "webserver": "green", "scheduler": "blue", "triggerer": "cyan", diff --git a/contributing-docs/03_contributors_quick_start.rst b/contributing-docs/03_contributors_quick_start.rst index b7467d6c4a16b..8f7ead6deacc4 100644 --- a/contributing-docs/03_contributors_quick_start.rst +++ b/contributing-docs/03_contributors_quick_start.rst @@ -335,7 +335,7 @@ Using Breeze Ports are forwarded to the running docker containers for webserver and database * 12322 -> forwarded to Airflow ssh server -> airflow:22 * 28080 -> forwarded to Airflow webserver -> airflow:8080 - * 29091 -> forwarded to Airflow UI API -> airflow:9091 + * 29091 -> forwarded to Airflow FastAPI API -> airflow:9091 * 25555 -> forwarded to Flower dashboard -> airflow:5555 * 25433 -> forwarded to Postgres database -> postgres:5432 * 23306 -> forwarded to MySQL database -> mysql:3306 @@ -344,7 +344,7 @@ Using Breeze Here are links to those services that you can use on host: * ssh connection for remote debugging: ssh -p 12322 airflow@127.0.0.1 (password: airflow) * Webserver: http://127.0.0.1:28080 - * UI API: http://127.0.0.1:29091 + * FastAPI API: http://127.0.0.1:29091 * Flower: http://127.0.0.1:25555 * Postgres: jdbc:postgresql://127.0.0.1:25433/airflow?user=postgres&password=airflow * Mysql: jdbc:mysql://127.0.0.1:23306/airflow?user=root diff --git a/contributing-docs/08_static_code_checks.rst b/contributing-docs/08_static_code_checks.rst index 0470b0a57dc66..0f2acf890bb52 100644 --- a/contributing-docs/08_static_code_checks.rst +++ b/contributing-docs/08_static_code_checks.rst @@ -272,7 +272,7 @@ require Breeze Docker image to be built locally. +-----------------------------------------------------------+--------------------------------------------------------------+---------+ | generate-airflow-diagrams | Generate airflow diagrams | | +-----------------------------------------------------------+--------------------------------------------------------------+---------+ -| generate-openapi-spec | Generate the UI API OPENAPI spec | * | +| generate-openapi-spec | Generate the FastAPI API spec | * | +-----------------------------------------------------------+--------------------------------------------------------------+---------+ | generate-pypi-readme | Generate PyPI README | | +-----------------------------------------------------------+--------------------------------------------------------------+---------+ diff --git a/contributing-docs/14_node_environment_setup.rst b/contributing-docs/14_node_environment_setup.rst index 77ba3e9542d8d..99da685aeb690 100644 --- a/contributing-docs/14_node_environment_setup.rst +++ b/contributing-docs/14_node_environment_setup.rst @@ -29,9 +29,9 @@ But we want to limit modifications to the legacy ``airflow/www`` views, to mainl 2. The minimum necessary to unblock other Airflow 3.0 feature work 3. Fixes to react views which we haven't migrated over yet, but can still be ported over to the new UI -Custom endpoints for the UI will also be moved away from ``airflow/www/views.py`` and to ``airflow/api_ui``. +Custom endpoints for the UI will also be moved away from ``airflow/www/views.py`` and to ``airflow/api_fastapi``. Contributions to the legacy views file will follow the same rules. -Committers will exercise their judgement on what endpoints should exist in the public ``airflow/api_connexion`` versus the private ``airflow/api_ui`` +Committers will exercise their judgement on what endpoints should exist in the public ``airflow/api_connexion`` versus the private ``airflow/api_fastapi`` Airflow UI ---------- diff --git a/contributing-docs/testing/unit_tests.rst b/contributing-docs/testing/unit_tests.rst index 5c1c2f1584da8..935a7b9b602b4 100644 --- a/contributing-docs/testing/unit_tests.rst +++ b/contributing-docs/testing/unit_tests.rst @@ -96,7 +96,7 @@ test types you want to use in various ``breeze testing`` sub-commands in three w Those test types are defined: * ``Always`` - those are tests that should be always executed (always sub-folder) -* ``API`` - Tests for the Airflow API (api, api_connexion, api_internal, api_ui sub-folders) +* ``API`` - Tests for the Airflow API (api, api_connexion, api_internal, api_fastapi sub-folders) * ``CLI`` - Tests for the Airflow CLI (cli folder) * ``Core`` - for the core Airflow functionality (core, executors, jobs, models, ti_deps, utils sub-folders) * ``Operators`` - tests for the operators (operators folder with exception of Virtualenv Operator tests and diff --git a/dev/breeze/doc/03_developer_tasks.rst b/dev/breeze/doc/03_developer_tasks.rst index 4acfdb4627849..76f43606837e8 100644 --- a/dev/breeze/doc/03_developer_tasks.rst +++ b/dev/breeze/doc/03_developer_tasks.rst @@ -113,7 +113,7 @@ When you run Airflow Breeze, the following ports are automatically forwarded: * 12322 -> forwarded to Airflow ssh server -> airflow:22 * 28080 -> forwarded to Airflow webserver -> airflow:8080 - * 29091 -> forwarded to Airflow UI API -> airflow:9091 + * 29091 -> forwarded to Airflow FastAPI API -> airflow:9091 * 25555 -> forwarded to Flower dashboard -> airflow:5555 * 25433 -> forwarded to Postgres database -> postgres:5432 * 23306 -> forwarded to MySQL database -> mysql:3306 @@ -126,7 +126,7 @@ You can connect to these ports/databases using: * ssh connection for remote debugging: ssh -p 12322 airflow@127.0.0.1 pw: airflow * Webserver: http://127.0.0.1:28080 - * UI API: http://127.0.0.1:29091 + * FastAPI API: http://127.0.0.1:29091 * Flower: http://127.0.0.1:25555 * Postgres: jdbc:postgresql://127.0.0.1:25433/airflow?user=postgres&password=airflow * Mysql: jdbc:mysql://127.0.0.1:23306/airflow?user=root @@ -156,7 +156,7 @@ You can change the used host port numbers by setting appropriate environment var * ``SSH_PORT`` * ``WEBSERVER_HOST_PORT`` -* ``UI_API_HOST_PORT`` +* ``FASTAPI_API_HOST_PORT`` * ``POSTGRES_HOST_PORT`` * ``MYSQL_HOST_PORT`` * ``MSSQL_HOST_PORT`` diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py index 0d6eb5f6f99f0..4ee452af6d339 100644 --- a/dev/breeze/src/airflow_breeze/global_constants.py +++ b/dev/breeze/src/airflow_breeze/global_constants.py @@ -253,7 +253,7 @@ def get_default_platform_machine() -> str: SSH_PORT = "12322" WEBSERVER_HOST_PORT = "28080" VITE_DEV_PORT = "5173" -UI_API_HOST_PORT = "29091" +FASTAPI_API_HOST_PORT = "29091" CELERY_BROKER_URLS_MAP = {"rabbitmq": "amqp://guest:guest@rabbitmq:5672", "redis": "redis://redis:6379/0"} SQLITE_URL = "sqlite:////root/airflow/sqlite/airflow.db" diff --git a/dev/breeze/src/airflow_breeze/params/shell_params.py b/dev/breeze/src/airflow_breeze/params/shell_params.py index 2f53c73fc76f7..af74be27c919b 100644 --- a/dev/breeze/src/airflow_breeze/params/shell_params.py +++ b/dev/breeze/src/airflow_breeze/params/shell_params.py @@ -39,6 +39,7 @@ DEFAULT_UV_HTTP_TIMEOUT, DOCKER_DEFAULT_PLATFORM, DRILL_HOST_PORT, + FASTAPI_API_HOST_PORT, FLOWER_HOST_PORT, MOUNT_ALL, MOUNT_PROVIDERS_AND_TESTS, @@ -52,7 +53,6 @@ SSH_PORT, START_AIRFLOW_DEFAULT_ALLOWED_EXECUTOR, TESTABLE_INTEGRATIONS, - UI_API_HOST_PORT, USE_AIRFLOW_MOUNT_SOURCES, WEBSERVER_HOST_PORT, GithubEvents, @@ -575,7 +575,7 @@ def env_variables_for_docker_commands(self) -> dict[str, str]: _set_var(_env, "VERBOSE_COMMANDS", self.verbose_commands) _set_var(_env, "VERSION_SUFFIX_FOR_PYPI", self.version_suffix_for_pypi) _set_var(_env, "WEBSERVER_HOST_PORT", None, WEBSERVER_HOST_PORT) - _set_var(_env, "UI_API_HOST_PORT", None, UI_API_HOST_PORT) + _set_var(_env, "FASTAPI_API_HOST_PORT", None, FASTAPI_API_HOST_PORT) _set_var(_env, "_AIRFLOW_RUN_DB_TESTS_ONLY", self.run_db_tests_only) _set_var(_env, "_AIRFLOW_SKIP_DB_TESTS", self.skip_db_tests) self._generate_env_for_docker_compose_file_if_needed(_env) diff --git a/dev/breeze/src/airflow_breeze/utils/run_tests.py b/dev/breeze/src/airflow_breeze/utils/run_tests.py index de48da86e43fe..840bf2fcad6d2 100644 --- a/dev/breeze/src/airflow_breeze/utils/run_tests.py +++ b/dev/breeze/src/airflow_breeze/utils/run_tests.py @@ -133,7 +133,7 @@ def get_excluded_provider_args(python_version: str) -> list[str]: TEST_TYPE_MAP_TO_PYTEST_ARGS: dict[str, list[str]] = { "Always": ["tests/always"], - "API": ["tests/api", "tests/api_connexion", "tests/api_internal", "tests/api_ui"], + "API": ["tests/api", "tests/api_connexion", "tests/api_internal", "tests/api_fastapi"], "BranchPythonVenv": [ "tests/operators/test_python.py::TestBranchPythonVirtualenvOperator", ], diff --git a/dev/breeze/src/airflow_breeze/utils/selective_checks.py b/dev/breeze/src/airflow_breeze/utils/selective_checks.py index eb38675fb3aab..653d71f9a443a 100644 --- a/dev/breeze/src/airflow_breeze/utils/selective_checks.py +++ b/dev/breeze/src/airflow_breeze/utils/selective_checks.py @@ -264,11 +264,11 @@ def __hash__(self): r"^airflow/api/", r"^airflow/api_connexion/", r"^airflow/api_internal/", - r"^airflow/api_ui/", + r"^airflow/api_fastapi/", r"^tests/api/", r"^tests/api_connexion/", r"^tests/api_internal/", - r"^tests/api_ui/", + r"^tests/api_fastapi/", ], SelectiveUnitTestTypes.CLI: [ r"^airflow/cli/", diff --git a/dev/breeze/src/airflow_breeze/utils/visuals.py b/dev/breeze/src/airflow_breeze/utils/visuals.py index a43182713f1d1..b9df215fb1bb1 100644 --- a/dev/breeze/src/airflow_breeze/utils/visuals.py +++ b/dev/breeze/src/airflow_breeze/utils/visuals.py @@ -21,12 +21,12 @@ from __future__ import annotations from airflow_breeze.global_constants import ( + FASTAPI_API_HOST_PORT, FLOWER_HOST_PORT, MYSQL_HOST_PORT, POSTGRES_HOST_PORT, REDIS_HOST_PORT, SSH_PORT, - UI_API_HOST_PORT, WEBSERVER_HOST_PORT, ) from airflow_breeze.utils.path_utils import AIRFLOW_SOURCES_ROOT @@ -83,7 +83,7 @@ Ports are forwarded to the running docker containers for webserver and database * {SSH_PORT} -> forwarded to Airflow ssh server -> airflow:22 * {WEBSERVER_HOST_PORT} -> forwarded to Airflow webserver -> airflow:8080 - * {UI_API_HOST_PORT} -> forwarded to Airflow UI API -> airflow:9091 + * {FASTAPI_API_HOST_PORT} -> forwarded to Airflow FastAPI API -> airflow:9091 * {FLOWER_HOST_PORT} -> forwarded to Flower dashboard -> airflow:5555 * {POSTGRES_HOST_PORT} -> forwarded to Postgres database -> postgres:5432 * {MYSQL_HOST_PORT} -> forwarded to MySQL database -> mysql:3306 @@ -93,7 +93,7 @@ * ssh connection for remote debugging: ssh -p {SSH_PORT} airflow@127.0.0.1 (password: airflow) * Webserver: http://127.0.0.1:{WEBSERVER_HOST_PORT} - * UI API: http://127.0.0.1:{WEBSERVER_HOST_PORT} + * FastAPI API: http://127.0.0.1:{WEBSERVER_HOST_PORT} * Flower: http://127.0.0.1:{FLOWER_HOST_PORT} * Postgres: jdbc:postgresql://127.0.0.1:{POSTGRES_HOST_PORT}/airflow?user=postgres&password=airflow * Mysql: jdbc:mysql://127.0.0.1:{MYSQL_HOST_PORT}/airflow?user=root diff --git a/dev/breeze/tests/test_pytest_args_for_test_types.py b/dev/breeze/tests/test_pytest_args_for_test_types.py index 489dfed86d71a..36a4b157794d2 100644 --- a/dev/breeze/tests/test_pytest_args_for_test_types.py +++ b/dev/breeze/tests/test_pytest_args_for_test_types.py @@ -54,7 +54,7 @@ ), ( "API", - ["tests/api", "tests/api_connexion", "tests/api_internal", "tests/api_ui"], + ["tests/api", "tests/api_connexion", "tests/api_internal", "tests/api_fastapi"], False, ), ( @@ -234,7 +234,7 @@ def test_pytest_args_for_helm_test_types(helm_test_package: str, pytest_args: li [ ( "API", - ["tests/api", "tests/api_connexion", "tests/api_internal", "tests/api_ui"], + ["tests/api", "tests/api_connexion", "tests/api_internal", "tests/api_fastapi"], False, ), ( @@ -250,7 +250,7 @@ def test_pytest_args_for_helm_test_types(helm_test_package: str, pytest_args: li "tests/api", "tests/api_connexion", "tests/api_internal", - "tests/api_ui", + "tests/api_fastapi", "tests/cli", ], False, diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 2491b6f920cf8..ac03f57ba3e5e 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -184,7 +184,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): ), ( pytest.param( - ("airflow/api_ui/file.py",), + ("airflow/api_fastapi/file.py",), { "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 2556113d001dc..c7a7d1221c4b1 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -600,6 +600,7 @@ falsy faq Fargate fargate +fastapi fbee fc fd diff --git a/scripts/ci/docker-compose/base-ports.yml b/scripts/ci/docker-compose/base-ports.yml index 050aaad2dce4f..17a935cd27e1c 100644 --- a/scripts/ci/docker-compose/base-ports.yml +++ b/scripts/ci/docker-compose/base-ports.yml @@ -20,5 +20,5 @@ services: ports: - "${SSH_PORT}:22" - "${WEBSERVER_HOST_PORT}:8080" - - "${UI_API_HOST_PORT}:9091" + - "${FASTAPI_API_HOST_PORT}:9091" - "${FLOWER_HOST_PORT}:5555" diff --git a/scripts/ci/pre_commit/check_tests_in_right_folders.py b/scripts/ci/pre_commit/check_tests_in_right_folders.py index 5a7e2c4ec247b..8260b6ad0d578 100755 --- a/scripts/ci/pre_commit/check_tests_in_right_folders.py +++ b/scripts/ci/pre_commit/check_tests_in_right_folders.py @@ -33,7 +33,7 @@ "api", "api_connexion", "api_internal", - "api_ui", + "api_fastapi", "auth", "callbacks", "charts", diff --git a/scripts/ci/pre_commit/update_ui_api_spec.py b/scripts/ci/pre_commit/update_fastapi_api_spec.py similarity index 94% rename from scripts/ci/pre_commit/update_ui_api_spec.py rename to scripts/ci/pre_commit/update_fastapi_api_spec.py index 9bba385301fbe..15ccaa5ac209e 100755 --- a/scripts/ci/pre_commit/update_ui_api_spec.py +++ b/scripts/ci/pre_commit/update_fastapi_api_spec.py @@ -26,7 +26,7 @@ initialize_breeze_precommit(__name__, __file__) cmd_result = run_command_via_breeze_shell( - ["python3", "/opt/airflow/scripts/in_container/run_update_ui_api_spec.py"], + ["python3", "/opt/airflow/scripts/in_container/run_update_fastapi_api_spec.py"], backend="postgres", skip_environment_initialization=False, ) diff --git a/scripts/in_container/bin/run_tmux b/scripts/in_container/bin/run_tmux index fce52fc2d21da..61e7dbf49d5da 100755 --- a/scripts/in_container/bin/run_tmux +++ b/scripts/in_container/bin/run_tmux @@ -60,9 +60,9 @@ tmux send-keys 'airflow scheduler' C-m tmux select-pane -t 2 tmux split-window -h if [[ ${DEV_MODE=} == "true" ]]; then - tmux send-keys 'airflow ui-api -d' C-m + tmux send-keys 'airflow fastapi-api -d' C-m else - tmux send-keys 'airflow ui-api' C-m + tmux send-keys 'airflow fastapi-api' C-m fi tmux split-window -h diff --git a/scripts/in_container/run_update_ui_api_spec.py b/scripts/in_container/run_update_fastapi_api_spec.py similarity index 79% rename from scripts/in_container/run_update_ui_api_spec.py rename to scripts/in_container/run_update_fastapi_api_spec.py index c21b7905bf21a..4d78bc4afd585 100644 --- a/scripts/in_container/run_update_ui_api_spec.py +++ b/scripts/in_container/run_update_fastapi_api_spec.py @@ -19,11 +19,17 @@ import yaml from fastapi.openapi.utils import get_openapi -from airflow.api_ui.app import cached_app +from airflow.api_fastapi.app import create_app -app = cached_app() +app = create_app() -OPENAPI_SPEC_FILE = "airflow/api_ui/openapi/v1-generated.yaml" +OPENAPI_SPEC_FILE = "airflow/api_fastapi/openapi/v1-generated.yaml" + + +# The persisted openapi spec will list all endpoints (public and ui), this +# is used for code generation. +for route in app.routes: + route.__setattr__("include_in_schema", True) with open(OPENAPI_SPEC_FILE, "w+") as f: yaml.dump( diff --git a/tests/api_ui/__init__.py b/tests/api_fastapi/__init__.py similarity index 100% rename from tests/api_ui/__init__.py rename to tests/api_fastapi/__init__.py diff --git a/tests/api_ui/conftest.py b/tests/api_fastapi/conftest.py similarity index 95% rename from tests/api_ui/conftest.py rename to tests/api_fastapi/conftest.py index 9f82802142e97..c5212272d7306 100644 --- a/tests/api_ui/conftest.py +++ b/tests/api_fastapi/conftest.py @@ -19,7 +19,7 @@ import pytest from fastapi.testclient import TestClient -from airflow.api_ui.app import create_app +from airflow.api_fastapi.app import create_app @pytest.fixture diff --git a/tests/api_ui/views/__init__.py b/tests/api_fastapi/views/__init__.py similarity index 100% rename from tests/api_ui/views/__init__.py rename to tests/api_fastapi/views/__init__.py diff --git a/tests/api_fastapi/views/ui/__init__.py b/tests/api_fastapi/views/ui/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/api_fastapi/views/ui/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/api_ui/views/test_datasets.py b/tests/api_fastapi/views/ui/test_datasets.py similarity index 100% rename from tests/api_ui/views/test_datasets.py rename to tests/api_fastapi/views/ui/test_datasets.py diff --git a/tests/cli/commands/test_ui_api_command.py b/tests/cli/commands/test_fastapi_api_command.py similarity index 68% rename from tests/cli/commands/test_ui_api_command.py rename to tests/cli/commands/test_fastapi_api_command.py index 81ee2275bf77b..529c67f5ed821 100644 --- a/tests/cli/commands/test_ui_api_command.py +++ b/tests/cli/commands/test_fastapi_api_command.py @@ -26,7 +26,7 @@ import pytest from rich.console import Console -from airflow.cli.commands import ui_api_command +from airflow.cli.commands import fastapi_api_command from tests.cli.commands._common_cli_classes import _CommonCLIGunicornTestClass console = Console(width=400, color_system="standard") @@ -34,28 +34,28 @@ @pytest.mark.db_test class TestCliInternalAPI(_CommonCLIGunicornTestClass): - main_process_regexp = r"airflow ui-api" + main_process_regexp = r"airflow fastapi-api" @pytest.mark.execution_timeout(210) - def test_cli_ui_api_background(self, tmp_path): + def test_cli_fastapi_api_background(self, tmp_path): parent_path = tmp_path / "gunicorn" parent_path.mkdir() - pidfile_ui_api = parent_path / "pidflow-ui-api.pid" - pidfile_monitor = parent_path / "pidflow-ui-api-monitor.pid" - stdout = parent_path / "airflow-ui-api.out" - stderr = parent_path / "airflow-ui-api.err" - logfile = parent_path / "airflow-ui-api.log" + pidfile_fastapi_api = parent_path / "pidflow-fastapi-api.pid" + pidfile_monitor = parent_path / "pidflow-fastapi-api-monitor.pid" + stdout = parent_path / "airflow-fastapi-api.out" + stderr = parent_path / "airflow-fastapi-api.err" + logfile = parent_path / "airflow-fastapi-api.log" try: # Run internal-api as daemon in background. Note that the wait method is not called. - console.print("[magenta]Starting airflow ui-api --daemon") + console.print("[magenta]Starting airflow fastapi-api --daemon") env = os.environ.copy() proc = subprocess.Popen( [ "airflow", - "ui-api", + "fastapi-api", "--daemon", "--pid", - os.fspath(pidfile_ui_api), + os.fspath(pidfile_fastapi_api), "--stdout", os.fspath(stdout), "--stderr", @@ -69,11 +69,11 @@ def test_cli_ui_api_background(self, tmp_path): pid_monitor = self._wait_pidfile(pidfile_monitor) console.print(f"[blue]Monitor started at {pid_monitor}") - pid_ui_api = self._wait_pidfile(pidfile_ui_api) - console.print(f"[blue]UI API started at {pid_ui_api}") - console.print("[blue]Running airflow ui-api process:") - # Assert that the ui-api and gunicorn processes are running (by name rather than pid). - assert self._find_process(r"airflow ui-api --daemon", print_found_process=True) + pid_fastapi_api = self._wait_pidfile(pidfile_fastapi_api) + console.print(f"[blue]FastAPI API started at {pid_fastapi_api}") + console.print("[blue]Running airflow fastapi-api process:") + # Assert that the fastapi-api and gunicorn processes are running (by name rather than pid). + assert self._find_process(r"airflow fastapi-api --daemon", print_found_process=True) console.print("[blue]Waiting for gunicorn processes:") # wait for gunicorn to start for _ in range(30): @@ -83,16 +83,16 @@ def test_cli_ui_api_background(self, tmp_path): time.sleep(1) console.print("[blue]Running gunicorn processes:") assert self._find_all_processes("^gunicorn", print_found_process=True) - console.print("[magenta]ui-api process started successfully.") + console.print("[magenta]fastapi-api process started successfully.") console.print( "[magenta]Terminating monitor process and expect " - "ui-api and gunicorn processes to terminate as well" + "fastapi-api and gunicorn processes to terminate as well" ) proc = psutil.Process(pid_monitor) proc.terminate() assert proc.wait(120) in (0, None) self._check_processes(ignore_running=False) - console.print("[magenta]All ui-api and gunicorn processes are terminated.") + console.print("[magenta]All fastapi-api and gunicorn processes are terminated.") except Exception: console.print("[red]Exception occurred. Dumping all logs.") # Dump all logs @@ -101,18 +101,20 @@ def test_cli_ui_api_background(self, tmp_path): console.print(file.read_text()) raise - def test_cli_ui_api_debug(self, app): - with mock.patch("subprocess.Popen") as Popen, mock.patch.object(ui_api_command, "GunicornMonitor"): + def test_cli_fastapi_api_debug(self, app): + with mock.patch("subprocess.Popen") as Popen, mock.patch.object( + fastapi_api_command, "GunicornMonitor" + ): port = "9092" hostname = "somehost" - args = self.parser.parse_args(["ui-api", "--port", port, "--hostname", hostname, "--debug"]) - ui_api_command.ui_api(args) + args = self.parser.parse_args(["fastapi-api", "--port", port, "--hostname", hostname, "--debug"]) + fastapi_api_command.fastapi_api(args) Popen.assert_called_with( [ "fastapi", "dev", - "airflow/api_ui/main.py", + "airflow/api_fastapi/main.py", "--port", port, "--host", @@ -121,18 +123,20 @@ def test_cli_ui_api_debug(self, app): close_fds=True, ) - def test_cli_ui_api_args(self): - with mock.patch("subprocess.Popen") as Popen, mock.patch.object(ui_api_command, "GunicornMonitor"): + def test_cli_fastapi_api_args(self): + with mock.patch("subprocess.Popen") as Popen, mock.patch.object( + fastapi_api_command, "GunicornMonitor" + ): args = self.parser.parse_args( [ - "ui-api", + "fastapi-api", "--access-logformat", "custom_log_format", "--pid", "/tmp/x.pid", ] ) - ui_api_command.ui_api(args) + fastapi_api_command.fastapi_api(args) Popen.assert_called_with( [ @@ -142,13 +146,13 @@ def test_cli_ui_api_args(self): "--workers", "4", "--worker-class", - "airflow.cli.commands.ui_api_command.AirflowUvicornWorker", + "airflow.cli.commands.fastapi_api_command.AirflowUvicornWorker", "--timeout", "120", "--bind", "0.0.0.0:9091", "--name", - "airflow-ui-api", + "airflow-fastapi-api", "--pid", "/tmp/x.pid", "--access-logfile", @@ -156,10 +160,10 @@ def test_cli_ui_api_args(self): "--error-logfile", "-", "--config", - "python:airflow.api_ui.gunicorn_config", + "python:airflow.api_fastapi.gunicorn_config", "--access-logformat", "custom_log_format", - "airflow.api_ui.app:cached_app()", + "airflow.api_fastapi.app:cached_app()", "--preload", ], close_fds=True, From 37dfde75d523fa35e178f697269842db1b1d5ef0 Mon Sep 17 00:00:00 2001 From: Hyunsoo Kang Date: Wed, 11 Sep 2024 02:15:46 +0900 Subject: [PATCH 014/349] Fix simple typo in the documentation. (#42058) --- docs/apache-airflow/howto/custom-operator.rst | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/docs/apache-airflow/howto/custom-operator.rst b/docs/apache-airflow/howto/custom-operator.rst index ce32654a6b690..f2b8db712dfa6 100644 --- a/docs/apache-airflow/howto/custom-operator.rst +++ b/docs/apache-airflow/howto/custom-operator.rst @@ -281,44 +281,44 @@ templated field: .. code-block:: python class HelloOperator(BaseOperator): - template_fields = "field_a" + template_fields = "foo" - def __init__(field_a_id) -> None: # <- should be def __init__(field_a)-> None - self.field_a = field_a_id # <- should be self.field_a = field_a + def __init__(self, foo_id) -> None: # should be def __init__(self, foo) -> None + self.foo = foo_id # should be self.foo = foo 2. Templated fields' instance members must be assigned with their corresponding parameter from the constructor, either by a direct assignment or by calling the parent's constructor (in which these fields are defined as ``template_fields``) with explicit an assignment of the parameter. -The following example is invalid, as the instance member ``self.field_a`` is not assigned at all, despite being a +The following example is invalid, as the instance member ``self.foo`` is not assigned at all, despite being a templated field: .. code-block:: python class HelloOperator(BaseOperator): - template_fields = ("field_a", "field_b") + template_fields = ("foo", "bar") - def __init__(field_a, field_b) -> None: - self.field_b = field_b + def __init__(self, foo, bar) -> None: + self.bar = bar -The following example is also invalid, as the instance member ``self.field_a`` of ``MyHelloOperator`` is initialized +The following example is also invalid, as the instance member ``self.foo`` of ``MyHelloOperator`` is initialized implicitly as part of the ``kwargs`` passed to its parent constructor: .. code-block:: python class HelloOperator(BaseOperator): - template_fields = "field_a" + template_fields = "foo" - def __init__(field_a) -> None: - self.field_a = field_a + def __init__(self, foo) -> None: + self.foo = foo class MyHelloOperator(HelloOperator): - template_fields = ("field_a", "field_b") + template_fields = ("foo", "bar") - def __init__(field_b, **kwargs) -> None: # <- should be def __init__(field_a, field_b, **kwargs) - super().__init__(**kwargs) # <- should be super().__init__(field_a=field_a, **kwargs) - self.field_b = field_b + def __init__(self, bar, **kwargs) -> None: # should be def __init__(self, foo, bar, **kwargs) + super().__init__(**kwargs) # should be super().__init__(foo=foo, **kwargs) + self.bar = bar 3. Applying actions on the parameter during the assignment in the constructor is not allowed. Any action on the value should be applied in the ``execute()`` method. @@ -327,10 +327,10 @@ Therefore, the following example is invalid: .. code-block:: python class HelloOperator(BaseOperator): - template_fields = "field_a" + template_fields = "foo" - def __init__(field_a) -> None: - self.field_a = field_a.lower() # <- assignment should be only self.field_a = field_a + def __init__(self, foo) -> None: + self.foo = foo.lower() # assignment should be only self.foo = foo When an operator inherits from a base operator and does not have a constructor defined on its own, the limitations above do not apply. However, the templated fields must be set properly in the parent according to those limitations. @@ -340,14 +340,14 @@ Thus, the following example is valid: .. code-block:: python class HelloOperator(BaseOperator): - template_fields = "field_a" + template_fields = "foo" - def __init__(field_a) -> None: - self.field_a = field_a + def __init__(self, foo) -> None: + self.foo = foo class MyHelloOperator(HelloOperator): - template_fields = "field_a" + template_fields = "foo" The limitations above are enforced by a pre-commit named 'validate-operators-init'. From 667283c327feca158f4b4847de3878899804e743 Mon Sep 17 00:00:00 2001 From: sc-anssi Date: Tue, 10 Sep 2024 19:22:00 +0200 Subject: [PATCH 015/349] Support multiline input for Params of type string in trigger UI form (#40414) * Add multiline input (textarea) support for Params of type string in trigger UI form * Use the 'format' attribute in the Param for rendering a multiline text area in the trigger UI form * Update example DAG and documentation to illustrate the use of a multiline text Param --- airflow/example_dags/example_params_ui_tutorial.py | 6 ++++++ airflow/www/templates/airflow/trigger.html | 6 ++++++ docs/apache-airflow/core-concepts/params.rst | 3 ++- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/airflow/example_dags/example_params_ui_tutorial.py b/airflow/example_dags/example_params_ui_tutorial.py index f7fd7e0844957..133df85d07b5c 100644 --- a/airflow/example_dags/example_params_ui_tutorial.py +++ b/airflow/example_dags/example_params_ui_tutorial.py @@ -165,6 +165,12 @@ title="Time Picker", description="Please select a time, use the button on the left for a pop-up tool.", ), + "multiline_text": Param( + "A multiline text Param\nthat will keep the newline\ncharacters in its value.", + description="This field allows for multiline text input. The returned value will be a single with newline (\\n) characters kept intact.", + type=["string", "null"], + format="multiline", + ), # Fields can be required or not. If the defined fields are typed they are getting required by default # (else they would not pass JSON schema validation) - to make typed fields optional you must # permit the optional "null" type. diff --git a/airflow/www/templates/airflow/trigger.html b/airflow/www/templates/airflow/trigger.html index 86e0b4eb4565a..7cdcd337beddf 100644 --- a/airflow/www/templates/airflow/trigger.html +++ b/airflow/www/templates/airflow/trigger.html @@ -127,6 +127,12 @@ {%- if form_details.schema.minimum %} min="{{ form_details.schema.minimum }}"{% endif %} {%- if form_details.schema.maximum %} max="{{ form_details.schema.maximum }}"{% endif %} {%- if form_details.schema.type and not "null" in form_details.schema.type %} required=""{% endif %} /> + {% elif form_details.schema and "string" in form_details.schema.type and "format" in form_details.schema and form_details.schema.format == "multiline" %} + {% else %} Date: Wed, 11 Sep 2024 02:27:53 +0900 Subject: [PATCH 016/349] docs: Remove outdated 'executor' reference from run() method docstring (#42121) --- airflow/models/dag.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 56f7dc89d25b1..95a2f8b6e3105 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2321,7 +2321,6 @@ def run( :param end_date: the end date of the range to run :param mark_success: True to mark jobs as succeeded without running them :param local: True to run the tasks using the LocalExecutor - :param executor: The executor instance to run the tasks :param donot_pickle: True to avoid pickling DAG object and send to workers :param ignore_task_deps: True to skip upstream tasks :param ignore_first_depends_on_past: True to ignore depends_on_past From 851f1cffd5f4ba0dd4a30658bd07732ba638f76f Mon Sep 17 00:00:00 2001 From: Pierre Jeambrun Date: Wed, 11 Sep 2024 03:09:38 +0800 Subject: [PATCH 017/349] Fix task_instance and dag_run links from list views (#42138) --- airflow/www/utils.py | 13 +++++++++++-- airflow/www/views.py | 24 +++++++++++++++--------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/airflow/www/utils.py b/airflow/www/utils.py index 653f2f1417c16..ef057adbf36ff 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -403,6 +403,8 @@ def task_instance_link(attr): task_id = attr.get("task_id") run_id = attr.get("run_id") map_index = attr.get("map_index", None) + execution_date = attr.get("execution_date") or attr.get("dag_run.execution_date") + if map_index == -1: map_index = None @@ -412,6 +414,7 @@ def task_instance_link(attr): task_id=task_id, dag_run_id=run_id, map_index=map_index, + execution_date=execution_date, tab="graph", ) url_root = url_for( @@ -421,6 +424,7 @@ def task_instance_link(attr): root=task_id, dag_run_id=run_id, map_index=map_index, + execution_date=execution_date, tab="graph", ) return Markup( @@ -500,10 +504,10 @@ def json_(attr): def dag_link(attr): """Generate a URL to the Graph view for a Dag.""" dag_id = attr.get("dag_id") - execution_date = attr.get("execution_date") + execution_date = attr.get("execution_date") or attr.get("dag_run.execution_date") if not dag_id: return Markup("None") - url = url_for("Airflow.graph", dag_id=dag_id, execution_date=execution_date) + url = url_for("Airflow.grid", dag_id=dag_id, execution_date=execution_date) return Markup('{}').format(url, dag_id) @@ -511,10 +515,15 @@ def dag_run_link(attr): """Generate a URL to the Graph view for a DagRun.""" dag_id = attr.get("dag_id") run_id = attr.get("run_id") + execution_date = attr.get("execution_date") or attr.get("dag_run.execution_date") + + if not dag_id: + return Markup("None") url = url_for( "Airflow.grid", dag_id=dag_id, + execution_date=execution_date, dag_run_id=run_id, tab="graph", ) diff --git a/airflow/www/views.py b/airflow/www/views.py index f060242f26686..0fbda988a7d4d 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -708,12 +708,16 @@ def show_traceback(error): "airflow/traceback.html", python_version=sys.version.split(" ")[0] if is_logged_in else "redacted", airflow_version=version if is_logged_in else "redacted", - hostname=get_hostname() - if conf.getboolean("webserver", "EXPOSE_HOSTNAME") and is_logged_in - else "redacted", - info=traceback.format_exc() - if conf.getboolean("webserver", "EXPOSE_STACKTRACE") and is_logged_in - else "Error! Please contact server admin.", + hostname=( + get_hostname() + if conf.getboolean("webserver", "EXPOSE_HOSTNAME") and is_logged_in + else "redacted" + ), + info=( + traceback.format_exc() + if conf.getboolean("webserver", "EXPOSE_STACKTRACE") and is_logged_in + else "Error! Please contact server admin." + ), ), 500, ) @@ -3439,9 +3443,11 @@ def next_run_datasets(self, dag_id): DatasetEvent, and_( DatasetEvent.dataset_id == DatasetModel.id, - DatasetEvent.timestamp >= latest_run.execution_date - if latest_run and latest_run.execution_date - else True, + ( + DatasetEvent.timestamp >= latest_run.execution_date + if latest_run and latest_run.execution_date + else True + ), ), isouter=True, ) From 3c2f8a60b5f492275d2c407a5ef6885acf743eb9 Mon Sep 17 00:00:00 2001 From: Bartosz Jankiewicz Date: Wed, 11 Sep 2024 01:19:01 +0200 Subject: [PATCH 018/349] AIP-59-performance-dags (#41961) --- .pre-commit-config.yaml | 2 +- performance/requirements.txt | 2 + .../performance_dag/performance_dag.py | 267 +++++++ .../bigquery_insert_job_workflow.json | 10 + .../blocking_workflow.json | 10 + .../scheduling_performance.json | 10 + .../single_workflow.json | 11 + .../skipping_workflow.json | 11 + .../tiny_workflow.json | 10 + .../tiny_workflow_extra_kwargs.json | 11 + .../worker_bash_task_template.json | 10 + .../worker_task_template.json | 10 + .../performance_dag/performance_dag_utils.py | 694 ++++++++++++++++++ performance/tests/test_performance_dag.py | 161 ++++ .../tests/test_performance_dag_utils.py | 49 ++ 15 files changed, 1267 insertions(+), 1 deletion(-) create mode 100644 performance/requirements.txt create mode 100644 performance/src/performance_dags/performance_dag/performance_dag.py create mode 100644 performance/src/performance_dags/performance_dag/performance_dag_configurations/bigquery_insert_job_workflow.json create mode 100644 performance/src/performance_dags/performance_dag/performance_dag_configurations/blocking_workflow.json create mode 100644 performance/src/performance_dags/performance_dag/performance_dag_configurations/scheduling_performance.json create mode 100644 performance/src/performance_dags/performance_dag/performance_dag_configurations/single_workflow.json create mode 100644 performance/src/performance_dags/performance_dag/performance_dag_configurations/skipping_workflow.json create mode 100644 performance/src/performance_dags/performance_dag/performance_dag_configurations/tiny_workflow.json create mode 100644 performance/src/performance_dags/performance_dag/performance_dag_configurations/tiny_workflow_extra_kwargs.json create mode 100644 performance/src/performance_dags/performance_dag/performance_dag_configurations/worker_bash_task_template.json create mode 100644 performance/src/performance_dags/performance_dag/performance_dag_configurations/worker_task_template.json create mode 100644 performance/src/performance_dags/performance_dag/performance_dag_utils.py create mode 100644 performance/tests/test_performance_dag.py create mode 100644 performance/tests/test_performance_dag_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6d6f77b34302b..8bb9ff4302a80 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -351,7 +351,7 @@ repos: args: [--fix] require_serial: true additional_dependencies: ["ruff==0.5.5"] - exclude: ^.*/.*_vendor/|^tests/dags/test_imports.py + exclude: ^.*/.*_vendor/|^tests/dags/test_imports.py|^performance/tests/test_.*.py - id: ruff-format name: Run 'ruff format' for extremely fast Python formatting description: "Run 'ruff format' for extremely fast Python formatting" diff --git a/performance/requirements.txt b/performance/requirements.txt new file mode 100644 index 0000000000000..81794e71d82cf --- /dev/null +++ b/performance/requirements.txt @@ -0,0 +1,2 @@ +apache-airflow==2.10.0 +openlineage-airflow==1.20.5 diff --git a/performance/src/performance_dags/performance_dag/performance_dag.py b/performance/src/performance_dags/performance_dag/performance_dag.py new file mode 100644 index 0000000000000..a07ec175dae95 --- /dev/null +++ b/performance/src/performance_dags/performance_dag/performance_dag.py @@ -0,0 +1,267 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Module generates number DAGs for the purpose of performance testing. + +The number of DAGs, +number, types of tasks in each DAG and the shape of the DAG are controlled through +environment variables: + +- `PERF_DAGS_COUNT` - number of DAGs to generate +- `PERF_TASKS_COUNT` - number of tasks in each DAG +- `PERF_START_DATE` - if not provided current time - `PERF_START_AGO` applies +- `PERF_START_AGO` - start time relative to current time used if PERF_START_DATE is not provided. Default `1h` +- `SCHEDULE_INTERVAL_ENV` - Schedule interval. Default `@once` +- `PERF_SHAPE` - shape of DAG. See `DagShape`. Default `NO_STRUCTURE` +- `PERF_SLEEP_TIME` - A non-negative float value specifying the time of sleep occurring + when each task is executed. Default `0` +- `PERF_OPERATOR_TYPE` - A string identifying the type of operator. Default `bash` +- `PERF_START_PAUSED` - Is DAG paused upon creation. Default `1` +- `PERF_TASKS_TRIGGER_RULE` - A string identifying the rule by which dependencies are applied + for the tasks to get triggered. Default `TriggerRule.ALL_SUCCESS`) +- `PERF_OPERATOR_EXTRA_KWARGS` - A dictionary with extra kwargs for operator + +""" + +from __future__ import annotations + +import enum +import json +import os +import time +from enum import Enum + +import re2 as re +from performance_dags.performance_dag.performance_dag_utils import ( + parse_schedule_interval, + parse_start_date, +) + +from airflow import DAG +from airflow.models.baseoperator import chain +from airflow.operators.bash import BashOperator +from airflow.operators.python import PythonOperator +from airflow.utils.trigger_rule import TriggerRule + +# DAG File used in performance tests. Its shape can be configured by environment variables. + + +def safe_dag_id(dag_id): + # type: (str) -> str + """Remove invalid characters for dag_id.""" + return re.sub("[^0-9a-zA-Z_]+", "_", dag_id) + + +def get_task_list( + dag_object, + operator_type_str, + task_count, + trigger_rule, + sleep_time, + operator_extra_kwargs, +): + # type: (DAG, str, int, float, int, dict) -> list[BaseOperator] + """ + Return list of tasks of test dag. + + :param dag_object: A DAG object the tasks should be assigned to. + :param operator_type_str: A string identifying the type of operator + :param task_count: An integer specifying the number of tasks to create. + :param trigger_rule: A string identifying the rule by which dependencies are applied + for the tasks to get triggered + :param sleep_time: A non-negative float value specifying the time of sleep occurring + when each task is executed + :param operator_extra_kwargs: A dictionary with extra kwargs for operator + :return list[BaseOperator]: a list of tasks + """ + if operator_type_str == "bash": + task_list = [ + BashOperator( + task_id="__".join(["tasks", f"{i}_of_{task_count}"]), + bash_command=f"sleep {sleep_time}; echo test", + dag=dag_object, + trigger_rule=trigger_rule, + **operator_extra_kwargs, + ) + for i in range(1, task_count + 1) + ] + elif operator_type_str == "python": + + def sleep_function(): + time.sleep(sleep_time) + print("test") + + task_list = [ + PythonOperator( + task_id="__".join(["tasks", f"{i}_of_{task_count}"]), + python_callable=sleep_function, + dag=dag_object, + trigger_rule=trigger_rule, + **operator_extra_kwargs, + ) + for i in range(1, task_count + 1) + ] + else: + raise ValueError(f"Unsupported operator type: {operator_type_str}.") + return task_list + + +def chain_as_binary_tree(*tasks): + # type: (BaseOperator) -> None + """ + Chain tasks as a binary tree where task i is child of task (i - 1) // 2. + + Example: + t0 -> t1 -> t3 -> t7 + | \ + | -> t4 -> t8 + | + -> t2 -> t5 -> t9 + \ + -> t6 + """ + for i in range(1, len(tasks)): + tasks[i].set_upstream(tasks[(i - 1) // 2]) + + +def chain_as_grid(*tasks): + # type: (BaseOperator) -> None + """ + Chain tasks as a grid. + + Example: + t0 -> t1 -> t2 -> t3 + | | | + v v v + t4 -> t5 -> t6 + | | + v v + t7 -> t8 + | + v + t9 + """ + if len(tasks) > 100 * 99 / 2: + raise ValueError("Cannot generate grid DAGs with lateral size larger than 100 tasks.") + grid_size = min([n for n in range(100) if n * (n + 1) / 2 >= len(tasks)]) + + def index(i, j): + """Return the index of node (i, j) on the grid.""" + return int(grid_size * i - i * (i - 1) / 2 + j) + + for i in range(grid_size - 1): + for j in range(grid_size - i - 1): + if index(i + 1, j) < len(tasks): + tasks[index(i + 1, j)].set_downstream(tasks[index(i, j)]) + if index(i, j + 1) < len(tasks): + tasks[index(i, j + 1)].set_downstream(tasks[index(i, j)]) + + +def chain_as_star(*tasks): + # type: (BaseOperator) -> None + """ + Chain tasks as a star (all tasks are children of task 0). + + Example: + t0 -> t1 + | -> t2 + | -> t3 + | -> t4 + | -> t5 + """ + tasks[0].set_downstream(list(tasks[1:])) + + +@enum.unique +class DagShape(Enum): + """Define shape of the Dag that will be used for testing.""" + + NO_STRUCTURE = "no_structure" + LINEAR = "linear" + BINARY_TREE = "binary_tree" + STAR = "star" + GRID = "grid" + + +DAG_COUNT = int(os.environ["PERF_DAGS_COUNT"]) +TASKS_COUNT = int(os.environ["PERF_TASKS_COUNT"]) +START_DATE, DAG_ID_START_DATE = parse_start_date( + os.environ["PERF_START_DATE"], os.environ.get("PERF_START_AGO", "1h") +) +SCHEDULE_INTERVAL_ENV = os.environ.get("PERF_SCHEDULE_INTERVAL", "@once") +SCHEDULE_INTERVAL = parse_schedule_interval(SCHEDULE_INTERVAL_ENV) +SHAPE = DagShape(os.environ["PERF_SHAPE"]) +SLEEP_TIME = float(os.environ.get("PERF_SLEEP_TIME", "0")) +OPERATOR_TYPE = os.environ.get("PERF_OPERATOR_TYPE", "bash") +START_PAUSED = bool(int(os.environ.get("PERF_START_PAUSED", "1"))) +TASKS_TRIGGER_RULE = os.environ.get("PERF_TASKS_TRIGGER_RULE", TriggerRule.ALL_SUCCESS) +OPERATOR_EXTRA_KWARGS = json.loads(os.environ.get("PERF_OPERATOR_EXTRA_KWARGS", "{}")) + +args = {"owner": "airflow", "start_date": START_DATE} + +if "PERF_MAX_RUNS" in os.environ: + if isinstance(SCHEDULE_INTERVAL, str): + raise ValueError("Can't set max runs with string-based schedule_interval") + if "PERF_START_DATE" not in os.environ: + raise ValueError( + "When using 'PERF_MAX_RUNS', please provide the start date as a date string in " + "'%Y-%m-%d %H:%M:%S.%f' format via 'PERF_START_DATE' environment variable." + ) + num_runs = int(os.environ["PERF_MAX_RUNS"]) + args["end_date"] = START_DATE + (SCHEDULE_INTERVAL * (num_runs - 1)) + +for dag_no in range(1, DAG_COUNT + 1): + dag = DAG( + dag_id=safe_dag_id( + "__".join( + [ + os.path.splitext(os.path.basename(__file__))[0], + f"SHAPE={SHAPE.name.lower()}", + f"DAGS_COUNT={dag_no}_of_{DAG_COUNT}", + f"TASKS_COUNT=${TASKS_COUNT}", + f"START_DATE=${DAG_ID_START_DATE}", + f"SCHEDULE_INTERVAL=${SCHEDULE_INTERVAL_ENV}", + ] + ) + ), + default_args=args, + schedule_interval=SCHEDULE_INTERVAL, + is_paused_upon_creation=START_PAUSED, + catchup=True, + ) + + performance_dag_tasks = get_task_list( + dag, + OPERATOR_TYPE, + TASKS_COUNT, + TASKS_TRIGGER_RULE, + SLEEP_TIME, + OPERATOR_EXTRA_KWARGS, + ) + + shape_function_map = { + DagShape.LINEAR: chain, + DagShape.BINARY_TREE: chain_as_binary_tree, + DagShape.STAR: chain_as_star, + DagShape.GRID: chain_as_grid, + } + if SHAPE != DagShape.NO_STRUCTURE: + shape_function_map[SHAPE](*performance_dag_tasks) + + globals()[f"dag_{dag_no}"] = dag diff --git a/performance/src/performance_dags/performance_dag/performance_dag_configurations/bigquery_insert_job_workflow.json b/performance/src/performance_dags/performance_dag/performance_dag_configurations/bigquery_insert_job_workflow.json new file mode 100644 index 0000000000000..562aa490d3448 --- /dev/null +++ b/performance/src/performance_dags/performance_dag/performance_dag_configurations/bigquery_insert_job_workflow.json @@ -0,0 +1,10 @@ +{ + "PERF_DAG_PREFIX": "big_query_insert_job_workflow", + "PERF_DAGS_COUNT": "10", + "PERF_TASKS_COUNT": "2", + "PERF_START_AGO": "1d", + "PERF_SCHEDULE_INTERVAL": "@once", + "PERF_SHAPE": "linear", + "PERF_SLEEP_TIME": "0", + "PERF_OPERATOR_TYPE": "big_query_insert_job" +} diff --git a/performance/src/performance_dags/performance_dag/performance_dag_configurations/blocking_workflow.json b/performance/src/performance_dags/performance_dag/performance_dag_configurations/blocking_workflow.json new file mode 100644 index 0000000000000..5990e96e504d1 --- /dev/null +++ b/performance/src/performance_dags/performance_dag/performance_dag_configurations/blocking_workflow.json @@ -0,0 +1,10 @@ +{ + "PERF_DAG_PREFIX": "sleep_workflow", + "PERF_DAGS_COUNT": "100", + "PERF_TASKS_COUNT": "100", + "PERF_START_AGO": "1d", + "PERF_SCHEDULE_INTERVAL": "@once", + "PERF_SHAPE": "no_structure", + "PERF_SLEEP_TIME": "31536000", + "PERF_OPERATOR_TYPE": "python" +} diff --git a/performance/src/performance_dags/performance_dag/performance_dag_configurations/scheduling_performance.json b/performance/src/performance_dags/performance_dag/performance_dag_configurations/scheduling_performance.json new file mode 100644 index 0000000000000..0567778d672dd --- /dev/null +++ b/performance/src/performance_dags/performance_dag/performance_dag_configurations/scheduling_performance.json @@ -0,0 +1,10 @@ +{ + "PERF_DAG_PREFIX": "workflow", + "PERF_DAGS_COUNT": "10", + "PERF_TASKS_COUNT": "100", + "PERF_START_AGO": "1d", + "PERF_SCHEDULE_INTERVAL": "@once", + "PERF_SHAPE": "no_structure", + "PERF_SLEEP_TIME": "0", + "PERF_OPERATOR_TYPE": "python" +} diff --git a/performance/src/performance_dags/performance_dag/performance_dag_configurations/single_workflow.json b/performance/src/performance_dags/performance_dag/performance_dag_configurations/single_workflow.json new file mode 100644 index 0000000000000..4ebea19c9c224 --- /dev/null +++ b/performance/src/performance_dags/performance_dag/performance_dag_configurations/single_workflow.json @@ -0,0 +1,11 @@ +{ + "PERF_DAG_PREFIX": "single_workflow", + "PERF_DAGS_COUNT": "1", + "PERF_TASKS_COUNT": "5000", + "PERF_START_AGO": "150m", + "PERF_SCHEDULE_INTERVAL": "150m", + "PERF_SHAPE": "no_structure", + "PERF_SLEEP_TIME": "0", + "PERF_OPERATOR_TYPE": "python", + "PERF_MAX_RUNS": "1" +} diff --git a/performance/src/performance_dags/performance_dag/performance_dag_configurations/skipping_workflow.json b/performance/src/performance_dags/performance_dag/performance_dag_configurations/skipping_workflow.json new file mode 100644 index 0000000000000..60ab967c2b6fd --- /dev/null +++ b/performance/src/performance_dags/performance_dag/performance_dag_configurations/skipping_workflow.json @@ -0,0 +1,11 @@ +{ + "PERF_DAG_PREFIX": "skipping_workflow", + "PERF_DAGS_COUNT": "100", + "PERF_TASKS_COUNT": "100", + "PERF_TASKS_TRIGGER_RULE": "all_failed", + "PERF_START_AGO": "1d", + "PERF_SCHEDULE_INTERVAL": "@once", + "PERF_SHAPE": "linear", + "PERF_SLEEP_TIME": "0", + "PERF_OPERATOR_TYPE": "python" +} diff --git a/performance/src/performance_dags/performance_dag/performance_dag_configurations/tiny_workflow.json b/performance/src/performance_dags/performance_dag/performance_dag_configurations/tiny_workflow.json new file mode 100644 index 0000000000000..efd2f93a48dce --- /dev/null +++ b/performance/src/performance_dags/performance_dag/performance_dag_configurations/tiny_workflow.json @@ -0,0 +1,10 @@ +{ + "PERF_DAG_PREFIX": "tiny_workflow", + "PERF_DAGS_COUNT": "10", + "PERF_TASKS_COUNT": "2", + "PERF_START_AGO": "1d", + "PERF_SCHEDULE_INTERVAL": "@once", + "PERF_SHAPE": "linear", + "PERF_SLEEP_TIME": "0", + "PERF_OPERATOR_TYPE": "python" +} diff --git a/performance/src/performance_dags/performance_dag/performance_dag_configurations/tiny_workflow_extra_kwargs.json b/performance/src/performance_dags/performance_dag/performance_dag_configurations/tiny_workflow_extra_kwargs.json new file mode 100644 index 0000000000000..51203f12e0d0c --- /dev/null +++ b/performance/src/performance_dags/performance_dag/performance_dag_configurations/tiny_workflow_extra_kwargs.json @@ -0,0 +1,11 @@ +{ + "PERF_DAG_PREFIX": "tiny_workflow", + "PERF_DAGS_COUNT": "10", + "PERF_TASKS_COUNT": "2", + "PERF_START_AGO": "1d", + "PERF_SCHEDULE_INTERVAL": "@once", + "PERF_SHAPE": "linear", + "PERF_SLEEP_TIME": "0", + "PERF_OPERATOR_TYPE": "python", + "PERF_OPERATOR_EXTRA_KWARGS": "{\"doc_md\": \"Test extra kwargs\"}" +} diff --git a/performance/src/performance_dags/performance_dag/performance_dag_configurations/worker_bash_task_template.json b/performance/src/performance_dags/performance_dag/performance_dag_configurations/worker_bash_task_template.json new file mode 100644 index 0000000000000..62ae72b7c529b --- /dev/null +++ b/performance/src/performance_dags/performance_dag/performance_dag_configurations/worker_bash_task_template.json @@ -0,0 +1,10 @@ +{ + "PERF_DAG_PREFIX": "worker_workflow", + "PERF_DAGS_COUNT": "200", + "PERF_TASKS_COUNT": "1", + "PERF_START_AGO": "1d", + "PERF_SCHEDULE_INTERVAL": "@once", + "PERF_SHAPE": "no_structure", + "PERF_SLEEP_TIME": "0", + "PERF_OPERATOR_TYPE": "bash" +} diff --git a/performance/src/performance_dags/performance_dag/performance_dag_configurations/worker_task_template.json b/performance/src/performance_dags/performance_dag/performance_dag_configurations/worker_task_template.json new file mode 100644 index 0000000000000..94c811f57e0c3 --- /dev/null +++ b/performance/src/performance_dags/performance_dag/performance_dag_configurations/worker_task_template.json @@ -0,0 +1,10 @@ +{ + "PERF_DAG_PREFIX": "worker_workflow", + "PERF_DAGS_COUNT": "200", + "PERF_TASKS_COUNT": "1", + "PERF_START_AGO": "1d", + "PERF_SCHEDULE_INTERVAL": "@once", + "PERF_SHAPE": "no_structure", + "PERF_SLEEP_TIME": "0", + "PERF_OPERATOR_TYPE": "python" +} diff --git a/performance/src/performance_dags/performance_dag/performance_dag_utils.py b/performance/src/performance_dags/performance_dag/performance_dag_utils.py new file mode 100644 index 0000000000000..8dc1ae49fa650 --- /dev/null +++ b/performance/src/performance_dags/performance_dag/performance_dag_utils.py @@ -0,0 +1,694 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import logging +import os +import tempfile +from collections import OrderedDict +from contextlib import contextmanager +from datetime import datetime, timedelta +from shutil import copyfile +from typing import Callable + +import re2 as re + +import airflow + +log = logging.getLogger(__name__) +log.setLevel(logging.INFO) + +MANDATORY_performance_DAG_VARIABLES = { + "PERF_DAGS_COUNT", + "PERF_TASKS_COUNT", + "PERF_SHAPE", + "PERF_START_DATE", +} + +performance_DAG_VARIABLES_DEFAULT_VALUES = { + "PERF_DAG_FILES_COUNT": "1", + "PERF_DAG_PREFIX": "perf_scheduler", + "PERF_START_AGO": "1h", + "PERF_SCHEDULE_INTERVAL": "@once", + "PERF_SLEEP_TIME": "0", + "PERF_OPERATOR_TYPE": "bash", + "PERF_MAX_RUNS": None, + "PERF_START_PAUSED": "1", +} + +ALLOWED_SHAPES = ("no_structure", "linear", "binary_tree", "star", "grid") + +ALLOWED_OPERATOR_TYPES = ("bash", "big_query_insert_job", "python") + +ALLOWED_TASKS_TRIGGER_RULES = ("all_success", "all_failed") + +# "None" schedule interval is not supported for now so that dag runs are created automatically +ALLOWED_NON_REGEXP_SCHEDULE_INTERVALS = ("@once",) + +DAG_IDS_NOT_ALLOWED_TO_MATCH_PREFIX = ("airflow_monitoring",) + +RE_TIME_DELTA = re.compile( + r"^((?P[\.\d]+?)d)?((?P[\.\d]+?)h)?((?P[\.\d]+?)m)?((?P[\.\d]+?)s)?$" +) + + +def add_perf_start_date_env_to_conf(performance_dag_conf: dict[str, str]) -> None: + """ + Calculate start date based on configuration. + + Calculates value for PERF_START_DATE environment variable and adds it to the performance_dag_conf + if it is not already present there. + + :param performance_dag_conf: dict with environment variables as keys and their values as values + """ + if "PERF_START_DATE" not in performance_dag_conf: + start_ago = get_performance_dag_environment_variable(performance_dag_conf, "PERF_START_AGO") + + perf_start_date = airflow.utils.timezone.utcnow - check_and_parse_time_delta( + "PERF_START_AGO", start_ago + ) + + performance_dag_conf["PERF_START_DATE"] = str(perf_start_date) + + +def validate_performance_dag_conf(performance_dag_conf: dict[str, str]) -> None: + """ + Check `performance_dag_conf` contains a valid configuration for performance DAG. + + :param performance_dag_conf: dict with environment variables as keys and their values as values + + :raises: + TypeError: if performance_dag_conf is not a dict + KeyError: if performance_dag_conf does not contain mandatory environment variables + ValueError: if any value in performance_dag_conf is not a string + """ + if not isinstance(performance_dag_conf, dict): + raise TypeError( + f"performance_dag configuration must be a dictionary containing at least following keys: " + f"{MANDATORY_performance_DAG_VARIABLES}." + ) + + missing_variables = MANDATORY_performance_DAG_VARIABLES.difference(set(performance_dag_conf.keys())) + + if missing_variables: + raise KeyError( + f"Following mandatory environment variables are missing " + f"from performance_dag configuration: {missing_variables}." + ) + + if not all(isinstance(env, str) for env in performance_dag_conf.values()): + raise ValueError("All values of variables must be strings.") + + variable_to_validation_fun_map = { + "PERF_DAGS_COUNT": check_positive_int_convertibility, + "PERF_TASKS_COUNT": check_positive_int_convertibility, + "PERF_START_DATE": check_datetime_convertibility, + "PERF_DAG_FILES_COUNT": check_positive_int_convertibility, + "PERF_DAG_PREFIX": check_dag_prefix, + "PERF_START_AGO": check_and_parse_time_delta, + "PERF_SCHEDULE_INTERVAL": check_schedule_interval, + "PERF_SHAPE": get_check_allowed_values_function(ALLOWED_SHAPES), + "PERF_SLEEP_TIME": check_non_negative_float_convertibility, + "PERF_OPERATOR_TYPE": get_check_allowed_values_function(ALLOWED_OPERATOR_TYPES), + "PERF_MAX_RUNS": check_positive_int_convertibility, + "PERF_START_PAUSED": check_int_convertibility, + "PERF_TASKS_TRIGGER_RULE": get_check_allowed_values_function(ALLOWED_TASKS_TRIGGER_RULES), + "PERF_OPERATOR_EXTRA_KWARGS": check_valid_json, + } + + # we do not need to validate default values of variables + for env_name in variable_to_validation_fun_map: + if env_name in performance_dag_conf: + variable_to_validation_fun_map[env_name](env_name, performance_dag_conf[env_name]) + + check_max_runs_and_schedule_interval_compatibility(performance_dag_conf) + + +def check_int_convertibility(env_name: str, env_value: str) -> None: + """ + Check if value of provided environment variable is convertible to int value. + + :param env_name: name of the environment variable which is being checked. + :param env_value: value of the variable. + + :raises: ValueError: if env_value could not be converted to int value + """ + try: + int(env_value) + except ValueError: + raise ValueError(f"{env_name} value must be convertible to int. Received: '{env_value}'.") + + +def check_positive_int_convertibility(env_name: str, env_value: str) -> None: + """ + Check if string value is a positive integer. + + :param env_name: name of the environment variable which is being checked. + :param env_value: value of the variable. + + :raises: ValueError: if env_value could not be converted to positive int value + """ + try: + converted_value = int(env_value) + check_positive(converted_value) + except ValueError: + raise ValueError(f"{env_name} value must be convertible to positive int. Received: '{env_value}'.") + + +def check_positive(value: int | float) -> None: + """Check if provided value is positive and raises ValueError otherwise.""" + if value <= 0: + raise ValueError + + +def check_datetime_convertibility(env_name: str, env_value: str) -> None: + """ + Check if value of provided environment variable is a date string in expected format. + + :param env_name: name of the environment variable which is being checked. + :param env_value: value of the variable. + """ + try: + datetime.strptime(env_value, "%Y-%m-%d %H:%M:%S.%f") + except Exception: + raise ValueError( + f"Value '{env_value}' of {env_name} cannot be converted " + f"to datetime object in '%Y-%m-%d %H:%M:%S.%f' format." + ) + + +def check_dag_prefix(env_name: str, env_value: str) -> None: + """ + Validate dag prefix value. + + Checks if value of dag prefix env variable is a prefix for one of the forbidden dag ids + (which would cause runs of corresponding DAGs to be collected alongside the real test Dag Runs). + :param env_name: name of the environment variable which is being checked. + :param env_value: value of the variable. + """ + # TODO: allow every environment type to specify its own "forbidden" matching dag ids + + safe_dag_prefix = safe_dag_id(env_value) + + matching_dag_ids = [ + dag_id for dag_id in DAG_IDS_NOT_ALLOWED_TO_MATCH_PREFIX if dag_id.startswith(safe_dag_prefix) + ] + + if matching_dag_ids: + raise ValueError( + f"Value '{env_value}' of {env_name} is not allowed as {safe_dag_prefix} is a prefix " + f"for the following forbidden dag ids: {matching_dag_ids}" + ) + + +def safe_dag_id(dag_id: str) -> str: + """Remove characters that are invalid in dag id from provided string.""" + return re.sub("[^0-9a-zA-Z_]+", "_", dag_id) + + +def check_and_parse_time_delta(env_name: str, env_value: str) -> timedelta: + """ + Validate and parse time delta value. + + Check if value of provided environment variable is a parsable time expression + and returns timedelta object with duration. + + :param env_name: name of the environment variable which is being checked. + :param env_value: value of the variable. + + :return: a timedelta object with the duration specified in env_value string + :rtype: timedelta + + :raises: ValueError: if env_value could not be parsed + """ + parts = RE_TIME_DELTA.match(env_value) + + if parts is None: + raise ValueError( + f"Could not parse any time information from '{env_value}' value of {env_name}. " + f"Examples of valid strings: '8h', '2d8h5m20s', '2m4s'" + ) + + time_params = {name: float(param) for name, param in parts.groupdict().items() if param} + return timedelta(**time_params) + + +def check_schedule_interval(env_name: str, env_value: str) -> None: + """ + Validate schedule_interval value. + + Checks if value of schedule_interval is a parsable time expression + or within a specified set of non-parsable values. + + :param env_name: name of the environment variable which is being checked. + :param env_value: value of the variable. + + :raises: ValueError: if env_value is neither a parsable time expression + nor one of allowed non-parsable values + """ + try: + check_and_parse_time_delta(env_name, env_value) + return + except ValueError as exception: + error_message = str(exception) + + check_allowed_values = get_check_allowed_values_function(ALLOWED_NON_REGEXP_SCHEDULE_INTERVALS) + + try: + check_allowed_values(env_name, env_value) + except ValueError: + log.error(error_message) + raise ValueError( + f"Value '{env_value}' of {env_name} is neither a parsable time expression " + f"nor one of the following: {ALLOWED_NON_REGEXP_SCHEDULE_INTERVALS}." + ) + + +def get_check_allowed_values_function( + values: tuple[str, ...], +) -> Callable[[str, str], None]: + """ + Return function that validates environment variable value. + + Returns a function which will check if value of provided environment variable + is within a specified set of values + + :param values: tuple of any length with allowed string values of environment variable + + :return: a function that checks if given environment variable's value is within the specified + set of values and raises ValueError otherwise + :rtype: Callable[[str, str], None] + """ + + def check_allowed_values(env_name: str, env_value: str) -> None: + """ + Check if value of provided environment variable is within a specified set of values. + + :param env_name: name of the environment variable which is being checked. + :param env_value: value of the variable. + + :raises: ValueError: if env_value is not within a specified set of values + """ + if env_value not in values: + raise ValueError( + f"{env_name} value must be one of the following: {values}. Received: '{env_value}'." + ) + + return check_allowed_values + + +def check_non_negative_float_convertibility(env_name: str, env_value: str) -> None: + """ + Validate if a string is parsable float. + + Checks if value of provided environment variable is convertible to non negative float value. + + :param env_name: name of the environment variable which is being checked. + :param env_value: value of the variable. + + :raises: ValueError: if env_value could not be converted to non negative float value + """ + try: + converted_value = float(env_value) + check_non_negative(converted_value) + except ValueError: + raise ValueError( + f"{env_name} value must be convertible to non negative float. Received: '{env_value}'." + ) + + +def check_non_negative(value: int | float) -> None: + """Check if provided value is not negative and raises ValueError otherwise.""" + if value < 0: + raise ValueError + + +def check_max_runs_and_schedule_interval_compatibility( + performance_dag_conf: dict[str, str], +) -> None: + """ + Validate max_runs value. + + Check if max_runs and schedule_interval values create a valid combination. + + :param performance_dag_conf: dict with environment variables as keys and their values as values + + :raises: ValueError: + if max_runs is specified when schedule_interval is not a duration time expression + if max_runs is not specified when schedule_interval is a duration time expression + if max_runs, schedule_interval and start_ago form a combination which causes end_date + to be in the future + """ + schedule_interval = get_performance_dag_environment_variable( + performance_dag_conf, "PERF_SCHEDULE_INTERVAL" + ) + max_runs = get_performance_dag_environment_variable(performance_dag_conf, "PERF_MAX_RUNS") + start_ago = get_performance_dag_environment_variable(performance_dag_conf, "PERF_START_AGO") + + if schedule_interval == "@once": + if max_runs is not None: + raise ValueError( + "PERF_MAX_RUNS is allowed only if PERF_SCHEDULE_INTERVAL is provided as a time expression." + ) + # if dags are set to be scheduled once, we do not need to check end_date + return + + if max_runs is None: + raise ValueError( + "PERF_MAX_RUNS must be specified if PERF_SCHEDULE_INTERVAL is provided as a time expression." + ) + + max_runs = int(max_runs) + + # make sure that the end_date does not occur in future + current_date = datetime.now() + + start_date = current_date - check_and_parse_time_delta("PERF_START_AGO", start_ago) + + end_date = start_date + ( + check_and_parse_time_delta("PERF_SCHEDULE_INTERVAL", schedule_interval) * (max_runs - 1) + ) + + if current_date < end_date: + raise ValueError( + f"PERF_START_AGO ({start_ago}), " + f"PERF_SCHEDULE_INTERVAL ({schedule_interval}) " + f"and PERF_MAX_RUNS ({max_runs}) " + f"must be specified in such a way that end_date does not occur in the future " + f"(end_date with provided values: {end_date})." + ) + + +def check_valid_json(env_name: str, env_value: str) -> None: + """ + Validate json string. + + Check if value of provided environment variable is a valid json. + + :param env_name: name of the environment variable which is being checked. + :param env_value: value of the variable. + """ + try: + json.loads(env_value) + except json.decoder.JSONDecodeError: + raise ValueError(f"Value '{env_value}' of {env_name} cannot be json decoded.") + + +@contextmanager +def generate_copies_of_performance_dag( + performance_dag_path: str, performance_dag_conf: dict[str, str] +) -> tuple[str, list[str]]: + """ + Create context manager that creates copies of DAG. + + Contextmanager that creates copies of performance DAG inside temporary directory using the + dag prefix env variable as a base for filenames. + + :param performance_dag_path: path to the performance DAG that should be copied. + :param performance_dag_conf: dict with environment variables as keys and their values as values. + + :yields: a pair consisting of path to the temporary directory + and a list with paths to copies of performance DAG + :type: Tuple[str, List[str]] + """ + dag_files_count = int( + get_performance_dag_environment_variable(performance_dag_conf, "PERF_DAG_FILES_COUNT") + ) + + safe_dag_prefix = get_dag_prefix(performance_dag_conf) + + with tempfile.TemporaryDirectory() as temp_dir: + performance_dag_copies = [] + + for i in range(1, dag_files_count + 1): + destination_filename = f"{safe_dag_prefix}_{i}.py" + destination_path = os.path.join(temp_dir, destination_filename) + + copyfile(performance_dag_path, destination_path) + performance_dag_copies.append(destination_path) + + yield temp_dir, performance_dag_copies + + +def get_dag_prefix(performance_dag_conf: dict[str, str]) -> str: + """ + Return DAG prefix. + + Returns prefix that will be assigned to DAGs created with given performance DAG configuration. + + :param performance_dag_conf: dict with environment variables as keys and their values as values + + :return: final form of prefix after substituting inappropriate characters + :rtype: str + """ + dag_prefix = get_performance_dag_environment_variable(performance_dag_conf, "PERF_DAG_PREFIX") + + safe_dag_prefix = safe_dag_id(dag_prefix) + + return safe_dag_prefix + + +def get_dags_count(performance_dag_conf: dict[str, str]) -> int: + """ + Return the number of test DAGs based on given performance DAG configuration. + + :param performance_dag_conf: dict with environment variables as keys and their values as values + + :return: number of test DAGs + :rtype: int + """ + dag_files_count = int( + get_performance_dag_environment_variable(performance_dag_conf, "PERF_DAG_FILES_COUNT") + ) + + dags_per_dag_file = int(get_performance_dag_environment_variable(performance_dag_conf, "PERF_DAGS_COUNT")) + + return dag_files_count * dags_per_dag_file + + +def calculate_number_of_dag_runs(performance_dag_conf: dict[str, str]) -> int: + """ + Calculate how many Dag Runs will be created with given performance DAG configuration. + + :param performance_dag_conf: dict with environment variables as keys and their values as values + + :return: total number of Dag Runs + :rtype: int + """ + max_runs = get_performance_dag_environment_variable(performance_dag_conf, "PERF_MAX_RUNS") + + total_dags_count = get_dags_count(performance_dag_conf) + + # if PERF_MAX_RUNS is missing from the configuration, + # it means that PERF_SCHEDULE_INTERVAL must be '@once' + if max_runs is None: + return total_dags_count + + return int(max_runs) * total_dags_count + + +def prepare_performance_dag_columns( + performance_dag_conf: dict[str, str], +) -> OrderedDict: + """ + Prepare dict containing DAG env variables. + + Prepare an OrderedDict containing chosen performance dag environment variables + that will serve as columns for the results dataframe. + + :param performance_dag_conf: dict with environment variables as keys and their values as values + + :return: a dict with a subset of environment variables + in order in which they should appear in the results dataframe + :rtype: OrderedDict + """ + max_runs = get_performance_dag_environment_variable(performance_dag_conf, "PERF_MAX_RUNS") + + # TODO: if PERF_MAX_RUNS is missing from configuration, then PERF_SCHEDULE_INTERVAL must + # be '@once'; this is an equivalent of PERF_MAX_RUNS being '1', which will be the default value + # once PERF_START_AGO and PERF_SCHEDULE_INTERVAL are removed + + # TODO: we should not ban PERF_SCHEDULE_INTERVAL completely because we will make it impossible + # to run time-based tests (where you run dags constantly for 1h for example). I think we should + # allow setting of only one of them. + # If PERF_MAX_RUNS is set, then PERF_SCHEDULE_INTERVAL should be ignored - default value of 1h + # should be used combined with PERF_START_AGO so that expected number of runs can be created immediately + # If PERF_SCHEDULE_INTERVAL is set and PERF_MAX_RUNS is not, then PERF_START_AGO should be set + # to current date so that dag runs start creating now instead of creating multiple runs from the + # past - but it will be rather hard taking into account time of environment creation. Wasn't + # there some dag option to NOT create past runs? -> catchup + # ALSO either PERF_MAX_RUNS or PERF_SCHEDULE_INTERVAL OR both should be included in results file + if max_runs is None: + max_runs = 1 + else: + max_runs = int(max_runs) + + performance_dag_columns = OrderedDict( + [ + ( + "PERF_DAG_FILES_COUNT", + int(get_performance_dag_environment_variable(performance_dag_conf, "PERF_DAG_FILES_COUNT")), + ), + ( + "PERF_DAGS_COUNT", + int(get_performance_dag_environment_variable(performance_dag_conf, "PERF_DAGS_COUNT")), + ), + ( + "PERF_TASKS_COUNT", + int(get_performance_dag_environment_variable(performance_dag_conf, "PERF_TASKS_COUNT")), + ), + ("PERF_MAX_RUNS", max_runs), + ( + "PERF_SCHEDULE_INTERVAL", + get_performance_dag_environment_variable(performance_dag_conf, "PERF_SCHEDULE_INTERVAL"), + ), + ( + "PERF_SHAPE", + get_performance_dag_environment_variable(performance_dag_conf, "PERF_SHAPE"), + ), + ( + "PERF_SLEEP_TIME", + float(get_performance_dag_environment_variable(performance_dag_conf, "PERF_SLEEP_TIME")), + ), + ( + "PERF_OPERATOR_TYPE", + get_performance_dag_environment_variable(performance_dag_conf, "PERF_OPERATOR_TYPE"), + ), + ] + ) + + add_performance_dag_configuration_type(performance_dag_columns) + + return performance_dag_columns + + +def get_performance_dag_environment_variable(performance_dag_conf: dict[str, str], env_name: str) -> str: + """ + Get env variable value. + + Returns value of environment variable with given env_name based on provided `performance_dag_conf`. + + :param performance_dag_conf: dict with environment variables as keys and their values as values + :param env_name: name of the environment variable value of which should be returned. + + :return: value of environment variable taken from performance_dag_conf or its default value, if it + was not present in the dictionary (if applicable) + :rtype: str + + :raises: ValueError: + if env_name is a mandatory environment variable but it is missing from performance_dag_conf + if env_name is not a valid name of an performance dag environment variable + """ + if env_name in MANDATORY_performance_DAG_VARIABLES: + if env_name not in performance_dag_conf: + raise ValueError( + f"Mandatory environment variable '{env_name}' " + f"is missing from performance dag configuration." + ) + return performance_dag_conf[env_name] + + if env_name not in performance_DAG_VARIABLES_DEFAULT_VALUES: + raise ValueError( + f"Provided environment variable '{env_name}' is not a valid performance dag" + f"configuration variable." + ) + + return performance_dag_conf.get(env_name, performance_DAG_VARIABLES_DEFAULT_VALUES[env_name]) + + +def add_performance_dag_configuration_type( + performance_dag_columns: OrderedDict, +) -> None: + """ + Add a key with type of given performance dag configuration to the columns dict. + + :param performance_dag_columns: a dict with columns containing performance dag configuration + """ + performance_dag_configuration_type = "__".join( + [ + f"{performance_dag_columns['PERF_SHAPE']}", + f"{performance_dag_columns['PERF_DAG_FILES_COUNT']}_dag_files", + f"{performance_dag_columns['PERF_DAGS_COUNT']}_dags", + f"{performance_dag_columns['PERF_TASKS_COUNT']}_tasks", + f"{performance_dag_columns['PERF_MAX_RUNS']}_dag_runs", + f"{performance_dag_columns['PERF_SLEEP_TIME']}_sleep", + f"{performance_dag_columns['PERF_OPERATOR_TYPE']}_operator", + ] + ) + + performance_dag_columns.update({"performance_dag_configuration_type": performance_dag_configuration_type}) + + # move the type key to the beginning of dict + performance_dag_columns.move_to_end("performance_dag_configuration_type", last=False) + + +def parse_time_delta(time_str): + # type: (str) -> datetime.timedelta + """ + Parse a time string e.g. (2h13m) into a timedelta object. + + :param time_str: A string identifying a duration. (eg. 2h13m) + :return datetime.timedelta: A datetime.timedelta object or "@once" + """ + parts = RE_TIME_DELTA.match(time_str) + + if parts is None: + raise ValueError( + f"Could not parse any time information from '{time_str}'. " + "Examples of valid strings: '8h', '2d8h5m20s', '2m4s'" + ) + + time_params = {name: float(param) for name, param in parts.groupdict().items() if param} + return timedelta(**time_params) # type: ignore + + +def parse_start_date(date, start_ago): + """ + Parse date or relative distance to current time. + + Returns the start date for the performance DAGs and string to be used as part of their ids. + + :return Tuple[datetime.datetime, str]: A tuple of datetime.datetime object to be used + as a start_date and a string that should be used as part of the dag_id. + """ + if date: + start_date = datetime.strptime(date, "%Y-%m-%d %H:%M:%S.%f") + dag_id_component = str(int(start_date.timestamp())) + else: + start_date = datetime.now() - parse_time_delta(start_ago) + dag_id_component = start_ago + return start_date, dag_id_component + + +def parse_schedule_interval(time_str): + # type: (str) -> datetime.timedelta + """ + Parse a schedule interval string e.g. (2h13m) or "@once". + + :param time_str: A string identifying a schedule interval. (eg. 2h13m, None, @once) + :return datetime.timedelta: A datetime.timedelta object or "@once" or None + """ + if time_str == "None": + return None + + if time_str == "@once": + return "@once" + + return parse_time_delta(time_str) diff --git a/performance/tests/test_performance_dag.py b/performance/tests/test_performance_dag.py new file mode 100644 index 0000000000000..e66ceac5a0e26 --- /dev/null +++ b/performance/tests/test_performance_dag.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# Note: Any AirflowException raised is expected to cause the TaskInstance +# to be marked in an ERROR state + +from __future__ import annotations + +import json +import os + +import pytest +import re2 as re +from airflow.configuration import conf +from airflow.models import DagBag +from airflow.utils.trigger_rule import TriggerRule + +DAGS_DIR = os.path.join(os.path.dirname(__file__), "../src/performance_dags/performance_dag") + + +def setup_dag( + dag_count="1", + task_count="10", + start_date="", + start_ago="1h", + schedule_interval_env="@once", + dag_shape="no_structure", + sleep_time="0", + operator_type="bash", + start_paused="1", + task_trigger_rule=TriggerRule.ALL_SUCCESS, + **extra_args, +): + os.environ["PERF_DAGS_COUNT"] = dag_count + os.environ["PERF_TASKS_COUNT"] = task_count + os.environ["PERF_START_DATE"] = start_date + os.environ["PERF_START_AGO"] = start_ago + os.environ["PERF_SCHEDULE_INTERVAL"] = schedule_interval_env + os.environ["PERF_SHAPE"] = dag_shape + os.environ["PERF_SLEEP_TIME"] = sleep_time + os.environ["PERF_OPERATOR_TYPE"] = operator_type + os.environ["PERF_START_PAUSED"] = start_paused + os.environ["PERF_TASKS_TRIGGER_RULE"] = task_trigger_rule + os.environ["PERF_OPERATOR_EXTRA_KWARGS"] = json.dumps(extra_args) + + +def get_top_level_tasks(dag): + result = [] + for task in dag.tasks: + if not task.upstream_list: + result.append(task) + return result + + +def get_leaf_tasks(dag): + result = [] + for task in dag.tasks: + if not task.downstream_list: + result.append(task) + return result + + +# Test fixture +@pytest.fixture(scope="session", autouse=True) +def airflow_config(): + """ + Update airflow config for the test. + + It sets the following configuration values: + - core.unit_test_mode: True + - lineage.backend: "" + + Returns: + AirflowConfigParser: The Airflow configuration object. + """ + conf.set("lineage", "backend", "") + return conf + + +def get_dags(dag_count=1, task_count=10, operator_type="bash", dag_shape="no_structure"): + """Generate a tuple of dag_id, in the DagBag.""" + setup_dag( + task_count=str(task_count), + dag_count=str(dag_count), + operator_type=operator_type, + dag_shape=dag_shape, + ) + dag_bag = DagBag(DAGS_DIR, include_examples=False) + + def strip_path_prefix(path): + return os.path.relpath(path, DAGS_DIR) + + return [(k, v, strip_path_prefix(v.fileloc)) for k, v in dag_bag.dags.items()] + + +def get_import_errors(): + """Generate a tuple for import errors in the dag bag.""" + dag_bag = DagBag(DAGS_DIR, include_examples=False) + + def strip_path_prefix(path): + return os.path.relpath(path, DAGS_DIR) + + # prepend "(None,None)" to ensure that a test object is always created even if it's a no op. + return [(None, None)] + [(strip_path_prefix(k), v.strip()) for k, v in dag_bag.import_errors.items()] + + +@pytest.mark.parametrize("rel_path,rv", get_import_errors(), ids=[x[0] for x in get_import_errors()]) +def test_file_imports(rel_path, rv): + """Test for import errors on a file.""" + if rel_path and rv: + pytest.fail(f"{rel_path} failed to import with message \n {rv}") + + +@pytest.mark.parametrize("dag_count,task_count", [(1, 1), (1, 10), (10, 10), (10, 100)]) +def test_performance_dag(dag_count, task_count): + dags = get_dags(dag_count=dag_count, task_count=task_count) + assert len(dags) == dag_count + ids = [x[0] for x in dags] + pattern = f"performance_dag__SHAPE_no_structure__DAGS_COUNT_\\d+_of_{dag_count}__TASKS_COUNT_{task_count}__START_DATE_1h__SCHEDULE_INTERVAL_once" + for id in ids: + assert re.search(pattern, id) + for dag in dags: + performance_dag = dag[1] + assert len(performance_dag.tasks) == task_count, f"DAG has no {task_count} tasks" + for task in performance_dag.tasks: + t_rule = task.trigger_rule + assert t_rule == "all_success", f"{task} in DAG has the trigger rule {t_rule}" + assert task.operator_name == "BashOperator", f"{task} should be based on bash operator" + + +def test_performance_dag_shape_binary_tree(): + def assert_two_downstream(task): + assert len(task.downstream_list) <= 2 + for downstream_task in task.downstream_list: + assert_two_downstream(downstream_task) + + dags = get_dags(task_count=100, dag_shape="binary_tree") + id, dag, _ = dags[0] + assert ( + id + == "performance_dag__SHAPE_binary_tree__DAGS_COUNT_1_of_1__TASKS_COUNT_100__START_DATE_1h__SCHEDULE_INTERVAL_once" + ) + assert len(dag.tasks) == 100 + top_level_tasks = get_top_level_tasks(dag) + assert len(top_level_tasks) == 1 + for task in top_level_tasks: + assert_two_downstream(task) diff --git a/performance/tests/test_performance_dag_utils.py b/performance/tests/test_performance_dag_utils.py new file mode 100644 index 0000000000000..52e4ee984d2a1 --- /dev/null +++ b/performance/tests/test_performance_dag_utils.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# Note: Any AirflowException raised is expected to cause the TaskInstance +# to be marked in an ERROR state +from __future__ import annotations + +from datetime import datetime, timedelta + +from performance_dags.performance_dag.performance_dag_utils import ( + parse_start_date, + parse_time_delta, +) + + +def test_parse_time_delta(): + assert parse_time_delta("1h") == timedelta(hours=1) + + +def test_parse_start_date_from_date(): + now = datetime.now() + formatted_time = now.strftime("%Y-%m-%d %H:%M:%S.%f") + date, dag_id_component = parse_start_date(formatted_time, "1h") + assert date == now + assert dag_id_component == str(int(now.timestamp())) + + +def test_parse_start_date_from_offset(): + now = datetime.now() + one_hour_ago = now - timedelta(hours=1) + date, dag_id_component = parse_start_date(None, "1h") + one_hour_ago_timestamp = int(one_hour_ago.timestamp()) + one_hour_ago_result = int(date.timestamp()) + assert abs(one_hour_ago_result - one_hour_ago_timestamp) < 1000 + assert dag_id_component == "1h" From b65a408ea5b9072bae59638af5838e4afb35cb9e Mon Sep 17 00:00:00 2001 From: Brent Bovenzi Date: Tue, 10 Sep 2024 19:49:15 -0400 Subject: [PATCH 019/349] Add Ryan Hamilton as ui codeowner (#42150) --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6740439d54e3b..befcd854edca5 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -30,7 +30,7 @@ /airflow/www/ @ryanahamilton @ashb @bbovenzi @pierrejeambrun # UI -/airflow/ui/ @bbovenzi @pierrejeambrun +/airflow/ui/ @bbovenzi @pierrejeambrun @ryanahamilton # Security/Permissions /airflow/api_connexion/security.py @jhtimmins From bc6d19dd952dd13bda1208254667c45dff2e79cb Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Wed, 11 Sep 2024 06:49:38 -0700 Subject: [PATCH 020/349] Actually move saml to amazon provider (mistakenly added in papermill) (#42148) Follow up after #42137 -> saml was added mistakenly to papermill, not amazon :( --- airflow/providers/amazon/provider.yaml | 1 + airflow/providers/papermill/provider.yaml | 1 - generated/provider_dependencies.json | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 54d36ca420635..b6854d666e0ae 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -108,6 +108,7 @@ dependencies: - asgiref>=2.3.0 - PyAthena>=3.0.10 - jmespath>=0.7.0 + - python3-saml>=1.16.0 additional-extras: - name: pandas diff --git a/airflow/providers/papermill/provider.yaml b/airflow/providers/papermill/provider.yaml index ff738b2ec73d1..afd273a69e1b2 100644 --- a/airflow/providers/papermill/provider.yaml +++ b/airflow/providers/papermill/provider.yaml @@ -57,7 +57,6 @@ dependencies: - ipykernel - pandas>=2.1.2,<2.2;python_version>="3.9" - pandas>=1.5.3,<2.2;python_version<"3.9" - - python3-saml>=1.16.0 integrations: - integration-name: Papermill diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 3ea4df282d780..f70a493d1f888 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -36,6 +36,7 @@ "inflection>=0.5.1", "jmespath>=0.7.0", "jsonpath_ng>=1.5.3", + "python3-saml>=1.16.0", "redshift_connector>=2.0.918", "sqlalchemy_redshift>=0.8.6", "watchtower>=3.0.0,!=3.3.0,<4" @@ -1014,7 +1015,6 @@ "pandas>=1.5.3,<2.2;python_version<\"3.9\"", "pandas>=2.1.2,<2.2;python_version>=\"3.9\"", "papermill[all]>=2.6.0", - "python3-saml>=1.16.0", "scrapbook[all]" ], "devel-deps": [], From 7d853a1d3d51bcfbe98149f27f5d9cb78096f7a5 Mon Sep 17 00:00:00 2001 From: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:48:39 +0200 Subject: [PATCH 021/349] Pin airbyte-api to 0.51.0 (#42154) (#42155) --- airflow/providers/airbyte/provider.yaml | 2 +- generated/provider_dependencies.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/airbyte/provider.yaml b/airflow/providers/airbyte/provider.yaml index ce1e84cbb4651..e421b1c66058c 100644 --- a/airflow/providers/airbyte/provider.yaml +++ b/airflow/providers/airbyte/provider.yaml @@ -51,7 +51,7 @@ versions: dependencies: - apache-airflow>=2.8.0 - - airbyte-api>=0.51.0 + - airbyte-api==0.51.0 # v0.52.0 breaks hooks, see https://github.com/apache/airflow/issues/42154 integrations: - integration-name: Airbyte diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index f70a493d1f888..1c0ad3f1cae66 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1,7 +1,7 @@ { "airbyte": { "deps": [ - "airbyte-api>=0.51.0", + "airbyte-api==0.51.0", "apache-airflow>=2.8.0" ], "devel-deps": [], From f3a6cc8e8d368d291d384af2cc08830fd218840a Mon Sep 17 00:00:00 2001 From: VladaZakharova Date: Wed, 11 Sep 2024 18:49:02 +0200 Subject: [PATCH 022/349] Fix vertex AI system tests (#42153) --- .../example_vertex_ai_batch_prediction_job.py | 2 +- .../example_vertex_ai_custom_container.py | 2 +- .../example_vertex_ai_model_service.py | 7 ++- .../example_vertex_ai_pipeline_job.py | 50 ++++++++++++------- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py index 3f2dfc60ec0c0..38198b5526874 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py @@ -87,7 +87,7 @@ "county": "categorical", } -BIGQUERY_SOURCE = f"bq://{PROJECT_ID}.test_iowa_liquor_sales_forecasting_us.2021_sales_predict" +BIGQUERY_SOURCE = "bq://airflow-system-tests-resources.vertex_ai_training_dataset.data" GCS_DESTINATION_PREFIX = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/output" MODEL_PARAMETERS: dict[str, str] = {} diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py index f039877f7134a..b8d01f8d71493 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py @@ -69,7 +69,7 @@ def TABULAR_DATASET(bucket_name): CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest" -CUSTOM_CONTAINER_URI = f"us-central1-docker.pkg.dev/{PROJECT_ID}/system-tests/housing:latest" +CUSTOM_CONTAINER_URI = "us-central1-docker.pkg.dev/airflow-system-tests-resources/system-tests/housing" MODEL_SERVING_CONTAINER_URI = "gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest" REPLICA_COUNT = 1 MACHINE_TYPE = "n1-standard-4" diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py index b80eaabdd58db..e6ad1e710e4c3 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py @@ -87,12 +87,11 @@ CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest" -# VERTEX_AI_LOCAL_TRAINING_SCRIPT_PATH should be set for Airflow which is running on distributed system. +# LOCAL_TRAINING_SCRIPT_PATH should be set for Airflow which is running on distributed system. # For example in Composer the correct path is `gcs/data/california_housing_training_script.py`. # Because `gcs/data/` is shared folder for Airflow's workers. -LOCAL_TRAINING_SCRIPT_PATH = os.environ.get( - "VERTEX_AI_LOCAL_TRAINING_SCRIPT_PATH", "california_housing_training_script.py" -) +IS_COMPOSER = bool(os.environ.get("COMPOSER_ENVIRONMENT", "")) +LOCAL_TRAINING_SCRIPT_PATH = "gcs/data/california_housing_training_script.py" if IS_COMPOSER else "" MODEL_OUTPUT_CONFIG = { "artifact_destination": { diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_pipeline_job.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_pipeline_job.py index 130effbb86c39..7dd4aa84fe41e 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_pipeline_job.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_pipeline_job.py @@ -26,6 +26,10 @@ import os from datetime import datetime +from google.cloud.aiplatform import schema +from google.protobuf.json_format import ParseDict +from google.protobuf.struct_pb2 import Value + from airflow.models.dag import DAG from airflow.providers.google.cloud.operators.gcs import ( GCSCreateBucketOperator, @@ -34,6 +38,7 @@ GCSListObjectsOperator, GCSSynchronizeBucketsOperator, ) +from airflow.providers.google.cloud.operators.vertex_ai.dataset import CreateDatasetOperator from airflow.providers.google.cloud.operators.vertex_ai.pipeline_job import ( DeletePipelineJobOperator, GetPipelineJobOperator, @@ -50,22 +55,25 @@ RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") -TEMPLATE_PATH = "https://us-kfp.pkg.dev/ml-pipeline/google-cloud-registry/automl-tabular/sha256:85e4218fc6604ee82353c9d2ebba20289eb1b71930798c0bb8ce32d8a10de146" +TEMPLATE_PATH = "https://us-kfp.pkg.dev/ml-pipeline/google-cloud-registry/get-vertex-dataset/sha256:f4eb4a2b0aab482c487c1cd62b3c735baaf914be8fa8c4687c06077c1d815a5d" OUTPUT_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}" PARAMETER_VALUES = { - "train_budget_milli_node_hours": 2000, - "optimization_objective": "minimize-log-loss", - "project": PROJECT_ID, - "location": REGION, - "root_dir": OUTPUT_BUCKET, - "target_column": "Adopted", - "training_fraction": 0.8, - "validation_fraction": 0.1, - "test_fraction": 0.1, - "prediction_type": "classification", - "data_source_csv_filenames": f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/vertex-ai/tabular-dataset.csv", - "transformations": f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/vertex-ai/column_transformations.json", + "dataset_resource_name": f"projects/{PROJECT_ID}/locations/{REGION}/datasets/tabular-dataset-{ENV_ID}", +} + +DATA_SAMPLE_GCS_OBJECT_NAME = "vertex-ai/california_housing_train.csv" +TABULAR_DATASET = { + "display_name": f"tabular-dataset-{ENV_ID}", + "metadata_schema_uri": schema.dataset.metadata.tabular, + "metadata": ParseDict( + { + "input_config": { + "gcs_source": {"uri": [f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/{DATA_SAMPLE_GCS_OBJECT_NAME}"]} + } + }, + Value(), + ), } @@ -83,15 +91,22 @@ location=REGION, ) - move_pipeline_files = GCSSynchronizeBucketsOperator( - task_id="move_files_to_bucket", + move_dataset_files = GCSSynchronizeBucketsOperator( + task_id="move_dataset_files_to_bucket", source_bucket=RESOURCE_DATA_BUCKET, - source_object="vertex-ai/pipeline", + source_object="vertex-ai/california-housing-data", destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, destination_object="vertex-ai", recursive=True, ) + create_dataset = CreateDatasetOperator( + task_id="tabular_dataset", + dataset=TABULAR_DATASET, + region=REGION, + project_id=PROJECT_ID, + ) + # [START how_to_cloud_vertex_ai_run_pipeline_job_operator] run_pipeline_job = RunPipelineJobOperator( task_id="run_pipeline_job", @@ -147,7 +162,8 @@ ( # TEST SETUP create_bucket - >> move_pipeline_files + >> move_dataset_files + >> create_dataset # TEST BODY >> run_pipeline_job >> get_pipeline_job From 4aadf477ffacd74e74327df3abfe2082ccdfa968 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 11 Sep 2024 09:49:16 -0700 Subject: [PATCH 023/349] Bump apache-airflow from 2.10.0 to 2.10.1 in /performance (#42149) Bumps [apache-airflow](https://github.com/apache/airflow) from 2.10.0 to 2.10.1. - [Release notes](https://github.com/apache/airflow/releases) - [Changelog](https://github.com/apache/airflow/blob/main/RELEASE_NOTES.rst) - [Commits](https://github.com/apache/airflow/compare/2.10.0...2.10.1) --- updated-dependencies: - dependency-name: apache-airflow dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- performance/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/performance/requirements.txt b/performance/requirements.txt index 81794e71d82cf..89bc2226d225b 100644 --- a/performance/requirements.txt +++ b/performance/requirements.txt @@ -1,2 +1,2 @@ -apache-airflow==2.10.0 +apache-airflow==2.10.1 openlineage-airflow==1.20.5 From ca9bb810c57de02f624e5a5e5c94c51b99b3c867 Mon Sep 17 00:00:00 2001 From: max <42827971+moiseenkov@users.noreply.github.com> Date: Wed, 11 Sep 2024 16:52:19 +0000 Subject: [PATCH 024/349] Exclude partition from BigQuery table name (#42130) Please enter the commit message for your changes. Lines starting --- .../providers/google/cloud/hooks/bigquery.py | 13 +++- .../google/cloud/hooks/test_bigquery.py | 71 +++++++++++++++++-- 2 files changed, 76 insertions(+), 8 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 8e330b17d5cab..b1aed15c458ca 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -2415,10 +2415,13 @@ def var_print(var_name): table_id = cmpt[1] else: raise ValueError( - f"{var_print(var_name)} Expect format of (.
- {header.isPlaceholder ? null : ( -
- {flexRender( - header.column.columnDef.header, - header.getContext() - )} -
- )} -
+ {isPlaceholder ? null : ( + <>{flexRender(column.columnDef.header, getContext())} + )} + {canSort && !sort && ( + + )} + {canSort && + sort && + (sort === "desc" ? ( + + ) : ( + + ))} +
, " + f"{var_print(var_name)}Expect format of (.
, " f"got {table_input}" ) + # Exclude partition from the table name + table_id = table_id.split("$")[0] + if project_id is None: if var_name is not None: self.log.info( @@ -3282,6 +3285,11 @@ def _escape(s: str) -> str: return e +@deprecated( + planned_removal_date="April 01, 2025", + use_instead="airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.split_tablename", + category=AirflowProviderDeprecationWarning, +) def split_tablename( table_input: str, default_project_id: str, var_name: str | None = None ) -> tuple[str, str, str]: @@ -3330,6 +3338,9 @@ def var_print(var_name): f"{var_print(var_name)}Expect format of (.
, got {table_input}" ) + # Exclude partition from the table name + table_id = table_id.split("$")[0] + if project_id is None: if var_name is not None: log.info( diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index fcee80d224e96..81db43c0f5310 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -1030,12 +1030,11 @@ def test_query_results(self, _, selected_fields, result): == result ) - -class TestBigQueryTableSplitter: - def test_internal_need_default_project(self): + def test_split_tablename_internal_need_default_project(self): with pytest.raises(ValueError, match="INTERNAL: No default project is specified"): - split_tablename("dataset.table", None) + self.hook.split_tablename("dataset.table", None) + @pytest.mark.parametrize("partition", ["$partition", ""]) @pytest.mark.parametrize( "project_expected, dataset_expected, table_expected, table_input", [ @@ -1046,9 +1045,11 @@ def test_internal_need_default_project(self): ("alt1:alt", "dataset", "table", "alt1:alt:dataset.table"), ], ) - def test_split_tablename(self, project_expected, dataset_expected, table_expected, table_input): + def test_split_tablename( + self, project_expected, dataset_expected, table_expected, table_input, partition + ): default_project_id = "project" - project, dataset, table = split_tablename(table_input, default_project_id) + project, dataset, table = self.hook.split_tablename(table_input + partition, default_project_id) assert project_expected == project assert dataset_expected == dataset assert table_expected == table @@ -1080,9 +1081,65 @@ def test_split_tablename(self, project_expected, dataset_expected, table_expecte ), ], ) - def test_invalid_syntax(self, table_input, var_name, exception_message): + def test_split_tablename_invalid_syntax(self, table_input, var_name, exception_message): default_project_id = "project" with pytest.raises(ValueError, match=exception_message.format(table_input)): + self.hook.split_tablename(table_input, default_project_id, var_name) + + +class TestBigQueryTableSplitter: + def test_internal_need_default_project(self): + with pytest.raises(AirflowProviderDeprecationWarning): + split_tablename("dataset.table", None) + + @pytest.mark.parametrize("partition", ["$partition", ""]) + @pytest.mark.parametrize( + "project_expected, dataset_expected, table_expected, table_input", + [ + ("project", "dataset", "table", "dataset.table"), + ("alternative", "dataset", "table", "alternative:dataset.table"), + ("alternative", "dataset", "table", "alternative.dataset.table"), + ("alt1:alt", "dataset", "table", "alt1:alt.dataset.table"), + ("alt1:alt", "dataset", "table", "alt1:alt:dataset.table"), + ], + ) + def test_split_tablename( + self, project_expected, dataset_expected, table_expected, table_input, partition + ): + default_project_id = "project" + with pytest.raises(AirflowProviderDeprecationWarning): + split_tablename(table_input + partition, default_project_id) + + @pytest.mark.parametrize( + "table_input, var_name, exception_message", + [ + ("alt1:alt2:alt3:dataset.table", None, "Use either : or . to specify project got {}"), + ( + "alt1.alt.dataset.table", + None, + r"Expect format of \(\.
, got {}", + ), + ( + "alt1:alt2:alt.dataset.table", + "var_x", + "Format exception for var_x: Use either : or . to specify project got {}", + ), + ( + "alt1:alt2:alt:dataset.table", + "var_x", + "Format exception for var_x: Use either : or . to specify project got {}", + ), + ( + "alt1.alt.dataset.table", + "var_x", + r"Format exception for var_x: Expect format of " + r"\(.
, got {}", + ), + ], + ) + def test_invalid_syntax(self, table_input, var_name, exception_message): + default_project_id = "project" + with pytest.raises(AirflowProviderDeprecationWarning): split_tablename(table_input, default_project_id, var_name) From 30998f9f8e3280c4e0cbb84081faf60fb27b81c3 Mon Sep 17 00:00:00 2001 From: Yuan Li <82419490+yuan-glu@users.noreply.github.com> Date: Wed, 11 Sep 2024 11:11:16 -0700 Subject: [PATCH 025/349] Added EA to NTHEWILD.md (#42161) --- INTHEWILD.md | 1 + 1 file changed, 1 insertion(+) diff --git a/INTHEWILD.md b/INTHEWILD.md index aacf82fac2049..9dab2090c8a93 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -188,6 +188,7 @@ Currently, **officially** using Airflow: 1. [Easy Taxi](http://www.easytaxi.com/) [[@caique-lima](https://github.com/caique-lima) & [@diraol](https://github.com/diraol)] 1. [EBANX](https://www.ebanx.com/) [[@diogodilcl](https://github.com/diogodilcl) & [@estevammr](https://github.com/estevammr) & [@filipe-banzoli](https://github.com/filipe-banzoli) & [@lara-clink](https://github.com/lara-clink) & [@Lucasdsvenancio](https://github.com/Lucasdsvenancio) & [@mariotaddeucci](https://github.com/mariotaddeucci) & [@nadiapetramont](https://github.com/nadiapetramont) & [@nathangngencissk](https://github.com/nathangngencissk) & [@patrickjuan](https://github.com/patrickjuan) & [@raafaadg](https://github.com/raafaadg) & [@samebanx](https://github.com/samebanx) & [@thiagoschonrock](https://github.com/thiagoschonrock) & [@whrocha](https://github.com/whrocha)] 1. [Elai Data](https://www.elaidata.com/) [[@lgov](https://github.com/lgov)] +1. [Electronic Arts](https://www.ea.com/) [[@yuan-glu](https://github.com/yuan-glu)] 1. [EllisDon](http://www.ellisdon.com/) [[@d2kalra](https://github.com/d2kalra) & [@zbasama](https://github.com/zbasama)] 1. [Endesa](https://www.endesa.com) [[@drexpp](https://github.com/drexpp)] 1. [ENECHANGE Ltd.](https://enechange.co.jp/) [[@enechange](https://github.com/enechange)] From efac09442eab8a629b06d19212fd490a61f09c49 Mon Sep 17 00:00:00 2001 From: Ryan Hatter <25823361+RNHTTR@users.noreply.github.com> Date: Wed, 11 Sep 2024 11:17:44 -0700 Subject: [PATCH 026/349] document that running task instances will be marked as skipped when a dagrun times out (#41410) --- airflow/models/dag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 95a2f8b6e3105..6545293ccf89c 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -430,8 +430,8 @@ class DAG(LoggingMixin): new active DAG runs :param max_consecutive_failed_dag_runs: (experimental) maximum number of consecutive failed DAG runs, beyond this the scheduler will disable the DAG - :param dagrun_timeout: specify how long a DagRun should be up before - timing out / failing, so that new DagRuns can be created. + :param dagrun_timeout: Specify the duration a DagRun should be allowed to run before it times out or + fails. Task instances that are running when a DagRun is timed out will be marked as skipped. :param sla_miss_callback: specify a function or list of functions to call when reporting SLA timeouts. See :ref:`sla_miss_callback` for more information about the function signature and parameters that are From 175b591b6e2ce816b4b777cb9e63e16bac4ef7ce Mon Sep 17 00:00:00 2001 From: Elad Kalif <45845474+eladkal@users.noreply.github.com> Date: Wed, 11 Sep 2024 11:19:24 -0700 Subject: [PATCH 027/349] Add Airflow 3 development readme (#41457) * Add Airflow 3 development readme * fix typo * Update dev/README_AIRFLOW3_DEV.md Co-authored-by: Jarek Potiuk * restructure doc * fix static checks * updates * fixes * fix spell checks * clarify cherry picking * add milestones * fix doc * Update dev/README_AIRFLOW3_DEV.md Co-authored-by: Ephraim Anierobi * add 2.11 mental model * Update dev/README_AIRFLOW3_DEV.md Co-authored-by: Jarek Potiuk * fixes --------- Co-authored-by: Jarek Potiuk Co-authored-by: Ephraim Anierobi --- dev/README_AIRFLOW3_DEV.md | 160 +++++++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 dev/README_AIRFLOW3_DEV.md diff --git a/dev/README_AIRFLOW3_DEV.md b/dev/README_AIRFLOW3_DEV.md new file mode 100644 index 0000000000000..fb626843bdc85 --- /dev/null +++ b/dev/README_AIRFLOW3_DEV.md @@ -0,0 +1,160 @@ + + + +**Table of contents** + +- [Main branch is Airflow 3](#main-branch-is-airflow-3) +- [Contributors](#contributors) + - [Developing for providers and Helm chart](#developing-for-providers-and-helm-chart) + - [Developing for Airflow 3 and 2.10.x / 2.11.x](#developing-for-airflow-3-and-210x--211x) + - [Developing for Airflow 3](#developing-for-airflow-3) + - [Developing for Airflow 2.10.x](#developing-for-airflow-210x) + - [Developing for Airflow 2.11](#developing-for-airflow-211) +- [Committers / PMCs](#committers--pmcs) + - [Merging PRs for providers and Helm chart](#merging-prs-for-providers-and-helm-chart) + - [Merging PR for Airflow 3 and 2.10.x / 2.11.x](#merging-pr-for-airflow-3-and-210x--211x) + - [Merging PRs 2.10.x](#merging-prs-210x) + - [Merging PRs for Airflow 3](#merging-prs-for-airflow-3) + - [Merging PRs for Airflow 2.11](#merging-prs-for-airflow-211) +- [Milestones for PR](#milestones-for-pr) + - [Set 2.10.x milestone](#set-210x-milestone) + - [Set 2.11 milestone](#set-211-milestone) + - [Set 3 milestone](#set-3-milestone) + + + +# Main branch is Airflow 3 + +The main branch is for development of Airflow 3. +Airflow 2.10.x releases will be cut from `v2-10-stable` branch. +Airflow 2.11.x releases will be cut from `v2-11-stable` branch. + +# Contributors + +The following section explains to which branches you should target your PR. + +## Developing for providers and Helm chart + +PRs should target `main` branch. +Make sure your code is only about Providers or Helm chart. +Avoid mixing core changes into the same PR + +## Developing for Airflow 3 and 2.10.x / 2.11.x + +If the PR is relevant for both Airflow 3 and 2, it should target `main` branch. + +Note: The mental model of Airflow 2.11 is bridge release for Airflow 3. +As a result, Airflow 2.11 is not planned to introduce new features other than ones relevant as bridge release for Airflow 3. +That said, we recognize that there may be exceptions. +If you believe a specific feature is a must-have for Airflow 2.11, you will need to raise this as discussion thread on the mailing list. +Points to address to make your case: + +1. You must clarify what is the urgency (i.e., why it can't wait for Airflow 3). +2. You need be willing to deliver the feature for both main branch and Airflow 2.11 branch. +3. You must be willing to provide support future bug fixes as needed. + +Points to consider on how PMC members evaluate the request of exception: + +1. Feature impact - Is it really urgent? How many are affected? +2. Workarounds - Are there any ? +3. Scope of change - Both in code lines / number of files and components changed. +4. Centrality - Is the feature at the heart of Airflow (scheduler, dag parser) or peripheral. +5. Identity of the requester - Is the request from/supported by a member of the community? +6. Similar previous cases approved. +7. Other considerations that may raise by PMC members depending on the case. + +## Developing for Airflow 3 + +PRs should target `main` branch. + +## Developing for Airflow 2.10.x + +PR should target `v2-10-test` branch. When cutting a new release for 2.10 release manager +will sync `v2-10-test` branch to `v2-10-stable` and cut the release from the stable branch. +PR should never target `v2-10-stable` unless specifically instructed by release manager. + +## Developing for Airflow 2.11 + +Version 2.11 is planned to be cut from `v2-10-stable` branch. +The version will contain features relevant as bridge release for Airflow 3. +We will not backport otherwise features from main branch to 2.11 +Note that 2.11 policy may change as 2.11 becomes closer. + +# Committers / PMCs + +The following sections explains the protocol for merging PRs. + +## Merging PRs for providers and Helm chart + +Make sure PR targets `main` branch. +Avoid merging PRs that involve providers + core / helm chart + core +Core part should be extracted to a separated PR. +Exclusions should be pre-approved specifically with a comment by release manager. +Do not treat PR approval (Green V) as exclusion approval. + +## Merging PR for Airflow 3 and 2.10.x / 2.11.x + +The committer who merges the PR is responsible for backporting the PR to `v2-10-test`. +It means that they should create a new PR where the original commit from main is cherry-picked and take care for resolving conflicts. +If the cherry-pick is too complex, then ask the PR author / start your own PR against `v2-10-test` directly with the change. +Note: tracking that the PRs merged as expected is the responsibility of committer who merged the PR. + +Committer may also request from PR author to raise 2 PRs one against `main` branch and one against `v2-10-test` prior to accepting the code change. + +Mistakes happen, and such backport PR work might fall through cracks. Therefore, if the committer thinks that certain PRs should be backported, they should set 2.10.x milestone for them. + +This way release manager can verify (as usual) if all the "expected" PRs have been backported and cherry-pick remaining PRS. + +## Merging PRs 2.10.x + +Make sure PR targets `v2-10-test` branch and merge it when ready. +Make sure your PRs target the `v2-10-test` branch, and it can be merged when ready. +All regular protocols of merging considerations are applied. + +## Merging PRs for Airflow 3 + +Make sure PR target `main` branch. + +### PRs that involve breaking changes + +Make sure it has newsfragment, please allow time for community members to review. +Our goal is to avoid breaking changes whenever possible. Therefore, we should allow time for community members to review PRs that contain such changes - please avoid rushing to merge them. In addition, ensure that these PRs include a newsfragment. + +## Merging PRs for Airflow 2.11 + +TBD + +# Milestones for PR + +## Set 2.10.x milestone + +Milestone will be added only to the original PR. + +1. PR targeting `v2-10-test` directly - mlinestone will be on that PR. +2. PR targeting `main` with backport PR targeting `v2-10-test`. Milestone will be added only on the PR targeting `v2-10-main`. + +## Set 2.11 milestone + +TBD +For now, similar procedure as 2.10.x + +## Set 3 milestone + +Set for any feature that targets Airflow 3 only. From 6d33d5efee9d7996beef1da8c70eb7a1a474bdd5 Mon Sep 17 00:00:00 2001 From: mayankymailusfedu <35581473+mayankymailusfedu@users.noreply.github.com> Date: Wed, 11 Sep 2024 11:19:45 -0700 Subject: [PATCH 028/349] Added Aisera to the list of companies using Apache Airflow (#42162) Added Aisera to INTHEWILD.md Co-authored-by: Ryan Hatter <25823361+RNHTTR@users.noreply.github.com> --- INTHEWILD.md | 1 + 1 file changed, 1 insertion(+) diff --git a/INTHEWILD.md b/INTHEWILD.md index 9dab2090c8a93..13ad6f02c27c2 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -40,6 +40,7 @@ Currently, **officially** using Airflow: 1. [AirDNA](https://www.airdna.co) 1. [Airfinity](https://www.airfinity.com) [[@sibowyer](https://github.com/sibowyer)] 1. [Airtel](https://www.airtel.in/) [[@harishbisht](https://github.com/harishbisht)] +1. [Aisera](https://aisera.com/) [[@mayankymailusfedu](https://github.com/mayankymailusfedu)] 1. [Akamai](https://www.akamai.com/) [[@anirudhbagri](https://github.com/anirudhbagri)] 1. [Akamas](https://akamas.io) [[@GiovanniPaoloGibilisco](https://github.com/GiovanniPaoloGibilisco), [@lucacavazzana](https://github.com/lucacavazzana)] 1. [Alan](https://alan.eu) [[@charles-go](https://github.com/charles-go)] From c50099ffa233649079527b00807b8c7d62a2cabf Mon Sep 17 00:00:00 2001 From: Karthik Dulam <101580964+kaydee-edb@users.noreply.github.com> Date: Wed, 11 Sep 2024 11:35:23 -0700 Subject: [PATCH 029/349] Added EnterpriseDB to the list of companies using Airflow (#42163) --- INTHEWILD.md | 1 + 1 file changed, 1 insertion(+) diff --git a/INTHEWILD.md b/INTHEWILD.md index 13ad6f02c27c2..29d788503d08e 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -195,6 +195,7 @@ Currently, **officially** using Airflow: 1. [ENECHANGE Ltd.](https://enechange.co.jp/) [[@enechange](https://github.com/enechange)] 1. [Energy Solutions](https://www.energy-solution.com) [[@energy-solution](https://github.com/energy-solution/)] 1. [Enigma](https://www.enigma.com) [[@hydrosquall](https://github.com/hydrosquall)] +1. [EnterpriseDB](https://www.Enterprisedb.com) [[@kaydee-edb](https://github.com/kaydee-edb)] 1. [Ericsson](https://www.ericsson.com) [[@khalidzohaib](https://github.com/khalidzohaib)] 1. [Estrategia Educacional](https://github.com/estrategiahq) [[@jonasrla](https://github.com/jonasrla)] 1. [Etsy](https://www.etsy.com) [[@mchalek](https://github.com/mchalek)] From 38c3c28f77ac68187fa150bd9f22b8e93264e174 Mon Sep 17 00:00:00 2001 From: althati <153687755+althati@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:35:53 -0400 Subject: [PATCH 030/349] Adding my company (#42165) --- INTHEWILD.md | 1 + 1 file changed, 1 insertion(+) diff --git a/INTHEWILD.md b/INTHEWILD.md index 29d788503d08e..7389d7a985d74 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -168,6 +168,7 @@ Currently, **officially** using Airflow: 1. [dataroots](https://dataroots.io/) [[@datarootsio]](https://github.com/datarootsio) 1. [DataSprints](https://datasprints.com/) [[@lopesdiego12](https://github.com/lopesdiego12) & [@rafaelsantanaep](https://github.com/rafaelsantanaep)] 1. [Datatonic](https://datatonic.com/) [[@teamdatatonic](https://github.com/teamdatatonic)] +1. [Datavant](https://datavant.com)/) [@althati(https://github.com/althati)] 1. [Datumo](https://datumo.io) [[@michalmisiewicz](https://github.com/michalmisiewicz)] 1. [Dcard](https://www.dcard.tw/) [[@damon09273](https://github.com/damon09273) & [@bruce3557](https://github.com/bruce3557) & [@kevin1kevin1k](http://github.com/kevin1kevin1k)] 1. [Delft University of Technology](https://www.tudelft.nl/en/) [[@saveriogzz](https://github.com/saveriogzz)] From e17209fe552b2bc99fb2d0fce77f1fbf505837ea Mon Sep 17 00:00:00 2001 From: hikaruhk <25138528+hikaruhk@users.noreply.github.com> Date: Wed, 11 Sep 2024 12:01:43 -0700 Subject: [PATCH 031/349] Added BTIG to the list of companies using Apache Airflow (#42169) Co-authored-by: HikaruT Co-authored-by: Ryan Hatter <25823361+RNHTTR@users.noreply.github.com> --- INTHEWILD.md | 1 + 1 file changed, 1 insertion(+) diff --git a/INTHEWILD.md b/INTHEWILD.md index 7389d7a985d74..f197d07e0a7a5 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -105,6 +105,7 @@ Currently, **officially** using Airflow: 1. [Braintree](https://www.braintreepayments.com) [[@coopergillan](https://github.com/coopergillan), [@curiousjazz77](https://github.com/curiousjazz77), [@raymondberg](https://github.com/raymondberg)] 1. [Branch](https://branch.io) [[@sdebarshi](https://github.com/sdebarshi), [@dmitrig01](https://github.com/dmitrig01)] 1. [Breezeline (formerly Atlantic Broadband)](https://www.breezeline.com/) [[@IanDoarn](https://github.com/IanDoarn), [@willsims14](https://github.com/willsims14)] +1. [BTIG](https://www.btig.com/) [[@hikaruhk](https://github.com/hikaruhk)] 1. [BWGI](https://www.bwgi.com.br/) [[@jgmarcel](https://github.com/jgmarcel)] 1. [Bwtech](https://www.bwtech.com/) [[@wolvery](https://github.com/wolvery)] 1. [C2FO](https://www.c2fo.com/) From c540060487fb941ccce8aafe235c9740399ca536 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Wed, 11 Sep 2024 12:04:21 -0700 Subject: [PATCH 032/349] Fix wrong casing in airbyte hook. (#42170) Fixes: #42154 --- airflow/providers/airbyte/hooks/airbyte.py | 2 +- airflow/providers/airbyte/provider.yaml | 2 +- generated/provider_dependencies.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/providers/airbyte/hooks/airbyte.py b/airflow/providers/airbyte/hooks/airbyte.py index d530f3e6685eb..f2a63080b6898 100644 --- a/airflow/providers/airbyte/hooks/airbyte.py +++ b/airflow/providers/airbyte/hooks/airbyte.py @@ -71,7 +71,7 @@ def create_api_session(self) -> AirbyteAPI: credentials = SchemeClientCredentials( client_id=self.conn["client_id"], client_secret=self.conn["client_secret"], - TOKEN_URL=self.conn["token_url"], + token_url=self.conn["token_url"], ) return AirbyteAPI( diff --git a/airflow/providers/airbyte/provider.yaml b/airflow/providers/airbyte/provider.yaml index e421b1c66058c..7db89834ecf5b 100644 --- a/airflow/providers/airbyte/provider.yaml +++ b/airflow/providers/airbyte/provider.yaml @@ -51,7 +51,7 @@ versions: dependencies: - apache-airflow>=2.8.0 - - airbyte-api==0.51.0 # v0.52.0 breaks hooks, see https://github.com/apache/airflow/issues/42154 + - airbyte-api>=0.52.0 integrations: - integration-name: Airbyte diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 1c0ad3f1cae66..a520c4c07ea43 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1,7 +1,7 @@ { "airbyte": { "deps": [ - "airbyte-api==0.51.0", + "airbyte-api>=0.52.0", "apache-airflow>=2.8.0" ], "devel-deps": [], From c182244433ccde5a08f1b242cb1bbd419e400bb0 Mon Sep 17 00:00:00 2001 From: Vikram Medabalimi Date: Wed, 11 Sep 2024 12:18:23 -0700 Subject: [PATCH 033/349] Added Autodesk to list of companies using Airflow In the wild (#42176) Co-authored-by: Vikram Medabalimi --- INTHEWILD.md | 1 + 1 file changed, 1 insertion(+) diff --git a/INTHEWILD.md b/INTHEWILD.md index f197d07e0a7a5..1fe38394e9522 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -62,6 +62,7 @@ Currently, **officially** using Airflow: 1. [Astronomer](https://www.astronomer.io) [[@schnie](https://github.com/schnie), [@ashb](https://github.com/ashb), [@kaxil](https://github.com/kaxil), [@dimberman](https://github.com/dimberman), [@andriisoldatenko](https://github.com/andriisoldatenko), [@ryw](https://github.com/ryw), [@ryanahamilton](https://github.com/ryanahamilton), [@jhtimmins](https://github.com/jhtimmins), [@vikramkoka](https://github.com/vikramkoka), [@jedcunningham](https://github.com/jedcunningham), [@BasPH](https://github.com/basph), [@ephraimbuddy](https://github.com/ephraimbuddy), [@feluelle](https://github.com/feluelle)] 1. [Audiomack](https://audiomack.com) [[@billcrook](https://github.com/billcrook)] 1. [Auth0](https://auth0.com) [[@scottypate](https://github.com/scottypate)], [[@dm03514](https://github.com/dm03514)], [[@karangale](https://github.com/karangale)] +1. [Autodesk](https://autodesk.com) 1. [Automattic](https://automattic.com/) [[@anandnalya](https://github.com/anandnalya), [@bperson](https://github.com/bperson), [@khrol](https://github.com/Khrol), [@xyu](https://github.com/xyu)] 1. [Avesta Technologies](https://avestatechnologies.com) [[@TheRum](https://github.com/TheRum)] 1. [Aviva plc](https://www.aviva.com) [[@panatimahesh](https://github.com/panatimahesh)] From 48b6d81bb923cd4d2ba759d9631bc488feac9499 Mon Sep 17 00:00:00 2001 From: ipsatrivedi <45980322+ipsatrivedi@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:18:59 -0500 Subject: [PATCH 034/349] Adding Tekmetric to Users list (#42167) Co-authored-by: Ipsa Trivedi <> --- INTHEWILD.md | 1 + 1 file changed, 1 insertion(+) diff --git a/INTHEWILD.md b/INTHEWILD.md index 1fe38394e9522..f9c25f34b0361 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -466,6 +466,7 @@ Currently, **officially** using Airflow: 1. [Talkdesk](https://www.talkdesk.com) 1. [Tapsi](https://tapsi.ir/) 1. [TEK](https://www.tek.fi/en) [[@telac](https://github.com/telac)] +1. [Tekmetric] (https://www.tekmetric.com/) 1. [Telefonica Innovation Alpha](https://www.alpha.company/) [[@Alpha-Health](https://github.com/Alpha-health)] 1. [Telia Company](https://www.teliacompany.com/en) 1. [Ternary Data](https://ternarydata.com/) [[@mhousley](https://github.com/mhousley), [@JoeReis](https://github.com/JoeReis)] From 16d9bd09b53b44b0c703e2042fe64099c84466ea Mon Sep 17 00:00:00 2001 From: svellaiyan Date: Wed, 11 Sep 2024 12:19:10 -0700 Subject: [PATCH 035/349] Added my current company. We use Airflow for production workloads. (#42168) --- INTHEWILD.md | 1 + 1 file changed, 1 insertion(+) diff --git a/INTHEWILD.md b/INTHEWILD.md index f9c25f34b0361..7587b7c4a80e4 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -199,6 +199,7 @@ Currently, **officially** using Airflow: 1. [Energy Solutions](https://www.energy-solution.com) [[@energy-solution](https://github.com/energy-solution/)] 1. [Enigma](https://www.enigma.com) [[@hydrosquall](https://github.com/hydrosquall)] 1. [EnterpriseDB](https://www.Enterprisedb.com) [[@kaydee-edb](https://github.com/kaydee-edb)] +1. [Envestnet](https://www.envestnet.com/) [[@svellaiyan](https://github.com/svellaiyan)] 1. [Ericsson](https://www.ericsson.com) [[@khalidzohaib](https://github.com/khalidzohaib)] 1. [Estrategia Educacional](https://github.com/estrategiahq) [[@jonasrla](https://github.com/jonasrla)] 1. [Etsy](https://www.etsy.com) [[@mchalek](https://github.com/mchalek)] From 001bc67ea9099fcb810898bcc0a455af7fedcdf5 Mon Sep 17 00:00:00 2001 From: skandala23 Date: Wed, 11 Sep 2024 15:20:27 -0400 Subject: [PATCH 036/349] adding airflowuser for bloomberg (#42173) * adding airflowuser for bloomberg * Update INTHEWILD.md --------- Co-authored-by: Ryan Hatter <25823361+RNHTTR@users.noreply.github.com> --- INTHEWILD.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/INTHEWILD.md b/INTHEWILD.md index 7587b7c4a80e4..2093c16db726e 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -89,7 +89,7 @@ Currently, **officially** using Airflow: 1. [BlaBlaCar](https://www.blablacar.com) [[@puckel](https://github.com/puckel) & [@wmorin](https://github.com/wmorin)] 1. [Blacklane](https://www.blacklane.com) [[@serkef](https://github.com/serkef)] 1. [Bloc](https://www.bloc.io) [[@dpaola2](https://github.com/dpaola2)] -1. [Bloomberg](https://www.techatbloomberg.com) [[@dimberman](https://github.com/dimberman)] +1. [Bloomberg](https://www.techatbloomberg.com) [[@skandala23] (https://github.com/skandala23)] 1. [Bloomreach](https://www.bloomreach.com/) [[@neelborooah](https://github.com/neelborooah) & [@debodirno](https://github.com/debodirno) & [@ayushmnnit](https://github.com/ayushmnnit)] 1. [Blue Yonder](http://www.blue-yonder.com) [[@blue-yonder](https://github.com/blue-yonder)] 1. [BlueApron](https://www.blueapron.com) [[@jasonjho](https://github.com/jasonjho) & [@matthewdavidhauser](https://github.com/matthewdavidhauser)] From 286025bc563289e870056c13819f2fb796ed3cab Mon Sep 17 00:00:00 2001 From: Shoaib UR Rehman <23278048+srehman420@users.noreply.github.com> Date: Thu, 12 Sep 2024 00:25:55 +0500 Subject: [PATCH 037/349] Adding support for volume configurations in ECSRunTaskOperator (#42087) * Adding support for volume configurations in ECSRunTaskOperator * Updating template fields override test to include volume_configurations --------- Co-authored-by: glory9211 --- airflow/providers/amazon/aws/operators/ecs.py | 9 +++++++++ tests/providers/amazon/aws/operators/test_ecs.py | 1 + 2 files changed, 10 insertions(+) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 1cd8685cf282f..433fd88cd636e 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -373,6 +373,10 @@ class EcsRunTaskOperator(EcsBaseOperator): When capacity_provider_strategy is specified, the launch_type parameter is omitted. If no capacity_provider_strategy or launch_type is specified, the default capacity provider strategy for the cluster is used. + :param volume_configurations: the volume configurations to use when using capacity provider. The name of the volume must match + the name from the task definition. + You can configure the settings like size, volume type, IOPS, throughput and others mentioned in + (https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_TaskManagedEBSVolumeConfiguration.html) :param group: the name of the task group associated with the task :param placement_constraints: an array of placement constraint objects to use for the task @@ -420,6 +424,7 @@ class EcsRunTaskOperator(EcsBaseOperator): "overrides", "launch_type", "capacity_provider_strategy", + "volume_configurations", "group", "placement_constraints", "placement_strategy", @@ -450,6 +455,7 @@ def __init__( overrides: dict, launch_type: str = "EC2", capacity_provider_strategy: list | None = None, + volume_configurations: list | None = None, group: str | None = None, placement_constraints: list | None = None, placement_strategy: list | None = None, @@ -479,6 +485,7 @@ def __init__( self.overrides = overrides self.launch_type = launch_type self.capacity_provider_strategy = capacity_provider_strategy + self.volume_configurations = volume_configurations self.group = group self.placement_constraints = placement_constraints self.placement_strategy = placement_strategy @@ -614,6 +621,8 @@ def _start_task(self): if self.capacity_provider_strategy: run_opts["capacityProviderStrategy"] = self.capacity_provider_strategy + if self.volume_configurations is not None: + run_opts["volumeConfigurations"] = self.volume_configurations elif self.launch_type: run_opts["launchType"] = self.launch_type if self.platform_version is not None: diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index a6915214a0764..fefdb595dacda 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -172,6 +172,7 @@ def test_template_fields_overrides(self): "overrides", "launch_type", "capacity_provider_strategy", + "volume_configurations", "group", "placement_constraints", "placement_strategy", From 1f08a7a25a6588d571f7c2ab93571389c642f09e Mon Sep 17 00:00:00 2001 From: Srabasti Banerjee Date: Wed, 11 Sep 2024 12:40:20 -0700 Subject: [PATCH 038/349] Add Five9 to Airflow Company List (#42172) --- INTHEWILD.md | 1 + 1 file changed, 1 insertion(+) diff --git a/INTHEWILD.md b/INTHEWILD.md index 2093c16db726e..979598e5c2fd0 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -213,6 +213,7 @@ Currently, **officially** using Airflow: 1. [Farfetch](https://github.com/farfetch) [[@davidmarques78](https://github.com/davidmarques78)] 1. [Fathom Health](https://www.fathomhealth.co/) 1. [Firestone Inventing](https://www.hsmap.com/) [[@zihengCat](https://github.com/zihengCat)] +1. [Five9](https://https://www.five9.com/) [[srabasti](https://github.com/Srabasti)] 1. [Fleek Fashion](https://www.fleekfashion.app/) [[@ghodouss](https://github.com/ghodoussG)] 1. [Flipp](https://www.flipp.com) [[@sethwilsonwishabi](https://github.com/sethwilsonwishabi)] 1. [Format](https://www.format.com) [[@format](https://github.com/4ormat) & [@jasonicarter](https://github.com/jasonicarter)] From 17ecb3abf5b06c10411d7cc10adc49ab64210250 Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Thu, 12 Sep 2024 01:14:56 +0530 Subject: [PATCH 039/349] Airflow 2.10.1 has been released (#42068) --- .github/ISSUE_TEMPLATE/airflow_bug_report.yml | 3 +- Dockerfile | 2 +- README.md | 12 ++--- RELEASE_NOTES.rst | 44 +++++++++++++++++++ airflow/reproducible_build.yaml | 4 +- .../installation/supported-versions.rst | 2 +- generated/PYPI_README.md | 10 ++--- scripts/ci/pre_commit/supported_versions.py | 2 +- 8 files changed, 61 insertions(+), 18 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/airflow_bug_report.yml b/.github/ISSUE_TEMPLATE/airflow_bug_report.yml index 0b9697edf00f7..853b102ef07f8 100644 --- a/.github/ISSUE_TEMPLATE/airflow_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/airflow_bug_report.yml @@ -25,8 +25,7 @@ body: the latest release or main to see if the issue is fixed before reporting it. multiple: false options: - - "2.10.0" - - "2.9.3" + - "2.10.1" - "main (development)" - "Other Airflow 2 version (please specify below)" validations: diff --git a/Dockerfile b/Dockerfile index ed38ba82146d7..5cd1caec434ee 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,7 +45,7 @@ ARG AIRFLOW_UID="50000" ARG AIRFLOW_USER_HOME_DIR=/home/airflow # latest released version here -ARG AIRFLOW_VERSION="2.10.0" +ARG AIRFLOW_VERSION="2.10.1" ARG PYTHON_BASE_IMAGE="python:3.8-slim-bookworm" diff --git a/README.md b/README.md index 9556022ce5afb..91ddf5e927245 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ Airflow is not a streaming solution, but it is often used to process real-time d Apache Airflow is tested with: -| | Main version (dev) | Stable version (2.10.0) | +| | Main version (dev) | Stable version (2.10.1) | |------------|----------------------------|----------------------------| | Python | 3.8, 3.9, 3.10, 3.11, 3.12 | 3.8, 3.9, 3.10, 3.11, 3.12 | | Platform | AMD64/ARM64(\*) | AMD64/ARM64(\*) | @@ -177,15 +177,15 @@ them to the appropriate format and workflow that your tool requires. ```bash -pip install 'apache-airflow==2.10.0' \ - --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.0/constraints-3.8.txt" +pip install 'apache-airflow==2.10.1' \ + --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.1/constraints-3.8.txt" ``` 2. Installing with extras (i.e., postgres, google) ```bash -pip install 'apache-airflow[postgres,google]==2.10.0' \ - --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.0/constraints-3.8.txt" +pip install 'apache-airflow[postgres,google]==2.10.1' \ + --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.1/constraints-3.8.txt" ``` For information on installing provider packages, check @@ -290,7 +290,7 @@ Apache Airflow version life cycle: | Version | Current Patch/Minor | State | First Release | Limited Support | EOL/Terminated | |-----------|-----------------------|-----------|-----------------|-------------------|------------------| -| 2 | 2.10.0 | Supported | Dec 17, 2020 | TBD | TBD | +| 2 | 2.10.1 | Supported | Dec 17, 2020 | TBD | TBD | | 1.10 | 1.10.15 | EOL | Aug 27, 2018 | Dec 17, 2020 | June 17, 2021 | | 1.9 | 1.9.0 | EOL | Jan 03, 2018 | Aug 27, 2018 | Aug 27, 2018 | | 1.8 | 1.8.2 | EOL | Mar 19, 2017 | Jan 03, 2018 | Jan 03, 2018 | diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst index a5e6d9974c50e..d42074b1146a7 100644 --- a/RELEASE_NOTES.rst +++ b/RELEASE_NOTES.rst @@ -21,6 +21,50 @@ .. towncrier release notes start +Airflow 2.10.1 (2024-09-05) +--------------------------- + +Significant Changes +^^^^^^^^^^^^^^^^^^^ + +No significant changes. + +Bug Fixes +""""""""" +- Handle Example dags case when checking for missing files (#41874) +- Fix logout link in "no roles" error page (#41845) +- Set end_date and duration for triggers completed with end_from_trigger as True. (#41834) +- DAGs are not marked as stale if the dags folder change (#41829) +- Fix compatibility with FAB provider versions <1.3.0 (#41809) +- Don't Fail LocalTaskJob on heartbeat (#41810) +- Remove deprecation warning for cgitb in Plugins Manager (#41793) +- Fix log for notifier(instance) without __name__ (#41699) +- Splitting syspath preparation into stages (#41694) +- Adding url sanitization for extra links (#41680) +- Fix InletEventsAccessors type stub (#41607) +- Fix UI rendering when XCom is INT, FLOAT, BOOL or NULL (#41605) +- Fix try selector refresh (#41503) +- Incorrect try number subtraction producing invalid span id for OTEL airflow (#41535) +- Add WebEncoder for trigger page rendering to avoid render failure (#41485) +- Adding ``tojson`` filter to example_inlet_event_extra example dag (#41890) +- Add backward compatibility check for executors that don't inherit BaseExecutor (#41927) + +Miscellaneous +""""""""""""" +- Bump webpack from 5.76.0 to 5.94.0 in /airflow/www (#41879) +- Adding rel property to hyperlinks in logs (#41783) +- Field Deletion Warning when editing Connections (#41504) +- Make Scarf usage reporting in major+minor versions and counters in buckets (#41900) +- Lower down universal-pathlib minimum to 0.2.2 (#41943) +- Protect against None components of universal pathlib xcom backend (#41938) + +Doc Only Changes +"""""""""""""""" +- Remove Debian bullseye support (#41569) +- Add an example for auth with ``keycloak`` (#41791) + + + Airflow 2.10.0 (2024-08-15) --------------------------- diff --git a/airflow/reproducible_build.yaml b/airflow/reproducible_build.yaml index 24eb1ab9cb526..31e63fbce742b 100644 --- a/airflow/reproducible_build.yaml +++ b/airflow/reproducible_build.yaml @@ -1,2 +1,2 @@ -release-notes-hash: 21b1c588582fcbd521c30e73b4b560b4 -source-date-epoch: 1724146775 +release-notes-hash: aa948d55b0b6062659dbcd0293d73838 +source-date-epoch: 1725624671 diff --git a/docs/apache-airflow/installation/supported-versions.rst b/docs/apache-airflow/installation/supported-versions.rst index 0fbeb683526f6..0a7694abbda3d 100644 --- a/docs/apache-airflow/installation/supported-versions.rst +++ b/docs/apache-airflow/installation/supported-versions.rst @@ -29,7 +29,7 @@ Apache Airflow® version life cycle: ========= ===================== ========= =============== ================= ================ Version Current Patch/Minor State First Release Limited Support EOL/Terminated ========= ===================== ========= =============== ================= ================ -2 2.10.0 Supported Dec 17, 2020 TBD TBD +2 2.10.1 Supported Dec 17, 2020 TBD TBD 1.10 1.10.15 EOL Aug 27, 2018 Dec 17, 2020 June 17, 2021 1.9 1.9.0 EOL Jan 03, 2018 Aug 27, 2018 Aug 27, 2018 1.8 1.8.2 EOL Mar 19, 2017 Jan 03, 2018 Jan 03, 2018 diff --git a/generated/PYPI_README.md b/generated/PYPI_README.md index 05d5bdedaf3b8..2b80e73a45f5e 100644 --- a/generated/PYPI_README.md +++ b/generated/PYPI_README.md @@ -54,7 +54,7 @@ Use Airflow to author workflows as directed acyclic graphs (DAGs) of tasks. The Apache Airflow is tested with: -| | Main version (dev) | Stable version (2.10.0) | +| | Main version (dev) | Stable version (2.10.1) | |------------|----------------------------|----------------------------| | Python | 3.8, 3.9, 3.10, 3.11, 3.12 | 3.8, 3.9, 3.10, 3.11, 3.12 | | Platform | AMD64/ARM64(\*) | AMD64/ARM64(\*) | @@ -130,15 +130,15 @@ them to the appropriate format and workflow that your tool requires. ```bash -pip install 'apache-airflow==2.10.0' \ - --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.0/constraints-3.8.txt" +pip install 'apache-airflow==2.10.1' \ + --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.1/constraints-3.8.txt" ``` 2. Installing with extras (i.e., postgres, google) ```bash -pip install 'apache-airflow[postgres,google]==2.10.0' \ - --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.0/constraints-3.8.txt" +pip install 'apache-airflow[postgres,google]==2.10.1' \ + --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.1/constraints-3.8.txt" ``` For information on installing provider packages, check diff --git a/scripts/ci/pre_commit/supported_versions.py b/scripts/ci/pre_commit/supported_versions.py index 0fb68e5afc4b6..b392eaf6d4e01 100755 --- a/scripts/ci/pre_commit/supported_versions.py +++ b/scripts/ci/pre_commit/supported_versions.py @@ -27,7 +27,7 @@ HEADERS = ("Version", "Current Patch/Minor", "State", "First Release", "Limited Support", "EOL/Terminated") SUPPORTED_VERSIONS = ( - ("2", "2.10.0", "Supported", "Dec 17, 2020", "TBD", "TBD"), + ("2", "2.10.1", "Supported", "Dec 17, 2020", "TBD", "TBD"), ("1.10", "1.10.15", "EOL", "Aug 27, 2018", "Dec 17, 2020", "June 17, 2021"), ("1.9", "1.9.0", "EOL", "Jan 03, 2018", "Aug 27, 2018", "Aug 27, 2018"), ("1.8", "1.8.2", "EOL", "Mar 19, 2017", "Jan 03, 2018", "Jan 03, 2018"), From 8f8f01fbf6f68d37fbe4b3cd41340d192ea84be7 Mon Sep 17 00:00:00 2001 From: Gopal Dirisala <39794726+dirrao@users.noreply.github.com> Date: Thu, 12 Sep 2024 01:16:20 +0530 Subject: [PATCH 040/349] Align timers and timing metrics (ms) across all metrics loggers (#39908) --- airflow/config_templates/config.yml | 11 ++++++++++ airflow/metrics/datadog_logger.py | 15 ++++++++++++- airflow/metrics/otel_logger.py | 14 +++++++++++- airflow/metrics/protocols.py | 16 +++++++++++++- airflow/models/taskinstance.py | 19 ++++++++++++++-- newsfragments/39908.significant.rst | 1 + tests/_internals/forbidden_warnings.py | 5 +++++ tests/core/test_otel_logger.py | 25 +++++++++++++++++---- tests/core/test_stats.py | 30 ++++++++++++++++++++++---- 9 files changed, 123 insertions(+), 13 deletions(-) create mode 100644 newsfragments/39908.significant.rst diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 2e624db827360..9c9f7c2153c20 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1106,6 +1106,17 @@ metrics: example: "\"scheduler,executor,dagrun,pool,triggerer,celery\" or \"^scheduler,^executor,heartbeat|timeout\"" default: "" + metrics_consistency_on: + description: | + Enables metrics consistency across all metrics loggers (ex: timer and timing metrics). + + .. warning:: + + It is enabled by default from Airflow 3. + version_added: 2.10.0 + type: string + example: ~ + default: "True" statsd_on: description: | Enables sending metrics to StatsD. diff --git a/airflow/metrics/datadog_logger.py b/airflow/metrics/datadog_logger.py index 156407977305e..c7bcf1986d853 100644 --- a/airflow/metrics/datadog_logger.py +++ b/airflow/metrics/datadog_logger.py @@ -19,9 +19,11 @@ import datetime import logging +import warnings from typing import TYPE_CHECKING from airflow.configuration import conf +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.metrics.protocols import Timer from airflow.metrics.validators import ( AllowListValidator, @@ -40,6 +42,14 @@ log = logging.getLogger(__name__) +metrics_consistency_on = conf.getboolean("metrics", "metrics_consistency_on", fallback=True) +if not metrics_consistency_on: + warnings.warn( + "Timer and timing metrics publish in seconds were deprecated. It is enabled by default from Airflow 3 onwards. Enable metrics consistency to publish all the timer and timing metrics in milliseconds.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + class SafeDogStatsdLogger: """DogStatsd Logger.""" @@ -134,7 +144,10 @@ def timing( tags_list = [] if self.metrics_validator.test(stat): if isinstance(dt, datetime.timedelta): - dt = dt.total_seconds() + if metrics_consistency_on: + dt = dt.total_seconds() * 1000.0 + else: + dt = dt.total_seconds() return self.dogstatsd.timing(metric=stat, value=dt, tags=tags_list) return None diff --git a/airflow/metrics/otel_logger.py b/airflow/metrics/otel_logger.py index 5dac960c169a0..e8d0f54d73288 100644 --- a/airflow/metrics/otel_logger.py +++ b/airflow/metrics/otel_logger.py @@ -31,6 +31,7 @@ from opentelemetry.sdk.resources import SERVICE_NAME, Resource from airflow.configuration import conf +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.metrics.protocols import Timer from airflow.metrics.validators import ( OTEL_NAME_MAX_LENGTH, @@ -71,6 +72,14 @@ # Delimiter is placed between the universal metric prefix and the unique metric name. DEFAULT_METRIC_NAME_DELIMITER = "." +metrics_consistency_on = conf.getboolean("metrics", "metrics_consistency_on", fallback=True) +if not metrics_consistency_on: + warnings.warn( + "Timer and timing metrics publish in seconds were deprecated. It is enabled by default from Airflow 3 onwards. Enable metrics consistency to publish all the timer and timing metrics in milliseconds.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + def full_name(name: str, *, prefix: str = DEFAULT_METRIC_NAME_PREFIX) -> str: """Assembles the prefix, delimiter, and name and returns it as a string.""" @@ -274,7 +283,10 @@ def timing( """OTel does not have a native timer, stored as a Gauge whose value is number of seconds elapsed.""" if self.metrics_validator.test(stat) and name_is_otel_safe(self.prefix, stat): if isinstance(dt, datetime.timedelta): - dt = dt.total_seconds() + if metrics_consistency_on: + dt = dt.total_seconds() * 1000.0 + else: + dt = dt.total_seconds() self.metrics_map.set_gauge_value(full_name(prefix=self.prefix, name=stat), float(dt), False, tags) def timer( diff --git a/airflow/metrics/protocols.py b/airflow/metrics/protocols.py index c46942ce95f70..7eef7929e02db 100644 --- a/airflow/metrics/protocols.py +++ b/airflow/metrics/protocols.py @@ -19,12 +19,23 @@ import datetime import time +import warnings from typing import Union +from airflow.configuration import conf +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.typing_compat import Protocol DeltaType = Union[int, float, datetime.timedelta] +metrics_consistency_on = conf.getboolean("metrics", "metrics_consistency_on", fallback=True) +if not metrics_consistency_on: + warnings.warn( + "Timer and timing metrics publish in seconds were deprecated. It is enabled by default from Airflow 3 onwards. Enable metrics consistency to publish all the timer and timing metrics in milliseconds.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + class TimerProtocol(Protocol): """Type protocol for StatsLogger.timer.""" @@ -116,6 +127,9 @@ def start(self) -> Timer: def stop(self, send: bool = True) -> None: """Stop the timer, and optionally send it to stats backend.""" if self._start_time is not None: - self.duration = time.perf_counter() - self._start_time + if metrics_consistency_on: + self.duration = 1000.0 * (time.perf_counter() - self._start_time) # Convert to milliseconds. + else: + self.duration = time.perf_counter() - self._start_time if send and self.real_timer: self.real_timer.stop() diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 165f5c7987305..5f82d84fe5c6c 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -74,6 +74,7 @@ from airflow.exceptions import ( AirflowException, AirflowFailException, + AirflowProviderDeprecationWarning, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, @@ -168,6 +169,14 @@ PAST_DEPENDS_MET = "past_depends_met" +metrics_consistency_on = conf.getboolean("metrics", "metrics_consistency_on", fallback=True) +if not metrics_consistency_on: + warnings.warn( + "Timer and timing metrics publish in seconds were deprecated. It is enabled by default from Airflow 3 onwards. Enable metrics consistency to publish all the timer and timing metrics in milliseconds.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + class TaskReturnCode(Enum): """ @@ -2809,7 +2818,10 @@ def emit_state_change_metric(self, new_state: TaskInstanceState) -> None: self.task_id, ) return - timing = timezone.utcnow() - self.queued_dttm + if metrics_consistency_on: + timing = timezone.utcnow() - self.queued_dttm + else: + timing = (timezone.utcnow() - self.queued_dttm).total_seconds() elif new_state == TaskInstanceState.QUEUED: metric_name = "scheduled_duration" if self.start_date is None: @@ -2822,7 +2834,10 @@ def emit_state_change_metric(self, new_state: TaskInstanceState) -> None: self.task_id, ) return - timing = timezone.utcnow() - self.start_date + if metrics_consistency_on: + timing = timezone.utcnow() - self.start_date + else: + timing = (timezone.utcnow() - self.start_date).total_seconds() else: raise NotImplementedError("no metric emission setup for state %s", new_state) diff --git a/newsfragments/39908.significant.rst b/newsfragments/39908.significant.rst new file mode 100644 index 0000000000000..bd4a2967ba4fb --- /dev/null +++ b/newsfragments/39908.significant.rst @@ -0,0 +1 @@ +Publishing timer and timing metrics in seconds has been deprecated. In Airflow 3, ``metrics_consistency_on`` from ``metrics`` is enabled by default. You can disable this for backward compatibility. To publish all timer and timing metrics in milliseconds, ensure metrics consistency is enabled diff --git a/tests/_internals/forbidden_warnings.py b/tests/_internals/forbidden_warnings.py index c78e4b0333f74..324d2ff6f9824 100644 --- a/tests/_internals/forbidden_warnings.py +++ b/tests/_internals/forbidden_warnings.py @@ -62,6 +62,11 @@ def pytest_itemcollected(self, item: pytest.Item): # Add marker at the beginning of the markers list. In this case, it does not conflict with # filterwarnings markers, which are set explicitly in the test suite. item.add_marker(pytest.mark.filterwarnings(f"error::{fw}"), append=False) + item.add_marker( + pytest.mark.filterwarnings( + "ignore:Timer and timing metrics publish in seconds were deprecated. It is enabled by default from Airflow 3 onwards. Enable metrics consistency to publish all the timer and timing metrics in milliseconds.:DeprecationWarning" + ) + ) @pytest.hookimpl(hookwrapper=True, trylast=True) def pytest_sessionfinish(self, session: pytest.Session, exitstatus: int): diff --git a/tests/core/test_otel_logger.py b/tests/core/test_otel_logger.py index 6cba116f652b9..d5697e585b45c 100644 --- a/tests/core/test_otel_logger.py +++ b/tests/core/test_otel_logger.py @@ -25,6 +25,7 @@ from opentelemetry.metrics import MeterProvider from airflow.exceptions import InvalidStatsNameException +from airflow.metrics import otel_logger, protocols from airflow.metrics.otel_logger import ( OTEL_NAME_MAX_LENGTH, UP_DOWN_COUNTERS, @@ -234,12 +235,22 @@ def test_gauge_value_is_correct(self, name): assert self.map[full_name(name)].value == 1 - def test_timing_new_metric(self, name): - self.stats.timing(name, dt=123) + @pytest.mark.parametrize( + "metrics_consistency_on", + [True, False], + ) + def test_timing_new_metric(self, metrics_consistency_on, name): + import datetime + + otel_logger.metrics_consistency_on = metrics_consistency_on + + self.stats.timing(name, dt=datetime.timedelta(seconds=123)) self.meter.get_meter().create_observable_gauge.assert_called_once_with( name=full_name(name), callbacks=ANY ) + expected_value = 123000.0 if metrics_consistency_on else 123 + assert self.map[full_name(name)].value == expected_value def test_timing_new_metric_with_tags(self, name): tags = {"hello": "world"} @@ -265,13 +276,19 @@ def test_timing_existing_metric(self, name): # time.perf_count() is called once to get the starting timestamp and again # to get the end timestamp. timer() should return the difference as a float. + @pytest.mark.parametrize( + "metrics_consistency_on", + [True, False], + ) @mock.patch.object(time, "perf_counter", side_effect=[0.0, 3.14]) - def test_timer_with_name_returns_float_and_stores_value(self, mock_time, name): + def test_timer_with_name_returns_float_and_stores_value(self, mock_time, metrics_consistency_on, name): + protocols.metrics_consistency_on = metrics_consistency_on with self.stats.timer(name) as timer: pass assert isinstance(timer.duration, float) - assert timer.duration == 3.14 + expected_duration = 3140.0 if metrics_consistency_on else 3.14 + assert timer.duration == expected_duration assert mock_time.call_count == 2 self.meter.get_meter().create_observable_gauge.assert_called_once_with( name=full_name(name), callbacks=ANY diff --git a/tests/core/test_stats.py b/tests/core/test_stats.py index 902a0ed0037f5..5127b95927a87 100644 --- a/tests/core/test_stats.py +++ b/tests/core/test_stats.py @@ -20,6 +20,7 @@ import importlib import logging import re +import time from unittest import mock from unittest.mock import Mock @@ -28,6 +29,7 @@ import airflow from airflow.exceptions import AirflowConfigException, InvalidStatsNameException, RemovedInAirflow3Warning +from airflow.metrics import datadog_logger, protocols from airflow.metrics.datadog_logger import SafeDogStatsdLogger from airflow.metrics.statsd_logger import SafeStatsdLogger from airflow.metrics.validators import ( @@ -224,24 +226,44 @@ def test_does_send_stats_using_dogstatsd_when_statsd_and_dogstatsd_both_on(self) metric="empty_key", sample_rate=1, tags=[], value=1 ) - def test_timer(self): - with self.dogstatsd.timer("empty_timer"): + @pytest.mark.parametrize( + "metrics_consistency_on", + [True, False], + ) + @mock.patch.object(time, "perf_counter", side_effect=[0.0, 100.0]) + def test_timer(self, time_mock, metrics_consistency_on): + protocols.metrics_consistency_on = metrics_consistency_on + + with self.dogstatsd.timer("empty_timer") as timer: pass self.dogstatsd_client.timed.assert_called_once_with("empty_timer", tags=[]) + expected_duration = 100.0 + if metrics_consistency_on: + expected_duration = 1000.0 * 100.0 + assert expected_duration == timer.duration + assert time_mock.call_count == 2 def test_empty_timer(self): with self.dogstatsd.timer(): pass self.dogstatsd_client.timed.assert_not_called() - def test_timing(self): + @pytest.mark.parametrize( + "metrics_consistency_on", + [True, False], + ) + def test_timing(self, metrics_consistency_on): import datetime + datadog_logger.metrics_consistency_on = metrics_consistency_on + self.dogstatsd.timing("empty_timer", 123) self.dogstatsd_client.timing.assert_called_once_with(metric="empty_timer", value=123, tags=[]) self.dogstatsd.timing("empty_timer", datetime.timedelta(seconds=123)) - self.dogstatsd_client.timing.assert_called_with(metric="empty_timer", value=123.0, tags=[]) + self.dogstatsd_client.timing.assert_called_with( + metric="empty_timer", value=123000.0 if metrics_consistency_on else 123.0, tags=[] + ) def test_gauge(self): self.dogstatsd.gauge("empty", 123) From ceb1dc371b4e9470c5485190c9fc4fab71bc9e56 Mon Sep 17 00:00:00 2001 From: ipsatrivedi <45980322+ipsatrivedi@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:58:44 -0500 Subject: [PATCH 041/349] Updating Tekmetric URL in Airflow User's List (#42179) * Adding Tekmetric to Users list * Removing extra space --------- Co-authored-by: Ipsa Trivedi <> --- INTHEWILD.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/INTHEWILD.md b/INTHEWILD.md index 979598e5c2fd0..580a8f4d014c5 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -468,7 +468,7 @@ Currently, **officially** using Airflow: 1. [Talkdesk](https://www.talkdesk.com) 1. [Tapsi](https://tapsi.ir/) 1. [TEK](https://www.tek.fi/en) [[@telac](https://github.com/telac)] -1. [Tekmetric] (https://www.tekmetric.com/) +1. [Tekmetric](https://www.tekmetric.com/) 1. [Telefonica Innovation Alpha](https://www.alpha.company/) [[@Alpha-Health](https://github.com/Alpha-health)] 1. [Telia Company](https://www.teliacompany.com/en) 1. [Ternary Data](https://ternarydata.com/) [[@mhousley](https://github.com/mhousley), [@JoeReis](https://github.com/JoeReis)] From 37ab5aff23a3fcb3f9fd83c35fd67518a092a5b3 Mon Sep 17 00:00:00 2001 From: vfeldsher <127131870+vfeldsher@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:59:13 -0400 Subject: [PATCH 042/349] Added new user under Bloomberg to INTHEWILD.md (#42178) --- INTHEWILD.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/INTHEWILD.md b/INTHEWILD.md index 580a8f4d014c5..1cdd5e5dc16ad 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -89,7 +89,7 @@ Currently, **officially** using Airflow: 1. [BlaBlaCar](https://www.blablacar.com) [[@puckel](https://github.com/puckel) & [@wmorin](https://github.com/wmorin)] 1. [Blacklane](https://www.blacklane.com) [[@serkef](https://github.com/serkef)] 1. [Bloc](https://www.bloc.io) [[@dpaola2](https://github.com/dpaola2)] -1. [Bloomberg](https://www.techatbloomberg.com) [[@skandala23] (https://github.com/skandala23)] +1. [Bloomberg](https://www.techatbloomberg.com) [[@skandala23] (https://github.com/skandala23) & [@vfeldsher](https://https://github.com/vfeldsher)] 1. [Bloomreach](https://www.bloomreach.com/) [[@neelborooah](https://github.com/neelborooah) & [@debodirno](https://github.com/debodirno) & [@ayushmnnit](https://github.com/ayushmnnit)] 1. [Blue Yonder](http://www.blue-yonder.com) [[@blue-yonder](https://github.com/blue-yonder)] 1. [BlueApron](https://www.blueapron.com) [[@jasonjho](https://github.com/jasonjho) & [@matthewdavidhauser](https://github.com/matthewdavidhauser)] From d2e07737e829271603792811afda9a8fd4ba2898 Mon Sep 17 00:00:00 2001 From: leoguzman <22300213+leoguzman@users.noreply.github.com> Date: Wed, 11 Sep 2024 20:00:15 +0000 Subject: [PATCH 043/349] reformat summary commands (#42171) Co-authored-by: Tzu-ping Chung Co-authored-by: Ryan Hatter <25823361+RNHTTR@users.noreply.github.com> --- contributing-docs/10_working_with_git.rst | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/contributing-docs/10_working_with_git.rst b/contributing-docs/10_working_with_git.rst index 6c8e2af6b173b..bc0b74a055501 100644 --- a/contributing-docs/10_working_with_git.rst +++ b/contributing-docs/10_working_with_git.rst @@ -191,11 +191,13 @@ Summary Useful when you understand the flow but don't remember the steps and want a quick reference. -``git fetch --all`` -``git merge-base my-branch apache/main`` -``git checkout my-branch`` -``git rebase HASH --onto apache/main`` -``git push --force-with-lease`` +.. code-block:: console + + git fetch --all + git merge-base my-branch apache/main + git checkout my-branch + git rebase HASH --onto apache/main + git push --force-with-lease ------- From e5711b433837a04c5ffc3ebf4222a9f6b13069bf Mon Sep 17 00:00:00 2001 From: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> Date: Thu, 12 Sep 2024 00:27:56 +0200 Subject: [PATCH 044/349] Adding bare/empty provider package for AIP-69 as starting point (#42046) --- .../airflow_providers_bug_report.yml | 1 + INSTALL | 2 +- airflow/providers/edge/CHANGELOG.rst | 45 +++++++++ airflow/providers/edge/__init__.py | 39 ++++++++ airflow/providers/edge/provider.yaml | 96 +++++++++++++++++++ .../12_airflow_dependencies_and_extras.rst | 2 +- dev/breeze/doc/images/output_build-docs.svg | 8 +- dev/breeze/doc/images/output_build-docs.txt | 2 +- ...release-management_add-back-references.svg | 8 +- ...release-management_add-back-references.txt | 2 +- ...output_release-management_publish-docs.svg | 8 +- ...output_release-management_publish-docs.txt | 2 +- ...t_sbom_generate-providers-requirements.svg | 4 +- ...t_sbom_generate-providers-requirements.txt | 2 +- .../src/airflow_breeze/global_constants.py | 4 +- .../changelog.rst | 25 +++++ .../apache-airflow-providers-edge/commits.rst | 34 +++++++ .../configurations-ref.rst | 18 ++++ docs/apache-airflow-providers-edge/index.rst | 85 ++++++++++++++++ .../installing-providers-from-sources.rst | 18 ++++ .../security.rst | 18 ++++ docs/apache-airflow/extra-packages-ref.rst | 4 + generated/provider_dependencies.json | 11 +++ pyproject.toml | 2 +- tests/providers/edge/__init__.py | 17 ++++ 25 files changed, 434 insertions(+), 23 deletions(-) create mode 100644 airflow/providers/edge/CHANGELOG.rst create mode 100644 airflow/providers/edge/__init__.py create mode 100644 airflow/providers/edge/provider.yaml create mode 100644 docs/apache-airflow-providers-edge/changelog.rst create mode 100644 docs/apache-airflow-providers-edge/commits.rst create mode 100644 docs/apache-airflow-providers-edge/configurations-ref.rst create mode 100644 docs/apache-airflow-providers-edge/index.rst create mode 100644 docs/apache-airflow-providers-edge/installing-providers-from-sources.rst create mode 100644 docs/apache-airflow-providers-edge/security.rst create mode 100644 tests/providers/edge/__init__.py diff --git a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml index 4cd0cad19fc7b..707a3dabeaf24 100644 --- a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml @@ -58,6 +58,7 @@ body: - dingding - discord - docker + - edge - elasticsearch - exasol - fab diff --git a/INSTALL b/INSTALL index 8d81910f071c7..d9ae8088ee30e 100644 --- a/INSTALL +++ b/INSTALL @@ -274,7 +274,7 @@ airbyte, alibaba, amazon, apache.beam, apache.cassandra, apache.drill, apache.dr apache.hdfs, apache.hive, apache.iceberg, apache.impala, apache.kafka, apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apprise, arangodb, asana, atlassian.jira, celery, cloudant, cncf.kubernetes, cohere, common.compat, common.io, common.sql, databricks, datadog, dbt.cloud, -dingding, discord, docker, elasticsearch, exasol, fab, facebook, ftp, github, google, grpc, +dingding, discord, docker, edge, elasticsearch, exasol, fab, facebook, ftp, github, google, grpc, hashicorp, http, imap, influxdb, jdbc, jenkins, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mysql, neo4j, odbc, openai, openfaas, openlineage, opensearch, opsgenie, oracle, pagerduty, papermill, pgvector, pinecone, postgres, presto, qdrant, redis, salesforce, diff --git a/airflow/providers/edge/CHANGELOG.rst b/airflow/providers/edge/CHANGELOG.rst new file mode 100644 index 0000000000000..57309ff6cde66 --- /dev/null +++ b/airflow/providers/edge/CHANGELOG.rst @@ -0,0 +1,45 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + +``apache-airflow-providers-edge`` + + +Changelog +--------- + +0.1.0pre0 +......... + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + +0.1.0 +..... + +|experimental| + +Initial version of the provider. + +.. note:: + This provider is currently experimental diff --git a/airflow/providers/edge/__init__.py b/airflow/providers/edge/__init__.py new file mode 100644 index 0000000000000..cf33596f63253 --- /dev/null +++ b/airflow/providers/edge/__init__.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE +# OVERWRITTEN WHEN PREPARING DOCUMENTATION FOR THE PACKAGES. +# +# IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE +# `PROVIDER__INIT__PY_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY +# +from __future__ import annotations + +import packaging.version + +from airflow import __version__ as airflow_version + +__all__ = ["__version__"] + +__version__ = "0.1.0pre0" + +if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( + "2.10.0" +): + raise RuntimeError( + f"The package `apache-airflow-providers-edge:{__version__}` needs Apache Airflow 2.10.0+" + ) diff --git a/airflow/providers/edge/provider.yaml b/airflow/providers/edge/provider.yaml new file mode 100644 index 0000000000000..cb775ee7cc7e4 --- /dev/null +++ b/airflow/providers/edge/provider.yaml @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# yaml-language-server: $schema=../../provider.yaml.schema.json +--- +package-name: apache-airflow-providers-edge +name: Edge Executor +description: | + Handle edge workers on remote sites via HTTP(s) connection and orchestrates work over distributed sites + +state: not-ready +source-date-epoch: 1720863625 +# note that those versions are maintained by release manager - do not update them manually +versions: + - 0.1.0pre0 + +dependencies: + - apache-airflow>=2.10.0 + - pydantic>=2.3.0 + +config: + edge: + description: | + This section only applies if you are using the EdgeExecutor in + ``[core]`` section above + options: + api_enabled: + description: | + Flag if the plugin endpoint is enabled to serve Edge Workers. + version_added: ~ + type: boolean + example: "True" + default: "False" + api_url: + description: | + URL endpoint on which the Airflow code edge API is accessible from edge worker. + version_added: ~ + type: string + example: https://airflow.hosting.org/edge_worker/v1/rpcapi + default: ~ + job_poll_interval: + description: | + Edge Worker currently polls for new jobs via HTTP. This parameter defines the number + of seconds it should sleep between polls for new jobs. + Job polling only happens if the Edge Worker seeks for new work. Not if busy. + version_added: ~ + type: integer + example: "5" + default: "5" + heartbeat_interval: + description: | + Edge Worker continuously reports status to the central site. This parameter defines + how often a status with heartbeat should be sent. + During heartbeat status is reported as well as it is checked if a running task is to be terminated. + version_added: ~ + type: integer + example: "10" + default: "30" + worker_concurrency: + description: | + The concurrency that will be used when starting workers with the + ``airflow edge worker`` command. This defines the number of task instances that + a worker will take, so size up your workers based on the resources on + your worker box and the nature of your tasks + version_added: ~ + type: integer + example: ~ + default: "8" + job_success_purge: + description: | + Minutes after which successful jobs for EdgeExecutor are purged from database + version_added: ~ + type: integer + example: ~ + default: "5" + job_fail_purge: + description: | + Minutes after which failed jobs for EdgeExecutor are purged from database + version_added: ~ + type: integer + example: ~ + default: "60" diff --git a/contributing-docs/12_airflow_dependencies_and_extras.rst b/contributing-docs/12_airflow_dependencies_and_extras.rst index 70f30fa0b7a7a..18fdf30e6ee24 100644 --- a/contributing-docs/12_airflow_dependencies_and_extras.rst +++ b/contributing-docs/12_airflow_dependencies_and_extras.rst @@ -182,7 +182,7 @@ airbyte, alibaba, amazon, apache.beam, apache.cassandra, apache.drill, apache.dr apache.hdfs, apache.hive, apache.iceberg, apache.impala, apache.kafka, apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apprise, arangodb, asana, atlassian.jira, celery, cloudant, cncf.kubernetes, cohere, common.compat, common.io, common.sql, databricks, datadog, dbt.cloud, -dingding, discord, docker, elasticsearch, exasol, fab, facebook, ftp, github, google, grpc, +dingding, discord, docker, edge, elasticsearch, exasol, fab, facebook, ftp, github, google, grpc, hashicorp, http, imap, influxdb, jdbc, jenkins, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mysql, neo4j, odbc, openai, openfaas, openlineage, opensearch, opsgenie, oracle, pagerduty, papermill, pgvector, pinecone, postgres, presto, qdrant, redis, salesforce, diff --git a/dev/breeze/doc/images/output_build-docs.svg b/dev/breeze/doc/images/output_build-docs.svg index c56344abf3452..8fb52ec33922c 100644 --- a/dev/breeze/doc/images/output_build-docs.svg +++ b/dev/breeze/doc/images/output_build-docs.svg @@ -193,10 +193,10 @@ apache.cassandra | apache.drill | apache.druid | apache.flink | apache.hdfs | apache.hive | apache.iceberg |           apache.impala | apache.kafka | apache.kylin | apache.livy | apache.pig | apache.pinot | apache.spark | apprise |       arangodb | asana | atlassian.jira | celery | cloudant | cncf.kubernetes | cohere | common.compat | common.io |         -common.sql | databricks | datadog | dbt.cloud | dingding | discord | docker | docker-stack | elasticsearch | exasol |  -fab | facebook | ftp | github | google | grpc | hashicorp | helm-chart | http | imap | influxdb | jdbc | jenkins |     -microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai |         -openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres |     +common.sql | databricks | datadog | dbt.cloud | dingding | discord | docker | docker-stack | edge | elasticsearch |    +exasol | fab | facebook | ftp | github | google | grpc | hashicorp | helm-chart | http | imap | influxdb | jdbc |      +jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai +openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres |   presto | qdrant | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake |    sqlite | ssh | tableau | telegram | teradata | trino | vertica | weaviate | yandex | ydb | zendesk]...                 diff --git a/dev/breeze/doc/images/output_build-docs.txt b/dev/breeze/doc/images/output_build-docs.txt index ab98306e86eec..8a3bc4349b459 100644 --- a/dev/breeze/doc/images/output_build-docs.txt +++ b/dev/breeze/doc/images/output_build-docs.txt @@ -1 +1 @@ -ab4b8155064e482877babe3eeee0008e +767cdd5028d6ac43dd9f2804e0501ee8 diff --git a/dev/breeze/doc/images/output_release-management_add-back-references.svg b/dev/breeze/doc/images/output_release-management_add-back-references.svg index e72a67115d646..65297fbbc1dcc 100644 --- a/dev/breeze/doc/images/output_release-management_add-back-references.svg +++ b/dev/breeze/doc/images/output_release-management_add-back-references.svg @@ -141,10 +141,10 @@ apache.cassandra | apache.drill | apache.druid | apache.flink | apache.hdfs | apache.hive | apache.iceberg |           apache.impala | apache.kafka | apache.kylin | apache.livy | apache.pig | apache.pinot | apache.spark | apprise |       arangodb | asana | atlassian.jira | celery | cloudant | cncf.kubernetes | cohere | common.compat | common.io |         -common.sql | databricks | datadog | dbt.cloud | dingding | discord | docker | docker-stack | elasticsearch | exasol |  -fab | facebook | ftp | github | google | grpc | hashicorp | helm-chart | http | imap | influxdb | jdbc | jenkins |     -microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai |         -openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres |     +common.sql | databricks | datadog | dbt.cloud | dingding | discord | docker | docker-stack | edge | elasticsearch |    +exasol | fab | facebook | ftp | github | google | grpc | hashicorp | helm-chart | http | imap | influxdb | jdbc |      +jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai +openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres |   presto | qdrant | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake |    sqlite | ssh | tableau | telegram | teradata | trino | vertica | weaviate | yandex | ydb | zendesk]...                 diff --git a/dev/breeze/doc/images/output_release-management_add-back-references.txt b/dev/breeze/doc/images/output_release-management_add-back-references.txt index 0855bdb46a38b..c198abfcb81ac 100644 --- a/dev/breeze/doc/images/output_release-management_add-back-references.txt +++ b/dev/breeze/doc/images/output_release-management_add-back-references.txt @@ -1 +1 @@ -00ce41a98e5da484704dfe0b988d953b +743a6e2ad304078a210877279db4546a diff --git a/dev/breeze/doc/images/output_release-management_publish-docs.svg b/dev/breeze/doc/images/output_release-management_publish-docs.svg index 5f65e2aef2ba5..95d455d7b039c 100644 --- a/dev/breeze/doc/images/output_release-management_publish-docs.svg +++ b/dev/breeze/doc/images/output_release-management_publish-docs.svg @@ -198,10 +198,10 @@ apache.cassandra | apache.drill | apache.druid | apache.flink | apache.hdfs | apache.hive | apache.iceberg |           apache.impala | apache.kafka | apache.kylin | apache.livy | apache.pig | apache.pinot | apache.spark | apprise |       arangodb | asana | atlassian.jira | celery | cloudant | cncf.kubernetes | cohere | common.compat | common.io |         -common.sql | databricks | datadog | dbt.cloud | dingding | discord | docker | docker-stack | elasticsearch | exasol |  -fab | facebook | ftp | github | google | grpc | hashicorp | helm-chart | http | imap | influxdb | jdbc | jenkins |     -microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai |         -openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres |     +common.sql | databricks | datadog | dbt.cloud | dingding | discord | docker | docker-stack | edge | elasticsearch |    +exasol | fab | facebook | ftp | github | google | grpc | hashicorp | helm-chart | http | imap | influxdb | jdbc |      +jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai +openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres |   presto | qdrant | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake |    sqlite | ssh | tableau | telegram | teradata | trino | vertica | weaviate | yandex | ydb | zendesk]...                 diff --git a/dev/breeze/doc/images/output_release-management_publish-docs.txt b/dev/breeze/doc/images/output_release-management_publish-docs.txt index 7d1f76c7aabd1..1d94b684f6a01 100644 --- a/dev/breeze/doc/images/output_release-management_publish-docs.txt +++ b/dev/breeze/doc/images/output_release-management_publish-docs.txt @@ -1 +1 @@ -117be5534ac6cb8d193650db5c57c869 +d360bdbf659b84e202be9c8ac76610e5 diff --git a/dev/breeze/doc/images/output_sbom_generate-providers-requirements.svg b/dev/breeze/doc/images/output_sbom_generate-providers-requirements.svg index 00d47b37e579a..9d57216d0fc43 100644 --- a/dev/breeze/doc/images/output_sbom_generate-providers-requirements.svg +++ b/dev/breeze/doc/images/output_sbom_generate-providers-requirements.svg @@ -190,8 +190,8 @@ apache.flink | apache.hdfs | apache.hive | apache.iceberg | apache.impala | apache.kafka |     apache.kylin | apache.livy | apache.pig | apache.pinot | apache.spark | apprise | arangodb |   asana | atlassian.jira | celery | cloudant | cncf.kubernetes | cohere | common.compat |        -common.io | common.sql | databricks | datadog | dbt.cloud | dingding | discord | docker |      -elasticsearch | exasol | fab | facebook | ftp | github | google | grpc | hashicorp | http |    +common.io | common.sql | databricks | datadog | dbt.cloud | dingding | discord | docker | edge +| elasticsearch | exasol | fab | facebook | ftp | github | google | grpc | hashicorp | http |  imap | influxdb | jdbc | jenkins | microsoft.azure | microsoft.mssql | microsoft.psrp |        microsoft.winrm | mongo | mysql | neo4j | odbc | openai | openfaas | openlineage | opensearch  | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres | presto | qdrant diff --git a/dev/breeze/doc/images/output_sbom_generate-providers-requirements.txt b/dev/breeze/doc/images/output_sbom_generate-providers-requirements.txt index 130177a996118..f913a56a5b80c 100644 --- a/dev/breeze/doc/images/output_sbom_generate-providers-requirements.txt +++ b/dev/breeze/doc/images/output_sbom_generate-providers-requirements.txt @@ -1 +1 @@ -323f0b61bd2248df08b5ce37ea16487d +483ab08cf0222a2966510cd93945537f diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py index 4ee452af6d339..d3d4ad20ab061 100644 --- a/dev/breeze/src/airflow_breeze/global_constants.py +++ b/dev/breeze/src/airflow_breeze/global_constants.py @@ -509,13 +509,13 @@ def get_airflow_extras(): { "python-version": "3.8", "airflow-version": "2.8.4", - "remove-providers": "cloudant fab", + "remove-providers": "cloudant fab edge", "run-tests": "true", }, { "python-version": "3.8", "airflow-version": "2.9.3", - "remove-providers": "cloudant", + "remove-providers": "cloudant edge", "run-tests": "true", }, { diff --git a/docs/apache-airflow-providers-edge/changelog.rst b/docs/apache-airflow-providers-edge/changelog.rst new file mode 100644 index 0000000000000..4a87ccc753b07 --- /dev/null +++ b/docs/apache-airflow-providers-edge/changelog.rst @@ -0,0 +1,25 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE + OVERWRITTEN WHEN PREPARING PACKAGES. + + .. IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE + `PROVIDER_CHANGELOG_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + +.. include:: ../../airflow/providers/edge/CHANGELOG.rst diff --git a/docs/apache-airflow-providers-edge/commits.rst b/docs/apache-airflow-providers-edge/commits.rst new file mode 100644 index 0000000000000..a1e7bd25ab0c1 --- /dev/null +++ b/docs/apache-airflow-providers-edge/commits.rst @@ -0,0 +1,34 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE + OVERWRITTEN WHEN PREPARING PACKAGES. + + .. IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE + `PROVIDER_COMMITS_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + .. THE REMAINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME! + +Package apache-airflow-providers-edge +------------------------------------- + +Handle Edge Workers via HTTP(s) connection and distribute work over remote sites + + +This is detailed commit list of changes for versions provider package: ``edge``. +For high-level changelog, see :doc:`package information including changelog `. diff --git a/docs/apache-airflow-providers-edge/configurations-ref.rst b/docs/apache-airflow-providers-edge/configurations-ref.rst new file mode 100644 index 0000000000000..5885c9d91b6e8 --- /dev/null +++ b/docs/apache-airflow-providers-edge/configurations-ref.rst @@ -0,0 +1,18 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: ../exts/includes/providers-configurations-ref.rst diff --git a/docs/apache-airflow-providers-edge/index.rst b/docs/apache-airflow-providers-edge/index.rst new file mode 100644 index 0000000000000..8b78170b34385 --- /dev/null +++ b/docs/apache-airflow-providers-edge/index.rst @@ -0,0 +1,85 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +``apache-airflow-providers-edge`` +================================= + + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Basics + + Home + Changelog + Security + + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: References + + Configuration + Python API <_api/airflow/providers/edge/index> + PyPI Repository + Installing from sources + +.. THE REMAINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME! + + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Commits + + Detailed list of commits + + +apache-airflow-providers-edge package +------------------------------------- + +Handle Edge Workers via HTTP(s) connection and distribute work over remote sites + + +Release: 0.1.0pre0 + +Provider package +---------------- + +This package is for the ``edge`` provider. +All classes for this package are included in the ``airflow.providers.edge`` python package. + +Installation +------------ + +You can install this package on top of an existing Airflow 2 installation via +``pip install apache-airflow-providers-edge``. +For the minimum Airflow version supported, see ``Requirements`` below. + +Requirements +------------ + +The minimum Apache Airflow version supported by this provider package is ``2.10.0``. + +================== ================== +PIP package Version required +================== ================== +``apache-airflow`` ``>=2.10.0`` +``pydantic`` +================== ================== diff --git a/docs/apache-airflow-providers-edge/installing-providers-from-sources.rst b/docs/apache-airflow-providers-edge/installing-providers-from-sources.rst new file mode 100644 index 0000000000000..b4e730f4ff21a --- /dev/null +++ b/docs/apache-airflow-providers-edge/installing-providers-from-sources.rst @@ -0,0 +1,18 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: ../exts/includes/installing-providers-from-sources.rst diff --git a/docs/apache-airflow-providers-edge/security.rst b/docs/apache-airflow-providers-edge/security.rst new file mode 100644 index 0000000000000..afa13dac6fc9b --- /dev/null +++ b/docs/apache-airflow-providers-edge/security.rst @@ -0,0 +1,18 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: ../exts/includes/security.rst diff --git a/docs/apache-airflow/extra-packages-ref.rst b/docs/apache-airflow/extra-packages-ref.rst index a96220f63f809..4ccccbdfb7380 100644 --- a/docs/apache-airflow/extra-packages-ref.rst +++ b/docs/apache-airflow/extra-packages-ref.rst @@ -51,6 +51,8 @@ python dependencies for the provided package. +---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ | cgroups | ``pip install 'apache-airflow[cgroups]'`` | Needed To use CgroupTaskRunner | +---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ +| edge | ``pip install 'apache-airflow[edge]'`` | Connect Edge Workers via HTTP to the scheduler | ++---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ | github-enterprise | ``pip install 'apache-airflow[github-enterprise]'`` | GitHub Enterprise auth backend | +---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ | google-auth | ``pip install 'apache-airflow[google-auth]'`` | Google auth backend | @@ -261,6 +263,8 @@ Some of those enable Airflow to use executors to run tasks with them - other tha +---------------------+-----------------------------------------------------+-----------------------------------------------------------------+----------------------------------------------+ | docker | ``pip install 'apache-airflow[docker]'`` | Docker hooks and operators | | +---------------------+-----------------------------------------------------+-----------------------------------------------------------------+----------------------------------------------+ +| edge | ``pip install 'apache-airflow[edge]'`` | Connect Edge Workers via HTTP to the scheduler | EdgeExecutor | ++---------------------+-----------------------------------------------------+-----------------------------------------------------------------+----------------------------------------------+ | elasticsearch | ``pip install 'apache-airflow[elasticsearch]'`` | Elasticsearch hooks and Log Handler | | +---------------------+-----------------------------------------------------+-----------------------------------------------------------------+----------------------------------------------+ | exasol | ``pip install 'apache-airflow[exasol]'`` | Exasol hooks and operators | | diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index a520c4c07ea43..0023c18cd0c08 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -520,6 +520,17 @@ "excluded-python-versions": [], "state": "ready" }, + "edge": { + "deps": [ + "apache-airflow>=2.10.0", + "pydantic>=2.3.0" + ], + "devel-deps": [], + "plugins": [], + "cross-providers-deps": [], + "excluded-python-versions": [], + "state": "not-ready" + }, "elasticsearch": { "deps": [ "apache-airflow-providers-common-sql>=1.17.0", diff --git a/pyproject.toml b/pyproject.toml index cd51540529bcd..caddac66ad18d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,7 +135,7 @@ dynamic = ["version", "optional-dependencies", "dependencies"] # apache.hdfs, apache.hive, apache.iceberg, apache.impala, apache.kafka, apache.kylin, apache.livy, # apache.pig, apache.pinot, apache.spark, apprise, arangodb, asana, atlassian.jira, celery, cloudant, # cncf.kubernetes, cohere, common.compat, common.io, common.sql, databricks, datadog, dbt.cloud, -# dingding, discord, docker, elasticsearch, exasol, fab, facebook, ftp, github, google, grpc, +# dingding, discord, docker, edge, elasticsearch, exasol, fab, facebook, ftp, github, google, grpc, # hashicorp, http, imap, influxdb, jdbc, jenkins, microsoft.azure, microsoft.mssql, microsoft.psrp, # microsoft.winrm, mongo, mysql, neo4j, odbc, openai, openfaas, openlineage, opensearch, opsgenie, # oracle, pagerduty, papermill, pgvector, pinecone, postgres, presto, qdrant, redis, salesforce, diff --git a/tests/providers/edge/__init__.py b/tests/providers/edge/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/edge/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. From fa3bfd46e2f7ee11ff00f0b43dd10630d5cabcb7 Mon Sep 17 00:00:00 2001 From: Brent Bovenzi Date: Wed, 11 Sep 2024 19:24:25 -0400 Subject: [PATCH 045/349] do not camelcase xcom entries (#42182) --- airflow/www/static/js/api/index.ts | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/airflow/www/static/js/api/index.ts b/airflow/www/static/js/api/index.ts index 487197d608761..a1442e541798f 100644 --- a/airflow/www/static/js/api/index.ts +++ b/airflow/www/static/js/api/index.ts @@ -65,9 +65,14 @@ axios.interceptors.request.use((config) => { return config; }); -axios.interceptors.response.use((res: AxiosResponse) => - res.data ? camelcaseKeys(res.data, { deep: true }) : res -); +// Do not camelCase xCom entry results +axios.interceptors.response.use((res: AxiosResponse) => { + const stopPaths = []; + if (res.config.url?.includes("/xcomEntries/")) { + stopPaths.push("value"); + } + return res.data ? camelcaseKeys(res.data, { deep: true, stopPaths }) : res; +}); axios.defaults.headers.common.Accept = "application/json"; From aa21636c1f59cf33d417d59f35d7bf3f1fbeaecd Mon Sep 17 00:00:00 2001 From: Gopal Dirisala <39794726+dirrao@users.noreply.github.com> Date: Thu, 12 Sep 2024 05:15:07 +0530 Subject: [PATCH 046/349] Logging deprecated configuration removed (#42100) * Deprecated logging configuration removed * news fragment added * Deprecated logging configuration removed * news fragment added --- airflow/configuration.py | 24 ------------------------ newsfragments/42100.significant.rst | 21 +++++++++++++++++++++ tests/core/test_configuration.py | 27 --------------------------- 3 files changed, 21 insertions(+), 51 deletions(-) create mode 100644 newsfragments/42100.significant.rst diff --git a/airflow/configuration.py b/airflow/configuration.py index f4bea46c5c0ba..4238d59054cf3 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -327,36 +327,12 @@ def sensitive_config_values(self) -> Set[tuple[str, str]]: # noqa: UP006 # DeprecationWarning will be issued and the old option will be used instead deprecated_options: dict[tuple[str, str], tuple[str, str, str]] = { ("celery", "worker_precheck"): ("core", "worker_precheck", "2.0.0"), - ("logging", "interleave_timestamp_parser"): ("core", "interleave_timestamp_parser", "2.6.1"), - ("logging", "base_log_folder"): ("core", "base_log_folder", "2.0.0"), - ("logging", "remote_logging"): ("core", "remote_logging", "2.0.0"), - ("logging", "remote_log_conn_id"): ("core", "remote_log_conn_id", "2.0.0"), - ("logging", "remote_base_log_folder"): ("core", "remote_base_log_folder", "2.0.0"), - ("logging", "encrypt_s3_logs"): ("core", "encrypt_s3_logs", "2.0.0"), - ("logging", "logging_level"): ("core", "logging_level", "2.0.0"), - ("logging", "fab_logging_level"): ("core", "fab_logging_level", "2.0.0"), - ("logging", "logging_config_class"): ("core", "logging_config_class", "2.0.0"), - ("logging", "colored_console_log"): ("core", "colored_console_log", "2.0.0"), - ("logging", "colored_log_format"): ("core", "colored_log_format", "2.0.0"), - ("logging", "colored_formatter_class"): ("core", "colored_formatter_class", "2.0.0"), - ("logging", "log_format"): ("core", "log_format", "2.0.0"), - ("logging", "simple_log_format"): ("core", "simple_log_format", "2.0.0"), - ("logging", "task_log_prefix_template"): ("core", "task_log_prefix_template", "2.0.0"), - ("logging", "log_filename_template"): ("core", "log_filename_template", "2.0.0"), - ("logging", "log_processor_filename_template"): ("core", "log_processor_filename_template", "2.0.0"), - ("logging", "dag_processor_manager_log_location"): ( - "core", - "dag_processor_manager_log_location", - "2.0.0", - ), - ("logging", "task_log_reader"): ("core", "task_log_reader", "2.0.0"), ("scheduler", "parsing_processes"): ("scheduler", "max_threads", "1.10.14"), ("operators", "default_queue"): ("celery", "default_queue", "2.1.0"), ("core", "hide_sensitive_var_conn_fields"): ("admin", "hide_sensitive_variable_fields", "2.1.0"), ("core", "sensitive_var_conn_names"): ("admin", "sensitive_variable_fields", "2.1.0"), ("core", "default_pool_task_slot_count"): ("core", "non_pooled_task_slot_count", "1.10.4"), ("core", "max_active_tasks_per_dag"): ("core", "dag_concurrency", "2.2.0"), - ("logging", "worker_log_server_port"): ("celery", "worker_log_server_port", "2.2.0"), ("api", "access_control_allow_origins"): ("api", "access_control_allow_origin", "2.2.0"), ("api", "auth_backends"): ("api", "auth_backend", "2.3.0"), ("database", "sql_alchemy_conn"): ("core", "sql_alchemy_conn", "2.3.0"), diff --git a/newsfragments/42100.significant.rst b/newsfragments/42100.significant.rst new file mode 100644 index 0000000000000..c256d575a01bf --- /dev/null +++ b/newsfragments/42100.significant.rst @@ -0,0 +1,21 @@ +Removed deprecated logging configuration. + + * Removed deprecated configuration ``interleave_timestamp_parser`` from ``core``. Please use ``interleave_timestamp_parser`` from ``logging`` instead. + * Removed deprecated configuration ``base_log_folder`` from ``core``. Please use ``base_log_folder`` from ``logging`` instead. + * Removed deprecated configuration ``remote_logging`` from ``core``. Please use ``remote_logging`` from ``logging`` instead. + * Removed deprecated configuration ``remote_log_conn_id`` from ``core``. Please use ``remote_log_conn_id`` from ``logging`` instead. + * Removed deprecated configuration ``remote_base_log_folder`` from ``core``. Please use ``remote_base_log_folder`` from ``logging`` instead. + * Removed deprecated configuration ``encrypt_s3_logs`` from ``core``. Please use ``encrypt_s3_logs`` from ``logging`` instead. + * Removed deprecated configuration ``logging_level`` from ``core``. Please use ``logging_level`` from ``logging`` instead. + * Removed deprecated configuration ``fab_logging_level`` from ``core``. Please use ``fab_logging_level`` from ``logging`` instead. + * Removed deprecated configuration ``logging_config_class`` from ``core``. Please use ``logging_config_class`` from ``logging`` instead. + * Removed deprecated configuration ``colored_console_log`` from ``core``. Please use ``colored_console_log`` from ``logging`` instead. + * Removed deprecated configuration ``colored_log_format`` from ``core``. Please use ``colored_log_format`` from ``logging`` instead. + * Removed deprecated configuration ``colored_formatter_class`` from ``core``. Please use ``colored_formatter_class`` from ``logging`` instead. + * Removed deprecated configuration ``log_format`` from ``core``. Please use ``log_format`` from ``logging`` instead. + * Removed deprecated configuration ``simple_log_format`` from ``core``. Please use ``simple_log_format`` from ``logging`` instead. + * Removed deprecated configuration ``task_log_prefix_template`` from ``core``. Please use ``task_log_prefix_template`` from ``logging`` instead. + * Removed deprecated configuration ``log_filename_template`` from ``core``. Please use ``log_filename_template`` from ``logging`` instead. + * Removed deprecated configuration ``log_processor_filename_template`` from ``core``. Please use ``log_processor_filename_template`` from ``logging`` instead. + * Removed deprecated configuration ``dag_processor_manager_log_location`` from ``core``. Please use ``dag_processor_manager_log_location`` from ``logging`` instead. + * Removed deprecated configuration ``task_log_reader`` from ``core``. Please use ``task_log_reader`` from ``logging`` instead. diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py index d02e9b4b3cb82..af43daa303df8 100644 --- a/tests/core/test_configuration.py +++ b/tests/core/test_configuration.py @@ -956,33 +956,6 @@ def test_deprecated_options(self): with pytest.warns(DeprecationWarning), conf_vars({("celery", "celeryd_concurrency"): "99"}): assert conf.getint("celery", "worker_concurrency") == 99 - @conf_vars( - { - ("logging", "logging_level"): None, - ("core", "logging_level"): None, - } - ) - def test_deprecated_options_with_new_section(self): - # Guarantee we have a deprecated setting, so we test the deprecation - # lookup even if we remove this explicit fallback - with set_deprecated_options( - deprecated_options={("logging", "logging_level"): ("core", "logging_level", "2.0.0")} - ): - # Remove it so we are sure we use the right setting - conf.remove_option("core", "logging_level") - conf.remove_option("logging", "logging_level") - - with pytest.warns(DeprecationWarning): - with mock.patch.dict("os.environ", AIRFLOW__CORE__LOGGING_LEVEL="VALUE"): - assert conf.get("logging", "logging_level") == "VALUE" - - with pytest.warns(FutureWarning, match="Please update your `conf.get"): - with mock.patch.dict("os.environ", AIRFLOW__CORE__LOGGING_LEVEL="VALUE"): - assert conf.get("core", "logging_level") == "VALUE" - - with pytest.warns(DeprecationWarning), conf_vars({("core", "logging_level"): "VALUE"}): - assert conf.get("logging", "logging_level") == "VALUE" - @conf_vars( { ("celery", "result_backend"): None, From 1eabb9bd3c955591a08fe92964a1a4850a4d0544 Mon Sep 17 00:00:00 2001 From: GPK Date: Thu, 12 Sep 2024 05:02:01 +0100 Subject: [PATCH 047/349] Add template field tests to AWS operators part1 (#42183) * adding template_fields tests in operators --- .../amazon/aws/operators/test_athena.py | 4 ++ .../amazon/aws/operators/test_bedrock.py | 27 ++++++++ .../aws/operators/test_cloud_formation.py | 28 ++++++++ .../amazon/aws/operators/test_comprehend.py | 7 ++ .../amazon/aws/operators/test_datasync.py | 5 ++ .../amazon/aws/operators/test_dms.py | 64 +++++++++++++++++++ .../amazon/aws/operators/test_ec2.py | 51 +++++++++++++++ .../amazon/aws/operators/test_ecs.py | 33 ++++++++++ .../amazon/aws/operators/test_eks.py | 44 +++++++++++++ .../aws/operators/test_emr_add_steps.py | 10 +++ .../aws/operators/test_emr_containers.py | 4 ++ .../aws/operators/test_emr_create_job_flow.py | 4 ++ .../aws/operators/test_emr_modify_cluster.py | 4 ++ .../operators/test_emr_notebook_execution.py | 18 ++++++ .../aws/operators/test_emr_serverless.py | 58 +++++++++++++++++ .../operators/test_emr_terminate_job_flow.py | 11 ++++ .../amazon/aws/utils/test_template_fields.py | 28 ++++++++ 17 files changed, 400 insertions(+) create mode 100644 tests/providers/amazon/aws/utils/test_template_fields.py diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index c132e6456f1d8..102d1fe31e5c1 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -39,6 +39,7 @@ from airflow.utils import timezone from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields TEST_DAG_ID = "unit_tests" DEFAULT_DATE = datetime(2018, 1, 1) @@ -397,3 +398,6 @@ def mock_get_table_metadata(CatalogName, DatabaseName, TableName): run_facets={"externalQuery": ExternalQueryRunFacet(externalQueryId="12345", source="awsathena")}, ) assert op.get_openlineage_facets_on_complete(None) == expected_lineage + + def test_template_fields(self): + validate_template_fields(self.athena) diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py index b49d09b52a5c0..8cbb67d6f50df 100644 --- a/tests/providers/amazon/aws/operators/test_bedrock.py +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -35,6 +35,7 @@ BedrockInvokeModelOperator, BedrockRaGOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection @@ -176,6 +177,9 @@ def test_ensure_unique_job_name(self, _, side_effect, ensure_unique_name, mock_c bedrock_hook.get_waiter.assert_not_called() self.operator.defer.assert_not_called() + def test_template_fields(self): + validate_template_fields(self.operator) + class TestBedrockCreateProvisionedModelThroughputOperator: MODEL_ARN = "testProvisionedModelArn" @@ -222,6 +226,9 @@ def test_provisioned_model_wait_combinations( assert bedrock_hook.get_waiter.call_count == wait_for_completion assert self.operator.defer.call_count == deferrable + def test_template_fields(self): + validate_template_fields(self.operator) + class TestBedrockCreateKnowledgeBaseOperator: KNOWLEDGE_BASE_ID = "knowledge_base_id" @@ -288,6 +295,9 @@ def test_returns_id(self, mock_conn): assert result == self.KNOWLEDGE_BASE_ID + def test_template_fields(self): + validate_template_fields(self.operator) + class TestBedrockCreateDataSourceOperator: DATA_SOURCE_ID = "data_source_id" @@ -317,6 +327,9 @@ def test_id_returned(self, mock_conn): assert result == self.DATA_SOURCE_ID + def test_template_fields(self): + validate_template_fields(self.operator) + class TestBedrockIngestDataOperator: INGESTION_JOB_ID = "ingestion_job_id" @@ -348,6 +361,9 @@ def test_id_returned(self, mock_conn): assert result == self.INGESTION_JOB_ID + def test_template_fields(self): + validate_template_fields(self.operator) + class TestBedrockRaGOperator: VECTOR_SEARCH_CONFIG = {"filter": {"equals": {"key": "some key", "value": "some value"}}} @@ -520,3 +536,14 @@ def test_external_sources_build_rag_config(self, prompt_template): **expected_config_without_template, **expected_config_template, } + + def test_template_fields(self): + op = BedrockRaGOperator( + task_id="test_rag", + input="some text prompt", + source_type="EXTERNAL_SOURCES", + model_arn=self.MODEL_ARN, + knowledge_base_id=self.KNOWLEDGE_BASE_ID, + vector_search_config=self.VECTOR_SEARCH_CONFIG, + ) + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_cloud_formation.py b/tests/providers/amazon/aws/operators/test_cloud_formation.py index 5de02c3622cfb..4d8fb4d12bd3c 100644 --- a/tests/providers/amazon/aws/operators/test_cloud_formation.py +++ b/tests/providers/amazon/aws/operators/test_cloud_formation.py @@ -28,6 +28,7 @@ CloudFormationDeleteStackOperator, ) from airflow.utils import timezone +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields DEFAULT_DATE = timezone.datetime(2019, 1, 1) DEFAULT_ARGS = {"owner": "airflow", "start_date": DEFAULT_DATE} @@ -87,6 +88,20 @@ def test_create_stack(self, mocked_hook_client): StackName=stack_name, TemplateBody=template_body, TimeoutInMinutes=timeout ) + def test_template_fields(self): + op = CloudFormationCreateStackOperator( + task_id="cf_create_stack_init", + stack_name="fake-stack", + cloudformation_parameters={}, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="eu-west-1", + verify=True, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) + class TestCloudFormationDeleteStackOperator: def test_init(self): @@ -125,3 +140,16 @@ def test_delete_stack(self, mocked_hook_client): operator.execute(MagicMock()) mocked_hook_client.delete_stack.assert_any_call(StackName=stack_name) + + def test_template_fields(self): + op = CloudFormationDeleteStackOperator( + task_id="cf_delete_stack_init", + stack_name="fake-stack", + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="us-east-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_comprehend.py b/tests/providers/amazon/aws/operators/test_comprehend.py index 60f0fca219111..a86b779b1d502 100644 --- a/tests/providers/amazon/aws/operators/test_comprehend.py +++ b/tests/providers/amazon/aws/operators/test_comprehend.py @@ -29,6 +29,7 @@ ComprehendStartPiiEntitiesDetectionJobOperator, ) from airflow.utils.types import NOTSET +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection @@ -163,6 +164,9 @@ def test_start_pii_entities_detection_job_wait_combinations( assert comprehend_hook.get_waiter.call_count == wait_for_completion assert self.operator.defer.call_count == deferrable + def test_template_fields(self): + validate_template_fields(self.operator) + class TestComprehendCreateDocumentClassifierOperator: CLASSIFIER_ARN = ( @@ -259,3 +263,6 @@ def test_create_document_classifier_wait_combinations( assert response == self.CLASSIFIER_ARN assert comprehend_hook.get_waiter.call_count == wait_for_completion assert self.operator.defer.call_count == deferrable + + def test_template_fields(self): + validate_template_fields(self.operator) diff --git a/tests/providers/amazon/aws/operators/test_datasync.py b/tests/providers/amazon/aws/operators/test_datasync.py index e1a44ce99e28c..18b0e86103c0b 100644 --- a/tests/providers/amazon/aws/operators/test_datasync.py +++ b/tests/providers/amazon/aws/operators/test_datasync.py @@ -29,6 +29,7 @@ from airflow.utils import timezone from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields TEST_DAG_ID = "unit_tests" DEFAULT_DATE = datetime(2018, 1, 1) @@ -363,6 +364,10 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns): # ### Check mocks: mock_get_conn.assert_called() + def test_template_fields(self, mock_get_conn): + self.set_up_operator() + validate_template_fields(self.datasync) + @mock_aws @mock.patch.object(DataSyncHook, "get_conn") diff --git a/tests/providers/amazon/aws/operators/test_dms.py b/tests/providers/amazon/aws/operators/test_dms.py index fba14a6370dd7..2528edaef9e0a 100644 --- a/tests/providers/amazon/aws/operators/test_dms.py +++ b/tests/providers/amazon/aws/operators/test_dms.py @@ -34,6 +34,7 @@ ) from airflow.utils import timezone from airflow.utils.types import DagRunType +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields TASK_ARN = "test_arn" @@ -121,6 +122,18 @@ def test_create_task_with_migration_type( assert dms_hook.get_task_status(TASK_ARN) == "ready" + def test_template_fields(self): + op = DmsCreateTaskOperator( + task_id="create_task", + **self.TASK_DATA, + aws_conn_id="fake-conn-id", + region_name="ca-west-1", + verify=True, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) + class TestDmsDeleteTaskOperator: TASK_DATA = { @@ -174,6 +187,19 @@ def test_delete_task( assert dms_hook.get_task_status(TASK_ARN) == "deleting" + def test_template_fields(self): + op = DmsDeleteTaskOperator( + task_id="delete_task", + replication_task_arn=TASK_ARN, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="us-east-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) + class TestDmsDescribeTasksOperator: FILTER = {"Name": "replication-task-arn", "Values": [TASK_ARN]} @@ -267,6 +293,18 @@ def test_describe_tasks_return_value(self, mock_conn, mock_describe_replication_ assert marker is None assert response == self.MOCK_RESPONSE + def test_template_fields(self): + op = DmsDescribeTasksOperator( + task_id="describe_tasks", + describe_tasks_kwargs={"Filters": [self.FILTER]}, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="eu-west-2", + verify="/foo/bar/spam.egg", + botocore_config={"read_timeout": 42}, + ) + validate_template_fields(op) + class TestDmsStartTaskOperator: TASK_DATA = { @@ -324,6 +362,19 @@ def test_start_task( assert dms_hook.get_task_status(TASK_ARN) == "starting" + def test_template_fields(self): + op = DmsStartTaskOperator( + task_id="start_task", + replication_task_arn=TASK_ARN, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="us-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) + class TestDmsStopTaskOperator: TASK_DATA = { @@ -376,3 +427,16 @@ def test_stop_task( mock_stop_replication_task.assert_called_once_with(replication_task_arn=TASK_ARN) assert dms_hook.get_task_status(TASK_ARN) == "stopping" + + def test_template_fields(self): + op = DmsStopTaskOperator( + task_id="stop_task", + replication_task_arn=TASK_ARN, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="eu-west-1", + verify=True, + botocore_config={"read_timeout": 42}, + ) + + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_ec2.py b/tests/providers/amazon/aws/operators/test_ec2.py index 8f8a755a84357..a5ea81ff6ae87 100644 --- a/tests/providers/amazon/aws/operators/test_ec2.py +++ b/tests/providers/amazon/aws/operators/test_ec2.py @@ -30,6 +30,7 @@ EC2StopInstanceOperator, EC2TerminateInstanceOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields class BaseEc2TestClass: @@ -87,6 +88,13 @@ def test_create_multiple_instances(self): for id in instance_ids: assert ec2_hook.get_instance_state(instance_id=id) == "running" + def test_template_fields(self): + ec2_operator = EC2CreateInstanceOperator( + task_id="test_create_instance", + image_id="test_image_id", + ) + validate_template_fields(ec2_operator) + class TestEC2TerminateInstanceOperator(BaseEc2TestClass): def test_init(self): @@ -140,6 +148,13 @@ def test_terminate_multiple_instances(self): for id in instance_ids: assert ec2_hook.get_instance_state(instance_id=id) == "terminated" + def test_template_fields(self): + ec2_operator = EC2TerminateInstanceOperator( + task_id="test_terminate_instance", + instance_ids="test_image_id", + ) + validate_template_fields(ec2_operator) + class TestEC2StartInstanceOperator(BaseEc2TestClass): def test_init(self): @@ -175,6 +190,17 @@ def test_start_instance(self): # assert instance state is running assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running" + def test_template_fields(self): + ec2_operator = EC2StartInstanceOperator( + task_id="task_test", + instance_id="i-123abc", + aws_conn_id="aws_conn_test", + region_name="region-test", + check_interval=3, + ) + + validate_template_fields(ec2_operator) + class TestEC2StopInstanceOperator(BaseEc2TestClass): def test_init(self): @@ -210,6 +236,17 @@ def test_stop_instance(self): # assert instance state is running assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "stopped" + def test_template_fields(self): + ec2_operator = EC2StopInstanceOperator( + task_id="task_test", + instance_id="i-123abc", + aws_conn_id="aws_conn_test", + region_name="region-test", + check_interval=3, + ) + + validate_template_fields(ec2_operator) + class TestEC2HibernateInstanceOperator(BaseEc2TestClass): def test_init(self): @@ -322,6 +359,13 @@ def test_cannot_hibernate_some_instances(self): for id in instance_ids: assert ec2_hook.get_instance_state(instance_id=id) == "running" + def test_template_fields(self): + ec2_operator = EC2HibernateInstanceOperator( + task_id="task_test", + instance_ids="i-123abc", + ) + validate_template_fields(ec2_operator) + class TestEC2RebootInstanceOperator(BaseEc2TestClass): def test_init(self): @@ -372,3 +416,10 @@ def test_reboot_multiple_instances(self): terminate_instance.execute(None) for id in instance_ids: assert ec2_hook.get_instance_state(instance_id=id) == "running" + + def test_template_fields(self): + ec2_operator = EC2RebootInstanceOperator( + task_id="task_test", + instance_ids="i-123abc", + ) + validate_template_fields(ec2_operator) diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index fefdb595dacda..5c2ba16c4cff2 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -39,6 +39,7 @@ from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.types import NOTSET +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields CLUSTER_NAME = "test_cluster" CONTAINER_NAME = "e1ed7aac-d9b2-4315-8726-d2432bf11868" @@ -794,6 +795,17 @@ def test_execute_without_waiter(self, patch_hook_waiters): patch_hook_waiters.assert_not_called() assert result is not None + def test_template_fields(self): + op = EcsCreateClusterOperator( + task_id="task", + cluster_name=CLUSTER_NAME, + deferrable=True, + waiter_delay=12, + waiter_max_attempts=34, + ) + + validate_template_fields(op) + class TestEcsDeleteClusterOperator(EcsBaseTestCase): @pytest.mark.parametrize("waiter_delay, waiter_max_attempts", WAITERS_TEST_CASES) @@ -858,6 +870,17 @@ def test_execute_without_waiter(self, patch_hook_waiters): patch_hook_waiters.assert_not_called() assert result is not None + def test_template_fields(self): + op = EcsDeleteClusterOperator( + task_id="task", + cluster_name=CLUSTER_NAME, + deferrable=True, + waiter_delay=12, + waiter_max_attempts=34, + ) + + validate_template_fields(op) + class TestEcsDeregisterTaskDefinitionOperator(EcsBaseTestCase): warn_message = "'wait_for_completion' and waiter related params have no effect" @@ -914,6 +937,11 @@ def test_partial_deprecation_waiters_params( assert not hasattr(ti.task, "waiter_delay") assert not hasattr(ti.task, "waiter_max_attempts") + def test_template_fields(self): + op = EcsDeregisterTaskDefinitionOperator(task_id="task", task_definition=TASK_DEFINITION_NAME) + + validate_template_fields(op) + class TestEcsRegisterTaskDefinitionOperator(EcsBaseTestCase): warn_message = "'wait_for_completion' and waiter related params have no effect" @@ -991,3 +1019,8 @@ def test_partial_deprecation_waiters_params( assert not hasattr(ti.task, "wait_for_completion") assert not hasattr(ti.task, "waiter_delay") assert not hasattr(ti.task, "waiter_max_attempts") + + def test_template_fields(self): + op = EcsRegisterTaskDefinitionOperator(task_id="task", **TASK_DEFINITION_CONFIG) + + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_eks.py b/tests/providers/amazon/aws/operators/test_eks.py index 9571ca0962005..399c8e40823ae 100644 --- a/tests/providers/amazon/aws/operators/test_eks.py +++ b/tests/providers/amazon/aws/operators/test_eks.py @@ -51,6 +51,7 @@ TASK_ID, ) from tests.providers.amazon.aws.utils.eks_test_utils import convert_keys +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type CLUSTER_NAME = "cluster1" @@ -365,6 +366,15 @@ def test_eks_create_cluster_with_deferrable(self, mock_create_cluster, caplog): eks_create_cluster_operator.execute({}) assert "Waiting for EKS Cluster to provision. This will take some time." in caplog.messages + def test_template_fields(self): + op = EksCreateClusterOperator( + task_id=TASK_ID, + **self.create_cluster_params, + compute="fargate", + ) + + validate_template_fields(op) + class TestEksCreateFargateProfileOperator: def setup_method(self) -> None: @@ -445,6 +455,11 @@ def test_create_fargate_profile_deferrable(self, _): exc.value.trigger, EksCreateFargateProfileTrigger ), "Trigger is not a EksCreateFargateProfileTrigger" + def test_template_fields(self): + op = EksCreateFargateProfileOperator(task_id=TASK_ID, **self.create_fargate_profile_params) + + validate_template_fields(op) + class TestEksCreateNodegroupOperator: def setup_method(self) -> None: @@ -536,6 +551,12 @@ def test_create_nodegroup_deferrable_versus_wait_for_completion(self): ) assert operator.wait_for_completion is True + def test_template_fields(self): + op_kwargs = {**self.create_nodegroup_params} + op = EksCreateNodegroupOperator(task_id=TASK_ID, **op_kwargs) + + validate_template_fields(op) + class TestEksDeleteClusterOperator: def setup_method(self) -> None: @@ -575,6 +596,9 @@ def test_eks_delete_cluster_operator_with_deferrable(self): with pytest.raises(TaskDeferred): self.delete_cluster_operator.execute({}) + def test_template_fields(self): + validate_template_fields(self.delete_cluster_operator) + class TestEksDeleteNodegroupOperator: def setup_method(self) -> None: @@ -608,6 +632,9 @@ def test_existing_nodegroup_with_wait(self, mock_delete_nodegroup, mock_waiter): mock_waiter.assert_called_with(mock.ANY, clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME) assert_expected_waiter_type(mock_waiter, "NodegroupDeleted") + def test_template_fields(self): + validate_template_fields(self.delete_nodegroup_operator) + class TestEksDeleteFargateProfileOperator: def setup_method(self) -> None: @@ -656,6 +683,9 @@ def test_delete_fargate_profile_deferrable(self, _): exc.value.trigger, EksDeleteFargateProfileTrigger ), "Trigger is not a EksDeleteFargateProfileTrigger" + def test_template_fields(self): + validate_template_fields(self.delete_fargate_profile_operator) + class TestEksPodOperator: @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute") @@ -767,3 +797,17 @@ def test_on_finish_action_handler( ) for expected_attr in expected_attributes: assert op.__getattribute__(expected_attr) == expected_attributes[expected_attr] + + def test_template_fields(self): + op = EksPodOperator( + task_id="run_pod", + pod_name="run_pod", + cluster_name=CLUSTER_NAME, + image="amazon/aws-cli:latest", + cmds=["sh", "-c", "ls"], + labels={"demo": "hello_world"}, + get_logs=True, + on_finish_action="delete_pod", + ) + + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index 9ee99864e00e3..d5a999349aa53 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -31,6 +31,7 @@ from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger from airflow.utils import timezone from airflow.utils.types import DagRunType +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields from tests.test_utils import AIRFLOW_MAIN_FOLDER DEFAULT_DATE = timezone.datetime(2017, 1, 1) @@ -274,3 +275,12 @@ def test_emr_add_steps_deferrable(self, mock_add_job_flow_steps, mock_get_log_ur operator.execute(MagicMock()) assert isinstance(exc.value.trigger, EmrAddStepsTrigger), "Trigger is not a EmrAddStepsTrigger" + + def test_template_fields(self): + op = EmrAddStepsOperator( + task_id="test_task", + job_flow_id="j-8989898989", + aws_conn_id="aws_default", + steps=self._config, + ) + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index feeec1278e155..52306864f3597 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator, EmrEksCreateClusterOperator from airflow.providers.amazon.aws.triggers.emr import EmrContainerTrigger +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields SUBMIT_JOB_SUCCESS_RETURN = { "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -194,3 +195,6 @@ def test_emr_on_eks_execute_with_failure(self, mock_create_emr_on_eks_cluster): with pytest.raises(AirflowException) as ctx: self.emr_container.execute(None) assert expected_exception_msg in str(ctx.value) + + def test_template_fields(self): + validate_template_fields(self.emr_container) diff --git a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py index 204d292c67b46..860df8c7219ac 100644 --- a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py @@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger from airflow.utils import timezone from airflow.utils.types import DagRunType +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type from tests.test_utils import AIRFLOW_MAIN_FOLDER @@ -203,3 +204,6 @@ def test_create_job_flow_deferrable(self, mocked_hook_client): assert isinstance( exc.value.trigger, EmrCreateJobFlowTrigger ), "Trigger is not a EmrCreateJobFlowTrigger" + + def test_template_fields(self): + validate_template_fields(self.operator) diff --git a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py index 6dada442ff79f..6f257288760c3 100644 --- a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py +++ b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py @@ -25,6 +25,7 @@ from airflow.models.dag import DAG from airflow.providers.amazon.aws.operators.emr import EmrModifyClusterOperator from airflow.utils import timezone +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields DEFAULT_DATE = timezone.datetime(2017, 1, 1) MODIFY_CLUSTER_SUCCESS_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200}, "StepConcurrencyLevel": 1} @@ -65,3 +66,6 @@ def test_execute_returns_error(self, mocked_hook_client): with pytest.raises(AirflowException, match="Modify cluster failed"): self.operator.execute(self.mock_context) + + def test_template_fields(self): + validate_template_fields(self.operator) diff --git a/tests/providers/amazon/aws/operators/test_emr_notebook_execution.py b/tests/providers/amazon/aws/operators/test_emr_notebook_execution.py index ef6cb7ebc70ec..6fcd4eeb74629 100644 --- a/tests/providers/amazon/aws/operators/test_emr_notebook_execution.py +++ b/tests/providers/amazon/aws/operators/test_emr_notebook_execution.py @@ -28,6 +28,7 @@ EmrStartNotebookExecutionOperator, EmrStopNotebookExecutionOperator, ) +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type PARAMS = { @@ -303,3 +304,20 @@ def test_stop_notebook_execution_waiter_config(self, mock_conn, mock_waiter, _): WaiterConfig={"Delay": delay, "MaxAttempts": waiter_max_attempts}, ) assert_expected_waiter_type(mock_waiter, "notebook_stopped") + + def test_template_fields(self): + op = EmrStartNotebookExecutionOperator( + task_id="test-id", + editor_id=PARAMS["EditorId"], + relative_path=PARAMS["RelativePath"], + cluster_id=PARAMS["ExecutionEngine"]["Id"], + service_role=PARAMS["ServiceRole"], + notebook_execution_name=PARAMS["NotebookExecutionName"], + notebook_params=PARAMS["NotebookParams"], + notebook_instance_security_group_id=PARAMS["NotebookInstanceSecurityGroupId"], + master_instance_security_group_id=PARAMS["ExecutionEngine"]["MasterInstanceSecurityGroupId"], + tags=PARAMS["Tags"], + wait_for_completion=True, + ) + + validate_template_fields(op) diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index 12c5cc938018e..e7a43cf079f0b 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -32,6 +32,7 @@ EmrServerlessStopApplicationOperator, ) from airflow.utils.types import NOTSET +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields if TYPE_CHECKING: from unittest.mock import MagicMock @@ -393,6 +394,25 @@ def test_create_application_deferrable(self, mock_conn): with pytest.raises(TaskDeferred): operator.execute(None) + def test_template_fields(self): + operator = EmrServerlessCreateApplicationOperator( + task_id=task_id, + release_label=release_label, + job_type=job_type, + client_request_token=client_request_token, + config=config, + waiter_max_attempts=3, + waiter_delay=0, + ) + + template_fields = list(operator.template_fields) + list(operator.template_fields_renderers.keys()) + + class_fields = operator.__dict__ + + missing_fields = [field for field in template_fields if field not in class_fields] + + assert not missing_fields, f"Templated fields are not available {missing_fields}" + class TestEmrServerlessStartJobOperator: def setup_method(self): @@ -1163,6 +1183,24 @@ def test_links_spark_without_applicationui_enabled( job_run_id=job_run_id, ) + def test_template_fields(self): + operator = EmrServerlessStartJobOperator( + task_id=task_id, + client_request_token=client_request_token, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + ) + + template_fields = list(operator.template_fields) + list(operator.template_fields_renderers.keys()) + + class_fields = operator.__dict__ + + missing_fields = [field for field in template_fields if field not in class_fields] + + assert not missing_fields, f"Templated fields are not available {missing_fields}" + class TestEmrServerlessDeleteOperator: @mock.patch.object(EmrServerlessHook, "get_waiter") @@ -1277,6 +1315,19 @@ def test_delete_application_deferrable(self, mock_conn): with pytest.raises(TaskDeferred): operator.execute(None) + def test_template_fields(self): + operator = EmrServerlessDeleteApplicationOperator( + task_id=task_id, application_id=application_id_delete_operator + ) + + template_fields = list(operator.template_fields) + list(operator.template_fields_renderers.keys()) + + class_fields = operator.__dict__ + + missing_fields = [field for field in template_fields if field not in class_fields] + + assert not missing_fields, f"Templated fields are not available {missing_fields}" + class TestEmrServerlessStopOperator: @mock.patch.object(EmrServerlessHook, "get_waiter") @@ -1344,3 +1395,10 @@ def test_stop_application_deferrable_without_force_stop( operator.execute({}) assert "no running jobs found with application ID test" in caplog.messages + + def test_template_fields(self): + operator = EmrServerlessStopApplicationOperator( + task_id=task_id, application_id="test", deferrable=True, force_stop=True + ) + + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py index 2c27c146d2d34..06ab35e4510ba 100644 --- a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py @@ -24,6 +24,7 @@ from airflow.exceptions import TaskDeferred from airflow.providers.amazon.aws.operators.emr import EmrTerminateJobFlowOperator from airflow.providers.amazon.aws.triggers.emr import EmrTerminateJobFlowTrigger +from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields TERMINATE_SUCCESS_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200}} @@ -57,3 +58,13 @@ def test_create_job_flow_deferrable(self, mocked_hook_client): assert isinstance( exc.value.trigger, EmrTerminateJobFlowTrigger ), "Trigger is not a EmrTerminateJobFlowTrigger" + + def test_template_fields(self): + operator = EmrTerminateJobFlowOperator( + task_id="test_task", + job_flow_id="j-8989898989", + aws_conn_id="aws_default", + deferrable=True, + ) + + validate_template_fields(operator) diff --git a/tests/providers/amazon/aws/utils/test_template_fields.py b/tests/providers/amazon/aws/utils/test_template_fields.py new file mode 100644 index 0000000000000..689977de9bcc5 --- /dev/null +++ b/tests/providers/amazon/aws/utils/test_template_fields.py @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +def validate_template_fields(operator): + template_fields = list(operator.template_fields) + list(operator.template_fields_renderers.keys()) + + class_fields = operator.__dict__ + + missing_fields = [field for field in template_fields if field not in class_fields] + + assert not missing_fields, f"Templated fields are not available {missing_fields}" From ffe0f660ad18f934f6172690d01dba4a4b652baa Mon Sep 17 00:00:00 2001 From: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> Date: Thu, 12 Sep 2024 09:57:12 +0200 Subject: [PATCH 048/349] Add Jens Scheffler for UI as contributor (#42190) --- .github/CODEOWNERS | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index befcd854edca5..75ceb48d78504 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -27,10 +27,10 @@ /airflow/api_connexion/ @ephraimbuddy @pierrejeambrun # WWW -/airflow/www/ @ryanahamilton @ashb @bbovenzi @pierrejeambrun +/airflow/www/ @ryanahamilton @ashb @bbovenzi @pierrejeambrun @jscheffl # UI -/airflow/ui/ @bbovenzi @pierrejeambrun @ryanahamilton +/airflow/ui/ @bbovenzi @pierrejeambrun @ryanahamilton @jscheffl # Security/Permissions /airflow/api_connexion/security.py @jhtimmins @@ -65,6 +65,7 @@ /airflow/providers/cncf/kubernetes @jedcunningham @hussein-awala /airflow/providers/common/sql/ @eladkal /airflow/providers/dbt/cloud/ @josh-fell +/airflow/providers/edge @jscheffl /airflow/providers/hashicorp/ @hussein-awala /airflow/providers/openlineage/ @mobuchowski /airflow/providers/slack/ @eladkal From 8300630eafeab13fbe838c17c07eaa762eeaf226 Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Thu, 12 Sep 2024 19:03:24 +0530 Subject: [PATCH 049/349] Revert "Handle Example dags case when checking for missing files (#41856)" (#42193) This reverts commit 435e9687b0c56499bc29c21d3cada8ae9e0a8c53. --- airflow/dag_processing/manager.py | 11 +-- tests/dag_processing/test_job_runner.py | 89 +++++++++++++------------ 2 files changed, 48 insertions(+), 52 deletions(-) diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 7e404307dccd8..fee515dc07164 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -41,7 +41,6 @@ from tabulate import tabulate import airflow.models -from airflow import example_dags from airflow.api_internal.internal_api_call import internal_api_call from airflow.callbacks.callback_requests import CallbackRequest, SlaCallbackRequest from airflow.configuration import conf @@ -70,8 +69,6 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks -example_dag_folder = next(iter(example_dags.__path__)) - if TYPE_CHECKING: from multiprocessing.connection import Connection as MultiprocessingConnection @@ -530,11 +527,9 @@ def deactivate_stale_dags( for dag in dags_parsed: # When the DAG processor runs as part of the scheduler, and the user changes the DAGs folder, - # DAGs from the previous DAGs folder will be marked as stale. We also need to handle example dags - # differently. Note that this change has no impact on standalone DAG processors. - dag_not_in_current_dag_folder = ( - not os.path.commonpath([dag.fileloc, example_dag_folder]) == example_dag_folder - ) and (os.path.commonpath([dag.fileloc, dag_directory]) != dag_directory) + # DAGs from the previous DAGs folder will be marked as stale. Note that this change has no impact + # on standalone DAG processors. + dag_not_in_current_dag_folder = os.path.commonpath([dag.fileloc, dag_directory]) != dag_directory # The largest valid difference between a DagFileStat's last_finished_time and a DAG's # last_parsed_time is the processor_timeout. Longer than that indicates that the DAG is # no longer present in the file. We have a stale_dag_threshold configured to prevent a diff --git a/tests/dag_processing/test_job_runner.py b/tests/dag_processing/test_job_runner.py index 0e15a2d1f6690..9b8437d77d50a 100644 --- a/tests/dag_processing/test_job_runner.py +++ b/tests/dag_processing/test_job_runner.py @@ -773,57 +773,58 @@ def test_scan_stale_dags_when_dag_folder_change(self): def get_dag_string(filename) -> str: return open(TEST_DAG_FOLDER / filename).read() - def add_dag_to_db(file_path, dag_id, processor_subdir): - dagbag = DagBag(file_path, read_dags_from_db=False) - dag = dagbag.get_dag(dag_id) - dag.fileloc = file_path - dag.last_parsed_time = timezone.utcnow() - dag.sync_to_db(processor_subdir=processor_subdir) + with tempfile.TemporaryDirectory() as tmpdir: + old_dag_home = tempfile.mkdtemp(dir=tmpdir) + old_dag_file = tempfile.NamedTemporaryFile(dir=old_dag_home, suffix=".py") + old_dag_file.write(get_dag_string("test_example_bash_operator.py").encode()) + old_dag_file.flush() + new_dag_home = tempfile.mkdtemp(dir=tmpdir) + new_dag_file = tempfile.NamedTemporaryFile(dir=new_dag_home, suffix=".py") + new_dag_file.write(get_dag_string("test_scheduler_dags.py").encode()) + new_dag_file.flush() + + manager = DagProcessorJobRunner( + job=Job(), + processor=DagFileProcessorManager( + dag_directory=new_dag_home, + max_runs=1, + processor_timeout=timedelta(minutes=10), + signal_conn=MagicMock(), + dag_ids=[], + pickle_dags=False, + async_mode=True, + ), + ) - def create_dag_folder(dag_id): - dag_home = tempfile.mkdtemp(dir=tmpdir) - dag_file = tempfile.NamedTemporaryFile(dir=dag_home, suffix=".py") - dag_file.write(get_dag_string(dag_id).encode()) - dag_file.flush() - return dag_home, dag_file + dagbag = DagBag(old_dag_file.name, read_dags_from_db=False) + other_dagbag = DagBag(new_dag_file.name, read_dags_from_db=False) - with tempfile.TemporaryDirectory() as tmpdir: - old_dag_home, old_dag_file = create_dag_folder("test_example_bash_operator.py") - new_dag_home, new_dag_file = create_dag_folder("test_scheduler_dags.py") - example_dag_home, example_dag_file = create_dag_folder("test_dag_warnings.py") - - with mock.patch("airflow.dag_processing.manager.example_dag_folder", example_dag_home): - manager = DagProcessorJobRunner( - job=Job(), - processor=DagFileProcessorManager( - dag_directory=new_dag_home, - max_runs=1, - processor_timeout=timedelta(minutes=10), - signal_conn=MagicMock(), - dag_ids=[], - pickle_dags=False, - async_mode=True, - ), - ) + with create_session() as session: + # Add DAG from old dah home to the DB + dag = dagbag.get_dag("test_example_bash_operator") + dag.fileloc = old_dag_file.name + dag.last_parsed_time = timezone.utcnow() + dag.sync_to_db(processor_subdir=old_dag_home) - with create_session() as session: - add_dag_to_db(old_dag_file.name, "test_example_bash_operator", old_dag_home) - add_dag_to_db(new_dag_file.name, "test_start_date_scheduling", new_dag_home) - add_dag_to_db(example_dag_file.name, "test_dag_warnings", example_dag_home) + # Add DAG from new DAG home to the DB + other_dag = other_dagbag.get_dag("test_start_date_scheduling") + other_dag.fileloc = new_dag_file.name + other_dag.last_parsed_time = timezone.utcnow() + other_dag.sync_to_db(processor_subdir=new_dag_home) - manager.processor._file_paths = [new_dag_file, example_dag_file] + manager.processor._file_paths = [new_dag_file] - active_dag_count = ( - session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() - ) - assert active_dag_count == 3 + active_dag_count = ( + session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() + ) + assert active_dag_count == 2 - manager.processor._scan_stale_dags() + manager.processor._scan_stale_dags() - active_dag_count = ( - session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() - ) - assert active_dag_count == 2 + active_dag_count = ( + session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar() + ) + assert active_dag_count == 1 @mock.patch( "airflow.dag_processing.processor.DagFileProcessorProcess.waitable_handle", new_callable=PropertyMock From 455a1fefb69d96784cb9f0e2fe77a9d0f576df4f Mon Sep 17 00:00:00 2001 From: Nikita Date: Thu, 12 Sep 2024 17:56:02 +0300 Subject: [PATCH 050/349] Fix require_confirmation_dag_change (#42063) * Add 'lower' to require_confirmation_dag_change * Set input attribute only if require-confirmation is true --- airflow/www/static/js/dag.js | 2 +- airflow/www/templates/airflow/dag.html | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/www/static/js/dag.js b/airflow/www/static/js/dag.js index 38930637396aa..f178ac0a5b024 100644 --- a/airflow/www/static/js/dag.js +++ b/airflow/www/static/js/dag.js @@ -55,7 +55,7 @@ $("#pause_resume").on("change", function onChange() { const $input = $(this); const id = $input.data("dag-id"); const isPaused = $input.is(":checked"); - const requireConfirmation = $input.data("require-confirmation"); + const requireConfirmation = $input.is("[data-require-confirmation]"); if (requireConfirmation) { const confirmation = window.confirm( `Are you sure you want to ${isPaused ? "resume" : "pause"} this DAG?` diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html index 1fc711d993ebd..b0c00bcd5c88f 100644 --- a/airflow/www/templates/airflow/dag.html +++ b/airflow/www/templates/airflow/dag.html @@ -120,7 +120,7 @@

{% endif %}

, " f"got {table_input}" ) - - # Exclude partition from the table name - table_id = table_id.split("$")[0] - if project_id is None: if var_name is not None: self.log.info( diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 81db43c0f5310..02f442cfc7caa 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -1034,7 +1034,6 @@ def test_split_tablename_internal_need_default_project(self): with pytest.raises(ValueError, match="INTERNAL: No default project is specified"): self.hook.split_tablename("dataset.table", None) - @pytest.mark.parametrize("partition", ["$partition", ""]) @pytest.mark.parametrize( "project_expected, dataset_expected, table_expected, table_input", [ @@ -1045,11 +1044,9 @@ def test_split_tablename_internal_need_default_project(self): ("alt1:alt", "dataset", "table", "alt1:alt:dataset.table"), ], ) - def test_split_tablename( - self, project_expected, dataset_expected, table_expected, table_input, partition - ): + def test_split_tablename(self, project_expected, dataset_expected, table_expected, table_input): default_project_id = "project" - project, dataset, table = self.hook.split_tablename(table_input + partition, default_project_id) + project, dataset, table = self.hook.split_tablename(table_input, default_project_id) assert project_expected == project assert dataset_expected == dataset assert table_expected == table From 2b8833db04ca71b3db9526484816e7918ea7a166 Mon Sep 17 00:00:00 2001 From: rom sharon <33751805+romsharon98@users.noreply.github.com> Date: Fri, 27 Sep 2024 15:07:49 +0300 Subject: [PATCH 201/349] Add slack notification for canary build failures (#42394) * add slack notifier --- .github/workflows/ci.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 14fe0bbe4baa9..8625aee73d9e1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,6 +39,7 @@ env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_USERNAME: ${{ github.actor }} IMAGE_TAG: "${{ github.event.pull_request.head.sha || github.sha }}" + SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} VERBOSE: "true" concurrency: @@ -669,3 +670,31 @@ jobs: include-success-outputs: ${{ needs.build-info.outputs.include-success-outputs }} docker-cache: ${{ needs.build-info.outputs.docker-cache }} canary-run: ${{ needs.build-info.outputs.canary-run }} + + notify-slack-failure: + name: "Notify Slack on Failure" + if: github.event_name == 'schedule' && failure() + runs-on: ["ubuntu-22.04"] + steps: + - name: Notify Slack + id: slack + uses: slackapi/slack-github-action@v1.27.0 + with: + channel-id: 'zzz_webhook_test' + # yamllint disable rule:line-length + payload: | + { + "text": "🚨🕒 Scheduled CI Failure Alert 🕒🚨\n\n*Details:* ", + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "🚨🕒 Scheduled CI Failure Alert 🕒🚨\n\n*Details:* " + } + } + ] + } + # yamllint enable rule:line-length + env: + SLACK_BOT_TOKEN: ${{ env.SLACK_BOT_TOKEN }} From b2f64e7822c94ce43554d4ebfdfb5d8ac4ceacbe Mon Sep 17 00:00:00 2001 From: Pierre Jeambrun Date: Fri, 27 Sep 2024 20:27:27 +0800 Subject: [PATCH 202/349] AIP-84 Add HTTPException openapi documentation (#42508) * Add HTTPException openapi documentation * Update following code review --- airflow/api_fastapi/openapi/__init__.py | 16 ++++++++ airflow/api_fastapi/openapi/exceptions.py | 41 +++++++++++++++++++ airflow/api_fastapi/openapi/v1-generated.yaml | 36 ++++++++++++++++ airflow/api_fastapi/views/public/dags.py | 14 +++---- airflow/api_fastapi/views/ui/datasets.py | 2 - .../ui/openapi-gen/requests/schemas.gen.ts | 20 +++++++++ .../ui/openapi-gen/requests/services.gen.ts | 4 ++ airflow/ui/openapi-gen/requests/types.gen.ts | 27 ++++++++++++ 8 files changed, 150 insertions(+), 10 deletions(-) create mode 100644 airflow/api_fastapi/openapi/__init__.py create mode 100644 airflow/api_fastapi/openapi/exceptions.py diff --git a/airflow/api_fastapi/openapi/__init__.py b/airflow/api_fastapi/openapi/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/api_fastapi/openapi/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/api_fastapi/openapi/exceptions.py b/airflow/api_fastapi/openapi/exceptions.py new file mode 100644 index 0000000000000..b3eaf204cc063 --- /dev/null +++ b/airflow/api_fastapi/openapi/exceptions.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from pydantic import BaseModel + + +class HTTPExceptionResponse(BaseModel): + """HTTPException Model used for error response.""" + + detail: str | dict + + +def create_openapi_http_exception_doc(responses_status_code: list[int]) -> dict: + """ + Will create additional response example for errors raised by the endpoint. + + There is no easy way to introspect the code and automatically see what HTTPException are actually + raised by the endpoint implementation. This piece of documentation needs to be kept + in sync with the endpoint code manually. + + Validation error i.e 422 are natively added to the openapi documentation by FastAPI. + """ + responses_status_code = sorted(responses_status_code) + + return {status_code: {"model": HTTPExceptionResponse} for status_code in responses_status_code} diff --git a/airflow/api_fastapi/openapi/v1-generated.yaml b/airflow/api_fastapi/openapi/v1-generated.yaml index f488825449c3a..64e475aeb6baa 100644 --- a/airflow/api_fastapi/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/openapi/v1-generated.yaml @@ -168,6 +168,30 @@ paths: application/json: schema: $ref: '#/components/schemas/DAGResponse' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found '422': description: Validation Error content: @@ -386,6 +410,18 @@ components: title: DagTagPydantic description: Serializable representation of the DagTag ORM SqlAlchemyModel used by internal API. + HTTPExceptionResponse: + properties: + detail: + anyOf: + - type: string + - type: object + title: Detail + type: object + required: + - detail + title: HTTPExceptionResponse + description: HTTPException Model used for error response. HTTPValidationError: properties: detail: diff --git a/airflow/api_fastapi/views/public/dags.py b/airflow/api_fastapi/views/public/dags.py index 07ab968adc975..a9fe87eef0953 100644 --- a/airflow/api_fastapi/views/public/dags.py +++ b/airflow/api_fastapi/views/public/dags.py @@ -23,6 +23,7 @@ from typing_extensions import Annotated from airflow.api_fastapi.db import apply_filters_to_select, get_session, latest_dag_run_per_dag_id_cte +from airflow.api_fastapi.openapi.exceptions import create_openapi_http_exception_doc from airflow.api_fastapi.parameters import ( QueryDagDisplayNamePatternSearch, QueryDagIdPatternSearch, @@ -95,16 +96,13 @@ async def get_dags( dags = session.scalars(dags_query).all() - try: - return DAGCollectionResponse( - dags=[DAGResponse.model_validate(dag, from_attributes=True) for dag in dags], - total_entries=total_entries, - ) - except ValueError as e: - raise HTTPException(400, f"DAGCollectionSchema error: {str(e)}") + return DAGCollectionResponse( + dags=[DAGResponse.model_validate(dag, from_attributes=True) for dag in dags], + total_entries=total_entries, + ) -@dags_router.patch("/dags/{dag_id}") +@dags_router.patch("/dags/{dag_id}", responses=create_openapi_http_exception_doc([400, 401, 403, 404])) async def patch_dag( dag_id: str, patch_body: DAGPatchBody, diff --git a/airflow/api_fastapi/views/ui/datasets.py b/airflow/api_fastapi/views/ui/datasets.py index 484385031a23d..f5dd2cacb126d 100644 --- a/airflow/api_fastapi/views/ui/datasets.py +++ b/airflow/api_fastapi/views/ui/datasets.py @@ -29,8 +29,6 @@ datasets_router = APIRouter(tags=["Dataset"]) -# Ultimately we want async routes, with async sqlalchemy session / context manager. -# Additional effort to make airflow utility code async, not handled for now and most likely part of the AIP-70 @datasets_router.get("/next_run_datasets/{dag_id}", include_in_schema=False) async def next_run_datasets( dag_id: str, diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index d9ce0528c396c..e8c9b5d70cdf4 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -317,6 +317,26 @@ export const $DagTagPydantic = { "Serializable representation of the DagTag ORM SqlAlchemyModel used by internal API.", } as const; +export const $HTTPExceptionResponse = { + properties: { + detail: { + anyOf: [ + { + type: "string", + }, + { + type: "object", + }, + ], + title: "Detail", + }, + }, + type: "object", + required: ["detail"], + title: "HTTPExceptionResponse", + description: "HTTPException Model used for error response.", +} as const; + export const $HTTPValidationError = { properties: { detail: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 9c261b3039000..37a4d11873acf 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -102,6 +102,10 @@ export class DagService { body: data.requestBody, mediaType: "application/json", errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", 422: "Validation Error", }, }); diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 803bcd84270c7..16977004e79d6 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -67,6 +67,17 @@ export type DagTagPydantic = { dag_id: string; }; +/** + * HTTPException Model used for error response. + */ +export type HTTPExceptionResponse = { + detail: + | string + | { + [key: string]: unknown; + }; +}; + export type HTTPValidationError = { detail?: Array; }; @@ -149,6 +160,22 @@ export type $OpenApiTs = { * Successful Response */ 200: DAGResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; /** * Validation Error */ From 1a6a4625fcaa28fd3156b5622f12b0bba7ac97fc Mon Sep 17 00:00:00 2001 From: GPK Date: Fri, 27 Sep 2024 13:55:24 +0100 Subject: [PATCH 203/349] Fix SparkKubernetesOperator spark name. (#42427) * use name parameter from spark yaml config or from operator argument parameter * update tests and name usage condition check * adding test, to check spark name starts with task_id * use set_name function in create_job * remove lower --- .../kubernetes/operators/spark_kubernetes.py | 15 +- ...ication_test_with_no_name_from_config.json | 57 +++++++ ...ication_test_with_no_name_from_config.yaml | 55 +++++++ .../operators/test_spark_kubernetes.py | 143 ++++++++++++++++++ 4 files changed, 265 insertions(+), 5 deletions(-) create mode 100644 tests/providers/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.json create mode 100644 tests/providers/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.yaml diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index 39fadae90e5bd..9bcf46d0d4f57 100644 --- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import re from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any @@ -83,7 +82,7 @@ def __init__( image: str | None = None, code_path: str | None = None, namespace: str = "default", - name: str = "default", + name: str | None = None, application_file: str | None = None, template_spec=None, get_logs: bool = True, @@ -103,7 +102,6 @@ def __init__( self.code_path = code_path self.application_file = application_file self.template_spec = template_spec - self.name = self.create_job_name() self.kubernetes_conn_id = kubernetes_conn_id self.startup_timeout_seconds = startup_timeout_seconds self.reattach_on_restart = reattach_on_restart @@ -161,8 +159,13 @@ def manage_template_specs(self): return template_body def create_job_name(self): - initial_name = add_unique_suffix(name=self.task_id, max_len=MAX_LABEL_LEN) - return re.sub(r"[^a-z0-9-]+", "-", initial_name.lower()) + name = ( + self.name or self.template_body.get("spark", {}).get("metadata", {}).get("name") or self.task_id + ) + + updated_name = add_unique_suffix(name=name, max_len=MAX_LABEL_LEN) + + return self._set_name(updated_name) @staticmethod def _get_pod_identifying_label_string(labels) -> str: @@ -282,6 +285,8 @@ def custom_obj_api(self) -> CustomObjectsApi: return CustomObjectsApi() def execute(self, context: Context): + self.name = self.create_job_name() + self.log.info("Creating sparkApplication.") self.launcher = CustomObjectLauncher( name=self.name, diff --git a/tests/providers/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.json b/tests/providers/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.json new file mode 100644 index 0000000000000..1504c40fbd1e9 --- /dev/null +++ b/tests/providers/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.json @@ -0,0 +1,57 @@ +{ + "apiVersion":"sparkoperator.k8s.io/v1beta2", + "kind":"SparkApplication", + "metadata":{ + "namespace":"default" + }, + "spec":{ + "type":"Scala", + "mode":"cluster", + "image":"gcr.io/spark-operator/spark:v2.4.5", + "imagePullPolicy":"Always", + "mainClass":"org.apache.spark.examples.SparkPi", + "mainApplicationFile":"local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", + "sparkVersion":"2.4.5", + "restartPolicy":{ + "type":"Never" + }, + "volumes":[ + { + "name":"test-volume", + "hostPath":{ + "path":"/tmp", + "type":"Directory" + } + } + ], + "driver":{ + "cores":1, + "coreLimit":"1200m", + "memory":"512m", + "labels":{ + "version":"2.4.5" + }, + "serviceAccount":"spark", + "volumeMounts":[ + { + "name":"test-volume", + "mountPath":"/tmp" + } + ] + }, + "executor":{ + "cores":1, + "instances":1, + "memory":"512m", + "labels":{ + "version":"2.4.5" + }, + "volumeMounts":[ + { + "name":"test-volume", + "mountPath":"/tmp" + } + ] + } + } +} diff --git a/tests/providers/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.yaml b/tests/providers/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.yaml new file mode 100644 index 0000000000000..91723980954ee --- /dev/null +++ b/tests/providers/cncf/kubernetes/data_files/spark/application_test_with_no_name_from_config.yaml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +--- +apiVersion: "sparkoperator.k8s.io/v1beta2" +kind: SparkApplication +metadata: + namespace: default +spec: + type: Scala + mode: cluster + image: "gcr.io/spark-operator/spark:v2.4.5" + imagePullPolicy: Always + mainClass: org.apache.spark.examples.SparkPi + mainApplicationFile: "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar" + sparkVersion: "2.4.5" + restartPolicy: + type: Never + volumes: + - name: "test-volume" + hostPath: + path: "/tmp" + type: Directory + driver: + cores: 1 + coreLimit: "1200m" + memory: "512m" + labels: + version: 2.4.5 + serviceAccount: spark + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" + executor: + cores: 1 + instances: 1 + memory: "512m" + labels: + version: 2.4.5 + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py index bc8404aa85607..9c8c40de6558d 100644 --- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -273,6 +273,149 @@ def test_create_application_from_yaml_json( version="v1beta2", ) + def test_create_application_from_yaml_json_and_use_name_from_metadata( + self, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_start, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + op = SparkKubernetesOperator( + application_file=data_file("spark/application_test.yaml").as_posix(), + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="create_app_and_use_name_from_metadata", + ) + context = create_context(op) + op.execute(context) + TEST_APPLICATION_DICT["metadata"]["name"] = op.name + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + assert op.name.startswith("default_yaml") + + op = SparkKubernetesOperator( + application_file=data_file("spark/application_test.json").as_posix(), + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="create_app_and_use_name_from_metadata", + ) + context = create_context(op) + op.execute(context) + TEST_APPLICATION_DICT["metadata"]["name"] = op.name + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + assert op.name.startswith("default_json") + + def test_create_application_from_yaml_json_and_use_name_from_operator_args( + self, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_start, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + op = SparkKubernetesOperator( + application_file=data_file("spark/application_test.yaml").as_posix(), + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="default_yaml", + name="test-spark", + ) + context = create_context(op) + op.execute(context) + TEST_APPLICATION_DICT["metadata"]["name"] = op.name + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + assert op.name.startswith("test-spark") + + op = SparkKubernetesOperator( + application_file=data_file("spark/application_test.json").as_posix(), + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="default_json", + name="test-spark", + ) + context = create_context(op) + op.execute(context) + TEST_APPLICATION_DICT["metadata"]["name"] = op.name + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + assert op.name.startswith("test-spark") + + def test_create_application_from_yaml_json_and_use_name_task_id( + self, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_start, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + op = SparkKubernetesOperator( + application_file=data_file("spark/application_test_with_no_name_from_config.yaml").as_posix(), + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="create_app_and_use_name_from_task_id", + ) + context = create_context(op) + op.execute(context) + TEST_APPLICATION_DICT["metadata"]["name"] = op.name + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + assert op.name.startswith("create_app_and_use_name_from_task_id") + + op = SparkKubernetesOperator( + application_file=data_file("spark/application_test_with_no_name_from_config.json").as_posix(), + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="create_app_and_use_name_from_task_id", + ) + context = create_context(op) + op.execute(context) + TEST_APPLICATION_DICT["metadata"]["name"] = op.name + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + assert op.name.startswith("create_app_and_use_name_from_task_id") + def test_new_template_from_yaml( self, mock_create_namespaced_crd, From 46ba84f6dcda5201eafbf2dbe2e63c10c3e5944a Mon Sep 17 00:00:00 2001 From: dan-js <50807588+dan-js@users.noreply.github.com> Date: Fri, 27 Sep 2024 13:56:48 +0100 Subject: [PATCH 204/349] Fix incorrect operator name in FileTransferOperator example (#42543) --- docs/apache-airflow-providers-common-io/operators.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/apache-airflow-providers-common-io/operators.rst b/docs/apache-airflow-providers-common-io/operators.rst index 7170c18980238..12b4a1c207ff0 100644 --- a/docs/apache-airflow-providers-common-io/operators.rst +++ b/docs/apache-airflow-providers-common-io/operators.rst @@ -38,7 +38,7 @@ location to another. Parameters of the operator are: If the ``src`` and the ``dst`` are both on the same object storage, copy will be performed in the object storage. Otherwise the data will be streamed from the source to the destination. -The example below shows how to instantiate the SQLExecuteQueryOperator task. +The example below shows how to instantiate the FileTransferOperator task. .. exampleinclude:: /../../tests/system/providers/common/io/example_file_transfer_local_to_s3.py :language: python From 1f8b99b8817ddc9acb899617c479ed5eb5ad312b Mon Sep 17 00:00:00 2001 From: GPK Date: Fri, 27 Sep 2024 14:24:42 +0100 Subject: [PATCH 205/349] Pre commit script to validate template fields (#42284) --- .pre-commit-config.yaml | 7 + contributing-docs/08_static_code_checks.rst | 2 + .../doc/images/output_static-checks.svg | 12 +- .../doc/images/output_static-checks.txt | 2 +- .../src/airflow_breeze/pre_commit_ids.py | 1 + .../pre_commit/check_provider_yaml_files.py | 15 +- .../ci/pre_commit/check_template_fields.py | 40 ++++ .../ci/pre_commit/common_precommit_utils.py | 17 ++ scripts/ci/pre_commit/migration_reference.py | 14 +- scripts/ci/pre_commit/update_er_diagram.py | 13 +- .../ci/pre_commit/update_fastapi_api_spec.py | 13 +- .../in_container/run_template_fields_check.py | 180 ++++++++++++++++++ 12 files changed, 279 insertions(+), 37 deletions(-) create mode 100755 scripts/ci/pre_commit/check_template_fields.py create mode 100644 scripts/in_container/run_template_fields_check.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 942b34ca2e6d5..2263086335bc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1343,6 +1343,13 @@ repos: files: ^airflow/providers/.*/provider\.yaml$ additional_dependencies: ['rich>=12.4.4'] require_serial: true + - id: check-template-fields-valid + name: Check templated fields mapped in operators/sensors + language: python + entry: ./scripts/ci/pre_commit/check_template_fields.py + files: ^airflow/.*/sensors/.*\.py$|^airflow/.*/operators/.*\.py$ + additional_dependencies: [ 'rich>=12.4.4' ] + require_serial: true - id: update-migration-references name: Update migration ref doc language: python diff --git a/contributing-docs/08_static_code_checks.rst b/contributing-docs/08_static_code_checks.rst index 0a3dcacd9e070..d50b9db3e607f 100644 --- a/contributing-docs/08_static_code_checks.rst +++ b/contributing-docs/08_static_code_checks.rst @@ -236,6 +236,8 @@ require Breeze Docker image to be built locally. +-----------------------------------------------------------+--------------------------------------------------------+---------+ | check-template-context-variable-in-sync | Sync template context variable refs | | +-----------------------------------------------------------+--------------------------------------------------------+---------+ +| check-template-fields-valid | Check templated fields mapped in operators/sensors | * | ++-----------------------------------------------------------+--------------------------------------------------------+---------+ | check-tests-in-the-right-folders | Check if tests are in the right folders | | +-----------------------------------------------------------+--------------------------------------------------------+---------+ | check-tests-unittest-testcase | Unit tests do not inherit from unittest.TestCase | | diff --git a/dev/breeze/doc/images/output_static-checks.svg b/dev/breeze/doc/images/output_static-checks.svg index ed52a596def64..36b88513a56ba 100644 --- a/dev/breeze/doc/images/output_static-checks.svg +++ b/dev/breeze/doc/images/output_static-checks.svg @@ -356,12 +356,12 @@ check-safe-filter-usage-in-html | check-sql-dependency-common-data-structure |    check-start-date-not-used-in-defaults | check-system-tests-present |              check-system-tests-tocs | check-taskinstance-tis-attrs |                          -check-template-context-variable-in-sync | check-tests-in-the-right-folders |      -check-tests-unittest-testcase | check-urlparse-usage-in-code |                    -check-usage-of-re2-over-re | check-xml | codespell | compile-ui-assets |          -compile-ui-assets-dev | compile-www-assets | compile-www-assets-dev |             -create-missing-init-py-files-tests | debug-statements | detect-private-key |      -doctoc | end-of-file-fixer | fix-encoding-pragma | flynt |                        +check-template-context-variable-in-sync | check-template-fields-valid |           +check-tests-in-the-right-folders | check-tests-unittest-testcase |                +check-urlparse-usage-in-code | check-usage-of-re2-over-re | check-xml | codespell +| compile-ui-assets | compile-ui-assets-dev | compile-www-assets |                +compile-www-assets-dev | create-missing-init-py-files-tests | debug-statements |  +detect-private-key | doctoc | end-of-file-fixer | fix-encoding-pragma | flynt |   generate-airflow-diagrams | generate-openapi-spec | generate-pypi-readme |        identity | insert-license | kubeconform | lint-chart-schema | lint-css |          lint-dockerfile | lint-helm-chart | lint-json-schema | lint-markdown |            diff --git a/dev/breeze/doc/images/output_static-checks.txt b/dev/breeze/doc/images/output_static-checks.txt index 3a3837fbb15bb..9e3ae46130640 100644 --- a/dev/breeze/doc/images/output_static-checks.txt +++ b/dev/breeze/doc/images/output_static-checks.txt @@ -1 +1 @@ -5c6ba60b1865538bce04fc940cd240c6 +e33cdf5f43d8c63290e44e92dc19d2c4 diff --git a/dev/breeze/src/airflow_breeze/pre_commit_ids.py b/dev/breeze/src/airflow_breeze/pre_commit_ids.py index 9a48df5e3f69c..457379f5b90ba 100644 --- a/dev/breeze/src/airflow_breeze/pre_commit_ids.py +++ b/dev/breeze/src/airflow_breeze/pre_commit_ids.py @@ -83,6 +83,7 @@ "check-system-tests-tocs", "check-taskinstance-tis-attrs", "check-template-context-variable-in-sync", + "check-template-fields-valid", "check-tests-in-the-right-folders", "check-tests-unittest-testcase", "check-urlparse-usage-in-code", diff --git a/scripts/ci/pre_commit/check_provider_yaml_files.py b/scripts/ci/pre_commit/check_provider_yaml_files.py index fcbe2512910a3..f848e38afa0b2 100755 --- a/scripts/ci/pre_commit/check_provider_yaml_files.py +++ b/scripts/ci/pre_commit/check_provider_yaml_files.py @@ -17,12 +17,15 @@ # under the License. from __future__ import annotations -import os import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.resolve())) -from common_precommit_utils import console, initialize_breeze_precommit, run_command_via_breeze_shell +from common_precommit_utils import ( + initialize_breeze_precommit, + run_command_via_breeze_shell, + validate_cmd_result, +) initialize_breeze_precommit(__name__, __file__) @@ -33,10 +36,4 @@ warn_image_upgrade_needed=True, extra_env={"PYTHONWARNINGS": "default"}, ) -if cmd_result.returncode != 0 and os.environ.get("CI") != "true": - console.print( - "\n[yellow]If you see strange stacktraces above, especially about missing imports " - "run this command:[/]\n" - ) - console.print("[magenta]breeze ci-image build --python 3.8 --upgrade-to-newer-dependencies[/]\n") -sys.exit(cmd_result.returncode) +validate_cmd_result(cmd_result, include_ci_env_check=True) diff --git a/scripts/ci/pre_commit/check_template_fields.py b/scripts/ci/pre_commit/check_template_fields.py new file mode 100755 index 0000000000000..da0b60fbd978f --- /dev/null +++ b/scripts/ci/pre_commit/check_template_fields.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.resolve())) +from common_precommit_utils import ( + initialize_breeze_precommit, + run_command_via_breeze_shell, + validate_cmd_result, +) + +initialize_breeze_precommit(__name__, __file__) +py_files_to_test = sys.argv[1:] + +cmd_result = run_command_via_breeze_shell( + ["python3", "/opt/airflow/scripts/in_container/run_template_fields_check.py", *py_files_to_test], + backend="sqlite", + warn_image_upgrade_needed=True, + extra_env={"PYTHONWARNINGS": "default"}, +) + +validate_cmd_result(cmd_result, include_ci_env_check=True) diff --git a/scripts/ci/pre_commit/common_precommit_utils.py b/scripts/ci/pre_commit/common_precommit_utils.py index 41bc3a5eeaf93..4f62c50cabeaa 100644 --- a/scripts/ci/pre_commit/common_precommit_utils.py +++ b/scripts/ci/pre_commit/common_precommit_utils.py @@ -211,3 +211,20 @@ def check_list_sorted(the_list: list[str], message: str, errors: list[str]) -> b console.print() errors.append(f"ERROR in {message}. The elements are not sorted/unique.") return False + + +def validate_cmd_result(cmd_result, include_ci_env_check=False): + if include_ci_env_check: + if cmd_result.returncode != 0 and os.environ.get("CI") != "true": + console.print( + "\n[yellow]If you see strange stacktraces above, especially about missing imports " + "run this command:[/]\n" + ) + console.print("[magenta]breeze ci-image build --python 3.8 --upgrade-to-newer-dependencies[/]\n") + + elif cmd_result.returncode != 0: + console.print( + "[warning]\nIf you see strange stacktraces above, " + "run `breeze ci-image build --python 3.8` and try again." + ) + sys.exit(cmd_result.returncode) diff --git a/scripts/ci/pre_commit/migration_reference.py b/scripts/ci/pre_commit/migration_reference.py index 34d3a94c6a90d..505bea5ca91af 100755 --- a/scripts/ci/pre_commit/migration_reference.py +++ b/scripts/ci/pre_commit/migration_reference.py @@ -21,7 +21,11 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.resolve())) -from common_precommit_utils import console, initialize_breeze_precommit, run_command_via_breeze_shell +from common_precommit_utils import ( + initialize_breeze_precommit, + run_command_via_breeze_shell, + validate_cmd_result, +) initialize_breeze_precommit(__name__, __file__) @@ -29,9 +33,5 @@ ["python3", "/opt/airflow/scripts/in_container/run_migration_reference.py"], backend="sqlite", ) -if cmd_result.returncode != 0: - console.print( - "[warning]\nIf you see strange stacktraces above, " - "run `breeze ci-image build --python 3.8` and try again." - ) -sys.exit(cmd_result.returncode) + +validate_cmd_result(cmd_result) diff --git a/scripts/ci/pre_commit/update_er_diagram.py b/scripts/ci/pre_commit/update_er_diagram.py index e660b47c6e6ae..c4f3cb797cf21 100755 --- a/scripts/ci/pre_commit/update_er_diagram.py +++ b/scripts/ci/pre_commit/update_er_diagram.py @@ -21,7 +21,11 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.resolve())) -from common_precommit_utils import console, initialize_breeze_precommit, run_command_via_breeze_shell +from common_precommit_utils import ( + initialize_breeze_precommit, + run_command_via_breeze_shell, + validate_cmd_result, +) initialize_breeze_precommit(__name__, __file__) @@ -36,9 +40,4 @@ }, ) -if cmd_result.returncode != 0: - console.print( - "[warning]\nIf you see strange stacktraces above, " - "run `breeze ci-image build --python 3.8` and try again." - ) - sys.exit(cmd_result.returncode) +validate_cmd_result(cmd_result) diff --git a/scripts/ci/pre_commit/update_fastapi_api_spec.py b/scripts/ci/pre_commit/update_fastapi_api_spec.py index 15ccaa5ac209e..3d7731c7ef2e2 100755 --- a/scripts/ci/pre_commit/update_fastapi_api_spec.py +++ b/scripts/ci/pre_commit/update_fastapi_api_spec.py @@ -21,7 +21,11 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.resolve())) -from common_precommit_utils import console, initialize_breeze_precommit, run_command_via_breeze_shell +from common_precommit_utils import ( + initialize_breeze_precommit, + run_command_via_breeze_shell, + validate_cmd_result, +) initialize_breeze_precommit(__name__, __file__) @@ -31,9 +35,4 @@ skip_environment_initialization=False, ) -if cmd_result.returncode != 0: - console.print( - "[warning]\nIf you see strange stacktraces above, " - "run `breeze ci-image build --python 3.8` and try again." - ) -sys.exit(cmd_result.returncode) +validate_cmd_result(cmd_result) diff --git a/scripts/in_container/run_template_fields_check.py b/scripts/in_container/run_template_fields_check.py new file mode 100644 index 0000000000000..202dce35c5745 --- /dev/null +++ b/scripts/in_container/run_template_fields_check.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import importlib.util +import inspect +import itertools +import pathlib +import sys +import warnings + +import yaml +from rich.console import Console + +try: + from yaml import CSafeLoader as SafeLoader +except ImportError: + from yaml import SafeLoader # type: ignore + +console = Console(width=400, color_system="standard") +ROOT_DIR = pathlib.Path(__file__).resolve().parents[2] + +provider_files_pattern = pathlib.Path(ROOT_DIR, "airflow", "providers").rglob("provider.yaml") +errors: list[str] = [] + +OPERATORS: list[str] = ["sensors", "operators"] +CLASS_IDENTIFIERS: list[str] = ["sensor", "operator"] + +TEMPLATE_TYPES: list[str] = ["template_fields"] + + +class InstanceFieldExtractor(ast.NodeVisitor): + def __init__(self): + self.current_class = None + self.instance_fields = [] + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: + if node.name == "__init__": + self.generic_visit(node) + return node + + def visit_Assign(self, node: ast.Assign) -> ast.Assign: + fields = [] + for target in node.targets: + if isinstance(target, ast.Attribute): + fields.append(target.attr) + if fields: + self.instance_fields.extend(fields) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: + if isinstance(node.target, ast.Attribute): + self.instance_fields.append(node.target.attr) + return node + + +def get_template_fields_and_class_instance_fields(cls): + """ + 1.This method retrieves the operator class and obtains all its parent classes using the method resolution order (MRO). + 2. It then gathers the templated fields declared in both the operator class and its parent classes. + 3. Finally, it retrieves the instance fields of the operator class, specifically the self.fields attributes. + """ + all_template_fields = [] + class_instance_fields = [] + + all_classes = cls.__mro__ + for current_class in all_classes: + if current_class.__init__ is not object.__init__: + cls_attr = current_class.__dict__ + for template_type in TEMPLATE_TYPES: + fields = cls_attr.get(template_type) + if fields: + all_template_fields.extend(fields) + + tree = ast.parse(inspect.getsource(current_class)) + visitor = InstanceFieldExtractor() + visitor.visit(tree) + if visitor.instance_fields: + class_instance_fields.extend(visitor.instance_fields) + return all_template_fields, class_instance_fields + + +def load_yaml_data() -> dict: + """ + It loads all the provider YAML files and retrieves the module referenced within each YAML file. + """ + package_paths = sorted(str(path) for path in provider_files_pattern) + result = {} + for provider_yaml_path in package_paths: + with open(provider_yaml_path) as yaml_file: + provider = yaml.load(yaml_file, SafeLoader) + rel_path = pathlib.Path(provider_yaml_path).relative_to(ROOT_DIR).as_posix() + result[rel_path] = provider + return result + + +def get_providers_modules() -> list[str]: + modules_container = [] + result = load_yaml_data() + + for (_, provider_data), resource_type in itertools.product(result.items(), OPERATORS): + if provider_data.get(resource_type): + for data in provider_data.get(resource_type): + modules_container.extend(data.get("python-modules")) + + return modules_container + + +def is_class_eligible(name: str) -> bool: + for op in CLASS_IDENTIFIERS: + if name.lower().endswith(op): + return True + return False + + +def get_eligible_classes(all_classes): + """ + Filter the results to include only classes that end with `Sensor` or `Operator`. + + """ + + eligible_classes = [(name, cls) for name, cls in all_classes if is_class_eligible(name)] + return eligible_classes + + +def iter_check_template_fields(module: str): + """ + 1. This method imports the providers module and retrieves all the classes defined within it. + 2. It then filters and selects classes related to operators or sensors by checking if the class name ends with "Operator" or "Sensor." + 3. For each operator class, it validates the template fields by inspecting the class instance fields. + """ + with warnings.catch_warnings(record=True): + imported_module = importlib.import_module(module) + classes = inspect.getmembers(imported_module, inspect.isclass) + op_classes = get_eligible_classes(classes) + + for op_class_name, cls in op_classes: + if cls.__module__ == module: + templated_fields, class_instance_fields = get_template_fields_and_class_instance_fields(cls) + + for field in templated_fields: + if field not in class_instance_fields: + errors.append(f"{module}: {op_class_name}: {field}") + + +if __name__ == "__main__": + provider_modules = get_providers_modules() + + if len(sys.argv) > 1: + py_files = sorted(sys.argv[1:]) + modules_to_validate = [ + module_name + for pyfile in py_files + if (module_name := pyfile.rstrip(".py").replace("/", ".")) in provider_modules + ] + else: + modules_to_validate = provider_modules + + [iter_check_template_fields(module) for module in modules_to_validate] + if errors: + console.print("[red]Found Invalid template fields:") + for error in errors: + console.print(f"[red]Error:[/] {error}") + + sys.exit(len(errors)) From 63dcb24767e081513962f15a248c63a0a374ce16 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 27 Sep 2024 13:25:19 -0700 Subject: [PATCH 206/349] Remove DagRun.is_backfill attribute (#42548) This attribute is only used in one place and is not very useful. --- airflow/models/dagrun.py | 6 +----- newsfragments/42548.significant.rst | 1 + tests/jobs/test_scheduler_job.py | 5 ++--- 3 files changed, 4 insertions(+), 8 deletions(-) create mode 100644 newsfragments/42548.significant.rst diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 3ef1c18f152a4..5d53e51763dff 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1267,7 +1267,7 @@ def verify_integrity(self, *, session: Session = NEW_SESSION) -> None: def task_filter(task: Operator) -> bool: return task.task_id not in task_ids and ( - self.is_backfill + self.run_type == DagRunType.BACKFILL_JOB or (task.start_date is None or task.start_date <= self.execution_date) and (task.end_date is None or self.execution_date <= task.end_date) ) @@ -1538,10 +1538,6 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> session.flush() yield ti - @property - def is_backfill(self) -> bool: - return self.run_type == DagRunType.BACKFILL_JOB - @classmethod @provide_session def get_latest_runs(cls, session: Session = NEW_SESSION) -> list[DagRun]: diff --git a/newsfragments/42548.significant.rst b/newsfragments/42548.significant.rst new file mode 100644 index 0000000000000..28d6795eebcc6 --- /dev/null +++ b/newsfragments/42548.significant.rst @@ -0,0 +1 @@ +Remove is_backfill attribute from DagRun object diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 52e9dbdeb1a04..78a911153dab2 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -601,8 +601,7 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self, dag_maker): ti1.state = State.SCHEDULED session.merge(ti1) session.flush() - - assert dr1.is_backfill + assert dr1.run_type == DagRunType.BACKFILL_JOB self.job_runner._critical_section_enqueue_task_instances(session) session.flush() @@ -3851,7 +3850,7 @@ def test_adopt_or_reset_orphaned_tasks_backfill_dag(self, dag_maker): session.merge(dr1) session.flush() - assert dr1.is_backfill + assert dr1.run_type == DagRunType.BACKFILL_JOB assert 0 == self.job_runner.adopt_or_reset_orphaned_tasks(session=session) session.rollback() From cfc374bbba359ca54db2287bccad2b415446e557 Mon Sep 17 00:00:00 2001 From: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> Date: Sat, 28 Sep 2024 09:34:19 +0200 Subject: [PATCH 207/349] Attempt to correct dependency for Slack notification for canary build (#42551) --- .github/workflows/ci.yml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8625aee73d9e1..8828a30ce3ecd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -673,6 +673,16 @@ jobs: notify-slack-failure: name: "Notify Slack on Failure" + needs: + - basic-tests + - additional-ci-image-checks + - providers + - tests-helm + - tests-special + - tests-with-lowest-direct-resolution + - additional-prod-image-tests + - tests-kubernetes + - finalize-tests if: github.event_name == 'schedule' && failure() runs-on: ["ubuntu-22.04"] steps: From d12b835562dabd2c2e4d060b77315df94c706d6f Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Sat, 28 Sep 2024 15:41:38 +0100 Subject: [PATCH 208/349] Airflow 2.10.2 has been released (#42405) --- .github/ISSUE_TEMPLATE/airflow_bug_report.yml | 2 +- Dockerfile | 2 +- README.md | 12 +++--- RELEASE_NOTES.rst | 40 ++++++++++++++++++- airflow/reproducible_build.yaml | 4 +- .../installation/supported-versions.rst | 2 +- generated/PYPI_README.md | 10 ++--- scripts/ci/pre_commit/supported_versions.py | 2 +- 8 files changed, 55 insertions(+), 19 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/airflow_bug_report.yml b/.github/ISSUE_TEMPLATE/airflow_bug_report.yml index 853b102ef07f8..f835c879f8380 100644 --- a/.github/ISSUE_TEMPLATE/airflow_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/airflow_bug_report.yml @@ -25,7 +25,7 @@ body: the latest release or main to see if the issue is fixed before reporting it. multiple: false options: - - "2.10.1" + - "2.10.2" - "main (development)" - "Other Airflow 2 version (please specify below)" validations: diff --git a/Dockerfile b/Dockerfile index 68f1ed166f12a..3053e0779540d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,7 +45,7 @@ ARG AIRFLOW_UID="50000" ARG AIRFLOW_USER_HOME_DIR=/home/airflow # latest released version here -ARG AIRFLOW_VERSION="2.10.1" +ARG AIRFLOW_VERSION="2.10.2" ARG PYTHON_BASE_IMAGE="python:3.8-slim-bookworm" diff --git a/README.md b/README.md index 91ddf5e927245..3169ac5144844 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ Airflow is not a streaming solution, but it is often used to process real-time d Apache Airflow is tested with: -| | Main version (dev) | Stable version (2.10.1) | +| | Main version (dev) | Stable version (2.10.2) | |------------|----------------------------|----------------------------| | Python | 3.8, 3.9, 3.10, 3.11, 3.12 | 3.8, 3.9, 3.10, 3.11, 3.12 | | Platform | AMD64/ARM64(\*) | AMD64/ARM64(\*) | @@ -177,15 +177,15 @@ them to the appropriate format and workflow that your tool requires. ```bash -pip install 'apache-airflow==2.10.1' \ - --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.1/constraints-3.8.txt" +pip install 'apache-airflow==2.10.2' \ + --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.2/constraints-3.8.txt" ``` 2. Installing with extras (i.e., postgres, google) ```bash -pip install 'apache-airflow[postgres,google]==2.10.1' \ - --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.1/constraints-3.8.txt" +pip install 'apache-airflow[postgres,google]==2.10.2' \ + --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.2/constraints-3.8.txt" ``` For information on installing provider packages, check @@ -290,7 +290,7 @@ Apache Airflow version life cycle: | Version | Current Patch/Minor | State | First Release | Limited Support | EOL/Terminated | |-----------|-----------------------|-----------|-----------------|-------------------|------------------| -| 2 | 2.10.1 | Supported | Dec 17, 2020 | TBD | TBD | +| 2 | 2.10.2 | Supported | Dec 17, 2020 | TBD | TBD | | 1.10 | 1.10.15 | EOL | Aug 27, 2018 | Dec 17, 2020 | June 17, 2021 | | 1.9 | 1.9.0 | EOL | Jan 03, 2018 | Aug 27, 2018 | Aug 27, 2018 | | 1.8 | 1.8.2 | EOL | Mar 19, 2017 | Jan 03, 2018 | Jan 03, 2018 | diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst index d42074b1146a7..6c84e45d8aca0 100644 --- a/RELEASE_NOTES.rst +++ b/RELEASE_NOTES.rst @@ -21,6 +21,43 @@ .. towncrier release notes start +Airflow 2.10.2 (2024-09-18) +--------------------------- + +Significant Changes +^^^^^^^^^^^^^^^^^^^ + +No significant changes. + +Bug Fixes +""""""""" +- Revert "Fix: DAGs are not marked as stale if the dags folder change" (#42220, #42217) +- Add missing open telemetry span and correct scheduled slots documentation (#41985) +- Fix require_confirmation_dag_change (#42063) (#42211) +- Only treat null/undefined as falsy when rendering XComEntry (#42199) (#42213) +- Add extra and ``renderedTemplates`` as keys to skip ``camelCasing`` (#42206) (#42208) +- Do not ``camelcase`` xcom entries (#42182) (#42187) +- Fix task_instance and dag_run links from list views (#42138) (#42143) +- Support multi-line input for Params of type string in trigger UI form (#40414) (#42139) +- Fix details tab log url detection (#42104) (#42114) +- Add new type of exception to catch timeout (#42064) (#42078) +- Rewrite how DAG to dataset / dataset alias are stored (#41987) (#42055) +- Allow dataset alias to add more than one dataset events (#42189) (#42247) + +Miscellaneous +""""""""""""" +- Limit universal-pathlib below ``0.2.4`` as it breaks our integration (#42101) +- Auto-fix default deferrable with ``LibCST`` (#42089) +- Deprecate ``--tree`` flag for ``tasks list`` cli command (#41965) + +Doc Only Changes +"""""""""""""""" +- Update ``security_model.rst`` to clear unauthenticated endpoints exceptions (#42085) +- Add note about dataclasses and attrs to XComs page (#42056) +- Improve docs on markdown docs in DAGs (#42013) +- Add warning that listeners can be dangerous (#41968) + + Airflow 2.10.1 (2024-09-05) --------------------------- @@ -38,7 +75,7 @@ Bug Fixes - Fix compatibility with FAB provider versions <1.3.0 (#41809) - Don't Fail LocalTaskJob on heartbeat (#41810) - Remove deprecation warning for cgitb in Plugins Manager (#41793) -- Fix log for notifier(instance) without __name__ (#41699) +- Fix log for notifier(instance) without ``__name__`` (#41699) - Splitting syspath preparation into stages (#41694) - Adding url sanitization for extra links (#41680) - Fix InletEventsAccessors type stub (#41607) @@ -64,7 +101,6 @@ Doc Only Changes - Add an example for auth with ``keycloak`` (#41791) - Airflow 2.10.0 (2024-08-15) --------------------------- diff --git a/airflow/reproducible_build.yaml b/airflow/reproducible_build.yaml index 31e63fbce742b..1bf308b87a705 100644 --- a/airflow/reproducible_build.yaml +++ b/airflow/reproducible_build.yaml @@ -1,2 +1,2 @@ -release-notes-hash: aa948d55b0b6062659dbcd0293d73838 -source-date-epoch: 1725624671 +release-notes-hash: 828fa8d5e93e215963c0a3e52e7f1e3d +source-date-epoch: 1727075869 diff --git a/docs/apache-airflow/installation/supported-versions.rst b/docs/apache-airflow/installation/supported-versions.rst index 0a7694abbda3d..d82500728ce3b 100644 --- a/docs/apache-airflow/installation/supported-versions.rst +++ b/docs/apache-airflow/installation/supported-versions.rst @@ -29,7 +29,7 @@ Apache Airflow® version life cycle: ========= ===================== ========= =============== ================= ================ Version Current Patch/Minor State First Release Limited Support EOL/Terminated ========= ===================== ========= =============== ================= ================ -2 2.10.1 Supported Dec 17, 2020 TBD TBD +2 2.10.2 Supported Dec 17, 2020 TBD TBD 1.10 1.10.15 EOL Aug 27, 2018 Dec 17, 2020 June 17, 2021 1.9 1.9.0 EOL Jan 03, 2018 Aug 27, 2018 Aug 27, 2018 1.8 1.8.2 EOL Mar 19, 2017 Jan 03, 2018 Jan 03, 2018 diff --git a/generated/PYPI_README.md b/generated/PYPI_README.md index 2b80e73a45f5e..50802f301b753 100644 --- a/generated/PYPI_README.md +++ b/generated/PYPI_README.md @@ -54,7 +54,7 @@ Use Airflow to author workflows as directed acyclic graphs (DAGs) of tasks. The Apache Airflow is tested with: -| | Main version (dev) | Stable version (2.10.1) | +| | Main version (dev) | Stable version (2.10.2) | |------------|----------------------------|----------------------------| | Python | 3.8, 3.9, 3.10, 3.11, 3.12 | 3.8, 3.9, 3.10, 3.11, 3.12 | | Platform | AMD64/ARM64(\*) | AMD64/ARM64(\*) | @@ -130,15 +130,15 @@ them to the appropriate format and workflow that your tool requires. ```bash -pip install 'apache-airflow==2.10.1' \ - --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.1/constraints-3.8.txt" +pip install 'apache-airflow==2.10.2' \ + --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.2/constraints-3.8.txt" ``` 2. Installing with extras (i.e., postgres, google) ```bash -pip install 'apache-airflow[postgres,google]==2.10.1' \ - --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.1/constraints-3.8.txt" +pip install 'apache-airflow[postgres,google]==2.10.2' \ + --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.2/constraints-3.8.txt" ``` For information on installing provider packages, check diff --git a/scripts/ci/pre_commit/supported_versions.py b/scripts/ci/pre_commit/supported_versions.py index b392eaf6d4e01..ab8204ab03baa 100755 --- a/scripts/ci/pre_commit/supported_versions.py +++ b/scripts/ci/pre_commit/supported_versions.py @@ -27,7 +27,7 @@ HEADERS = ("Version", "Current Patch/Minor", "State", "First Release", "Limited Support", "EOL/Terminated") SUPPORTED_VERSIONS = ( - ("2", "2.10.1", "Supported", "Dec 17, 2020", "TBD", "TBD"), + ("2", "2.10.2", "Supported", "Dec 17, 2020", "TBD", "TBD"), ("1.10", "1.10.15", "EOL", "Aug 27, 2018", "Dec 17, 2020", "June 17, 2021"), ("1.9", "1.9.0", "EOL", "Jan 03, 2018", "Aug 27, 2018", "Aug 27, 2018"), ("1.8", "1.8.2", "EOL", "Mar 19, 2017", "Jan 03, 2018", "Jan 03, 2018"), From a4f8faabdbe817c57a30df20f7ca10f08115c496 Mon Sep 17 00:00:00 2001 From: GPK Date: Sun, 29 Sep 2024 12:40:37 +0100 Subject: [PATCH 209/349] remove callable functions parameter from kafka operator template_fields (#42555) --- airflow/providers/apache/kafka/operators/consume.py | 1 - airflow/providers/apache/kafka/operators/produce.py | 1 - 2 files changed, 2 deletions(-) diff --git a/airflow/providers/apache/kafka/operators/consume.py b/airflow/providers/apache/kafka/operators/consume.py index 91d2f4f052daf..377b58a46df5e 100644 --- a/airflow/providers/apache/kafka/operators/consume.py +++ b/airflow/providers/apache/kafka/operators/consume.py @@ -68,7 +68,6 @@ class ConsumeFromTopicOperator(BaseOperator): ui_color = BLUE template_fields = ( "topics", - "apply_function", "apply_function_args", "apply_function_kwargs", "kafka_config_id", diff --git a/airflow/providers/apache/kafka/operators/produce.py b/airflow/providers/apache/kafka/operators/produce.py index 04090811b9ec7..e0623128a1f7d 100644 --- a/airflow/providers/apache/kafka/operators/produce.py +++ b/airflow/providers/apache/kafka/operators/produce.py @@ -67,7 +67,6 @@ class ProduceToTopicOperator(BaseOperator): template_fields = ( "topic", - "producer_function", "producer_function_args", "producer_function_kwargs", "kafka_config_id", From 18bc08f9910b174d1fdad51dc603f72978ca1463 Mon Sep 17 00:00:00 2001 From: Danny Liu Date: Sun, 29 Sep 2024 04:42:01 -0700 Subject: [PATCH 210/349] fix PyDocStyle checks (#42557) --- tests/www/views/test_views_task_norun.py | 2 +- tests/www/views/test_views_tasks.py | 4 ++-- tests/www/views/test_views_trigger_dag.py | 2 +- tests/www/views/test_views_variable.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/www/views/test_views_task_norun.py b/tests/www/views/test_views_task_norun.py index a0709c4303d99..2a39b2a60134e 100644 --- a/tests/www/views/test_views_task_norun.py +++ b/tests/www/views/test_views_task_norun.py @@ -32,7 +32,7 @@ @pytest.fixture(scope="module", autouse=True) -def reset_dagruns(): +def _reset_dagruns(): """Clean up stray garbage from other tests.""" clear_db_runs() diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 0b52c1f9aef3c..f5cc011fb6f0e 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -64,13 +64,13 @@ @pytest.fixture(scope="module", autouse=True) -def reset_dagruns(): +def _reset_dagruns(): """Clean up stray garbage from other tests.""" clear_db_runs() @pytest.fixture(autouse=True) -def init_dagruns(app): +def _init_dagruns(app): with time_machine.travel(DEFAULT_DATE, tick=False): triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} app.dag_bag.get_dag("example_bash_operator").create_dagrun( diff --git a/tests/www/views/test_views_trigger_dag.py b/tests/www/views/test_views_trigger_dag.py index 01b6713600af3..0c9384a195f5e 100644 --- a/tests/www/views/test_views_trigger_dag.py +++ b/tests/www/views/test_views_trigger_dag.py @@ -40,7 +40,7 @@ @pytest.fixture(autouse=True) -def initialize_one_dag(): +def _initialize_one_dag(): with create_session() as session: DagBag().get_dag("example_bash_operator").sync_to_db(session=session) yield diff --git a/tests/www/views/test_views_variable.py b/tests/www/views/test_views_variable.py index fcdad2bdb0bdd..a91a12ddc470b 100644 --- a/tests/www/views/test_views_variable.py +++ b/tests/www/views/test_views_variable.py @@ -43,7 +43,7 @@ @pytest.fixture(autouse=True) -def clear_variables(): +def _clear_variables(): with create_session() as session: session.query(Variable).delete() From 47e04177d87ab5894c21befc2d82e833f44fcdad Mon Sep 17 00:00:00 2001 From: Kunal Bhattacharya Date: Sun, 29 Sep 2024 21:55:20 +0530 Subject: [PATCH 211/349] Documentation change to highlight difference in usage between params and parameters attributes in SQLExecuteQueryOperator for Postgres (#42564) * Documentation change to call out the difference in usage between params and parameters attributes * Static checks fix --- .../operators/postgres_operator_howto_guide.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst b/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst index f9dafe34196b1..09402178aa057 100644 --- a/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst +++ b/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst @@ -123,10 +123,11 @@ Passing Parameters into SQLExecuteQueryOperator for Postgres SQLExecuteQueryOperator provides ``parameters`` attribute which makes it possible to dynamically inject values into your SQL requests during runtime. The BaseOperator class has the ``params`` attribute which is available to the SQLExecuteQueryOperator -by virtue of inheritance. Both ``parameters`` and ``params`` make it possible to dynamically pass in parameters in many -interesting ways. +by virtue of inheritance. While both ``parameters`` and ``params`` make it possible to dynamically pass in parameters in many +interesting ways, their usage is slightly different as demonstrated in the examples below. -To find the owner of the pet called 'Lester': +To find the birth dates of all pets between two dates, when we use the SQL statements directly in our code, we will use the +``parameters`` attribute: .. code-block:: python @@ -137,16 +138,15 @@ To find the owner of the pet called 'Lester': parameters={"begin_date": "2020-01-01", "end_date": "2020-12-31"}, ) -Now lets refactor our ``get_birth_date`` task. Instead of dumping SQL statements directly into our code, let's tidy things up -by creating a sql file. +Now lets refactor our ``get_birth_date`` task. Now, instead of dumping SQL statements directly into our code, let's tidy things up +by creating a sql file. And this time we will use the ``params`` attribute which we get for free from the parent ``BaseOperator`` +class. :: -- dags/sql/birth_date.sql SELECT * FROM pet WHERE birth_date BETWEEN SYMMETRIC {{ params.begin_date }} AND {{ params.end_date }}; -And this time we will use the ``params`` attribute which we get for free from the parent ``BaseOperator`` -class. .. code-block:: python From 4f177f30093fc33b885743f284d91ef2202e6cc9 Mon Sep 17 00:00:00 2001 From: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> Date: Sun, 29 Sep 2024 23:02:07 +0200 Subject: [PATCH 212/349] Bugfix/42575 workaround pin azure kusto data (#42576) * Workaround, pin azure-kusto-data to not 4.6.0 --- airflow/providers/microsoft/azure/provider.yaml | 3 ++- generated/provider_dependencies.json | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index 9110a3046b5d1..45fe28eecffc7 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -101,7 +101,8 @@ dependencies: - azure-synapse-artifacts>=0.17.0 - adal>=1.2.7 - azure-storage-file-datalake>=12.9.1 - - azure-kusto-data>=4.1.0 + # azure-kusto-data 4.6.0 breaks main - see https://github.com/apache/airflow/issues/42575 + - azure-kusto-data>=4.1.0,!=4.6.0 - azure-mgmt-datafactory>=2.0.0 - azure-mgmt-containerregistry>=8.0.0 - azure-mgmt-containerinstance>=10.1.0 diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 074c5dd41e93b..b9bc363b15e33 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -804,7 +804,7 @@ "azure-datalake-store>=0.0.45", "azure-identity>=1.3.1", "azure-keyvault-secrets>=4.1.0", - "azure-kusto-data>=4.1.0", + "azure-kusto-data>=4.1.0,!=4.6.0", "azure-mgmt-containerinstance>=10.1.0", "azure-mgmt-containerregistry>=8.0.0", "azure-mgmt-cosmosdb>=3.0.0", From f05c842bcb0682413641b96b7fbf08f142a69bfb Mon Sep 17 00:00:00 2001 From: Topher Anderson <48180628+topherinternational@users.noreply.github.com> Date: Sun, 29 Sep 2024 19:21:11 -0500 Subject: [PATCH 213/349] Bump uv to 0.4.17 (#42574) --- Dockerfile | 2 +- Dockerfile.ci | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 3053e0779540d..cfb894ac87d22 100644 --- a/Dockerfile +++ b/Dockerfile @@ -50,7 +50,7 @@ ARG AIRFLOW_VERSION="2.10.2" ARG PYTHON_BASE_IMAGE="python:3.8-slim-bookworm" ARG AIRFLOW_PIP_VERSION=24.2 -ARG AIRFLOW_UV_VERSION=0.4.7 +ARG AIRFLOW_UV_VERSION=0.4.17 ARG AIRFLOW_USE_UV="false" ARG UV_HTTP_TIMEOUT="300" ARG AIRFLOW_IMAGE_REPOSITORY="https://github.com/apache/airflow" diff --git a/Dockerfile.ci b/Dockerfile.ci index ad944d151adcb..f7b7bb4172025 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -1262,7 +1262,7 @@ ARG DEFAULT_CONSTRAINTS_BRANCH="constraints-main" ARG AIRFLOW_CI_BUILD_EPOCH="10" ARG AIRFLOW_PRE_CACHED_PIP_PACKAGES="true" ARG AIRFLOW_PIP_VERSION=24.2 -ARG AIRFLOW_UV_VERSION=0.4.7 +ARG AIRFLOW_UV_VERSION=0.4.17 ARG AIRFLOW_USE_UV="true" # Setup PIP # By default PIP install run without cache to make image smaller From ea99f74d27b7b83b8e7f7b8a177dc8e0d5a8f2d7 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Sun, 29 Sep 2024 21:43:44 -0700 Subject: [PATCH 214/349] Change default .airflowignore syntax to glob (#42436) Co-authored-by: Shahar Epstein <60007259+shahar1@users.noreply.github.com> --- airflow/config_templates/config.yml | 2 +- airflow/utils/file.py | 6 ++-- .../modules_management.rst | 9 +----- docs/apache-airflow/core-concepts/dags.rst | 31 +++++++------------ .../howto/dynamic-dag-generation.rst | 2 +- newsfragments/42436.significant.rst | 7 +++++ tests/dags/.airflowignore | 5 ++- tests/dags/subdir1/.airflowignore | 2 +- tests/plugins/test_plugin_ignore.py | 2 +- 9 files changed, 29 insertions(+), 37 deletions(-) create mode 100644 newsfragments/42436.significant.rst diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index c9abee3c85065..7317fce60e4e6 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -310,7 +310,7 @@ core: version_added: 2.3.0 type: string example: ~ - default: "regexp" + default: "glob" default_task_retries: description: | The number of retries each task is going to have by default. Can be overridden at dag or task level. diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 7081113d5bd46..2e39eb7dd7b52 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -221,7 +221,7 @@ def _find_path_from_directory( def find_path_from_directory( base_dir_path: str | os.PathLike[str], ignore_file_name: str, - ignore_file_syntax: str = conf.get_mandatory_value("core", "DAG_IGNORE_FILE_SYNTAX", fallback="regexp"), + ignore_file_syntax: str = conf.get_mandatory_value("core", "DAG_IGNORE_FILE_SYNTAX", fallback="glob"), ) -> Generator[str, None, None]: """ Recursively search the base path for a list of file paths that should not be ignored. @@ -232,9 +232,9 @@ def find_path_from_directory( :return: a generator of file paths. """ - if ignore_file_syntax == "glob": + if ignore_file_syntax == "glob" or not ignore_file_syntax: return _find_path_from_directory(base_dir_path, ignore_file_name, _GlobIgnoreRule) - elif ignore_file_syntax == "regexp" or not ignore_file_syntax: + elif ignore_file_syntax == "regexp": return _find_path_from_directory(base_dir_path, ignore_file_name, _RegexpIgnoreRule) else: raise ValueError(f"Unsupported ignore_file_syntax: {ignore_file_syntax}") diff --git a/docs/apache-airflow/administration-and-deployment/modules_management.rst b/docs/apache-airflow/administration-and-deployment/modules_management.rst index dc6be49b1d43d..25adb5f333c91 100644 --- a/docs/apache-airflow/administration-and-deployment/modules_management.rst +++ b/docs/apache-airflow/administration-and-deployment/modules_management.rst @@ -125,14 +125,7 @@ for the paths that should be ignored. You do not need to have that file in any o In the example above the DAGs are only in ``my_custom_dags`` folder, the ``common_package`` should not be scanned by scheduler when searching for DAGS, so we should ignore ``common_package`` folder. You also want to ignore the ``base_dag.py`` if you keep a base DAG there that ``my_dag1.py`` and ``my_dag2.py`` derives -from. Your ``.airflowignore`` should look then like this: - -.. code-block:: none - - my_company/common_package/.* - my_company/my_custom_dags/base_dag\.py - -If ``DAG_IGNORE_FILE_SYNTAX`` is set to ``glob``, the equivalent ``.airflowignore`` file would be: +from. Your ``.airflowignore`` should look then like this (using the default ``glob`` syntax): .. code-block:: none diff --git a/docs/apache-airflow/core-concepts/dags.rst b/docs/apache-airflow/core-concepts/dags.rst index fbef745e46d34..f9dc7d64c72e0 100644 --- a/docs/apache-airflow/core-concepts/dags.rst +++ b/docs/apache-airflow/core-concepts/dags.rst @@ -712,19 +712,9 @@ configuration parameter (*added in Airflow 2.3*): ``regexp`` and ``glob``. .. note:: - The default ``DAG_IGNORE_FILE_SYNTAX`` is ``regexp`` to ensure backwards compatibility. + The default ``DAG_IGNORE_FILE_SYNTAX`` is ``glob`` in Airflow 3 or later (in previous versions it was ``regexp``). -For the ``regexp`` pattern syntax (the default), each line in ``.airflowignore`` -specifies a regular expression pattern, and directories or files whose names (not DAG id) -match any of the patterns would be ignored (under the hood, ``Pattern.search()`` is used -to match the pattern). Use the ``#`` character to indicate a comment; all characters -on lines starting with ``#`` will be ignored. - -As with most regexp matching in Airflow, the regexp engine is ``re2``, which explicitly -doesn't support many advanced features, please check its -`documentation `_ for more information. - -With the ``glob`` syntax, the patterns work just like those in a ``.gitignore`` file: +With the ``glob`` syntax (the default), the patterns work just like those in a ``.gitignore`` file: * The ``*`` character will match any number of characters, except ``/`` * The ``?`` character will match any single character, except ``/`` @@ -738,15 +728,18 @@ With the ``glob`` syntax, the patterns work just like those in a ``.gitignore`` is relative to the directory level of the particular .airflowignore file itself. Otherwise the pattern may also match at any level below the .airflowignore level. -The ``.airflowignore`` file should be put in your ``DAG_FOLDER``. For example, you can prepare -a ``.airflowignore`` file using the ``regexp`` syntax with content - -.. code-block:: +For the ``regexp`` pattern syntax, each line in ``.airflowignore`` +specifies a regular expression pattern, and directories or files whose names (not DAG id) +match any of the patterns would be ignored (under the hood, ``Pattern.search()`` is used +to match the pattern). Use the ``#`` character to indicate a comment; all characters +on lines starting with ``#`` will be ignored. - project_a - tenant_[\d] +As with most regexp matching in Airflow, the regexp engine is ``re2``, which explicitly +doesn't support many advanced features, please check its +`documentation `_ for more information. -Or, equivalently, in the ``glob`` syntax +The ``.airflowignore`` file should be put in your ``DAG_FOLDER``. For example, you can prepare +a ``.airflowignore`` file with the ``glob`` syntax .. code-block:: diff --git a/docs/apache-airflow/howto/dynamic-dag-generation.rst b/docs/apache-airflow/howto/dynamic-dag-generation.rst index 5d542a29320b7..9aa988f28bdb1 100644 --- a/docs/apache-airflow/howto/dynamic-dag-generation.rst +++ b/docs/apache-airflow/howto/dynamic-dag-generation.rst @@ -91,7 +91,7 @@ Then you can import and use the ``ALL_TASKS`` constant in all your DAGs like tha ... Don't forget that in this case you need to add empty ``__init__.py`` file in the ``my_company_utils`` folder -and you should add the ``my_company_utils/.*`` line to ``.airflowignore`` file (if using the regexp ignore +and you should add the ``my_company_utils/*`` line to ``.airflowignore`` file (using the default glob syntax), so that the whole folder is ignored by the scheduler when it looks for DAGs. diff --git a/newsfragments/42436.significant.rst b/newsfragments/42436.significant.rst new file mode 100644 index 0000000000000..d9dbcfc4c9f5d --- /dev/null +++ b/newsfragments/42436.significant.rst @@ -0,0 +1,7 @@ +Default ``.airflowignore`` syntax changed to ``glob`` + +The default value to the configuration ``[core] dag_ignore_file_syntax`` has +been changed to ``glob``, which better matches the ignore file behavior of many +popular tools. + +To revert to the previous behavior, set the configuration to ``regexp``. diff --git a/tests/dags/.airflowignore b/tests/dags/.airflowignore index 313b04ef81cd4..7daaf22e65efc 100644 --- a/tests/dags/.airflowignore +++ b/tests/dags/.airflowignore @@ -1,3 +1,2 @@ -.*_invalid.* # Skip invalid files -subdir3 # Skip the nested subdir3 directory -# *badrule # This rule is an invalid regex. It would be warned about and skipped. +*_invalid_* # Skip invalid files +subdir3 # Skip the nested subdir3 directory diff --git a/tests/dags/subdir1/.airflowignore b/tests/dags/subdir1/.airflowignore index 8b69a752e69fb..0bfa43be300a1 100644 --- a/tests/dags/subdir1/.airflowignore +++ b/tests/dags/subdir1/.airflowignore @@ -1 +1 @@ -.*_ignore_this.py # Ignore files ending with "_ignore_this.py" +*_ignore_this.py # Ignore files ending with "_ignore_this.py" diff --git a/tests/plugins/test_plugin_ignore.py b/tests/plugins/test_plugin_ignore.py index d995fabd080f8..92951304d2b9f 100644 --- a/tests/plugins/test_plugin_ignore.py +++ b/tests/plugins/test_plugin_ignore.py @@ -77,7 +77,7 @@ def test_find_not_should_ignore_path_regexp(self, tmp_path): "test_load_sub1.py", } ignore_list_file = ".airflowignore" - for file_path in find_path_from_directory(plugin_folder_path, ignore_list_file): + for file_path in find_path_from_directory(plugin_folder_path, ignore_list_file, "regexp"): file_path = Path(file_path) if file_path.is_file() and file_path.suffix == ".py": detected_files.add(file_path.name) From 673da922182a3e4f2686b2a846715b3bd811c37e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 30 Sep 2024 14:30:24 +0900 Subject: [PATCH 215/349] Rename dataset related python variable names to asset (#41348) --- RELEASE_NOTES.rst | 4 +- airflow/__init__.py | 6 +- .../endpoints/dag_run_endpoint.py | 16 +- .../endpoints/dataset_endpoint.py | 190 +++--- .../{dataset_schema.py => asset_schema.py} | 88 +-- airflow/api_connexion/security.py | 8 +- airflow/api_fastapi/openapi/v1-generated.yaml | 8 +- airflow/api_fastapi/views/ui/__init__.py | 4 +- .../views/ui/{datasets.py => assets.py} | 34 +- .../endpoints/rpc_api_endpoint.py | 8 +- airflow/{datasets => assets}/__init__.py | 206 +++--- airflow/{datasets => assets}/manager.py | 195 +++--- airflow/{datasets => assets}/metadata.py | 10 +- airflow/auth/managers/base_auth_manager.py | 10 +- .../auth/managers/models/resource_details.py | 4 +- .../managers/simple/simple_auth_manager.py | 6 +- airflow/config_templates/config.yml | 18 +- airflow/dag_processing/collection.py | 147 +++-- airflow/decorators/base.py | 10 +- airflow/example_dags/example_asset_alias.py | 101 +++ .../example_asset_alias_with_no_taskflow.py | 108 ++++ airflow/example_dags/example_assets.py | 192 ++++++ airflow/example_dags/example_dataset_alias.py | 101 --- .../example_dataset_alias_with_no_taskflow.py | 108 ---- airflow/example_dags/example_datasets.py | 192 ------ .../example_dags/example_inlet_event_extra.py | 22 +- .../example_outlet_event_extra.py | 28 +- airflow/io/path.py | 12 +- airflow/jobs/scheduler_job_runner.py | 90 ++- airflow/lineage/__init__.py | 2 +- airflow/lineage/hook.py | 126 ++-- airflow/listeners/listener.py | 4 +- .../listeners/spec/{dataset.py => asset.py} | 18 +- airflow/models/__init__.py | 2 +- airflow/models/{dataset.py => asset.py} | 104 ++-- airflow/models/dag.py | 91 +-- airflow/models/taskinstance.py | 74 +-- airflow/operators/python.py | 6 +- airflow/provider.yaml.schema.json | 30 +- .../aws/{datasets => assets}/__init__.py | 0 .../amazon/aws/{datasets => assets}/s3.py | 16 +- .../amazon/aws/auth_manager/avp/entities.py | 2 +- .../aws/auth_manager/aws_auth_manager.py | 16 +- airflow/providers/amazon/aws/hooks/s3.py | 39 +- .../utils/asset_compat_lineage_collector.py | 106 ++++ airflow/providers/amazon/provider.yaml | 14 +- .../common/compat/assets/__init__.py | 77 +++ .../providers/common/compat/lineage/hook.py | 73 ++- .../openlineage/utils}/__init__.py | 0 .../common/compat/openlineage/utils/utils.py | 43 ++ .../compat/security}/__init__.py | 0 .../common/compat/security/permissions.py | 30 + .../datasets => common/io/assets}/__init__.py | 0 .../common/io/{datasets => assets}/file.py | 15 +- airflow/providers/common/io/provider.yaml | 14 +- .../fab/auth_manager/fab_auth_manager.py | 16 +- .../auth_manager/security_manager/override.py | 16 +- airflow/providers/fab/provider.yaml | 1 + airflow/providers/google/provider.yaml | 8 + .../datasets => mysql/assets}/__init__.py | 0 .../mysql/{datasets => assets}/mysql.py | 0 airflow/providers/mysql/provider.yaml | 8 +- .../openlineage/extractors/manager.py | 27 +- .../utils/asset_compat_lineage_collector.py | 108 ++++ airflow/providers/openlineage/utils/utils.py | 38 +- .../providers/postgres/assets}/__init__.py | 0 .../postgres/{datasets => assets}/postgres.py | 0 airflow/providers/postgres/provider.yaml | 8 +- .../providers/trino/assets}/__init__.py | 0 .../trino/{datasets => assets}/trino.py | 0 airflow/providers/trino/provider.yaml | 8 +- airflow/providers_manager.py | 48 +- airflow/reproducible_build.yaml | 4 +- airflow/security/permissions.py | 2 +- airflow/serialization/dag_dependency.py | 2 +- airflow/serialization/enums.py | 13 +- .../pydantic/{dataset.py => asset.py} | 22 +- airflow/serialization/pydantic/dag_run.py | 4 +- .../serialization/pydantic/taskinstance.py | 4 +- airflow/serialization/serialized_objects.py | 111 ++-- airflow/timetables/{datasets.py => assets.py} | 36 +- airflow/timetables/base.py | 24 +- airflow/timetables/simple.py | 40 +- airflow/ui/openapi-gen/queries/common.ts | 18 +- airflow/ui/openapi-gen/queries/prefetch.ts | 36 +- airflow/ui/openapi-gen/queries/queries.ts | 21 +- airflow/ui/openapi-gen/queries/suspense.ts | 52 +- .../ui/openapi-gen/requests/services.gen.ts | 14 +- airflow/ui/openapi-gen/requests/types.gen.ts | 6 +- airflow/utils/context.py | 102 +-- airflow/utils/context.pyi | 28 +- airflow/utils/operator_helpers.py | 2 +- airflow/www/auth.py | 6 +- airflow/www/security_manager.py | 4 +- airflow/www/static/css/graph.css | 8 +- .../www/static/js/dag/details/graph/Node.tsx | 2 +- .../www/static/js/dag/details/graph/index.tsx | 6 +- .../www/static/js/dag/details/graph/utils.ts | 2 +- airflow/www/static/js/datasets/Graph/Node.tsx | 4 +- .../www/static/js/datasets/Graph/index.tsx | 2 +- airflow/www/static/js/datasets/SearchBar.tsx | 2 +- airflow/www/static/js/types/index.ts | 4 +- airflow/www/templates/airflow/dag.html | 16 +- .../templates/airflow/dag_dependencies.html | 4 +- airflow/www/templates/airflow/dags.html | 22 +- airflow/www/views.py | 102 ++- dev/breeze/tests/test_packages.py | 3 + .../tests/test_pytest_args_for_test_types.py | 2 +- dev/breeze/tests/test_selective_checks.py | 26 +- .../auth-manager/manage/index.rst | 12 +- .../auth-manager/access-control.rst | 10 +- ...{dataset-schemes.rst => asset-schemes.rst} | 8 +- .../howto/create-custom-providers.rst | 4 +- .../administration-and-deployment/lineage.rst | 12 +- .../listeners.rst | 8 +- .../logging-monitoring/metrics.rst | 6 +- .../authoring-and-scheduling/assets.rst | 532 ++++++++++++++++ .../authoring-and-scheduling/datasets.rst | 532 ---------------- .../authoring-and-scheduling/index.rst | 2 +- .../authoring-and-scheduling/timetable.rst | 18 +- docs/apache-airflow/core-concepts/dag-run.rst | 4 +- .../apache-airflow/core-concepts/taskflow.rst | 14 +- ...uled-dags.png => asset-scheduled-dags.png} | Bin .../img/{datasets.png => assets.png} | Bin docs/apache-airflow/templates-ref.rst | 12 +- .../apache-airflow/tutorial/objectstorage.rst | 2 +- docs/apache-airflow/ui.rst | 4 +- docs/exts/operators_and_hooks_ref.py | 6 +- ...st.jinja2 => asset-uri-schemes.rst.jinja2} | 2 +- docs/spelling_wordlist.txt | 2 + generated/provider_dependencies.json | 10 +- newsfragments/41348.significant.rst | 240 +++++++ .../check_tests_in_right_folders.py | 2 +- scripts/cov/core_coverage.py | 2 +- scripts/cov/other_coverage.py | 4 +- tests/always/test_project_structure.py | 2 + .../endpoints/test_dag_run_endpoint.py | 22 +- .../endpoints/test_dag_source_endpoint.py | 2 +- .../endpoints/test_dataset_endpoint.py | 170 ++--- .../api_connexion/schemas/test_dag_schema.py | 12 +- .../schemas/test_dataset_schema.py | 90 ++- .../ui/{test_datasets.py => test_assets.py} | 4 +- .../common/io/datasets => assets}/__init__.py | 0 tests/{datasets => assets}/test_manager.py | 104 ++-- tests/assets/tests_asset.py | 586 +++++++++++++++++ .../simple/test_simple_auth_manager.py | 6 +- tests/auth/managers/test_base_auth_manager.py | 6 +- tests/conftest.py | 4 +- .../dags/{test_datasets.py => test_assets.py} | 6 +- tests/dags/test_only_empty_tasks.py | 4 +- tests/datasets/test_dataset.py | 588 ------------------ tests/decorators/test_python.py | 10 +- tests/io/test_path.py | 32 +- tests/io/test_wrapper.py | 14 +- tests/jobs/test_scheduler_job.py | 186 +++--- tests/lineage/test_hook.py | 140 ++--- ...{dataset_listener.py => asset_listener.py} | 14 +- ...set_listener.py => test_asset_listener.py} | 26 +- .../models/{test_dataset.py => test_asset.py} | 12 +- tests/models/test_dag.py | 326 +++++----- tests/models/test_dagrun.py | 2 +- tests/models/test_serialized_dag.py | 14 +- tests/models/test_taskinstance.py | 453 +++++++------- tests/operators/test_python.py | 4 +- .../aws/assets}/__init__.py | 0 .../aws/{datasets => assets}/test_s3.py | 22 +- .../aws/auth_manager/test_aws_auth_manager.py | 53 +- tests/providers/amazon/aws/hooks/test_s3.py | 50 +- .../compat/openlineage/utils}/__init__.py | 0 .../compat/openlineage/utils/test_utils.py | 23 + .../compat/security}/__init__.py | 0 .../compat/security/test_permissions.py | 23 + tests/providers/common/io/assets/__init__.py | 16 + .../io/{datasets => assets}/test_file.py | 16 +- .../fab/auth_manager/test_fab_auth_manager.py | 12 +- .../fab/auth_manager/test_security.py | 14 +- tests/providers/mysql/assets/__init__.py | 16 + .../mysql/{datasets => assets}/test_mysql.py | 2 +- .../openlineage/extractors/test_manager.py | 35 +- tests/providers/postgres/assets/__init__.py | 16 + .../{datasets => assets}/test_postgres.py | 2 +- tests/providers/trino/assets/__init__.py | 16 + .../trino/{datasets => assets}/test_trino.py | 2 +- tests/serialization/test_dag_serialization.py | 86 +-- tests/serialization/test_pydantic_models.py | 52 +- tests/serialization/test_serde.py | 6 +- .../serialization/test_serialized_objects.py | 48 +- .../microsoft/azure/example_msfabric.py | 4 +- .../providers/papermill/input_notebook.ipynb | 2 + tests/test_utils/compat.py | 33 + tests/test_utils/db.py | 35 +- ..._timetable.py => test_assets_timetable.py} | 134 ++-- tests/utils/test_context.py | 40 +- tests/utils/test_db_cleanup.py | 6 +- tests/utils/test_json.py | 10 +- tests/www/test_auth.py | 2 +- tests/www/views/test_views_acl.py | 8 +- tests/www/views/test_views_dataset.py | 197 +++--- tests/www/views/test_views_grid.py | 50 +- 199 files changed, 5061 insertions(+), 4127 deletions(-) rename airflow/api_connexion/schemas/{dataset_schema.py => asset_schema.py} (65%) rename airflow/api_fastapi/views/ui/{datasets.py => assets.py} (66%) rename airflow/{datasets => assets}/__init__.py (60%) rename airflow/{datasets => assets}/manager.py (53%) rename airflow/{datasets => assets}/metadata.py (80%) create mode 100644 airflow/example_dags/example_asset_alias.py create mode 100644 airflow/example_dags/example_asset_alias_with_no_taskflow.py create mode 100644 airflow/example_dags/example_assets.py delete mode 100644 airflow/example_dags/example_dataset_alias.py delete mode 100644 airflow/example_dags/example_dataset_alias_with_no_taskflow.py delete mode 100644 airflow/example_dags/example_datasets.py rename airflow/listeners/spec/{dataset.py => asset.py} (76%) rename airflow/models/{dataset.py => asset.py} (83%) rename airflow/providers/amazon/aws/{datasets => assets}/__init__.py (100%) rename airflow/providers/amazon/aws/{datasets => assets}/s3.py (73%) create mode 100644 airflow/providers/amazon/aws/utils/asset_compat_lineage_collector.py create mode 100644 airflow/providers/common/compat/assets/__init__.py rename airflow/providers/common/{io/datasets => compat/openlineage/utils}/__init__.py (100%) create mode 100644 airflow/providers/common/compat/openlineage/utils/utils.py rename airflow/providers/{mysql/datasets => common/compat/security}/__init__.py (100%) create mode 100644 airflow/providers/common/compat/security/permissions.py rename airflow/providers/{postgres/datasets => common/io/assets}/__init__.py (100%) rename airflow/providers/common/io/{datasets => assets}/file.py (76%) rename airflow/providers/{trino/datasets => mysql/assets}/__init__.py (100%) rename airflow/providers/mysql/{datasets => assets}/mysql.py (100%) create mode 100644 airflow/providers/openlineage/utils/asset_compat_lineage_collector.py rename {tests/datasets => airflow/providers/postgres/assets}/__init__.py (100%) rename airflow/providers/postgres/{datasets => assets}/postgres.py (100%) rename {tests/providers/amazon/aws/datasets => airflow/providers/trino/assets}/__init__.py (100%) rename airflow/providers/trino/{datasets => assets}/trino.py (100%) rename airflow/serialization/pydantic/{dataset.py => asset.py} (68%) rename airflow/timetables/{datasets.py => assets.py} (71%) rename docs/apache-airflow-providers/core-extensions/{dataset-schemes.rst => asset-schemes.rst} (82%) create mode 100644 docs/apache-airflow/authoring-and-scheduling/assets.rst delete mode 100644 docs/apache-airflow/authoring-and-scheduling/datasets.rst rename docs/apache-airflow/img/{dataset-scheduled-dags.png => asset-scheduled-dags.png} (100%) rename docs/apache-airflow/img/{datasets.png => assets.png} (100%) rename docs/exts/templates/{dataset-uri-schemes.rst.jinja2 => asset-uri-schemes.rst.jinja2} (95%) create mode 100644 newsfragments/41348.significant.rst rename tests/api_fastapi/views/ui/{test_datasets.py => test_assets.py} (92%) rename tests/{providers/common/io/datasets => assets}/__init__.py (100%) rename tests/{datasets => assets}/test_manager.py (52%) create mode 100644 tests/assets/tests_asset.py rename tests/dags/{test_datasets.py => test_assets.py} (91%) delete mode 100644 tests/datasets/test_dataset.py rename tests/listeners/{dataset_listener.py => asset_listener.py} (80%) rename tests/listeners/{test_dataset_listener.py => test_asset_listener.py} (72%) rename tests/models/{test_dataset.py => test_asset.py} (73%) rename tests/providers/{mysql/datasets => amazon/aws/assets}/__init__.py (100%) rename tests/providers/amazon/aws/{datasets => assets}/test_s3.py (75%) rename tests/providers/{postgres/datasets => common/compat/openlineage/utils}/__init__.py (100%) create mode 100644 tests/providers/common/compat/openlineage/utils/test_utils.py rename tests/providers/{trino/datasets => common/compat/security}/__init__.py (100%) create mode 100644 tests/providers/common/compat/security/test_permissions.py create mode 100644 tests/providers/common/io/assets/__init__.py rename tests/providers/common/io/{datasets => assets}/test_file.py (83%) create mode 100644 tests/providers/mysql/assets/__init__.py rename tests/providers/mysql/{datasets => assets}/test_mysql.py (97%) create mode 100644 tests/providers/postgres/assets/__init__.py rename tests/providers/postgres/{datasets => assets}/test_postgres.py (97%) create mode 100644 tests/providers/trino/assets/__init__.py rename tests/providers/trino/{datasets => assets}/test_trino.py (97%) rename tests/timetables/{test_datasets_timetable.py => test_assets_timetable.py} (57%) diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst index 6c84e45d8aca0..69e666461efae 100644 --- a/RELEASE_NOTES.rst +++ b/RELEASE_NOTES.rst @@ -642,7 +642,7 @@ Dataset URIs are now validated on input (#37005) Datasets must use a URI that conform to rules laid down in AIP-60, and the value will be automatically normalized when the DAG file is parsed. See -`documentation on Datasets `_ for +`documentation on Datasets `_ for a more detailed description on the rules. You may need to change your Dataset identifiers if they look like a URI, but are @@ -3264,7 +3264,7 @@ If you have the producer and consumer in different files you do not need to use Datasets represent the abstract concept of a dataset, and (for now) do not have any direct read or write capability - in this release we are adding the foundational feature that we will build upon. -For more info on Datasets please see :doc:`/authoring-and-scheduling/datasets`. +For more info on Datasets please see `Datasets documentation `_. Expanded dynamic task mapping support """"""""""""""""""""""""""""""""""""" diff --git a/airflow/__init__.py b/airflow/__init__.py index 8930f190130a7..18f4cc3e3c28c 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -55,7 +55,7 @@ __all__ = [ "__version__", "DAG", - "Dataset", + "Asset", "XComArg", ] @@ -76,7 +76,7 @@ # Things to lazy import in form {local_name: ('target_module', 'target_name', 'deprecated')} __lazy_imports: dict[str, tuple[str, str, bool]] = { "DAG": (".models.dag", "DAG", False), - "Dataset": (".datasets", "Dataset", False), + "Asset": (".assets", "Asset", False), "XComArg": (".models.xcom_arg", "XComArg", False), "version": (".version", "", False), # Deprecated lazy imports @@ -86,8 +86,8 @@ # These objects are imported by PEP-562, however, static analyzers and IDE's # have no idea about typing of these objects. # Add it under TYPE_CHECKING block should help with it. + from airflow.models.asset import Asset from airflow.models.dag import DAG - from airflow.models.dataset import Dataset from airflow.models.xcom_arg import XComArg diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 02847f0a00e92..02d4663837f4e 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -39,6 +39,10 @@ format_datetime, format_parameters, ) +from airflow.api_connexion.schemas.asset_schema import ( + AssetEventCollection, + asset_event_collection_schema, +) from airflow.api_connexion.schemas.dag_run_schema import ( DAGRunCollection, DAGRunCollectionSchema, @@ -50,10 +54,6 @@ set_dagrun_note_form_schema, set_dagrun_state_form_schema, ) -from airflow.api_connexion.schemas.dataset_schema import ( - DatasetEventCollection, - dataset_event_collection_schema, -) from airflow.api_connexion.schemas.task_instance_schema import ( TaskInstanceReferenceCollection, task_instance_reference_collection_schema, @@ -112,12 +112,12 @@ def get_dag_run( @security.requires_access_dag("GET", DagAccessEntity.RUN) -@security.requires_access_dataset("GET") +@security.requires_access_asset("GET") @provide_session def get_upstream_dataset_events( *, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION ) -> APIResponse: - """If dag run is dataset-triggered, return the dataset events that triggered it.""" + """If dag run is dataset-triggered, return the asset events that triggered it.""" dag_run: DagRun | None = session.scalar( select(DagRun).where( DagRun.dag_id == dag_id, @@ -130,8 +130,8 @@ def get_upstream_dataset_events( detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found", ) events = dag_run.consumed_dataset_events - return dataset_event_collection_schema.dump( - DatasetEventCollection(dataset_events=events, total_entries=len(events)) + return asset_event_collection_schema.dump( + AssetEventCollection(dataset_events=events, total_entries=len(events)) ) diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py index 1a1578266838c..95c3bead3da52 100644 --- a/airflow/api_connexion/endpoints/dataset_endpoint.py +++ b/airflow/api_connexion/endpoints/dataset_endpoint.py @@ -28,24 +28,24 @@ from airflow.api_connexion.endpoints.request_dict import get_json_request_dict from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters -from airflow.api_connexion.schemas.dataset_schema import ( - DagScheduleDatasetReference, - DatasetCollection, - DatasetEventCollection, +from airflow.api_connexion.schemas.asset_schema import ( + AssetCollection, + AssetEventCollection, + DagScheduleAssetReference, QueuedEvent, QueuedEventCollection, - TaskOutletDatasetReference, - create_dataset_event_schema, - dataset_collection_schema, - dataset_event_collection_schema, - dataset_event_schema, - dataset_schema, + TaskOutletAssetReference, + asset_collection_schema, + asset_event_collection_schema, + asset_event_schema, + asset_schema, + create_asset_event_schema, queued_event_collection_schema, queued_event_schema, ) -from airflow.datasets import Dataset -from airflow.datasets.manager import dataset_manager -from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel +from airflow.assets import Asset +from airflow.assets.manager import asset_manager +from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel from airflow.utils import timezone from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -60,24 +60,24 @@ RESOURCE_EVENT_PREFIX = "dataset" -@security.requires_access_dataset("GET") +@security.requires_access_asset("GET") @provide_session def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse: - """Get a Dataset.""" - dataset = session.scalar( - select(DatasetModel) - .where(DatasetModel.uri == uri) - .options(joinedload(DatasetModel.consuming_dags), joinedload(DatasetModel.producing_tasks)) + """Get an asset .""" + asset = session.scalar( + select(AssetModel) + .where(AssetModel.uri == uri) + .options(joinedload(AssetModel.consuming_dags), joinedload(AssetModel.producing_tasks)) ) - if not dataset: + if not asset: raise NotFound( - "Dataset not found", - detail=f"The Dataset with uri: `{uri}` was not found", + "Asset not found", + detail=f"The Asset with uri: `{uri}` was not found", ) - return dataset_schema.dump(dataset) + return asset_schema.dump(asset) -@security.requires_access_dataset("GET") +@security.requires_access_asset("GET") @format_parameters({"limit": check_limit}) @provide_session def get_datasets( @@ -89,30 +89,30 @@ def get_datasets( order_by: str = "id", session: Session = NEW_SESSION, ) -> APIResponse: - """Get datasets.""" + """Get assets.""" allowed_attrs = ["id", "uri", "created_at", "updated_at"] - total_entries = session.scalars(select(func.count(DatasetModel.id))).one() - query = select(DatasetModel) + total_entries = session.scalars(select(func.count(AssetModel.id))).one() + query = select(AssetModel) if dag_ids: dags_list = dag_ids.split(",") query = query.filter( - (DatasetModel.consuming_dags.any(DagScheduleDatasetReference.dag_id.in_(dags_list))) - | (DatasetModel.producing_tasks.any(TaskOutletDatasetReference.dag_id.in_(dags_list))) + (AssetModel.consuming_dags.any(DagScheduleAssetReference.dag_id.in_(dags_list))) + | (AssetModel.producing_tasks.any(TaskOutletAssetReference.dag_id.in_(dags_list))) ) if uri_pattern: - query = query.where(DatasetModel.uri.ilike(f"%{uri_pattern}%")) + query = query.where(AssetModel.uri.ilike(f"%{uri_pattern}%")) query = apply_sorting(query, order_by, {}, allowed_attrs) - datasets = session.scalars( - query.options(subqueryload(DatasetModel.consuming_dags), subqueryload(DatasetModel.producing_tasks)) + assets = session.scalars( + query.options(subqueryload(AssetModel.consuming_dags), subqueryload(AssetModel.producing_tasks)) .offset(offset) .limit(limit) ).all() - return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries)) + return asset_collection_schema.dump(AssetCollection(datasets=assets, total_entries=total_entries)) -@security.requires_access_dataset("GET") +@security.requires_access_asset("GET") @provide_session @format_parameters({"limit": check_limit}) def get_dataset_events( @@ -127,29 +127,29 @@ def get_dataset_events( source_map_index: int | None = None, session: Session = NEW_SESSION, ) -> APIResponse: - """Get dataset events.""" + """Get asset events.""" allowed_attrs = ["source_dag_id", "source_task_id", "source_run_id", "source_map_index", "timestamp"] - query = select(DatasetEvent) + query = select(AssetEvent) if dataset_id: - query = query.where(DatasetEvent.dataset_id == dataset_id) + query = query.where(AssetEvent.dataset_id == dataset_id) if source_dag_id: - query = query.where(DatasetEvent.source_dag_id == source_dag_id) + query = query.where(AssetEvent.source_dag_id == source_dag_id) if source_task_id: - query = query.where(DatasetEvent.source_task_id == source_task_id) + query = query.where(AssetEvent.source_task_id == source_task_id) if source_run_id: - query = query.where(DatasetEvent.source_run_id == source_run_id) + query = query.where(AssetEvent.source_run_id == source_run_id) if source_map_index: - query = query.where(DatasetEvent.source_map_index == source_map_index) + query = query.where(AssetEvent.source_map_index == source_map_index) - query = query.options(subqueryload(DatasetEvent.created_dagruns)) + query = query.options(subqueryload(AssetEvent.created_dagruns)) total_entries = get_query_count(query, session=session) query = apply_sorting(query, order_by, {}, allowed_attrs) events = session.scalars(query.offset(offset).limit(limit)).all() - return dataset_event_collection_schema.dump( - DatasetEventCollection(dataset_events=events, total_entries=total_entries) + return asset_event_collection_schema.dump( + AssetEventCollection(dataset_events=events, total_entries=total_entries) ) @@ -161,79 +161,77 @@ def _generate_queued_event_where_clause( before: str | None = None, permitted_dag_ids: set[str] | None = None, ) -> list: - """Get DatasetDagRunQueue where clause.""" + """Get AssetDagRunQueue where clause.""" where_clause = [] if dag_id is not None: - where_clause.append(DatasetDagRunQueue.target_dag_id == dag_id) + where_clause.append(AssetDagRunQueue.target_dag_id == dag_id) if dataset_id is not None: - where_clause.append(DatasetDagRunQueue.dataset_id == dataset_id) + where_clause.append(AssetDagRunQueue.dataset_id == dataset_id) if uri is not None: where_clause.append( - DatasetDagRunQueue.dataset_id.in_( - select(DatasetModel.id).where(DatasetModel.uri == uri), + AssetDagRunQueue.dataset_id.in_( + select(AssetModel.id).where(AssetModel.uri == uri), ), ) if before is not None: - where_clause.append(DatasetDagRunQueue.created_at < format_datetime(before)) + where_clause.append(AssetDagRunQueue.created_at < format_datetime(before)) if permitted_dag_ids is not None: - where_clause.append(DatasetDagRunQueue.target_dag_id.in_(permitted_dag_ids)) + where_clause.append(AssetDagRunQueue.target_dag_id.in_(permitted_dag_ids)) return where_clause -@security.requires_access_dataset("GET") +@security.requires_access_asset("GET") @security.requires_access_dag("GET") @provide_session def get_dag_dataset_queued_event( *, dag_id: str, uri: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: - """Get a queued Dataset event for a DAG.""" + """Get a queued asset event for a DAG.""" where_clause = _generate_queued_event_where_clause(dag_id=dag_id, uri=uri, before=before) - ddrq = session.scalar( - select(DatasetDagRunQueue) - .join(DatasetModel, DatasetDagRunQueue.dataset_id == DatasetModel.id) + adrq = session.scalar( + select(AssetDagRunQueue) + .join(AssetModel, AssetDagRunQueue.dataset_id == AssetModel.id) .where(*where_clause) ) - if ddrq is None: + if adrq is None: raise NotFound( "Queue event not found", - detail=f"Queue event with dag_id: `{dag_id}` and dataset uri: `{uri}` was not found", + detail=f"Queue event with dag_id: `{dag_id}` and asset uri: `{uri}` was not found", ) - queued_event = {"created_at": ddrq.created_at, "dag_id": dag_id, "uri": uri} + queued_event = {"created_at": adrq.created_at, "dag_id": dag_id, "uri": uri} return queued_event_schema.dump(queued_event) -@security.requires_access_dataset("DELETE") +@security.requires_access_asset("DELETE") @security.requires_access_dag("GET") @provide_session @action_logging def delete_dag_dataset_queued_event( *, dag_id: str, uri: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: - """Delete a queued Dataset event for a DAG.""" + """Delete a queued asset event for a DAG.""" where_clause = _generate_queued_event_where_clause(dag_id=dag_id, uri=uri, before=before) - delete_stmt = ( - delete(DatasetDagRunQueue).where(*where_clause).execution_options(synchronize_session="fetch") - ) + delete_stmt = delete(AssetDagRunQueue).where(*where_clause).execution_options(synchronize_session="fetch") result = session.execute(delete_stmt) if result.rowcount > 0: return NoContent, HTTPStatus.NO_CONTENT raise NotFound( "Queue event not found", - detail=f"Queue event with dag_id: `{dag_id}` and dataset uri: `{uri}` was not found", + detail=f"Queue event with dag_id: `{dag_id}` and asset uri: `{uri}` was not found", ) -@security.requires_access_dataset("GET") +@security.requires_access_asset("GET") @security.requires_access_dag("GET") @provide_session def get_dag_dataset_queued_events( *, dag_id: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: - """Get queued Dataset events for a DAG.""" + """Get queued asset events for a DAG.""" where_clause = _generate_queued_event_where_clause(dag_id=dag_id, before=before) query = ( - select(DatasetDagRunQueue, DatasetModel.uri) - .join(DatasetModel, DatasetDagRunQueue.dataset_id == DatasetModel.id) + select(AssetDagRunQueue, AssetModel.uri) + .join(AssetModel, AssetDagRunQueue.dataset_id == AssetModel.id) .where(*where_clause) ) result = session.execute(query).all() @@ -244,23 +242,23 @@ def get_dag_dataset_queued_events( detail=f"Queue event with dag_id: `{dag_id}` was not found", ) queued_events = [ - QueuedEvent(created_at=ddrq.created_at, dag_id=ddrq.target_dag_id, uri=uri) for ddrq, uri in result + QueuedEvent(created_at=adrq.created_at, dag_id=adrq.target_dag_id, uri=uri) for adrq, uri in result ] return queued_event_collection_schema.dump( QueuedEventCollection(queued_events=queued_events, total_entries=total_entries) ) -@security.requires_access_dataset("DELETE") +@security.requires_access_asset("DELETE") @security.requires_access_dag("GET") @action_logging @provide_session def delete_dag_dataset_queued_events( *, dag_id: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: - """Delete queued Dataset events for a DAG.""" + """Delete queued asset events for a DAG.""" where_clause = _generate_queued_event_where_clause(dag_id=dag_id, before=before) - delete_stmt = delete(DatasetDagRunQueue).where(*where_clause) + delete_stmt = delete(AssetDagRunQueue).where(*where_clause) result = session.execute(delete_stmt) if result.rowcount > 0: return NoContent, HTTPStatus.NO_CONTENT @@ -271,87 +269,85 @@ def delete_dag_dataset_queued_events( ) -@security.requires_access_dataset("GET") +@security.requires_access_asset("GET") @provide_session def get_dataset_queued_events( *, uri: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: - """Get queued Dataset events for a Dataset.""" + """Get queued asset events for an asset.""" permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"]) where_clause = _generate_queued_event_where_clause( uri=uri, before=before, permitted_dag_ids=permitted_dag_ids ) query = ( - select(DatasetDagRunQueue, DatasetModel.uri) - .join(DatasetModel, DatasetDagRunQueue.dataset_id == DatasetModel.id) + select(AssetDagRunQueue, AssetModel.uri) + .join(AssetModel, AssetDagRunQueue.dataset_id == AssetModel.id) .where(*where_clause) ) total_entries = get_query_count(query, session=session) result = session.execute(query).all() if total_entries > 0: queued_events = [ - QueuedEvent(created_at=ddrq.created_at, dag_id=ddrq.target_dag_id, uri=uri) - for ddrq, uri in result + QueuedEvent(created_at=adrq.created_at, dag_id=adrq.target_dag_id, uri=uri) + for adrq, uri in result ] return queued_event_collection_schema.dump( QueuedEventCollection(queued_events=queued_events, total_entries=total_entries) ) raise NotFound( "Queue event not found", - detail=f"Queue event with dataset uri: `{uri}` was not found", + detail=f"Queue event with asset uri: `{uri}` was not found", ) -@security.requires_access_dataset("DELETE") +@security.requires_access_asset("DELETE") @action_logging @provide_session def delete_dataset_queued_events( *, uri: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: - """Delete queued Dataset events for a Dataset.""" + """Delete queued asset events for an asset.""" permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"]) where_clause = _generate_queued_event_where_clause( uri=uri, before=before, permitted_dag_ids=permitted_dag_ids ) - delete_stmt = ( - delete(DatasetDagRunQueue).where(*where_clause).execution_options(synchronize_session="fetch") - ) + delete_stmt = delete(AssetDagRunQueue).where(*where_clause).execution_options(synchronize_session="fetch") result = session.execute(delete_stmt) if result.rowcount > 0: return NoContent, HTTPStatus.NO_CONTENT raise NotFound( "Queue event not found", - detail=f"Queue event with dataset uri: `{uri}` was not found", + detail=f"Queue event with asset uri: `{uri}` was not found", ) -@security.requires_access_dataset("POST") +@security.requires_access_asset("POST") @provide_session @action_logging def create_dataset_event(session: Session = NEW_SESSION) -> APIResponse: - """Create dataset event.""" + """Create asset event.""" body = get_json_request_dict() try: - json_body = create_dataset_event_schema.load(body) + json_body = create_asset_event_schema.load(body) except ValidationError as err: raise BadRequest(detail=str(err)) uri = json_body["dataset_uri"] - dataset = session.scalar(select(DatasetModel).where(DatasetModel.uri == uri).limit(1)) - if not dataset: - raise NotFound(title="Dataset not found", detail=f"Dataset with uri: '{uri}' not found") + asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri).limit(1)) + if not asset: + raise NotFound(title="Asset not found", detail=f"Asset with uri: '{uri}' not found") timestamp = timezone.utcnow() extra = json_body.get("extra", {}) extra["from_rest_api"] = True - dataset_event = dataset_manager.register_dataset_change( - dataset=Dataset(uri), + asset_event = asset_manager.register_asset_change( + asset=Asset(uri), timestamp=timestamp, extra=extra, session=session, ) - if not dataset_event: - raise NotFound(title="Dataset not found", detail=f"Dataset with uri: '{uri}' not found") + if not asset_event: + raise NotFound(title="Asset not found", detail=f"Asset with uri: '{uri}' not found") session.flush() # So we can dump the timestamp. - event = dataset_event_schema.dump(dataset_event) + event = asset_event_schema.dump(asset_event) return event diff --git a/airflow/api_connexion/schemas/dataset_schema.py b/airflow/api_connexion/schemas/asset_schema.py similarity index 65% rename from airflow/api_connexion/schemas/dataset_schema.py rename to airflow/api_connexion/schemas/asset_schema.py index b8aaf2f8fa30e..791941f42016d 100644 --- a/airflow/api_connexion/schemas/dataset_schema.py +++ b/airflow/api_connexion/schemas/asset_schema.py @@ -23,23 +23,23 @@ from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field from airflow.api_connexion.schemas.common_schema import JsonObjectField -from airflow.models.dagrun import DagRun -from airflow.models.dataset import ( - DagScheduleDatasetReference, - DatasetAliasModel, - DatasetEvent, - DatasetModel, - TaskOutletDatasetReference, +from airflow.models.asset import ( + AssetAliasModel, + AssetEvent, + AssetModel, + DagScheduleAssetReference, + TaskOutletAssetReference, ) +from airflow.models.dagrun import DagRun -class TaskOutletDatasetReferenceSchema(SQLAlchemySchema): - """TaskOutletDatasetReference DB schema.""" +class TaskOutletAssetReferenceSchema(SQLAlchemySchema): + """TaskOutletAssetReference DB schema.""" class Meta: """Meta.""" - model = TaskOutletDatasetReference + model = TaskOutletAssetReference dag_id = auto_field() task_id = auto_field() @@ -47,65 +47,65 @@ class Meta: updated_at = auto_field() -class DagScheduleDatasetReferenceSchema(SQLAlchemySchema): - """DagScheduleDatasetReference DB schema.""" +class DagScheduleAssetReferenceSchema(SQLAlchemySchema): + """DagScheduleAssetReference DB schema.""" class Meta: """Meta.""" - model = DagScheduleDatasetReference + model = DagScheduleAssetReference dag_id = auto_field() created_at = auto_field() updated_at = auto_field() -class DatasetAliasSchema(SQLAlchemySchema): - """DatasetAlias DB schema.""" +class AssetAliasSchema(SQLAlchemySchema): + """AssetAlias DB schema.""" class Meta: """Meta.""" - model = DatasetAliasModel + model = AssetAliasModel id = auto_field() name = auto_field() -class DatasetSchema(SQLAlchemySchema): - """Dataset DB schema.""" +class AssetSchema(SQLAlchemySchema): + """Asset DB schema.""" class Meta: """Meta.""" - model = DatasetModel + model = AssetModel id = auto_field() uri = auto_field() extra = JsonObjectField() created_at = auto_field() updated_at = auto_field() - producing_tasks = fields.List(fields.Nested(TaskOutletDatasetReferenceSchema)) - consuming_dags = fields.List(fields.Nested(DagScheduleDatasetReferenceSchema)) - aliases = fields.List(fields.Nested(DatasetAliasSchema)) + producing_tasks = fields.List(fields.Nested(TaskOutletAssetReferenceSchema)) + consuming_dags = fields.List(fields.Nested(DagScheduleAssetReferenceSchema)) + aliases = fields.List(fields.Nested(AssetAliasSchema)) -class DatasetCollection(NamedTuple): - """List of Datasets with meta.""" +class AssetCollection(NamedTuple): + """List of Assets with meta.""" - datasets: list[DatasetModel] + datasets: list[AssetModel] total_entries: int -class DatasetCollectionSchema(Schema): - """Dataset Collection Schema.""" +class AssetCollectionSchema(Schema): + """Asset Collection Schema.""" - datasets = fields.List(fields.Nested(DatasetSchema)) + datasets = fields.List(fields.Nested(AssetSchema)) total_entries = fields.Int() -dataset_schema = DatasetSchema() -dataset_collection_schema = DatasetCollectionSchema() +asset_schema = AssetSchema() +asset_collection_schema = AssetCollectionSchema() class BasicDAGRunSchema(SQLAlchemySchema): @@ -127,13 +127,13 @@ class Meta: data_interval_end = auto_field(dump_only=True) -class DatasetEventSchema(SQLAlchemySchema): - """Dataset Event DB schema.""" +class AssetEventSchema(SQLAlchemySchema): + """Asset Event DB schema.""" class Meta: """Meta.""" - model = DatasetEvent + model = AssetEvent id = auto_field() dataset_id = auto_field() @@ -147,30 +147,30 @@ class Meta: timestamp = auto_field() -class DatasetEventCollection(NamedTuple): - """List of Dataset events with meta.""" +class AssetEventCollection(NamedTuple): + """List of Asset events with meta.""" - dataset_events: list[DatasetEvent] + dataset_events: list[AssetEvent] total_entries: int -class DatasetEventCollectionSchema(Schema): - """Dataset Event Collection Schema.""" +class AssetEventCollectionSchema(Schema): + """Asset Event Collection Schema.""" - dataset_events = fields.List(fields.Nested(DatasetEventSchema)) + dataset_events = fields.List(fields.Nested(AssetEventSchema)) total_entries = fields.Int() -class CreateDatasetEventSchema(Schema): - """Create Dataset Event Schema.""" +class CreateAssetEventSchema(Schema): + """Create Asset Event Schema.""" dataset_uri = fields.String() extra = JsonObjectField() -dataset_event_schema = DatasetEventSchema() -dataset_event_collection_schema = DatasetEventCollectionSchema() -create_dataset_event_schema = CreateDatasetEventSchema() +asset_event_schema = AssetEventSchema() +asset_event_collection_schema = AssetEventCollectionSchema() +create_asset_event_schema = CreateAssetEventSchema() class QueuedEvent(NamedTuple): diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index 445ded913e56a..1098de3a1f474 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -24,11 +24,11 @@ from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated from airflow.auth.managers.models.resource_details import ( AccessView, + AssetDetails, ConfigurationDetails, ConnectionDetails, DagAccessEntity, DagDetails, - DatasetDetails, PoolDetails, VariableDetails, ) @@ -158,14 +158,14 @@ def decorated(*args, **kwargs): return requires_access_decorator -def requires_access_dataset(method: ResourceMethod) -> Callable[[T], T]: +def requires_access_asset(method: ResourceMethod) -> Callable[[T], T]: def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): uri: str | None = kwargs.get("uri") return _requires_access( - is_authorized_callback=lambda: get_auth_manager().is_authorized_dataset( - method=method, details=DatasetDetails(uri=uri) + is_authorized_callback=lambda: get_auth_manager().is_authorized_asset( + method=method, details=AssetDetails(uri=uri) ), func=func, args=args, diff --git a/airflow/api_fastapi/openapi/v1-generated.yaml b/airflow/api_fastapi/openapi/v1-generated.yaml index 64e475aeb6baa..c130f3162c6e6 100644 --- a/airflow/api_fastapi/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/openapi/v1-generated.yaml @@ -10,9 +10,9 @@ paths: /ui/next_run_datasets/{dag_id}: get: tags: - - Dataset - summary: Next Run Datasets - operationId: next_run_datasets_ui_next_run_datasets__dag_id__get + - Asset + summary: Next Run Assets + operationId: next_run_assets_ui_next_run_datasets__dag_id__get parameters: - name: dag_id in: path @@ -27,7 +27,7 @@ paths: application/json: schema: type: object - title: Response Next Run Datasets Ui Next Run Datasets Dag Id Get + title: Response Next Run Assets Ui Next Run Datasets Dag Id Get '422': description: Validation Error content: diff --git a/airflow/api_fastapi/views/ui/__init__.py b/airflow/api_fastapi/views/ui/__init__.py index 2d95e040403a7..edba930c3d1d1 100644 --- a/airflow/api_fastapi/views/ui/__init__.py +++ b/airflow/api_fastapi/views/ui/__init__.py @@ -18,8 +18,8 @@ from fastapi import APIRouter -from airflow.api_fastapi.views.ui.datasets import datasets_router +from airflow.api_fastapi.views.ui.assets import assets_router ui_router = APIRouter(prefix="/ui") -ui_router.include_router(datasets_router) +ui_router.include_router(assets_router) diff --git a/airflow/api_fastapi/views/ui/datasets.py b/airflow/api_fastapi/views/ui/assets.py similarity index 66% rename from airflow/api_fastapi/views/ui/datasets.py rename to airflow/api_fastapi/views/ui/assets.py index f5dd2cacb126d..458d531facf6a 100644 --- a/airflow/api_fastapi/views/ui/datasets.py +++ b/airflow/api_fastapi/views/ui/assets.py @@ -24,13 +24,13 @@ from airflow.api_fastapi.db import get_session from airflow.models import DagModel -from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel +from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel, DagScheduleAssetReference -datasets_router = APIRouter(tags=["Dataset"]) +assets_router = APIRouter(tags=["Asset"]) -@datasets_router.get("/next_run_datasets/{dag_id}", include_in_schema=False) -async def next_run_datasets( +@assets_router.get("/next_run_datasets/{dag_id}", include_in_schema=False) +async def next_run_assets( dag_id: str, request: Request, session: Annotated[Session, Depends(get_session)], @@ -51,34 +51,34 @@ async def next_run_datasets( dict(info._mapping) for info in session.execute( select( - DatasetModel.id, - DatasetModel.uri, - func.max(DatasetEvent.timestamp).label("lastUpdate"), + AssetModel.id, + AssetModel.uri, + func.max(AssetEvent.timestamp).label("lastUpdate"), ) - .join(DagScheduleDatasetReference, DagScheduleDatasetReference.dataset_id == DatasetModel.id) + .join(DagScheduleAssetReference, DagScheduleAssetReference.dataset_id == AssetModel.id) .join( - DatasetDagRunQueue, + AssetDagRunQueue, and_( - DatasetDagRunQueue.dataset_id == DatasetModel.id, - DatasetDagRunQueue.target_dag_id == DagScheduleDatasetReference.dag_id, + AssetDagRunQueue.dataset_id == AssetModel.id, + AssetDagRunQueue.target_dag_id == DagScheduleAssetReference.dag_id, ), isouter=True, ) .join( - DatasetEvent, + AssetEvent, and_( - DatasetEvent.dataset_id == DatasetModel.id, + AssetEvent.dataset_id == AssetModel.id, ( - DatasetEvent.timestamp >= latest_run.execution_date + AssetEvent.timestamp >= latest_run.execution_date if latest_run and latest_run.execution_date else True ), ), isouter=True, ) - .where(DagScheduleDatasetReference.dag_id == dag_id, ~DatasetModel.is_orphaned) - .group_by(DatasetModel.id, DatasetModel.uri) - .order_by(DatasetModel.uri) + .where(DagScheduleAssetReference.dag_id == dag_id, ~AssetModel.is_orphaned) + .group_by(AssetModel.id, AssetModel.uri) + .order_by(AssetModel.uri) ) ] data = {"dataset_expression": dag_model.dataset_expression, "events": events} diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 8716d9c9cc49d..1cdd2536e1354 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -54,11 +54,11 @@ @functools.lru_cache def initialize_method_map() -> dict[str, Callable]: from airflow.api.common.trigger_dag import trigger_dag + from airflow.assets import expand_alias_to_assets + from airflow.assets.manager import AssetManager from airflow.cli.commands.task_command import _get_ti_db_access from airflow.dag_processing.manager import DagFileProcessorManager from airflow.dag_processing.processor import DagFileProcessor - from airflow.datasets import expand_alias_to_datasets - from airflow.datasets.manager import DatasetManager from airflow.models import Trigger, Variable, XCom from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun @@ -109,8 +109,8 @@ def initialize_method_map() -> dict[str, Callable]: DagFileProcessorManager.clear_nonexistent_import_errors, DagFileProcessorManager.deactivate_stale_dags, DagWarning.purge_inactive_dag_warnings, - expand_alias_to_datasets, - DatasetManager.register_dataset_change, + expand_alias_to_assets, + AssetManager.register_asset_change, FileTaskHandler._render_filename_db_access, Job._add_to_db, Job._fetch_from_db, diff --git a/airflow/datasets/__init__.py b/airflow/assets/__init__.py similarity index 60% rename from airflow/datasets/__init__.py rename to airflow/assets/__init__.py index 6f7ae99ff7417..9727e408edc2e 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/assets/__init__.py @@ -38,7 +38,7 @@ from airflow.configuration import conf -__all__ = ["Dataset", "DatasetAll", "DatasetAny"] +__all__ = ["Asset", "AssetAll", "AssetAny"] def normalize_noop(parts: SplitResult) -> SplitResult: @@ -55,7 +55,7 @@ def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | N return normalize_noop from airflow.providers_manager import ProvidersManager - return ProvidersManager().dataset_uri_handlers.get(scheme) + return ProvidersManager().asset_uri_handlers.get(scheme) def _get_normalized_scheme(uri: str) -> str: @@ -65,17 +65,17 @@ def _get_normalized_scheme(uri: str) -> str: def _sanitize_uri(uri: str) -> str: """ - Sanitize a dataset URI. + Sanitize an asset URI. This checks for URI validity, and normalizes the URI if needed. A fully normalized URI is returned. """ if not uri: - raise ValueError("Dataset URI cannot be empty") + raise ValueError("Asset URI cannot be empty") if uri.isspace(): - raise ValueError("Dataset URI cannot be just whitespace") + raise ValueError("Asset URI cannot be just whitespace") if not uri.isascii(): - raise ValueError("Dataset URI must only consist of ASCII characters") + raise ValueError("Asset URI must only consist of ASCII characters") parsed = urllib.parse.urlsplit(uri) if not parsed.scheme and not parsed.netloc: # Does not look like a URI. return uri @@ -84,12 +84,12 @@ def _sanitize_uri(uri: str) -> str: if normalized_scheme.startswith("x-"): return uri if normalized_scheme == "airflow": - raise ValueError("Dataset scheme 'airflow' is reserved") + raise ValueError("Asset scheme 'airflow' is reserved") _, auth_exists, normalized_netloc = parsed.netloc.rpartition("@") if auth_exists: # TODO: Collect this into a DagWarning. warnings.warn( - "A dataset URI should not contain auth info (e.g. username or " + "An Asset URI should not contain auth info (e.g. username or " "password). It has been automatically dropped.", UserWarning, stacklevel=3, @@ -109,10 +109,10 @@ def _sanitize_uri(uri: str) -> str: try: parsed = normalizer(parsed) except ValueError as exception: - if conf.getboolean("core", "strict_dataset_uri_validation", fallback=False): + if conf.getboolean("core", "strict_asset_uri_validation", fallback=False): raise warnings.warn( - f"The dataset URI {uri} is not AIP-60 compliant: {exception}. " + f"The Asset URI {uri} is not AIP-60 compliant: {exception}. " f"In Airflow 3, this will raise an exception.", UserWarning, stacklevel=3, @@ -120,46 +120,44 @@ def _sanitize_uri(uri: str) -> str: return urllib.parse.urlunsplit(parsed) -def extract_event_key(value: str | Dataset | DatasetAlias) -> str: +def extract_event_key(value: str | Asset | AssetAlias) -> str: """ Extract the key of an inlet or an outlet event. If the input value is a string, it is treated as a URI and sanitized. If the - input is a :class:`Dataset`, the URI it contains is considered sanitized and - returned directly. If the input is a :class:`DatasetAlias`, the name it contains + input is a :class:`Asset`, the URI it contains is considered sanitized and + returned directly. If the input is a :class:`AssetAlias`, the name it contains will be returned directly. :meta private: """ - if isinstance(value, DatasetAlias): + if isinstance(value, AssetAlias): return value.name - if isinstance(value, Dataset): + if isinstance(value, Asset): return value.uri return _sanitize_uri(str(value)) @internal_api_call @provide_session -def expand_alias_to_datasets( - alias: str | DatasetAlias, *, session: Session = NEW_SESSION -) -> list[BaseDataset]: - """Expand dataset alias to resolved datasets.""" - from airflow.models.dataset import DatasetAliasModel +def expand_alias_to_assets(alias: str | AssetAlias, *, session: Session = NEW_SESSION) -> list[BaseAsset]: + """Expand asset alias to resolved assets.""" + from airflow.models.asset import AssetAliasModel - alias_name = alias.name if isinstance(alias, DatasetAlias) else alias + alias_name = alias.name if isinstance(alias, AssetAlias) else alias - dataset_alias_obj = session.scalar( - select(DatasetAliasModel).where(DatasetAliasModel.name == alias_name).limit(1) + asset_alias_obj = session.scalar( + select(AssetAliasModel).where(AssetAliasModel.name == alias_name).limit(1) ) - if dataset_alias_obj: - return [Dataset(uri=dataset.uri, extra=dataset.extra) for dataset in dataset_alias_obj.datasets] + if asset_alias_obj: + return [Asset(uri=asset.uri, extra=asset.extra) for asset in asset_alias_obj.datasets] return [] -class BaseDataset: +class BaseAsset: """ - Protocol for all dataset triggers to use in ``DAG(schedule=...)``. + Protocol for all asset triggers to use in ``DAG(schedule=...)``. :meta private: """ @@ -167,19 +165,19 @@ class BaseDataset: def __bool__(self) -> bool: return True - def __or__(self, other: BaseDataset) -> BaseDataset: - if not isinstance(other, BaseDataset): + def __or__(self, other: BaseAsset) -> BaseAsset: + if not isinstance(other, BaseAsset): return NotImplemented - return DatasetAny(self, other) + return AssetAny(self, other) - def __and__(self, other: BaseDataset) -> BaseDataset: - if not isinstance(other, BaseDataset): + def __and__(self, other: BaseAsset) -> BaseAsset: + if not isinstance(other, BaseAsset): return NotImplemented - return DatasetAll(self, other) + return AssetAll(self, other) def as_expression(self) -> Any: """ - Serialize the dataset into its scheduling expression. + Serialize the asset into its scheduling expression. The return value is stored in DagModel for display purposes. It must be JSON-compatible. @@ -191,15 +189,15 @@ def as_expression(self) -> Any: def evaluate(self, statuses: dict[str, bool]) -> bool: raise NotImplementedError - def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + def iter_assets(self) -> Iterator[tuple[str, Asset]]: raise NotImplementedError - def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: raise NotImplementedError def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ - Iterate a base dataset as dag dependency. + Iterate a base asset as dag dependency. :meta private: """ @@ -207,36 +205,36 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe @attr.define(unsafe_hash=False) -class DatasetAlias(BaseDataset): - """A represeation of dataset alias which is used to create dataset during the runtime.""" +class AssetAlias(BaseAsset): + """A represeation of asset alias which is used to create asset during the runtime.""" name: str - def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + def iter_assets(self) -> Iterator[tuple[str, Asset]]: return iter(()) - def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: yield self.name, self def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ - Iterate a dataset alias as dag dependency. + Iterate an asset alias as dag dependency. :meta private: """ yield DagDependency( - source=source or "dataset-alias", - target=target or "dataset-alias", - dependency_type="dataset-alias", + source=source or "asset-alias", + target=target or "asset-alias", + dependency_type="asset-alias", dependency_id=self.name, ) -class DatasetAliasEvent(TypedDict): - """A represeation of dataset event to be triggered by a dataset alias.""" +class AssetAliasEvent(TypedDict): + """A represeation of asset event to be triggered by an asset alias.""" source_alias_name: str - dest_dataset_uri: str + dest_asset_uri: str extra: dict[str, Any] @@ -244,7 +242,7 @@ def _set_extra_default(extra: dict | None) -> dict: """ Automatically convert None to an empty dict. - This allows the caller site to continue doing ``Dataset(uri, extra=None)``, + This allows the caller site to continue doing ``Asset(uri, extra=None)``, but still allow the ``extra`` attribute to always be a dict. """ if extra is None: @@ -253,7 +251,7 @@ def _set_extra_default(extra: dict | None) -> dict: @attr.define(unsafe_hash=False) -class Dataset(os.PathLike, BaseDataset): +class Asset(os.PathLike, BaseAsset): """A representation of data dependencies between workflows.""" uri: str = attr.field( @@ -291,16 +289,16 @@ def normalized_uri(self) -> str | None: def as_expression(self) -> Any: """ - Serialize the dataset into its scheduling expression. + Serialize the asset into its scheduling expression. :meta private: """ return self.uri - def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + def iter_assets(self) -> Iterator[tuple[str, Asset]]: yield self.uri, self - def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: return iter(()) def evaluate(self, statuses: dict[str, bool]) -> bool: @@ -308,51 +306,51 @@ def evaluate(self, statuses: dict[str, bool]) -> bool: def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ - Iterate a dataset as dag dependency. + Iterate an asset as dag dependency. :meta private: """ yield DagDependency( - source=source or "dataset", - target=target or "dataset", - dependency_type="dataset", + source=source or "asset", + target=target or "asset", + dependency_type="asset", dependency_id=self.uri, ) -class _DatasetBooleanCondition(BaseDataset): - """Base class for dataset boolean logic.""" +class _AssetBooleanCondition(BaseAsset): + """Base class for asset boolean logic.""" agg_func: Callable[[Iterable], bool] - def __init__(self, *objects: BaseDataset) -> None: - if not all(isinstance(o, BaseDataset) for o in objects): - raise TypeError("expect dataset expressions in condition") + def __init__(self, *objects: BaseAsset) -> None: + if not all(isinstance(o, BaseAsset) for o in objects): + raise TypeError("expect asset expressions in condition") self.objects = [ - _DatasetAliasCondition(obj.name) if isinstance(obj, DatasetAlias) else obj for obj in objects + _AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects ] def evaluate(self, statuses: dict[str, bool]) -> bool: return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) - def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + def iter_assets(self) -> Iterator[tuple[str, Asset]]: seen = set() # We want to keep the first instance. for o in self.objects: - for k, v in o.iter_datasets(): + for k, v in o.iter_assets(): if k in seen: continue yield k, v seen.add(k) - def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: - """Filter dataest aliases in the condition.""" + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: + """Filter asset aliases in the condition.""" for o in self.objects: - yield from o.iter_dataset_aliases() + yield from o.iter_asset_aliases() def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ - Iterate dataset, dataset aliases and their resolved datasets as dag dependency. + Iterate asset, asset aliases and their resolved assets as dag dependency. :meta private: """ @@ -360,104 +358,104 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe yield from obj.iter_dag_dependencies(source=source, target=target) -class DatasetAny(_DatasetBooleanCondition): - """Use to combine datasets schedule references in an "and" relationship.""" +class AssetAny(_AssetBooleanCondition): + """Use to combine assets schedule references in an "and" relationship.""" agg_func = any - def __or__(self, other: BaseDataset) -> BaseDataset: - if not isinstance(other, BaseDataset): + def __or__(self, other: BaseAsset) -> BaseAsset: + if not isinstance(other, BaseAsset): return NotImplemented # Optimization: X | (Y | Z) is equivalent to X | Y | Z. - return DatasetAny(*self.objects, other) + return AssetAny(*self.objects, other) def __repr__(self) -> str: - return f"DatasetAny({', '.join(map(str, self.objects))})" + return f"AssetAny({', '.join(map(str, self.objects))})" def as_expression(self) -> dict[str, Any]: """ - Serialize the dataset into its scheduling expression. + Serialize the asset into its scheduling expression. :meta private: """ return {"any": [o.as_expression() for o in self.objects]} -class _DatasetAliasCondition(DatasetAny): +class _AssetAliasCondition(AssetAny): """ - Use to expand DataAlias as DatasetAny of its resolved Datasets. + Use to expand AssetAlias as AssetAny of its resolved Assets. :meta private: """ def __init__(self, name: str) -> None: self.name = name - self.objects = expand_alias_to_datasets(name) + self.objects = expand_alias_to_assets(name) def __repr__(self) -> str: - return f"_DatasetAliasCondition({', '.join(map(str, self.objects))})" + return f"_AssetAliasCondition({', '.join(map(str, self.objects))})" def as_expression(self) -> Any: """ - Serialize the dataset into its scheduling expression. + Serialize the asset alias into its scheduling expression. :meta private: """ return {"alias": self.name} - def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: - yield self.name, DatasetAlias(self.name) + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: + yield self.name, AssetAlias(self.name) def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]: """ - Iterate a dataset alias and its resolved datasets as dag dependency. + Iterate an asset alias and its resolved assets as dag dependency. :meta private: """ if self.objects: for obj in self.objects: - dataset = cast(Dataset, obj) - uri = dataset.uri - # dataset + asset = cast(Asset, obj) + uri = asset.uri + # asset yield DagDependency( - source=f"dataset-alias:{self.name}" if source else "dataset", - target="dataset" if source else f"dataset-alias:{self.name}", - dependency_type="dataset", + source=f"asset-alias:{self.name}" if source else "asset", + target="asset" if source else f"asset-alias:{self.name}", + dependency_type="asset", dependency_id=uri, ) - # dataset alias + # asset alias yield DagDependency( - source=source or f"dataset:{uri}", - target=target or f"dataset:{uri}", - dependency_type="dataset-alias", + source=source or f"asset:{uri}", + target=target or f"asset:{uri}", + dependency_type="asset-alias", dependency_id=self.name, ) else: yield DagDependency( - source=source or "dataset-alias", - target=target or "dataset-alias", - dependency_type="dataset-alias", + source=source or "asset-alias", + target=target or "asset-alias", + dependency_type="asset-alias", dependency_id=self.name, ) -class DatasetAll(_DatasetBooleanCondition): - """Use to combine datasets schedule references in an "or" relationship.""" +class AssetAll(_AssetBooleanCondition): + """Use to combine assets schedule references in an "or" relationship.""" agg_func = all - def __and__(self, other: BaseDataset) -> BaseDataset: - if not isinstance(other, BaseDataset): + def __and__(self, other: BaseAsset) -> BaseAsset: + if not isinstance(other, BaseAsset): return NotImplemented # Optimization: X & (Y & Z) is equivalent to X & Y & Z. - return DatasetAll(*self.objects, other) + return AssetAll(*self.objects, other) def __repr__(self) -> str: - return f"DatasetAll({', '.join(map(str, self.objects))})" + return f"AssetAll({', '.join(map(str, self.objects))})" def as_expression(self) -> Any: """ - Serialize the dataset into its scheduling expression. + Serialize the assets into its scheduling expression. :meta private: """ diff --git a/airflow/datasets/manager.py b/airflow/assets/manager.py similarity index 53% rename from airflow/datasets/manager.py rename to airflow/assets/manager.py index 6322414bb8499..d68a0efc87d12 100644 --- a/airflow/datasets/manager.py +++ b/airflow/assets/manager.py @@ -24,120 +24,121 @@ from sqlalchemy.orm import joinedload from airflow.api_internal.internal_api_call import internal_api_call +from airflow.assets import Asset from airflow.configuration import conf from airflow.listeners.listener import get_listener_manager -from airflow.models.dagbag import DagPriorityParsingRequest -from airflow.models.dataset import ( - DagScheduleDatasetAliasReference, - DagScheduleDatasetReference, - DatasetAliasModel, - DatasetDagRunQueue, - DatasetEvent, - DatasetModel, +from airflow.models.asset import ( + AssetAliasModel, + AssetDagRunQueue, + AssetEvent, + AssetModel, + DagScheduleAssetAliasReference, + DagScheduleAssetReference, ) +from airflow.models.dagbag import DagPriorityParsingRequest from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from sqlalchemy.orm.session import Session - from airflow.datasets import Dataset, DatasetAlias + from airflow.assets import Asset, AssetAlias from airflow.models.dag import DagModel from airflow.models.taskinstance import TaskInstance -class DatasetManager(LoggingMixin): +class AssetManager(LoggingMixin): """ - A pluggable class that manages operations for datasets. + A pluggable class that manages operations for assets. - The intent is to have one place to handle all Dataset-related operations, so different - Airflow deployments can use plugins that broadcast dataset events to each other. + The intent is to have one place to handle all Asset-related operations, so different + Airflow deployments can use plugins that broadcast Asset events to each other. """ @classmethod - def create_datasets(cls, datasets: list[Dataset], *, session: Session) -> list[DatasetModel]: - """Create new datasets.""" + def create_assets(cls, assets: list[Asset], *, session: Session) -> list[AssetModel]: + """Create new assets.""" - def _add_one(dataset: Dataset) -> DatasetModel: - model = DatasetModel.from_public(dataset) + def _add_one(asset: Asset) -> AssetModel: + model = AssetModel.from_public(asset) session.add(model) - cls.notify_dataset_created(dataset=dataset) + cls.notify_asset_created(asset=asset) return model - return [_add_one(d) for d in datasets] + return [_add_one(a) for a in assets] @classmethod - def create_dataset_aliases( + def create_asset_aliases( cls, - dataset_aliases: list[DatasetAlias], + asset_aliases: list[AssetAlias], *, session: Session, - ) -> list[DatasetAliasModel]: - """Create new dataset aliases.""" + ) -> list[AssetAliasModel]: + """Create new asset aliases.""" - def _add_one(dataset_alias: DatasetAlias) -> DatasetAliasModel: - model = DatasetAliasModel.from_public(dataset_alias) + def _add_one(asset_alias: AssetAlias) -> AssetAliasModel: + model = AssetAliasModel.from_public(asset_alias) session.add(model) - cls.notify_dataset_alias_created(dataset_alias=dataset_alias) + cls.notify_asset_alias_created(asset_assets=asset_alias) return model - return [_add_one(a) for a in dataset_aliases] + return [_add_one(a) for a in asset_aliases] @classmethod - def _add_dataset_alias_association( + def _add_asset_alias_association( cls, alias_names: Collection[str], - dataset: DatasetModel, + asset: AssetModel, *, session: Session, ) -> None: - already_related = {m.name for m in dataset.aliases} + already_related = {m.name for m in asset.aliases} existing_aliases = { m.name: m - for m in session.scalars(select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names))) + for m in session.scalars(select(AssetAliasModel).where(AssetAliasModel.name.in_(alias_names))) } - dataset.aliases.extend( - existing_aliases.get(name, DatasetAliasModel(name=name)) + asset.aliases.extend( + existing_aliases.get(name, AssetAliasModel(name=name)) for name in alias_names if name not in already_related ) @classmethod @internal_api_call - def register_dataset_change( + def register_asset_change( cls, *, task_instance: TaskInstance | None = None, - dataset: Dataset, + asset: Asset, extra=None, - aliases: Collection[DatasetAlias] = (), + aliases: Collection[AssetAlias] = (), source_alias_names: Iterable[str] | None = None, session: Session, **kwargs, - ) -> DatasetEvent | None: + ) -> AssetEvent | None: """ - Register dataset related changes. + Register asset related changes. - For local datasets, look them up, record the dataset event, queue dagruns, and broadcast - the dataset event + For local assets, look them up, record the asset event, queue dagruns, and broadcast + the asset event """ # todo: add test so that all usages of internal_api_call are added to rpc endpoint - dataset_model = session.scalar( - select(DatasetModel) - .where(DatasetModel.uri == dataset.uri) + asset_model = session.scalar( + select(AssetModel) + .where(AssetModel.uri == asset.uri) .options( - joinedload(DatasetModel.aliases), - joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag), + joinedload(AssetModel.aliases), + joinedload(AssetModel.consuming_dags).joinedload(DagScheduleAssetReference.dag), ) ) - if not dataset_model: - cls.logger().warning("DatasetModel %s not found", dataset) + if not asset_model: + cls.logger().warning("AssetModel %s not found", asset) return None - cls._add_dataset_alias_association({alias.name for alias in aliases}, dataset_model, session=session) + cls._add_asset_alias_association({alias.name for alias in aliases}, asset_model, session=session) event_kwargs = { - "dataset_id": dataset_model.id, + "dataset_id": asset_model.id, "extra": extra, } if task_instance: @@ -148,67 +149,65 @@ def register_dataset_change( source_map_index=task_instance.map_index, ) - dataset_event = DatasetEvent(**event_kwargs) - session.add(dataset_event) + asset_event = AssetEvent(**event_kwargs) + session.add(asset_event) session.flush() # Ensure the event is written earlier than DDRQ entries below. - dags_to_queue_from_dataset = { - ref.dag for ref in dataset_model.consuming_dags if ref.dag.is_active and not ref.dag.is_paused + dags_to_queue_from_asset = { + ref.dag for ref in asset_model.consuming_dags if ref.dag.is_active and not ref.dag.is_paused } - dags_to_queue_from_dataset_alias = set() + dags_to_queue_from_asset_alias = set() if source_alias_names: - dataset_alias_models = session.scalars( - select(DatasetAliasModel) - .where(DatasetAliasModel.name.in_(source_alias_names)) + asset_alias_models = session.scalars( + select(AssetAliasModel) + .where(AssetAliasModel.name.in_(source_alias_names)) .options( - joinedload(DatasetAliasModel.consuming_dags).joinedload( - DagScheduleDatasetAliasReference.dag - ) + joinedload(AssetAliasModel.consuming_dags).joinedload(DagScheduleAssetAliasReference.dag) ) ).unique() - for dsa in dataset_alias_models: - dsa.dataset_events.append(dataset_event) - session.add(dsa) + for asset_alias_model in asset_alias_models: + asset_alias_model.dataset_events.append(asset_event) + session.add(asset_alias_model) - dags_to_queue_from_dataset_alias |= { + dags_to_queue_from_asset_alias |= { alias_ref.dag - for alias_ref in dsa.consuming_dags + for alias_ref in asset_alias_model.consuming_dags if alias_ref.dag.is_active and not alias_ref.dag.is_paused } - dags_to_reparse = dags_to_queue_from_dataset_alias - dags_to_queue_from_dataset + dags_to_reparse = dags_to_queue_from_asset_alias - dags_to_queue_from_asset if dags_to_reparse: file_locs = {dag.fileloc for dag in dags_to_reparse} cls._send_dag_priority_parsing_request(file_locs, session) - cls.notify_dataset_changed(dataset=dataset) + cls.notify_asset_changed(asset=asset) - Stats.incr("dataset.updates") + Stats.incr("asset.updates") - dags_to_queue = dags_to_queue_from_dataset | dags_to_queue_from_dataset_alias - cls._queue_dagruns(dataset_id=dataset_model.id, dags_to_queue=dags_to_queue, session=session) - return dataset_event + dags_to_queue = dags_to_queue_from_asset | dags_to_queue_from_asset_alias + cls._queue_dagruns(asset_id=asset_model.id, dags_to_queue=dags_to_queue, session=session) + return asset_event @staticmethod - def notify_dataset_created(dataset: Dataset): - """Run applicable notification actions when a dataset is created.""" - get_listener_manager().hook.on_dataset_created(dataset=dataset) + def notify_asset_created(asset: Asset): + """Run applicable notification actions when an asset is created.""" + get_listener_manager().hook.on_asset_created(asset=asset) @staticmethod - def notify_dataset_alias_created(dataset_alias: DatasetAlias): - """Run applicable notification actions when a dataset alias is created.""" - get_listener_manager().hook.on_dataset_alias_created(dataset_alias=dataset_alias) + def notify_asset_alias_created(asset_assets: AssetAlias): + """Run applicable notification actions when an asset alias is created.""" + get_listener_manager().hook.on_asset_alias_created(asset_alias=asset_assets) @staticmethod - def notify_dataset_changed(dataset: Dataset): - """Run applicable notification actions when a dataset is changed.""" - get_listener_manager().hook.on_dataset_changed(dataset=dataset) + def notify_asset_changed(asset: Asset): + """Run applicable notification actions when an asset is changed.""" + get_listener_manager().hook.on_asset_changed(asset=asset) @classmethod - def _queue_dagruns(cls, dataset_id: int, dags_to_queue: set[DagModel], session: Session) -> None: + def _queue_dagruns(cls, asset_id: int, dags_to_queue: set[DagModel], session: Session) -> None: # Possible race condition: if multiple dags or multiple (usually - # mapped) tasks update the same dataset, this can fail with a unique + # mapped) tasks update the same asset, this can fail with a unique # constraint violation. # # If we support it, use ON CONFLICT to do nothing, otherwise @@ -219,15 +218,13 @@ def _queue_dagruns(cls, dataset_id: int, dags_to_queue: set[DagModel], session: return if session.bind.dialect.name == "postgresql": - return cls._postgres_queue_dagruns(dataset_id, dags_to_queue, session) - return cls._slow_path_queue_dagruns(dataset_id, dags_to_queue, session) + return cls._postgres_queue_dagruns(asset_id, dags_to_queue, session) + return cls._slow_path_queue_dagruns(asset_id, dags_to_queue, session) @classmethod - def _slow_path_queue_dagruns( - cls, dataset_id: int, dags_to_queue: set[DagModel], session: Session - ) -> None: + def _slow_path_queue_dagruns(cls, asset_id: int, dags_to_queue: set[DagModel], session: Session) -> None: def _queue_dagrun_if_needed(dag: DagModel) -> str | None: - item = DatasetDagRunQueue(target_dag_id=dag.dag_id, dataset_id=dataset_id) + item = AssetDagRunQueue(target_dag_id=dag.dag_id, dataset_id=asset_id) # Don't error whole transaction when a single RunQueue item conflicts. # https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#using-savepoint try: @@ -242,11 +239,11 @@ def _queue_dagrun_if_needed(dag: DagModel) -> str | None: cls.logger().debug("consuming dag ids %s", queued_dag_ids) @classmethod - def _postgres_queue_dagruns(cls, dataset_id: int, dags_to_queue: set[DagModel], session: Session) -> None: + def _postgres_queue_dagruns(cls, asset_id: int, dags_to_queue: set[DagModel], session: Session) -> None: from sqlalchemy.dialects.postgresql import insert values = [{"target_dag_id": dag.dag_id} for dag in dags_to_queue] - stmt = insert(DatasetDagRunQueue).values(dataset_id=dataset_id).on_conflict_do_nothing() + stmt = insert(AssetDagRunQueue).values(dataset_id=asset_id).on_conflict_do_nothing() session.execute(stmt, values) @classmethod @@ -279,19 +276,19 @@ def _postgres_send_dag_priority_parsing_request(cls, file_locs: Iterable[str], s session.execute(stmt, {"fileloc": fileloc for fileloc in file_locs}) -def resolve_dataset_manager() -> DatasetManager: - """Retrieve the dataset manager.""" - _dataset_manager_class = conf.getimport( +def resolve_asset_manager() -> AssetManager: + """Retrieve the asset manager.""" + _asset_manager_class = conf.getimport( section="core", - key="dataset_manager_class", - fallback="airflow.datasets.manager.DatasetManager", + key="asset_manager_class", + fallback="airflow.assets.manager.AssetManager", ) - _dataset_manager_kwargs = conf.getjson( + _asset_manager_kwargs = conf.getjson( section="core", - key="dataset_manager_kwargs", + key="asset_manager_kwargs", fallback={}, ) - return _dataset_manager_class(**_dataset_manager_kwargs) + return _asset_manager_class(**_asset_manager_kwargs) -dataset_manager = resolve_dataset_manager() +asset_manager = resolve_asset_manager() diff --git a/airflow/datasets/metadata.py b/airflow/assets/metadata.py similarity index 80% rename from airflow/datasets/metadata.py rename to airflow/assets/metadata.py index 43dff9287365c..4fd2902afc8bf 100644 --- a/airflow/datasets/metadata.py +++ b/airflow/assets/metadata.py @@ -21,26 +21,26 @@ import attrs -from airflow.datasets import DatasetAlias, extract_event_key +from airflow.assets import AssetAlias, extract_event_key if TYPE_CHECKING: - from airflow.datasets import Dataset + from airflow.assets import Asset @attrs.define(init=False) class Metadata: - """Metadata to attach to a DatasetEvent.""" + """Metadata to attach to a AssetEvent.""" uri: str extra: dict[str, Any] alias_name: str | None = None def __init__( - self, target: str | Dataset, extra: dict[str, Any], alias: DatasetAlias | str | None = None + self, target: str | Asset, extra: dict[str, Any], alias: AssetAlias | str | None = None ) -> None: self.uri = extract_event_key(target) self.extra = extra - if isinstance(alias, DatasetAlias): + if isinstance(alias, AssetAlias): self.alias_name = alias.name else: self.alias_name = alias diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 6be53da0807e0..69b5969c827c6 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -46,10 +46,10 @@ ) from airflow.auth.managers.models.resource_details import ( AccessView, + AssetDetails, ConfigurationDetails, ConnectionDetails, DagAccessEntity, - DatasetDetails, PoolDetails, VariableDetails, ) @@ -178,18 +178,18 @@ def is_authorized_dag( """ @abstractmethod - def is_authorized_dataset( + def is_authorized_asset( self, *, method: ResourceMethod, - details: DatasetDetails | None = None, + details: AssetDetails | None = None, user: BaseUser | None = None, ) -> bool: """ - Return whether the user is authorized to perform a given action on a dataset. + Return whether the user is authorized to perform a given action on an asset. :param method: the method to perform - :param details: optional details about the dataset + :param details: optional details about the asset :param user: the user to perform the action on. If not provided (or None), it uses the current user """ diff --git a/airflow/auth/managers/models/resource_details.py b/airflow/auth/managers/models/resource_details.py index fcbee5a2ad299..6dec2236bf233 100644 --- a/airflow/auth/managers/models/resource_details.py +++ b/airflow/auth/managers/models/resource_details.py @@ -43,8 +43,8 @@ class DagDetails: @dataclass -class DatasetDetails: - """Represents the details of a dataset.""" +class AssetDetails: + """Represents the details of an asset.""" uri: str | None = None diff --git a/airflow/auth/managers/simple/simple_auth_manager.py b/airflow/auth/managers/simple/simple_auth_manager.py index a683aa5472cef..451068733667c 100644 --- a/airflow/auth/managers/simple/simple_auth_manager.py +++ b/airflow/auth/managers/simple/simple_auth_manager.py @@ -36,11 +36,11 @@ from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( AccessView, + AssetDetails, ConfigurationDetails, ConnectionDetails, DagAccessEntity, DagDetails, - DatasetDetails, PoolDetails, VariableDetails, ) @@ -163,8 +163,8 @@ def is_authorized_dag( allow_role=SimpleAuthManagerRole.USER, ) - def is_authorized_dataset( - self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None + def is_authorized_asset( + self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None ) -> bool: return self._is_authorized( method=method, diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 7317fce60e4e6..a6d40a48c039e 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -470,22 +470,22 @@ core: type: string default: "0o077" example: ~ - dataset_manager_class: - description: Class to use as dataset manager. - version_added: 2.4.0 + asset_manager_class: + description: Class to use as asset manager. + version_added: 3.0.0 type: string default: ~ - example: 'airflow.datasets.manager.DatasetManager' - dataset_manager_kwargs: - description: Kwargs to supply to dataset manager. - version_added: 2.4.0 + example: 'airflow.datasets.manager.AssetManager' + asset_manager_kwargs: + description: Kwargs to supply to asset manager. + version_added: 3.0.0 type: string sensitive: true default: ~ example: '{"some_param": "some_value"}' - strict_dataset_uri_validation: + strict_asset_uri_validation: description: | - Dataset URI validation should raise an exception if it is not compliant with AIP-60. + Asset URI validation should raise an exception if it is not compliant with AIP-60. By default this configuration is false, meaning that Airflow 2.x only warns the user. In Airflow 3, this configuration will be enabled by default. default: "False" diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index bcac479d875a3..c8ce5dc873afa 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -35,17 +35,17 @@ from sqlalchemy.orm import joinedload, load_only from sqlalchemy.sql import expression -from airflow.datasets import Dataset, DatasetAlias -from airflow.datasets.manager import dataset_manager +from airflow.assets import Asset, AssetAlias +from airflow.assets.manager import asset_manager +from airflow.models.asset import ( + AssetAliasModel, + AssetModel, + DagScheduleAssetAliasReference, + DagScheduleAssetReference, + TaskOutletAssetReference, +) from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag from airflow.models.dagrun import DagRun -from airflow.models.dataset import ( - DagScheduleDatasetAliasReference, - DagScheduleDatasetReference, - DatasetAliasModel, - DatasetModel, - TaskOutletDatasetReference, -) from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.timezone import utcnow from airflow.utils.types import DagRunType @@ -209,7 +209,7 @@ def update_dags( ) dm.timetable_summary = dag.timetable.summary dm.timetable_description = dag.timetable.description - dm.dataset_expression = dag.timetable.dataset_condition.as_expression() + dm.dataset_expression = dag.timetable.asset_condition.as_expression() dm.processor_subdir = processor_subdir last_automated_run: DagRun | None = run_info.latest_runs.get(dag.dag_id) @@ -222,7 +222,7 @@ def update_dags( else: dm.calculate_dagrun_date_fields(dag, last_automated_data_interval) - if not dag.timetable.dataset_condition: + if not dag.timetable.asset_condition: dm.schedule_dataset_references = [] dm.schedule_dataset_alias_references = [] # FIXME: STORE NEW REFERENCES. @@ -237,44 +237,44 @@ def update_dags( dm.dag_owner_links = [] -def _find_all_datasets(dags: Iterable[DAG]) -> Iterator[Dataset]: +def _find_all_assets(dags: Iterable[DAG]) -> Iterator[Asset]: for dag in dags: - for _, dataset in dag.timetable.dataset_condition.iter_datasets(): - yield dataset + for _, asset in dag.timetable.asset_condition.iter_assets(): + yield asset for task in dag.task_dict.values(): for obj in itertools.chain(task.inlets, task.outlets): - if isinstance(obj, Dataset): + if isinstance(obj, Asset): yield obj -def _find_all_dataset_aliases(dags: Iterable[DAG]) -> Iterator[DatasetAlias]: +def _find_all_asset_aliases(dags: Iterable[DAG]) -> Iterator[AssetAlias]: for dag in dags: - for _, alias in dag.timetable.dataset_condition.iter_dataset_aliases(): + for _, alias in dag.timetable.asset_condition.iter_asset_aliases(): yield alias for task in dag.task_dict.values(): for obj in itertools.chain(task.inlets, task.outlets): - if isinstance(obj, DatasetAlias): + if isinstance(obj, AssetAlias): yield obj -class DatasetModelOperation(NamedTuple): - """Collect dataset/alias objects from DAGs and perform database operations for them.""" +class AssetModelOperation(NamedTuple): + """Collect asset/alias objects from DAGs and perform database operations for them.""" - schedule_dataset_references: dict[str, list[Dataset]] - schedule_dataset_alias_references: dict[str, list[DatasetAlias]] - outlet_references: dict[str, list[tuple[str, Dataset]]] - datasets: dict[str, Dataset] - dataset_aliases: dict[str, DatasetAlias] + schedule_asset_references: dict[str, list[Asset]] + schedule_asset_alias_references: dict[str, list[AssetAlias]] + outlet_references: dict[str, list[tuple[str, Asset]]] + assets: dict[str, Asset] + asset_aliases: dict[str, AssetAlias] @classmethod def collect(cls, dags: dict[str, DAG]) -> Self: coll = cls( - schedule_dataset_references={ - dag_id: [dataset for _, dataset in dag.timetable.dataset_condition.iter_datasets()] + schedule_asset_references={ + dag_id: [asset for _, asset in dag.timetable.asset_condition.iter_assets()] for dag_id, dag in dags.items() }, - schedule_dataset_alias_references={ - dag_id: [alias for _, alias in dag.timetable.dataset_condition.iter_dataset_aliases()] + schedule_asset_alias_references={ + dag_id: [alias for _, alias in dag.timetable.asset_condition.iter_asset_aliases()] for dag_id, dag in dags.items() }, outlet_references={ @@ -282,90 +282,89 @@ def collect(cls, dags: dict[str, DAG]) -> Self: (task_id, outlet) for task_id, task in dag.task_dict.items() for outlet in task.outlets - if isinstance(outlet, Dataset) + if isinstance(outlet, Asset) ] for dag_id, dag in dags.items() }, - datasets={dataset.uri: dataset for dataset in _find_all_datasets(dags.values())}, - dataset_aliases={alias.name: alias for alias in _find_all_dataset_aliases(dags.values())}, + assets={asset.uri: asset for asset in _find_all_assets(dags.values())}, + asset_aliases={alias.name: alias for alias in _find_all_asset_aliases(dags.values())}, ) return coll - def add_datasets(self, *, session: Session) -> dict[str, DatasetModel]: - # Optimization: skip all database calls if no datasets were collected. - if not self.datasets: + def add_assets(self, *, session: Session) -> dict[str, AssetModel]: + # Optimization: skip all database calls if no assets were collected. + if not self.assets: return {} - orm_datasets: dict[str, DatasetModel] = { - dm.uri: dm - for dm in session.scalars(select(DatasetModel).where(DatasetModel.uri.in_(self.datasets))) + orm_assets: dict[str, AssetModel] = { + am.uri: am for am in session.scalars(select(AssetModel).where(AssetModel.uri.in_(self.assets))) } - for model in orm_datasets.values(): + for model in orm_assets.values(): model.is_orphaned = expression.false() - orm_datasets.update( + orm_assets.update( (model.uri, model) - for model in dataset_manager.create_datasets( - [dataset for uri, dataset in self.datasets.items() if uri not in orm_datasets], + for model in asset_manager.create_assets( + [asset for uri, asset in self.assets.items() if uri not in orm_assets], session=session, ) ) - return orm_datasets + return orm_assets - def add_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasModel]: - # Optimization: skip all database calls if no dataset aliases were collected. - if not self.dataset_aliases: + def add_asset_aliases(self, *, session: Session) -> dict[str, AssetAliasModel]: + # Optimization: skip all database calls if no asset aliases were collected. + if not self.asset_aliases: return {} - orm_aliases: dict[str, DatasetAliasModel] = { + orm_aliases: dict[str, AssetAliasModel] = { da.name: da for da in session.scalars( - select(DatasetAliasModel).where(DatasetAliasModel.name.in_(self.dataset_aliases)) + select(AssetAliasModel).where(AssetAliasModel.name.in_(self.asset_aliases)) ) } orm_aliases.update( (model.name, model) - for model in dataset_manager.create_dataset_aliases( - [alias for name, alias in self.dataset_aliases.items() if name not in orm_aliases], + for model in asset_manager.create_asset_aliases( + [alias for name, alias in self.asset_aliases.items() if name not in orm_aliases], session=session, ) ) return orm_aliases - def add_dag_dataset_references( + def add_dag_asset_references( self, dags: dict[str, DagModel], - datasets: dict[str, DatasetModel], + assets: dict[str, AssetModel], *, session: Session, ) -> None: - # Optimization: No datasets means there are no references to update. - if not datasets: + # Optimization: No assets means there are no references to update. + if not assets: return - for dag_id, references in self.schedule_dataset_references.items(): + for dag_id, references in self.schedule_asset_references.items(): # Optimization: no references at all; this is faster than repeated delete(). if not references: dags[dag_id].schedule_dataset_references = [] continue - referenced_dataset_ids = {dataset.id for dataset in (datasets[r.uri] for r in references)} + referenced_asset_ids = {asset.id for asset in (assets[r.uri] for r in references)} orm_refs = {r.dataset_id: r for r in dags[dag_id].schedule_dataset_references} - for dataset_id, ref in orm_refs.items(): - if dataset_id not in referenced_dataset_ids: + for asset_id, ref in orm_refs.items(): + if asset_id not in referenced_asset_ids: session.delete(ref) session.bulk_save_objects( - DagScheduleDatasetReference(dataset_id=dataset_id, dag_id=dag_id) - for dataset_id in referenced_dataset_ids - if dataset_id not in orm_refs + DagScheduleAssetReference(dataset_id=asset_id, dag_id=dag_id) + for asset_id in referenced_asset_ids + if asset_id not in orm_refs ) - def add_dag_dataset_alias_references( + def add_dag_asset_alias_references( self, dags: dict[str, DagModel], - aliases: dict[str, DatasetAliasModel], + aliases: dict[str, AssetAliasModel], *, session: Session, ) -> None: # Optimization: No aliases means there are no references to update. if not aliases: return - for dag_id, references in self.schedule_dataset_alias_references.items(): + for dag_id, references in self.schedule_asset_alias_references.items(): # Optimization: no references at all; this is faster than repeated delete(). if not references: dags[dag_id].schedule_dataset_alias_references = [] @@ -376,20 +375,20 @@ def add_dag_dataset_alias_references( if alias_id not in referenced_alias_ids: session.delete(ref) session.bulk_save_objects( - DagScheduleDatasetAliasReference(alias_id=alias_id, dag_id=dag_id) + DagScheduleAssetAliasReference(alias_id=alias_id, dag_id=dag_id) for alias_id in referenced_alias_ids if alias_id not in orm_refs ) - def add_task_dataset_references( + def add_task_asset_references( self, dags: dict[str, DagModel], - datasets: dict[str, DatasetModel], + assets: dict[str, AssetModel], *, session: Session, ) -> None: - # Optimization: No datasets means there are no references to update. - if not datasets: + # Optimization: No assets means there are no references to update. + if not assets: return for dag_id, references in self.outlet_references.items(): # Optimization: no references at all; this is faster than repeated delete(). @@ -397,15 +396,15 @@ def add_task_dataset_references( dags[dag_id].task_outlet_dataset_references = [] continue referenced_outlets = { - (task_id, dataset.id) - for task_id, dataset in ((task_id, datasets[d.uri]) for task_id, d in references) + (task_id, asset.id) + for task_id, asset in ((task_id, assets[d.uri]) for task_id, d in references) } orm_refs = {(r.task_id, r.dataset_id): r for r in dags[dag_id].task_outlet_dataset_references} for key, ref in orm_refs.items(): if key not in referenced_outlets: session.delete(ref) session.bulk_save_objects( - TaskOutletDatasetReference(dataset_id=dataset_id, dag_id=dag_id, task_id=task_id) - for task_id, dataset_id in referenced_outlets - if (task_id, dataset_id) not in orm_refs + TaskOutletAssetReference(dataset_id=asset_id, dag_id=dag_id, task_id=task_id) + for task_id, asset_id in referenced_outlets + if (task_id, asset_id) not in orm_refs ) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 611d363961c51..1ef2c12c702f2 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -40,7 +40,7 @@ import re2 import typing_extensions -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY from airflow.models.baseoperator import ( BaseOperator, @@ -261,7 +261,7 @@ def execute(self, context: Context): # todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators # as well for arg in itertools.chain(self.op_args, self.op_kwargs.values()): - if isinstance(arg, Dataset): + if isinstance(arg, Asset): self.inlets.append(arg) return_value = super().execute(context) return self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push) @@ -270,17 +270,17 @@ def _handle_output(self, return_value: Any, context: Context, xcom_push: Callabl """ Handle logic for whether a decorator needs to push a single return value or multiple return values. - It sets outlets if any datasets are found in the returned value(s) + It sets outlets if any assets are found in the returned value(s) :param return_value: :param context: :param xcom_push: """ - if isinstance(return_value, Dataset): + if isinstance(return_value, Asset): self.outlets.append(return_value) if isinstance(return_value, list): for item in return_value: - if isinstance(item, Dataset): + if isinstance(item, Asset): self.outlets.append(item) return return_value diff --git a/airflow/example_dags/example_asset_alias.py b/airflow/example_dags/example_asset_alias.py new file mode 100644 index 0000000000000..4970b1eda2660 --- /dev/null +++ b/airflow/example_dags/example_asset_alias.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG for demonstrating the behavior of the AssetAlias feature in Airflow, including conditional and +asset expression-based scheduling. + +Notes on usage: + +Turn on all the DAGs. + +Before running any DAG, the schedule of the "asset_alias_example_alias_consumer" DAG will show as "Unresolved AssetAlias". +This is expected because the asset alias has not been resolved into any asset yet. + +Once the "asset_s3_bucket_producer" DAG is triggered, the "asset_s3_bucket_consumer" DAG should be triggered upon completion. +This is because the asset alias "example-alias" is used to add an asset event to the asset "s3://bucket/my-task" +during the "produce_asset_events_through_asset_alias" task. +As the DAG "asset-alias-consumer" relies on asset alias "example-alias" which was previously unresolved, +the DAG "asset-alias-consumer" (along with all the DAGs in the same file) will be re-parsed and +thus update its schedule to the asset "s3://bucket/my-task" and will also be triggered. +""" + +from __future__ import annotations + +import pendulum + +from airflow import DAG +from airflow.assets import Asset, AssetAlias +from airflow.decorators import task + +with DAG( + dag_id="asset_s3_bucket_producer", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=None, + catchup=False, + tags=["producer", "asset"], +): + + @task(outlets=[Asset("s3://bucket/my-task")]) + def produce_asset_events(): + pass + + produce_asset_events() + +with DAG( + dag_id="asset_alias_example_alias_producer", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=None, + catchup=False, + tags=["producer", "asset-alias"], +): + + @task(outlets=[AssetAlias("example-alias")]) + def produce_asset_events_through_asset_alias(*, outlet_events=None): + bucket_name = "bucket" + object_path = "my-task" + outlet_events["example-alias"].add(Asset(f"s3://{bucket_name}/{object_path}")) + + produce_asset_events_through_asset_alias() + +with DAG( + dag_id="asset_s3_bucket_consumer", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[Asset("s3://bucket/my-task")], + catchup=False, + tags=["consumer", "asset"], +): + + @task + def consume_asset_event(): + pass + + consume_asset_event() + +with DAG( + dag_id="asset_alias_example_alias_consumer", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[AssetAlias("example-alias")], + catchup=False, + tags=["consumer", "asset-alias"], +): + + @task(inlets=[AssetAlias("example-alias")]) + def consume_asset_event_from_asset_alias(*, inlet_events=None): + for event in inlet_events[AssetAlias("example-alias")]: + print(event) + + consume_asset_event_from_asset_alias() diff --git a/airflow/example_dags/example_asset_alias_with_no_taskflow.py b/airflow/example_dags/example_asset_alias_with_no_taskflow.py new file mode 100644 index 0000000000000..3293f7e45bb94 --- /dev/null +++ b/airflow/example_dags/example_asset_alias_with_no_taskflow.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG for demonstrating the behavior of the AssetAlias feature in Airflow, including conditional and +asset expression-based scheduling. + +Notes on usage: + +Turn on all the DAGs. + +Before running any DAG, the schedule of the "asset_alias_example_alias_consumer_with_no_taskflow" DAG will show as "unresolved AssetAlias". +This is expected because the asset alias has not been resolved into any asset yet. + +Once the "asset_s3_bucket_producer_with_no_taskflow" DAG is triggered, the "asset_s3_bucket_consumer_with_no_taskflow" DAG should be triggered upon completion. +This is because the asset alias "example-alias-no-taskflow" is used to add an asset event to the asset "s3://bucket/my-task-with-no-taskflow" +during the "produce_asset_events_through_asset_alias_with_no_taskflow" task. Also, the schedule of the "asset_alias_example_alias_consumer_with_no_taskflow" DAG should change to "Asset" as +the asset alias "example-alias-no-taskflow" is now resolved to the asset "s3://bucket/my-task-with-no-taskflow" and this DAG should also be triggered. +""" + +from __future__ import annotations + +import pendulum + +from airflow import DAG +from airflow.assets import Asset, AssetAlias +from airflow.operators.python import PythonOperator + +with DAG( + dag_id="asset_s3_bucket_producer_with_no_taskflow", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=None, + catchup=False, + tags=["producer", "asset"], +): + + def produce_asset_events(): + pass + + PythonOperator( + task_id="produce_asset_events", + outlets=[Asset("s3://bucket/my-task-with-no-taskflow")], + python_callable=produce_asset_events, + ) + + +with DAG( + dag_id="asset_alias_example_alias_producer_with_no_taskflow", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=None, + catchup=False, + tags=["producer", "asset-alias"], +): + + def produce_asset_events_through_asset_alias_with_no_taskflow(*, outlet_events=None): + bucket_name = "bucket" + object_path = "my-task" + outlet_events["example-alias-no-taskflow"].add(Asset(f"s3://{bucket_name}/{object_path}")) + + PythonOperator( + task_id="produce_asset_events_through_asset_alias_with_no_taskflow", + outlets=[AssetAlias("example-alias-no-taskflow")], + python_callable=produce_asset_events_through_asset_alias_with_no_taskflow, + ) + +with DAG( + dag_id="asset_s3_bucket_consumer_with_no_taskflow", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[Asset("s3://bucket/my-task-with-no-taskflow")], + catchup=False, + tags=["consumer", "asset"], +): + + def consume_asset_event(): + pass + + PythonOperator(task_id="consume_asset_event", python_callable=consume_asset_event) + +with DAG( + dag_id="asset_alias_example_alias_consumer_with_no_taskflow", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[AssetAlias("example-alias-no-taskflow")], + catchup=False, + tags=["consumer", "asset-alias"], +): + + def consume_asset_event_from_asset_alias(*, inlet_events=None): + for event in inlet_events[AssetAlias("example-alias-no-taskflow")]: + print(event) + + PythonOperator( + task_id="consume_asset_event_from_asset_alias", + python_callable=consume_asset_event_from_asset_alias, + inlets=[AssetAlias("example-alias-no-taskflow")], + ) diff --git a/airflow/example_dags/example_assets.py b/airflow/example_dags/example_assets.py new file mode 100644 index 0000000000000..66369794ed999 --- /dev/null +++ b/airflow/example_dags/example_assets.py @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG for demonstrating the behavior of the Assets feature in Airflow, including conditional and +asset expression-based scheduling. + +Notes on usage: + +Turn on all the DAGs. + +asset_produces_1 is scheduled to run daily. Once it completes, it triggers several DAGs due to its asset +being updated. asset_consumes_1 is triggered immediately, as it depends solely on the asset produced by +asset_produces_1. consume_1_or_2_with_asset_expressions will also be triggered, as its condition of +either asset_produces_1 or asset_produces_2 being updated is satisfied with asset_produces_1. + +asset_consumes_1_and_2 will not be triggered after asset_produces_1 runs because it requires the asset +from asset_produces_2, which has no schedule and must be manually triggered. + +After manually triggering asset_produces_2, several DAGs will be affected. asset_consumes_1_and_2 should +run because both its asset dependencies are now met. consume_1_and_2_with_asset_expressions will be +triggered, as it requires both asset_produces_1 and asset_produces_2 assets to be updated. +consume_1_or_2_with_asset_expressions will be triggered again, since it's conditionally set to run when +either asset is updated. + +consume_1_or_both_2_and_3_with_asset_expressions demonstrates complex asset dependency logic. +This DAG triggers if asset_produces_1 is updated or if both asset_produces_2 and dag3_asset +are updated. This example highlights the capability to combine updates from multiple assets with logical +expressions for advanced scheduling. + +conditional_asset_and_time_based_timetable illustrates the integration of time-based scheduling with +asset dependencies. This DAG is configured to execute either when both asset_produces_1 and +asset_produces_2 assets have been updated or according to a specific cron schedule, showcasing +Airflow's versatility in handling mixed triggers for asset and time-based scheduling. + +The DAGs asset_consumes_1_never_scheduled and asset_consumes_unknown_never_scheduled will not run +automatically as they depend on assets that do not get updated or are not produced by any scheduled tasks. +""" + +from __future__ import annotations + +import pendulum + +from airflow.assets import Asset +from airflow.models.dag import DAG +from airflow.operators.bash import BashOperator +from airflow.timetables.assets import AssetOrTimeSchedule +from airflow.timetables.trigger import CronTriggerTimetable + +# [START asset_def] +dag1_asset = Asset("s3://dag1/output_1.txt", extra={"hi": "bye"}) +# [END asset_def] +dag2_asset = Asset("s3://dag2/output_1.txt", extra={"hi": "bye"}) +dag3_asset = Asset("s3://dag3/output_3.txt", extra={"hi": "bye"}) + +with DAG( + dag_id="asset_produces_1", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule="@daily", + tags=["produces", "asset-scheduled"], +) as dag1: + # [START task_outlet] + BashOperator(outlets=[dag1_asset], task_id="producing_task_1", bash_command="sleep 5") + # [END task_outlet] + +with DAG( + dag_id="asset_produces_2", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=None, + tags=["produces", "asset-scheduled"], +) as dag2: + BashOperator(outlets=[dag2_asset], task_id="producing_task_2", bash_command="sleep 5") + +# [START dag_dep] +with DAG( + dag_id="asset_consumes_1", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[dag1_asset], + tags=["consumes", "asset-scheduled"], +) as dag3: + # [END dag_dep] + BashOperator( + outlets=[Asset("s3://consuming_1_task/asset_other.txt")], + task_id="consuming_1", + bash_command="sleep 5", + ) + +with DAG( + dag_id="asset_consumes_1_and_2", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[dag1_asset, dag2_asset], + tags=["consumes", "asset-scheduled"], +) as dag4: + BashOperator( + outlets=[Asset("s3://consuming_2_task/asset_other_unknown.txt")], + task_id="consuming_2", + bash_command="sleep 5", + ) + +with DAG( + dag_id="asset_consumes_1_never_scheduled", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[ + dag1_asset, + Asset("s3://unrelated/this-asset-doesnt-get-triggered"), + ], + tags=["consumes", "asset-scheduled"], +) as dag5: + BashOperator( + outlets=[Asset("s3://consuming_2_task/asset_other_unknown.txt")], + task_id="consuming_3", + bash_command="sleep 5", + ) + +with DAG( + dag_id="asset_consumes_unknown_never_scheduled", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=[ + Asset("s3://unrelated/asset3.txt"), + Asset("s3://unrelated/asset_other_unknown.txt"), + ], + tags=["asset-scheduled"], +) as dag6: + BashOperator( + task_id="unrelated_task", + outlets=[Asset("s3://unrelated_task/asset_other_unknown.txt")], + bash_command="sleep 5", + ) + +with DAG( + dag_id="consume_1_and_2_with_asset_expressions", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=(dag1_asset & dag2_asset), +) as dag5: + BashOperator( + outlets=[Asset("s3://consuming_2_task/asset_other_unknown.txt")], + task_id="consume_1_and_2_with_asset_expressions", + bash_command="sleep 5", + ) +with DAG( + dag_id="consume_1_or_2_with_asset_expressions", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=(dag1_asset | dag2_asset), +) as dag6: + BashOperator( + outlets=[Asset("s3://consuming_2_task/asset_other_unknown.txt")], + task_id="consume_1_or_2_with_asset_expressions", + bash_command="sleep 5", + ) +with DAG( + dag_id="consume_1_or_both_2_and_3_with_asset_expressions", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=(dag1_asset | (dag2_asset & dag3_asset)), +) as dag7: + BashOperator( + outlets=[Asset("s3://consuming_2_task/asset_other_unknown.txt")], + task_id="consume_1_or_both_2_and_3_with_asset_expressions", + bash_command="sleep 5", + ) +with DAG( + dag_id="conditional_asset_and_time_based_timetable", + catchup=False, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + schedule=AssetOrTimeSchedule( + timetable=CronTriggerTimetable("0 1 * * 3", timezone="UTC"), assets=(dag1_asset & dag2_asset) + ), + tags=["asset-time-based-timetable"], +) as dag8: + BashOperator( + outlets=[Asset("s3://asset_time_based/asset_other_unknown.txt")], + task_id="conditional_asset_and_time_based_timetable", + bash_command="sleep 5", + ) diff --git a/airflow/example_dags/example_dataset_alias.py b/airflow/example_dags/example_dataset_alias.py deleted file mode 100644 index c50a89e34fb8c..0000000000000 --- a/airflow/example_dags/example_dataset_alias.py +++ /dev/null @@ -1,101 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example DAG for demonstrating the behavior of the DatasetAlias feature in Airflow, including conditional and -dataset expression-based scheduling. - -Notes on usage: - -Turn on all the DAGs. - -Before running any DAG, the schedule of the "dataset_alias_example_alias_consumer" DAG will show as "Unresolved DatasetAlias". -This is expected because the dataset alias has not been resolved into any dataset yet. - -Once the "dataset_s3_bucket_producer" DAG is triggered, the "dataset_s3_bucket_consumer" DAG should be triggered upon completion. -This is because the dataset alias "example-alias" is used to add a dataset event to the dataset "s3://bucket/my-task" -during the "produce_dataset_events_through_dataset_alias" task. -As the DAG "dataset-alias-consumer" relies on dataset alias "example-alias" which was previously unresolved, -the DAG "dataset-alias-consumer" (along with all the DAGs in the same file) will be re-parsed and -thus update its schedule to the dataset "s3://bucket/my-task" and will also be triggered. -""" - -from __future__ import annotations - -import pendulum - -from airflow import DAG -from airflow.datasets import Dataset, DatasetAlias -from airflow.decorators import task - -with DAG( - dag_id="dataset_s3_bucket_producer", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=None, - catchup=False, - tags=["producer", "dataset"], -): - - @task(outlets=[Dataset("s3://bucket/my-task")]) - def produce_dataset_events(): - pass - - produce_dataset_events() - -with DAG( - dag_id="dataset_alias_example_alias_producer", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=None, - catchup=False, - tags=["producer", "dataset-alias"], -): - - @task(outlets=[DatasetAlias("example-alias")]) - def produce_dataset_events_through_dataset_alias(*, outlet_events=None): - bucket_name = "bucket" - object_path = "my-task" - outlet_events["example-alias"].add(Dataset(f"s3://{bucket_name}/{object_path}")) - - produce_dataset_events_through_dataset_alias() - -with DAG( - dag_id="dataset_s3_bucket_consumer", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=[Dataset("s3://bucket/my-task")], - catchup=False, - tags=["consumer", "dataset"], -): - - @task - def consume_dataset_event(): - pass - - consume_dataset_event() - -with DAG( - dag_id="dataset_alias_example_alias_consumer", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=[DatasetAlias("example-alias")], - catchup=False, - tags=["consumer", "dataset-alias"], -): - - @task(inlets=[DatasetAlias("example-alias")]) - def consume_dataset_event_from_dataset_alias(*, inlet_events=None): - for event in inlet_events[DatasetAlias("example-alias")]: - print(event) - - consume_dataset_event_from_dataset_alias() diff --git a/airflow/example_dags/example_dataset_alias_with_no_taskflow.py b/airflow/example_dags/example_dataset_alias_with_no_taskflow.py deleted file mode 100644 index 7d7227af39f50..0000000000000 --- a/airflow/example_dags/example_dataset_alias_with_no_taskflow.py +++ /dev/null @@ -1,108 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example DAG for demonstrating the behavior of the DatasetAlias feature in Airflow, including conditional and -dataset expression-based scheduling. - -Notes on usage: - -Turn on all the DAGs. - -Before running any DAG, the schedule of the "dataset_alias_example_alias_consumer_with_no_taskflow" DAG will show as "unresolved DatasetAlias". -This is expected because the dataset alias has not been resolved into any dataset yet. - -Once the "dataset_s3_bucket_producer_with_no_taskflow" DAG is triggered, the "dataset_s3_bucket_consumer_with_no_taskflow" DAG should be triggered upon completion. -This is because the dataset alias "example-alias-no-taskflow" is used to add a dataset event to the dataset "s3://bucket/my-task-with-no-taskflow" -during the "produce_dataset_events_through_dataset_alias_with_no_taskflow" task. Also, the schedule of the "dataset_alias_example_alias_consumer_with_no_taskflow" DAG should change to "Dataset" as -the dataset alias "example-alias-no-taskflow" is now resolved to the dataset "s3://bucket/my-task-with-no-taskflow" and this DAG should also be triggered. -""" - -from __future__ import annotations - -import pendulum - -from airflow import DAG -from airflow.datasets import Dataset, DatasetAlias -from airflow.operators.python import PythonOperator - -with DAG( - dag_id="dataset_s3_bucket_producer_with_no_taskflow", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=None, - catchup=False, - tags=["producer", "dataset"], -): - - def produce_dataset_events(): - pass - - PythonOperator( - task_id="produce_dataset_events", - outlets=[Dataset("s3://bucket/my-task-with-no-taskflow")], - python_callable=produce_dataset_events, - ) - - -with DAG( - dag_id="dataset_alias_example_alias_producer_with_no_taskflow", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=None, - catchup=False, - tags=["producer", "dataset-alias"], -): - - def produce_dataset_events_through_dataset_alias_with_no_taskflow(*, outlet_events=None): - bucket_name = "bucket" - object_path = "my-task" - outlet_events["example-alias-no-taskflow"].add(Dataset(f"s3://{bucket_name}/{object_path}")) - - PythonOperator( - task_id="produce_dataset_events_through_dataset_alias_with_no_taskflow", - outlets=[DatasetAlias("example-alias-no-taskflow")], - python_callable=produce_dataset_events_through_dataset_alias_with_no_taskflow, - ) - -with DAG( - dag_id="dataset_s3_bucket_consumer_with_no_taskflow", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=[Dataset("s3://bucket/my-task-with-no-taskflow")], - catchup=False, - tags=["consumer", "dataset"], -): - - def consume_dataset_event(): - pass - - PythonOperator(task_id="consume_dataset_event", python_callable=consume_dataset_event) - -with DAG( - dag_id="dataset_alias_example_alias_consumer_with_no_taskflow", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=[DatasetAlias("example-alias-no-taskflow")], - catchup=False, - tags=["consumer", "dataset-alias"], -): - - def consume_dataset_event_from_dataset_alias(*, inlet_events=None): - for event in inlet_events[DatasetAlias("example-alias-no-taskflow")]: - print(event) - - PythonOperator( - task_id="consume_dataset_event_from_dataset_alias", - python_callable=consume_dataset_event_from_dataset_alias, - inlets=[DatasetAlias("example-alias-no-taskflow")], - ) diff --git a/airflow/example_dags/example_datasets.py b/airflow/example_dags/example_datasets.py deleted file mode 100644 index 54f15d8a2d802..0000000000000 --- a/airflow/example_dags/example_datasets.py +++ /dev/null @@ -1,192 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example DAG for demonstrating the behavior of the Datasets feature in Airflow, including conditional and -dataset expression-based scheduling. - -Notes on usage: - -Turn on all the DAGs. - -dataset_produces_1 is scheduled to run daily. Once it completes, it triggers several DAGs due to its dataset -being updated. dataset_consumes_1 is triggered immediately, as it depends solely on the dataset produced by -dataset_produces_1. consume_1_or_2_with_dataset_expressions will also be triggered, as its condition of -either dataset_produces_1 or dataset_produces_2 being updated is satisfied with dataset_produces_1. - -dataset_consumes_1_and_2 will not be triggered after dataset_produces_1 runs because it requires the dataset -from dataset_produces_2, which has no schedule and must be manually triggered. - -After manually triggering dataset_produces_2, several DAGs will be affected. dataset_consumes_1_and_2 should -run because both its dataset dependencies are now met. consume_1_and_2_with_dataset_expressions will be -triggered, as it requires both dataset_produces_1 and dataset_produces_2 datasets to be updated. -consume_1_or_2_with_dataset_expressions will be triggered again, since it's conditionally set to run when -either dataset is updated. - -consume_1_or_both_2_and_3_with_dataset_expressions demonstrates complex dataset dependency logic. -This DAG triggers if dataset_produces_1 is updated or if both dataset_produces_2 and dag3_dataset -are updated. This example highlights the capability to combine updates from multiple datasets with logical -expressions for advanced scheduling. - -conditional_dataset_and_time_based_timetable illustrates the integration of time-based scheduling with -dataset dependencies. This DAG is configured to execute either when both dataset_produces_1 and -dataset_produces_2 datasets have been updated or according to a specific cron schedule, showcasing -Airflow's versatility in handling mixed triggers for dataset and time-based scheduling. - -The DAGs dataset_consumes_1_never_scheduled and dataset_consumes_unknown_never_scheduled will not run -automatically as they depend on datasets that do not get updated or are not produced by any scheduled tasks. -""" - -from __future__ import annotations - -import pendulum - -from airflow.datasets import Dataset -from airflow.models.dag import DAG -from airflow.operators.bash import BashOperator -from airflow.timetables.datasets import DatasetOrTimeSchedule -from airflow.timetables.trigger import CronTriggerTimetable - -# [START dataset_def] -dag1_dataset = Dataset("s3://dag1/output_1.txt", extra={"hi": "bye"}) -# [END dataset_def] -dag2_dataset = Dataset("s3://dag2/output_1.txt", extra={"hi": "bye"}) -dag3_dataset = Dataset("s3://dag3/output_3.txt", extra={"hi": "bye"}) - -with DAG( - dag_id="dataset_produces_1", - catchup=False, - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule="@daily", - tags=["produces", "dataset-scheduled"], -) as dag1: - # [START task_outlet] - BashOperator(outlets=[dag1_dataset], task_id="producing_task_1", bash_command="sleep 5") - # [END task_outlet] - -with DAG( - dag_id="dataset_produces_2", - catchup=False, - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=None, - tags=["produces", "dataset-scheduled"], -) as dag2: - BashOperator(outlets=[dag2_dataset], task_id="producing_task_2", bash_command="sleep 5") - -# [START dag_dep] -with DAG( - dag_id="dataset_consumes_1", - catchup=False, - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=[dag1_dataset], - tags=["consumes", "dataset-scheduled"], -) as dag3: - # [END dag_dep] - BashOperator( - outlets=[Dataset("s3://consuming_1_task/dataset_other.txt")], - task_id="consuming_1", - bash_command="sleep 5", - ) - -with DAG( - dag_id="dataset_consumes_1_and_2", - catchup=False, - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=[dag1_dataset, dag2_dataset], - tags=["consumes", "dataset-scheduled"], -) as dag4: - BashOperator( - outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], - task_id="consuming_2", - bash_command="sleep 5", - ) - -with DAG( - dag_id="dataset_consumes_1_never_scheduled", - catchup=False, - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=[ - dag1_dataset, - Dataset("s3://unrelated/this-dataset-doesnt-get-triggered"), - ], - tags=["consumes", "dataset-scheduled"], -) as dag5: - BashOperator( - outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], - task_id="consuming_3", - bash_command="sleep 5", - ) - -with DAG( - dag_id="dataset_consumes_unknown_never_scheduled", - catchup=False, - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=[ - Dataset("s3://unrelated/dataset3.txt"), - Dataset("s3://unrelated/dataset_other_unknown.txt"), - ], - tags=["dataset-scheduled"], -) as dag6: - BashOperator( - task_id="unrelated_task", - outlets=[Dataset("s3://unrelated_task/dataset_other_unknown.txt")], - bash_command="sleep 5", - ) - -with DAG( - dag_id="consume_1_and_2_with_dataset_expressions", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=(dag1_dataset & dag2_dataset), -) as dag5: - BashOperator( - outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], - task_id="consume_1_and_2_with_dataset_expressions", - bash_command="sleep 5", - ) -with DAG( - dag_id="consume_1_or_2_with_dataset_expressions", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=(dag1_dataset | dag2_dataset), -) as dag6: - BashOperator( - outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], - task_id="consume_1_or_2_with_dataset_expressions", - bash_command="sleep 5", - ) -with DAG( - dag_id="consume_1_or_both_2_and_3_with_dataset_expressions", - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=(dag1_dataset | (dag2_dataset & dag3_dataset)), -) as dag7: - BashOperator( - outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")], - task_id="consume_1_or_both_2_and_3_with_dataset_expressions", - bash_command="sleep 5", - ) -with DAG( - dag_id="conditional_dataset_and_time_based_timetable", - catchup=False, - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - schedule=DatasetOrTimeSchedule( - timetable=CronTriggerTimetable("0 1 * * 3", timezone="UTC"), datasets=(dag1_dataset & dag2_dataset) - ), - tags=["dataset-time-based-timetable"], -) as dag8: - BashOperator( - outlets=[Dataset("s3://dataset_time_based/dataset_other_unknown.txt")], - task_id="conditional_dataset_and_time_based_timetable", - bash_command="sleep 5", - ) diff --git a/airflow/example_dags/example_inlet_event_extra.py b/airflow/example_dags/example_inlet_event_extra.py index 4b7567fc2f87e..974534c295b79 100644 --- a/airflow/example_dags/example_inlet_event_extra.py +++ b/airflow/example_dags/example_inlet_event_extra.py @@ -16,7 +16,7 @@ # under the License. """ -Example DAG to demonstrate reading dataset events annotated with extra information. +Example DAG to demonstrate reading asset events annotated with extra information. Also see examples in ``example_outlet_event_extra.py``. """ @@ -25,37 +25,37 @@ import datetime -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.decorators import task from airflow.models.dag import DAG from airflow.operators.bash import BashOperator -ds = Dataset("s3://output/1.txt") +asset = Asset("s3://output/1.txt") with DAG( - dag_id="read_dataset_event", + dag_id="read_asset_event", catchup=False, start_date=datetime.datetime.min, schedule="@daily", tags=["consumes"], ): - @task(inlets=[ds]) - def read_dataset_event(*, inlet_events=None): - for event in inlet_events[ds][:-2]: + @task(inlets=[asset]) + def read_asset_event(*, inlet_events=None): + for event in inlet_events[asset][:-2]: print(event.extra["hi"]) - read_dataset_event() + read_asset_event() with DAG( - dag_id="read_dataset_event_from_classic", + dag_id="read_asset_event_from_classic", catchup=False, start_date=datetime.datetime.min, schedule="@daily", tags=["consumes"], ): BashOperator( - task_id="read_dataset_event_from_classic", - inlets=[ds], + task_id="read_asset_event_from_classic", + inlets=[asset], bash_command="echo '{{ inlet_events['s3://output/1.txt'][-1].extra | tojson }}'", ) diff --git a/airflow/example_dags/example_outlet_event_extra.py b/airflow/example_dags/example_outlet_event_extra.py index 5f7d986e90fdf..893090460b538 100644 --- a/airflow/example_dags/example_outlet_event_extra.py +++ b/airflow/example_dags/example_outlet_event_extra.py @@ -16,7 +16,7 @@ # under the License. """ -Example DAG to demonstrate annotating a dataset event with extra information. +Example DAG to demonstrate annotating an asset event with extra information. Also see examples in ``example_inlet_event_extra.py``. """ @@ -25,16 +25,16 @@ import datetime -from airflow.datasets import Dataset -from airflow.datasets.metadata import Metadata +from airflow.assets import Asset +from airflow.assets.metadata import Metadata from airflow.decorators import task from airflow.models.dag import DAG from airflow.operators.bash import BashOperator -ds = Dataset("s3://output/1.txt") +ds = Asset("s3://output/1.txt") with DAG( - dag_id="dataset_with_extra_by_yield", + dag_id="asset_with_extra_by_yield", catchup=False, start_date=datetime.datetime.min, schedule="@daily", @@ -42,13 +42,13 @@ ): @task(outlets=[ds]) - def dataset_with_extra_by_yield(): + def asset_with_extra_by_yield(): yield Metadata(ds, {"hi": "bye"}) - dataset_with_extra_by_yield() + asset_with_extra_by_yield() with DAG( - dag_id="dataset_with_extra_by_context", + dag_id="asset_with_extra_by_context", catchup=False, start_date=datetime.datetime.min, schedule="@daily", @@ -56,25 +56,25 @@ def dataset_with_extra_by_yield(): ): @task(outlets=[ds]) - def dataset_with_extra_by_context(*, outlet_events=None): + def asset_with_extra_by_context(*, outlet_events=None): outlet_events[ds].extra = {"hi": "bye"} - dataset_with_extra_by_context() + asset_with_extra_by_context() with DAG( - dag_id="dataset_with_extra_from_classic_operator", + dag_id="asset_with_extra_from_classic_operator", catchup=False, start_date=datetime.datetime.min, schedule="@daily", tags=["produces"], ): - def _dataset_with_extra_from_classic_operator_post_execute(context, result): + def _asset_with_extra_from_classic_operator_post_execute(context, result): context["outlet_events"][ds].extra = {"hi": "bye"} BashOperator( - task_id="dataset_with_extra_from_classic_operator", + task_id="asset_with_extra_from_classic_operator", outlets=[ds], bash_command=":", - post_execute=_dataset_with_extra_from_classic_operator_post_execute, + post_execute=_asset_with_extra_from_classic_operator_post_execute, ) diff --git a/airflow/io/path.py b/airflow/io/path.py index 6deafae004959..3526050d12883 100644 --- a/airflow/io/path.py +++ b/airflow/io/path.py @@ -56,9 +56,9 @@ def __getattr__(self, name): def wrapper(*args, **kwargs): self.log.debug("Calling method: %s", name) if name == "read": - get_hook_lineage_collector().add_input_dataset(context=self._path, uri=str(self._path)) + get_hook_lineage_collector().add_input_asset(context=self._path, uri=str(self._path)) elif name == "write": - get_hook_lineage_collector().add_output_dataset(context=self._path, uri=str(self._path)) + get_hook_lineage_collector().add_output_asset(context=self._path, uri=str(self._path)) result = attr(*args, **kwargs) return result @@ -316,8 +316,8 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) if self.samestore(dst) or self.protocol == "file" or dst.protocol == "file": # only emit this in "optimized" variants - else lineage will be captured by file writes/reads - get_hook_lineage_collector().add_input_dataset(context=self, uri=str(self)) - get_hook_lineage_collector().add_output_dataset(context=dst, uri=str(dst)) + get_hook_lineage_collector().add_input_asset(context=self, uri=str(self)) + get_hook_lineage_collector().add_output_asset(context=dst, uri=str(dst)) # same -> same if self.samestore(dst): @@ -381,8 +381,8 @@ def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs) path = ObjectStoragePath(path) if self.samestore(path): - get_hook_lineage_collector().add_input_dataset(context=self, uri=str(self)) - get_hook_lineage_collector().add_output_dataset(context=path, uri=str(path)) + get_hook_lineage_collector().add_input_asset(context=self, uri=str(self)) + get_hook_lineage_collector().add_output_asset(context=path, uri=str(path)) return self.fs.move(self.path, path.path, recursive=recursive, **kwargs) # non-local copy diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 9438edd4d9187..242154820df9e 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -44,21 +44,21 @@ from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import Job, perform_heartbeat from airflow.models import Log +from airflow.models.asset import ( + AssetDagRunQueue, + AssetEvent, + AssetModel, + DagScheduleAssetReference, + TaskOutletAssetReference, +) from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun -from airflow.models.dataset import ( - DagScheduleDatasetReference, - DatasetDagRunQueue, - DatasetEvent, - DatasetModel, - TaskOutletDatasetReference, -) from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance from airflow.stats import Stats from airflow.ti_deps.dependencies_states import EXECUTION_STATES -from airflow.timetables.simple import DatasetTriggeredTimetable +from airflow.timetables.simple import AssetTriggeredTimetable from airflow.traces import utils as trace_utils from airflow.traces.tracer import Trace, add_span from airflow.utils import timezone @@ -1086,7 +1086,7 @@ def _run_scheduler_loop(self) -> None: timers.call_regular_interval( conf.getfloat("scheduler", "parsing_cleanup_interval"), - self._orphan_unreferenced_datasets, + self._orphan_unreferenced_assets, ) if self._standalone_dag_processor: @@ -1286,9 +1286,7 @@ def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session: Sessio non_dataset_dags = all_dags_needing_dag_runs.difference(dataset_triggered_dags) self._create_dag_runs(non_dataset_dags, session) if dataset_triggered_dags: - self._create_dag_runs_dataset_triggered( - dataset_triggered_dags, dataset_triggered_dag_info, session - ) + self._create_dag_runs_asset_triggered(dataset_triggered_dags, dataset_triggered_dag_info, session) # commit the session - Release the write lock on DagModel table. guard.commit() @@ -1367,13 +1365,13 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) - # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in # memory for larger dags? or expunge_all() - def _create_dag_runs_dataset_triggered( + def _create_dag_runs_asset_triggered( self, dag_models: Collection[DagModel], dataset_triggered_dag_info: dict[str, tuple[datetime, datetime]], session: Session, ) -> None: - """For DAGs that are triggered by datasets, create dag runs.""" + """For DAGs that are triggered by assets, create dag runs.""" # Bulk Fetch DagRuns with dag_id and execution_date same # as DagModel.dag_id and DagModel.next_dagrun # This list is used to verify if the DagRun already exist so that we don't attempt to create @@ -1396,9 +1394,9 @@ def _create_dag_runs_dataset_triggered( self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id) continue - if not isinstance(dag.timetable, DatasetTriggeredTimetable): + if not isinstance(dag.timetable, AssetTriggeredTimetable): self.log.error( - "DAG '%s' was dataset-scheduled, but didn't have a DatasetTriggeredTimetable!", + "DAG '%s' was asset-scheduled, but didn't have a AssetTriggeredTimetable!", dag_model.dag_id, ) continue @@ -1425,29 +1423,29 @@ def _create_dag_runs_dataset_triggered( .order_by(DagRun.execution_date.desc()) .limit(1) ) - dataset_event_filters = [ - DagScheduleDatasetReference.dag_id == dag.dag_id, - DatasetEvent.timestamp <= exec_date, + asset_event_filters = [ + DagScheduleAssetReference.dag_id == dag.dag_id, + AssetEvent.timestamp <= exec_date, ] if previous_dag_run: - dataset_event_filters.append(DatasetEvent.timestamp > previous_dag_run.execution_date) + asset_event_filters.append(AssetEvent.timestamp > previous_dag_run.execution_date) - dataset_events = session.scalars( - select(DatasetEvent) + asset_events = session.scalars( + select(AssetEvent) .join( - DagScheduleDatasetReference, - DatasetEvent.dataset_id == DagScheduleDatasetReference.dataset_id, + DagScheduleAssetReference, + AssetEvent.dataset_id == DagScheduleAssetReference.dataset_id, ) - .where(*dataset_event_filters) + .where(*asset_event_filters) ).all() - data_interval = dag.timetable.data_interval_for_events(exec_date, dataset_events) + data_interval = dag.timetable.data_interval_for_events(exec_date, asset_events) run_id = dag.timetable.generate_run_id( run_type=DagRunType.DATASET_TRIGGERED, logical_date=exec_date, data_interval=data_interval, session=session, - events=dataset_events, + events=asset_events, ) dag_run = dag.create_dagrun( @@ -1462,10 +1460,10 @@ def _create_dag_runs_dataset_triggered( creating_job_id=self.job.id, triggered_by=DagRunTriggeredByType.DATASET, ) - Stats.incr("dataset.triggered_dagruns") - dag_run.consumed_dataset_events.extend(dataset_events) + Stats.incr("asset.triggered_dagruns") + dag_run.consumed_dataset_events.extend(asset_events) session.execute( - delete(DatasetDagRunQueue).where(DatasetDagRunQueue.target_dag_id == dag_run.dag_id) + delete(AssetDagRunQueue).where(AssetDagRunQueue.target_dag_id == dag_run.dag_id) ) def _should_update_dag_next_dagruns( @@ -2014,40 +2012,40 @@ def _cleanup_stale_dags(self, session: Session = NEW_SESSION) -> None: SerializedDagModel.remove_dag(dag_id=dag.dag_id, session=session) session.flush() - def _set_orphaned(self, dataset: DatasetModel) -> int: - self.log.info("Orphaning unreferenced dataset '%s'", dataset.uri) - dataset.is_orphaned = expression.true() + def _set_orphaned(self, asset: AssetModel) -> int: + self.log.info("Orphaning unreferenced asset '%s'", asset.uri) + asset.is_orphaned = expression.true() return 1 @provide_session - def _orphan_unreferenced_datasets(self, session: Session = NEW_SESSION) -> None: + def _orphan_unreferenced_assets(self, session: Session = NEW_SESSION) -> None: """ - Detect orphaned datasets and set is_orphaned flag to True. + Detect orphaned assets and set is_orphaned flag to True. - An orphaned dataset is no longer referenced in any DAG schedule parameters or task outlets. + An orphaned asset is no longer referenced in any DAG schedule parameters or task outlets. """ - orphaned_dataset_query = session.scalars( - select(DatasetModel) + orphaned_asset_query = session.scalars( + select(AssetModel) .join( - DagScheduleDatasetReference, + DagScheduleAssetReference, isouter=True, ) .join( - TaskOutletDatasetReference, + TaskOutletAssetReference, isouter=True, ) - .group_by(DatasetModel.id) - .where(~DatasetModel.is_orphaned) + .group_by(AssetModel.id) + .where(~AssetModel.is_orphaned) .having( and_( - func.count(DagScheduleDatasetReference.dag_id) == 0, - func.count(TaskOutletDatasetReference.dag_id) == 0, + func.count(DagScheduleAssetReference.dag_id) == 0, + func.count(TaskOutletAssetReference.dag_id) == 0, ) ) ) - updated_count = sum(self._set_orphaned(dataset) for dataset in orphaned_dataset_query) - Stats.gauge("dataset.orphaned", updated_count) + updated_count = sum(self._set_orphaned(asset) for asset in orphaned_asset_query) + Stats.gauge("asset.orphaned", updated_count) def _executor_to_tis(self, tis: list[TaskInstance]) -> dict[BaseExecutor, list[TaskInstance]]: """Organize TIs into lists per their respective executor.""" diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py index 332a04e7250bf..4385f3fbaf586 100644 --- a/airflow/lineage/__init__.py +++ b/airflow/lineage/__init__.py @@ -104,7 +104,7 @@ def prepare_lineage(func: T) -> T: * "auto" -> picks up any outlets from direct upstream tasks that have outlets defined, as such that if A -> B -> C and B does not have outlets but A does, these are provided as inlets. * "list of task_ids" -> picks up outlets from the upstream task_ids - * "list of datasets" -> manually defined list of data + * "list of datasets" -> manually defined list of dataset """ diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py index 4ff35e4d9ce82..fd321bcab49cf 100644 --- a/airflow/lineage/hook.py +++ b/airflow/lineage/hook.py @@ -24,7 +24,7 @@ import attr -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.providers_manager import ProvidersManager from airflow.utils.log.logging_mixin import LoggingMixin @@ -39,15 +39,15 @@ @attr.define -class DatasetLineageInfo: +class AssetLineageInfo: """ - Holds lineage information for a single dataset. + Holds lineage information for a single asset. - This class represents the lineage information for a single dataset, including the dataset itself, + This class represents the lineage information for a single asset, including the asset itself, the count of how many times it has been encountered, and the context in which it was encountered. """ - dataset: Dataset + asset: Asset count: int context: LineageContext @@ -58,133 +58,129 @@ class HookLineage: Holds lineage collected by HookLineageCollector. This class represents the lineage information collected by the `HookLineageCollector`. It stores - the input and output datasets, each with an associated count indicating how many times the dataset + the input and output assets, each with an associated count indicating how many times the asset has been encountered during the hook execution. """ - inputs: list[DatasetLineageInfo] = attr.ib(factory=list) - outputs: list[DatasetLineageInfo] = attr.ib(factory=list) + inputs: list[AssetLineageInfo] = attr.ib(factory=list) + outputs: list[AssetLineageInfo] = attr.ib(factory=list) class HookLineageCollector(LoggingMixin): """ HookLineageCollector is a base class for collecting hook lineage information. - It is used to collect the input and output datasets of a hook execution. + It is used to collect the input and output assets of a hook execution. """ def __init__(self, **kwargs): super().__init__(**kwargs) - # Dictionary to store input datasets, counted by unique key (dataset URI, MD5 hash of extra + # Dictionary to store input assets, counted by unique key (asset URI, MD5 hash of extra # dictionary, and LineageContext's unique identifier) - self._inputs: dict[str, tuple[Dataset, LineageContext]] = {} - self._outputs: dict[str, tuple[Dataset, LineageContext]] = {} + self._inputs: dict[str, tuple[Asset, LineageContext]] = {} + self._outputs: dict[str, tuple[Asset, LineageContext]] = {} self._input_counts: dict[str, int] = defaultdict(int) self._output_counts: dict[str, int] = defaultdict(int) - def _generate_key(self, dataset: Dataset, context: LineageContext) -> str: + def _generate_key(self, asset: Asset, context: LineageContext) -> str: """ - Generate a unique key for the given dataset and context. + Generate a unique key for the given asset and context. - This method creates a unique key by combining the dataset URI, the MD5 hash of the dataset's extra + This method creates a unique key by combining the asset URI, the MD5 hash of the asset's extra dictionary, and the LineageContext's unique identifier. This ensures that the generated key is - unique for each combination of dataset and context. + unique for each combination of asset and context. """ - extra_str = json.dumps(dataset.extra, sort_keys=True) + extra_str = json.dumps(asset.extra, sort_keys=True) extra_hash = hashlib.md5(extra_str.encode()).hexdigest() - return f"{dataset.uri}_{extra_hash}_{id(context)}" + return f"{asset.uri}_{extra_hash}_{id(context)}" - def create_dataset( - self, scheme: str | None, uri: str | None, dataset_kwargs: dict | None, dataset_extra: dict | None - ) -> Dataset | None: + def create_asset( + self, scheme: str | None, uri: str | None, asset_kwargs: dict | None, asset_extra: dict | None + ) -> Asset | None: """ - Create a Dataset instance using the provided parameters. + Create an asset instance using the provided parameters. - This method attempts to create a Dataset instance using the given parameters. - It first checks if a URI is provided and falls back to using the default dataset factory + This method attempts to create an asset instance using the given parameters. + It first checks if a URI is provided and falls back to using the default asset factory with the given URI if no other information is available. - If a scheme is provided but no URI, it attempts to find a dataset factory that matches + If a scheme is provided but no URI, it attempts to find an asset factory that matches the given scheme. If no such factory is found, it logs an error message and returns None. - If dataset_kwargs is provided, it is used to pass additional parameters to the Dataset - factory. The dataset_extra parameter is also passed to the factory as an ``extra`` parameter. + If asset_kwargs is provided, it is used to pass additional parameters to the asset + factory. The asset_extra parameter is also passed to the factory as an ``extra`` parameter. """ if uri: # Fallback to default factory using the provided URI - return Dataset(uri=uri, extra=dataset_extra) + return Asset(uri=uri, extra=asset_extra) if not scheme: self.log.debug( - "Missing required parameter: either 'uri' or 'scheme' must be provided to create a Dataset." + "Missing required parameter: either 'uri' or 'scheme' must be provided to create an asset." ) return None - dataset_factory = ProvidersManager().dataset_factories.get(scheme) - if not dataset_factory: - self.log.debug("Unsupported scheme: %s. Please provide a valid URI to create a Dataset.", scheme) + asset_factory = ProvidersManager().asset_factories.get(scheme) + if not asset_factory: + self.log.debug("Unsupported scheme: %s. Please provide a valid URI to create an asset.", scheme) return None - dataset_kwargs = dataset_kwargs or {} + asset_kwargs = asset_kwargs or {} try: - return dataset_factory(**dataset_kwargs, extra=dataset_extra) + return asset_factory(**asset_kwargs, extra=asset_extra) except Exception as e: - self.log.debug("Failed to create dataset. Skipping. Error: %s", e) + self.log.debug("Failed to create asset. Skipping. Error: %s", e) return None - def add_input_dataset( + def add_input_asset( self, context: LineageContext, scheme: str | None = None, uri: str | None = None, - dataset_kwargs: dict | None = None, - dataset_extra: dict | None = None, + asset_kwargs: dict | None = None, + asset_extra: dict | None = None, ): - """Add the input dataset and its corresponding hook execution context to the collector.""" - dataset = self.create_dataset( - scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs, dataset_extra=dataset_extra - ) - if dataset: - key = self._generate_key(dataset, context) + """Add the input asset and its corresponding hook execution context to the collector.""" + asset = self.create_asset(scheme=scheme, uri=uri, asset_kwargs=asset_kwargs, asset_extra=asset_extra) + if asset: + key = self._generate_key(asset, context) if key not in self._inputs: - self._inputs[key] = (dataset, context) + self._inputs[key] = (asset, context) self._input_counts[key] += 1 - def add_output_dataset( + def add_output_asset( self, context: LineageContext, scheme: str | None = None, uri: str | None = None, - dataset_kwargs: dict | None = None, - dataset_extra: dict | None = None, + asset_kwargs: dict | None = None, + asset_extra: dict | None = None, ): - """Add the output dataset and its corresponding hook execution context to the collector.""" - dataset = self.create_dataset( - scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs, dataset_extra=dataset_extra - ) - if dataset: - key = self._generate_key(dataset, context) + """Add the output asset and its corresponding hook execution context to the collector.""" + asset = self.create_asset(scheme=scheme, uri=uri, asset_kwargs=asset_kwargs, asset_extra=asset_extra) + if asset: + key = self._generate_key(asset, context) if key not in self._outputs: - self._outputs[key] = (dataset, context) + self._outputs[key] = (asset, context) self._output_counts[key] += 1 @property - def collected_datasets(self) -> HookLineage: + def collected_assets(self) -> HookLineage: """Get the collected hook lineage information.""" return HookLineage( [ - DatasetLineageInfo(dataset=dataset, count=self._input_counts[key], context=context) - for key, (dataset, context) in self._inputs.items() + AssetLineageInfo(asset=asset, count=self._input_counts[key], context=context) + for key, (asset, context) in self._inputs.items() ], [ - DatasetLineageInfo(dataset=dataset, count=self._output_counts[key], context=context) - for key, (dataset, context) in self._outputs.items() + AssetLineageInfo(asset=asset, count=self._output_counts[key], context=context) + for key, (asset, context) in self._outputs.items() ], ) @property def has_collected(self) -> bool: - """Check if any datasets have been collected.""" + """Check if any assets have been collected.""" return len(self._inputs) != 0 or len(self._outputs) != 0 @@ -195,14 +191,14 @@ class NoOpCollector(HookLineageCollector): It is used when you want to disable lineage collection. """ - def add_input_dataset(self, *_, **__): + def add_input_asset(self, *_, **__): pass - def add_output_dataset(self, *_, **__): + def add_output_asset(self, *_, **__): pass @property - def collected_datasets( + def collected_assets( self, ) -> HookLineage: self.log.warning( @@ -219,7 +215,7 @@ def __init__(self, **kwargs): def retrieve_hook_lineage(self) -> HookLineage: """Retrieve hook lineage from HookLineageCollector.""" - hook_lineage = self.lineage_collector.collected_datasets + hook_lineage = self.lineage_collector.collected_assets return hook_lineage diff --git a/airflow/listeners/listener.py b/airflow/listeners/listener.py index 57d0360487bbf..5e8fba55d4395 100644 --- a/airflow/listeners/listener.py +++ b/airflow/listeners/listener.py @@ -46,13 +46,13 @@ class ListenerManager: """Manage listener registration and provides hook property for calling them.""" def __init__(self): - from airflow.listeners.spec import dagrun, dataset, importerrors, lifecycle, taskinstance + from airflow.listeners.spec import asset, dagrun, importerrors, lifecycle, taskinstance self.pm = pluggy.PluginManager("airflow") self.pm.add_hookcall_monitoring(_before_hookcall, _after_hookcall) self.pm.add_hookspecs(lifecycle) self.pm.add_hookspecs(dagrun) - self.pm.add_hookspecs(dataset) + self.pm.add_hookspecs(asset) self.pm.add_hookspecs(taskinstance) self.pm.add_hookspecs(importerrors) diff --git a/airflow/listeners/spec/dataset.py b/airflow/listeners/spec/asset.py similarity index 76% rename from airflow/listeners/spec/dataset.py rename to airflow/listeners/spec/asset.py index eee1a10dd7d89..78b14c8b10aeb 100644 --- a/airflow/listeners/spec/dataset.py +++ b/airflow/listeners/spec/asset.py @@ -22,27 +22,21 @@ from pluggy import HookspecMarker if TYPE_CHECKING: - from airflow.datasets import Dataset, DatasetAlias + from airflow.assets import Asset, AssetAlias hookspec = HookspecMarker("airflow") @hookspec -def on_dataset_created( - dataset: Dataset, -): - """Execute when a new dataset is created.""" +def on_asset_created(asset: Asset): + """Execute when a new asset is created.""" @hookspec -def on_dataset_alias_created( - dataset_alias: DatasetAlias, -): +def on_asset_alias_created(dataset_alias: AssetAlias): """Execute when a new dataset alias is created.""" @hookspec -def on_dataset_changed( - dataset: Dataset, -): - """Execute when dataset change is registered.""" +def on_asset_changed(asset: Asset): + """Execute when asset change is registered.""" diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py index 7bf23e1bbb7d1..375761bc20f52 100644 --- a/airflow/models/__init__.py +++ b/airflow/models/__init__.py @@ -58,9 +58,9 @@ def import_all_models(): for name in __lazy_imports: __getattr__(name) + import airflow.models.asset import airflow.models.backfill import airflow.models.dagwarning - import airflow.models.dataset import airflow.models.errors import airflow.models.serialized_dag import airflow.models.taskinstancehistory diff --git a/airflow/models/dataset.py b/airflow/models/asset.py similarity index 83% rename from airflow/models/dataset.py rename to airflow/models/asset.py index 489d6b68a6f15..b99aa86f2c889 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/asset.py @@ -34,7 +34,7 @@ ) from sqlalchemy.orm import relationship -from airflow.datasets import Dataset, DatasetAlias +from airflow.assets import Asset, AssetAlias from airflow.models.base import Base, StringID from airflow.settings import json from airflow.utils import timezone @@ -83,11 +83,11 @@ ) -class DatasetAliasModel(Base): +class AssetAliasModel(Base): """ - A table to store dataset alias. + A table to store asset alias. - :param uri: a string that uniquely identifies the dataset alias + :param uri: a string that uniquely identifies the asset alias """ id = Column(Integer, primary_key=True, autoincrement=True) @@ -111,19 +111,19 @@ class DatasetAliasModel(Base): ) datasets = relationship( - "DatasetModel", + "AssetModel", secondary=alias_association_table, backref="aliases", ) dataset_events = relationship( - "DatasetEvent", + "AssetEvent", secondary=dataset_alias_dataset_event_assocation_table, back_populates="source_aliases", ) - consuming_dags = relationship("DagScheduleDatasetAliasReference", back_populates="dataset_alias") + consuming_dags = relationship("DagScheduleAssetAliasReference", back_populates="dataset_alias") @classmethod - def from_public(cls, obj: DatasetAlias) -> DatasetAliasModel: + def from_public(cls, obj: AssetAlias) -> AssetAliasModel: return cls(name=obj.name) def __repr__(self): @@ -133,20 +133,20 @@ def __hash__(self): return hash(self.name) def __eq__(self, other): - if isinstance(other, (self.__class__, DatasetAlias)): + if isinstance(other, (self.__class__, AssetAlias)): return self.name == other.name else: return NotImplemented - def to_public(self) -> DatasetAlias: - return DatasetAlias(name=self.name) + def to_public(self) -> AssetAlias: + return AssetAlias(name=self.name) -class DatasetModel(Base): +class AssetModel(Base): """ - A table to store datasets. + A table to store assets. - :param uri: a string that uniquely identifies the dataset + :param uri: a string that uniquely identifies the asset :param extra: JSON field for arbitrary extra info """ @@ -168,8 +168,8 @@ class DatasetModel(Base): updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) is_orphaned = Column(Boolean, default=False, nullable=False, server_default="0") - consuming_dags = relationship("DagScheduleDatasetReference", back_populates="dataset") - producing_tasks = relationship("TaskOutletDatasetReference", back_populates="dataset") + consuming_dags = relationship("DagScheduleAssetReference", back_populates="dataset") + producing_tasks = relationship("TaskOutletAssetReference", back_populates="dataset") __tablename__ = "dataset" __table_args__ = ( @@ -178,7 +178,7 @@ class DatasetModel(Base): ) @classmethod - def from_public(cls, obj: Dataset) -> DatasetModel: + def from_public(cls, obj: Asset) -> AssetModel: return cls(uri=obj.uri, extra=obj.extra) def __init__(self, uri: str, **kwargs): @@ -192,7 +192,7 @@ def __init__(self, uri: str, **kwargs): super().__init__(uri=uri, **kwargs) def __eq__(self, other): - if isinstance(other, (self.__class__, Dataset)): + if isinstance(other, (self.__class__, Asset)): return self.uri == other.uri else: return NotImplemented @@ -203,19 +203,19 @@ def __hash__(self): def __repr__(self): return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})" - def to_public(self) -> Dataset: - return Dataset(uri=self.uri, extra=self.extra) + def to_public(self) -> Asset: + return Asset(uri=self.uri, extra=self.extra) -class DagScheduleDatasetAliasReference(Base): - """References from a DAG to a dataset alias of which it is a consumer.""" +class DagScheduleAssetAliasReference(Base): + """References from a DAG to an asset alias of which it is a consumer.""" alias_id = Column(Integer, primary_key=True, nullable=False) dag_id = Column(StringID(), primary_key=True, nullable=False) created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) - dataset_alias = relationship("DatasetAliasModel", back_populates="consuming_dags") + dataset_alias = relationship("AssetAliasModel", back_populates="consuming_dags") dag = relationship("DagModel", back_populates="schedule_dataset_alias_references") __tablename__ = "dag_schedule_dataset_alias_reference" @@ -251,22 +251,22 @@ def __repr__(self): return f"{self.__class__.__name__}({', '.join(args)})" -class DagScheduleDatasetReference(Base): - """References from a DAG to a dataset of which it is a consumer.""" +class DagScheduleAssetReference(Base): + """References from a DAG to an asset of which it is a consumer.""" dataset_id = Column(Integer, primary_key=True, nullable=False) dag_id = Column(StringID(), primary_key=True, nullable=False) created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) - dataset = relationship("DatasetModel", back_populates="consuming_dags") + dataset = relationship("AssetModel", back_populates="consuming_dags") dag = relationship("DagModel", back_populates="schedule_dataset_references") queue_records = relationship( - "DatasetDagRunQueue", + "AssetDagRunQueue", primaryjoin="""and_( - DagScheduleDatasetReference.dataset_id == foreign(DatasetDagRunQueue.dataset_id), - DagScheduleDatasetReference.dag_id == foreign(DatasetDagRunQueue.target_dag_id), + DagScheduleAssetReference.dataset_id == foreign(AssetDagRunQueue.dataset_id), + DagScheduleAssetReference.dag_id == foreign(AssetDagRunQueue.target_dag_id), )""", cascade="all, delete, delete-orphan", ) @@ -305,8 +305,8 @@ def __repr__(self): return f"{self.__class__.__name__}({', '.join(args)})" -class TaskOutletDatasetReference(Base): - """References from a task to a dataset that it updates / produces.""" +class TaskOutletAssetReference(Base): + """References from a task to an asset that it updates / produces.""" dataset_id = Column(Integer, primary_key=True, nullable=False) dag_id = Column(StringID(), primary_key=True, nullable=False) @@ -314,7 +314,7 @@ class TaskOutletDatasetReference(Base): created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) - dataset = relationship("DatasetModel", back_populates="producing_tasks") + dataset = relationship("AssetModel", back_populates="producing_tasks") __tablename__ = "task_outlet_dataset_reference" __table_args__ = ( @@ -354,13 +354,13 @@ def __repr__(self): return f"{self.__class__.__name__}({', '.join(args)})" -class DatasetDagRunQueue(Base): - """Model for storing dataset events that need processing.""" +class AssetDagRunQueue(Base): + """Model for storing asset events that need processing.""" dataset_id = Column(Integer, primary_key=True, nullable=False) target_dag_id = Column(StringID(), primary_key=True, nullable=False) created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - dataset = relationship("DatasetModel", viewonly=True) + dataset = relationship("AssetModel", viewonly=True) __tablename__ = "dataset_dag_run_queue" __table_args__ = ( PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey"), @@ -405,19 +405,19 @@ def __repr__(self): ) -class DatasetEvent(Base): +class AssetEvent(Base): """ - A table to store datasets events. + A table to store assets events. - :param dataset_id: reference to DatasetModel record + :param dataset_id: reference to AssetModel record :param extra: JSON field for arbitrary extra info - :param source_task_id: the task_id of the TI which updated the dataset - :param source_dag_id: the dag_id of the TI which updated the dataset - :param source_run_id: the run_id of the TI which updated the dataset - :param source_map_index: the map_index of the TI which updated the dataset + :param source_task_id: the task_id of the TI which updated the asset + :param source_dag_id: the dag_id of the TI which updated the asset + :param source_run_id: the run_id of the TI which updated the asset + :param source_map_index: the map_index of the TI which updated the asset :param timestamp: the time the event was logged - We use relationships instead of foreign keys so that dataset events are not deleted even + We use relationships instead of foreign keys so that asset events are not deleted even if the foreign key object is. """ @@ -443,7 +443,7 @@ class DatasetEvent(Base): ) source_aliases = relationship( - "DatasetAliasModel", + "AssetAliasModel", secondary=dataset_alias_dataset_event_assocation_table, back_populates="dataset_events", ) @@ -451,10 +451,10 @@ class DatasetEvent(Base): source_task_instance = relationship( "TaskInstance", primaryjoin="""and_( - DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id), - DatasetEvent.source_run_id == foreign(TaskInstance.run_id), - DatasetEvent.source_task_id == foreign(TaskInstance.task_id), - DatasetEvent.source_map_index == foreign(TaskInstance.map_index), + AssetEvent.source_dag_id == foreign(TaskInstance.dag_id), + AssetEvent.source_run_id == foreign(TaskInstance.run_id), + AssetEvent.source_task_id == foreign(TaskInstance.task_id), + AssetEvent.source_map_index == foreign(TaskInstance.map_index), )""", viewonly=True, lazy="select", @@ -463,16 +463,16 @@ class DatasetEvent(Base): source_dag_run = relationship( "DagRun", primaryjoin="""and_( - DatasetEvent.source_dag_id == foreign(DagRun.dag_id), - DatasetEvent.source_run_id == foreign(DagRun.run_id), + AssetEvent.source_dag_id == foreign(DagRun.dag_id), + AssetEvent.source_run_id == foreign(DagRun.run_id), )""", viewonly=True, lazy="select", uselist=False, ) dataset = relationship( - DatasetModel, - primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)", + AssetModel, + primaryjoin="AssetEvent.dataset_id == foreign(AssetModel.id)", viewonly=True, lazy="select", uselist=False, diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 91f8aec7302cb..0632819952ae4 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -81,8 +81,8 @@ import airflow.templates from airflow import settings, utils from airflow.api_internal.internal_api_call import internal_api_call +from airflow.assets import Asset, AssetAlias, AssetAll, BaseAsset from airflow.configuration import conf as airflow_conf, secrets_backend_list -from airflow.datasets import BaseDataset, Dataset, DatasetAlias, DatasetAll from airflow.exceptions import ( AirflowException, DuplicateTaskIdFound, @@ -96,12 +96,15 @@ from airflow.executors.executor_loader import ExecutorLoader from airflow.jobs.job import run_job from airflow.models.abstractoperator import AbstractOperator, TaskStateChangeCallback +from airflow.models.asset import ( + AssetDagRunQueue, + AssetModel, +) from airflow.models.base import Base, StringID from airflow.models.baseoperator import BaseOperator from airflow.models.dagcode import DagCode from airflow.models.dagpickle import DagPickle from airflow.models.dagrun import RUN_ID_REGEX, DagRun -from airflow.models.dataset import DatasetDagRunQueue from airflow.models.param import DagParam, ParamsDict from airflow.models.taskinstance import ( Context, @@ -118,8 +121,8 @@ from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable from airflow.timetables.simple import ( + AssetTriggeredTimetable, ContinuousTimetable, - DatasetTriggeredTimetable, NullTimetable, OnceTimetable, ) @@ -163,8 +166,8 @@ ScheduleArg = Union[ ScheduleInterval, Timetable, - BaseDataset, - Collection[Union["Dataset", "DatasetAlias"]], + BaseAsset, + Collection[Union["Asset", "AssetAlias"]], ] @@ -240,16 +243,16 @@ def get_last_dagrun(dag_id, session, include_externally_triggered=False): return session.scalar(query.limit(1)) -def get_dataset_triggered_next_run_info( +def get_asset_triggered_next_run_info( dag_ids: list[str], *, session: Session ) -> dict[str, dict[str, int | str]]: """ Get next run info for a list of dag_ids. - Given a list of dag_ids, get string representing how close any that are dataset triggered are - their next run, e.g. "1 of 2 datasets updated". + Given a list of dag_ids, get string representing how close any that are asset triggered are + their next run, e.g. "1 of 2 assets updated". """ - from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ, DatasetModel + from airflow.models.asset import AssetDagRunQueue as ADRQ, DagScheduleAssetReference return { x.dag_id: { @@ -259,24 +262,24 @@ def get_dataset_triggered_next_run_info( } for x in session.execute( select( - DagScheduleDatasetReference.dag_id, + DagScheduleAssetReference.dag_id, # This is a dirty hack to workaround group by requiring an aggregate, - # since grouping by dataset is not what we want to do here...but it works - case((func.count() == 1, func.max(DatasetModel.uri)), else_="").label("uri"), + # since grouping by asset is not what we want to do here...but it works + case((func.count() == 1, func.max(AssetModel.uri)), else_="").label("uri"), func.count().label("total"), - func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)).label("ready"), + func.sum(case((ADRQ.target_dag_id.is_not(None), 1), else_=0)).label("ready"), ) .join( - DDRQ, + ADRQ, and_( - DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id, - DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id, + ADRQ.dataset_id == DagScheduleAssetReference.dataset_id, + ADRQ.target_dag_id == DagScheduleAssetReference.dag_id, ), isouter=True, ) - .join(DatasetModel, DatasetModel.id == DagScheduleDatasetReference.dataset_id) - .group_by(DagScheduleDatasetReference.dag_id) - .where(DagScheduleDatasetReference.dag_id.in_(dag_ids)) + .join(AssetModel, AssetModel.id == DagScheduleAssetReference.dataset_id) + .group_by(DagScheduleAssetReference.dag_id) + .where(DagScheduleAssetReference.dag_id.in_(dag_ids)) ).all() } @@ -386,7 +389,7 @@ class DAG(LoggingMixin): :param description: The description for the DAG to e.g. be shown on the webserver :param schedule: If provided, this defines the rules according to which DAG runs are scheduled. Possible values include a cron expression string, - timedelta object, Timetable, or list of Dataset objects. + timedelta object, Timetable, or list of Asset objects. See also :doc:`/howto/timetable`. :param start_date: The timestamp from which the scheduler will attempt to backfill. If this is not provided, backfilling must be done @@ -595,12 +598,12 @@ def __init__( if isinstance(schedule, Timetable): self.timetable = schedule - elif isinstance(schedule, BaseDataset): - self.timetable = DatasetTriggeredTimetable(schedule) + elif isinstance(schedule, BaseAsset): + self.timetable = AssetTriggeredTimetable(schedule) elif isinstance(schedule, Collection) and not isinstance(schedule, str): - if not all(isinstance(x, (Dataset, DatasetAlias)) for x in schedule): - raise ValueError("All elements in 'schedule' should be datasets or dataset aliases") - self.timetable = DatasetTriggeredTimetable(DatasetAll(*schedule)) + if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule): + raise ValueError("All elements in 'schedule' should be assets or asset aliases") + self.timetable = AssetTriggeredTimetable(AssetAll(*schedule)) else: self.timetable = create_timetable(schedule, self.timezone) @@ -873,7 +876,7 @@ def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval: :meta private: """ timetable_type = type(self.timetable) - if issubclass(timetable_type, (NullTimetable, OnceTimetable, DatasetTriggeredTimetable)): + if issubclass(timetable_type, (NullTimetable, OnceTimetable, AssetTriggeredTimetable)): return DataInterval.exact(timezone.coerce_datetime(logical_date)) start = timezone.coerce_datetime(logical_date) if issubclass(timetable_type, CronDataIntervalTimetable): @@ -2649,7 +2652,7 @@ def bulk_write_to_db( if not dags: return - from airflow.dag_processing.collection import DagModelOperation, DatasetModelOperation + from airflow.dag_processing.collection import AssetModelOperation, DagModelOperation log.info("Sync %s DAGs", len(dags)) dag_op = DagModelOperation({dag.dag_id: dag for dag in dags}) @@ -2658,15 +2661,15 @@ def bulk_write_to_db( dag_op.update_dags(orm_dags, processor_subdir=processor_subdir, session=session) DagCode.bulk_sync_to_db((dag.fileloc for dag in dags), session=session) - dataset_op = DatasetModelOperation.collect(dag_op.dags) + asset_op = AssetModelOperation.collect(dag_op.dags) - orm_datasets = dataset_op.add_datasets(session=session) - orm_dataset_aliases = dataset_op.add_dataset_aliases(session=session) + orm_assets = asset_op.add_assets(session=session) + orm_asset_aliases = asset_op.add_asset_aliases(session=session) session.flush() # This populates id so we can create fks in later calls. - dataset_op.add_dag_dataset_references(orm_dags, orm_datasets, session=session) - dataset_op.add_dag_dataset_alias_references(orm_dags, orm_dataset_aliases, session=session) - dataset_op.add_task_dataset_references(orm_dags, orm_datasets, session=session) + asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session) + asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session) + asset_op.add_task_asset_references(orm_dags, orm_assets, session=session) session.flush() @provide_session @@ -2963,18 +2966,18 @@ class DagModel(Base): __table_args__ = (Index("idx_next_dagrun_create_after", next_dagrun_create_after, unique=False),) schedule_dataset_references = relationship( - "DagScheduleDatasetReference", + "DagScheduleAssetReference", back_populates="dag", cascade="all, delete, delete-orphan", ) schedule_dataset_alias_references = relationship( - "DagScheduleDatasetAliasReference", + "DagScheduleAssetAliasReference", back_populates="dag", cascade="all, delete, delete-orphan", ) schedule_datasets = association_proxy("schedule_dataset_references", "dataset") task_outlet_dataset_references = relationship( - "TaskOutletDatasetReference", + "TaskOutletAssetReference", cascade="all, delete, delete-orphan", ) NUM_DAGS_PER_DAGRUN_QUERY = airflow_conf.getint( @@ -3155,7 +3158,7 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[ """ from airflow.models.serialized_dag import SerializedDagModel - def dag_ready(dag_id: str, cond: BaseDataset, statuses: dict) -> bool | None: + def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict) -> bool | None: # if dag was serialized before 2.9 and we *just* upgraded, # we may be dealing with old version. In that case, # just wait for the dag to be reserialized. @@ -3165,8 +3168,8 @@ def dag_ready(dag_id: str, cond: BaseDataset, statuses: dict) -> bool | None: log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id) return None - # this loads all the DDRQ records.... may need to limit num dags - all_records = session.scalars(select(DatasetDagRunQueue)).all() + # this loads all the ADRQ records.... may need to limit num dags + all_records = session.scalars(select(AssetDagRunQueue)).all() by_dag = defaultdict(list) for r in all_records: by_dag[r.target_dag_id].append(r) @@ -3181,7 +3184,7 @@ def dag_ready(dag_id: str, cond: BaseDataset, statuses: dict) -> bool | None: dag_id = ser_dag.dag_id statuses = dag_statuses[dag_id] - if not dag_ready(dag_id, cond=ser_dag.dag.timetable.dataset_condition, statuses=statuses): + if not dag_ready(dag_id, cond=ser_dag.dag.timetable.asset_condition, statuses=statuses): del by_dag[dag_id] del dag_statuses[dag_id] del dag_statuses @@ -3265,13 +3268,13 @@ def calculate_dagrun_date_fields( ) @provide_session - def get_dataset_triggered_next_run_info(self, *, session=NEW_SESSION) -> dict[str, int | str] | None: + def get_asset_triggered_next_run_info(self, *, session=NEW_SESSION) -> dict[str, int | str] | None: if self.dataset_expression is None: return None - # When a dataset alias does not resolve into datasets, get_dataset_triggered_next_run_info returns - # an empty dict as there's no dataset info to get. This method should thus return None. - return get_dataset_triggered_next_run_info([self.dag_id], session=session).get(self.dag_id, None) + # When an asset alias does not resolve into assets, get_asset_triggered_next_run_info returns + # an empty dict as there's no asset info to get. This method should thus return None. + return get_asset_triggered_next_run_info([self.dag_id], session=session).get(self.dag_id, None) # NOTE: Please keep the list of arguments in sync with DAG.__init__. diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index c17acdd2b7212..b19e65486307d 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -67,10 +67,10 @@ from airflow import settings from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call +from airflow.assets import Asset, AssetAlias +from airflow.assets.manager import asset_manager from airflow.compat.functools import cache from airflow.configuration import conf -from airflow.datasets import Dataset, DatasetAlias -from airflow.datasets.manager import dataset_manager from airflow.exceptions import ( AirflowException, AirflowFailException, @@ -87,9 +87,9 @@ XComForMappingNotPushed, ) from airflow.listeners.listener import get_listener_manager +from airflow.models.asset import AssetEvent, AssetModel from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel from airflow.models.dagbag import DagBag -from airflow.models.dataset import DatasetModel from airflow.models.log import Log from airflow.models.param import process_params from airflow.models.renderedtifields import get_serialized_template_fields @@ -154,13 +154,13 @@ from sqlalchemy.sql.expression import ColumnOperators from airflow.models.abstractoperator import TaskStateChangeCallback + from airflow.models.asset import AssetEvent from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun - from airflow.models.dataset import DatasetEvent from airflow.models.operator import Operator + from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag import DagModelPydantic - from airflow.serialization.pydantic.dataset import DatasetEventPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.timetables.base import DataInterval from airflow.typing_compat import Literal, TypeGuard @@ -366,7 +366,7 @@ def _run_raw_task( if not test_mode: _add_log(event=ti.state, task_instance=ti, session=session) if ti.state == TaskInstanceState.SUCCESS: - ti._register_dataset_changes(events=context["outlet_events"], session=session) + ti._register_asset_changes(events=context["outlet_events"], session=session) TaskInstance.save_to_db(ti=ti, session=session) if ti.state == TaskInstanceState.SUCCESS: @@ -1077,7 +1077,7 @@ def get_prev_ds_nodash() -> str | None: return None return prev_ds.replace("-", "") - def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydantic]]: + def get_triggering_events() -> dict[str, list[AssetEvent | AssetEventPydantic]]: if TYPE_CHECKING: assert session is not None @@ -1087,9 +1087,9 @@ def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydanti nonlocal dag_run if dag_run not in session: dag_run = session.merge(dag_run, load=False) - dataset_events = dag_run.consumed_dataset_events - triggering_events: dict[str, list[DatasetEvent | DatasetEventPydantic]] = defaultdict(list) - for event in dataset_events: + asset_events = dag_run.consumed_dataset_events + triggering_events: dict[str, list[AssetEvent | AssetEventPydantic]] = defaultdict(list) + for event in asset_events: if event.dataset: triggering_events[event.dataset.uri].append(event) @@ -1144,7 +1144,7 @@ def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydanti "ti": task_instance, "tomorrow_ds": get_tomorrow_ds(), "tomorrow_ds_nodash": get_tomorrow_ds_nodash(), - "triggering_dataset_events": lazy_object_proxy.Proxy(get_triggering_events), + "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events), "ts": ts, "ts_nodash": ts_nodash, "ts_nodash_with_tz": ts_nodash_with_tz, @@ -2886,56 +2886,56 @@ def _run_raw_task( session=session, ) - def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Session) -> None: + def _register_asset_changes(self, *, events: OutletEventAccessors, session: Session) -> None: if TYPE_CHECKING: assert self.task - # One task only triggers one dataset event for each dataset with the same extra. - # This tuple[dataset uri, extra] to sets alias names mapping is used to find whether - # there're datasets with same uri but different extra that we need to emit more than one dataset events. - dataset_alias_names: dict[tuple[str, frozenset], set[str]] = defaultdict(set) + # One task only triggers one asset event for each asset with the same extra. + # This tuple[asset uri, extra] to sets alias names mapping is used to find whether + # there're assets with same uri but different extra that we need to emit more than one asset events. + asset_alias_names: dict[tuple[str, frozenset], set[str]] = defaultdict(set) for obj in self.task.outlets or []: self.log.debug("outlet obj %s", obj) - # Lineage can have other types of objects besides datasets - if isinstance(obj, Dataset): - dataset_manager.register_dataset_change( + # Lineage can have other types of objects besides assets + if isinstance(obj, Asset): + asset_manager.register_asset_change( task_instance=self, - dataset=obj, + asset=obj, extra=events[obj].extra, session=session, ) - elif isinstance(obj, DatasetAlias): - for dataset_alias_event in events[obj].dataset_alias_events: - dataset_alias_name = dataset_alias_event["source_alias_name"] - dataset_uri = dataset_alias_event["dest_dataset_uri"] - frozen_extra = frozenset(dataset_alias_event["extra"].items()) - dataset_alias_names[(dataset_uri, frozen_extra)].add(dataset_alias_name) - - dataset_models: dict[str, DatasetModel] = { + elif isinstance(obj, AssetAlias): + for asset_alias_event in events[obj].asset_alias_events: + asset_alias_name = asset_alias_event["source_alias_name"] + asset_uri = asset_alias_event["dest_asset_uri"] + frozen_extra = frozenset(asset_alias_event["extra"].items()) + asset_alias_names[(asset_uri, frozen_extra)].add(asset_alias_name) + + dataset_models: dict[str, AssetModel] = { dataset_obj.uri: dataset_obj for dataset_obj in session.scalars( - select(DatasetModel).where(DatasetModel.uri.in_(uri for uri, _ in dataset_alias_names)) + select(AssetModel).where(AssetModel.uri.in_(uri for uri, _ in asset_alias_names)) ) } - if missing_datasets := [Dataset(uri=u) for u, _ in dataset_alias_names if u not in dataset_models]: + if missing_datasets := [Asset(uri=u) for u, _ in asset_alias_names if u not in dataset_models]: dataset_models.update( (dataset_obj.uri, dataset_obj) - for dataset_obj in dataset_manager.create_datasets(missing_datasets, session=session) + for dataset_obj in asset_manager.create_assets(missing_datasets, session=session) ) self.log.warning("Created new datasets for alias reference: %s", missing_datasets) session.flush() # Needed because we need the id for fk. - for (uri, extra_items), alias_names in dataset_alias_names.items(): - dataset_obj = dataset_models[uri] + for (uri, extra_items), alias_names in asset_alias_names.items(): + asset_obj = dataset_models[uri] self.log.info( 'Creating event for %r through aliases "%s"', - dataset_obj, + asset_obj, ", ".join(alias_names), ) - dataset_manager.register_dataset_change( + asset_manager.register_asset_change( task_instance=self, - dataset=dataset_obj.to_public(), - aliases=[DatasetAlias(name) for name in alias_names], + asset=asset_obj, + aliases=[AssetAlias(name) for name in alias_names], extra=dict(extra_items), session=session, source_alias_names=alias_names, diff --git a/airflow/operators/python.py b/airflow/operators/python.py index 25c0b8ca68632..a4788caedf438 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -234,7 +234,7 @@ def __init__( def execute(self, context: Context) -> Any: context_merge(context, self.op_kwargs, templates_dict=self.templates_dict) self.op_kwargs = self.determine_kwargs(context) - self._dataset_events = context_get_outlet_events(context) + self._asset_events = context_get_outlet_events(context) return_value = self.execute_callable() if self.show_return_value_in_logs: @@ -253,7 +253,7 @@ def execute_callable(self) -> Any: :return: the return value of the call. """ - runner = ExecutionCallableRunner(self.python_callable, self._dataset_events, logger=self.log) + runner = ExecutionCallableRunner(self.python_callable, self._asset_events, logger=self.log) return runner.run(*self.op_args, **self.op_kwargs) @@ -424,7 +424,7 @@ class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta): "dag_run", "task", "params", - "triggering_dataset_events", + "triggering_asset_events", } def __init__( diff --git a/airflow/provider.yaml.schema.json b/airflow/provider.yaml.schema.json index 8f11833ee1c8f..35e266c310ac3 100644 --- a/airflow/provider.yaml.schema.json +++ b/airflow/provider.yaml.schema.json @@ -196,9 +196,37 @@ "type": "string" } }, + "asset-uris": { + "type": "array", + "description": "Asset URI formats", + "items": { + "type": "object", + "properties": { + "schemes": { + "type": "array", + "description": "List of supported URI schemes", + "items": { + "type": "string" + } + }, + "handler": { + "type": ["string", "null"], + "description": "Normalization function for specified URI schemes. Import path to a callable taking and returning a SplitResult. 'null' specifies a no-op." + }, + "factory": { + "type": ["string", "null"], + "description": "Dataset factory for specified URI. Creates AIP-60 compliant Dataset." + }, + "to_openlineage_converter": { + "type": ["string", "null"], + "description": "OpenLineage converter function for specified URI schemes. Import path to a callable accepting a Dataset and LineageContext and returning OpenLineage dataset." + } + } + } + }, "dataset-uris": { "type": "array", - "description": "Dataset URI formats", + "description": "Dataset URI formats (will be removed in Airflow 3.0)", "items": { "type": "object", "properties": { diff --git a/airflow/providers/amazon/aws/datasets/__init__.py b/airflow/providers/amazon/aws/assets/__init__.py similarity index 100% rename from airflow/providers/amazon/aws/datasets/__init__.py rename to airflow/providers/amazon/aws/assets/__init__.py diff --git a/airflow/providers/amazon/aws/datasets/s3.py b/airflow/providers/amazon/aws/assets/s3.py similarity index 73% rename from airflow/providers/amazon/aws/datasets/s3.py rename to airflow/providers/amazon/aws/assets/s3.py index c42ec2bb1cc03..378e7e977ab56 100644 --- a/airflow/providers/amazon/aws/datasets/s3.py +++ b/airflow/providers/amazon/aws/assets/s3.py @@ -18,17 +18,21 @@ from typing import TYPE_CHECKING -from airflow.datasets import Dataset from airflow.providers.amazon.aws.hooks.s3 import S3Hook +try: + from airflow.assets import Asset +except ModuleNotFoundError: + from airflow.datasets import Dataset as Asset # type: ignore[no-redef] + if TYPE_CHECKING: from urllib.parse import SplitResult from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset -def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset: - return Dataset(uri=f"s3://{bucket}/{key}", extra=extra) +def create_asset(*, bucket: str, key: str, extra=None) -> Asset: + return Asset(uri=f"s3://{bucket}/{key}", extra=extra) def sanitize_uri(uri: SplitResult) -> SplitResult: @@ -37,9 +41,9 @@ def sanitize_uri(uri: SplitResult) -> SplitResult: return uri -def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLineageDataset: - """Translate Dataset with valid AIP-60 uri to OpenLineage with assistance from the hook.""" +def convert_asset_to_openlineage(asset: Asset, lineage_context) -> OpenLineageDataset: + """Translate Asset with valid AIP-60 uri to OpenLineage with assistance from the hook.""" from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset - bucket, key = S3Hook.parse_s3_url(dataset.uri) + bucket, key = S3Hook.parse_s3_url(asset.uri) return OpenLineageDataset(namespace=f"s3://{bucket}", name=key if key else "/") diff --git a/airflow/providers/amazon/aws/auth_manager/avp/entities.py b/airflow/providers/amazon/aws/auth_manager/avp/entities.py index 8c2e8855b877d..4db9aed340208 100644 --- a/airflow/providers/amazon/aws/auth_manager/avp/entities.py +++ b/airflow/providers/amazon/aws/auth_manager/avp/entities.py @@ -33,11 +33,11 @@ class AvpEntities(Enum): USER = "User" # Resource types + ASSET = "Asset" CONFIGURATION = "Configuration" CONNECTION = "Connection" CUSTOM = "Custom" DAG = "Dag" - DATASET = "Dataset" MENU = "Menu" POOL = "Pool" VARIABLE = "Variable" diff --git a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index c8693da3382e0..face67c38fb57 100644 --- a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -63,10 +63,7 @@ IsAuthorizedPoolRequest, IsAuthorizedVariableRequest, ) - from airflow.auth.managers.models.resource_details import ( - ConfigurationDetails, - DatasetDetails, - ) + from airflow.auth.managers.models.resource_details import AssetDetails, ConfigurationDetails from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser from airflow.www.extensions.init_appbuilder import AirflowAppBuilder @@ -161,15 +158,12 @@ def is_authorized_dag( context=context, ) - def is_authorized_dataset( - self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None + def is_authorized_asset( + self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None ) -> bool: - dataset_uri = details.uri if details else None + asset_uri = details.uri if details else None return self.avp_facade.is_authorized( - method=method, - entity_type=AvpEntities.DATASET, - user=user or self.get_user(), - entity_id=dataset_uri, + method=method, entity_type=AvpEntities.ASSET, user=user or self.get_user(), entity_id=asset_uri ) def is_authorized_pool( diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index b609259f846ba..6efb5953b8eee 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -40,8 +40,6 @@ from urllib.parse import urlsplit from uuid import uuid4 -from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector - if TYPE_CHECKING: from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject @@ -50,6 +48,8 @@ with suppress(ImportError): from aiobotocore.client import AioBaseClient +from importlib.util import find_spec + from asgiref.sync import sync_to_async from boto3.s3.transfer import S3Transfer, TransferConfig from botocore.exceptions import ClientError @@ -60,6 +60,13 @@ from airflow.providers.amazon.aws.utils.tags import format_tags from airflow.utils.helpers import chunks +if find_spec("airflow.assets"): + from airflow.lineage.hook import get_hook_lineage_collector +else: + # TODO: import from common.compat directly after common.compat providers with + # asset_compat_lineage_collector released + from airflow.providers.amazon.aws.utils.asset_compat_lineage_collector import get_hook_lineage_collector + logger = logging.getLogger(__name__) @@ -1103,11 +1110,11 @@ def load_file( client = self.get_conn() client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config) - get_hook_lineage_collector().add_input_dataset( - context=self, scheme="file", dataset_kwargs={"path": filename} + get_hook_lineage_collector().add_input_asset( + context=self, scheme="file", asset_kwargs={"path": filename} ) - get_hook_lineage_collector().add_output_dataset( - context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key} + get_hook_lineage_collector().add_output_asset( + context=self, scheme="s3", asset_kwargs={"bucket": bucket_name, "key": key} ) @unify_bucket_name_and_key @@ -1250,8 +1257,8 @@ def _upload_file_obj( Config=self.transfer_config, ) # No input because file_obj can be anything - handle in calling function if possible - get_hook_lineage_collector().add_output_dataset( - context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key} + get_hook_lineage_collector().add_output_asset( + context=self, scheme="s3", asset_kwargs={"bucket": bucket_name, "key": key} ) def copy_object( @@ -1308,11 +1315,11 @@ def copy_object( response = self.get_conn().copy_object( Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs ) - get_hook_lineage_collector().add_input_dataset( - context=self, scheme="s3", dataset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key} + get_hook_lineage_collector().add_input_asset( + context=self, scheme="s3", asset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key} ) - get_hook_lineage_collector().add_output_dataset( - context=self, scheme="s3", dataset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key} + get_hook_lineage_collector().add_output_asset( + context=self, scheme="s3", asset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key} ) return response @@ -1433,10 +1440,10 @@ def download_file( file_path.parent.mkdir(exist_ok=True, parents=True) - get_hook_lineage_collector().add_output_dataset( + get_hook_lineage_collector().add_output_asset( context=self, scheme="file", - dataset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()}, + asset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()}, ) file = open(file_path, "wb") else: @@ -1448,8 +1455,8 @@ def download_file( ExtraArgs=self.extra_args, Config=self.transfer_config, ) - get_hook_lineage_collector().add_input_dataset( - context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key} + get_hook_lineage_collector().add_input_asset( + context=self, scheme="s3", asset_kwargs={"bucket": bucket_name, "key": key} ) return file.name diff --git a/airflow/providers/amazon/aws/utils/asset_compat_lineage_collector.py b/airflow/providers/amazon/aws/utils/asset_compat_lineage_collector.py new file mode 100644 index 0000000000000..50fbc3d0996aa --- /dev/null +++ b/airflow/providers/amazon/aws/utils/asset_compat_lineage_collector.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from importlib.util import find_spec + + +def _get_asset_compat_hook_lineage_collector(): + from airflow.lineage.hook import get_hook_lineage_collector + + collector = get_hook_lineage_collector() + + if all( + getattr(collector, asset_method_name, None) + for asset_method_name in ("add_input_asset", "add_output_asset", "collected_assets") + ): + return collector + + # dataset is renamed as asset in Airflow 3.0 + + from functools import wraps + + from airflow.lineage.hook import DatasetLineageInfo, HookLineage + + DatasetLineageInfo.asset = DatasetLineageInfo.dataset + + def rename_dataset_kwargs_as_assets_kwargs(function): + @wraps(function) + def wrapper(*args, **kwargs): + if "asset_kwargs" in kwargs: + kwargs["dataset_kwargs"] = kwargs.pop("asset_kwargs") + + if "asset_extra" in kwargs: + kwargs["dataset_extra"] = kwargs.pop("asset_extra") + + return function(*args, **kwargs) + + return wrapper + + collector.create_asset = rename_dataset_kwargs_as_assets_kwargs(collector.create_dataset) + collector.add_input_asset = rename_dataset_kwargs_as_assets_kwargs(collector.add_input_dataset) + collector.add_output_asset = rename_dataset_kwargs_as_assets_kwargs(collector.add_output_dataset) + + def collected_assets_compat(collector) -> HookLineage: + """Get the collected hook lineage information.""" + lineage = collector.collected_datasets + return HookLineage( + [ + DatasetLineageInfo(dataset=item.dataset, count=item.count, context=item.context) + for item in lineage.inputs + ], + [ + DatasetLineageInfo(dataset=item.dataset, count=item.count, context=item.context) + for item in lineage.outputs + ], + ) + + setattr( + collector.__class__, + "collected_assets", + property(lambda collector: collected_assets_compat(collector)), + ) + + return collector + + +def get_hook_lineage_collector(): + # HookLineageCollector added in 2.10 + try: + if find_spec("airflow.assets"): + # Dataset has been renamed as Asset in 3.0 + from airflow.lineage.hook import get_hook_lineage_collector + + return get_hook_lineage_collector() + + return _get_asset_compat_hook_lineage_collector() + except ImportError: + + class NoOpCollector: + """ + NoOpCollector is a hook lineage collector that does nothing. + + It is used when you want to disable lineage collection. + """ + + def add_input_asset(self, *_, **__): + pass + + def add_output_asset(self, *_, **__): + pass + + return NoOpCollector() diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 83d66de69a85b..1316cd05231a0 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -563,11 +563,19 @@ sensors: python-modules: - airflow.providers.amazon.aws.sensors.quicksight +asset-uris: + - schemes: [s3] + handler: airflow.providers.amazon.aws.assets.s3.sanitize_uri + to_openlineage_converter: airflow.providers.amazon.aws.assets.s3.convert_asset_to_openlineage + factory: airflow.providers.amazon.aws.assets.s3.create_asset + +# dataset has been renamed to asset in Airflow 3.0 +# This is kept for backward compatibility. dataset-uris: - schemes: [s3] - handler: airflow.providers.amazon.aws.datasets.s3.sanitize_uri - to_openlineage_converter: airflow.providers.amazon.aws.datasets.s3.convert_dataset_to_openlineage - factory: airflow.providers.amazon.aws.datasets.s3.create_dataset + handler: airflow.providers.amazon.aws.assets.s3.sanitize_uri + to_openlineage_converter: airflow.providers.amazon.aws.assets.s3.convert_asset_to_openlineage + factory: airflow.providers.amazon.aws.assets.s3.create_asset filesystems: - airflow.providers.amazon.aws.fs.s3 diff --git a/airflow/providers/common/compat/assets/__init__.py b/airflow/providers/common/compat/assets/__init__.py new file mode 100644 index 0000000000000..460204a4e417f --- /dev/null +++ b/airflow/providers/common/compat/assets/__init__.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow import __version__ as AIRFLOW_VERSION + +if TYPE_CHECKING: + from airflow.assets import ( + Asset, + AssetAlias, + AssetAliasEvent, + AssetAll, + AssetAny, + expand_alias_to_assets, + ) + from airflow.auth.managers.models.resource_details import AssetDetails +else: + try: + from airflow.assets import ( + Asset, + AssetAlias, + AssetAliasEvent, + AssetAll, + AssetAny, + expand_alias_to_assets, + ) + from airflow.auth.managers.models.resource_details import AssetDetails + except ModuleNotFoundError: + from packaging.version import Version + + _IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") + _IS_AIRFLOW_2_9_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0") + + # dataset is renamed to asset since Airflow 3.0 + from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails + from airflow.datasets import Dataset as Asset + + if _IS_AIRFLOW_2_9_OR_HIGHER: + from airflow.datasets import ( + DatasetAll as AssetAll, + DatasetAny as AssetAny, + ) + + if _IS_AIRFLOW_2_10_OR_HIGHER: + from airflow.datasets import ( + DatasetAlias as AssetAlias, + DatasetAliasEvent as AssetAliasEvent, + expand_alias_to_datasets as expand_alias_to_assets, + ) + + +__all__ = [ + "Asset", + "AssetAlias", + "AssetAliasEvent", + "AssetAll", + "AssetAny", + "AssetDetails", + "expand_alias_to_assets", +] diff --git a/airflow/providers/common/compat/lineage/hook.py b/airflow/providers/common/compat/lineage/hook.py index dbdbc5bf86f4d..50fbc3d0996aa 100644 --- a/airflow/providers/common/compat/lineage/hook.py +++ b/airflow/providers/common/compat/lineage/hook.py @@ -16,13 +16,78 @@ # under the License. from __future__ import annotations +from importlib.util import find_spec + + +def _get_asset_compat_hook_lineage_collector(): + from airflow.lineage.hook import get_hook_lineage_collector + + collector = get_hook_lineage_collector() + + if all( + getattr(collector, asset_method_name, None) + for asset_method_name in ("add_input_asset", "add_output_asset", "collected_assets") + ): + return collector + + # dataset is renamed as asset in Airflow 3.0 + + from functools import wraps + + from airflow.lineage.hook import DatasetLineageInfo, HookLineage + + DatasetLineageInfo.asset = DatasetLineageInfo.dataset + + def rename_dataset_kwargs_as_assets_kwargs(function): + @wraps(function) + def wrapper(*args, **kwargs): + if "asset_kwargs" in kwargs: + kwargs["dataset_kwargs"] = kwargs.pop("asset_kwargs") + + if "asset_extra" in kwargs: + kwargs["dataset_extra"] = kwargs.pop("asset_extra") + + return function(*args, **kwargs) + + return wrapper + + collector.create_asset = rename_dataset_kwargs_as_assets_kwargs(collector.create_dataset) + collector.add_input_asset = rename_dataset_kwargs_as_assets_kwargs(collector.add_input_dataset) + collector.add_output_asset = rename_dataset_kwargs_as_assets_kwargs(collector.add_output_dataset) + + def collected_assets_compat(collector) -> HookLineage: + """Get the collected hook lineage information.""" + lineage = collector.collected_datasets + return HookLineage( + [ + DatasetLineageInfo(dataset=item.dataset, count=item.count, context=item.context) + for item in lineage.inputs + ], + [ + DatasetLineageInfo(dataset=item.dataset, count=item.count, context=item.context) + for item in lineage.outputs + ], + ) + + setattr( + collector.__class__, + "collected_assets", + property(lambda collector: collected_assets_compat(collector)), + ) + + return collector + def get_hook_lineage_collector(): # HookLineageCollector added in 2.10 try: - from airflow.lineage.hook import get_hook_lineage_collector + if find_spec("airflow.assets"): + # Dataset has been renamed as Asset in 3.0 + from airflow.lineage.hook import get_hook_lineage_collector + + return get_hook_lineage_collector() - return get_hook_lineage_collector() + return _get_asset_compat_hook_lineage_collector() except ImportError: class NoOpCollector: @@ -32,10 +97,10 @@ class NoOpCollector: It is used when you want to disable lineage collection. """ - def add_input_dataset(self, *_, **__): + def add_input_asset(self, *_, **__): pass - def add_output_dataset(self, *_, **__): + def add_output_asset(self, *_, **__): pass return NoOpCollector() diff --git a/airflow/providers/common/io/datasets/__init__.py b/airflow/providers/common/compat/openlineage/utils/__init__.py similarity index 100% rename from airflow/providers/common/io/datasets/__init__.py rename to airflow/providers/common/compat/openlineage/utils/__init__.py diff --git a/airflow/providers/common/compat/openlineage/utils/utils.py b/airflow/providers/common/compat/openlineage/utils/utils.py new file mode 100644 index 0000000000000..5492c76d55dcc --- /dev/null +++ b/airflow/providers/common/compat/openlineage/utils/utils.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from functools import wraps +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from airflow.providers.openlineage.utils.utils import translate_airflow_asset +else: + try: + from airflow.providers.openlineage.utils.utils import translate_airflow_asset + except ImportError: + from airflow.providers.openlineage.utils.utils import translate_airflow_dataset + + def rename_asset_as_dataset(function): + @wraps(function) + def wrapper(*args, **kwargs): + if "asset" in kwargs: + kwargs["dataset"] = kwargs.pop("asset") + return function(*args, **kwargs) + + return wrapper + + translate_airflow_asset = rename_asset_as_dataset(translate_airflow_dataset) + + +__all__ = ["translate_airflow_asset"] diff --git a/airflow/providers/mysql/datasets/__init__.py b/airflow/providers/common/compat/security/__init__.py similarity index 100% rename from airflow/providers/mysql/datasets/__init__.py rename to airflow/providers/common/compat/security/__init__.py diff --git a/airflow/providers/common/compat/security/permissions.py b/airflow/providers/common/compat/security/permissions.py new file mode 100644 index 0000000000000..d5c351bdad31e --- /dev/null +++ b/airflow/providers/common/compat/security/permissions.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from airflow.security.permissions import RESOURCE_ASSET +else: + try: + from airflow.security.permissions import RESOURCE_ASSET + except ImportError: + from airflow.security.permissions import RESOURCE_DATASET as RESOURCE_ASSET + + +__all__ = ["RESOURCE_ASSET"] diff --git a/airflow/providers/postgres/datasets/__init__.py b/airflow/providers/common/io/assets/__init__.py similarity index 100% rename from airflow/providers/postgres/datasets/__init__.py rename to airflow/providers/common/io/assets/__init__.py diff --git a/airflow/providers/common/io/datasets/file.py b/airflow/providers/common/io/assets/file.py similarity index 76% rename from airflow/providers/common/io/datasets/file.py rename to airflow/providers/common/io/assets/file.py index 35d3b227e5223..fadc4cbe1bdc8 100644 --- a/airflow/providers/common/io/datasets/file.py +++ b/airflow/providers/common/io/assets/file.py @@ -19,7 +19,10 @@ import urllib.parse from typing import TYPE_CHECKING -from airflow.datasets import Dataset +try: + from airflow.assets import Asset +except ModuleNotFoundError: + from airflow.datasets import Dataset as Asset # type: ignore[no-redef] if TYPE_CHECKING: from urllib.parse import SplitResult @@ -27,9 +30,9 @@ from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset -def create_dataset(*, path: str, extra=None) -> Dataset: +def create_asset(*, path: str, extra=None) -> Asset: # We assume that we get absolute path starting with / - return Dataset(uri=f"file://{path}", extra=extra) + return Asset(uri=f"file://{path}", extra=extra) def sanitize_uri(uri: SplitResult) -> SplitResult: @@ -38,13 +41,13 @@ def sanitize_uri(uri: SplitResult) -> SplitResult: return uri -def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLineageDataset: +def convert_asset_to_openlineage(asset: Asset, lineage_context) -> OpenLineageDataset: """ - Translate Dataset with valid AIP-60 uri to OpenLineage with assistance from the context. + Translate Asset with valid AIP-60 uri to OpenLineage with assistance from the context. Windows paths are not standardized and can produce unexpected behaviour. """ from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset - parsed = urllib.parse.urlsplit(dataset.uri) + parsed = urllib.parse.urlsplit(asset.uri) return OpenLineageDataset(namespace=f"file://{parsed.netloc}", name=parsed.path) diff --git a/airflow/providers/common/io/provider.yaml b/airflow/providers/common/io/provider.yaml index 870605be33b05..6743cfff86c40 100644 --- a/airflow/providers/common/io/provider.yaml +++ b/airflow/providers/common/io/provider.yaml @@ -53,11 +53,19 @@ operators: xcom: - airflow.providers.common.io.xcom.backend +asset-uris: + - schemes: [file] + handler: airflow.providers.common.io.assets.file.sanitize_uri + to_openlineage_converter: airflow.providers.common.io.assets.file.convert_asset_to_openlineage + factory: airflow.providers.common.io.assets.file.create_asset + +# dataset has been renamed to asset in Airflow 3.0 +# This is kept for backward compatibility. dataset-uris: - schemes: [file] - handler: airflow.providers.common.io.datasets.file.sanitize_uri - to_openlineage_converter: airflow.providers.common.io.datasets.file.convert_dataset_to_openlineage - factory: airflow.providers.common.io.datasets.file.create_dataset + handler: airflow.providers.common.io.assets.file.sanitize_uri + to_openlineage_converter: airflow.providers.common.io.assets.file.convert_asset_to_openlineage + factory: airflow.providers.common.io.assets.file.create_asset config: common.io: diff --git a/airflow/providers/fab/auth_manager/fab_auth_manager.py b/airflow/providers/fab/auth_manager/fab_auth_manager.py index 3d0f102650935..425f2d6d2124f 100644 --- a/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -37,7 +37,6 @@ ConnectionDetails, DagAccessEntity, DagDetails, - DatasetDetails, PoolDetails, VariableDetails, ) @@ -67,7 +66,6 @@ RESOURCE_DAG_DEPENDENCIES, RESOURCE_DAG_RUN, RESOURCE_DAG_WARNING, - RESOURCE_DATASET, RESOURCE_DOCS, RESOURCE_IMPORT_ERROR, RESOURCE_JOB, @@ -94,7 +92,15 @@ from airflow.cli.cli_config import ( CLICommand, ) + from airflow.providers.common.compat.assets import AssetDetails from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride + from airflow.security.permissions import RESOURCE_ASSET +else: + try: + from airflow.security.permissions import RESOURCE_ASSET + except ImportError: + from airflow.security.permissions import RESOURCE_DATASET as RESOURCE_ASSET + _MAP_DAG_ACCESS_ENTITY_TO_FAB_RESOURCE_TYPE: dict[DagAccessEntity, tuple[str, ...]] = { DagAccessEntity.AUDIT_LOG: (RESOURCE_AUDIT_LOG,), @@ -263,10 +269,10 @@ def is_authorized_dag( for resource_type in resource_types ) - def is_authorized_dataset( - self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None + def is_authorized_asset( + self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None ) -> bool: - return self._is_authorized(method=method, resource_type=RESOURCE_DATASET, user=user) + return self._is_authorized(method=method, resource_type=RESOURCE_ASSET, user=user) def is_authorized_pool( self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None diff --git a/airflow/providers/fab/auth_manager/security_manager/override.py b/airflow/providers/fab/auth_manager/security_manager/override.py index 412e493b7bfe2..023154c5d4dd3 100644 --- a/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/airflow/providers/fab/auth_manager/security_manager/override.py @@ -115,6 +115,12 @@ if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod + from airflow.security.permissions import RESOURCE_ASSET +else: + try: + from airflow.security.permissions import RESOURCE_ASSET + except ImportError: + from airflow.security.permissions import RESOURCE_DATASET as RESOURCE_ASSET log = logging.getLogger(__name__) @@ -234,7 +240,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_DEPENDENCIES), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_READ, RESOURCE_ASSET), (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), (permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL), (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), @@ -253,7 +259,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_DEPENDENCIES), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_ACCESS_MENU, RESOURCE_ASSET), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_CLUSTER_ACTIVITY), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS_MENU), @@ -273,7 +279,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_CREATE, RESOURCE_ASSET), ] # [END security_user_perms] @@ -302,8 +308,8 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_XCOM), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DATASET), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_DELETE, RESOURCE_ASSET), + (permissions.ACTION_CAN_CREATE, RESOURCE_ASSET), ] # [END security_op_perms] diff --git a/airflow/providers/fab/provider.yaml b/airflow/providers/fab/provider.yaml index 63be11c264938..1d5cc820f1f24 100644 --- a/airflow/providers/fab/provider.yaml +++ b/airflow/providers/fab/provider.yaml @@ -47,6 +47,7 @@ versions: dependencies: - apache-airflow>=2.9.0 + - apache-airflow-providers-common-compat>=1.2.0 - flask>=2.2,<2.3 # We are tightly coupled with FAB version as we vendored-in part of FAB code related to security manager # This is done as part of preparation to removing FAB as dependency, but we are not ready for it yet diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 612fd8e29bac3..a64b2ce17a76e 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -768,6 +768,14 @@ sensors: filesystems: - airflow.providers.google.cloud.fs.gcs +asset-uris: + - schemes: [gcp] + handler: null + - schemes: [bigquery] + handler: airflow.providers.google.datasets.bigquery.sanitize_uri + +# dataset has been renamed to asset in Airflow 3.0 +# This is kept for backward compatibility. dataset-uris: - schemes: [gcp] handler: null diff --git a/airflow/providers/trino/datasets/__init__.py b/airflow/providers/mysql/assets/__init__.py similarity index 100% rename from airflow/providers/trino/datasets/__init__.py rename to airflow/providers/mysql/assets/__init__.py diff --git a/airflow/providers/mysql/datasets/mysql.py b/airflow/providers/mysql/assets/mysql.py similarity index 100% rename from airflow/providers/mysql/datasets/mysql.py rename to airflow/providers/mysql/assets/mysql.py diff --git a/airflow/providers/mysql/provider.yaml b/airflow/providers/mysql/provider.yaml index 28ba986b64ccd..f0f77f28d0e94 100644 --- a/airflow/providers/mysql/provider.yaml +++ b/airflow/providers/mysql/provider.yaml @@ -113,6 +113,12 @@ connection-types: - hook-class-name: airflow.providers.mysql.hooks.mysql.MySqlHook connection-type: mysql +asset-uris: + - schemes: [mysql, mariadb] + handler: airflow.providers.mysql.assets.mysql.sanitize_uri + +# dataset has been renamed to asset in Airflow 3.0 +# This is kept for backward compatibility. dataset-uris: - schemes: [mysql, mariadb] - handler: airflow.providers.mysql.datasets.mysql.sanitize_uri + handler: airflow.providers.mysql.assets.mysql.sanitize_uri diff --git a/airflow/providers/openlineage/extractors/manager.py b/airflow/providers/openlineage/extractors/manager.py index 74be9e01f4b7b..c72c989d8936e 100644 --- a/airflow/providers/openlineage/extractors/manager.py +++ b/airflow/providers/openlineage/extractors/manager.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Iterator +from airflow.providers.common.compat.openlineage.utils.utils import translate_airflow_asset from airflow.providers.openlineage import conf from airflow.providers.openlineage.extractors import BaseExtractor, OperatorLineage from airflow.providers.openlineage.extractors.base import DefaultExtractor @@ -25,7 +26,6 @@ from airflow.providers.openlineage.extractors.python import PythonExtractor from airflow.providers.openlineage.utils.utils import ( get_unknown_source_attribute_run_facet, - translate_airflow_dataset, try_import_from_string, ) from airflow.utils.log.logging_mixin import LoggingMixin @@ -178,7 +178,16 @@ def extract_inlets_and_outlets( def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None: try: - from airflow.lineage.hook import get_hook_lineage_collector + from importlib.util import find_spec + + if find_spec("airflow.assets"): + from airflow.lineage.hook import get_hook_lineage_collector + else: + # TODO: import from common.compat directly after common.compat providers with + # asset_compat_lineage_collector released + from airflow.providers.openlineage.utils.asset_compat_lineage_collector import ( + get_hook_lineage_collector, + ) except ImportError: return None @@ -187,16 +196,14 @@ def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None: return ( [ - dataset - for dataset_info in get_hook_lineage_collector().collected_datasets.inputs - if (dataset := translate_airflow_dataset(dataset_info.dataset, dataset_info.context)) - is not None + asset + for asset_info in get_hook_lineage_collector().collected_assets.inputs + if (asset := translate_airflow_asset(asset_info.asset, asset_info.context)) is not None ], [ - dataset - for dataset_info in get_hook_lineage_collector().collected_datasets.outputs - if (dataset := translate_airflow_dataset(dataset_info.dataset, dataset_info.context)) - is not None + asset + for asset_info in get_hook_lineage_collector().collected_assets.outputs + if (asset := translate_airflow_asset(asset_info.asset, asset_info.context)) is not None ], ) diff --git a/airflow/providers/openlineage/utils/asset_compat_lineage_collector.py b/airflow/providers/openlineage/utils/asset_compat_lineage_collector.py new file mode 100644 index 0000000000000..8a4d2b61914ff --- /dev/null +++ b/airflow/providers/openlineage/utils/asset_compat_lineage_collector.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from importlib.util import find_spec + +# TODO: replace this module with common.compat provider once common.compat 1.3.0 released + + +def _get_asset_compat_hook_lineage_collector(): + from airflow.lineage.hook import get_hook_lineage_collector + + collector = get_hook_lineage_collector() + + if all( + getattr(collector, asset_method_name, None) + for asset_method_name in ("add_input_asset", "add_output_asset", "collected_assets") + ): + return collector + + # dataset is renamed as asset in Airflow 3.0 + + from functools import wraps + + from airflow.lineage.hook import DatasetLineageInfo, HookLineage + + DatasetLineageInfo.asset = DatasetLineageInfo.dataset + + def rename_dataset_kwargs_as_assets_kwargs(function): + @wraps(function) + def wrapper(*args, **kwargs): + if "asset_kwargs" in kwargs: + kwargs["dataset_kwargs"] = kwargs.pop("asset_kwargs") + + if "asset_extra" in kwargs: + kwargs["dataset_extra"] = kwargs.pop("asset_extra") + + return function(*args, **kwargs) + + return wrapper + + collector.create_asset = rename_dataset_kwargs_as_assets_kwargs(collector.create_dataset) + collector.add_input_asset = rename_dataset_kwargs_as_assets_kwargs(collector.add_input_dataset) + collector.add_output_asset = rename_dataset_kwargs_as_assets_kwargs(collector.add_output_dataset) + + def collected_assets_compat(collector) -> HookLineage: + """Get the collected hook lineage information.""" + lineage = collector.collected_datasets + return HookLineage( + [ + DatasetLineageInfo(dataset=item.dataset, count=item.count, context=item.context) + for item in lineage.inputs + ], + [ + DatasetLineageInfo(dataset=item.dataset, count=item.count, context=item.context) + for item in lineage.outputs + ], + ) + + setattr( + collector.__class__, + "collected_assets", + property(lambda collector: collected_assets_compat(collector)), + ) + + return collector + + +def get_hook_lineage_collector(): + # HookLineageCollector added in 2.10 + try: + if find_spec("airflow.assets"): + # Dataset has been renamed as Asset in 3.0 + from airflow.lineage.hook import get_hook_lineage_collector + + return get_hook_lineage_collector() + + return _get_asset_compat_hook_lineage_collector() + except ImportError: + + class NoOpCollector: + """ + NoOpCollector is a hook lineage collector that does nothing. + + It is used when you want to disable lineage collection. + """ + + def add_input_asset(self, *_, **__): + pass + + def add_output_asset(self, *_, **__): + pass + + return NoOpCollector() diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index f283c09e8759c..ca57755d692ea 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -31,7 +31,6 @@ from packaging.version import Version from airflow import __version__ as AIRFLOW_VERSION -from airflow.datasets import Dataset from airflow.exceptions import AirflowProviderDeprecationWarning # TODO: move this maybe to Airflow's logic? from airflow.models import DAG, BaseOperator, DagRun, MappedOperator from airflow.providers.openlineage import conf @@ -54,6 +53,11 @@ from airflow.utils.log.secrets_masker import Redactable, Redacted, SecretsMasker, should_hide_value_for_key from airflow.utils.module_loading import import_string +try: + from airflow.assets import Asset +except ModuleNotFoundError: + from airflow.datasets import Dataset as Asset # type: ignore[no-redef] + if TYPE_CHECKING: from openlineage.client.event_v2 import Dataset as OpenLineageDataset from openlineage.client.facet_v2 import RunFacet @@ -283,8 +287,8 @@ class TaskInstanceInfo(InfoJsonEncodable): } -class DatasetInfo(InfoJsonEncodable): - """Defines encoding Airflow Dataset object to JSON.""" +class AssetInfo(InfoJsonEncodable): + """Defines encoding Airflow Asset object to JSON.""" includes = ["uri", "extra"] @@ -335,8 +339,8 @@ class TaskInfo(InfoJsonEncodable): if hasattr(task, "task_group") and getattr(task.task_group, "_group_id", None) else None ), - "inlets": lambda task: [DatasetInfo(i) for i in task.inlets if isinstance(i, Dataset)], - "outlets": lambda task: [DatasetInfo(o) for o in task.outlets if isinstance(o, Dataset)], + "inlets": lambda task: [AssetInfo(i) for i in task.inlets if isinstance(i, Asset)], + "outlets": lambda task: [AssetInfo(o) for o in task.outlets if isinstance(o, Asset)], } @@ -641,19 +645,29 @@ def should_use_external_connection(hook) -> bool: return True -def translate_airflow_dataset(dataset: Dataset, lineage_context) -> OpenLineageDataset | None: +def translate_airflow_asset(asset: Asset, lineage_context) -> OpenLineageDataset | None: """ - Convert a Dataset with an AIP-60 compliant URI to an OpenLineageDataset. + Convert a Asset with an AIP-60 compliant URI to an OpenLineageDataset. - This function returns None if no URI normalizer is defined, no dataset converter is found or + This function returns None if no URI normalizer is defined, no asset converter is found or some core Airflow changes are missing and ImportError is raised. """ try: - from airflow.datasets import _get_normalized_scheme + from airflow.assets import _get_normalized_scheme + except ModuleNotFoundError: + try: + from airflow.datasets import _get_normalized_scheme # type: ignore[no-redef] + except ImportError: + return None + + try: from airflow.providers_manager import ProvidersManager - ol_converters = ProvidersManager().dataset_to_openlineage_converters - normalized_uri = dataset.normalized_uri + ol_converters = getattr(ProvidersManager(), "asset_to_openlineage_converters", None) + if not ol_converters: + ol_converters = ProvidersManager().dataset_to_openlineage_converters # type: ignore[attr-defined] + + normalized_uri = asset.normalized_uri except (ImportError, AttributeError): return None @@ -666,4 +680,4 @@ def translate_airflow_dataset(dataset: Dataset, lineage_context) -> OpenLineageD if (airflow_to_ol_converter := ol_converters.get(normalized_scheme)) is None: return None - return airflow_to_ol_converter(Dataset(uri=normalized_uri, extra=dataset.extra), lineage_context) + return airflow_to_ol_converter(Asset(uri=normalized_uri, extra=asset.extra), lineage_context) diff --git a/tests/datasets/__init__.py b/airflow/providers/postgres/assets/__init__.py similarity index 100% rename from tests/datasets/__init__.py rename to airflow/providers/postgres/assets/__init__.py diff --git a/airflow/providers/postgres/datasets/postgres.py b/airflow/providers/postgres/assets/postgres.py similarity index 100% rename from airflow/providers/postgres/datasets/postgres.py rename to airflow/providers/postgres/assets/postgres.py diff --git a/airflow/providers/postgres/provider.yaml b/airflow/providers/postgres/provider.yaml index 7ce95986fd65a..edbbaeb1da2c1 100644 --- a/airflow/providers/postgres/provider.yaml +++ b/airflow/providers/postgres/provider.yaml @@ -96,6 +96,12 @@ connection-types: - hook-class-name: airflow.providers.postgres.hooks.postgres.PostgresHook connection-type: postgres +asset-uris: + - schemes: [postgres, postgresql] + handler: airflow.providers.postgres.assets.postgres.sanitize_uri + +# dataset has been renamed to asset in Airflow 3.0 +# This is kept for backward compatibility. dataset-uris: - schemes: [postgres, postgresql] - handler: airflow.providers.postgres.datasets.postgres.sanitize_uri + handler: airflow.providers.postgres.assets.postgres.sanitize_uri diff --git a/tests/providers/amazon/aws/datasets/__init__.py b/airflow/providers/trino/assets/__init__.py similarity index 100% rename from tests/providers/amazon/aws/datasets/__init__.py rename to airflow/providers/trino/assets/__init__.py diff --git a/airflow/providers/trino/datasets/trino.py b/airflow/providers/trino/assets/trino.py similarity index 100% rename from airflow/providers/trino/datasets/trino.py rename to airflow/providers/trino/assets/trino.py diff --git a/airflow/providers/trino/provider.yaml b/airflow/providers/trino/provider.yaml index d4000baaa063f..424be2cca67d9 100644 --- a/airflow/providers/trino/provider.yaml +++ b/airflow/providers/trino/provider.yaml @@ -86,9 +86,15 @@ operators: python-modules: - airflow.providers.trino.operators.trino +asset-uris: + - schemes: [trino] + handler: airflow.providers.trino.assets.trino.sanitize_uri + +# dataset has been renamed to asset in Airflow 3.0 +# This is kept for backward compatibility. dataset-uris: - schemes: [trino] - handler: airflow.providers.trino.datasets.trino.sanitize_uri + handler: airflow.providers.trino.assets.trino.sanitize_uri hooks: - integration-name: Trino diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index dd3e841fa1662..2c673063cb23e 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -91,7 +91,7 @@ def ensure_prefix(field): if TYPE_CHECKING: from urllib.parse import SplitResult - from airflow.datasets import Dataset + from airflow.assets import Asset from airflow.decorators.base import TaskDecorator from airflow.hooks.base import BaseHook from airflow.typing_compat import Literal @@ -426,9 +426,9 @@ def __init__(self): # Keeps dict of hooks keyed by connection type self._hooks_dict: dict[str, HookInfo] = {} self._fs_set: set[str] = set() - self._dataset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {} - self._dataset_factories: dict[str, Callable[..., Dataset]] = {} - self._dataset_to_openlineage_converters: dict[str, Callable] = {} + self._asset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {} + self._asset_factories: dict[str, Callable[..., Asset]] = {} + self._asset_to_openlineage_converters: dict[str, Callable] = {} self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() # type: ignore[assignment] # keeps mapping between connection_types and hook class, package they come from self._hook_provider_dict: dict[str, HookClassProvider] = {} @@ -525,11 +525,11 @@ def initialize_providers_filesystems(self): self.initialize_providers_list() self._discover_filesystems() - @provider_info_cache("dataset_uris") - def initialize_providers_dataset_uri_resources(self): - """Lazy initialization of provider dataset URI handlers, factories, converters etc.""" + @provider_info_cache("asset_uris") + def initialize_providers_asset_uri_resources(self): + """Lazy initialization of provider asset URI handlers, factories, converters etc.""" self.initialize_providers_list() - self._discover_dataset_uri_resources() + self._discover_asset_uri_resources() @provider_info_cache("hook_lineage_writers") @provider_info_cache("taskflow_decorators") @@ -882,9 +882,9 @@ def _discover_filesystems(self) -> None: self._fs_set.add(fs_module_name) self._fs_set = set(sorted(self._fs_set)) - def _discover_dataset_uri_resources(self) -> None: - """Discovers and registers dataset URI handlers, factories, and converters for all providers.""" - from airflow.datasets import normalize_noop + def _discover_asset_uri_resources(self) -> None: + """Discovers and registers asset URI handlers, factories, and converters for all providers.""" + from airflow.assets import normalize_noop def _safe_register_resource( provider_package_name: str, @@ -908,24 +908,24 @@ def _safe_register_resource( resource_registry.update((scheme, resource) for scheme in schemes_list) for provider_name, provider in self._provider_dict.items(): - for uri_info in provider.data.get("dataset-uris", []): + for uri_info in provider.data.get("asset-uris", []): if "schemes" not in uri_info or "handler" not in uri_info: continue # Both schemas and handler must be explicitly set, handler can be set to null common_args = {"schemes_list": uri_info["schemes"], "provider_package_name": provider_name} _safe_register_resource( resource_path=uri_info["handler"], - resource_registry=self._dataset_uri_handlers, + resource_registry=self._asset_uri_handlers, default_resource=normalize_noop, **common_args, ) _safe_register_resource( resource_path=uri_info.get("factory"), - resource_registry=self._dataset_factories, + resource_registry=self._asset_factories, **common_args, ) _safe_register_resource( resource_path=uri_info.get("to_openlineage_converter"), - resource_registry=self._dataset_to_openlineage_converters, + resource_registry=self._asset_to_openlineage_converters, **common_args, ) @@ -1325,21 +1325,21 @@ def filesystem_module_names(self) -> list[str]: return sorted(self._fs_set) @property - def dataset_factories(self) -> dict[str, Callable[..., Dataset]]: - self.initialize_providers_dataset_uri_resources() - return self._dataset_factories + def asset_factories(self) -> dict[str, Callable[..., Asset]]: + self.initialize_providers_asset_uri_resources() + return self._asset_factories @property - def dataset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]: - self.initialize_providers_dataset_uri_resources() - return self._dataset_uri_handlers + def asset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]: + self.initialize_providers_asset_uri_resources() + return self._asset_uri_handlers @property - def dataset_to_openlineage_converters( + def asset_to_openlineage_converters( self, ) -> dict[str, Callable]: - self.initialize_providers_dataset_uri_resources() - return self._dataset_to_openlineage_converters + self.initialize_providers_asset_uri_resources() + return self._asset_to_openlineage_converters @property def provider_configs(self) -> list[tuple[str, dict[str, Any]]]: diff --git a/airflow/reproducible_build.yaml b/airflow/reproducible_build.yaml index 1bf308b87a705..8a35282492059 100644 --- a/airflow/reproducible_build.yaml +++ b/airflow/reproducible_build.yaml @@ -1,2 +1,2 @@ -release-notes-hash: 828fa8d5e93e215963c0a3e52e7f1e3d -source-date-epoch: 1727075869 +release-notes-hash: cc9c5c2ea1cade5d714aa4832587e13a +source-date-epoch: 1727595745 diff --git a/airflow/security/permissions.py b/airflow/security/permissions.py index 45b56c342b44e..acd245865a4ad 100644 --- a/airflow/security/permissions.py +++ b/airflow/security/permissions.py @@ -33,7 +33,7 @@ RESOURCE_DAG_RUN_PREFIX = "DAG Run:" RESOURCE_DAG_WARNING = "DAG Warnings" RESOURCE_CLUSTER_ACTIVITY = "Cluster Activity" -RESOURCE_DATASET = "Datasets" +RESOURCE_ASSET = "Assets" RESOURCE_DOCS = "Documentation" RESOURCE_DOCS_MENU = "Docs" RESOURCE_IMPORT_ERROR = "ImportError" diff --git a/airflow/serialization/dag_dependency.py b/airflow/serialization/dag_dependency.py index bff1b39ebe04b..bede95ba9235b 100644 --- a/airflow/serialization/dag_dependency.py +++ b/airflow/serialization/dag_dependency.py @@ -36,7 +36,7 @@ class DagDependency: def node_id(self): """Node ID for graph rendering.""" val = f"{self.dependency_type}" - if self.dependency_type not in ("dataset", "dataset-alias"): + if self.dependency_type not in ("asset", "asset-alias"): val += f":{self.source}:{self.target}" if self.dependency_id: val += f":{self.dependency_id}" diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index 49a3de3d774c4..dd63366b8a958 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -37,8 +37,8 @@ class DagAttributeTypes(str, Enum): """Enum of supported attribute types of DAG.""" DAG = "dag" - DATASET_EVENT_ACCESSORS = "dataset_event_accessors" - DATASET_EVENT_ACCESSOR = "dataset_event_accessor" + ASSET_EVENT_ACCESSORS = "asset_event_accessors" + ASSET_EVENT_ACCESSOR = "asset_event_accessor" OP = "operator" DATETIME = "datetime" TIMEDELTA = "timedelta" @@ -55,16 +55,15 @@ class DagAttributeTypes(str, Enum): EDGE_INFO = "edgeinfo" PARAM = "param" XCOM_REF = "xcomref" - DATASET = "dataset" - DATASET_ALIAS = "dataset_alias" - DATASET_ANY = "dataset_any" - DATASET_ALL = "dataset_all" + ASSET = "asset" + ASSET_ALIAS = "asset_alias" + ASSET_ANY = "asset_any" + ASSET_ALL = "asset_all" SIMPLE_TASK_INSTANCE = "simple_task_instance" BASE_JOB = "Job" TASK_INSTANCE = "task_instance" DAG_RUN = "dag_run" DAG_MODEL = "dag_model" - DATA_SET = "data_set" LOG_TEMPLATE = "log_template" CONNECTION = "connection" TASK_CONTEXT = "task_context" diff --git a/airflow/serialization/pydantic/dataset.py b/airflow/serialization/pydantic/asset.py similarity index 68% rename from airflow/serialization/pydantic/dataset.py rename to airflow/serialization/pydantic/asset.py index 0c233a3fd67c6..29806d3bdf911 100644 --- a/airflow/serialization/pydantic/dataset.py +++ b/airflow/serialization/pydantic/asset.py @@ -20,8 +20,8 @@ from pydantic import BaseModel as BaseModelPydantic, ConfigDict -class DagScheduleDatasetReferencePydantic(BaseModelPydantic): - """Serializable version of the DagScheduleDatasetReference ORM SqlAlchemyModel used by internal API.""" +class DagScheduleAssetReferencePydantic(BaseModelPydantic): + """Serializable version of the DagScheduleAssetReference ORM SqlAlchemyModel used by internal API.""" dataset_id: int dag_id: str @@ -31,8 +31,8 @@ class DagScheduleDatasetReferencePydantic(BaseModelPydantic): model_config = ConfigDict(from_attributes=True) -class TaskOutletDatasetReferencePydantic(BaseModelPydantic): - """Serializable version of the TaskOutletDatasetReference ORM SqlAlchemyModel used by internal API.""" +class TaskOutletAssetReferencePydantic(BaseModelPydantic): + """Serializable version of the TaskOutletAssetReference ORM SqlAlchemyModel used by internal API.""" dataset_id: int dag_id: str @@ -43,8 +43,8 @@ class TaskOutletDatasetReferencePydantic(BaseModelPydantic): model_config = ConfigDict(from_attributes=True) -class DatasetPydantic(BaseModelPydantic): - """Serializable representation of the Dataset ORM SqlAlchemyModel used by internal API.""" +class AssetPydantic(BaseModelPydantic): + """Serializable representation of the Asset ORM SqlAlchemyModel used by internal API.""" id: int uri: str @@ -53,14 +53,14 @@ class DatasetPydantic(BaseModelPydantic): updated_at: datetime is_orphaned: bool - consuming_dags: List[DagScheduleDatasetReferencePydantic] - producing_tasks: List[TaskOutletDatasetReferencePydantic] + consuming_dags: List[DagScheduleAssetReferencePydantic] + producing_tasks: List[TaskOutletAssetReferencePydantic] model_config = ConfigDict(from_attributes=True) -class DatasetEventPydantic(BaseModelPydantic): - """Serializable representation of the DatasetEvent ORM SqlAlchemyModel used by internal API.""" +class AssetEventPydantic(BaseModelPydantic): + """Serializable representation of the AssetEvent ORM SqlAlchemyModel used by internal API.""" id: int dataset_id: Optional[int] @@ -70,6 +70,6 @@ class DatasetEventPydantic(BaseModelPydantic): source_run_id: Optional[str] source_map_index: Optional[int] timestamp: datetime - dataset: Optional[DatasetPydantic] + dataset: Optional[AssetPydantic] model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) diff --git a/airflow/serialization/pydantic/dag_run.py b/airflow/serialization/pydantic/dag_run.py index a3a53c6d941f4..86857452e8310 100644 --- a/airflow/serialization/pydantic/dag_run.py +++ b/airflow/serialization/pydantic/dag_run.py @@ -22,8 +22,8 @@ from pydantic import BaseModel as BaseModelPydantic, ConfigDict from airflow.models.dagrun import DagRun +from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag import PydanticDag -from airflow.serialization.pydantic.dataset import DatasetEventPydantic from airflow.utils.types import DagRunTriggeredByType if TYPE_CHECKING: @@ -55,7 +55,7 @@ class DagRunPydantic(BaseModelPydantic): dag_hash: Optional[str] updated_at: Optional[datetime] dag: Optional[PydanticDag] - consumed_dataset_events: List[DatasetEventPydantic] # noqa: UP006 + consumed_dataset_events: List[AssetEventPydantic] # noqa: UP006 log_template_id: Optional[int] triggered_by: Optional[DagRunTriggeredByType] diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index 549b03680df83..caf44bea4c673 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -509,8 +509,8 @@ def command_as_list( cfg_path=cfg_path, ) - def _register_dataset_changes(self, *, events, session: Session | None = None) -> None: - TaskInstance._register_dataset_changes(self=self, events=events, session=session) # type: ignore[arg-type] + def _register_asset_changes(self, *, events, session: Session | None = None) -> None: + TaskInstance._register_asset_changes(self=self, events=events, session=session) # type: ignore[arg-type] def defer_task(self, exception: TaskDeferred, session: Session | None = None): """Defer task.""" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index c9c1f11835277..a4801b767acc5 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -34,16 +34,16 @@ from pendulum.tz.timezone import FixedTimezone, Timezone from airflow import macros +from airflow.assets import ( + Asset, + AssetAlias, + AssetAll, + AssetAny, + BaseAsset, + _AssetAliasCondition, +) from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest from airflow.compat.functools import cache -from airflow.datasets import ( - BaseDataset, - Dataset, - DatasetAlias, - DatasetAll, - DatasetAny, - _DatasetAliasCondition, -) from airflow.exceptions import AirflowException, SerializationError, TaskDeferred from airflow.jobs.job import Job from airflow.models import Trigger @@ -63,9 +63,9 @@ from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.helpers import serialize_template_field from airflow.serialization.json_schema import load_dag_schema +from airflow.serialization.pydantic.asset import AssetPydantic from airflow.serialization.pydantic.dag import DagModelPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic -from airflow.serialization.pydantic.dataset import DatasetPydantic from airflow.serialization.pydantic.job import JobPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.serialization.pydantic.tasklog import LogTemplatePydantic @@ -246,38 +246,38 @@ def __str__(self) -> str: ) -def encode_dataset_condition(var: BaseDataset) -> dict[str, Any]: +def encode_asset_condition(var: BaseAsset) -> dict[str, Any]: """ - Encode a dataset condition. + Encode an asset condition. :meta private: """ - if isinstance(var, Dataset): - return {"__type": DAT.DATASET, "uri": var.uri, "extra": var.extra} - if isinstance(var, DatasetAlias): - return {"__type": DAT.DATASET_ALIAS, "name": var.name} - if isinstance(var, DatasetAll): - return {"__type": DAT.DATASET_ALL, "objects": [encode_dataset_condition(x) for x in var.objects]} - if isinstance(var, DatasetAny): - return {"__type": DAT.DATASET_ANY, "objects": [encode_dataset_condition(x) for x in var.objects]} + if isinstance(var, Asset): + return {"__type": DAT.ASSET, "uri": var.uri, "extra": var.extra} + if isinstance(var, AssetAlias): + return {"__type": DAT.ASSET_ALIAS, "name": var.name} + if isinstance(var, AssetAll): + return {"__type": DAT.ASSET_ALL, "objects": [encode_asset_condition(x) for x in var.objects]} + if isinstance(var, AssetAny): + return {"__type": DAT.ASSET_ANY, "objects": [encode_asset_condition(x) for x in var.objects]} raise ValueError(f"serialization not implemented for {type(var).__name__!r}") -def decode_dataset_condition(var: dict[str, Any]) -> BaseDataset: +def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: """ Decode a previously serialized dataset condition. :meta private: """ dat = var["__type"] - if dat == DAT.DATASET: - return Dataset(var["uri"], extra=var["extra"]) - if dat == DAT.DATASET_ALL: - return DatasetAll(*(decode_dataset_condition(x) for x in var["objects"])) - if dat == DAT.DATASET_ANY: - return DatasetAny(*(decode_dataset_condition(x) for x in var["objects"])) - if dat == DAT.DATASET_ALIAS: - return DatasetAlias(name=var["name"]) + if dat == DAT.ASSET: + return Asset(var["uri"], extra=var["extra"]) + if dat == DAT.ASSET_ALL: + return AssetAll(*(decode_asset_condition(x) for x in var["objects"])) + if dat == DAT.ASSET_ANY: + return AssetAny(*(decode_asset_condition(x) for x in var["objects"])) + if dat == DAT.ASSET_ALIAS: + return AssetAlias(name=var["name"]) raise ValueError(f"deserialization not implemented for DAT {dat!r}") @@ -285,23 +285,18 @@ def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]: raw_key = var.raw_key return { "extra": var.extra, - "dataset_alias_events": var.dataset_alias_events, + "asset_alias_events": var.asset_alias_events, "raw_key": BaseSerialization.serialize(raw_key), } def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor: - # This is added for compatibility. The attribute used to be dataset_alias_event and - # is now dataset_alias_events. - if dataset_alias_event := var.get("dataset_alias_event", None): - dataset_alias_events = [dataset_alias_event] - else: - dataset_alias_events = var.get("dataset_alias_events", []) + asset_alias_events = var.get("asset_alias_events", []) outlet_event_accessor = OutletEventAccessor( extra=var["extra"], raw_key=BaseSerialization.deserialize(var["raw_key"]), - dataset_alias_events=dataset_alias_events, + asset_alias_events=asset_alias_events, ) return outlet_event_accessor @@ -482,7 +477,7 @@ def deref(self, dag: DAG) -> ExpandInput: DagRun: DagRunPydantic, DagModel: DagModelPydantic, LogTemplate: LogTemplatePydantic, - Dataset: DatasetPydantic, + Asset: AssetPydantic, Trigger: TriggerPydantic, } _type_to_class: dict[DAT | str, list] = { @@ -491,7 +486,7 @@ def deref(self, dag: DAG) -> ExpandInput: DAT.DAG_RUN: [DagRunPydantic, DagRun], DAT.DAG_MODEL: [DagModelPydantic, DagModel], DAT.LOG_TEMPLATE: [LogTemplatePydantic, LogTemplate], - DAT.DATA_SET: [DatasetPydantic, Dataset], + DAT.ASSET: [AssetPydantic, Asset], DAT.TRIGGER: [TriggerPydantic, Trigger], } _class_to_type = {cls_: type_ for type_, classes in _type_to_class.items() for cls_ in classes} @@ -661,12 +656,12 @@ def serialize( elif isinstance(var, OutletEventAccessors): return cls._encode( cls.serialize(var._dict, strict=strict, use_pydantic_models=use_pydantic_models), # type: ignore[attr-defined] - type_=DAT.DATASET_EVENT_ACCESSORS, + type_=DAT.ASSET_EVENT_ACCESSORS, ) elif isinstance(var, OutletEventAccessor): return cls._encode( encode_outlet_event_accessor(var), - type_=DAT.DATASET_EVENT_ACCESSOR, + type_=DAT.ASSET_EVENT_ACCESSOR, ) elif isinstance(var, DAG): return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG) @@ -744,8 +739,8 @@ def serialize( return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF) elif isinstance(var, LazySelectSequence): return cls.serialize(list(var)) - elif isinstance(var, BaseDataset): - serialized_dataset = encode_dataset_condition(var) + elif isinstance(var, BaseAsset): + serialized_dataset = encode_asset_condition(var) return cls._encode(serialized_dataset, type_=serialized_dataset.pop("__type")) elif isinstance(var, SimpleTaskInstance): return cls._encode( @@ -826,11 +821,11 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: return Context(**d) elif type_ == DAT.DICT: return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()} - elif type_ == DAT.DATASET_EVENT_ACCESSORS: + elif type_ == DAT.ASSET_EVENT_ACCESSORS: d = OutletEventAccessors() # type: ignore[assignment] d._dict = cls.deserialize(var) # type: ignore[attr-defined] return d - elif type_ == DAT.DATASET_EVENT_ACCESSOR: + elif type_ == DAT.ASSET_EVENT_ACCESSOR: return decode_outlet_event_accessor(var) elif type_ == DAT.DAG: return SerializedDAG.deserialize_dag(var) @@ -872,14 +867,14 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: return cls._deserialize_param(var) elif type_ == DAT.XCOM_REF: return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG. - elif type_ == DAT.DATASET: - return Dataset(**var) - elif type_ == DAT.DATASET_ALIAS: - return DatasetAlias(**var) - elif type_ == DAT.DATASET_ANY: - return DatasetAny(*(decode_dataset_condition(x) for x in var["objects"])) - elif type_ == DAT.DATASET_ALL: - return DatasetAll(*(decode_dataset_condition(x) for x in var["objects"])) + elif type_ == DAT.ASSET: + return Asset(**var) + elif type_ == DAT.ASSET_ALIAS: + return AssetAlias(**var) + elif type_ == DAT.ASSET_ANY: + return AssetAny(*(decode_asset_condition(x) for x in var["objects"])) + elif type_ == DAT.ASSET_ALL: + return AssetAll(*(decode_asset_condition(x) for x in var["objects"])) elif type_ == DAT.SIMPLE_TASK_INSTANCE: return SimpleTaskInstance(**cls.deserialize(var)) elif type_ == DAT.CONNECTION: @@ -1041,17 +1036,17 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]: ) ) for obj in task.outlets or []: - if isinstance(obj, Dataset): + if isinstance(obj, Asset): deps.append( DagDependency( source=task.dag_id, - target="dataset", - dependency_type="dataset", + target="asset", + dependency_type="asset", dependency_id=obj.uri, ) ) - elif isinstance(obj, DatasetAlias): - cond = _DatasetAliasCondition(obj.name) + elif isinstance(obj, AssetAlias): + cond = _AssetAliasCondition(obj.name) deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target="")) return deps @@ -1062,7 +1057,7 @@ def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]: if not dag: return - yield from dag.timetable.dataset_condition.iter_dag_dependencies(source="", target=dag.dag_id) + yield from dag.timetable.asset_condition.iter_dag_dependencies(source="", target=dag.dag_id) class SerializedBaseOperator(BaseOperator, BaseSerialization): diff --git a/airflow/timetables/datasets.py b/airflow/timetables/assets.py similarity index 71% rename from airflow/timetables/datasets.py rename to airflow/timetables/assets.py index 05db0d66cc2df..b158555590ad5 100644 --- a/airflow/timetables/datasets.py +++ b/airflow/timetables/assets.py @@ -19,9 +19,9 @@ import typing -from airflow.datasets import BaseDataset, DatasetAll +from airflow.assets import AssetAll, BaseAsset from airflow.exceptions import AirflowTimetableInvalid -from airflow.timetables.simple import DatasetTriggeredTimetable as DatasetTriggeredSchedule +from airflow.timetables.simple import AssetTriggeredTimetable from airflow.utils.types import DagRunType if typing.TYPE_CHECKING: @@ -29,56 +29,56 @@ import pendulum - from airflow.datasets import Dataset + from airflow.assets import Asset from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable -class DatasetOrTimeSchedule(DatasetTriggeredSchedule): +class AssetOrTimeSchedule(AssetTriggeredTimetable): """Combine time-based scheduling with event-based scheduling.""" def __init__( self, *, timetable: Timetable, - datasets: Collection[Dataset] | BaseDataset, + assets: Collection[Asset] | BaseAsset, ) -> None: self.timetable = timetable - if isinstance(datasets, BaseDataset): - self.dataset_condition = datasets + if isinstance(assets, BaseAsset): + self.asset_condition = assets else: - self.dataset_condition = DatasetAll(*datasets) + self.asset_condition = AssetAll(*assets) - self.description = f"Triggered by datasets or {timetable.description}" + self.description = f"Triggered by assets or {timetable.description}" self.periodic = timetable.periodic self.can_be_scheduled = timetable.can_be_scheduled self.active_runs_limit = timetable.active_runs_limit @classmethod def deserialize(cls, data: dict[str, typing.Any]) -> Timetable: - from airflow.serialization.serialized_objects import decode_dataset_condition, decode_timetable + from airflow.serialization.serialized_objects import decode_asset_condition, decode_timetable return cls( - datasets=decode_dataset_condition(data["dataset_condition"]), + assets=decode_asset_condition(data["asset_condition"]), timetable=decode_timetable(data["timetable"]), ) def serialize(self) -> dict[str, typing.Any]: - from airflow.serialization.serialized_objects import encode_dataset_condition, encode_timetable + from airflow.serialization.serialized_objects import encode_asset_condition, encode_timetable return { - "dataset_condition": encode_dataset_condition(self.dataset_condition), + "asset_condition": encode_asset_condition(self.asset_condition), "timetable": encode_timetable(self.timetable), } def validate(self) -> None: - if isinstance(self.timetable, DatasetTriggeredSchedule): - raise AirflowTimetableInvalid("cannot nest dataset timetables") - if not isinstance(self.dataset_condition, BaseDataset): - raise AirflowTimetableInvalid("all elements in 'datasets' must be datasets") + if isinstance(self.timetable, AssetTriggeredTimetable): + raise AirflowTimetableInvalid("cannot nest asset timetables") + if not isinstance(self.asset_condition, BaseAsset): + raise AirflowTimetableInvalid("all elements in 'assets' must be assets") @property def summary(self) -> str: - return f"Dataset or {self.timetable.summary}" + return f"Asset or {self.timetable.summary}" def infer_manual_data_interval(self, *, run_after: pendulum.DateTime) -> DataInterval: return self.timetable.infer_manual_data_interval(run_after=run_after) diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py index 5d97591856b5a..64a2612026517 100644 --- a/airflow/timetables/base.py +++ b/airflow/timetables/base.py @@ -18,20 +18,20 @@ from typing import TYPE_CHECKING, Any, Iterator, NamedTuple, Sequence -from airflow.datasets import BaseDataset +from airflow.assets import BaseAsset from airflow.typing_compat import Protocol, runtime_checkable if TYPE_CHECKING: from pendulum import DateTime - from airflow.datasets import Dataset, DatasetAlias + from airflow.assets import Asset, AssetAlias from airflow.serialization.dag_dependency import DagDependency from airflow.utils.types import DagRunType -class _NullDataset(BaseDataset): +class _NullAsset(BaseAsset): """ - Sentinel type that represents "no datasets". + Sentinel type that represents "no assets". This is only implemented to make typing easier in timetables, and not expected to be used anywhere else. @@ -42,10 +42,10 @@ class _NullDataset(BaseDataset): def __bool__(self) -> bool: return False - def __or__(self, other: BaseDataset) -> BaseDataset: + def __or__(self, other: BaseAsset) -> BaseAsset: return NotImplemented - def __and__(self, other: BaseDataset) -> BaseDataset: + def __and__(self, other: BaseAsset) -> BaseAsset: return NotImplemented def as_expression(self) -> Any: @@ -54,10 +54,10 @@ def as_expression(self) -> Any: def evaluate(self, statuses: dict[str, bool]) -> bool: return False - def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + def iter_assets(self) -> Iterator[tuple[str, Asset]]: return iter(()) - def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: return iter(()) def iter_dag_dependencies(self, source, target) -> Iterator[DagDependency]: @@ -189,11 +189,11 @@ class Timetable(Protocol): as for :class:`~airflow.timetable.simple.ContinuousTimetable`. """ - dataset_condition: BaseDataset = _NullDataset() - """The dataset condition that triggers a DAG using this timetable. + asset_condition: BaseAsset = _NullAsset() + """The asset condition that triggers a DAG using this timetable. - If this is not *None*, this should be a dataset, or a combination of, that - controls the DAG's dataset triggers. + If this is not *None*, this should be an asset, or a combination of, that + controls the DAG's asset triggers. """ @classmethod diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index ad166a641378a..5a931b40dd11d 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Collection, Sequence -from airflow.datasets import DatasetAlias, _DatasetAliasCondition +from airflow.assets import AssetAlias, _AssetAliasCondition from airflow.timetables.base import DagRunInfo, DataInterval, Timetable from airflow.utils import timezone @@ -26,8 +26,8 @@ from pendulum import DateTime from sqlalchemy import Session - from airflow.datasets import BaseDataset - from airflow.models.dataset import DatasetEvent + from airflow.assets import BaseAsset + from airflow.models.asset import AssetEvent from airflow.timetables.base import TimeRestriction from airflow.utils.types import DagRunType @@ -152,44 +152,44 @@ def next_dagrun_info( return DagRunInfo.interval(start, end) -class DatasetTriggeredTimetable(_TrivialTimetable): +class AssetTriggeredTimetable(_TrivialTimetable): """ Timetable that never schedules anything. - This should not be directly used anywhere, but only set if a DAG is triggered by datasets. + This should not be directly used anywhere, but only set if a DAG is triggered by assets. :meta private: """ - UNRESOLVED_ALIAS_SUMMARY = "Unresolved DatasetAlias" + UNRESOLVED_ALIAS_SUMMARY = "Unresolved AssetAlias" - description: str = "Triggered by datasets" + description: str = "Triggered by assets" - def __init__(self, datasets: BaseDataset) -> None: + def __init__(self, assets: BaseAsset) -> None: super().__init__() - self.dataset_condition = datasets - if isinstance(self.dataset_condition, DatasetAlias): - self.dataset_condition = _DatasetAliasCondition(self.dataset_condition.name) + self.asset_condition = assets + if isinstance(self.asset_condition, AssetAlias): + self.asset_condition = _AssetAliasCondition(self.asset_condition.name) - if not next(self.dataset_condition.iter_datasets(), False): - self._summary = DatasetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY + if not next(self.asset_condition.iter_assets(), False): + self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY else: - self._summary = "Dataset" + self._summary = "Asset" @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: - from airflow.serialization.serialized_objects import decode_dataset_condition + from airflow.serialization.serialized_objects import decode_asset_condition - return cls(decode_dataset_condition(data["dataset_condition"])) + return cls(decode_asset_condition(data["asset_condition"])) @property def summary(self) -> str: return self._summary def serialize(self) -> dict[str, Any]: - from airflow.serialization.serialized_objects import encode_dataset_condition + from airflow.serialization.serialized_objects import encode_asset_condition - return {"dataset_condition": encode_dataset_condition(self.dataset_condition)} + return {"asset_condition": encode_asset_condition(self.asset_condition)} def generate_run_id( self, @@ -198,7 +198,7 @@ def generate_run_id( logical_date: DateTime, data_interval: DataInterval | None, session: Session | None = None, - events: Collection[DatasetEvent] | None = None, + events: Collection[AssetEvent] | None = None, **extra, ) -> str: from airflow.models.dagrun import DagRun @@ -208,7 +208,7 @@ def generate_run_id( def data_interval_for_events( self, logical_date: DateTime, - events: Collection[DatasetEvent], + events: Collection[AssetEvent], ) -> DataInterval: if not events: return DataInterval(logical_date, logical_date) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index b8021fed9be3c..46694939ed74e 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1,20 +1,20 @@ // generated with @7nohe/openapi-react-query-codegen@1.6.0 import { UseQueryResult } from "@tanstack/react-query"; -import { DagService, DatasetService } from "../requests/services.gen"; +import { AssetService, DagService } from "../requests/services.gen"; import { DagRunState } from "../requests/types.gen"; -export type DatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetDefaultResponse = +export type AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse = Awaited< - ReturnType + ReturnType >; -export type DatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetQueryResult< - TData = DatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetDefaultResponse, +export type AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetQueryResult< + TData = AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse, TError = unknown, > = UseQueryResult; -export const useDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetKey = - "DatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGet"; -export const UseDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetKeyFn = ( +export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKey = + "AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet"; +export const UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn = ( { dagId, }: { @@ -22,7 +22,7 @@ export const UseDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetKeyFn = ( }, queryKey?: Array, ) => [ - useDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetKey, + useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKey, ...(queryKey ?? [{ dagId }]), ]; export type DagServiceGetDagsPublicDagsGetDefaultResponse = Awaited< diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow/ui/openapi-gen/queries/prefetch.ts index 6dd99f96b8425..7de7282a9bd01 100644 --- a/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow/ui/openapi-gen/queries/prefetch.ts @@ -1,34 +1,32 @@ // generated with @7nohe/openapi-react-query-codegen@1.6.0 import { type QueryClient } from "@tanstack/react-query"; -import { DagService, DatasetService } from "../requests/services.gen"; +import { AssetService, DagService } from "../requests/services.gen"; import { DagRunState } from "../requests/types.gen"; import * as Common from "./common"; /** - * Next Run Datasets + * Next Run Assets * @param data The data for the request. * @param data.dagId * @returns unknown Successful Response * @throws ApiError */ -export const prefetchUseDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGet = - ( - queryClient: QueryClient, - { - dagId, - }: { - dagId: string; - }, - ) => - queryClient.prefetchQuery({ - queryKey: - Common.UseDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetKeyFn({ - dagId, - }), - queryFn: () => - DatasetService.nextRunDatasetsUiNextRunDatasetsDagIdGet({ dagId }), - }); +export const prefetchUseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = ( + queryClient: QueryClient, + { + dagId, + }: { + dagId: string; + }, +) => + queryClient.prefetchQuery({ + queryKey: Common.UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn( + { dagId }, + ), + queryFn: () => + AssetService.nextRunAssetsUiNextRunDatasetsDagIdGet({ dagId }), + }); /** * Get Dags * Get all DAGs. diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index b771fccfeb947..7cbaac5b2c77d 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -6,19 +6,19 @@ import { UseQueryOptions, } from "@tanstack/react-query"; -import { DagService, DatasetService } from "../requests/services.gen"; +import { AssetService, DagService } from "../requests/services.gen"; import { DAGPatchBody, DagRunState } from "../requests/types.gen"; import * as Common from "./common"; /** - * Next Run Datasets + * Next Run Assets * @param data The data for the request. * @param data.dagId * @returns unknown Successful Response * @throws ApiError */ -export const useDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGet = < - TData = Common.DatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetDefaultResponse, +export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = < + TData = Common.AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse, TError = unknown, TQueryKey extends Array = unknown[], >( @@ -31,15 +31,12 @@ export const useDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGet = < options?: Omit, "queryKey" | "queryFn">, ) => useQuery({ - queryKey: - Common.UseDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetKeyFn( - { dagId }, - queryKey, - ), + queryKey: Common.UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn( + { dagId }, + queryKey, + ), queryFn: () => - DatasetService.nextRunDatasetsUiNextRunDatasetsDagIdGet({ - dagId, - }) as TData, + AssetService.nextRunAssetsUiNextRunDatasetsDagIdGet({ dagId }) as TData, ...options, }); /** diff --git a/airflow/ui/openapi-gen/queries/suspense.ts b/airflow/ui/openapi-gen/queries/suspense.ts index 7743ce92d2855..18dba7acb4b5b 100644 --- a/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow/ui/openapi-gen/queries/suspense.ts @@ -1,43 +1,39 @@ // generated with @7nohe/openapi-react-query-codegen@1.6.0 import { UseQueryOptions, useSuspenseQuery } from "@tanstack/react-query"; -import { DagService, DatasetService } from "../requests/services.gen"; +import { AssetService, DagService } from "../requests/services.gen"; import { DagRunState } from "../requests/types.gen"; import * as Common from "./common"; /** - * Next Run Datasets + * Next Run Assets * @param data The data for the request. * @param data.dagId * @returns unknown Successful Response * @throws ApiError */ -export const useDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetSuspense = - < - TData = Common.DatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetDefaultResponse, - TError = unknown, - TQueryKey extends Array = unknown[], - >( - { - dagId, - }: { - dagId: string; - }, - queryKey?: TQueryKey, - options?: Omit, "queryKey" | "queryFn">, - ) => - useSuspenseQuery({ - queryKey: - Common.UseDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetKeyFn( - { dagId }, - queryKey, - ), - queryFn: () => - DatasetService.nextRunDatasetsUiNextRunDatasetsDagIdGet({ - dagId, - }) as TData, - ...options, - }); +export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetSuspense = < + TData = Common.AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse, + TError = unknown, + TQueryKey extends Array = unknown[], +>( + { + dagId, + }: { + dagId: string; + }, + queryKey?: TQueryKey, + options?: Omit, "queryKey" | "queryFn">, +) => + useSuspenseQuery({ + queryKey: Common.UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn( + { dagId }, + queryKey, + ), + queryFn: () => + AssetService.nextRunAssetsUiNextRunDatasetsDagIdGet({ dagId }) as TData, + ...options, + }); /** * Get Dags * Get all DAGs. diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 37a4d11873acf..5aa5876d112ad 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3,25 +3,25 @@ import type { CancelablePromise } from "./core/CancelablePromise"; import { OpenAPI } from "./core/OpenAPI"; import { request as __request } from "./core/request"; import type { - NextRunDatasetsUiNextRunDatasetsDagIdGetData, - NextRunDatasetsUiNextRunDatasetsDagIdGetResponse, + NextRunAssetsUiNextRunDatasetsDagIdGetData, + NextRunAssetsUiNextRunDatasetsDagIdGetResponse, GetDagsPublicDagsGetData, GetDagsPublicDagsGetResponse, PatchDagPublicDagsDagIdPatchData, PatchDagPublicDagsDagIdPatchResponse, } from "./types.gen"; -export class DatasetService { +export class AssetService { /** - * Next Run Datasets + * Next Run Assets * @param data The data for the request. * @param data.dagId * @returns unknown Successful Response * @throws ApiError */ - public static nextRunDatasetsUiNextRunDatasetsDagIdGet( - data: NextRunDatasetsUiNextRunDatasetsDagIdGetData, - ): CancelablePromise { + public static nextRunAssetsUiNextRunDatasetsDagIdGet( + data: NextRunAssetsUiNextRunDatasetsDagIdGetData, + ): CancelablePromise { return __request(OpenAPI, { method: "GET", url: "/ui/next_run_datasets/{dag_id}", diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 16977004e79d6..bc455f63b6449 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -88,11 +88,11 @@ export type ValidationError = { type: string; }; -export type NextRunDatasetsUiNextRunDatasetsDagIdGetData = { +export type NextRunAssetsUiNextRunDatasetsDagIdGetData = { dagId: string; }; -export type NextRunDatasetsUiNextRunDatasetsDagIdGetResponse = { +export type NextRunAssetsUiNextRunDatasetsDagIdGetResponse = { [key: string]: unknown; }; @@ -122,7 +122,7 @@ export type PatchDagPublicDagsDagIdPatchResponse = DAGResponse; export type $OpenApiTs = { "/ui/next_run_datasets/{dag_id}": { get: { - req: NextRunDatasetsUiNextRunDatasetsDagIdGetData; + req: NextRunAssetsUiNextRunDatasetsDagIdGetData; res: { /** * Successful Response diff --git a/airflow/utils/context.py b/airflow/utils/context.py index a72885401f7b2..e5d30b1e2d7d2 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -40,14 +40,14 @@ import lazy_object_proxy from sqlalchemy import select -from airflow.datasets import ( - Dataset, - DatasetAlias, - DatasetAliasEvent, +from airflow.assets import ( + Asset, + AssetAlias, + AssetAliasEvent, extract_event_key, ) from airflow.exceptions import RemovedInAirflow3Warning -from airflow.models.dataset import DatasetAliasModel, DatasetEvent, DatasetModel +from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel from airflow.utils.db import LazySelectSequence from airflow.utils.types import NOTSET @@ -102,7 +102,7 @@ "ti", "tomorrow_ds", "tomorrow_ds_nodash", - "triggering_dataset_events", + "triggering_asset_events", "ts", "ts_nodash", "ts_nodash_with_tz", @@ -165,40 +165,40 @@ def get(self, key: str, default_conn: Any = None) -> Any: @attrs.define() class OutletEventAccessor: """ - Wrapper to access an outlet dataset event in template. + Wrapper to access an outlet asset event in template. :meta private: """ - raw_key: str | Dataset | DatasetAlias + raw_key: str | Asset | AssetAlias extra: dict[str, Any] = attrs.Factory(dict) - dataset_alias_events: list[DatasetAliasEvent] = attrs.field(factory=list) - - def add(self, dataset: Dataset | str, extra: dict[str, Any] | None = None) -> None: - """Add a DatasetEvent to an existing Dataset.""" - if isinstance(dataset, str): - dataset_uri = dataset - elif isinstance(dataset, Dataset): - dataset_uri = dataset.uri + asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) + + def add(self, asset: Asset | str, extra: dict[str, Any] | None = None) -> None: + """Add an AssetEvent to an existing Asset.""" + if isinstance(asset, str): + asset_uri = asset + elif isinstance(asset, Asset): + asset_uri = asset.uri else: return if isinstance(self.raw_key, str): - dataset_alias_name = self.raw_key - elif isinstance(self.raw_key, DatasetAlias): - dataset_alias_name = self.raw_key.name + asset_alias_name = self.raw_key + elif isinstance(self.raw_key, AssetAlias): + asset_alias_name = self.raw_key.name else: return - event = DatasetAliasEvent( - source_alias_name=dataset_alias_name, dest_dataset_uri=dataset_uri, extra=extra or {} + event = AssetAliasEvent( + source_alias_name=asset_alias_name, dest_asset_uri=asset_uri, extra=extra or {} ) - self.dataset_alias_events.append(event) + self.asset_alias_events.append(event) class OutletEventAccessors(Mapping[str, OutletEventAccessor]): """ - Lazy mapping of outlet dataset event accessors. + Lazy mapping of outlet asset event accessors. :meta private: """ @@ -215,53 +215,53 @@ def __iter__(self) -> Iterator[str]: def __len__(self) -> int: return len(self._dict) - def __getitem__(self, key: str | Dataset | DatasetAlias) -> OutletEventAccessor: + def __getitem__(self, key: str | Asset | AssetAlias) -> OutletEventAccessor: event_key = extract_event_key(key) if event_key not in self._dict: self._dict[event_key] = OutletEventAccessor(extra={}, raw_key=key) return self._dict[event_key] -class LazyDatasetEventSelectSequence(LazySelectSequence[DatasetEvent]): +class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]): """ - List-like interface to lazily access DatasetEvent rows. + List-like interface to lazily access AssetEvent rows. :meta private: """ @staticmethod def _rebuild_select(stmt: TextClause) -> Select: - return select(DatasetEvent).from_statement(stmt) + return select(AssetEvent).from_statement(stmt) @staticmethod - def _process_row(row: Row) -> DatasetEvent: + def _process_row(row: Row) -> AssetEvent: return row[0] @attrs.define(init=False) -class InletEventsAccessors(Mapping[str, LazyDatasetEventSelectSequence]): +class InletEventsAccessors(Mapping[str, LazyAssetEventSelectSequence]): """ - Lazy mapping for inlet dataset events accessors. + Lazy mapping for inlet asset events accessors. :meta private: """ _inlets: list[Any] - _datasets: dict[str, Dataset] - _dataset_aliases: dict[str, DatasetAlias] + _assets: dict[str, Asset] + _asset_aliases: dict[str, AssetAlias] _session: Session def __init__(self, inlets: list, *, session: Session) -> None: self._inlets = inlets self._session = session - self._datasets = {} - self._dataset_aliases = {} + self._assets = {} + self._asset_aliases = {} for inlet in inlets: - if isinstance(inlet, Dataset): - self._datasets[inlet.uri] = inlet - elif isinstance(inlet, DatasetAlias): - self._dataset_aliases[inlet.name] = inlet + if isinstance(inlet, Asset): + self._assets[inlet.uri] = inlet + elif isinstance(inlet, AssetAlias): + self._asset_aliases[inlet.name] = inlet def __iter__(self) -> Iterator[str]: return iter(self._inlets) @@ -269,28 +269,28 @@ def __iter__(self) -> Iterator[str]: def __len__(self) -> int: return len(self._inlets) - def __getitem__(self, key: int | str | Dataset | DatasetAlias) -> LazyDatasetEventSelectSequence: + def __getitem__(self, key: int | str | Asset | AssetAlias) -> LazyAssetEventSelectSequence: if isinstance(key, int): # Support index access; it's easier for trivial cases. obj = self._inlets[key] - if not isinstance(obj, (Dataset, DatasetAlias)): + if not isinstance(obj, (Asset, AssetAlias)): raise IndexError(key) else: obj = key - if isinstance(obj, DatasetAlias): - dataset_alias = self._dataset_aliases[obj.name] - join_clause = DatasetEvent.source_aliases - where_clause = DatasetAliasModel.name == dataset_alias.name - elif isinstance(obj, (Dataset, str)): - dataset = self._datasets[extract_event_key(obj)] - join_clause = DatasetEvent.dataset - where_clause = DatasetModel.uri == dataset.uri + if isinstance(obj, AssetAlias): + asset_alias = self._asset_aliases[obj.name] + join_clause = AssetEvent.source_aliases + where_clause = AssetAliasModel.name == asset_alias.name + elif isinstance(obj, (Asset, str)): + asset = self._assets[extract_event_key(obj)] + join_clause = AssetEvent.dataset + where_clause = AssetModel.uri == asset.uri else: raise ValueError(key) - return LazyDatasetEventSelectSequence.from_select( - select(DatasetEvent).join(join_clause).where(where_clause), - order_by=[DatasetEvent.timestamp], + return LazyAssetEventSelectSequence.from_select( + select(AssetEvent).join(join_clause).where(where_clause), + order_by=[AssetEvent.timestamp], session=self._session, ) diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index 658aac5839ec5..4dc4659548ac0 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -31,16 +31,16 @@ from typing import Any, Collection, Container, Iterable, Iterator, Mapping, Sequ from pendulum import DateTime from sqlalchemy.orm import Session +from airflow.assets import Asset, AssetAlias, AssetAliasEvent from airflow.configuration import AirflowConfigParser -from airflow.datasets import Dataset, DatasetAlias, DatasetAliasEvent +from airflow.models.asset import AssetEvent from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG from airflow.models.dagrun import DagRun -from airflow.models.dataset import DatasetEvent from airflow.models.param import ParamsDict from airflow.models.taskinstance import TaskInstance +from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic -from airflow.serialization.pydantic.dataset import DatasetEventPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.typing_compat import TypedDict @@ -62,31 +62,31 @@ class OutletEventAccessor: self, *, extra: dict[str, Any], - raw_key: str | Dataset | DatasetAlias, - dataset_alias_events: list[DatasetAliasEvent], + raw_key: str | Asset | AssetAlias, + asset_alias_events: list[AssetAliasEvent], ) -> None: ... - def add(self, dataset: Dataset | str, extra: dict[str, Any] | None = None) -> None: ... + def add(self, asset: Asset | str, extra: dict[str, Any] | None = None) -> None: ... extra: dict[str, Any] - raw_key: str | Dataset | DatasetAlias - dataset_alias_events: list[DatasetAliasEvent] + raw_key: str | Asset | AssetAlias + asset_alias_events: list[AssetAliasEvent] class OutletEventAccessors(Mapping[str, OutletEventAccessor]): def __iter__(self) -> Iterator[str]: ... def __len__(self) -> int: ... - def __getitem__(self, key: str | Dataset | DatasetAlias) -> OutletEventAccessor: ... + def __getitem__(self, key: str | Asset | AssetAlias) -> OutletEventAccessor: ... -class InletEventsAccessor(Sequence[DatasetEvent]): +class InletEventsAccessor(Sequence[AssetEvent]): @overload - def __getitem__(self, key: int) -> DatasetEvent: ... + def __getitem__(self, key: int) -> AssetEvent: ... @overload - def __getitem__(self, key: slice) -> Sequence[DatasetEvent]: ... + def __getitem__(self, key: slice) -> Sequence[AssetEvent]: ... def __len__(self) -> int: ... class InletEventsAccessors(Mapping[str, InletEventsAccessor]): def __init__(self, inlets: list, *, session: Session) -> None: ... def __iter__(self) -> Iterator[str]: ... def __len__(self) -> int: ... - def __getitem__(self, key: int | str | Dataset | DatasetAlias) -> InletEventsAccessor: ... + def __getitem__(self, key: int | str | Asset | AssetAlias) -> InletEventsAccessor: ... # NOTE: Please keep this in sync with the following: # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py @@ -132,7 +132,7 @@ class Context(TypedDict, total=False): ti: TaskInstance | TaskInstancePydantic tomorrow_ds: str tomorrow_ds_nodash: str - triggering_dataset_events: Mapping[str, Collection[DatasetEvent | DatasetEventPydantic]] + triggering_asset_events: Mapping[str, Collection[AssetEvent | AssetEventPydantic]] ts: str ts_nodash: str ts_nodash_with_tz: str diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index 108a84a9eabb9..6e5e03d8d163b 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -251,7 +251,7 @@ def __init__( def run(self, *args, **kwargs) -> Any: import inspect - from airflow.datasets.metadata import Metadata + from airflow.assets.metadata import Metadata from airflow.utils.types import NOTSET if not inspect.isgeneratorfunction(self.func): diff --git a/airflow/www/auth.py b/airflow/www/auth.py index 47a06f52e94bc..74f31d135c1aa 100644 --- a/airflow/www/auth.py +++ b/airflow/www/auth.py @@ -262,9 +262,9 @@ def decorated(*args, **kwargs): return has_access_decorator -def has_access_dataset(method: ResourceMethod) -> Callable[[T], T]: - """Check current user's permissions against required permissions for datasets.""" - return _has_access_no_details(lambda: get_auth_manager().is_authorized_dataset(method=method)) +def has_access_asset(method: ResourceMethod) -> Callable[[T], T]: + """Check current user's permissions against required permissions for assets.""" + return _has_access_no_details(lambda: get_auth_manager().is_authorized_asset(method=method)) def has_access_pool(method: ResourceMethod) -> Callable[[T], T]: diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py index 926148f7eba86..77fd653b5f416 100644 --- a/airflow/www/security_manager.py +++ b/airflow/www/security_manager.py @@ -40,6 +40,7 @@ from airflow.models import Connection, DagRun, Pool, TaskInstance, Variable from airflow.security.permissions import ( RESOURCE_ADMIN_MENU, + RESOURCE_ASSET, RESOURCE_AUDIT_LOG, RESOURCE_BROWSE_MENU, RESOURCE_CLUSTER_ACTIVITY, @@ -49,7 +50,6 @@ RESOURCE_DAG_CODE, RESOURCE_DAG_DEPENDENCIES, RESOURCE_DAG_RUN, - RESOURCE_DATASET, RESOURCE_DOCS, RESOURCE_DOCS_MENU, RESOURCE_JOB, @@ -253,7 +253,7 @@ def _is_authorized_dag(entity_=None, details_func_=None): details=ConnectionDetails(conn_id=get_connection_id(resource_pk)), user=user, ), - RESOURCE_DATASET: lambda action, resource_pk, user: auth_manager.is_authorized_dataset( + RESOURCE_ASSET: lambda action, resource_pk, user: auth_manager.is_authorized_asset( method=methods[action], user=user, ), diff --git a/airflow/www/static/css/graph.css b/airflow/www/static/css/graph.css index f175a7e025d78..16dc5186af14f 100644 --- a/airflow/www/static/css/graph.css +++ b/airflow/www/static/css/graph.css @@ -161,12 +161,12 @@ g.node text { background-color: #e6f1f2; } -.legend-item.dataset { +.legend-item.asset { float: left; background-color: #fcecd4; } -.legend-item.dataset-alias { +.legend-item.asset-alias { float: left; background-color: #e8cfe4; } @@ -183,10 +183,10 @@ g.node.sensor rect { fill: #e6f1f2; } -g.node.dataset rect { +g.node.asset rect { fill: #fcecd4; } -g.node.dataset-alias rect { +g.node.asset-alias rect { fill: #e8cfe4; } diff --git a/airflow/www/static/js/dag/details/graph/Node.tsx b/airflow/www/static/js/dag/details/graph/Node.tsx index a4e9dee4c8074..daedfb8524e0b 100644 --- a/airflow/www/static/js/dag/details/graph/Node.tsx +++ b/airflow/www/static/js/dag/details/graph/Node.tsx @@ -94,7 +94,7 @@ const Node = (props: NodeProps) => { ); } - if (data.class === "dataset") return ; + if (data.class === "asset") return ; return ; }; diff --git a/airflow/www/static/js/dag/details/graph/index.tsx b/airflow/www/static/js/dag/details/graph/index.tsx index 51dd20d88b105..edafc99fe9c34 100644 --- a/airflow/www/static/js/dag/details/graph/index.tsx +++ b/airflow/www/static/js/dag/details/graph/index.tsx @@ -105,7 +105,7 @@ const getUpstreamDatasets = ( nodes.push({ id: d, value: { - class: "dataset", + class: "asset", label: d, }, }); @@ -202,7 +202,7 @@ const Graph = ({ openGroupIds, onToggleGroups, hoveredTaskState }: Props) => { datasetNodes.push({ id: dataset.uri, value: { - class: "dataset", + class: "asset", label: dataset.uri, }, }); @@ -221,7 +221,7 @@ const Graph = ({ openGroupIds, onToggleGroups, hoveredTaskState }: Props) => { datasetNodes.push({ id: de.datasetUri, value: { - class: "dataset", + class: "asset", label: de.datasetUri, }, }); diff --git a/airflow/www/static/js/dag/details/graph/utils.ts b/airflow/www/static/js/dag/details/graph/utils.ts index 93c8b9c253016..2fb7351525e71 100644 --- a/airflow/www/static/js/dag/details/graph/utils.ts +++ b/airflow/www/static/js/dag/details/graph/utils.ts @@ -92,7 +92,7 @@ export const flattenNodes = ({ onToggleGroups(newGroupIds); }, datasetEvent: - node.value.class === "dataset" + node.value.class === "asset" ? datasetEvents?.find((de) => de.datasetUri === node.value.label) : undefined, ...node.value, diff --git a/airflow/www/static/js/datasets/Graph/Node.tsx b/airflow/www/static/js/datasets/Graph/Node.tsx index baef11aa83663..dfb0cf8ed4deb 100644 --- a/airflow/www/static/js/datasets/Graph/Node.tsx +++ b/airflow/www/static/js/datasets/Graph/Node.tsx @@ -70,10 +70,10 @@ const BaseNode = ({ justifyContent="space-between" alignItems="center" > - {type === "dataset" && } + {type === "asset" && } {type === "sensor" && } {type === "trigger" && } - {type === "dataset-alias" && } + {type === "asset-alias" && } {label} )} diff --git a/airflow/www/static/js/datasets/Graph/index.tsx b/airflow/www/static/js/datasets/Graph/index.tsx index 9157a8a1a2617..e960c48ff63a2 100644 --- a/airflow/www/static/js/datasets/Graph/index.tsx +++ b/airflow/www/static/js/datasets/Graph/index.tsx @@ -87,7 +87,7 @@ const Graph = ({ selectedNodeId, onSelect }: Props) => { height: c.height, onSelect: () => { if (onSelect) { - if (c.value.class === "dataset") onSelect({ uri: c.value.label }); + if (c.value.class === "asset") onSelect({ uri: c.value.label }); else if (c.value.class === "dag") onSelect({ dagId: c.value.label }); } diff --git a/airflow/www/static/js/datasets/SearchBar.tsx b/airflow/www/static/js/datasets/SearchBar.tsx index fc47215389a57..33476b8419fdb 100644 --- a/airflow/www/static/js/datasets/SearchBar.tsx +++ b/airflow/www/static/js/datasets/SearchBar.tsx @@ -46,7 +46,7 @@ const SearchBar = ({ (datasetDependencies?.nodes || []).forEach((node) => { if (node.value.class === "dag") dagOptions.push({ value: node.id, label: node.value.label }); - if (node.value.class === "dataset") + if (node.value.class === "asset") datasetOptions.push({ value: node.id, label: node.value.label }); }); diff --git a/airflow/www/static/js/types/index.ts b/airflow/www/static/js/types/index.ts index b390568c8cf3b..1ce07bb350795 100644 --- a/airflow/www/static/js/types/index.ts +++ b/airflow/www/static/js/types/index.ts @@ -135,12 +135,12 @@ interface DepNode { id?: string; class: | "dag" - | "dataset" + | "asset" | "trigger" | "sensor" | "or-gate" | "and-gate" - | "dataset-alias"; + | "asset-alias"; label: string; rx?: number; ry?: number; diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html index b0c00bcd5c88f..0d3a2cf1770ca 100644 --- a/airflow/www/templates/airflow/dag.html +++ b/airflow/www/templates/airflow/dag.html @@ -149,29 +149,29 @@

- {% if ds_info.total == 1 -%} - On {{ ds_info.uri }} + {% if asset_info.total == 1 -%} + On {{ asset_info.uri }} {%- else -%} - {{ ds_info.ready }} of {{ ds_info.total }} datasets updated + {{ asset_info.ready }} of {{ asset_info.total }} datasets updated {%- endif %}

diff --git a/airflow/www/templates/airflow/dag_dependencies.html b/airflow/www/templates/airflow/dag_dependencies.html index 542a5b7b47b8f..393b8d0de694d 100644 --- a/airflow/www/templates/airflow/dag_dependencies.html +++ b/airflow/www/templates/airflow/dag_dependencies.html @@ -43,8 +43,8 @@

dag trigger sensor - dataset - dataset alias + asset + asset alias
Last refresh:
diff --git a/airflow/www/templates/airflow/dags.html b/airflow/www/templates/airflow/dags.html index ca374c665aea0..c629936df7c00 100644 --- a/airflow/www/templates/airflow/dags.html +++ b/airflow/www/templates/airflow/dags.html @@ -304,31 +304,31 @@

{{ page_title }}

info -
+ {table.getHeaderGroups().map((headerGroup) => ( {headerGroup.headers.map( diff --git a/airflow/ui/src/components/TogglePause.tsx b/airflow/ui/src/components/TogglePause.tsx new file mode 100644 index 0000000000000..50362187c8ad7 --- /dev/null +++ b/airflow/ui/src/components/TogglePause.tsx @@ -0,0 +1,56 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { Switch } from "@chakra-ui/react"; +import { useQueryClient } from "@tanstack/react-query"; +import { useCallback } from "react"; + +import { + useDagServiceGetDagsKey, + useDagServicePatchDag, +} from "openapi/queries"; + +type Props = { + readonly dagId: string; + readonly isPaused: boolean; +}; + +export const TogglePause = ({ dagId, isPaused }: Props) => { + const queryClient = useQueryClient(); + + const onSuccess = async () => { + await queryClient.invalidateQueries({ + queryKey: [useDagServiceGetDagsKey], + }); + }; + + const { mutate } = useDagServicePatchDag({ + onSuccess, + }); + + const onChange = useCallback(() => { + mutate({ + dagId, + requestBody: { + is_paused: !isPaused, + }, + }); + }, [dagId, isPaused, mutate]); + + return ; +}; diff --git a/airflow/ui/src/pages/DagsList/DagsFilters.tsx b/airflow/ui/src/pages/DagsList/DagsFilters.tsx new file mode 100644 index 0000000000000..cb2be8322e500 --- /dev/null +++ b/airflow/ui/src/pages/DagsList/DagsFilters.tsx @@ -0,0 +1,86 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { HStack, Select, Text, Box } from "@chakra-ui/react"; +import { Select as ReactSelect } from "chakra-react-select"; +import { useCallback } from "react"; +import { useSearchParams } from "react-router-dom"; + +import { useTableURLState } from "src/components/DataTable/useTableUrlState"; +import { QuickFilterButton } from "src/components/QuickFilterButton"; + +const PAUSED_PARAM = "paused"; + +export const DagsFilters = () => { + const [searchParams, setSearchParams] = useSearchParams(); + + const showPaused = searchParams.get(PAUSED_PARAM); + + const { setTableURLState, tableURLState } = useTableURLState(); + const { pagination, sorting } = tableURLState; + + const handlePausedChange: React.ChangeEventHandler = + useCallback( + ({ target: { value } }) => { + if (value === "All") { + searchParams.delete(PAUSED_PARAM); + } else { + searchParams.set(PAUSED_PARAM, value); + } + setSearchParams(searchParams); + setTableURLState({ + pagination: { ...pagination, pageIndex: 0 }, + sorting, + }); + }, + [pagination, searchParams, setSearchParams, setTableURLState, sorting], + ); + + return ( + + + + + State: + + + All + Failed + Running + Successful + + + + + Active: + + + + + + + ); +}; diff --git a/airflow/ui/src/pages/DagsList.tsx b/airflow/ui/src/pages/DagsList/DagsList.tsx similarity index 72% rename from airflow/ui/src/pages/DagsList.tsx rename to airflow/ui/src/pages/DagsList/DagsList.tsx index ab480d2cbabdb..d58e3eaa2038c 100644 --- a/airflow/ui/src/pages/DagsList.tsx +++ b/airflow/ui/src/pages/DagsList/DagsList.tsx @@ -18,7 +18,6 @@ */ import { Badge, - Checkbox, Heading, HStack, Select, @@ -26,30 +25,36 @@ import { VStack, } from "@chakra-ui/react"; import type { ColumnDef } from "@tanstack/react-table"; -import { Select as ReactSelect } from "chakra-react-select"; import { type ChangeEventHandler, useCallback } from "react"; import { useSearchParams } from "react-router-dom"; import { useDagServiceGetDags } from "openapi/queries"; import type { DAGResponse } from "openapi/requests/types.gen"; +import { DataTable } from "src/components/DataTable"; +import { useTableURLState } from "src/components/DataTable/useTableUrlState"; +import { SearchBar } from "src/components/SearchBar"; +import { TogglePause } from "src/components/TogglePause"; +import { pluralize } from "src/utils/pluralize"; -import { DataTable } from "../components/DataTable"; -import { useTableURLState } from "../components/DataTable/useTableUrlState"; -import { QuickFilterButton } from "../components/QuickFilterButton"; -import { SearchBar } from "../components/SearchBar"; -import { pluralize } from "../utils/pluralize"; +import { DagsFilters } from "./DagsFilters"; const columns: Array> = [ + { + accessorKey: "is_paused", + cell: ({ row }) => ( + + ), + enableSorting: false, + header: "", + }, { accessorKey: "dag_id", cell: ({ row }) => row.original.dag_display_name, header: "DAG", }, - { - accessorKey: "is_paused", - enableSorting: false, - header: () => "Is Paused", - }, { accessorKey: "timetable_description", cell: (info) => @@ -82,9 +87,9 @@ const PAUSED_PARAM = "paused"; // eslint-disable-next-line complexity export const DagsList = ({ cardView = false }) => { - const [searchParams, setSearchParams] = useSearchParams(); + const [searchParams] = useSearchParams(); - const showPaused = searchParams.get(PAUSED_PARAM) === "true"; + const showPaused = searchParams.get(PAUSED_PARAM); const { setTableURLState, tableURLState } = useTableURLState(); const { pagination, sorting } = tableURLState; @@ -98,22 +103,9 @@ export const DagsList = ({ cardView = false }) => { offset: pagination.pageIndex * pagination.pageSize, onlyActive: true, orderBy, - paused: showPaused, + paused: showPaused === null ? undefined : showPaused === "true", }); - const handlePausedChange = useCallback(() => { - searchParams[showPaused ? "delete" : "set"](PAUSED_PARAM, "true"); - setSearchParams(searchParams); - setTableURLState({ pagination: { ...pagination, pageIndex: 0 }, sorting }); - }, [ - pagination, - searchParams, - setSearchParams, - setTableURLState, - showPaused, - sorting, - ]); - const handleSortChange = useCallback>( ({ currentTarget: { value } }) => { setTableURLState({ @@ -136,20 +128,7 @@ export const DagsList = ({ cardView = false }) => { buttonProps={{ isDisabled: true }} inputProps={{ isDisabled: true }} /> - - - - All - Failed - Running - Successful - - - Show Paused DAGs - - - - + {pluralize("DAG", data?.total_entries)} diff --git a/airflow/ui/src/pages/DagsList/index.tsx b/airflow/ui/src/pages/DagsList/index.tsx new file mode 100644 index 0000000000000..df59a682abb67 --- /dev/null +++ b/airflow/ui/src/pages/DagsList/index.tsx @@ -0,0 +1,20 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export { DagsList } from "./DagsList"; From a12d78cbd7224e074eadcdb0ec21c85fd5ff2d33 Mon Sep 17 00:00:00 2001 From: jonhspyro <121674572+jonhspyro@users.noreply.github.com> Date: Wed, 2 Oct 2024 11:36:27 +0100 Subject: [PATCH 251/349] Correctly select task in DAG Graph View when clicking on its name (#38782) * Fix in DAG Graph View, clicking Task on it's name doesn't select the task. (#37932) * Updated TaskName onClick * Fixed missing onToggleCollapse * Added missing changes * Updated: rebase * fixed providers error message * undo fab changes * Update user_command.py --------- Co-authored-by: Brent Bovenzi --- .../static/js/dag/details/graph/DagNode.test.tsx | 14 +++++++++++++- .../www/static/js/dag/details/graph/DagNode.tsx | 8 +++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/airflow/www/static/js/dag/details/graph/DagNode.test.tsx b/airflow/www/static/js/dag/details/graph/DagNode.test.tsx index 34ddac7506c71..7c6dea7584a0b 100644 --- a/airflow/www/static/js/dag/details/graph/DagNode.test.tsx +++ b/airflow/www/static/js/dag/details/graph/DagNode.test.tsx @@ -20,7 +20,7 @@ /* global describe, test, expect */ import React from "react"; -import { render } from "@testing-library/react"; +import { fireEvent, render } from "@testing-library/react"; import { Wrapper } from "src/utils/testUtils"; @@ -124,4 +124,16 @@ describe("Test Graph Node", () => { expect(getByTestId("node")).toHaveStyle("opacity: 0.3"); }); + + test("Clicks on taskName", async () => { + const { getByText } = render(, { + wrapper: Wrapper, + }); + + const taskName = getByText("task_id"); + + fireEvent.click(taskName); + + expect(taskName).toBeInTheDocument(); + }); }); diff --git a/airflow/www/static/js/dag/details/graph/DagNode.tsx b/airflow/www/static/js/dag/details/graph/DagNode.tsx index c2f9b01296c35..4ac1be8ef4ade 100644 --- a/airflow/www/static/js/dag/details/graph/DagNode.tsx +++ b/airflow/www/static/js/dag/details/graph/DagNode.tsx @@ -42,10 +42,10 @@ const DagNode = ({ task, isSelected, latestDagRunId, - onToggleCollapse, isOpen, isActive, setupTeardownType, + onToggleCollapse, labelStyle, style, isZoomedOut, @@ -139,8 +139,10 @@ const DagNode = ({ isOpen={isOpen} isGroup={!!childCount} onClick={(e) => { - e.stopPropagation(); - onToggleCollapse(); + if (childCount) { + e.stopPropagation(); + onToggleCollapse(); + } }} setupTeardownType={setupTeardownType} isZoomedOut={isZoomedOut} From 78a84fd251e730c01f451452a8a64980391d8b66 Mon Sep 17 00:00:00 2001 From: GPK Date: Wed, 2 Oct 2024 12:48:12 +0100 Subject: [PATCH 252/349] =?UTF-8?q?Revert=20"Move=20FSHook/PackageIndexHoo?= =?UTF-8?q?k/SubprocessHook=20to=20standard=20provider=20(#42=E2=80=A6"=20?= =?UTF-8?q?(#42659)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 61d1dbbc7feb9728da125dc00ad05314758036eb. --- .../standard => }/hooks/filesystem.py | 0 .../standard => }/hooks/package_index.py | 0 .../standard => }/hooks/subprocess.py | 4 +- airflow/operators/bash.py | 2 +- airflow/providers/standard/hooks/__init__.py | 16 -------- airflow/providers/standard/provider.yaml | 7 ---- airflow/providers_manager.py | 4 +- airflow/sensors/filesystem.py | 2 +- .../logging-monitoring/errors.rst | 2 +- .../operators-and-hooks-ref.rst | 4 +- .../standard => }/hooks/test_package_index.py | 6 +-- .../standard => }/hooks/test_subprocess.py | 6 +-- tests/providers/standard/hooks/__init__.py | 16 -------- .../standard/hooks/test_filesystem.py | 39 ------------------- tests/sensors/test_filesystem.py | 2 +- 15 files changed, 16 insertions(+), 94 deletions(-) rename airflow/{providers/standard => }/hooks/filesystem.py (100%) rename airflow/{providers/standard => }/hooks/package_index.py (100%) rename airflow/{providers/standard => }/hooks/subprocess.py (96%) delete mode 100644 airflow/providers/standard/hooks/__init__.py rename tests/{providers/standard => }/hooks/test_package_index.py (93%) rename tests/{providers/standard => }/hooks/test_subprocess.py (95%) delete mode 100644 tests/providers/standard/hooks/__init__.py delete mode 100644 tests/providers/standard/hooks/test_filesystem.py diff --git a/airflow/providers/standard/hooks/filesystem.py b/airflow/hooks/filesystem.py similarity index 100% rename from airflow/providers/standard/hooks/filesystem.py rename to airflow/hooks/filesystem.py diff --git a/airflow/providers/standard/hooks/package_index.py b/airflow/hooks/package_index.py similarity index 100% rename from airflow/providers/standard/hooks/package_index.py rename to airflow/hooks/package_index.py diff --git a/airflow/providers/standard/hooks/subprocess.py b/airflow/hooks/subprocess.py similarity index 96% rename from airflow/providers/standard/hooks/subprocess.py rename to airflow/hooks/subprocess.py index 9e578a7d8034b..bc20b5c20b4c5 100644 --- a/airflow/providers/standard/hooks/subprocess.py +++ b/airflow/hooks/subprocess.py @@ -52,8 +52,8 @@ def run_command( :param env: Optional dict containing environment variables to be made available to the shell environment in which ``command`` will be executed. If omitted, ``os.environ`` will be used. Note, that in case you have Sentry configured, original variables from the environment - will also be passed to the subprocess with ``SUBPROCESS_`` prefix. See: - https://airflow.apache.org/docs/apache-airflow/stable/administration-and-deployment/logging-monitoring/errors.html for details. + will also be passed to the subprocess with ``SUBPROCESS_`` prefix. See + :doc:`/administration-and-deployment/logging-monitoring/errors` for details. :param output_encoding: encoding to use for decoding stdout :param cwd: Working directory to run the command in. If None (default), the command is run in a temporary directory. diff --git a/airflow/operators/bash.py b/airflow/operators/bash.py index bf4a943df6e08..2ec0341a0d1e2 100644 --- a/airflow/operators/bash.py +++ b/airflow/operators/bash.py @@ -24,8 +24,8 @@ from typing import TYPE_CHECKING, Any, Callable, Container, Sequence, cast from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.hooks.subprocess import SubprocessHook from airflow.models.baseoperator import BaseOperator -from airflow.providers.standard.hooks.subprocess import SubprocessHook from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.types import ArgNotSet diff --git a/airflow/providers/standard/hooks/__init__.py b/airflow/providers/standard/hooks/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/airflow/providers/standard/hooks/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow/providers/standard/provider.yaml b/airflow/providers/standard/provider.yaml index 068fde1fe3761..83d8acf0a68b3 100644 --- a/airflow/providers/standard/provider.yaml +++ b/airflow/providers/standard/provider.yaml @@ -50,10 +50,3 @@ sensors: - airflow.providers.standard.sensors.time_delta - airflow.providers.standard.sensors.time - airflow.providers.standard.sensors.weekday - -hooks: - - integration-name: Standard - python-modules: - - airflow.providers.standard.hooks.filesystem - - airflow.providers.standard.hooks.package_index - - airflow.providers.standard.hooks.subprocess diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index e276c465ef689..2c673063cb23e 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -36,8 +36,8 @@ from packaging.utils import canonicalize_name from airflow.exceptions import AirflowOptionalProviderFeatureException -from airflow.providers.standard.hooks.filesystem import FSHook -from airflow.providers.standard.hooks.package_index import PackageIndexHook +from airflow.hooks.filesystem import FSHook +from airflow.hooks.package_index import PackageIndexHook from airflow.typing_compat import ParamSpec from airflow.utils import yaml from airflow.utils.entry_points import entry_points_with_dist diff --git a/airflow/sensors/filesystem.py b/airflow/sensors/filesystem.py index 4496f5d6abfa4..5d32ab07ad4e7 100644 --- a/airflow/sensors/filesystem.py +++ b/airflow/sensors/filesystem.py @@ -25,7 +25,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.providers.standard.hooks.filesystem import FSHook +from airflow.hooks.filesystem import FSHook from airflow.sensors.base import BaseSensorOperator from airflow.triggers.base import StartTriggerArgs from airflow.triggers.file import FileTrigger diff --git a/docs/apache-airflow/administration-and-deployment/logging-monitoring/errors.rst b/docs/apache-airflow/administration-and-deployment/logging-monitoring/errors.rst index 0ad3fa8c5127a..cb09843422321 100644 --- a/docs/apache-airflow/administration-and-deployment/logging-monitoring/errors.rst +++ b/docs/apache-airflow/administration-and-deployment/logging-monitoring/errors.rst @@ -96,7 +96,7 @@ Impact of Sentry on Environment variables passed to Subprocess Hook When Sentry is enabled, by default it changes the standard library to pass all environment variables to subprocesses opened by Airflow. This changes the default behaviour of -:class:`airflow.providers.standard.hooks.subprocess.SubprocessHook` - always all environment variables are passed to the +:class:`airflow.hooks.subprocess.SubprocessHook` - always all environment variables are passed to the subprocess executed with specific set of environment variables. In this case not only the specified environment variables are passed but also all existing environment variables are passed with ``SUBPROCESS_`` prefix added. This happens also for all other subprocesses. diff --git a/docs/apache-airflow/operators-and-hooks-ref.rst b/docs/apache-airflow/operators-and-hooks-ref.rst index d4ac6bda74c34..16b74305a958b 100644 --- a/docs/apache-airflow/operators-and-hooks-ref.rst +++ b/docs/apache-airflow/operators-and-hooks-ref.rst @@ -106,8 +106,8 @@ For details see: :doc:`apache-airflow-providers:operators-and-hooks-ref/index`. * - Hooks - Guides - * - :mod:`airflow.providers.standard.hooks.filesystem` + * - :mod:`airflow.hooks.filesystem` - - * - :mod:`airflow.providers.standard.hooks.subprocess` + * - :mod:`airflow.hooks.subprocess` - diff --git a/tests/providers/standard/hooks/test_package_index.py b/tests/hooks/test_package_index.py similarity index 93% rename from tests/providers/standard/hooks/test_package_index.py rename to tests/hooks/test_package_index.py index 6a90db0715d81..9da429c5a09cf 100644 --- a/tests/providers/standard/hooks/test_package_index.py +++ b/tests/hooks/test_package_index.py @@ -21,8 +21,8 @@ import pytest +from airflow.hooks.package_index import PackageIndexHook from airflow.models.connection import Connection -from airflow.providers.standard.hooks.package_index import PackageIndexHook class MockConnection(Connection): @@ -73,7 +73,7 @@ def mock_get_connection(monkeypatch: pytest.MonkeyPatch, request: pytest.Fixture password: str | None = testdata.get("password", None) expected_result: str | None = testdata.get("expected_result", None) monkeypatch.setattr( - "airflow.providers.standard.hooks.package_index.PackageIndexHook.get_connection", + "airflow.hooks.package_index.PackageIndexHook.get_connection", lambda *_: MockConnection(host, login, password), ) return expected_result @@ -104,7 +104,7 @@ class MockProc: return MockProc() - monkeypatch.setattr("airflow.providers.standard.hooks.package_index.subprocess.run", mock_run) + monkeypatch.setattr("airflow.hooks.package_index.subprocess.run", mock_run) hook_instance = PackageIndexHook() if mock_get_connection: diff --git a/tests/providers/standard/hooks/test_subprocess.py b/tests/hooks/test_subprocess.py similarity index 95% rename from tests/providers/standard/hooks/test_subprocess.py rename to tests/hooks/test_subprocess.py index 2b2e9473359e5..0f625be816887 100644 --- a/tests/providers/standard/hooks/test_subprocess.py +++ b/tests/hooks/test_subprocess.py @@ -26,7 +26,7 @@ import pytest -from airflow.providers.standard.hooks.subprocess import SubprocessHook +from airflow.hooks.subprocess import SubprocessHook OS_ENV_KEY = "SUBPROCESS_ENV_TEST" OS_ENV_VAL = "this-is-from-os-environ" @@ -81,11 +81,11 @@ def test_return_value(self, val, expected): @mock.patch.dict("os.environ", clear=True) @mock.patch( - "airflow.providers.standard.hooks.subprocess.TemporaryDirectory", + "airflow.hooks.subprocess.TemporaryDirectory", return_value=MagicMock(__enter__=MagicMock(return_value="/tmp/airflowtmpcatcat")), ) @mock.patch( - "airflow.providers.standard.hooks.subprocess.Popen", + "airflow.hooks.subprocess.Popen", return_value=MagicMock(stdout=MagicMock(readline=MagicMock(side_effect=StopIteration), returncode=0)), ) def test_should_exec_subprocess(self, mock_popen, mock_temporary_directory): diff --git a/tests/providers/standard/hooks/__init__.py b/tests/providers/standard/hooks/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/tests/providers/standard/hooks/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/tests/providers/standard/hooks/test_filesystem.py b/tests/providers/standard/hooks/test_filesystem.py deleted file mode 100644 index bbcd22dc94219..0000000000000 --- a/tests/providers/standard/hooks/test_filesystem.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.providers.standard.hooks.filesystem import FSHook - -pytestmark = pytest.mark.db_test - - -class TestFSHook: - def test_get_ui_field_behaviour(self): - fs_hook = FSHook() - assert fs_hook.get_ui_field_behaviour() == { - "hidden_fields": ["host", "schema", "port", "login", "password", "extra"], - "relabeling": {}, - "placeholders": {}, - } - - def test_get_path(self): - fs_hook = FSHook(fs_conn_id="fs_default") - - assert fs_hook.get_path() == "/" diff --git a/tests/sensors/test_filesystem.py b/tests/sensors/test_filesystem.py index 641f2f218f2db..1fb123cfe7248 100644 --- a/tests/sensors/test_filesystem.py +++ b/tests/sensors/test_filesystem.py @@ -40,7 +40,7 @@ @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode class TestFileSensor: def setup_method(self): - from airflow.providers.standard.hooks.filesystem import FSHook + from airflow.hooks.filesystem import FSHook hook = FSHook() args = {"owner": "airflow", "start_date": DEFAULT_DATE} From 6c00b895b4a64d5da9d7fc3a34045bc01c545450 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 2 Oct 2024 05:59:14 -0700 Subject: [PATCH 253/349] Add backfill cancellation logic (#42530) --- .../endpoints/backfill_endpoint.py | 40 ++++++++--------- airflow/models/backfill.py | 45 +++++++++++++++++-- tests/models/test_backfill.py | 45 ++++++++++++++++++- 3 files changed, 104 insertions(+), 26 deletions(-) diff --git a/airflow/api_connexion/endpoints/backfill_endpoint.py b/airflow/api_connexion/endpoints/backfill_endpoint.py index baafdeea4f992..a0e728c5bc464 100644 --- a/airflow/api_connexion/endpoints/backfill_endpoint.py +++ b/airflow/api_connexion/endpoints/backfill_endpoint.py @@ -32,8 +32,12 @@ backfill_collection_schema, backfill_schema, ) -from airflow.models.backfill import AlreadyRunningBackfill, Backfill, _create_backfill -from airflow.utils import timezone +from airflow.models.backfill import ( + AlreadyRunningBackfill, + Backfill, + _cancel_backfill, + _create_backfill, +) from airflow.utils.session import NEW_SESSION, provide_session from airflow.www.decorators import action_logging @@ -104,24 +108,6 @@ def unpause_backfill(*, backfill_id, session, **kwargs): return backfill_schema.dump(br) -@provide_session -@backfill_to_dag -@security.requires_access_dag("PUT") -@action_logging -def cancel_backfill(*, backfill_id, session, **kwargs): - br: Backfill = session.get(Backfill, backfill_id) - if br.completed_at is not None: - raise Conflict("Backfill is already completed.") - - br.completed_at = timezone.utcnow() - - # first, pause - if not br.is_paused: - br.is_paused = True - session.commit() - return backfill_schema.dump(br) - - @provide_session @backfill_to_dag @security.requires_access_dag("GET") @@ -155,3 +141,17 @@ def create_backfill( return backfill_schema.dump(backfill_obj) except AlreadyRunningBackfill: raise Conflict(f"There is already a running backfill for dag {dag_id}") + + +@provide_session +@backfill_to_dag +@security.requires_access_dag("PUT") +@action_logging +def cancel_backfill( + *, + backfill_id, + session: Session = NEW_SESSION, # used by backfill_to_dag decorator + **kwargs, +): + br = _cancel_backfill(backfill_id=backfill_id) + return backfill_schema.dump(br) diff --git a/airflow/models/backfill.py b/airflow/models/backfill.py index 6d3a8ee4fa922..db10c804aac0d 100644 --- a/airflow/models/backfill.py +++ b/airflow/models/backfill.py @@ -26,12 +26,13 @@ import logging from typing import TYPE_CHECKING -from sqlalchemy import Boolean, Column, ForeignKeyConstraint, Integer, UniqueConstraint, func, select +from sqlalchemy import Boolean, Column, ForeignKeyConstraint, Integer, UniqueConstraint, func, select, update from sqlalchemy.orm import relationship from sqlalchemy_jsonfield import JSONField -from airflow.api_connexion.exceptions import NotFound +from airflow.api_connexion.exceptions import Conflict, NotFound from airflow.exceptions import AirflowException +from airflow.models import DagRun from airflow.models.base import Base, StringID from airflow.models.serialized_dag import SerializedDagModel from airflow.settings import json @@ -48,7 +49,11 @@ class AlreadyRunningBackfill(AirflowException): - """Raised when attempting to create backfill and one already active.""" + """ + Raised when attempting to create backfill and one already active. + + :meta private: + """ class Backfill(Base): @@ -172,7 +177,11 @@ def _create_backfill( session=session, ) except Exception: - dag.log.exception("something failed") + dag.log.exception( + "Error while attempting to create a dag run dag_id='%s' logical_date='%s'", + dag.dag_id, + info.logical_date, + ) session.rollback() session.add( BackfillDagRun( @@ -183,3 +192,31 @@ def _create_backfill( ) session.commit() return br + + +def _cancel_backfill(backfill_id) -> Backfill: + with create_session() as session: + b: Backfill = session.get(Backfill, backfill_id) + if b.completed_at is not None: + raise Conflict("Backfill is already completed.") + + b.completed_at = timezone.utcnow() + + # first, pause + if not b.is_paused: + b.is_paused = True + + session.commit() + + # now, let's mark all queued dag runs as failed + query = ( + update(DagRun) + .where( + DagRun.id.in_(select(BackfillDagRun.dag_run_id).where(BackfillDagRun.backfill_id == b.id)), + DagRun.state == DagRunState.QUEUED, + ) + .values(state=DagRunState.FAILED) + .execution_options(synchronize_session=False) + ) + session.execute(query) + return b diff --git a/tests/models/test_backfill.py b/tests/models/test_backfill.py index 9a845f86803e0..c45625db335de 100644 --- a/tests/models/test_backfill.py +++ b/tests/models/test_backfill.py @@ -24,7 +24,13 @@ from sqlalchemy import select from airflow.models import DagRun -from airflow.models.backfill import AlreadyRunningBackfill, Backfill, BackfillDagRun, _create_backfill +from airflow.models.backfill import ( + AlreadyRunningBackfill, + Backfill, + BackfillDagRun, + _cancel_backfill, + _create_backfill, +) from airflow.operators.python import PythonOperator from airflow.utils.state import DagRunState from tests.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs, clear_db_serialized_dags @@ -71,7 +77,7 @@ def test_reverse_and_depends_on_past_fails(dep_on_past, dag_maker, session): @pytest.mark.parametrize("reverse", [True, False]) -def test_simple(reverse, dag_maker, session): +def test_create_backfill_simple(reverse, dag_maker, session): """ Verify simple case behavior. @@ -150,3 +156,38 @@ def test_active_dag_run(dag_maker, session): reverse=False, dag_run_conf={"this": "param"}, ) + + +def test_cancel_backfill(dag_maker, session): + """ + Queued runs should be marked *failed*. + Every other dag run should be left alone. + """ + with dag_maker(schedule="@daily") as dag: + PythonOperator(task_id="hi", python_callable=print) + b = _create_backfill( + dag_id=dag.dag_id, + from_date=pendulum.parse("2021-01-01"), + to_date=pendulum.parse("2021-01-05"), + max_active_runs=2, + reverse=False, + dag_run_conf={}, + ) + query = ( + select(DagRun) + .join(BackfillDagRun.dag_run) + .where(BackfillDagRun.backfill_id == b.id) + .order_by(BackfillDagRun.sort_ordinal) + ) + dag_runs = session.scalars(query).all() + dates = [str(x.logical_date.date()) for x in dag_runs] + expected_dates = ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-04", "2021-01-05"] + assert dates == expected_dates + assert all(x.state == DagRunState.QUEUED for x in dag_runs) + dag_runs[0].state = "running" + session.commit() + _cancel_backfill(backfill_id=b.id) + session.expunge_all() + dag_runs = session.scalars(query).all() + states = [x.state for x in dag_runs] + assert states == ["running", "failed", "failed", "failed", "failed"] From dc6597b356b744cdc13d99d60625f191a3343bba Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Wed, 2 Oct 2024 09:48:32 -0400 Subject: [PATCH 254/349] Use FAB auth manager in `test_google_openid` (#42622) --- .../google/common/auth_backend/test_google_openid.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/providers/google/common/auth_backend/test_google_openid.py b/tests/providers/google/common/auth_backend/test_google_openid.py index dab0eae07a23d..260ae0d6fb5e1 100644 --- a/tests/providers/google/common/auth_backend/test_google_openid.py +++ b/tests/providers/google/common/auth_backend/test_google_openid.py @@ -22,6 +22,7 @@ from google.auth.exceptions import GoogleAuthError from airflow.www.app import create_app +from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_pools from tests.test_utils.decorators import dont_initialize_flask_app_submodules @@ -41,7 +42,13 @@ def google_openid_app(): ) def factory(): with conf_vars( - {("api", "auth_backends"): "airflow.providers.google.common.auth_backend.google_openid"} + { + ("api", "auth_backends"): "airflow.providers.google.common.auth_backend.google_openid", + ( + "core", + "auth_manager", + ): "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", + } ): _app = create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore _app.config["AUTH_ROLE_PUBLIC"] = None @@ -67,6 +74,7 @@ def admin_user(google_openid_app): return role_admin +@pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="The tests should be skipped for Airflow < 2.9") @pytest.mark.skip_if_database_isolation_mode @pytest.mark.db_test class TestGoogleOpenID: From c3879aa839f8b1a72f0cda494f1884b6c4cf6d4a Mon Sep 17 00:00:00 2001 From: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> Date: Wed, 2 Oct 2024 08:38:37 -0600 Subject: [PATCH 255/349] Remove "project" from log path in callback docs (#42666) Airflow doesn't have the concept of a "project", unless DAG authors add that layer themselves. --- .../logging-monitoring/callbacks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst b/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst index b54071373cf09..4f74626ab29ba 100644 --- a/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst +++ b/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst @@ -34,7 +34,7 @@ For example, you may wish to alert when certain tasks have failed, or have the l Callback functions are executed after tasks are completed. Errors in callback functions will show up in scheduler logs rather than task logs. By default, scheduler logs do not show up in the UI and instead can be found in - ``$AIRFLOW_HOME/logs/scheduler/latest/PROJECT/DAG_FILE.py.log`` + ``$AIRFLOW_HOME/logs/scheduler/latest/DAG_FILE.py.log`` Callback Types -------------- From 879537e29cb93cd7d9d2d0ff3503781c8323def0 Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Wed, 2 Oct 2024 19:07:31 +0300 Subject: [PATCH 256/349] Fix invalid path in lineage.rst (#42655) --- docs/apache-airflow/administration-and-deployment/lineage.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/apache-airflow/administration-and-deployment/lineage.rst b/docs/apache-airflow/administration-and-deployment/lineage.rst index d2ef63d869755..b274809175c03 100644 --- a/docs/apache-airflow/administration-and-deployment/lineage.rst +++ b/docs/apache-airflow/administration-and-deployment/lineage.rst @@ -101,7 +101,7 @@ The collector then uses this data to construct AIP-60 compliant Assets, a standa .. code-block:: python - from airflow.lineage.hook_lineage import get_hook_lineage_collector + from airflow.lineage.hook.lineage import get_hook_lineage_collector class CustomHook(BaseHook): From 300f260ea7f7c54b0833e3ac3fa6ac2bb5a34a14 Mon Sep 17 00:00:00 2001 From: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> Date: Wed, 2 Oct 2024 10:36:30 -0600 Subject: [PATCH 257/349] Add support for PostgreSQL 17 in Breeze (#42644) * Add support for PostgreSQL 17 in Breeze * Fix tests --------- Co-authored-by: Tzu-ping Chung --- README.md | 2 +- dev/breeze/doc/images/output-commands.svg | 42 ++--- dev/breeze/doc/images/output_setup_config.svg | 2 +- dev/breeze/doc/images/output_setup_config.txt | 2 +- dev/breeze/doc/images/output_shell.svg | 140 +++++++-------- dev/breeze/doc/images/output_shell.txt | 2 +- .../doc/images/output_start-airflow.svg | 2 +- .../doc/images/output_start-airflow.txt | 2 +- .../doc/images/output_testing_db-tests.svg | 158 ++++++++--------- .../doc/images/output_testing_db-tests.txt | 2 +- .../output_testing_integration-tests.svg | 50 +++--- .../output_testing_integration-tests.txt | 2 +- .../doc/images/output_testing_tests.svg | 162 +++++++++--------- .../doc/images/output_testing_tests.txt | 2 +- .../src/airflow_breeze/global_constants.py | 4 +- dev/breeze/tests/test_selective_checks.py | 4 +- generated/PYPI_README.md | 2 +- 17 files changed, 296 insertions(+), 284 deletions(-) diff --git a/README.md b/README.md index 3169ac5144844..3cd6416e93405 100644 --- a/README.md +++ b/README.md @@ -102,7 +102,7 @@ Apache Airflow is tested with: | Python | 3.8, 3.9, 3.10, 3.11, 3.12 | 3.8, 3.9, 3.10, 3.11, 3.12 | | Platform | AMD64/ARM64(\*) | AMD64/ARM64(\*) | | Kubernetes | 1.28, 1.29, 1.30, 1.31 | 1.27, 1.28, 1.29, 1.30 | -| PostgreSQL | 12, 13, 14, 15, 16 | 12, 13, 14, 15, 16 | +| PostgreSQL | 12, 13, 14, 15, 16, 17 | 12, 13, 14, 15, 16 | | MySQL | 8.0, 8.4, Innovation | 8.0, 8.4, Innovation | | SQLite | 3.15.0+ | 3.15.0+ | diff --git a/dev/breeze/doc/images/output-commands.svg b/dev/breeze/doc/images/output-commands.svg index 1556dfef6f5a7..78c753526e449 100644 --- a/dev/breeze/doc/images/output-commands.svg +++ b/dev/breeze/doc/images/output-commands.svg @@ -301,53 +301,53 @@ Usage:breeze[OPTIONSCOMMAND [ARGS]... ╭─ Execution mode ─────────────────────────────────────────────────────────────────────────────────────────────────────╮ ---python-pPython major/minor version used in Airflow image for images. +--python-pPython major/minor version used in Airflow image for images. (>3.8< | 3.9 | 3.10 | 3.11 | 3.12)                           [default: 3.8]                                               ---integrationIntegration(s) to enable when running (can be more than one).                        +--integrationIntegration(s) to enable when running (can be more than one).                        (all | all-testable | cassandra | celery | drill | kafka | kerberos | mongo | mssql  | openlineage | otel | pinot | qdrant | redis | statsd | trino | ydb)                ---standalone-dag-processorRun standalone dag processor for start-airflow. ---database-isolationRun airflow in database isolation mode. +--standalone-dag-processorRun standalone dag processor for start-airflow. +--database-isolationRun airflow in database isolation mode. ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ╭─ Docker Compose selection and cleanup ───────────────────────────────────────────────────────────────────────────────╮ ---project-nameName of the docker-compose project to bring down. The `docker-compose` is for legacy breeze        -project name and you can use `breeze down --project-name docker-compose` to stop all containers    +--project-nameName of the docker-compose project to bring down. The `docker-compose` is for legacy breeze        +project name and you can use `breeze down --project-name docker-compose` to stop all containers    belonging to it.                                                                                   (breeze | pre-commit | docker-compose)                                                             [default: breeze]                                                                                  ---docker-hostOptional - docker host to use when running docker commands. When set, the `--builder` option is    +--docker-hostOptional - docker host to use when running docker commands. When set, the `--builder` option is    ignored when building images.                                                                      (TEXT)                                                                                             ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ╭─ Database ───────────────────────────────────────────────────────────────────────────────────────────────────────────╮ ---backend-bDatabase backend to use. If 'none' is chosen, Breeze will start with an invalid database     +--backend-bDatabase backend to use. If 'none' is chosen, Breeze will start with an invalid database     configuration, meaning there will be no database available, and any attempts to connect to   the Airflow database will fail.                                                              (>sqlite< | mysql | postgres | none)                                                         [default: sqlite]                                                                            ---postgres-version-PVersion of Postgres used.(>12< | 13 | 14 | 15 | 16)[default: 12] ---mysql-version-MVersion of MySQL used.(>8.0< | 8.4)[default: 8.0] ---db-reset-dReset DB when entering the container. +--postgres-version-PVersion of Postgres used.(>12< | 13 | 14 | 15 | 16 | 17)[default: 12] +--mysql-version-MVersion of MySQL used.(>8.0< | 8.4)[default: 8.0] +--db-reset-dReset DB when entering the container. ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ╭─ Build CI image (before entering shell) ─────────────────────────────────────────────────────────────────────────────╮ ---github-repository-gGitHub repository used to pull, push run images.(TEXT)[default: apache/airflow] ---builderBuildx builder used to perform `docker buildx build` commands.(TEXT) +--github-repository-gGitHub repository used to pull, push run images.(TEXT)[default: apache/airflow] +--builderBuildx builder used to perform `docker buildx build` commands.(TEXT) [default: autodetect]                                          ---use-uv/--no-use-uvUse uv instead of pip as packaging tool to build the image.[default: use-uv] ---uv-http-timeoutTimeout for requests that UV makes (only used in case of UV builds).(INTEGER RANGE) +--use-uv/--no-use-uvUse uv instead of pip as packaging tool to build the image.[default: use-uv] +--uv-http-timeoutTimeout for requests that UV makes (only used in case of UV builds).(INTEGER RANGE) [default: 300; x>=1]                                                 ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ╭─ Other options ──────────────────────────────────────────────────────────────────────────────────────────────────────╮ ---forward-credentials-fForward local credentials to container when running. ---max-timeMaximum time that the command should take - if it takes longer, the command will fail. +--forward-credentials-fForward local credentials to container when running. +--max-timeMaximum time that the command should take - if it takes longer, the command will fail. (INTEGER RANGE)                                                                        ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ╭─ Common options ─────────────────────────────────────────────────────────────────────────────────────────────────────╮ ---answer-aForce answer to questions.(y | n | q | yes | no | quit) ---dry-run-DIf dry-run is set, commands are only printed, not executed. ---verbose-vPrint verbose information about performed steps. ---help-hShow this message and exit. +--answer-aForce answer to questions.(y | n | q | yes | no | quit) +--dry-run-DIf dry-run is set, commands are only printed, not executed. +--verbose-vPrint verbose information about performed steps. +--help-hShow this message and exit. ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ╭─ Developer commands ─────────────────────────────────────────────────────────────────────────────────────────────────╮ start-airflow          Enter breeze environment and starts all Airflow components in the tmux session. Compile     diff --git a/dev/breeze/doc/images/output_setup_config.svg b/dev/breeze/doc/images/output_setup_config.svg index 2fb7fab65273c..69780cb5426af 100644 --- a/dev/breeze/doc/images/output_setup_config.svg +++ b/dev/breeze/doc/images/output_setup_config.svg @@ -137,7 +137,7 @@ attempts to connect to the Airflow database will fail.                          (>sqlite< | mysql | postgres | none)                                            [default: sqlite]                                                               ---postgres-version-PVersion of Postgres used.(>12< | 13 | 14 | 15 | 16)[default: 12] +--postgres-version-PVersion of Postgres used.(>12< | 13 | 14 | 15 | 16 | 17)[default: 12] --mysql-version-MVersion of MySQL used.(>8.0< | 8.4)[default: 8.0] --cheatsheet/--no-cheatsheet-C/-cEnable/disable cheatsheet. --asciiart/--no-asciiart-A/-aEnable/disable ASCIIart. diff --git a/dev/breeze/doc/images/output_setup_config.txt b/dev/breeze/doc/images/output_setup_config.txt index f47fa38e42c7b..97d022c37b5e1 100644 --- a/dev/breeze/doc/images/output_setup_config.txt +++ b/dev/breeze/doc/images/output_setup_config.txt @@ -1 +1 @@ -422c8c524b557fcf5924da4c8590935d +783acef079cbdd31cd7880618c20fae5 diff --git a/dev/breeze/doc/images/output_shell.svg b/dev/breeze/doc/images/output_shell.svg index bf8fbc3ee81f6..1e86993b2c466 100644 --- a/dev/breeze/doc/images/output_shell.svg +++ b/dev/breeze/doc/images/output_shell.svg @@ -573,7 +573,7 @@ the Airflow database will fail.                                                              (>sqlite< | mysql | postgres | none)                                                         [default: sqlite]                                                                            ---postgres-version-PVersion of Postgres used.(>12< | 13 | 14 | 15 | 16)[default: 12] +--postgres-version-PVersion of Postgres used.(>12< | 13 | 14 | 15 | 16 | 17)[default: 12] --mysql-version-MVersion of MySQL used.(>8.0< | 8.4)[default: 8.0] --db-reset-dReset DB when entering the container. ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ @@ -620,75 +620,75 @@ --airflow-skip-constraintsDo not use constraints when installing airflow. --clean-airflow-installationClean the airflow installation before installing version specified by --use-airflow-version.                      ---force-lowest-dependenciesRun tests for the lowest direct dependencies of Airflow  -or selected provider if `Provider[PROVIDER_ID]` is used  -as test type.                                            ---install-airflow-with-constraints/--no-install-airflow…Install airflow in a separate step, with constraints     -determined from package or airflow version.              -[default: install-airflow-with-constraints]              ---install-selected-providersComma-separated list of providers selected to be         -installed (implies --use-packages-from-dist).            -(TEXT)                                                   ---package-formatFormat of packages that should be installed from dist. -(wheel | sdist)                                        -[default: wheel]                                       ---providers-constraints-locationLocation of providers constraints to use (remote URL or  -local context file).                                     -(TEXT)                                                   ---providers-constraints-modeMode of constraints for Providers for CI image building. -(constraints-source-providers | constraints |            -constraints-no-providers)                                -[default: constraints-source-providers]                  ---providers-constraints-referenceConstraint reference to use for providers installation   -(used in calculated constraints URL). Can be 'default'   -in which case the default constraints-reference is used. -(TEXT)                                                   ---providers-skip-constraintsDo not use constraints when installing providers. ---test-typeType of test to run. With Providers, you can specify     -tests of which providers should be run:                  -`Providers[airbyte,http]` or excluded from the full test -suite: `Providers[-amazon,google]`                       -(All | Default | API | Always | BranchExternalPython |   -BranchPythonVenv | CLI | Core | ExternalPython |         -Operators | Other | PlainAsserts | Providers |           -PythonVenv | Serialization | WWW | All-Postgres |        -All-MySQL | All-Quarantined)                             -[default: Default]                                       ---use-airflow-versionUse (reinstall at entry) Airflow version from PyPI. It   -can also be version (to install from PyPI), `none`,      -`wheel`, or `sdist` to install from `dist` folder, or    -VCS URL to install from                                  -(https://pip.pypa.io/en/stable/topics/vcs-support/).     -Implies --mount-sources `remove`.                        -(none | wheel | sdist | <airflow_version>)               ---use-packages-from-distInstall all found packages (--package-format determines  -type) from 'dist' folder when entering breeze.           -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Upgrading/downgrading/removing selected packages ───────────────────────────────────────────────────────────────────╮ ---upgrade-botoRemove aiobotocore and upgrade botocore and boto to the latest version. ---downgrade-sqlalchemyDowngrade SQLAlchemy to minimum supported version. ---downgrade-pendulumDowngrade Pendulum to minimum supported version. -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ DB test flags ──────────────────────────────────────────────────────────────────────────────────────────────────────╮ ---run-db-tests-onlyOnly runs tests that require a database ---skip-db-testsSkip tests that require a database -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Other options ──────────────────────────────────────────────────────────────────────────────────────────────────────╮ ---forward-credentials-fForward local credentials to container when running. ---max-timeMaximum time that the command should take - if it takes longer, the command will fail. -(INTEGER RANGE)                                                                        ---verbose-commandsShow details of commands executed. ---keep-env-variablesDo not clear environment variables that might have side effect while running tests ---no-db-cleanupDo not clear the database before each test module -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Common options ─────────────────────────────────────────────────────────────────────────────────────────────────────╮ ---answer-aForce answer to questions.(y | n | q | yes | no | quit) ---dry-run-DIf dry-run is set, commands are only printed, not executed. ---excluded-providersJSON-string of dictionary containing excluded providers per python version ({'3.12':       -['provider']})                                                                             -(TEXT)                                                                                     ---verbose-vPrint verbose information about performed steps. ---help-hShow this message and exit. +--excluded-providersJSON-string of dictionary containing excluded providers  +per python version ({'3.12': ['provider']})              +(TEXT)                                                   +--force-lowest-dependenciesRun tests for the lowest direct dependencies of Airflow  +or selected provider if `Provider[PROVIDER_ID]` is used  +as test type.                                            +--install-airflow-with-constraints/--no-install-airflow…Install airflow in a separate step, with constraints     +determined from package or airflow version.              +[default: install-airflow-with-constraints]              +--install-selected-providersComma-separated list of providers selected to be         +installed (implies --use-packages-from-dist).            +(TEXT)                                                   +--package-formatFormat of packages that should be installed from dist. +(wheel | sdist)                                        +[default: wheel]                                       +--providers-constraints-locationLocation of providers constraints to use (remote URL or  +local context file).                                     +(TEXT)                                                   +--providers-constraints-modeMode of constraints for Providers for CI image building. +(constraints-source-providers | constraints |            +constraints-no-providers)                                +[default: constraints-source-providers]                  +--providers-constraints-referenceConstraint reference to use for providers installation   +(used in calculated constraints URL). Can be 'default'   +in which case the default constraints-reference is used. +(TEXT)                                                   +--providers-skip-constraintsDo not use constraints when installing providers. +--test-typeType of test to run. With Providers, you can specify     +tests of which providers should be run:                  +`Providers[airbyte,http]` or excluded from the full test +suite: `Providers[-amazon,google]`                       +(All | Default | API | Always | BranchExternalPython |   +BranchPythonVenv | CLI | Core | ExternalPython |         +Operators | Other | PlainAsserts | Providers |           +PythonVenv | Serialization | WWW | All-Postgres |        +All-MySQL | All-Quarantined)                             +[default: Default]                                       +--use-airflow-versionUse (reinstall at entry) Airflow version from PyPI. It   +can also be version (to install from PyPI), `none`,      +`wheel`, or `sdist` to install from `dist` folder, or    +VCS URL to install from                                  +(https://pip.pypa.io/en/stable/topics/vcs-support/).     +Implies --mount-sources `remove`.                        +(none | wheel | sdist | <airflow_version>)               +--use-packages-from-distInstall all found packages (--package-format determines  +type) from 'dist' folder when entering breeze.           +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Upgrading/downgrading/removing selected packages ───────────────────────────────────────────────────────────────────╮ +--upgrade-botoRemove aiobotocore and upgrade botocore and boto to the latest version. +--downgrade-sqlalchemyDowngrade SQLAlchemy to minimum supported version. +--downgrade-pendulumDowngrade Pendulum to minimum supported version. +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ DB test flags ──────────────────────────────────────────────────────────────────────────────────────────────────────╮ +--run-db-tests-onlyOnly runs tests that require a database +--skip-db-testsSkip tests that require a database +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Other options ──────────────────────────────────────────────────────────────────────────────────────────────────────╮ +--forward-credentials-fForward local credentials to container when running. +--max-timeMaximum time that the command should take - if it takes longer, the command will fail. +(INTEGER RANGE)                                                                        +--verbose-commandsShow details of commands executed. +--keep-env-variablesDo not clear environment variables that might have side effect while running tests +--no-db-cleanupDo not clear the database before each test module +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Common options ─────────────────────────────────────────────────────────────────────────────────────────────────────╮ +--answer-aForce answer to questions.(y | n | q | yes | no | quit) +--dry-run-DIf dry-run is set, commands are only printed, not executed. +--verbose-vPrint verbose information about performed steps. +--help-hShow this message and exit. ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ diff --git a/dev/breeze/doc/images/output_shell.txt b/dev/breeze/doc/images/output_shell.txt index 71be1f4ed5fa6..8529a31200925 100644 --- a/dev/breeze/doc/images/output_shell.txt +++ b/dev/breeze/doc/images/output_shell.txt @@ -1 +1 @@ -4d7e652e8a79290f5ca783e94662ada1 +12f9e4a84051e05a5e0b9ec4fe3c8632 diff --git a/dev/breeze/doc/images/output_start-airflow.svg b/dev/breeze/doc/images/output_start-airflow.svg index 377745370a5b4..55bac0da8cd10 100644 --- a/dev/breeze/doc/images/output_start-airflow.svg +++ b/dev/breeze/doc/images/output_start-airflow.svg @@ -432,7 +432,7 @@ the Airflow database will fail.                                                              (>sqlite< | mysql | postgres | none)                                                         [default: sqlite]                                                                            ---postgres-version-PVersion of Postgres used.(>12< | 13 | 14 | 15 | 16)[default: 12] +--postgres-version-PVersion of Postgres used.(>12< | 13 | 14 | 15 | 16 | 17)[default: 12] --mysql-version-MVersion of MySQL used.(>8.0< | 8.4)[default: 8.0] --db-reset-dReset DB when entering the container. ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ diff --git a/dev/breeze/doc/images/output_start-airflow.txt b/dev/breeze/doc/images/output_start-airflow.txt index 6399738d71e5d..3aaa148b8ca83 100644 --- a/dev/breeze/doc/images/output_start-airflow.txt +++ b/dev/breeze/doc/images/output_start-airflow.txt @@ -1 +1 @@ -74f2c1895c08408a8caa90eaf96f98cf +f6365a250b86242436df9236025d447f diff --git a/dev/breeze/doc/images/output_testing_db-tests.svg b/dev/breeze/doc/images/output_testing_db-tests.svg index 916e8c6005c29..d9f6d92eef100 100644 --- a/dev/breeze/doc/images/output_testing_db-tests.svg +++ b/dev/breeze/doc/images/output_testing_db-tests.svg @@ -1,4 +1,4 @@ - + --python-pPython major/minor version used in Airflow image for images. (>3.8< | 3.9 | 3.10 | 3.11 | 3.12)                           [default: 3.8]                                               ---postgres-version-PVersion of Postgres used.(>12< | 13 | 14 | 15 | 16)[default: 12] ---mysql-version-MVersion of MySQL used.(>8.0< | 8.4)[default: 8.0] ---forward-credentials-fForward local credentials to container when running. ---force-sa-warnings/--no-force-sa-warningsEnable `sqlalchemy.exc.MovedIn20Warning` during the tests runs. -[default: force-sa-warnings]                                    -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Options for parallel test commands ─────────────────────────────────────────────────────────────────────────────────╮ ---parallelismMaximum number of processes to use while running the operation in parallel. -(INTEGER RANGE)                                                             -[default: 4; 1<=x<=8]                                                       ---skip-cleanupSkip cleanup of temporary files created during parallel run. ---debug-resourcesWhether to show resource information while running in parallel. ---include-success-outputsWhether to include outputs of successful parallel runs (skipped by default). -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Upgrading/downgrading/removing selected packages ───────────────────────────────────────────────────────────────────╮ ---upgrade-botoRemove aiobotocore and upgrade botocore and boto to the latest version. ---downgrade-sqlalchemyDowngrade SQLAlchemy to minimum supported version. ---downgrade-pendulumDowngrade Pendulum to minimum supported version. ---remove-arm-packagesRemoves arm packages from the image to test if ARM collection works -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Advanced flag for tests command ────────────────────────────────────────────────────────────────────────────────────╮ ---airflow-constraints-referenceConstraint reference to use for airflow installation    -(used in calculated constraints URL).                   -(TEXT)                                                  ---clean-airflow-installationClean the airflow installation before installing        -version specified by --use-airflow-version.             ---excluded-providersJSON-string of dictionary containing excluded providers -per python version ({'3.12': ['provider']})             -(TEXT)                                                  ---force-lowest-dependenciesRun tests for the lowest direct dependencies of Airflow -or selected provider if `Provider[PROVIDER_ID]` is used -as test type.                                           ---github-repository-gGitHub repository used to pull, push run images.(TEXT) -[default: apache/airflow]                        ---image-tagTag of the image which is used to run the image         -(implies --mount-sources=skip).                         -(TEXT)                                                  -[default: latest]                                       ---install-airflow-with-constraints/--no-install-airflo…Install airflow in a separate step, with constraints    -determined from package or airflow version.             -[default: no-install-airflow-with-constraints]          ---package-formatFormat of packages.(wheel | sdist | both) -[default: wheel]    ---providers-constraints-locationLocation of providers constraints to use (remote URL or -local context file).                                    -(TEXT)                                                  ---providers-skip-constraintsDo not use constraints when installing providers. ---use-airflow-versionUse (reinstall at entry) Airflow version from PyPI. It  -can also be version (to install from PyPI), `none`,     -`wheel`, or `sdist` to install from `dist` folder, or   -VCS URL to install from                                 -(https://pip.pypa.io/en/stable/topics/vcs-support/).    -Implies --mount-sources `remove`.                       -(none | wheel | sdist | <airflow_version>)              ---use-packages-from-distInstall all found packages (--package-format determines -type) from 'dist' folder when entering breeze.          ---mount-sourcesChoose scope of local sources that should be mounted,   -skipped, or removed (default = selected).               -(selected | all | skip | remove | tests |               -providers-and-tests)                                    -[default: selected]                                     ---skip-docker-compose-downSkips running docker-compose down after tests ---skip-providersSpace-separated list of provider ids to skip when       -running tests                                           -(TEXT)                                                  ---keep-env-variablesDo not clear environment variables that might have side -effect while running tests                              ---no-db-cleanupDo not clear the database before each test module -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Common options ─────────────────────────────────────────────────────────────────────────────────────────────────────╮ ---dry-run-DIf dry-run is set, commands are only printed, not executed. ---verbose-vPrint verbose information about performed steps. ---help-hShow this message and exit. -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +--postgres-version-PVersion of Postgres used.(>12< | 13 | 14 | 15 | 16 | 17) +[default: 12]             +--mysql-version-MVersion of MySQL used.(>8.0< | 8.4)[default: 8.0] +--forward-credentials-fForward local credentials to container when running. +--force-sa-warnings/--no-force-sa-warningsEnable `sqlalchemy.exc.MovedIn20Warning` during the tests runs. +[default: force-sa-warnings]                                    +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Options for parallel test commands ─────────────────────────────────────────────────────────────────────────────────╮ +--parallelismMaximum number of processes to use while running the operation in parallel. +(INTEGER RANGE)                                                             +[default: 4; 1<=x<=8]                                                       +--skip-cleanupSkip cleanup of temporary files created during parallel run. +--debug-resourcesWhether to show resource information while running in parallel. +--include-success-outputsWhether to include outputs of successful parallel runs (skipped by default). +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Upgrading/downgrading/removing selected packages ───────────────────────────────────────────────────────────────────╮ +--upgrade-botoRemove aiobotocore and upgrade botocore and boto to the latest version. +--downgrade-sqlalchemyDowngrade SQLAlchemy to minimum supported version. +--downgrade-pendulumDowngrade Pendulum to minimum supported version. +--remove-arm-packagesRemoves arm packages from the image to test if ARM collection works +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Advanced flag for tests command ────────────────────────────────────────────────────────────────────────────────────╮ +--airflow-constraints-referenceConstraint reference to use for airflow installation    +(used in calculated constraints URL).                   +(TEXT)                                                  +--clean-airflow-installationClean the airflow installation before installing        +version specified by --use-airflow-version.             +--excluded-providersJSON-string of dictionary containing excluded providers +per python version ({'3.12': ['provider']})             +(TEXT)                                                  +--force-lowest-dependenciesRun tests for the lowest direct dependencies of Airflow +or selected provider if `Provider[PROVIDER_ID]` is used +as test type.                                           +--github-repository-gGitHub repository used to pull, push run images.(TEXT) +[default: apache/airflow]                        +--image-tagTag of the image which is used to run the image         +(implies --mount-sources=skip).                         +(TEXT)                                                  +[default: latest]                                       +--install-airflow-with-constraints/--no-install-airflo…Install airflow in a separate step, with constraints    +determined from package or airflow version.             +[default: no-install-airflow-with-constraints]          +--package-formatFormat of packages.(wheel | sdist | both) +[default: wheel]    +--providers-constraints-locationLocation of providers constraints to use (remote URL or +local context file).                                    +(TEXT)                                                  +--providers-skip-constraintsDo not use constraints when installing providers. +--use-airflow-versionUse (reinstall at entry) Airflow version from PyPI. It  +can also be version (to install from PyPI), `none`,     +`wheel`, or `sdist` to install from `dist` folder, or   +VCS URL to install from                                 +(https://pip.pypa.io/en/stable/topics/vcs-support/).    +Implies --mount-sources `remove`.                       +(none | wheel | sdist | <airflow_version>)              +--use-packages-from-distInstall all found packages (--package-format determines +type) from 'dist' folder when entering breeze.          +--mount-sourcesChoose scope of local sources that should be mounted,   +skipped, or removed (default = selected).               +(selected | all | skip | remove | tests |               +providers-and-tests)                                    +[default: selected]                                     +--skip-docker-compose-downSkips running docker-compose down after tests +--skip-providersSpace-separated list of provider ids to skip when       +running tests                                           +(TEXT)                                                  +--keep-env-variablesDo not clear environment variables that might have side +effect while running tests                              +--no-db-cleanupDo not clear the database before each test module +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Common options ─────────────────────────────────────────────────────────────────────────────────────────────────────╮ +--dry-run-DIf dry-run is set, commands are only printed, not executed. +--verbose-vPrint verbose information about performed steps. +--help-hShow this message and exit. +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ diff --git a/dev/breeze/doc/images/output_testing_db-tests.txt b/dev/breeze/doc/images/output_testing_db-tests.txt index 41be88099eb91..683eef1d0509b 100644 --- a/dev/breeze/doc/images/output_testing_db-tests.txt +++ b/dev/breeze/doc/images/output_testing_db-tests.txt @@ -1 +1 @@ -97cc799ed5c1244b2aeb680f3215021b +5a6490989a911c538b427ca2806d0e3e diff --git a/dev/breeze/doc/images/output_testing_integration-tests.svg b/dev/breeze/doc/images/output_testing_integration-tests.svg index ac4ce00627101..4f86fcef38b29 100644 --- a/dev/breeze/doc/images/output_testing_integration-tests.svg +++ b/dev/breeze/doc/images/output_testing_integration-tests.svg @@ -1,4 +1,4 @@ - +

Edge Worker Hosts

+ {% if hosts|length == 0 %} +

No Edge Workers connected or known currently.

+ {% else %} + +
- {% if dag.dag_id in dataset_triggered_next_run_info %} - {%- with ds_info = dataset_triggered_next_run_info[dag.dag_id] -%} + + {% if dag.dag_id in asset_triggered_next_run_info %} + {%- with asset_info = asset_triggered_next_run_info[dag.dag_id] -%}
- {% if ds_info.total == 1 -%} - On {{ ds_info.uri[0:40] + '…' if ds_info.uri and ds_info.uri|length > 40 else ds_info.uri|default('', true) }} + {% if asset_info.total == 1 -%} + On {{ asset_info.uri[0:40] + '…' if asset_info.uri and asset_info.uri|length > 40 else asset_info.uri|default('', true) }} {%- else -%} - {{ ds_info.ready }} of {{ ds_info.total }} datasets updated + {{ asset_info.ready }} of {{ asset_info.total }} datasets updated {%- endif %}
diff --git a/airflow/www/views.py b/airflow/www/views.py index 0ef37f71d336c..b3300b517e757 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -87,10 +87,10 @@ set_dag_run_state_to_success, set_state, ) +from airflow.assets import Asset, AssetAlias from airflow.auth.managers.models.resource_details import AccessView, DagAccessEntity, DagDetails from airflow.compat.functools import cache from airflow.configuration import AIRFLOW_CONFIG, conf -from airflow.datasets import Dataset, DatasetAlias from airflow.exceptions import ( AirflowConfigException, AirflowException, @@ -104,9 +104,9 @@ from airflow.jobs.scheduler_job_runner import SchedulerJobRunner from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models import Connection, DagModel, DagTag, Log, SlaMiss, Trigger, XCom -from airflow.models.dag import get_dataset_triggered_next_run_info +from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel, DagScheduleAssetReference +from airflow.models.dag import get_asset_triggered_next_run_info from airflow.models.dagrun import RUN_ID_REGEX, DagRun, DagRunType -from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel from airflow.models.errors import ParseImportError from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance, TaskInstanceNote @@ -469,9 +469,7 @@ def set_overall_state(record): "label": item.label, "extra_links": item.extra_links, "is_mapped": item_is_mapped, - "has_outlet_datasets": any( - isinstance(i, (Dataset, DatasetAlias)) for i in (item.outlets or []) - ), + "has_outlet_datasets": any(isinstance(i, (Asset, AssetAlias)) for i in (item.outlets or [])), "operator": item.operator_name, "trigger_rule": item.trigger_rule, **setup_teardown_type, @@ -1005,13 +1003,13 @@ def index(self): .all() ) - dataset_triggered_dag_ids = {dag.dag_id for dag in dags if dag.dataset_expression is not None} - if dataset_triggered_dag_ids: - dataset_triggered_next_run_info = get_dataset_triggered_next_run_info( - dataset_triggered_dag_ids, session=session + asset_triggered_dag_ids = {dag.dag_id for dag in dags if dag.dataset_expression is not None} + if asset_triggered_dag_ids: + asset_triggered_next_run_info = get_asset_triggered_next_run_info( + asset_triggered_dag_ids, session=session ) else: - dataset_triggered_next_run_info = {} + asset_triggered_next_run_info = {} file_tokens = {} for dag in dags: @@ -1168,15 +1166,15 @@ def _iter_parsed_moved_data_table_names(): sorting_key=arg_sorting_key, sorting_direction=arg_sorting_direction, auto_refresh_interval=conf.getint("webserver", "auto_refresh_interval"), - dataset_triggered_next_run_info=dataset_triggered_next_run_info, + asset_triggered_next_run_info=asset_triggered_next_run_info, scarf_url=scarf_url, file_tokens=file_tokens, ) @expose("/datasets") - @auth.has_access_dataset("GET") + @auth.has_access_asset("GET") def datasets(self): - """Datasets view.""" + """Assets view.""" state_color_mapping = State.state_color.copy() state_color_mapping["null"] = state_color_mapping.pop(None) return self.render_template( @@ -1222,11 +1220,11 @@ def next_run_datasets_summary(self, session: Session = NEW_SESSION): .where(DagModel.dataset_expression.is_not(None)) ).all() - dataset_triggered_next_run_info = get_dataset_triggered_next_run_info( + asset_triggered_next_run_info = get_asset_triggered_next_run_info( dataset_triggered_dag_ids, session=session ) - return flask.json.jsonify(dataset_triggered_next_run_info) + return flask.json.jsonify(asset_triggered_next_run_info) @expose("/dag_stats", methods=["POST"]) @auth.has_access_dag("GET", DagAccessEntity.RUN) @@ -3407,7 +3405,7 @@ def historical_metrics_data(self): @expose("/object/next_run_datasets/") @auth.has_access_dag("GET", DagAccessEntity.RUN) - @auth.has_access_dataset("GET") + @auth.has_access_asset("GET") @mark_fastapi_migration_done def next_run_datasets(self, dag_id): """Return datasets necessary, and their status, for the next dag run.""" @@ -3424,36 +3422,34 @@ def next_run_datasets(self, dag_id): dict(info._mapping) for info in session.execute( select( - DatasetModel.id, - DatasetModel.uri, - func.max(DatasetEvent.timestamp).label("lastUpdate"), - ) - .join( - DagScheduleDatasetReference, DagScheduleDatasetReference.dataset_id == DatasetModel.id + AssetModel.id, + AssetModel.uri, + func.max(AssetEvent.timestamp).label("lastUpdate"), ) + .join(DagScheduleAssetReference, DagScheduleAssetReference.dataset_id == AssetModel.id) .join( - DatasetDagRunQueue, + AssetDagRunQueue, and_( - DatasetDagRunQueue.dataset_id == DatasetModel.id, - DatasetDagRunQueue.target_dag_id == DagScheduleDatasetReference.dag_id, + AssetDagRunQueue.dataset_id == AssetModel.id, + AssetDagRunQueue.target_dag_id == DagScheduleAssetReference.dag_id, ), isouter=True, ) .join( - DatasetEvent, + AssetEvent, and_( - DatasetEvent.dataset_id == DatasetModel.id, + AssetEvent.dataset_id == AssetModel.id, ( - DatasetEvent.timestamp >= latest_run.execution_date + AssetEvent.timestamp >= latest_run.execution_date if latest_run and latest_run.execution_date else True ), ), isouter=True, ) - .where(DagScheduleDatasetReference.dag_id == dag_id, ~DatasetModel.is_orphaned) - .group_by(DatasetModel.id, DatasetModel.uri) - .order_by(DatasetModel.uri) + .where(DagScheduleAssetReference.dag_id == dag_id, ~AssetModel.is_orphaned) + .group_by(AssetModel.id, AssetModel.uri) + .order_by(AssetModel.uri) ) ] data = {"dataset_expression": dag_model.dataset_expression, "events": events} @@ -3473,7 +3469,7 @@ def dataset_dependencies(self): dag_node_id = f"dag:{dag}" if dag_node_id not in nodes_dict: for dep in dependencies: - if dep.dependency_type in ("dag", "dataset", "dataset-alias"): + if dep.dependency_type in ("dag", "asset", "asset-alias"): # add node nodes_dict[dag_node_id] = node_dict(dag_node_id, dag, "dag") if dep.node_id not in nodes_dict: @@ -3509,7 +3505,7 @@ def dataset_dependencies(self): ) @expose("/object/datasets_summary") - @auth.has_access_dataset("GET") + @auth.has_access_asset("GET") def datasets_summary(self): """ Get a summary of datasets. @@ -3543,54 +3539,54 @@ def datasets_summary(self): with create_session() as session: if lstripped_orderby == "uri": if order_by.startswith("-"): - order_by = (DatasetModel.uri.desc(),) + order_by = (AssetModel.uri.desc(),) else: - order_by = (DatasetModel.uri.asc(),) + order_by = (AssetModel.uri.asc(),) elif lstripped_orderby == "last_dataset_update": if order_by.startswith("-"): order_by = ( - func.max(DatasetEvent.timestamp).desc(), - DatasetModel.uri.asc(), + func.max(AssetEvent.timestamp).desc(), + AssetModel.uri.asc(), ) if session.bind.dialect.name == "postgresql": order_by = (order_by[0].nulls_last(), *order_by[1:]) else: order_by = ( - func.max(DatasetEvent.timestamp).asc(), - DatasetModel.uri.desc(), + func.max(AssetEvent.timestamp).asc(), + AssetModel.uri.desc(), ) if session.bind.dialect.name == "postgresql": order_by = (order_by[0].nulls_first(), *order_by[1:]) - count_query = select(func.count(DatasetModel.id)) + count_query = select(func.count(AssetModel.id)) has_event_filters = bool(updated_before or updated_after) query = ( select( - DatasetModel.id, - DatasetModel.uri, - func.max(DatasetEvent.timestamp).label("last_dataset_update"), - func.sum(case((DatasetEvent.id.is_not(None), 1), else_=0)).label("total_updates"), + AssetModel.id, + AssetModel.uri, + func.max(AssetEvent.timestamp).label("last_dataset_update"), + func.sum(case((AssetEvent.id.is_not(None), 1), else_=0)).label("total_updates"), ) - .join(DatasetEvent, DatasetEvent.dataset_id == DatasetModel.id, isouter=not has_event_filters) + .join(AssetEvent, AssetEvent.dataset_id == AssetModel.id, isouter=not has_event_filters) .group_by( - DatasetModel.id, - DatasetModel.uri, + AssetModel.id, + AssetModel.uri, ) .order_by(*order_by) ) if has_event_filters: - count_query = count_query.join(DatasetEvent, DatasetEvent.dataset_id == DatasetModel.id) + count_query = count_query.join(AssetEvent, AssetEvent.dataset_id == AssetModel.id) - filters = [~DatasetModel.is_orphaned] + filters = [~AssetModel.is_orphaned] if uri_pattern: - filters.append(DatasetModel.uri.ilike(f"%{uri_pattern}%")) + filters.append(AssetModel.uri.ilike(f"%{uri_pattern}%")) if updated_after: - filters.append(DatasetEvent.timestamp >= updated_after) + filters.append(AssetEvent.timestamp >= updated_after) if updated_before: - filters.append(DatasetEvent.timestamp <= updated_before) + filters.append(AssetEvent.timestamp <= updated_before) query = query.where(*filters).offset(offset).limit(limit) count_query = count_query.where(*filters) diff --git a/dev/breeze/tests/test_packages.py b/dev/breeze/tests/test_packages.py index 9556ae695be8e..39ee245cf36bd 100644 --- a/dev/breeze/tests/test_packages.py +++ b/dev/breeze/tests/test_packages.py @@ -165,6 +165,7 @@ def test_get_documentation_package_path(): "fab", "", """ + "apache-airflow-providers-common-compat>=1.2.0", "apache-airflow>=2.9.0", "flask-appbuilder==4.5.0", "flask-login>=0.6.2", @@ -178,6 +179,7 @@ def test_get_documentation_package_path(): "fab", "dev0", """ + "apache-airflow-providers-common-compat>=1.2.0.dev0", "apache-airflow>=2.9.0.dev0", "flask-appbuilder==4.5.0", "flask-login>=0.6.2", @@ -191,6 +193,7 @@ def test_get_documentation_package_path(): "fab", "beta0", """ + "apache-airflow-providers-common-compat>=1.2.0b0", "apache-airflow>=2.9.0b0", "flask-appbuilder==4.5.0", "flask-login>=0.6.2", diff --git a/dev/breeze/tests/test_pytest_args_for_test_types.py b/dev/breeze/tests/test_pytest_args_for_test_types.py index 36a4b157794d2..7ecbbf4b5bf3c 100644 --- a/dev/breeze/tests/test_pytest_args_for_test_types.py +++ b/dev/breeze/tests/test_pytest_args_for_test_types.py @@ -151,13 +151,13 @@ ( "Other", [ + "tests/assets", "tests/auth", "tests/callbacks", "tests/charts", "tests/cluster_policies", "tests/config_templates", "tests/dag_processing", - "tests/datasets", "tests/decorators", "tests/hooks", "tests/io", diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 6161c44f6ebf3..4483cae573359 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -136,7 +136,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): pytest.param( ("airflow/api/file.py",), { - "affected-providers-list-as-string": "fab", + "affected-providers-list-as-string": "common.compat fab", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "python-versions": "['3.8']", @@ -150,13 +150,13 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "skip-pre-commits": "check-provider-yaml-valid,identity,lint-helm-chart,mypy-airflow,mypy-dev," "mypy-docs,mypy-providers,ts-compile-format-lint-ui,ts-compile-format-lint-www", "upgrade-to-newer-dependencies": "false", - "parallel-test-types-list-as-string": "API Always Providers[fab]", - "providers-test-types-list-as-string": "Providers[fab]", - "separate-test-types-list-as-string": "API Always Providers[fab]", + "parallel-test-types-list-as-string": "API Always Providers[common.compat,fab]", + "providers-test-types-list-as-string": "Providers[common.compat,fab]", + "separate-test-types-list-as-string": "API Always Providers[common.compat] Providers[fab]", "needs-mypy": "true", "mypy-folders": "['airflow']", }, - id="Only API tests and DOCS and FAB provider should run", + id="Only API tests and DOCS and common.compat, FAB providers should run", ) ), ( @@ -324,7 +324,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "tests/providers/postgres/file.py", ), { - "affected-providers-list-as-string": "amazon common.sql fab google openlineage " + "affected-providers-list-as-string": "amazon common.compat common.sql fab google openlineage " "pgvector postgres", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", @@ -340,10 +340,10 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "ts-compile-format-lint-ui,ts-compile-format-lint-www", "upgrade-to-newer-dependencies": "false", "parallel-test-types-list-as-string": "API Always Providers[amazon] " - "Providers[common.sql,fab,openlineage,pgvector,postgres] Providers[google]", + "Providers[common.compat,common.sql,fab,openlineage,pgvector,postgres] Providers[google]", "providers-test-types-list-as-string": "Providers[amazon] " - "Providers[common.sql,fab,openlineage,pgvector,postgres] Providers[google]", - "separate-test-types-list-as-string": "API Always Providers[amazon] Providers[common.sql] " + "Providers[common.compat,common.sql,fab,openlineage,pgvector,postgres] Providers[google]", + "separate-test-types-list-as-string": "API Always Providers[amazon] Providers[common.compat] Providers[common.sql] " "Providers[fab] Providers[google] Providers[openlineage] Providers[pgvector] " "Providers[postgres]", "needs-mypy": "true", @@ -1390,7 +1390,7 @@ def test_expected_output_pull_request_v2_7( "airflow/api/file.py", ), { - "affected-providers-list-as-string": "fab", + "affected-providers-list-as-string": "common.compat fab", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "ci-image-build": "true", @@ -1398,17 +1398,17 @@ def test_expected_output_pull_request_v2_7( "needs-helm-tests": "false", "run-tests": "true", "docs-build": "true", - "docs-list-as-string": "apache-airflow fab", + "docs-list-as-string": "apache-airflow common.compat fab", "skip-pre-commits": "check-provider-yaml-valid,identity,lint-helm-chart,mypy-airflow,mypy-dev,mypy-docs,mypy-providers," "ts-compile-format-lint-ui,ts-compile-format-lint-www", "run-kubernetes-tests": "false", "upgrade-to-newer-dependencies": "false", "skip-provider-tests": "false", - "parallel-test-types-list-as-string": "API Always CLI Operators Providers[fab] WWW", + "parallel-test-types-list-as-string": "API Always CLI Operators Providers[common.compat,fab] WWW", "needs-mypy": "true", "mypy-folders": "['airflow']", }, - id="No providers tests except fab should run if only CLI/API/Operators/WWW file changed", + id="No providers tests except common.compat fab should run if only CLI/API/Operators/WWW file changed", ), pytest.param( ("airflow/models/test.py",), diff --git a/docs/apache-airflow-providers-amazon/auth-manager/manage/index.rst b/docs/apache-airflow-providers-amazon/auth-manager/manage/index.rst index 0a540b8d32b4a..359f3cfff040c 100644 --- a/docs/apache-airflow-providers-amazon/auth-manager/manage/index.rst +++ b/docs/apache-airflow-providers-amazon/auth-manager/manage/index.rst @@ -164,7 +164,7 @@ This is equivalent to the :doc:`Viewer role in Flask AppBuilder ` for details on how dataset URIs work. +See :doc:`documentation on assets ` for details on how asset URIs work. -.. airflow-dataset-schemes:: +.. airflow-asset-schemes:: :tags: None :header-separator: " diff --git a/docs/apache-airflow-providers/howto/create-custom-providers.rst b/docs/apache-airflow-providers/howto/create-custom-providers.rst index 70588d4532b4d..d95719e38bc6c 100644 --- a/docs/apache-airflow-providers/howto/create-custom-providers.rst +++ b/docs/apache-airflow-providers/howto/create-custom-providers.rst @@ -96,9 +96,9 @@ Exposing customized functionality to the Airflow's core: * ``filesystems`` - this field should contain the list of all the filesystem module names. See :doc:`apache-airflow:core-concepts/objectstorage` for description of the filesystems. -* ``dataset-uris`` - this field should contain the list of the URI schemes together with +* ``asset-uris`` - this field should contain the list of the URI schemes together with class names implementing normalization functions. - See :doc:`apache-airflow:authoring-and-scheduling/datasets` for description of the dataset URIs. + See :doc:`apache-airflow:authoring-and-scheduling/assets` for description of the asset URIs. .. note:: Deprecated values diff --git a/docs/apache-airflow/administration-and-deployment/lineage.rst b/docs/apache-airflow/administration-and-deployment/lineage.rst index de20dd8f1d802..d2ef63d869755 100644 --- a/docs/apache-airflow/administration-and-deployment/lineage.rst +++ b/docs/apache-airflow/administration-and-deployment/lineage.rst @@ -96,8 +96,8 @@ Airflow provides a powerful feature for tracking data lineage not only between t This functionality helps you understand how data flows throughout your Airflow pipelines. A global instance of ``HookLineageCollector`` serves as the central hub for collecting lineage information. -Hooks can send details about datasets they interact with to this collector. -The collector then uses this data to construct AIP-60 compliant Datasets, a standard format for describing datasets. +Hooks can send details about assets they interact with to this collector. +The collector then uses this data to construct AIP-60 compliant Assets, a standard format for describing assets. .. code-block:: python @@ -108,8 +108,8 @@ The collector then uses this data to construct AIP-60 compliant Datasets, a stan def run(self): # run actual code collector = get_hook_lineage_collector() - collector.add_input_dataset(self, dataset_kwargs={"scheme": "file", "path": "/tmp/in"}) - collector.add_output_dataset(self, dataset_kwargs={"scheme": "file", "path": "/tmp/out"}) + collector.add_input_asset(self, asset_kwargs={"scheme": "file", "path": "/tmp/in"}) + collector.add_output_asset(self, asset_kwargs={"scheme": "file", "path": "/tmp/out"}) Lineage data collected by the ``HookLineageCollector`` can be accessed using an instance of ``HookLineageReader``, which is registered in an Airflow plugin. @@ -122,7 +122,7 @@ which is registered in an Airflow plugin. class CustomHookLineageReader(HookLineageReader): def get_inputs(self): - return self.lineage_collector.collected_datasets.inputs + return self.lineage_collector.collected_assets.inputs class HookLineageCollectionPlugin(AirflowPlugin): @@ -130,7 +130,7 @@ which is registered in an Airflow plugin. hook_lineage_readers = [CustomHookLineageReader] If no ``HookLineageReader`` is registered within Airflow, a default ``NoOpCollector`` is used instead. -This collector does not create AIP-60 compliant datasets or collect lineage information. +This collector does not create AIP-60 compliant assets or collect lineage information. Lineage Backend diff --git a/docs/apache-airflow/administration-and-deployment/listeners.rst b/docs/apache-airflow/administration-and-deployment/listeners.rst index 4926b12ed6c6d..1fca915a6f1df 100644 --- a/docs/apache-airflow/administration-and-deployment/listeners.rst +++ b/docs/apache-airflow/administration-and-deployment/listeners.rst @@ -91,14 +91,14 @@ You can use these events to react to ``LocalTaskJob`` state changes. :end-before: [END howto_listen_ti_failure_task] -Dataset Events +Asset Events -------------- -- ``on_dataset_created`` +- ``on_asset_created`` - ``on_dataset_alias_created`` -- ``on_dataset_changed`` +- ``on_asset_changed`` -Dataset events occur when Dataset management operations are run. +Asset events occur when Asset management operations are run. Dag Import Error Events diff --git a/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst b/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst index 61985cecea9b0..ac44d1acba9c0 100644 --- a/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst +++ b/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst @@ -201,10 +201,10 @@ Name Descripti fully asynchronous) ``triggers.failed`` Number of triggers that errored before they could fire an event ``triggers.succeeded`` Number of triggers that have fired at least one event -``dataset.updates`` Number of updated datasets -``dataset.orphaned`` Number of datasets marked as orphans because they are no longer referenced in DAG +``asset.updates`` Number of updated assets +``asset.orphaned`` Number of assets marked as orphans because they are no longer referenced in DAG schedule parameters or task outlets -``dataset.triggered_dagruns`` Number of DAG runs triggered by a dataset update +``asset.triggered_dagruns`` Number of DAG runs triggered by a asset update ====================================================================== ================================================================ Gauges diff --git a/docs/apache-airflow/authoring-and-scheduling/assets.rst b/docs/apache-airflow/authoring-and-scheduling/assets.rst new file mode 100644 index 0000000000000..d37143367fabe --- /dev/null +++ b/docs/apache-airflow/authoring-and-scheduling/assets.rst @@ -0,0 +1,532 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Data-aware scheduling +===================== + +.. versionadded:: 2.4 + +Quickstart +---------- + +In addition to scheduling DAGs based on time, you can also schedule DAGs to run based on when a task updates a asset. + +.. code-block:: python + + from airflow.assets import asset + + with DAG(...): + MyOperator( + # this task updates example.csv + outlets=[asset("s3://asset-bucket/example.csv")], + ..., + ) + + + with DAG( + # this DAG should be run when example.csv is updated (by dag1) + schedule=[asset("s3://asset-bucket/example.csv")], + ..., + ): + ... + + +.. image:: /img/asset-scheduled-dags.png + + +What is a "asset"? +-------------------- + +An Airflow asset is a logical grouping of data. Upstream producer tasks can update assets, and asset updates contribute to scheduling downstream consumer DAGs. + +`Uniform Resource Identifier (URI) `_ define assets: + +.. code-block:: python + + from airflow.assets import asset + + example_asset = asset("s3://asset-bucket/example.csv") + +Airflow makes no assumptions about the content or location of the data represented by the URI, and treats the URI like a string. This means that Airflow treats any regular expressions, like ``input_\d+.csv``, or file glob patterns, such as ``input_2022*.csv``, as an attempt to create multiple assets from one declaration, and they will not work. + +You must create assets with a valid URI. Airflow core and providers define various URI schemes that you can use, such as ``file`` (core), ``postgres`` (by the Postgres provider), and ``s3`` (by the Amazon provider). Third-party providers and plugins might also provide their own schemes. These pre-defined schemes have individual semantics that are expected to be followed. + +What is valid URI? +------------------ + +Technically, the URI must conform to the valid character set in RFC 3986, which is basically ASCII alphanumeric characters, plus ``%``, ``-``, ``_``, ``.``, and ``~``. To identify a resource that cannot be represented by URI-safe characters, encode the resource name with `percent-encoding `_. + +The URI is also case sensitive, so ``s3://example/asset`` and ``s3://Example/asset`` are considered different. Note that the *host* part of the URI is also case sensitive, which differs from RFC 3986. + +Do not use the ``airflow`` scheme, which is is reserved for Airflow's internals. + +Airflow always prefers using lower cases in schemes, and case sensitivity is needed in the host part of the URI to correctly distinguish between resources. + +.. code-block:: python + + # invalid assets: + reserved = asset("airflow://example_asset") + not_ascii = asset("èxample_datašet") + +If you want to define assets with a scheme that doesn't include additional semantic constraints, use a scheme with the prefix ``x-``. Airflow skips any semantic validation on URIs with these schemes. + +.. code-block:: python + + # valid asset, treated as a plain string + my_ds = asset("x-my-thing://foobarbaz") + +The identifier does not have to be absolute; it can be a scheme-less, relative URI, or even just a simple path or string: + +.. code-block:: python + + # valid assets: + schemeless = asset("//example/asset") + csv_file = asset("example_asset") + +Non-absolute identifiers are considered plain strings that do not carry any semantic meanings to Airflow. + +Extra information on asset +---------------------------- + +If needed, you can include an extra dictionary in a asset: + +.. code-block:: python + + example_asset = asset( + "s3://asset/example.csv", + extra={"team": "trainees"}, + ) + +This can be used to supply custom description to the asset, such as who has ownership to the target file, or what the file is for. The extra information does not affect a asset's identity. This means a DAG will be triggered by a asset with an identical URI, even if the extra dict is different: + +.. code-block:: python + + with DAG( + dag_id="consumer", + schedule=[asset("s3://asset/example.csv", extra={"different": "extras"})], + ): + ... + + with DAG(dag_id="producer", ...): + MyOperator( + # triggers "consumer" with the given extra! + outlets=[asset("s3://asset/example.csv", extra={"team": "trainees"})], + ..., + ) + +.. note:: **Security Note:** asset URI and extra fields are not encrypted, they are stored in cleartext in Airflow's metadata database. Do NOT store any sensitive values, especially credentials, in either asset URIs or extra key values! + +How to use assets in your DAGs +-------------------------------- + +You can use assets to specify data dependencies in your DAGs. The following example shows how after the ``producer`` task in the ``producer`` DAG successfully completes, Airflow schedules the ``consumer`` DAG. Airflow marks a asset as ``updated`` only if the task completes successfully. If the task fails or if it is skipped, no update occurs, and Airflow doesn't schedule the ``consumer`` DAG. + +.. code-block:: python + + example_asset = asset("s3://asset/example.csv") + + with DAG(dag_id="producer", ...): + BashOperator(task_id="producer", outlets=[example_asset], ...) + + with DAG(dag_id="consumer", schedule=[example_asset], ...): + ... + + +You can find a listing of the relationships between assets and DAGs in the +:ref:`assets View` + +Multiple assets +----------------- + +Because the ``schedule`` parameter is a list, DAGs can require multiple assets. Airflow schedules a DAG after **all** assets the DAG consumes have been updated at least once since the last time the DAG ran: + +.. code-block:: python + + with DAG( + dag_id="multiple_assets_example", + schedule=[ + example_asset_1, + example_asset_2, + example_asset_3, + ], + ..., + ): + ... + + +If one asset is updated multiple times before all consumed assets update, the downstream DAG still only runs once, as shown in this illustration: + +.. :: + ASCII art representation of this diagram + + example_asset_1 x----x---x---x----------------------x- + example_asset_2 -------x---x-------x------x----x------ + example_asset_3 ---------------x-----x------x--------- + DAG runs created * * + +.. graphviz:: + + graph asset_event_timeline { + graph [layout=neato] + { + node [margin=0 fontcolor=blue width=0.1 shape=point label=""] + e1 [pos="1,2.5!"] + e2 [pos="2,2.5!"] + e3 [pos="2.5,2!"] + e4 [pos="4,2.5!"] + e5 [pos="5,2!"] + e6 [pos="6,2.5!"] + e7 [pos="7,1.5!"] + r7 [pos="7,1!" shape=star width=0.25 height=0.25 fixedsize=shape] + e8 [pos="8,2!"] + e9 [pos="9,1.5!"] + e10 [pos="10,2!"] + e11 [pos="11,1.5!"] + e12 [pos="12,2!"] + e13 [pos="13,2.5!"] + r13 [pos="13,1!" shape=star width=0.25 height=0.25 fixedsize=shape] + } + { + node [shape=none label="" width=0] + end_ds1 [pos="14,2.5!"] + end_ds2 [pos="14,2!"] + end_ds3 [pos="14,1.5!"] + } + + { + node [shape=none margin=0.25 fontname="roboto,sans-serif"] + example_asset_1 [ pos="-0.5,2.5!"] + example_asset_2 [ pos="-0.5,2!"] + example_asset_3 [ pos="-0.5,1.5!"] + dag_runs [label="DagRuns created" pos="-0.5,1!"] + } + + edge [color=lightgrey] + + example_asset_1 -- e1 -- e2 -- e4 -- e6 -- e13 -- end_ds1 + example_asset_2 -- e3 -- e5 -- e8 -- e10 -- e12 -- end_ds2 + example_asset_3 -- e7 -- e9 -- e11 -- end_ds3 + + } + +Attaching extra information to an emitting asset event +-------------------------------------------------------- + +.. versionadded:: 2.10.0 + +A task with a asset outlet can optionally attach extra information before it emits a asset event. This is different +from `Extra information on asset`_. Extra information on a asset statically describes the entity pointed to by the asset URI; extra information on the *asset event* instead should be used to annotate the triggering data change, such as how many rows in the database are changed by the update, or the date range covered by it. + +The easiest way to attach extra information to the asset event is by ``yield``-ing a ``Metadata`` object from a task: + +.. code-block:: python + + from airflow.assets import asset + from airflow.assets.metadata import Metadata + + example_s3_asset = asset("s3://asset/example.csv") + + + @task(outlets=[example_s3_asset]) + def write_to_s3(): + df = ... # Get a Pandas DataFrame to write. + # Write df to asset... + yield Metadata(example_s3_asset, {"row_count": len(df)}) + +Airflow automatically collects all yielded metadata, and populates asset events with extra information for corresponding metadata objects. + +This can also be done in classic operators. The best way is to subclass the operator and override ``execute``. Alternatively, extras can also be added in a task's ``pre_execute`` or ``post_execute`` hook. If you choose to use hooks, however, remember that they are not rerun when a task is retried, and may cause the extra information to not match actual data in certain scenarios. + +Another way to achieve the same is by accessing ``outlet_events`` in a task's execution context directly: + +.. code-block:: python + + @task(outlets=[example_s3_asset]) + def write_to_s3(*, outlet_events): + outlet_events[example_s3_asset].extra = {"row_count": len(df)} + +There's minimal magic here---Airflow simply writes the yielded values to the exact same accessor. This also works in classic operators, including ``execute``, ``pre_execute``, and ``post_execute``. + +.. _fetching_information_from_previously_emitted_asset_events: + +Fetching information from previously emitted asset events +----------------------------------------------------------- + +.. versionadded:: 2.10.0 + +Events of a asset defined in a task's ``outlets``, as described in the previous section, can be read by a task that declares the same asset in its ``inlets``. A asset event entry contains ``extra`` (see previous section for details), ``timestamp`` indicating when the event was emitted from a task, and ``source_task_instance`` linking the event back to its source. + +Inlet asset events can be read with the ``inlet_events`` accessor in the execution context. Continuing from the ``write_to_s3`` task in the previous section: + +.. code-block:: python + + @task(inlets=[example_s3_asset]) + def post_process_s3_file(*, inlet_events): + events = inlet_events[example_s3_asset] + last_row_count = events[-1].extra["row_count"] + +Each value in the ``inlet_events`` mapping is a sequence-like object that orders past events of a given asset by ``timestamp``, earliest to latest. It supports most of Python's list interface, so you can use ``[-1]`` to access the last event, ``[-2:]`` for the last two, etc. The accessor is lazy and only hits the database when you access items inside it. + + +Fetching information from a triggering asset event +---------------------------------------------------- + +A triggered DAG can fetch information from the asset that triggered it using the ``triggering_asset_events`` template or parameter. See more at :ref:`templates-ref`. + +Example: + +.. code-block:: python + + example_snowflake_asset = asset("snowflake://my_db/my_schema/my_table") + + with DAG(dag_id="load_snowflake_data", schedule="@hourly", ...): + SQLExecuteQueryOperator( + task_id="load", conn_id="snowflake_default", outlets=[example_snowflake_asset], ... + ) + + with DAG(dag_id="query_snowflake_data", schedule=[example_snowflake_asset], ...): + SQLExecuteQueryOperator( + task_id="query", + conn_id="snowflake_default", + sql=""" + SELECT * + FROM my_db.my_schema.my_table + WHERE "updated_at" >= '{{ (triggering_asset_events.values() | first | first).source_dag_run.data_interval_start }}' + AND "updated_at" < '{{ (triggering_asset_events.values() | first | first).source_dag_run.data_interval_end }}'; + """, + ) + + @task + def print_triggering_asset_events(triggering_asset_events=None): + for asset, asset_list in triggering_asset_events.items(): + print(asset, asset_list) + print(asset_list[0].source_dag_run.dag_id) + + print_triggering_asset_events() + +Note that this example is using `(.values() | first | first) `_ to fetch the first of one asset given to the DAG, and the first of one AssetEvent for that asset. An implementation can be quite complex if you have multiple assets, potentially with multiple AssetEvents. + + +Manipulating queued asset events through REST API +--------------------------------------------------- + +.. versionadded:: 2.9 + +In this example, the DAG ``waiting_for_asset_1_and_2`` will be triggered when tasks update both assets "asset-1" and "asset-2". Once "asset-1" is updated, Airflow creates a record. This ensures that Airflow knows to trigger the DAG when "asset-2" is updated. We call such records queued asset events. + +.. code-block:: python + + with DAG( + dag_id="waiting_for_asset_1_and_2", + schedule=[asset("asset-1"), asset("asset-2")], + ..., + ): + ... + + +``queuedEvent`` API endpoints are introduced to manipulate such records. + +* Get a queued asset event for a DAG: ``/assets/queuedEvent/{uri}`` +* Get queued asset events for a DAG: ``/dags/{dag_id}/assets/queuedEvent`` +* Delete a queued asset event for a DAG: ``/assets/queuedEvent/{uri}`` +* Delete queued asset events for a DAG: ``/dags/{dag_id}/assets/queuedEvent`` +* Get queued asset events for a asset: ``/dags/{dag_id}/assets/queuedEvent/{uri}`` +* Delete queued asset events for a asset: ``DELETE /dags/{dag_id}/assets/queuedEvent/{uri}`` + + For how to use REST API and the parameters needed for these endpoints, please refer to :doc:`Airflow API `. + +Advanced asset scheduling with conditional expressions +-------------------------------------------------------- + +Apache Airflow includes advanced scheduling capabilities that use conditional expressions with assets. This feature allows you to define complex dependencies for DAG executions based on asset updates, using logical operators for more control on workflow triggers. + +Logical operators for assets +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Airflow supports two logical operators for combining asset conditions: + +- **AND (``&``)**: Specifies that the DAG should be triggered only after all of the specified assets have been updated. +- **OR (``|``)**: Specifies that the DAG should be triggered when any of the specified assets is updated. + +These operators enable you to configure your Airflow workflows to use more complex asset update conditions, making them more dynamic and flexible. + +Example Use +------------- + +**Scheduling based on multiple asset updates** + +To schedule a DAG to run only when two specific assets have both been updated, use the AND operator (``&``): + +.. code-block:: python + + dag1_asset = asset("s3://dag1/output_1.txt") + dag2_asset = asset("s3://dag2/output_1.txt") + + with DAG( + # Consume asset 1 and 2 with asset expressions + schedule=(dag1_asset & dag2_asset), + ..., + ): + ... + +**Scheduling based on any asset update** + +To trigger a DAG execution when either one of two assets is updated, apply the OR operator (``|``): + +.. code-block:: python + + with DAG( + # Consume asset 1 or 2 with asset expressions + schedule=(dag1_asset | dag2_asset), + ..., + ): + ... + +**Complex Conditional Logic** + +For scenarios requiring more intricate conditions, such as triggering a DAG when one asset is updated or when both of two other assets are updated, combine the OR and AND operators: + +.. code-block:: python + + dag3_asset = asset("s3://dag3/output_3.txt") + + with DAG( + # Consume asset 1 or both 2 and 3 with asset expressions + schedule=(dag1_asset | (dag2_asset & dag3_asset)), + ..., + ): + ... + + +Dynamic data events emitting and asset creation through AssetAlias +----------------------------------------------------------------------- +An asset alias can be used to emit asset events of assets with association to the aliases. Downstreams can depend on resolved asset. This feature allows you to define complex dependencies for DAG executions based on asset updates. + +How to use AssetAlias +~~~~~~~~~~~~~~~~~~~~~~~ + +``AssetAlias`` has one single argument ``name`` that uniquely identifies the asset. The task must first declare the alias as an outlet, and use ``outlet_events`` or yield ``Metadata`` to add events to it. + +The following example creates a asset event against the S3 URI ``f"s3://bucket/my-task"`` with optional extra information ``extra``. If the asset does not exist, Airflow will dynamically create it and log a warning message. + +**Emit a asset event during task execution through outlet_events** + +.. code-block:: python + + from airflow.assets import AssetAlias + + + @task(outlets=[AssetAlias("my-task-outputs")]) + def my_task_with_outlet_events(*, outlet_events): + outlet_events["my-task-outputs"].add(asset("s3://bucket/my-task"), extra={"k": "v"}) + + +**Emit a asset event during task execution through yielding Metadata** + +.. code-block:: python + + from airflow.assets.metadata import Metadata + + + @task(outlets=[AssetAlias("my-task-outputs")]) + def my_task_with_metadata(): + s3_asset = asset("s3://bucket/my-task") + yield Metadata(s3_asset, extra={"k": "v"}, alias="my-task-outputs") + +Only one asset event is emitted for an added asset, even if it is added to the alias multiple times, or added to multiple aliases. However, if different ``extra`` values are passed, it can emit multiple asset events. In the following example, two asset events will be emitted. + +.. code-block:: python + + from airflow.assets import AssetAlias + + + @task( + outlets=[ + AssetAlias("my-task-outputs-1"), + AssetAlias("my-task-outputs-2"), + AssetAlias("my-task-outputs-3"), + ] + ) + def my_task_with_outlet_events(*, outlet_events): + outlet_events["my-task-outputs-1"].add(asset("s3://bucket/my-task"), extra={"k": "v"}) + # This line won't emit an additional asset event as the asset and extra are the same as the previous line. + outlet_events["my-task-outputs-2"].add(asset("s3://bucket/my-task"), extra={"k": "v"}) + # This line will emit an additional asset event as the extra is different. + outlet_events["my-task-outputs-3"].add(asset("s3://bucket/my-task"), extra={"k2": "v2"}) + +Scheduling based on asset aliases +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Since asset events added to an alias are just simple asset events, a downstream DAG depending on the actual asset can read asset events of it normally, without considering the associated aliases. A downstream DAG can also depend on an asset alias. The authoring syntax is referencing the ``AssetAlias`` by name, and the associated asset events are picked up for scheduling. Note that a DAG can be triggered by a task with ``outlets=AssetAlias("xxx")`` if and only if the alias is resolved into ``asset("s3://bucket/my-task")``. The DAG runs whenever a task with outlet ``AssetAlias("out")`` gets associated with at least one asset at runtime, regardless of the asset's identity. The downstream DAG is not triggered if no assets are associated to the alias for a particular given task run. This also means we can do conditional asset-triggering. + +The asset alias is resolved to the assets during DAG parsing. Thus, if the "min_file_process_interval" configuration is set to a high value, there is a possibility that the asset alias may not be resolved. To resolve this issue, you can trigger DAG parsing. + +.. code-block:: python + + with DAG(dag_id="asset-producer"): + + @task(outlets=[asset("example-alias")]) + def produce_asset_events(): + pass + + + with DAG(dag_id="asset-alias-producer"): + + @task(outlets=[AssetAlias("example-alias")]) + def produce_asset_events(*, outlet_events): + outlet_events["example-alias"].add(asset("s3://bucket/my-task")) + + + with DAG(dag_id="asset-consumer", schedule=asset("s3://bucket/my-task")): + ... + + with DAG(dag_id="asset-alias-consumer", schedule=AssetAlias("example-alias")): + ... + + +In the example provided, once the DAG ``asset-alias-producer`` is executed, the asset alias ``AssetAlias("example-alias")`` will be resolved to ``asset("s3://bucket/my-task")``. However, the DAG ``asset-alias-consumer`` will have to wait for the next DAG re-parsing to update its schedule. To address this, Airflow will re-parse the DAGs relying on the asset alias ``AssetAlias("example-alias")`` when it's resolved into assets that these DAGs did not previously depend on. As a result, both the "asset-consumer" and "asset-alias-consumer" DAGs will be triggered after the execution of DAG ``asset-alias-producer``. + + +Fetching information from previously emitted asset events through resolved asset aliases +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +As mentioned in :ref:`Fetching information from previously emitted asset events`, inlet asset events can be read with the ``inlet_events`` accessor in the execution context, and you can also use asset aliases to access the asset events triggered by them. + +.. code-block:: python + + with DAG(dag_id="asset-alias-producer"): + + @task(outlets=[AssetAlias("example-alias")]) + def produce_asset_events(*, outlet_events): + outlet_events["example-alias"].add(asset("s3://bucket/my-task"), extra={"row_count": 1}) + + + with DAG(dag_id="asset-alias-consumer", schedule=None): + + @task(inlets=[AssetAlias("example-alias")]) + def consume_asset_alias_events(*, inlet_events): + events = inlet_events[AssetAlias("example-alias")] + last_row_count = events[-1].extra["row_count"] + + +Combining asset and time-based schedules +------------------------------------------ + +AssetTimetable Integration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +You can schedule DAGs based on both asset events and time-based schedules using ``AssetOrTimeSchedule``. This allows you to create workflows when a DAG needs both to be triggered by data updates and run periodically according to a fixed timetable. + +For more detailed information on ``AssetOrTimeSchedule``, refer to the corresponding section in :ref:`AssetOrTimeSchedule `. diff --git a/docs/apache-airflow/authoring-and-scheduling/datasets.rst b/docs/apache-airflow/authoring-and-scheduling/datasets.rst deleted file mode 100644 index a69c09bc13b0f..0000000000000 --- a/docs/apache-airflow/authoring-and-scheduling/datasets.rst +++ /dev/null @@ -1,532 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -Data-aware scheduling -===================== - -.. versionadded:: 2.4 - -Quickstart ----------- - -In addition to scheduling DAGs based on time, you can also schedule DAGs to run based on when a task updates a dataset. - -.. code-block:: python - - from airflow.datasets import Dataset - - with DAG(...): - MyOperator( - # this task updates example.csv - outlets=[Dataset("s3://dataset-bucket/example.csv")], - ..., - ) - - - with DAG( - # this DAG should be run when example.csv is updated (by dag1) - schedule=[Dataset("s3://dataset-bucket/example.csv")], - ..., - ): - ... - - -.. image:: /img/dataset-scheduled-dags.png - - -What is a "dataset"? --------------------- - -An Airflow dataset is a logical grouping of data. Upstream producer tasks can update datasets, and dataset updates contribute to scheduling downstream consumer DAGs. - -`Uniform Resource Identifier (URI) `_ define datasets: - -.. code-block:: python - - from airflow.datasets import Dataset - - example_dataset = Dataset("s3://dataset-bucket/example.csv") - -Airflow makes no assumptions about the content or location of the data represented by the URI, and treats the URI like a string. This means that Airflow treats any regular expressions, like ``input_\d+.csv``, or file glob patterns, such as ``input_2022*.csv``, as an attempt to create multiple datasets from one declaration, and they will not work. - -You must create datasets with a valid URI. Airflow core and providers define various URI schemes that you can use, such as ``file`` (core), ``postgres`` (by the Postgres provider), and ``s3`` (by the Amazon provider). Third-party providers and plugins might also provide their own schemes. These pre-defined schemes have individual semantics that are expected to be followed. - -What is valid URI? ------------------- - -Technically, the URI must conform to the valid character set in RFC 3986, which is basically ASCII alphanumeric characters, plus ``%``, ``-``, ``_``, ``.``, and ``~``. To identify a resource that cannot be represented by URI-safe characters, encode the resource name with `percent-encoding `_. - -The URI is also case sensitive, so ``s3://example/dataset`` and ``s3://Example/Dataset`` are considered different. Note that the *host* part of the URI is also case sensitive, which differs from RFC 3986. - -Do not use the ``airflow`` scheme, which is is reserved for Airflow's internals. - -Airflow always prefers using lower cases in schemes, and case sensitivity is needed in the host part of the URI to correctly distinguish between resources. - -.. code-block:: python - - # invalid datasets: - reserved = Dataset("airflow://example_dataset") - not_ascii = Dataset("èxample_datašet") - -If you want to define datasets with a scheme that doesn't include additional semantic constraints, use a scheme with the prefix ``x-``. Airflow skips any semantic validation on URIs with these schemes. - -.. code-block:: python - - # valid dataset, treated as a plain string - my_ds = Dataset("x-my-thing://foobarbaz") - -The identifier does not have to be absolute; it can be a scheme-less, relative URI, or even just a simple path or string: - -.. code-block:: python - - # valid datasets: - schemeless = Dataset("//example/dataset") - csv_file = Dataset("example_dataset") - -Non-absolute identifiers are considered plain strings that do not carry any semantic meanings to Airflow. - -Extra information on dataset ----------------------------- - -If needed, you can include an extra dictionary in a dataset: - -.. code-block:: python - - example_dataset = Dataset( - "s3://dataset/example.csv", - extra={"team": "trainees"}, - ) - -This can be used to supply custom description to the dataset, such as who has ownership to the target file, or what the file is for. The extra information does not affect a dataset's identity. This means a DAG will be triggered by a dataset with an identical URI, even if the extra dict is different: - -.. code-block:: python - - with DAG( - dag_id="consumer", - schedule=[Dataset("s3://dataset/example.csv", extra={"different": "extras"})], - ): - ... - - with DAG(dag_id="producer", ...): - MyOperator( - # triggers "consumer" with the given extra! - outlets=[Dataset("s3://dataset/example.csv", extra={"team": "trainees"})], - ..., - ) - -.. note:: **Security Note:** Dataset URI and extra fields are not encrypted, they are stored in cleartext in Airflow's metadata database. Do NOT store any sensitive values, especially credentials, in either dataset URIs or extra key values! - -How to use datasets in your DAGs --------------------------------- - -You can use datasets to specify data dependencies in your DAGs. The following example shows how after the ``producer`` task in the ``producer`` DAG successfully completes, Airflow schedules the ``consumer`` DAG. Airflow marks a dataset as ``updated`` only if the task completes successfully. If the task fails or if it is skipped, no update occurs, and Airflow doesn't schedule the ``consumer`` DAG. - -.. code-block:: python - - example_dataset = Dataset("s3://dataset/example.csv") - - with DAG(dag_id="producer", ...): - BashOperator(task_id="producer", outlets=[example_dataset], ...) - - with DAG(dag_id="consumer", schedule=[example_dataset], ...): - ... - - -You can find a listing of the relationships between datasets and DAGs in the -:ref:`Datasets View` - -Multiple Datasets ------------------ - -Because the ``schedule`` parameter is a list, DAGs can require multiple datasets. Airflow schedules a DAG after **all** datasets the DAG consumes have been updated at least once since the last time the DAG ran: - -.. code-block:: python - - with DAG( - dag_id="multiple_datasets_example", - schedule=[ - example_dataset_1, - example_dataset_2, - example_dataset_3, - ], - ..., - ): - ... - - -If one dataset is updated multiple times before all consumed datasets update, the downstream DAG still only runs once, as shown in this illustration: - -.. :: - ASCII art representation of this diagram - - example_dataset_1 x----x---x---x----------------------x- - example_dataset_2 -------x---x-------x------x----x------ - example_dataset_3 ---------------x-----x------x--------- - DAG runs created * * - -.. graphviz:: - - graph dataset_event_timeline { - graph [layout=neato] - { - node [margin=0 fontcolor=blue width=0.1 shape=point label=""] - e1 [pos="1,2.5!"] - e2 [pos="2,2.5!"] - e3 [pos="2.5,2!"] - e4 [pos="4,2.5!"] - e5 [pos="5,2!"] - e6 [pos="6,2.5!"] - e7 [pos="7,1.5!"] - r7 [pos="7,1!" shape=star width=0.25 height=0.25 fixedsize=shape] - e8 [pos="8,2!"] - e9 [pos="9,1.5!"] - e10 [pos="10,2!"] - e11 [pos="11,1.5!"] - e12 [pos="12,2!"] - e13 [pos="13,2.5!"] - r13 [pos="13,1!" shape=star width=0.25 height=0.25 fixedsize=shape] - } - { - node [shape=none label="" width=0] - end_ds1 [pos="14,2.5!"] - end_ds2 [pos="14,2!"] - end_ds3 [pos="14,1.5!"] - } - - { - node [shape=none margin=0.25 fontname="roboto,sans-serif"] - example_dataset_1 [ pos="-0.5,2.5!"] - example_dataset_2 [ pos="-0.5,2!"] - example_dataset_3 [ pos="-0.5,1.5!"] - dag_runs [label="DagRuns created" pos="-0.5,1!"] - } - - edge [color=lightgrey] - - example_dataset_1 -- e1 -- e2 -- e4 -- e6 -- e13 -- end_ds1 - example_dataset_2 -- e3 -- e5 -- e8 -- e10 -- e12 -- end_ds2 - example_dataset_3 -- e7 -- e9 -- e11 -- end_ds3 - - } - -Attaching extra information to an emitting dataset event --------------------------------------------------------- - -.. versionadded:: 2.10.0 - -A task with a dataset outlet can optionally attach extra information before it emits a dataset event. This is different -from `Extra information on dataset`_. Extra information on a dataset statically describes the entity pointed to by the dataset URI; extra information on the *dataset event* instead should be used to annotate the triggering data change, such as how many rows in the database are changed by the update, or the date range covered by it. - -The easiest way to attach extra information to the dataset event is by ``yield``-ing a ``Metadata`` object from a task: - -.. code-block:: python - - from airflow.datasets import Dataset - from airflow.datasets.metadata import Metadata - - example_s3_dataset = Dataset("s3://dataset/example.csv") - - - @task(outlets=[example_s3_dataset]) - def write_to_s3(): - df = ... # Get a Pandas DataFrame to write. - # Write df to dataset... - yield Metadata(example_s3_dataset, {"row_count": len(df)}) - -Airflow automatically collects all yielded metadata, and populates dataset events with extra information for corresponding metadata objects. - -This can also be done in classic operators. The best way is to subclass the operator and override ``execute``. Alternatively, extras can also be added in a task's ``pre_execute`` or ``post_execute`` hook. If you choose to use hooks, however, remember that they are not rerun when a task is retried, and may cause the extra information to not match actual data in certain scenarios. - -Another way to achieve the same is by accessing ``outlet_events`` in a task's execution context directly: - -.. code-block:: python - - @task(outlets=[example_s3_dataset]) - def write_to_s3(*, outlet_events): - outlet_events[example_s3_dataset].extra = {"row_count": len(df)} - -There's minimal magic here---Airflow simply writes the yielded values to the exact same accessor. This also works in classic operators, including ``execute``, ``pre_execute``, and ``post_execute``. - -.. _fetching_information_from_previously_emitted_dataset_events: - -Fetching information from previously emitted dataset events ------------------------------------------------------------ - -.. versionadded:: 2.10.0 - -Events of a dataset defined in a task's ``outlets``, as described in the previous section, can be read by a task that declares the same dataset in its ``inlets``. A dataset event entry contains ``extra`` (see previous section for details), ``timestamp`` indicating when the event was emitted from a task, and ``source_task_instance`` linking the event back to its source. - -Inlet dataset events can be read with the ``inlet_events`` accessor in the execution context. Continuing from the ``write_to_s3`` task in the previous section: - -.. code-block:: python - - @task(inlets=[example_s3_dataset]) - def post_process_s3_file(*, inlet_events): - events = inlet_events[example_s3_dataset] - last_row_count = events[-1].extra["row_count"] - -Each value in the ``inlet_events`` mapping is a sequence-like object that orders past events of a given dataset by ``timestamp``, earliest to latest. It supports most of Python's list interface, so you can use ``[-1]`` to access the last event, ``[-2:]`` for the last two, etc. The accessor is lazy and only hits the database when you access items inside it. - - -Fetching information from a triggering dataset event ----------------------------------------------------- - -A triggered DAG can fetch information from the dataset that triggered it using the ``triggering_dataset_events`` template or parameter. See more at :ref:`templates-ref`. - -Example: - -.. code-block:: python - - example_snowflake_dataset = Dataset("snowflake://my_db/my_schema/my_table") - - with DAG(dag_id="load_snowflake_data", schedule="@hourly", ...): - SQLExecuteQueryOperator( - task_id="load", conn_id="snowflake_default", outlets=[example_snowflake_dataset], ... - ) - - with DAG(dag_id="query_snowflake_data", schedule=[example_snowflake_dataset], ...): - SQLExecuteQueryOperator( - task_id="query", - conn_id="snowflake_default", - sql=""" - SELECT * - FROM my_db.my_schema.my_table - WHERE "updated_at" >= '{{ (triggering_dataset_events.values() | first | first).source_dag_run.data_interval_start }}' - AND "updated_at" < '{{ (triggering_dataset_events.values() | first | first).source_dag_run.data_interval_end }}'; - """, - ) - - @task - def print_triggering_dataset_events(triggering_dataset_events=None): - for dataset, dataset_list in triggering_dataset_events.items(): - print(dataset, dataset_list) - print(dataset_list[0].source_dag_run.dag_id) - - print_triggering_dataset_events() - -Note that this example is using `(.values() | first | first) `_ to fetch the first of one dataset given to the DAG, and the first of one DatasetEvent for that dataset. An implementation can be quite complex if you have multiple datasets, potentially with multiple DatasetEvents. - - -Manipulating queued dataset events through REST API ---------------------------------------------------- - -.. versionadded:: 2.9 - -In this example, the DAG ``waiting_for_dataset_1_and_2`` will be triggered when tasks update both datasets "dataset-1" and "dataset-2". Once "dataset-1" is updated, Airflow creates a record. This ensures that Airflow knows to trigger the DAG when "dataset-2" is updated. We call such records queued dataset events. - -.. code-block:: python - - with DAG( - dag_id="waiting_for_dataset_1_and_2", - schedule=[Dataset("dataset-1"), Dataset("dataset-2")], - ..., - ): - ... - - -``queuedEvent`` API endpoints are introduced to manipulate such records. - -* Get a queued Dataset event for a DAG: ``/datasets/queuedEvent/{uri}`` -* Get queued Dataset events for a DAG: ``/dags/{dag_id}/datasets/queuedEvent`` -* Delete a queued Dataset event for a DAG: ``/datasets/queuedEvent/{uri}`` -* Delete queued Dataset events for a DAG: ``/dags/{dag_id}/datasets/queuedEvent`` -* Get queued Dataset events for a Dataset: ``/dags/{dag_id}/datasets/queuedEvent/{uri}`` -* Delete queued Dataset events for a Dataset: ``DELETE /dags/{dag_id}/datasets/queuedEvent/{uri}`` - - For how to use REST API and the parameters needed for these endpoints, please refer to :doc:`Airflow API `. - -Advanced dataset scheduling with conditional expressions --------------------------------------------------------- - -Apache Airflow includes advanced scheduling capabilities that use conditional expressions with datasets. This feature allows you to define complex dependencies for DAG executions based on dataset updates, using logical operators for more control on workflow triggers. - -Logical operators for datasets -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Airflow supports two logical operators for combining dataset conditions: - -- **AND (``&``)**: Specifies that the DAG should be triggered only after all of the specified datasets have been updated. -- **OR (``|``)**: Specifies that the DAG should be triggered when any of the specified datasets is updated. - -These operators enable you to configure your Airflow workflows to use more complex dataset update conditions, making them more dynamic and flexible. - -Example Use -------------- - -**Scheduling based on multiple dataset updates** - -To schedule a DAG to run only when two specific datasets have both been updated, use the AND operator (``&``): - -.. code-block:: python - - dag1_dataset = Dataset("s3://dag1/output_1.txt") - dag2_dataset = Dataset("s3://dag2/output_1.txt") - - with DAG( - # Consume dataset 1 and 2 with dataset expressions - schedule=(dag1_dataset & dag2_dataset), - ..., - ): - ... - -**Scheduling based on any dataset update** - -To trigger a DAG execution when either one of two datasets is updated, apply the OR operator (``|``): - -.. code-block:: python - - with DAG( - # Consume dataset 1 or 2 with dataset expressions - schedule=(dag1_dataset | dag2_dataset), - ..., - ): - ... - -**Complex Conditional Logic** - -For scenarios requiring more intricate conditions, such as triggering a DAG when one dataset is updated or when both of two other datasets are updated, combine the OR and AND operators: - -.. code-block:: python - - dag3_dataset = Dataset("s3://dag3/output_3.txt") - - with DAG( - # Consume dataset 1 or both 2 and 3 with dataset expressions - schedule=(dag1_dataset | (dag2_dataset & dag3_dataset)), - ..., - ): - ... - - -Dynamic data events emitting and dataset creation through DatasetAlias ------------------------------------------------------------------------ -A dataset alias can be used to emit dataset events of datasets with association to the aliases. Downstreams can depend on resolved dataset. This feature allows you to define complex dependencies for DAG executions based on dataset updates. - -How to use DatasetAlias -~~~~~~~~~~~~~~~~~~~~~~~ - -``DatasetAlias`` has one single argument ``name`` that uniquely identifies the dataset. The task must first declare the alias as an outlet, and use ``outlet_events`` or yield ``Metadata`` to add events to it. - -The following example creates a dataset event against the S3 URI ``f"s3://bucket/my-task"`` with optional extra information ``extra``. If the dataset does not exist, Airflow will dynamically create it and log a warning message. - -**Emit a dataset event during task execution through outlet_events** - -.. code-block:: python - - from airflow.datasets import DatasetAlias - - - @task(outlets=[DatasetAlias("my-task-outputs")]) - def my_task_with_outlet_events(*, outlet_events): - outlet_events["my-task-outputs"].add(Dataset("s3://bucket/my-task"), extra={"k": "v"}) - - -**Emit a dataset event during task execution through yielding Metadata** - -.. code-block:: python - - from airflow.datasets.metadata import Metadata - - - @task(outlets=[DatasetAlias("my-task-outputs")]) - def my_task_with_metadata(): - s3_dataset = Dataset("s3://bucket/my-task") - yield Metadata(s3_dataset, extra={"k": "v"}, alias="my-task-outputs") - -Only one dataset event is emitted for an added dataset, even if it is added to the alias multiple times, or added to multiple aliases. However, if different ``extra`` values are passed, it can emit multiple dataset events. In the following example, two dataset events will be emitted. - -.. code-block:: python - - from airflow.datasets import DatasetAlias - - - @task( - outlets=[ - DatasetAlias("my-task-outputs-1"), - DatasetAlias("my-task-outputs-2"), - DatasetAlias("my-task-outputs-3"), - ] - ) - def my_task_with_outlet_events(*, outlet_events): - outlet_events["my-task-outputs-1"].add(Dataset("s3://bucket/my-task"), extra={"k": "v"}) - # This line won't emit an additional dataset event as the dataset and extra are the same as the previous line. - outlet_events["my-task-outputs-2"].add(Dataset("s3://bucket/my-task"), extra={"k": "v"}) - # This line will emit an additional dataset event as the extra is different. - outlet_events["my-task-outputs-3"].add(Dataset("s3://bucket/my-task"), extra={"k2": "v2"}) - -Scheduling based on dataset aliases -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Since dataset events added to an alias are just simple dataset events, a downstream DAG depending on the actual dataset can read dataset events of it normally, without considering the associated aliases. A downstream DAG can also depend on a dataset alias. The authoring syntax is referencing the ``DatasetAlias`` by name, and the associated dataset events are picked up for scheduling. Note that a DAG can be triggered by a task with ``outlets=DatasetAlias("xxx")`` if and only if the alias is resolved into ``Dataset("s3://bucket/my-task")``. The DAG runs whenever a task with outlet ``DatasetAlias("out")`` gets associated with at least one dataset at runtime, regardless of the dataset's identity. The downstream DAG is not triggered if no datasets are associated to the alias for a particular given task run. This also means we can do conditional dataset-triggering. - -The dataset alias is resolved to the datasets during DAG parsing. Thus, if the "min_file_process_interval" configuration is set to a high value, there is a possibility that the dataset alias may not be resolved. To resolve this issue, you can trigger DAG parsing. - -.. code-block:: python - - with DAG(dag_id="dataset-producer"): - - @task(outlets=[Dataset("example-alias")]) - def produce_dataset_events(): - pass - - - with DAG(dag_id="dataset-alias-producer"): - - @task(outlets=[DatasetAlias("example-alias")]) - def produce_dataset_events(*, outlet_events): - outlet_events["example-alias"].add(Dataset("s3://bucket/my-task")) - - - with DAG(dag_id="dataset-consumer", schedule=Dataset("s3://bucket/my-task")): - ... - - with DAG(dag_id="dataset-alias-consumer", schedule=DatasetAlias("example-alias")): - ... - - -In the example provided, once the DAG ``dataset-alias-producer`` is executed, the dataset alias ``DatasetAlias("example-alias")`` will be resolved to ``Dataset("s3://bucket/my-task")``. However, the DAG ``dataset-alias-consumer`` will have to wait for the next DAG re-parsing to update its schedule. To address this, Airflow will re-parse the DAGs relying on the dataset alias ``DatasetAlias("example-alias")`` when it's resolved into datasets that these DAGs did not previously depend on. As a result, both the "dataset-consumer" and "dataset-alias-consumer" DAGs will be triggered after the execution of DAG ``dataset-alias-producer``. - - -Fetching information from previously emitted dataset events through resolved dataset aliases -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -As mentioned in :ref:`Fetching information from previously emitted dataset events`, inlet dataset events can be read with the ``inlet_events`` accessor in the execution context, and you can also use dataset aliases to access the dataset events triggered by them. - -.. code-block:: python - - with DAG(dag_id="dataset-alias-producer"): - - @task(outlets=[DatasetAlias("example-alias")]) - def produce_dataset_events(*, outlet_events): - outlet_events["example-alias"].add(Dataset("s3://bucket/my-task"), extra={"row_count": 1}) - - - with DAG(dag_id="dataset-alias-consumer", schedule=None): - - @task(inlets=[DatasetAlias("example-alias")]) - def consume_dataset_alias_events(*, inlet_events): - events = inlet_events[DatasetAlias("example-alias")] - last_row_count = events[-1].extra["row_count"] - - -Combining dataset and time-based schedules ------------------------------------------- - -DatasetTimetable Integration -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -You can schedule DAGs based on both dataset events and time-based schedules using ``DatasetOrTimeSchedule``. This allows you to create workflows when a DAG needs both to be triggered by data updates and run periodically according to a fixed timetable. - -For more detailed information on ``DatasetOrTimeSchedule``, refer to the corresponding section in :ref:`DatasetOrTimeSchedule `. diff --git a/docs/apache-airflow/authoring-and-scheduling/index.rst b/docs/apache-airflow/authoring-and-scheduling/index.rst index 1a042918fc1ba..5ec94d6ca7301 100644 --- a/docs/apache-airflow/authoring-and-scheduling/index.rst +++ b/docs/apache-airflow/authoring-and-scheduling/index.rst @@ -41,5 +41,5 @@ It's recommended that you first review the pages in :doc:`core concepts dict: + def retrieve(src: Asset) -> dict: resp = requests.get(url=src.uri) data = resp.json() return data["data"] @@ -137,14 +137,14 @@ a ``Dataset``, which is ``@attr.define`` decorated, together with TaskFlow. return ret @task() - def load(fahrenheit: dict[int, float]) -> Dataset: + def load(fahrenheit: dict[int, float]) -> Asset: filename = "/tmp/fahrenheit.json" s = json.dumps(fahrenheit) f = open(filename, "w") f.write(s) f.close() - return Dataset(f"file:///{filename}") + return Asset(f"file:///{filename}") data = retrieve(SRC) fahrenheit = to_fahrenheit(data) diff --git a/docs/apache-airflow/img/dataset-scheduled-dags.png b/docs/apache-airflow/img/asset-scheduled-dags.png similarity index 100% rename from docs/apache-airflow/img/dataset-scheduled-dags.png rename to docs/apache-airflow/img/asset-scheduled-dags.png diff --git a/docs/apache-airflow/img/datasets.png b/docs/apache-airflow/img/assets.png similarity index 100% rename from docs/apache-airflow/img/datasets.png rename to docs/apache-airflow/img/assets.png diff --git a/docs/apache-airflow/templates-ref.rst b/docs/apache-airflow/templates-ref.rst index 05d4b10accca5..5524c82f8cc95 100644 --- a/docs/apache-airflow/templates-ref.rst +++ b/docs/apache-airflow/templates-ref.rst @@ -62,10 +62,10 @@ Variable Type Description ``{{ prev_end_date_success }}`` `pendulum.DateTime`_ End date from prior successful :class:`~airflow.models.dagrun.DagRun` (if available). | ``None`` ``{{ inlets }}`` list List of inlets declared on the task. -``{{ inlet_events }}`` dict[str, ...] Access past events of inlet datasets. See :doc:`Datasets `. Added in version 2.10. +``{{ inlet_events }}`` dict[str, ...] Access past events of inlet assets. See :doc:`Assets `. Added in version 2.10. ``{{ outlets }}`` list List of outlets declared on the task. -``{{ outlet_events }}`` dict[str, ...] | Accessors to attach information to dataset events that will be emitted by the current task. - | See :doc:`Datasets `. Added in version 2.10. +``{{ outlet_events }}`` dict[str, ...] | Accessors to attach information to asset events that will be emitted by the current task. + | See :doc:`Assets `. Added in version 2.10. ``{{ dag }}`` DAG The currently running :class:`~airflow.models.dag.DAG`. You can read more about DAGs in :doc:`DAGs `. ``{{ task }}`` BaseOperator | The currently running :class:`~airflow.models.baseoperator.BaseOperator`. You can read more about Tasks in :doc:`core-concepts/operators` ``{{ macros }}`` | A reference to the macros package. See Macros_ below. @@ -88,9 +88,9 @@ Variable Type Description ``{{ expanded_ti_count }}`` int | ``None`` | Number of task instances that a mapped task was expanded into. If | the current task is not mapped, this should be ``None``. | Added in version 2.5. -``{{ triggering_dataset_events }}`` dict[str, | If in a Dataset Scheduled DAG, a map of Dataset URI to a list of triggering :class:`~airflow.models.dataset.DatasetEvent` - list[DatasetEvent]] | (there may be more than one, if there are multiple Datasets with different frequencies). - | Read more here :doc:`Datasets `. +``{{ triggering_asset_events }}`` dict[str, | If in a Asset Scheduled DAG, a map of Asset URI to a list of triggering :class:`~airflow.models.asset.AssetEvent` + list[AssetEvent]] | (there may be more than one, if there are multiple Assets with different frequencies). + | Read more here :doc:`Assets `. | Added in version 2.4. =========================================== ===================== =================================================================== diff --git a/docs/apache-airflow/tutorial/objectstorage.rst b/docs/apache-airflow/tutorial/objectstorage.rst index 943e8031a7e58..39e42ddd76627 100644 --- a/docs/apache-airflow/tutorial/objectstorage.rst +++ b/docs/apache-airflow/tutorial/objectstorage.rst @@ -65,7 +65,7 @@ The connection ID can alternatively be passed in with a keyword argument: ObjectStoragePath("s3://airflow-tutorial-data/", conn_id="aws_default") -This is useful when reusing a URL defined for another purpose (e.g. Dataset), +This is useful when reusing a URL defined for another purpose (e.g. Asset), which generally does not contain a username part. The explicit keyword argument takes precedence over the URL's username value if both are specified. diff --git a/docs/apache-airflow/ui.rst b/docs/apache-airflow/ui.rst index 4238cb6ba427c..05d71c9a96176 100644 --- a/docs/apache-airflow/ui.rst +++ b/docs/apache-airflow/ui.rst @@ -61,7 +61,7 @@ Native Airflow dashboard page into the UI to collect several useful metrics for ------------ -.. _ui:datasets-view: +.. _ui:assets-view: Datasets View ............. @@ -72,7 +72,7 @@ Clicking on any dataset in either the list or the graph will highlight it and it ------------ -.. image:: img/datasets.png +.. image:: img/assets.png ------------ diff --git a/docs/exts/operators_and_hooks_ref.py b/docs/exts/operators_and_hooks_ref.py index 43f954ebb0c37..fe6cd5d3300d2 100644 --- a/docs/exts/operators_and_hooks_ref.py +++ b/docs/exts/operators_and_hooks_ref.py @@ -519,8 +519,8 @@ class DatasetSchemeDirective(BaseJinjaReferenceDirective): def render_content(self, *, tags: set[str] | None, header_separator: str = DEFAULT_HEADER_SEPARATOR): return _common_render_list_content( header_separator=header_separator, - resource_type="dataset-uris", - template="dataset-uri-schemes.rst.jinja2", + resource_type="asset-uris", + template="asset-uri-schemes.rst.jinja2", ) @@ -538,7 +538,7 @@ def setup(app): app.add_directive("airflow-executors", ExecutorsDirective) app.add_directive("airflow-deferrable-operators", DeferrableOperatorDirective) app.add_directive("airflow-deprecations", DeprecationsDirective) - app.add_directive("airflow-dataset-schemes", DatasetSchemeDirective) + app.add_directive("airflow-asset-schemes", DatasetSchemeDirective) return {"parallel_read_safe": True, "parallel_write_safe": True} diff --git a/docs/exts/templates/dataset-uri-schemes.rst.jinja2 b/docs/exts/templates/asset-uri-schemes.rst.jinja2 similarity index 95% rename from docs/exts/templates/dataset-uri-schemes.rst.jinja2 rename to docs/exts/templates/asset-uri-schemes.rst.jinja2 index aa247507becc4..14cdbcf9aab15 100644 --- a/docs/exts/templates/dataset-uri-schemes.rst.jinja2 +++ b/docs/exts/templates/asset-uri-schemes.rst.jinja2 @@ -27,7 +27,7 @@ Core {{ provider['name'] }} {{ header_separator * (provider['name']|length) }} -{% for uri_entry in provider['dataset-uris'] -%} +{% for uri_entry in provider['asset-uris'] -%} - {% for scheme in uri_entry['schemes'] %}``{{ scheme }}``{% if not loop.last %}, {% endif %}{% endfor %} {% endfor %} diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b834ccc9b2005..eb6e612e0992b 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -88,6 +88,8 @@ asctime asend asia assertEqualIgnoreMultipleSpaces +AssetEvent +AssetEvents assigment ast astroid diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index b9bc363b15e33..2a81933c6c584 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -397,7 +397,9 @@ ], "devel-deps": [], "plugins": [], - "cross-providers-deps": [], + "cross-providers-deps": [ + "openlineage" + ], "excluded-python-versions": [], "state": "ready" }, @@ -563,6 +565,7 @@ }, "fab": { "deps": [ + "apache-airflow-providers-common-compat>=1.2.0", "apache-airflow>=2.9.0", "flask-appbuilder==4.5.0", "flask-login>=0.6.2", @@ -574,7 +577,9 @@ "kerberos>=1.3.0" ], "plugins": [], - "cross-providers-deps": [], + "cross-providers-deps": [ + "common.compat" + ], "excluded-python-versions": [], "state": "ready" }, @@ -967,6 +972,7 @@ } ], "cross-providers-deps": [ + "common.compat", "common.sql" ], "excluded-python-versions": [], diff --git a/newsfragments/41348.significant.rst b/newsfragments/41348.significant.rst new file mode 100644 index 0000000000000..8b5cc54dd40dc --- /dev/null +++ b/newsfragments/41348.significant.rst @@ -0,0 +1,240 @@ +**Breaking Change** + +* Rename module ``airflow.api_connexion.schemas.dataset_schema`` as ``airflow.api_connexion.schemas.asset_schema`` + + * Rename variable ``create_dataset_event_schema`` as ``create_asset_event_schema`` + * Rename variable ``dataset_collection_schema`` as ``asset_collection_schema`` + * Rename variable ``dataset_event_collection_schema`` as ``asset_event_collection_schema`` + * Rename variable ``dataset_event_schema`` as ``asset_event_schema`` + * Rename variable ``dataset_schema`` as ``asset_schema`` + * Rename class ``TaskOutletDatasetReferenceSchema`` as ``TaskOutletAssetReferenceSchema`` + * Rename class ``DagScheduleDatasetReferenceSchema`` as ``DagScheduleAssetReferenceSchema`` + * Rename class ``DatasetAliasSchema`` as ``AssetAliasSchema`` + * Rename class ``DatasetSchema`` as ``AssetSchema`` + * Rename class ``DatasetCollection`` as ``AssetCollection`` + * Rename class ``DatasetEventSchema`` as ``AssetEventSchema`` + * Rename class ``DatasetEventCollection`` as ``AssetEventCollection`` + * Rename class ``DatasetEventCollectionSchema`` as ``AssetEventCollectionSchema`` + * Rename class ``CreateDatasetEventSchema`` as ``CreateAssetEventSchema`` + +* Rename module ``airflow.datasets`` as ``airflow.assets`` + + * Rename class ``DatasetAlias`` as ``AssetAlias`` + * Rename class ``DatasetAll`` as ``AssetAll`` + * Rename class ``DatasetAny`` as ``AssetAny`` + * Rename function ``expand_alias_to_datasets`` as ``expand_alias_to_assets`` + * Rename class ``DatasetAliasEvent`` as ``AssetAliasEvent`` + + * Rename method ``dest_dataset_uri`` as ``dest_asset_uri`` + + * Rename class ``BaseDataset`` as ``BaseAsset`` + + * Rename method ``iter_datasets`` as ``iter_assets`` + * Rename method ``iter_dataset_aliases`` as ``iter_asset_aliases`` + + * Rename class ``Dataset`` as ``Asset`` + + * Rename method ``iter_datasets`` as ``iter_assets`` + * Rename method ``iter_dataset_aliases`` as ``iter_asset_aliases`` + + * Rename class ``_DatasetBooleanCondition`` as ``_AssetBooleanCondition`` + + * Rename method ``iter_datasets`` as ``iter_assets`` + * Rename method ``iter_dataset_aliases`` as ``iter_asset_aliases`` + +* Rename module ``airflow.datasets.manager`` as ``airflow.assets.manager`` + + * Rename variable ``dataset_manager`` as ``asset_manager`` + * Rename function ``resolve_dataset_manager`` as ``resolve_asset_manager`` + * Rename class ``DatasetManager`` as ``AssetManager`` + + * Rename method ``register_dataset_change`` as ``register_asset_change`` + * Rename method ``create_datasets`` as ``create_assets`` + * Rename method ``register_dataset_change`` as ``notify_asset_created`` + * Rename method ``notify_dataset_changed`` as ``notify_asset_changed`` + * Renme method ``notify_dataset_alias_created`` as ``notify_asset_alias_created`` + +* Rename module ``airflow.models.dataset`` as ``airflow.models.asset`` + + * Rename class ``DatasetDagRunQueue`` as ``AssetDagRunQueue`` + * Rename class ``DatasetEvent`` as ``AssetEvent`` + * Rename class ``DatasetModel`` as ``AssetModel`` + * Rename class ``DatasetAliasModel`` as ``AssetAliasModel`` + * Rename class ``DagScheduleDatasetReference`` as ``DagScheduleAssetReference`` + * Rename class ``TaskOutletDatasetReference`` as ``TaskOutletAssetReference`` + * Rename class ``DagScheduleDatasetAliasReference`` as ``DagScheduleAssetAliasReference`` + +* Rename module ``airflow.api_ui.views.datasets`` as ``airflow.api_ui.views.assets`` + + * Rename variable ``dataset_router`` as ``asset_rounter`` + +* Rename module ``airflow.listeners.spec.dataset`` as ``airflow.listeners.spec.asset`` + + * Rename function ``on_dataset_created`` as ``on_asset_created`` + * Rename function ``on_dataset_changed`` as ``on_asset_changed`` + +* Rename module ``airflow.timetables.datasets`` as ``airflow.timetables.assets`` + + * Rename class ``DatasetOrTimeSchedule`` as ``AssetOrTimeSchedule`` + +* Rename module ``airflow.serialization.pydantic.dataset`` as ``airflow.serialization.pydantic.asset`` + + * Rename class ``DagScheduleDatasetReferencePydantic`` as ``DagScheduleAssetReferencePydantic`` + * Rename class ``TaskOutletDatasetReferencePydantic`` as ``TaskOutletAssetReferencePydantic`` + * Rename class ``DatasetPydantic`` as ``AssetPydantic`` + * Rename class ``DatasetEventPydantic`` as ``AssetEventPydantic`` + +* Rename module ``airflow.datasets.metadata`` as ``airflow.assets.metadata`` + +* In module ``airflow.jobs.scheduler_job_runner`` + + * and its class ``SchedulerJobRunner`` + + * Rename method ``_create_dag_runs_dataset_triggered`` as ``_create_dag_runs_asset_triggered`` + * Rename method ``_orphan_unreferenced_datasets`` as ``_orphan_unreferenced_datasets`` + +* In module ``airflow.api_connexion.security`` + + * Rename decorator ``requires_access_dataset`` as ``requires_access_asset`` + +* In module ``airflow.auth.managers.models.resource_details`` + + * Rename class ``DatasetDetails`` as ``AssetDetails`` + +* In module ``airflow.auth.managers.base_auth_manager`` + + * Rename function ``is_authorized_dataset`` as ``is_authorized_asset`` + +* In module ``airflow.timetables.simple`` + + * Rename class ``DatasetTriggeredTimetable`` as ``AssetTriggeredTimetable`` + +* In module ``airflow.lineage.hook`` + + * Rename class ``DatasetLineageInfo`` as ``AssetLineageInfo`` + + * Rename attribute ``dataset`` as ``asset`` + + * In its class ``HookLineageCollector`` + + * Rename method ``create_dataset`` as ``create_asset`` + * Rename method ``add_input_dataset`` as ``add_input_asset`` + * Rename method ``add_output_dataset`` as ``add_output_asset`` + * Rename method ``collected_datasets`` as ``collected_assets`` + +* In module ``airflow.models.dag`` + + * Rename function ``get_dataset_triggered_next_run_info`` as ``get_asset_triggered_next_run_info`` + + * In its class ``DagModel`` + + * Rename method ``get_dataset_triggered_next_run_info`` as ``get_asset_triggered_next_run_info`` + +* In module ``airflow.models.taskinstance`` + + * and its class ``TaskInstance`` + + * Rename method ``_register_dataset_changes`` as ``_register_asset_changes`` + +* In module ``airflow.providers_manager`` + + * and its class ``ProvidersManager`` + + * Rename method ``initialize_providers_dataset_uri_resources`` as ``initialize_providers_asset_uri_resources`` + * Rename attribute ``_discover_dataset_uri_resources`` as ``_discover_asset_uri_resources`` + * Rename property ``dataset_factories`` as ``asset_factories`` + * Rename property ``dataset_uri_handlers`` as ``asset_uri_handlers`` + * Rename property ``dataset_to_openlineage_converters`` as ``asset_to_openlineage_converters`` + +* In module ``airflow.security.permissions`` + + * Rename constant ``RESOURCE_DATASET`` as ``RESOURCE_ASSET`` + +* In module ``airflow.serialization.enums`` + + * and its class DagAttributeTypes + + * Rename attribute ``DATASET_EVENT_ACCESSORS`` as ``ASSET_EVENT_ACCESSORS`` + * Rename attribute ``DATASET_EVENT_ACCESSOR`` as ``ASSET_EVENT_ACCESSOR`` + * Rename attribute ``DATASET`` as ``ASSET`` + * Rename attribute ``DATASET_ALIAS`` as ``ASSET_ALIAS`` + * Rename attribute ``DATASET_ANY`` as ``ASSET_ANY`` + * Rename attribute ``DATASET_ALL`` as ``ASSET_ALL`` + +* In module ``airflow.serialization.pydantic.taskinstance`` + + * and its class ``TaskInstancePydantic`` + + * Rename method ``_register_dataset_changes`` as ``_register_dataset_changes`` + +* In module ``airflow.serialization.serialized_objects`` + + * Rename function ``encode_dataset_condition`` as ``encode_asset_condition`` + * Rename function ``decode_dataset_condition`` as ``decode_asset_condition`` + +* In module ``airflow.timetables.base`` + + * Rename class ```_NullDataset``` as ```_NullAsset``` + + * Rename method ``iter_datasets`` as ``iter_assets`` + * Rename method ``iter_dataset_aliases`` as ``iter_assets_aliases`` + +* In module ``airflow.utils.context`` + + * Rename class ``LazyDatasetEventSelectSequence`` as ``LazyAssetEventSelectSequence`` + +* In module ``airflow.www.auth`` + + * Rename function ``has_access_dataset`` as ``has_access_asset`` + +* Rename configuration ``core.strict_dataset_uri_validation`` as ``core.strict_asset_uri_validation``, ``core.dataset_manager_class`` as ``core.asset_manager_class`` and ``core.dataset_manager_class`` as ``core.asset_manager_class`` +* Rename example dags ``example_dataset_alias.py``, ``example_dataset_alias_with_no_taskflow.py``, ``example_datasets.py`` as ``example_asset_alias.py``, ``example_asset_alias_with_no_taskflow.py``, ``example_assets.py`` +* Rename DagDependency name ``dataset-alias``, ``dataset`` as ``asset-alias``, ``asset`` +* Rename context key ``triggering_dataset_events`` as ``triggering_asset_events`` +* Rename resource key ``dataset-uris`` as ``asset-uris`` for providers amazon, common.io, mysql, fab, postgres, trino + +* In provider ``airflow.providers.amazon.aws`` + + * Rename package ``datasets`` as ``assets`` + + * In its module ``s3`` + + * Rename method ``create_dataset`` as ``create_asset`` + * Rename method ``convert_dataset_to_openlineage`` as ``convert_asset_to_openlineage`` + + * and its module ``auth_manager.avp.entities`` + + * Rename attribute ``AvpEntities.DATASET`` as ``AvpEntities.ASSET`` + + * and its module ``auth_manager.auth_manager.aws_auth_manager`` + + * Rename function ``is_authorized_dataset`` as ``is_authorized_asset`` + +* In provider ``airflow.providers.common.io`` + + * Rename package ``datasets`` as ``assets`` + + * in its module ``file`` + + * Rename method ``create_dataset`` as ``create_asset`` + * Rename method ``convert_dataset_to_openlineage`` as ``convert_asset_to_openlineage`` + +* In provider ``airflow.providers.fab`` + + * in its module ``auth_manager.fab_auth_manager`` + + * Rename function ``is_authorized_dataset`` as ``is_authorized_asset`` + +* In provider ``airflow.providers.openlineage`` + + * in its module ``utils.utils`` + + * Rename class ``DatasetInfo`` as ``AssetInfo`` + * Rename function ``translate_airflow_dataset`` as ``translate_airflow_asset`` + +* Rename package ``airflow.providers.postgres.datasets`` as ``airflow.providers.postgres.assets`` +* Rename package ``airflow.providers.mysql.datasets`` as ``airflow.providers.mysql.assets`` +* Rename package ``airflow.providers.trino.datasets`` as ``airflow.providers.trino.assets`` +* Add module ``airflow.providers.common.compat.assets`` +* Add module ``airflow.providers.common.compat.openlineage.utils.utils`` +* Add module ``airflow.providers.common.compat.security.permissions`` diff --git a/scripts/ci/pre_commit/check_tests_in_right_folders.py b/scripts/ci/pre_commit/check_tests_in_right_folders.py index 8260b6ad0d578..11d44efd407a7 100755 --- a/scripts/ci/pre_commit/check_tests_in_right_folders.py +++ b/scripts/ci/pre_commit/check_tests_in_right_folders.py @@ -34,6 +34,7 @@ "api_connexion", "api_internal", "api_fastapi", + "assets", "auth", "callbacks", "charts", @@ -45,7 +46,6 @@ "dags", "dags_corrupted", "dags_with_system_exit", - "datasets", "decorators", "executors", "hooks", diff --git a/scripts/cov/core_coverage.py b/scripts/cov/core_coverage.py index 0facd4bb1c5d7..2d8ac091c6e0d 100644 --- a/scripts/cov/core_coverage.py +++ b/scripts/cov/core_coverage.py @@ -47,6 +47,7 @@ "airflow/jobs/triggerer_job_runner.py", # models "airflow/models/abstractoperator.py", + "airflow/models/asset.py", "airflow/models/base.py", "airflow/models/baseoperator.py", "airflow/models/connection.py", @@ -57,7 +58,6 @@ "airflow/models/dagpickle.py", "airflow/models/dagrun.py", "airflow/models/dagwarning.py", - "airflow/models/dataset.py", "airflow/models/expandinput.py", "airflow/models/log.py", "airflow/models/mappedoperator.py", diff --git a/scripts/cov/other_coverage.py b/scripts/cov/other_coverage.py index 6543d2fc780e0..dae7733ec5c15 100644 --- a/scripts/cov/other_coverage.py +++ b/scripts/cov/other_coverage.py @@ -37,7 +37,7 @@ "airflow/callbacks", "airflow/config_templates", "airflow/dag_processing", - "airflow/datasets", + "airflow/assets", "airflow/decorators", "airflow/hooks", "airflow/io", @@ -79,7 +79,7 @@ "tests/cluster_policies", "tests/config_templates", "tests/dag_processing", - "tests/datasets", + "tests/assets", "tests/decorators", "tests/hooks", "tests/io", diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index c387f6173ca2f..b27729a68a261 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -75,6 +75,7 @@ def test_providers_modules_should_have_tests(self): "tests/providers/amazon/aws/triggers/test_step_function.py", "tests/providers/amazon/aws/utils/test_rds.py", "tests/providers/amazon/aws/utils/test_sagemaker.py", + "tests/providers/amazon/aws/utils/test_asset_compat_lineage_collector.py", "tests/providers/amazon/aws/waiters/test_base_waiter.py", "tests/providers/apache/cassandra/hooks/test_cassandra.py", "tests/providers/apache/drill/operators/test_drill.py", @@ -150,6 +151,7 @@ def test_providers_modules_should_have_tests(self): "tests/providers/google/test_go_module_utils.py", "tests/providers/microsoft/azure/operators/test_adls.py", "tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py", + "tests/providers/openlineage/utils/test_asset_compat_lineage_collector.py", "tests/providers/slack/notifications/test_slack_notifier.py", "tests/providers/snowflake/triggers/test_snowflake_trigger.py", "tests/providers/yandex/hooks/test_yandexcloud_dataproc.py", diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index deb5fe0af2daa..f3921da7b9c29 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -24,10 +24,10 @@ import time_machine from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.datasets import Dataset +from airflow.assets import Asset +from airflow.models.asset import AssetEvent, AssetModel from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun -from airflow.models.dataset import DatasetEvent, DatasetModel from airflow.models.param import Param from airflow.operators.empty import EmptyOperator from airflow.security import permissions @@ -57,7 +57,7 @@ def configured_app(minimal_app_for_api): role_name="Test", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), @@ -72,7 +72,7 @@ def configured_app(minimal_app_for_api): role_name="TestNoDagRunCreatePermission", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), @@ -1911,16 +1911,16 @@ def test_should_respond_404(self): @pytest.mark.need_serialized_dag class TestGetDagRunDatasetTriggerEvents(TestDagRunEndpoint): def test_should_respond_200(self, dag_maker, session): - dataset1 = Dataset(uri="ds1") + asset1 = Asset(uri="ds1") with dag_maker(dag_id="source_dag", start_date=timezone.utcnow(), session=session): - EmptyOperator(task_id="task", outlets=[dataset1]) + EmptyOperator(task_id="task", outlets=[asset1]) dr = dag_maker.create_dagrun() ti = dr.task_instances[0] - ds1_id = session.query(DatasetModel.id).filter_by(uri=dataset1.uri).scalar() - event = DatasetEvent( - dataset_id=ds1_id, + asset1_id = session.query(AssetModel.id).filter_by(uri=asset1.uri).scalar() + event = AssetEvent( + dataset_id=asset1_id, source_task_id=ti.task_id, source_dag_id=ti.dag_id, source_run_id=ti.run_id, @@ -1945,8 +1945,8 @@ def test_should_respond_200(self, dag_maker, session): "dataset_events": [ { "timestamp": event.timestamp.isoformat(), - "dataset_id": ds1_id, - "dataset_uri": dataset1.uri, + "dataset_id": asset1_id, + "dataset_uri": asset1.uri, "extra": {}, "id": event.id, "source_dag_id": ti.dag_id, diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index 1e5389d377440..a8d1224e034c3 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -37,7 +37,7 @@ EXAMPLE_DAG_ID = "example_bash_operator" TEST_DAG_ID = "latest_only" NOT_READABLE_DAG_ID = "latest_only_with_trigger" -TEST_MULTIPLE_DAGS_ID = "dataset_produces_1" +TEST_MULTIPLE_DAGS_ID = "asset_produces_1" @pytest.fixture(scope="module") diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py index 25f8012039109..5caec0ac2a131 100644 --- a/tests/api_connexion/endpoints/test_dataset_endpoint.py +++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py @@ -25,14 +25,14 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import DagModel -from airflow.models.dagrun import DagRun -from airflow.models.dataset import ( - DagScheduleDatasetReference, - DatasetDagRunQueue, - DatasetEvent, - DatasetModel, - TaskOutletDatasetReference, +from airflow.models.asset import ( + AssetDagRunQueue, + AssetEvent, + AssetModel, + DagScheduleAssetReference, + TaskOutletAssetReference, ) +from airflow.models.dagrun import DagRun from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session @@ -40,7 +40,7 @@ from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars -from tests.test_utils.db import clear_db_datasets, clear_db_runs +from tests.test_utils.db import clear_db_assets, clear_db_runs from tests.test_utils.www import _check_last_log pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -54,8 +54,8 @@ def configured_app(minimal_app_for_api): username="test", role_name="Test", permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ASSET), ], ) create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore @@ -65,8 +65,8 @@ def configured_app(minimal_app_for_api): role_name="TestQueuedEvent", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET), ], ) @@ -84,30 +84,30 @@ class TestDatasetEndpoint: def setup_attrs(self, configured_app) -> None: self.app = configured_app self.client = self.app.test_client() - clear_db_datasets() + clear_db_assets() clear_db_runs() def teardown_method(self) -> None: - clear_db_datasets() + clear_db_assets() clear_db_runs() def _create_dataset(self, session): - dataset_model = DatasetModel( + asset_model = AssetModel( id=1, uri="s3://bucket/key", extra={"foo": "bar"}, created_at=timezone.parse(self.default_time), updated_at=timezone.parse(self.default_time), ) - session.add(dataset_model) + session.add(asset_model) session.commit() - return dataset_model + return asset_model class TestGetDatasetEndpoint(TestDatasetEndpoint): def test_should_respond_200(self, session): self._create_dataset(session) - assert session.query(DatasetModel).count() == 1 + assert session.query(AssetModel).count() == 1 with assert_queries_count(6): response = self.client.get( @@ -133,9 +133,9 @@ def test_should_respond_404(self): ) assert response.status_code == 404 assert { - "detail": "The Dataset with uri: `s3://bucket/key` was not found", + "detail": "The Asset with uri: `s3://bucket/key` was not found", "status": 404, - "title": "Dataset not found", + "title": "Asset not found", "type": EXCEPTIONS_LINK_MAP[404], } == response.json @@ -147,8 +147,8 @@ def test_should_raises_401_unauthenticated(self, session): class TestGetDatasets(TestDatasetEndpoint): def test_should_respond_200(self, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( id=i, uri=f"s3://bucket/key/{i}", extra={"foo": "bar"}, @@ -157,9 +157,9 @@ def test_should_respond_200(self, session): ) for i in [1, 2] ] - session.add_all(datasets) + session.add_all(assets) session.commit() - assert session.query(DatasetModel).count() == 2 + assert session.query(AssetModel).count() == 2 with assert_queries_count(10): response = self.client.get("/api/v1/datasets", environ_overrides={"REMOTE_USER": "test"}) @@ -193,8 +193,8 @@ def test_should_respond_200(self, session): } def test_order_by_raises_400_for_invalid_attr(self, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( uri=f"s3://bucket/key/{i}", extra={"foo": "bar"}, created_at=timezone.parse(self.default_time), @@ -202,9 +202,9 @@ def test_order_by_raises_400_for_invalid_attr(self, session): ) for i in [1, 2] ] - session.add_all(datasets) + session.add_all(assets) session.commit() - assert session.query(DatasetModel).count() == 2 + assert session.query(AssetModel).count() == 2 response = self.client.get( "/api/v1/datasets?order_by=fake", environ_overrides={"REMOTE_USER": "test"} @@ -215,8 +215,8 @@ def test_order_by_raises_400_for_invalid_attr(self, session): assert response.json["detail"] == msg def test_should_raises_401_unauthenticated(self, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( uri=f"s3://bucket/key/{i}", extra={"foo": "bar"}, created_at=timezone.parse(self.default_time), @@ -224,9 +224,9 @@ def test_should_raises_401_unauthenticated(self, session): ) for i in [1, 2] ] - session.add_all(datasets) + session.add_all(assets) session.commit() - assert session.query(DatasetModel).count() == 2 + assert session.query(AssetModel).count() == 2 response = self.client.get("/api/v1/datasets") @@ -254,11 +254,11 @@ def test_should_raises_401_unauthenticated(self, session): ) @provide_session def test_filter_datasets_by_uri_pattern_works(self, url, expected_datasets, session): - dataset1 = DatasetModel("s3://folder/key") - dataset2 = DatasetModel("gcp://bucket/key") - dataset3 = DatasetModel("somescheme://dataset/key") - dataset4 = DatasetModel("wasb://some_dataset_bucket_/key") - session.add_all([dataset1, dataset2, dataset3, dataset4]) + asset1 = AssetModel("s3://folder/key") + asset2 = AssetModel("gcp://bucket/key") + asset3 = AssetModel("somescheme://dataset/key") + asset4 = AssetModel("wasb://some_dataset_bucket_/key") + session.add_all([asset1, asset2, asset3, asset4]) session.commit() response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 200 @@ -273,12 +273,12 @@ def test_filter_datasets_by_dag_ids_works(self, dag_ids, expected_num, session): dag1 = DagModel(dag_id="dag1") dag2 = DagModel(dag_id="dag2") dag3 = DagModel(dag_id="dag3") - dataset1 = DatasetModel("s3://folder/key") - dataset2 = DatasetModel("gcp://bucket/key") - dataset3 = DatasetModel("somescheme://dataset/key") - dag_ref1 = DagScheduleDatasetReference(dag_id="dag1", dataset=dataset1) - dag_ref2 = DagScheduleDatasetReference(dag_id="dag2", dataset=dataset2) - task_ref1 = TaskOutletDatasetReference(dag_id="dag3", task_id="task1", dataset=dataset3) + dataset1 = AssetModel("s3://folder/key") + dataset2 = AssetModel("gcp://bucket/key") + dataset3 = AssetModel("somescheme://dataset/key") + dag_ref1 = DagScheduleAssetReference(dag_id="dag1", dataset=dataset1) + dag_ref2 = DagScheduleAssetReference(dag_id="dag2", dataset=dataset2) + task_ref1 = TaskOutletAssetReference(dag_id="dag3", task_id="task1", dataset=dataset3) session.add_all([dataset1, dataset2, dataset3, dag1, dag2, dag3, dag_ref1, dag_ref2, task_ref1]) session.commit() response = self.client.get( @@ -300,13 +300,13 @@ def test_filter_datasets_by_dag_ids_and_uri_pattern_works( dag1 = DagModel(dag_id="dag1") dag2 = DagModel(dag_id="dag2") dag3 = DagModel(dag_id="dag3") - dataset1 = DatasetModel("s3://folder/key") - dataset2 = DatasetModel("gcp://bucket/key") - dataset3 = DatasetModel("somescheme://dataset/key") - dag_ref1 = DagScheduleDatasetReference(dag_id="dag1", dataset=dataset1) - dag_ref2 = DagScheduleDatasetReference(dag_id="dag2", dataset=dataset2) - task_ref1 = TaskOutletDatasetReference(dag_id="dag3", task_id="task1", dataset=dataset3) - session.add_all([dataset1, dataset2, dataset3, dag1, dag2, dag3, dag_ref1, dag_ref2, task_ref1]) + asset1 = AssetModel("s3://folder/key") + asset2 = AssetModel("gcp://bucket/key") + asset3 = AssetModel("somescheme://dataset/key") + dag_ref1 = DagScheduleAssetReference(dag_id="dag1", dataset=asset1) + dag_ref2 = DagScheduleAssetReference(dag_id="dag2", dataset=asset2) + task_ref1 = TaskOutletAssetReference(dag_id="dag3", task_id="task1", dataset=asset3) + session.add_all([asset1, asset2, asset3, dag1, dag2, dag3, dag_ref1, dag_ref2, task_ref1]) session.commit() response = self.client.get( f"/api/v1/datasets?dag_ids={dag_ids}&uri_pattern={uri_pattern}", @@ -333,8 +333,8 @@ class TestGetDatasetsEndpointPagination(TestDatasetEndpoint): ) @provide_session def test_limit_and_offset(self, url, expected_dataset_uris, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( uri=f"s3://bucket/key/{i}", extra={"foo": "bar"}, created_at=timezone.parse(self.default_time), @@ -342,7 +342,7 @@ def test_limit_and_offset(self, url, expected_dataset_uris, session): ) for i in range(1, 110) ] - session.add_all(datasets) + session.add_all(assets) session.commit() response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) @@ -352,8 +352,8 @@ def test_limit_and_offset(self, url, expected_dataset_uris, session): assert dataset_uris == expected_dataset_uris def test_should_respect_page_size_limit_default(self, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( uri=f"s3://bucket/key/{i}", extra={"foo": "bar"}, created_at=timezone.parse(self.default_time), @@ -361,7 +361,7 @@ def test_should_respect_page_size_limit_default(self, session): ) for i in range(1, 110) ] - session.add_all(datasets) + session.add_all(assets) session.commit() response = self.client.get("/api/v1/datasets", environ_overrides={"REMOTE_USER": "test"}) @@ -371,8 +371,8 @@ def test_should_respect_page_size_limit_default(self, session): @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( uri=f"s3://bucket/key/{i}", extra={"foo": "bar"}, created_at=timezone.parse(self.default_time), @@ -380,7 +380,7 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session): ) for i in range(1, 200) ] - session.add_all(datasets) + session.add_all(assets) session.commit() response = self.client.get("/api/v1/datasets?limit=180", environ_overrides={"REMOTE_USER": "test"}) @@ -402,10 +402,10 @@ def test_should_respond_200(self, session): "created_dagruns": [], } - events = [DatasetEvent(id=i, timestamp=timezone.parse(self.default_time), **common) for i in [1, 2]] + events = [AssetEvent(id=i, timestamp=timezone.parse(self.default_time), **common) for i in [1, 2]] session.add_all(events) session.commit() - assert session.query(DatasetEvent).count() == 2 + assert session.query(AssetEvent).count() == 2 response = self.client.get("/api/v1/datasets/events", environ_overrides={"REMOTE_USER": "test"}) @@ -441,8 +441,8 @@ def test_should_respond_200(self, session): ) @provide_session def test_filtering(self, attr, value, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( id=i, uri=f"s3://bucket/key/{i}", extra={"foo": "bar"}, @@ -451,10 +451,10 @@ def test_filtering(self, attr, value, session): ) for i in [1, 2, 3] ] - session.add_all(datasets) + session.add_all(assets) session.commit() events = [ - DatasetEvent( + AssetEvent( id=i, dataset_id=i, source_dag_id=f"dag{i}", @@ -467,7 +467,7 @@ def test_filtering(self, attr, value, session): ] session.add_all(events) session.commit() - assert session.query(DatasetEvent).count() == 3 + assert session.query(AssetEvent).count() == 3 response = self.client.get( f"/api/v1/datasets/events?{attr}={value}", environ_overrides={"REMOTE_USER": "test"} @@ -480,7 +480,7 @@ def test_filtering(self, attr, value, session): { "id": 2, "dataset_id": 2, - "dataset_uri": datasets[1].uri, + "dataset_uri": assets[1].uri, "extra": {}, "source_dag_id": "dag2", "source_task_id": "task2", @@ -496,7 +496,7 @@ def test_filtering(self, attr, value, session): def test_order_by_raises_400_for_invalid_attr(self, session): self._create_dataset(session) events = [ - DatasetEvent( + AssetEvent( dataset_id=1, extra="{'foo': 'bar'}", source_dag_id="foo", @@ -509,7 +509,7 @@ def test_order_by_raises_400_for_invalid_attr(self, session): ] session.add_all(events) session.commit() - assert session.query(DatasetEvent).count() == 2 + assert session.query(AssetEvent).count() == 2 response = self.client.get( "/api/v1/datasets/events?order_by=fake", environ_overrides={"REMOTE_USER": "test"} @@ -525,7 +525,7 @@ def test_should_raises_401_unauthenticated(self, session): def test_includes_created_dagrun(self, session): self._create_dataset(session) - event = DatasetEvent( + event = AssetEvent( id=1, dataset_id=1, timestamp=timezone.parse(self.default_time), @@ -685,7 +685,7 @@ class TestGetDatasetEventsEndpointPagination(TestDatasetEndpoint): def test_limit_and_offset(self, url, expected_event_runids, session): self._create_dataset(session) events = [ - DatasetEvent( + AssetEvent( dataset_id=1, source_dag_id="foo", source_task_id="bar", @@ -707,7 +707,7 @@ def test_limit_and_offset(self, url, expected_event_runids, session): def test_should_respect_page_size_limit_default(self, session): self._create_dataset(session) events = [ - DatasetEvent( + AssetEvent( dataset_id=1, source_dag_id="foo", source_task_id="bar", @@ -729,7 +729,7 @@ def test_should_respect_page_size_limit_default(self, session): def test_should_return_conf_max_if_req_max_above_conf(self, session): self._create_dataset(session) events = [ - DatasetEvent( + AssetEvent( dataset_id=1, source_dag_id="foo", source_task_id="bar", @@ -761,10 +761,10 @@ def time_freezer(self) -> Generator: freezer.stop() def _create_dataset_dag_run_queues(self, dag_id, dataset_id, session): - ddrq = DatasetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) - session.add(ddrq) + adrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) + session.add(adrq) session.commit() - return ddrq + return adrq class TestGetDagDatasetQueuedEvent(TestQueuedEventEndpoint): @@ -799,7 +799,7 @@ def test_should_respond_404(self): assert response.status_code == 404 assert { - "detail": "Queue event with dag_id: `not_exists` and dataset uri: `not_exists` was not found", + "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], @@ -832,10 +832,10 @@ def test_delete_should_respond_204(self, session, create_dummy_dag): dataset_uri = "s3://bucket/key" dataset_id = self._create_dataset(session).id - ddrq = DatasetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) - session.add(ddrq) + adrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) + session.add(adrq) session.commit() - conn = session.query(DatasetDagRunQueue).all() + conn = session.query(AssetDagRunQueue).all() assert len(conn) == 1 response = self.client.delete( @@ -844,7 +844,7 @@ def test_delete_should_respond_204(self, session, create_dummy_dag): ) assert response.status_code == 204 - conn = session.query(DatasetDagRunQueue).all() + conn = session.query(AssetDagRunQueue).all() assert len(conn) == 0 _check_last_log( session, dag_id=dag_id, event="api.delete_dag_dataset_queued_event", execution_date=None @@ -861,7 +861,7 @@ def test_should_respond_404(self): assert response.status_code == 404 assert { - "detail": "Queue event with dag_id: `not_exists` and dataset uri: `not_exists` was not found", + "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], @@ -1013,7 +1013,7 @@ def test_should_respond_404(self): assert response.status_code == 404 assert { - "detail": "Queue event with dataset uri: `not_exists` was not found", + "detail": "Queue event with asset uri: `not_exists` was not found", "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], @@ -1051,7 +1051,7 @@ def test_delete_should_respond_204(self, session, create_dummy_dag): ) assert response.status_code == 204 - conn = session.query(DatasetDagRunQueue).all() + conn = session.query(AssetDagRunQueue).all() assert len(conn) == 0 _check_last_log(session, dag_id=None, event="api.delete_dataset_queued_events", execution_date=None) @@ -1065,7 +1065,7 @@ def test_should_respond_404(self): assert response.status_code == 404 assert { - "detail": "Queue event with dataset uri: `not_exists` was not found", + "detail": "Queue event with asset uri: `not_exists` was not found", "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py index a4a86bc05cc9b..1a7e345421c62 100644 --- a/tests/api_connexion/schemas/test_dag_schema.py +++ b/tests/api_connexion/schemas/test_dag_schema.py @@ -27,7 +27,7 @@ DAGDetailSchema, DAGSchema, ) -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.models import DagModel, DagTag from airflow.models.dag import DAG @@ -210,9 +210,9 @@ def test_serialize_test_dag_detail_schema(url_safe_serializer): @pytest.mark.skip_if_database_isolation_mode @pytest.mark.db_test -def test_serialize_test_dag_with_dataset_schedule_detail_schema(url_safe_serializer): - dataset1 = Dataset(uri="s3://bucket/obj1") - dataset2 = Dataset(uri="s3://bucket/obj2") +def test_serialize_test_dag_with_asset_schedule_detail_schema(url_safe_serializer): + asset1 = Asset(uri="s3://bucket/obj1") + asset2 = Asset(uri="s3://bucket/obj2") dag = DAG( dag_id="test_dag", start_date=datetime(2020, 6, 19), @@ -220,7 +220,7 @@ def test_serialize_test_dag_with_dataset_schedule_detail_schema(url_safe_seriali orientation="LR", default_view="duration", params={"foo": 1}, - schedule=dataset1 & dataset2, + schedule=asset1 & asset2, tags=["example1", "example2"], ) schema = DAGDetailSchema() @@ -255,7 +255,7 @@ def test_serialize_test_dag_with_dataset_schedule_detail_schema(url_safe_seriali key=lambda val: val["name"], ), "template_searchpath": None, - "timetable_summary": "Dataset", + "timetable_summary": "Asset", "timezone": UTC_JSON_REPR, "max_active_runs": 16, "max_consecutive_failed_dag_runs": 0, diff --git a/tests/api_connexion/schemas/test_dataset_schema.py b/tests/api_connexion/schemas/test_dataset_schema.py index c07eed2236a67..a9a5ce9e9673b 100644 --- a/tests/api_connexion/schemas/test_dataset_schema.py +++ b/tests/api_connexion/schemas/test_dataset_schema.py @@ -19,26 +19,26 @@ import pytest import time_machine -from airflow.api_connexion.schemas.dataset_schema import ( - DatasetCollection, - DatasetEventCollection, - dataset_collection_schema, - dataset_event_collection_schema, - dataset_event_schema, - dataset_schema, +from airflow.api_connexion.schemas.asset_schema import ( + AssetCollection, + AssetEventCollection, + asset_collection_schema, + asset_event_collection_schema, + asset_event_schema, + asset_schema, ) -from airflow.datasets import Dataset -from airflow.models.dataset import DatasetAliasModel, DatasetEvent, DatasetModel +from airflow.assets import Asset +from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel from airflow.operators.empty import EmptyOperator -from tests.test_utils.db import clear_db_dags, clear_db_datasets +from tests.test_utils.db import clear_db_assets, clear_db_dags pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] -class TestDatasetSchemaBase: +class TestAssetSchemaBase: def setup_method(self) -> None: clear_db_dags() - clear_db_datasets() + clear_db_assets() self.timestamp = "2022-06-10T12:02:44+00:00" self.freezer = time_machine.travel(self.timestamp, tick=False) self.freezer.start() @@ -46,12 +46,12 @@ def setup_method(self) -> None: def teardown_method(self) -> None: self.freezer.stop() clear_db_dags() - clear_db_datasets() + clear_db_assets() -class TestDatasetSchema(TestDatasetSchemaBase): +class TestAssetSchema(TestAssetSchemaBase): def test_serialize(self, dag_maker, session): - dataset = Dataset( + dataset = Asset( uri="s3://bucket/key", extra={"foo": "bar"}, ) @@ -62,9 +62,9 @@ def test_serialize(self, dag_maker, session): ): EmptyOperator(task_id="task2") - dataset_model = session.query(DatasetModel).filter_by(uri=dataset.uri).one() + asset_model = session.query(AssetModel).filter_by(uri=dataset.uri).one() - serialized_data = dataset_schema.dump(dataset_model) + serialized_data = asset_schema.dump(asset_model) serialized_data["id"] = 1 assert serialized_data == { "id": 1, @@ -91,24 +91,22 @@ def test_serialize(self, dag_maker, session): } -class TestDatasetCollectionSchema(TestDatasetSchemaBase): +class TestAssetCollectionSchema(TestAssetSchemaBase): def test_serialize(self, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( uri=f"s3://bucket/key/{i+1}", extra={"foo": "bar"}, ) for i in range(2) ] - dataset_aliases = [DatasetAliasModel(name=f"alias_{i}") for i in range(2)] - for dataset_alias in dataset_aliases: - dataset_alias.datasets.append(datasets[0]) - session.add_all(datasets) - session.add_all(dataset_aliases) + asset_aliases = [AssetAliasModel(name=f"alias_{i}") for i in range(2)] + for asset_alias in asset_aliases: + asset_alias.datasets.append(assets[0]) + session.add_all(assets) + session.add_all(asset_aliases) session.flush() - serialized_data = dataset_collection_schema.dump( - DatasetCollection(datasets=datasets, total_entries=2) - ) + serialized_data = asset_collection_schema.dump(AssetCollection(datasets=assets, total_entries=2)) serialized_data["datasets"][0]["id"] = 1 serialized_data["datasets"][1]["id"] = 2 serialized_data["datasets"][0]["aliases"][0]["id"] = 1 @@ -143,14 +141,14 @@ def test_serialize(self, session): } -class TestDatasetEventSchema(TestDatasetSchemaBase): +class TestAssetEventSchema(TestAssetSchemaBase): def test_serialize(self, session): - d = DatasetModel("s3://abc") - session.add(d) + assetssetsset = AssetModel("s3://abc") + session.add(assetssetsset) session.commit() - event = DatasetEvent( + event = AssetEvent( id=1, - dataset_id=d.id, + dataset_id=assetssetsset.id, extra={"foo": "bar"}, source_dag_id="foo", source_task_id="bar", @@ -159,10 +157,10 @@ def test_serialize(self, session): ) session.add(event) session.flush() - serialized_data = dataset_event_schema.dump(event) + serialized_data = asset_event_schema.dump(event) assert serialized_data == { "id": 1, - "dataset_id": d.id, + "dataset_id": assetssetsset.id, "dataset_uri": "s3://abc", "extra": {"foo": "bar"}, "source_dag_id": "foo", @@ -174,14 +172,14 @@ def test_serialize(self, session): } -class TestDatasetEventCreateSchema(TestDatasetSchemaBase): +class TestDatasetEventCreateSchema(TestAssetSchemaBase): def test_serialize(self, session): - d = DatasetModel("s3://abc") - session.add(d) + asset = AssetModel("s3://abc") + session.add(asset) session.commit() - event = DatasetEvent( + event = AssetEvent( id=1, - dataset_id=d.id, + dataset_id=asset.id, extra={"foo": "bar"}, source_dag_id=None, source_task_id=None, @@ -190,10 +188,10 @@ def test_serialize(self, session): ) session.add(event) session.flush() - serialized_data = dataset_event_schema.dump(event) + serialized_data = asset_event_schema.dump(event) assert serialized_data == { "id": 1, - "dataset_id": d.id, + "dataset_id": asset.id, "dataset_uri": "s3://abc", "extra": {"foo": "bar"}, "source_dag_id": None, @@ -205,7 +203,7 @@ def test_serialize(self, session): } -class TestDatasetEventCollectionSchema(TestDatasetSchemaBase): +class TestAssetEventCollectionSchema(TestAssetSchemaBase): def test_serialize(self, session): common = { "dataset_id": 10, @@ -217,11 +215,11 @@ def test_serialize(self, session): "created_dagruns": [], } - events = [DatasetEvent(id=i, **common) for i in [1, 2]] + events = [AssetEvent(id=i, **common) for i in [1, 2]] session.add_all(events) session.flush() - serialized_data = dataset_event_collection_schema.dump( - DatasetEventCollection(dataset_events=events, total_entries=2) + serialized_data = asset_event_collection_schema.dump( + AssetEventCollection(dataset_events=events, total_entries=2) ) assert serialized_data == { "dataset_events": [ diff --git a/tests/api_fastapi/views/ui/test_datasets.py b/tests/api_fastapi/views/ui/test_assets.py similarity index 92% rename from tests/api_fastapi/views/ui/test_datasets.py rename to tests/api_fastapi/views/ui/test_assets.py index 12b22e4bbb9ef..7aff14249b4db 100644 --- a/tests/api_fastapi/views/ui/test_datasets.py +++ b/tests/api_fastapi/views/ui/test_assets.py @@ -18,7 +18,7 @@ import pytest -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.operators.empty import EmptyOperator from tests.conftest import initial_db_init @@ -35,7 +35,7 @@ def cleanup(): def test_next_run_datasets(test_client, dag_maker): - with dag_maker(dag_id="upstream", schedule=[Dataset(uri="s3://bucket/key/1")], serialized=True): + with dag_maker(dag_id="upstream", schedule=[Asset(uri="s3://bucket/key/1")], serialized=True): EmptyOperator(task_id="task1") dag_maker.create_dagrun() diff --git a/tests/providers/common/io/datasets/__init__.py b/tests/assets/__init__.py similarity index 100% rename from tests/providers/common/io/datasets/__init__.py rename to tests/assets/__init__.py diff --git a/tests/datasets/test_manager.py b/tests/assets/test_manager.py similarity index 52% rename from tests/datasets/test_manager.py rename to tests/assets/test_manager.py index 9b8b0c180d48e..0539fdace52ba 100644 --- a/tests/datasets/test_manager.py +++ b/tests/assets/test_manager.py @@ -24,13 +24,13 @@ import pytest from sqlalchemy import delete -from airflow.datasets import Dataset -from airflow.datasets.manager import DatasetManager +from airflow.assets import Asset +from airflow.assets.manager import AssetManager from airflow.listeners.listener import get_listener_manager +from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel, DagScheduleAssetReference from airflow.models.dag import DagModel -from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic -from tests.listeners import dataset_listener +from tests.listeners import asset_listener pytestmark = pytest.mark.db_test @@ -90,93 +90,93 @@ def create_mock_dag(): yield mock_dag -class TestDatasetManager: - def test_register_dataset_change_dataset_doesnt_exist(self, mock_task_instance): - dsem = DatasetManager() +class TestAssetManager: + def test_register_asset_change_asset_doesnt_exist(self, mock_task_instance): + dsem = AssetManager() - dataset = Dataset(uri="dataset_doesnt_exist") + asset = Asset(uri="asset_doesnt_exist") mock_session = mock.Mock() # Gotta mock up the query results mock_session.scalar.return_value = None - dsem.register_dataset_change(task_instance=mock_task_instance, dataset=dataset, session=mock_session) + dsem.register_asset_change(task_instance=mock_task_instance, asset=asset, session=mock_session) - # Ensure that we have ignored the dataset and _not_ created a DatasetEvent or - # DatasetDagRunQueue rows + # Ensure that we have ignored the asset and _not_ created a AssetEvent or + # AssetDagRunQueue rows mock_session.add.assert_not_called() mock_session.merge.assert_not_called() - def test_register_dataset_change(self, session, dag_maker, mock_task_instance): - dsem = DatasetManager() + def test_register_asset_change(self, session, dag_maker, mock_task_instance): + dsem = AssetManager() - ds = Dataset(uri="test_dataset_uri") + ds = Asset(uri="test_asset_uri") dag1 = DagModel(dag_id="dag1", is_active=True) dag2 = DagModel(dag_id="dag2", is_active=True) session.add_all([dag1, dag2]) - dsm = DatasetModel(uri="test_dataset_uri") - session.add(dsm) - dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag.dag_id) for dag in (dag1, dag2)] - session.execute(delete(DatasetDagRunQueue)) + asm = AssetModel(uri="test_asset_uri") + session.add(asm) + asm.consuming_dags = [DagScheduleAssetReference(dag_id=dag.dag_id) for dag in (dag1, dag2)] + session.execute(delete(AssetDagRunQueue)) session.flush() - dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session) + dsem.register_asset_change(task_instance=mock_task_instance, asset=ds, session=session) session.flush() - # Ensure we've created a dataset - assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1 - assert session.query(DatasetDagRunQueue).count() == 2 + # Ensure we've created an asset + assert session.query(AssetEvent).filter_by(dataset_id=asm.id).count() == 1 + assert session.query(AssetDagRunQueue).count() == 2 - def test_register_dataset_change_no_downstreams(self, session, mock_task_instance): - dsem = DatasetManager() + def test_register_asset_change_no_downstreams(self, session, mock_task_instance): + dsem = AssetManager() - ds = Dataset(uri="never_consumed") - dsm = DatasetModel(uri="never_consumed") - session.add(dsm) - session.execute(delete(DatasetDagRunQueue)) + ds = Asset(uri="never_consumed") + asm = AssetModel(uri="never_consumed") + session.add(asm) + session.execute(delete(AssetDagRunQueue)) session.flush() - dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session) + dsem.register_asset_change(task_instance=mock_task_instance, asset=ds, session=session) session.flush() - # Ensure we've created a dataset - assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1 - assert session.query(DatasetDagRunQueue).count() == 0 + # Ensure we've created an asset + assert session.query(AssetEvent).filter_by(dataset_id=asm.id).count() == 1 + assert session.query(AssetDagRunQueue).count() == 0 @pytest.mark.skip_if_database_isolation_mode - def test_register_dataset_change_notifies_dataset_listener(self, session, mock_task_instance): - dsem = DatasetManager() - dataset_listener.clear() - get_listener_manager().add_listener(dataset_listener) + def test_register_asset_change_notifies_asset_listener(self, session, mock_task_instance): + dsem = AssetManager() + asset_listener.clear() + get_listener_manager().add_listener(asset_listener) - ds = Dataset(uri="test_dataset_uri_2") + ds = Asset(uri="test_asset_uri_2") dag1 = DagModel(dag_id="dag3") session.add(dag1) - dsm = DatasetModel(uri="test_dataset_uri_2") - session.add(dsm) - dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag1.dag_id)] + asm = AssetModel(uri="test_asset_uri_2") + session.add(asm) + asm.consuming_dags = [DagScheduleAssetReference(dag_id=dag1.dag_id)] session.flush() - dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session) + dsem.register_asset_change(task_instance=mock_task_instance, asset=ds, session=session) session.flush() # Ensure the listener was notified - assert len(dataset_listener.changed) == 1 - assert dataset_listener.changed[0].uri == ds.uri + assert len(asset_listener.changed) == 1 + assert asset_listener.changed[0].uri == ds.uri @pytest.mark.skip_if_database_isolation_mode - def test_create_datasets_notifies_dataset_listener(self, session): - dsem = DatasetManager() - dataset_listener.clear() - get_listener_manager().add_listener(dataset_listener) + def test_create_assets_notifies_asset_listener(self, session): + asset_manager = AssetManager() + asset_listener.clear() + get_listener_manager().add_listener(asset_listener) - ds = Dataset(uri="test_dataset_uri_3") + asset = Asset(uri="test_asset_uri_3") - dsms = dsem.create_datasets([ds], session=session) + asms = asset_manager.create_assets([asset], session=session) # Ensure the listener was notified - assert len(dataset_listener.created) == 1 - assert len(dsms) == 1 - assert dataset_listener.created[0].uri == ds.uri == dsms[0].uri + assert len(asset_listener.created) == 1 + assert len(asms) == 1 + assert asset_listener.created[0].uri == asset.uri == asms[0].uri diff --git a/tests/assets/tests_asset.py b/tests/assets/tests_asset.py new file mode 100644 index 0000000000000..da6ef8ee79e39 --- /dev/null +++ b/tests/assets/tests_asset.py @@ -0,0 +1,586 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os +from collections import defaultdict +from typing import Callable +from unittest.mock import patch + +import pytest +from sqlalchemy.sql import select + +from airflow.assets import ( + Asset, + AssetAlias, + AssetAll, + AssetAny, + BaseAsset, + _AssetAliasCondition, + _get_normalized_scheme, + _sanitize_uri, +) +from airflow.models.asset import AssetAliasModel, AssetDagRunQueue, AssetModel +from airflow.models.serialized_dag import SerializedDagModel +from airflow.operators.empty import EmptyOperator +from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG +from tests.test_utils.config import conf_vars + + +@pytest.fixture +def clear_assets(): + from tests.test_utils.db import clear_db_assets + + clear_db_assets() + yield + clear_db_assets() + + +@pytest.mark.parametrize( + ["uri"], + [ + pytest.param("", id="empty"), + pytest.param("\n\t", id="whitespace"), + pytest.param("a" * 3001, id="too_long"), + pytest.param("airflow://xcom/dag/task", id="reserved_scheme"), + pytest.param("😊", id="non-ascii"), + ], +) +def test_invalid_uris(uri): + with pytest.raises(ValueError): + Asset(uri=uri) + + +@pytest.mark.parametrize( + "uri, normalized", + [ + pytest.param("foobar", "foobar", id="scheme-less"), + pytest.param("foo:bar", "foo:bar", id="scheme-less-colon"), + pytest.param("foo/bar", "foo/bar", id="scheme-less-slash"), + pytest.param("s3://bucket/key/path", "s3://bucket/key/path", id="normal"), + pytest.param("file:///123/456/", "file:///123/456", id="trailing-slash"), + ], +) +def test_uri_with_scheme(uri: str, normalized: str) -> None: + asset = Asset(uri) + EmptyOperator(task_id="task1", outlets=[asset]) + assert asset.uri == normalized + assert os.fspath(asset) == normalized + + +def test_uri_with_auth() -> None: + with pytest.warns(UserWarning) as record: + asset = Asset("ftp://user@localhost/foo.txt") + assert len(record) == 1 + assert str(record[0].message) == ( + "An Asset URI should not contain auth info (e.g. username or " + "password). It has been automatically dropped." + ) + EmptyOperator(task_id="task1", outlets=[asset]) + assert asset.uri == "ftp://localhost/foo.txt" + assert os.fspath(asset) == "ftp://localhost/foo.txt" + + +def test_uri_without_scheme(): + asset = Asset(uri="example_asset") + EmptyOperator(task_id="task1", outlets=[asset]) + + +def test_fspath(): + uri = "s3://example/asset" + asset = Asset(uri=uri) + assert os.fspath(asset) == uri + + +def test_equal_when_same_uri(): + uri = "s3://example/asset" + asset1 = Asset(uri=uri) + asset2 = Asset(uri=uri) + assert asset1 == asset2 + + +def test_not_equal_when_different_uri(): + asset1 = Asset(uri="s3://example/asset") + asset2 = Asset(uri="s3://other/asset") + assert asset1 != asset2 + + +def test_asset_logic_operations(): + result_or = asset1 | asset2 + assert isinstance(result_or, AssetAny) + result_and = asset1 & asset2 + assert isinstance(result_and, AssetAll) + + +def test_asset_iter_assets(): + assert list(asset1.iter_assets()) == [("s3://bucket1/data1", asset1)] + + +@pytest.mark.db_test +def test_asset_iter_asset_aliases(): + base_asset = AssetAll( + AssetAlias("example-alias-1"), + Asset("1"), + AssetAny( + Asset("2"), + AssetAlias("example-alias-2"), + Asset("3"), + AssetAll(AssetAlias("example-alias-3"), Asset("4"), AssetAlias("example-alias-4")), + ), + AssetAll(AssetAlias("example-alias-5"), Asset("5")), + ) + assert list(base_asset.iter_asset_aliases()) == [ + (f"example-alias-{i}", AssetAlias(f"example-alias-{i}")) for i in range(1, 6) + ] + + +def test_asset_evaluate(): + assert asset1.evaluate({"s3://bucket1/data1": True}) is True + assert asset1.evaluate({"s3://bucket1/data1": False}) is False + + +def test_asset_any_operations(): + result_or = (asset1 | asset2) | asset3 + assert isinstance(result_or, AssetAny) + assert len(result_or.objects) == 3 + result_and = (asset1 | asset2) & asset3 + assert isinstance(result_and, AssetAll) + + +def test_asset_all_operations(): + result_or = (asset1 & asset2) | asset3 + assert isinstance(result_or, AssetAny) + result_and = (asset1 & asset2) & asset3 + assert isinstance(result_and, AssetAll) + + +def test_assset_boolean_condition_evaluate_iter(): + """ + Tests _AssetBooleanCondition's evaluate and iter_assets methods through AssetAny and AssetAll. + Ensures AssetAny evaluate returns True with any true condition, AssetAll evaluate returns False if + any condition is false, and both classes correctly iterate over assets without duplication. + """ + any_condition = AssetAny(asset1, asset2) + all_condition = AssetAll(asset1, asset2) + assert any_condition.evaluate({"s3://bucket1/data1": False, "s3://bucket2/data2": True}) is True + assert all_condition.evaluate({"s3://bucket1/data1": True, "s3://bucket2/data2": False}) is False + + # Testing iter_assets indirectly through the subclasses + assets_any = dict(any_condition.iter_assets()) + assets_all = dict(all_condition.iter_assets()) + assert assets_any == {"s3://bucket1/data1": asset1, "s3://bucket2/data2": asset2} + assert assets_all == {"s3://bucket1/data1": asset1, "s3://bucket2/data2": asset2} + + +@pytest.mark.parametrize( + "inputs, scenario, expected", + [ + # Scenarios for AssetAny + ((True, True, True), "any", True), + ((True, True, False), "any", True), + ((True, False, True), "any", True), + ((True, False, False), "any", True), + ((False, False, True), "any", True), + ((False, True, False), "any", True), + ((False, True, True), "any", True), + ((False, False, False), "any", False), + # Scenarios for AssetAll + ((True, True, True), "all", True), + ((True, True, False), "all", False), + ((True, False, True), "all", False), + ((True, False, False), "all", False), + ((False, False, True), "all", False), + ((False, True, False), "all", False), + ((False, True, True), "all", False), + ((False, False, False), "all", False), + ], +) +def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, expected): + class_ = AssetAny if scenario == "any" else AssetAll + assets = [Asset(uri=f"s3://abc/{i}") for i in range(123, 126)] + condition = class_(*assets) + + statuses = {asset.uri: status for asset, status in zip(assets, inputs)} + assert ( + condition.evaluate(statuses) == expected + ), f"Condition evaluation failed for inputs {inputs} and scenario '{scenario}'" + + # Serialize and deserialize the condition to test persistence + serialized = BaseSerialization.serialize(condition) + deserialized = BaseSerialization.deserialize(serialized) + assert deserialized.evaluate(statuses) == expected, "Serialization round-trip failed" + + +@pytest.mark.parametrize( + "status_values, expected_evaluation", + [ + ((False, True, True), False), # AssetAll requires all conditions to be True, but d1 is False + ((True, True, True), True), # All conditions are True + ((True, False, True), True), # d1 is True, and AssetAny condition (d2 or d3 being True) is met + ((True, False, False), False), # d1 is True, but neither d2 nor d3 meet the AssetAny condition + ], +) +def test_nested_asset_conditions_with_serialization(status_values, expected_evaluation): + # Define assets + d1 = Asset(uri="s3://abc/123") + d2 = Asset(uri="s3://abc/124") + d3 = Asset(uri="s3://abc/125") + + # Create a nested condition: AssetAll with d1 and AssetAny with d2 and d3 + nested_condition = AssetAll(d1, AssetAny(d2, d3)) + + statuses = { + d1.uri: status_values[0], + d2.uri: status_values[1], + d3.uri: status_values[2], + } + + assert nested_condition.evaluate(statuses) == expected_evaluation, "Initial evaluation mismatch" + + serialized_condition = BaseSerialization.serialize(nested_condition) + deserialized_condition = BaseSerialization.deserialize(serialized_condition) + + assert ( + deserialized_condition.evaluate(statuses) == expected_evaluation + ), "Post-serialization evaluation mismatch" + + +@pytest.fixture +def create_test_assets(session): + """Fixture to create test assets and corresponding models.""" + assets = [Asset(uri=f"hello{i}") for i in range(1, 3)] + for asset in assets: + session.add(AssetModel(uri=asset.uri)) + session.commit() + return assets + + +@pytest.mark.db_test +@pytest.mark.usefixtures("clear_assets") +def test_asset_trigger_setup_and_serialization(session, dag_maker, create_test_assets): + assets = create_test_assets + + # Create DAG with asset triggers + with dag_maker(schedule=AssetAny(*assets)) as dag: + EmptyOperator(task_id="hello") + + # Verify assets are set up correctly + assert isinstance(dag.timetable.asset_condition, AssetAny), "DAG assets should be an instance of AssetAny" + + # Round-trip the DAG through serialization + deserialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + + # Verify serialization and deserialization integrity + assert isinstance( + deserialized_dag.timetable.asset_condition, AssetAny + ), "Deserialized assets should maintain type AssetAny" + assert ( + deserialized_dag.timetable.asset_condition.objects == dag.timetable.asset_condition.objects + ), "Deserialized assets should match original" + + +@pytest.mark.db_test +@pytest.mark.usefixtures("clear_assets") +def test_asset_dag_run_queue_processing(session, clear_assets, dag_maker, create_test_assets): + assets = create_test_assets + asset_models = session.query(AssetModel).all() + + with dag_maker(schedule=AssetAny(*assets)) as dag: + EmptyOperator(task_id="hello") + + # Add AssetDagRunQueue entries to simulate asset event processing + for am in asset_models: + session.add(AssetDagRunQueue(dataset_id=am.id, target_dag_id=dag.dag_id)) + session.commit() + + # Fetch and evaluate asset triggers for all DAGs affected by asset events + records = session.scalars(select(AssetDagRunQueue)).all() + dag_statuses = defaultdict(lambda: defaultdict(bool)) + for record in records: + dag_statuses[record.target_dag_id][record.dataset.uri] = True + + serialized_dags = session.execute( + select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) + ).fetchall() + + for (serialized_dag,) in serialized_dags: + dag = SerializedDAG.deserialize(serialized_dag.data) + for asset_uri, status in dag_statuses[dag.dag_id].items(): + cond = dag.timetable.asset_condition + assert cond.evaluate({asset_uri: status}), "DAG trigger evaluation failed" + + +@pytest.mark.db_test +@pytest.mark.usefixtures("clear_assets") +def test_dag_with_complex_asset_condition(session, dag_maker): + # Create Asset instances + d1 = Asset(uri="hello1") + d2 = Asset(uri="hello2") + + # Create and add AssetModel instances to the session + am1 = AssetModel(uri=d1.uri) + am2 = AssetModel(uri=d2.uri) + session.add_all([am1, am2]) + session.commit() + + # Setup a DAG with complex asset triggers (AssetAny with AssetAll) + with dag_maker(schedule=AssetAny(d1, AssetAll(d2, d1))) as dag: + EmptyOperator(task_id="hello") + + assert isinstance( + dag.timetable.asset_condition, AssetAny + ), "DAG's asset trigger should be an instance of AssetAny" + assert any( + isinstance(trigger, AssetAll) for trigger in dag.timetable.asset_condition.objects + ), "DAG's asset trigger should include AssetAll" + + serialized_triggers = SerializedDAG.serialize(dag.timetable.asset_condition) + + deserialized_triggers = SerializedDAG.deserialize(serialized_triggers) + + assert isinstance( + deserialized_triggers, AssetAny + ), "Deserialized triggers should be an instance of AssetAny" + assert any( + isinstance(trigger, AssetAll) for trigger in deserialized_triggers.objects + ), "Deserialized triggers should include AssetAll" + + serialized_timetable_dict = SerializedDAG.to_dict(dag)["dag"]["timetable"]["__var"] + assert ( + "asset_condition" in serialized_timetable_dict + ), "Serialized timetable should contain 'asset_condition'" + assert isinstance( + serialized_timetable_dict["asset_condition"], dict + ), "Serialized 'asset_condition' should be a dict" + + +def assets_equal(a1: BaseAsset, a2: BaseAsset) -> bool: + if type(a1) is not type(a2): + return False + + if isinstance(a1, Asset) and isinstance(a2, Asset): + return a1.uri == a2.uri + + elif isinstance(a1, (AssetAny, AssetAll)) and isinstance(a2, (AssetAny, AssetAll)): + if len(a1.objects) != len(a2.objects): + return False + + # Compare each pair of objects + for obj1, obj2 in zip(a1.objects, a2.objects): + # If obj1 or obj2 is a Asset, AssetAny, or AssetAll instance, + # recursively call assets_equal + if not assets_equal(obj1, obj2): + return False + return True + + return False + + +asset1 = Asset(uri="s3://bucket1/data1") +asset2 = Asset(uri="s3://bucket2/data2") +asset3 = Asset(uri="s3://bucket3/data3") +asset4 = Asset(uri="s3://bucket4/data4") +asset5 = Asset(uri="s3://bucket5/data5") + +test_cases = [ + (lambda: asset1, asset1), + (lambda: asset1 & asset2, AssetAll(asset1, asset2)), + (lambda: asset1 | asset2, AssetAny(asset1, asset2)), + (lambda: asset1 | (asset2 & asset3), AssetAny(asset1, AssetAll(asset2, asset3))), + (lambda: asset1 | asset2 & asset3, AssetAny(asset1, AssetAll(asset2, asset3))), + ( + lambda: ((asset1 & asset2) | asset3) & (asset4 | asset5), + AssetAll(AssetAny(AssetAll(asset1, asset2), asset3), AssetAny(asset4, asset5)), + ), + (lambda: asset1 & asset2 | asset3, AssetAny(AssetAll(asset1, asset2), asset3)), + ( + lambda: (asset1 | asset2) & (asset3 | asset4), + AssetAll(AssetAny(asset1, asset2), AssetAny(asset3, asset4)), + ), + ( + lambda: (asset1 & asset2) | (asset3 & (asset4 | asset5)), + AssetAny(AssetAll(asset1, asset2), AssetAll(asset3, AssetAny(asset4, asset5))), + ), + ( + lambda: (asset1 & asset2) & (asset3 & asset4), + AssetAll(asset1, asset2, AssetAll(asset3, asset4)), + ), + (lambda: asset1 | asset2 | asset3, AssetAny(asset1, asset2, asset3)), + (lambda: asset1 & asset2 & asset3, AssetAll(asset1, asset2, asset3)), + ( + lambda: ((asset1 & asset2) | asset3) & (asset4 | asset5), + AssetAll(AssetAny(AssetAll(asset1, asset2), asset3), AssetAny(asset4, asset5)), + ), +] + + +@pytest.mark.parametrize("expression, expected", test_cases) +def test_evaluate_assets_expression(expression, expected): + expr = expression() + assert assets_equal(expr, expected) + + +@pytest.mark.parametrize( + "expression, error", + [ + pytest.param( + lambda: asset1 & 1, # type: ignore[operator] + "unsupported operand type(s) for &: 'Asset' and 'int'", + id="&", + ), + pytest.param( + lambda: asset1 | 1, # type: ignore[operator] + "unsupported operand type(s) for |: 'Asset' and 'int'", + id="|", + ), + pytest.param( + lambda: AssetAll(1, asset1), # type: ignore[arg-type] + "expect asset expressions in condition", + id="AssetAll", + ), + pytest.param( + lambda: AssetAny(1, asset1), # type: ignore[arg-type] + "expect asset expressions in condition", + id="AssetAny", + ), + ], +) +def test_assets_expression_error(expression: Callable[[], None], error: str) -> None: + with pytest.raises(TypeError) as info: + expression() + assert str(info.value) == error + + +def test_get_normalized_scheme(): + assert _get_normalized_scheme("http://example.com") == "http" + assert _get_normalized_scheme("HTTPS://example.com") == "https" + assert _get_normalized_scheme("ftp://example.com") == "ftp" + assert _get_normalized_scheme("file://") == "file" + + assert _get_normalized_scheme("example.com") == "" + assert _get_normalized_scheme("") == "" + assert _get_normalized_scheme(" ") == "" + + +def _mock_get_uri_normalizer_raising_error(normalized_scheme): + def normalizer(uri): + raise ValueError("Incorrect URI format") + + return normalizer + + +def _mock_get_uri_normalizer_noop(normalized_scheme): + def normalizer(uri): + return uri + + return normalizer + + +@patch("airflow.assets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error) +@patch("airflow.assets.warnings.warn") +def test_sanitize_uri_raises_warning(mock_warn): + _sanitize_uri("postgres://localhost:5432/database.schema.table") + msg = mock_warn.call_args.args[0] + assert "The Asset URI postgres://localhost:5432/database.schema.table is not AIP-60 compliant" in msg + assert "In Airflow 3, this will raise an exception." in msg + + +@patch("airflow.assets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error) +@conf_vars({("core", "strict_asset_uri_validation"): "True"}) +def test_sanitize_uri_raises_exception(): + with pytest.raises(ValueError) as e_info: + _sanitize_uri("postgres://localhost:5432/database.schema.table") + assert isinstance(e_info.value, ValueError) + assert str(e_info.value) == "Incorrect URI format" + + +@patch("airflow.assets._get_uri_normalizer", lambda x: None) +def test_normalize_uri_no_normalizer_found(): + asset = Asset(uri="any_uri_without_normalizer_defined") + assert asset.normalized_uri is None + + +@patch("airflow.assets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error) +def test_normalize_uri_invalid_uri(): + asset = Asset(uri="any_uri_not_aip60_compliant") + assert asset.normalized_uri is None + + +@patch("airflow.assets._get_uri_normalizer", _mock_get_uri_normalizer_noop) +@patch("airflow.assets._get_normalized_scheme", lambda x: "valid_scheme") +def test_normalize_uri_valid_uri(): + asset = Asset(uri="valid_aip60_uri") + assert asset.normalized_uri == "valid_aip60_uri" + + +@pytest.mark.skip_if_database_isolation_mode +@pytest.mark.db_test +@pytest.mark.usefixtures("clear_assets") +class Test_AssetAliasCondition: + @pytest.fixture + def asset_1(self, session): + """Example asset links to asset alias resolved_asset_alias_2.""" + asset_uri = "test_uri" + asset_1 = AssetModel(id=1, uri=asset_uri) + + session.add(asset_1) + session.commit() + + return asset_1 + + @pytest.fixture + def asset_alias_1(self, session): + """Example asset alias links to no assets.""" + alias_name = "test_name" + asset_alias_model = AssetAliasModel(name=alias_name) + + session.add(asset_alias_model) + session.commit() + + return asset_alias_model + + @pytest.fixture + def resolved_asset_alias_2(self, session, asset_1): + """Example asset alias links to asset asset_alias_1.""" + asset_name = "test_name_2" + asset_alias_2 = AssetAliasModel(name=asset_name) + asset_alias_2.datasets.append(asset_1) + + session.add(asset_alias_2) + session.commit() + + return asset_alias_2 + + def test_init(self, asset_alias_1, asset_1, resolved_asset_alias_2): + cond = _AssetAliasCondition(name=asset_alias_1.name) + assert cond.objects == [] + + cond = _AssetAliasCondition(name=resolved_asset_alias_2.name) + assert cond.objects == [Asset(uri=asset_1.uri)] + + def test_as_expression(self, asset_alias_1, resolved_asset_alias_2): + for assset_alias in (asset_alias_1, resolved_asset_alias_2): + cond = _AssetAliasCondition(assset_alias.name) + assert cond.as_expression() == {"alias": assset_alias.name} + + def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_1): + cond = _AssetAliasCondition(asset_alias_1.name) + assert cond.evaluate({asset_1.uri: True}) is False + + cond = _AssetAliasCondition(resolved_asset_alias_2.name) + assert cond.evaluate({asset_1.uri: True}) is True diff --git a/tests/auth/managers/simple/test_simple_auth_manager.py b/tests/auth/managers/simple/test_simple_auth_manager.py index a11c79063d042..d4bd4e4fbfed2 100644 --- a/tests/auth/managers/simple/test_simple_auth_manager.py +++ b/tests/auth/managers/simple/test_simple_auth_manager.py @@ -140,7 +140,7 @@ def test_get_user_return_none_when_not_logged_in(self, mock_is_logged_in, auth_m "is_authorized_configuration", "is_authorized_connection", "is_authorized_dag", - "is_authorized_dataset", + "is_authorized_asset", "is_authorized_pool", "is_authorized_variable", ], @@ -206,7 +206,7 @@ def test_is_authorized_view_methods( [ "is_authorized_configuration", "is_authorized_connection", - "is_authorized_dataset", + "is_authorized_asset", "is_authorized_pool", "is_authorized_variable", ], @@ -258,7 +258,7 @@ def test_is_authorized_methods_user_role_required( @patch.object(SimpleAuthManager, "is_logged_in") @pytest.mark.parametrize( "api", - ["is_authorized_dag", "is_authorized_dataset", "is_authorized_pool"], + ["is_authorized_dag", "is_authorized_asset", "is_authorized_pool"], ) @pytest.mark.parametrize( "role, method, result", diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index cd9652fb465d0..82efe20048b71 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -35,9 +35,9 @@ from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( AccessView, + AssetDetails, ConfigurationDetails, DagAccessEntity, - DatasetDetails, ) @@ -73,8 +73,8 @@ def is_authorized_dag( ) -> bool: raise NotImplementedError() - def is_authorized_dataset( - self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None + def is_authorized_asset( + self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None ) -> bool: raise NotImplementedError() diff --git a/tests/conftest.py b/tests/conftest.py index 7e3affd2a1b27..60d009416fe8e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -973,10 +973,10 @@ def __call__( def cleanup(self): from airflow.models import DagModel, DagRun, TaskInstance, XCom - from airflow.models.dataset import DatasetEvent from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskmap import TaskMap from airflow.utils.retries import run_with_db_retries + from tests.test_utils.compat import AssetEvent for attempt in run_with_db_retries(logger=self.log): with attempt: @@ -1004,7 +1004,7 @@ def cleanup(self): self.session.query(TaskMap).filter(TaskMap.dag_id.in_(dag_ids)).delete( synchronize_session=False, ) - self.session.query(DatasetEvent).filter(DatasetEvent.source_dag_id.in_(dag_ids)).delete( + self.session.query(AssetEvent).filter(AssetEvent.source_dag_id.in_(dag_ids)).delete( synchronize_session=False, ) self.session.commit() diff --git a/tests/dags/test_datasets.py b/tests/dags/test_assets.py similarity index 91% rename from tests/dags/test_datasets.py rename to tests/dags/test_assets.py index 4bdef9f6978cb..a4ecd6aad4a6a 100644 --- a/tests/dags/test_datasets.py +++ b/tests/dags/test_assets.py @@ -19,14 +19,14 @@ from datetime import datetime -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.exceptions import AirflowFailException, AirflowSkipException from airflow.models.dag import DAG from airflow.operators.bash import BashOperator from airflow.operators.python import PythonOperator -skip_task_dag_dataset = Dataset("s3://dag_with_skip_task/output_1.txt", extra={"hi": "bye"}) -fail_task_dag_dataset = Dataset("s3://dag_with_fail_task/output_1.txt", extra={"hi": "bye"}) +skip_task_dag_dataset = Asset("s3://dag_with_skip_task/output_1.txt", extra={"hi": "bye"}) +fail_task_dag_dataset = Asset("s3://dag_with_fail_task/output_1.txt", extra={"hi": "bye"}) def raise_skip_exc(): diff --git a/tests/dags/test_only_empty_tasks.py b/tests/dags/test_only_empty_tasks.py index 68f1dc5e897ae..2cea9c3c6b173 100644 --- a/tests/dags/test_only_empty_tasks.py +++ b/tests/dags/test_only_empty_tasks.py @@ -20,7 +20,7 @@ from datetime import datetime from typing import Sequence -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.models.dag import DAG from airflow.operators.empty import EmptyOperator @@ -56,4 +56,4 @@ def __init__(self, body, *args, **kwargs): EmptyOperator(task_id="test_task_on_success", on_success_callback=lambda *args, **kwargs: None) - EmptyOperator(task_id="test_task_outlets", outlets=[Dataset("hello")]) + EmptyOperator(task_id="test_task_outlets", outlets=[Asset("hello")]) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py deleted file mode 100644 index 8221a5aea8aa3..0000000000000 --- a/tests/datasets/test_dataset.py +++ /dev/null @@ -1,588 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -import os -from collections import defaultdict -from typing import Callable -from unittest.mock import patch - -import pytest -from sqlalchemy.sql import select - -from airflow.datasets import ( - BaseDataset, - Dataset, - DatasetAlias, - DatasetAll, - DatasetAny, - _DatasetAliasCondition, - _get_normalized_scheme, - _sanitize_uri, -) -from airflow.models.dataset import DatasetAliasModel, DatasetDagRunQueue, DatasetModel -from airflow.models.serialized_dag import SerializedDagModel -from airflow.operators.empty import EmptyOperator -from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG -from tests.test_utils.config import conf_vars - - -@pytest.fixture -def clear_datasets(): - from tests.test_utils.db import clear_db_datasets - - clear_db_datasets() - yield - clear_db_datasets() - - -@pytest.mark.parametrize( - ["uri"], - [ - pytest.param("", id="empty"), - pytest.param("\n\t", id="whitespace"), - pytest.param("a" * 3001, id="too_long"), - pytest.param("airflow://xcom/dag/task", id="reserved_scheme"), - pytest.param("😊", id="non-ascii"), - ], -) -def test_invalid_uris(uri): - with pytest.raises(ValueError): - Dataset(uri=uri) - - -@pytest.mark.parametrize( - "uri, normalized", - [ - pytest.param("foobar", "foobar", id="scheme-less"), - pytest.param("foo:bar", "foo:bar", id="scheme-less-colon"), - pytest.param("foo/bar", "foo/bar", id="scheme-less-slash"), - pytest.param("s3://bucket/key/path", "s3://bucket/key/path", id="normal"), - pytest.param("file:///123/456/", "file:///123/456", id="trailing-slash"), - ], -) -def test_uri_with_scheme(uri: str, normalized: str) -> None: - dataset = Dataset(uri) - EmptyOperator(task_id="task1", outlets=[dataset]) - assert dataset.uri == normalized - assert os.fspath(dataset) == normalized - - -def test_uri_with_auth() -> None: - with pytest.warns(UserWarning) as record: - dataset = Dataset("ftp://user@localhost/foo.txt") - assert len(record) == 1 - assert str(record[0].message) == ( - "A dataset URI should not contain auth info (e.g. username or " - "password). It has been automatically dropped." - ) - EmptyOperator(task_id="task1", outlets=[dataset]) - assert dataset.uri == "ftp://localhost/foo.txt" - assert os.fspath(dataset) == "ftp://localhost/foo.txt" - - -def test_uri_without_scheme(): - dataset = Dataset(uri="example_dataset") - EmptyOperator(task_id="task1", outlets=[dataset]) - - -def test_fspath(): - uri = "s3://example/dataset" - dataset = Dataset(uri=uri) - assert os.fspath(dataset) == uri - - -def test_equal_when_same_uri(): - uri = "s3://example/dataset" - dataset1 = Dataset(uri=uri) - dataset2 = Dataset(uri=uri) - assert dataset1 == dataset2 - - -def test_not_equal_when_different_uri(): - dataset1 = Dataset(uri="s3://example/dataset") - dataset2 = Dataset(uri="s3://other/dataset") - assert dataset1 != dataset2 - - -def test_dataset_logic_operations(): - result_or = dataset1 | dataset2 - assert isinstance(result_or, DatasetAny) - result_and = dataset1 & dataset2 - assert isinstance(result_and, DatasetAll) - - -def test_dataset_iter_datasets(): - assert list(dataset1.iter_datasets()) == [("s3://bucket1/data1", dataset1)] - - -@pytest.mark.db_test -def test_dataset_iter_dataset_aliases(): - base_dataset = DatasetAll( - DatasetAlias("example-alias-1"), - Dataset("1"), - DatasetAny( - Dataset("2"), - DatasetAlias("example-alias-2"), - Dataset("3"), - DatasetAll(DatasetAlias("example-alias-3"), Dataset("4"), DatasetAlias("example-alias-4")), - ), - DatasetAll(DatasetAlias("example-alias-5"), Dataset("5")), - ) - assert list(base_dataset.iter_dataset_aliases()) == [ - (f"example-alias-{i}", DatasetAlias(f"example-alias-{i}")) for i in range(1, 6) - ] - - -def test_dataset_evaluate(): - assert dataset1.evaluate({"s3://bucket1/data1": True}) is True - assert dataset1.evaluate({"s3://bucket1/data1": False}) is False - - -def test_dataset_any_operations(): - result_or = (dataset1 | dataset2) | dataset3 - assert isinstance(result_or, DatasetAny) - assert len(result_or.objects) == 3 - result_and = (dataset1 | dataset2) & dataset3 - assert isinstance(result_and, DatasetAll) - - -def test_dataset_all_operations(): - result_or = (dataset1 & dataset2) | dataset3 - assert isinstance(result_or, DatasetAny) - result_and = (dataset1 & dataset2) & dataset3 - assert isinstance(result_and, DatasetAll) - - -def test_datasetbooleancondition_evaluate_iter(): - """ - Tests _DatasetBooleanCondition's evaluate and iter_datasets methods through DatasetAny and DatasetAll. - Ensures DatasetAny evaluate returns True with any true condition, DatasetAll evaluate returns False if - any condition is false, and both classes correctly iterate over datasets without duplication. - """ - any_condition = DatasetAny(dataset1, dataset2) - all_condition = DatasetAll(dataset1, dataset2) - assert any_condition.evaluate({"s3://bucket1/data1": False, "s3://bucket2/data2": True}) is True - assert all_condition.evaluate({"s3://bucket1/data1": True, "s3://bucket2/data2": False}) is False - - # Testing iter_datasets indirectly through the subclasses - datasets_any = dict(any_condition.iter_datasets()) - datasets_all = dict(all_condition.iter_datasets()) - assert datasets_any == {"s3://bucket1/data1": dataset1, "s3://bucket2/data2": dataset2} - assert datasets_all == {"s3://bucket1/data1": dataset1, "s3://bucket2/data2": dataset2} - - -@pytest.mark.parametrize( - "inputs, scenario, expected", - [ - # Scenarios for DatasetAny - ((True, True, True), "any", True), - ((True, True, False), "any", True), - ((True, False, True), "any", True), - ((True, False, False), "any", True), - ((False, False, True), "any", True), - ((False, True, False), "any", True), - ((False, True, True), "any", True), - ((False, False, False), "any", False), - # Scenarios for DatasetAll - ((True, True, True), "all", True), - ((True, True, False), "all", False), - ((True, False, True), "all", False), - ((True, False, False), "all", False), - ((False, False, True), "all", False), - ((False, True, False), "all", False), - ((False, True, True), "all", False), - ((False, False, False), "all", False), - ], -) -def test_dataset_logical_conditions_evaluation_and_serialization(inputs, scenario, expected): - class_ = DatasetAny if scenario == "any" else DatasetAll - datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)] - condition = class_(*datasets) - - statuses = {dataset.uri: status for dataset, status in zip(datasets, inputs)} - assert ( - condition.evaluate(statuses) == expected - ), f"Condition evaluation failed for inputs {inputs} and scenario '{scenario}'" - - # Serialize and deserialize the condition to test persistence - serialized = BaseSerialization.serialize(condition) - deserialized = BaseSerialization.deserialize(serialized) - assert deserialized.evaluate(statuses) == expected, "Serialization round-trip failed" - - -@pytest.mark.parametrize( - "status_values, expected_evaluation", - [ - ((False, True, True), False), # DatasetAll requires all conditions to be True, but d1 is False - ((True, True, True), True), # All conditions are True - ((True, False, True), True), # d1 is True, and DatasetAny condition (d2 or d3 being True) is met - ((True, False, False), False), # d1 is True, but neither d2 nor d3 meet the DatasetAny condition - ], -) -def test_nested_dataset_conditions_with_serialization(status_values, expected_evaluation): - # Define datasets - d1 = Dataset(uri="s3://abc/123") - d2 = Dataset(uri="s3://abc/124") - d3 = Dataset(uri="s3://abc/125") - - # Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and d3 - nested_condition = DatasetAll(d1, DatasetAny(d2, d3)) - - statuses = { - d1.uri: status_values[0], - d2.uri: status_values[1], - d3.uri: status_values[2], - } - - assert nested_condition.evaluate(statuses) == expected_evaluation, "Initial evaluation mismatch" - - serialized_condition = BaseSerialization.serialize(nested_condition) - deserialized_condition = BaseSerialization.deserialize(serialized_condition) - - assert ( - deserialized_condition.evaluate(statuses) == expected_evaluation - ), "Post-serialization evaluation mismatch" - - -@pytest.fixture -def create_test_datasets(session): - """Fixture to create test datasets and corresponding models.""" - datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)] - for dataset in datasets: - session.add(DatasetModel(uri=dataset.uri)) - session.commit() - return datasets - - -@pytest.mark.db_test -@pytest.mark.usefixtures("clear_datasets") -def test_dataset_trigger_setup_and_serialization(session, dag_maker, create_test_datasets): - datasets = create_test_datasets - - # Create DAG with dataset triggers - with dag_maker(schedule=DatasetAny(*datasets)) as dag: - EmptyOperator(task_id="hello") - - # Verify datasets are set up correctly - assert isinstance( - dag.timetable.dataset_condition, DatasetAny - ), "DAG datasets should be an instance of DatasetAny" - - # Round-trip the DAG through serialization - deserialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) - - # Verify serialization and deserialization integrity - assert isinstance( - deserialized_dag.timetable.dataset_condition, DatasetAny - ), "Deserialized datasets should maintain type DatasetAny" - assert ( - deserialized_dag.timetable.dataset_condition.objects == dag.timetable.dataset_condition.objects - ), "Deserialized datasets should match original" - - -@pytest.mark.db_test -@pytest.mark.usefixtures("clear_datasets") -def test_dataset_dag_run_queue_processing(session, clear_datasets, dag_maker, create_test_datasets): - datasets = create_test_datasets - dataset_models = session.query(DatasetModel).all() - - with dag_maker(schedule=DatasetAny(*datasets)) as dag: - EmptyOperator(task_id="hello") - - # Add DatasetDagRunQueue entries to simulate dataset event processing - for dm in dataset_models: - session.add(DatasetDagRunQueue(dataset_id=dm.id, target_dag_id=dag.dag_id)) - session.commit() - - # Fetch and evaluate dataset triggers for all DAGs affected by dataset events - records = session.scalars(select(DatasetDagRunQueue)).all() - dag_statuses = defaultdict(lambda: defaultdict(bool)) - for record in records: - dag_statuses[record.target_dag_id][record.dataset.uri] = True - - serialized_dags = session.execute( - select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) - ).fetchall() - - for (serialized_dag,) in serialized_dags: - dag = SerializedDAG.deserialize(serialized_dag.data) - for dataset_uri, status in dag_statuses[dag.dag_id].items(): - cond = dag.timetable.dataset_condition - assert cond.evaluate({dataset_uri: status}), "DAG trigger evaluation failed" - - -@pytest.mark.db_test -@pytest.mark.usefixtures("clear_datasets") -def test_dag_with_complex_dataset_condition(session, dag_maker): - # Create Dataset instances - d1 = Dataset(uri="hello1") - d2 = Dataset(uri="hello2") - - # Create and add DatasetModel instances to the session - dm1 = DatasetModel(uri=d1.uri) - dm2 = DatasetModel(uri=d2.uri) - session.add_all([dm1, dm2]) - session.commit() - - # Setup a DAG with complex dataset triggers (DatasetAny with DatasetAll) - with dag_maker(schedule=DatasetAny(d1, DatasetAll(d2, d1))) as dag: - EmptyOperator(task_id="hello") - - assert isinstance( - dag.timetable.dataset_condition, DatasetAny - ), "DAG's dataset trigger should be an instance of DatasetAny" - assert any( - isinstance(trigger, DatasetAll) for trigger in dag.timetable.dataset_condition.objects - ), "DAG's dataset trigger should include DatasetAll" - - serialized_triggers = SerializedDAG.serialize(dag.timetable.dataset_condition) - - deserialized_triggers = SerializedDAG.deserialize(serialized_triggers) - - assert isinstance( - deserialized_triggers, DatasetAny - ), "Deserialized triggers should be an instance of DatasetAny" - assert any( - isinstance(trigger, DatasetAll) for trigger in deserialized_triggers.objects - ), "Deserialized triggers should include DatasetAll" - - serialized_timetable_dict = SerializedDAG.to_dict(dag)["dag"]["timetable"]["__var"] - assert ( - "dataset_condition" in serialized_timetable_dict - ), "Serialized timetable should contain 'dataset_condition'" - assert isinstance( - serialized_timetable_dict["dataset_condition"], dict - ), "Serialized 'dataset_condition' should be a dict" - - -def datasets_equal(d1: BaseDataset, d2: BaseDataset) -> bool: - if type(d1) is not type(d2): - return False - - if isinstance(d1, Dataset) and isinstance(d2, Dataset): - return d1.uri == d2.uri - - elif isinstance(d1, (DatasetAny, DatasetAll)) and isinstance(d2, (DatasetAny, DatasetAll)): - if len(d1.objects) != len(d2.objects): - return False - - # Compare each pair of objects - for obj1, obj2 in zip(d1.objects, d2.objects): - # If obj1 or obj2 is a Dataset, DatasetAny, or DatasetAll instance, - # recursively call datasets_equal - if not datasets_equal(obj1, obj2): - return False - return True - - return False - - -dataset1 = Dataset(uri="s3://bucket1/data1") -dataset2 = Dataset(uri="s3://bucket2/data2") -dataset3 = Dataset(uri="s3://bucket3/data3") -dataset4 = Dataset(uri="s3://bucket4/data4") -dataset5 = Dataset(uri="s3://bucket5/data5") - -test_cases = [ - (lambda: dataset1, dataset1), - (lambda: dataset1 & dataset2, DatasetAll(dataset1, dataset2)), - (lambda: dataset1 | dataset2, DatasetAny(dataset1, dataset2)), - (lambda: dataset1 | (dataset2 & dataset3), DatasetAny(dataset1, DatasetAll(dataset2, dataset3))), - (lambda: dataset1 | dataset2 & dataset3, DatasetAny(dataset1, DatasetAll(dataset2, dataset3))), - ( - lambda: ((dataset1 & dataset2) | dataset3) & (dataset4 | dataset5), - DatasetAll(DatasetAny(DatasetAll(dataset1, dataset2), dataset3), DatasetAny(dataset4, dataset5)), - ), - (lambda: dataset1 & dataset2 | dataset3, DatasetAny(DatasetAll(dataset1, dataset2), dataset3)), - ( - lambda: (dataset1 | dataset2) & (dataset3 | dataset4), - DatasetAll(DatasetAny(dataset1, dataset2), DatasetAny(dataset3, dataset4)), - ), - ( - lambda: (dataset1 & dataset2) | (dataset3 & (dataset4 | dataset5)), - DatasetAny(DatasetAll(dataset1, dataset2), DatasetAll(dataset3, DatasetAny(dataset4, dataset5))), - ), - ( - lambda: (dataset1 & dataset2) & (dataset3 & dataset4), - DatasetAll(dataset1, dataset2, DatasetAll(dataset3, dataset4)), - ), - (lambda: dataset1 | dataset2 | dataset3, DatasetAny(dataset1, dataset2, dataset3)), - (lambda: dataset1 & dataset2 & dataset3, DatasetAll(dataset1, dataset2, dataset3)), - ( - lambda: ((dataset1 & dataset2) | dataset3) & (dataset4 | dataset5), - DatasetAll(DatasetAny(DatasetAll(dataset1, dataset2), dataset3), DatasetAny(dataset4, dataset5)), - ), -] - - -@pytest.mark.parametrize("expression, expected", test_cases) -def test_evaluate_datasets_expression(expression, expected): - expr = expression() - assert datasets_equal(expr, expected) - - -@pytest.mark.parametrize( - "expression, error", - [ - pytest.param( - lambda: dataset1 & 1, # type: ignore[operator] - "unsupported operand type(s) for &: 'Dataset' and 'int'", - id="&", - ), - pytest.param( - lambda: dataset1 | 1, # type: ignore[operator] - "unsupported operand type(s) for |: 'Dataset' and 'int'", - id="|", - ), - pytest.param( - lambda: DatasetAll(1, dataset1), # type: ignore[arg-type] - "expect dataset expressions in condition", - id="DatasetAll", - ), - pytest.param( - lambda: DatasetAny(1, dataset1), # type: ignore[arg-type] - "expect dataset expressions in condition", - id="DatasetAny", - ), - ], -) -def test_datasets_expression_error(expression: Callable[[], None], error: str) -> None: - with pytest.raises(TypeError) as info: - expression() - assert str(info.value) == error - - -def test_get_normalized_scheme(): - assert _get_normalized_scheme("http://example.com") == "http" - assert _get_normalized_scheme("HTTPS://example.com") == "https" - assert _get_normalized_scheme("ftp://example.com") == "ftp" - assert _get_normalized_scheme("file://") == "file" - - assert _get_normalized_scheme("example.com") == "" - assert _get_normalized_scheme("") == "" - assert _get_normalized_scheme(" ") == "" - - -def _mock_get_uri_normalizer_raising_error(normalized_scheme): - def normalizer(uri): - raise ValueError("Incorrect URI format") - - return normalizer - - -def _mock_get_uri_normalizer_noop(normalized_scheme): - def normalizer(uri): - return uri - - return normalizer - - -@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error) -@patch("airflow.datasets.warnings.warn") -def test_sanitize_uri_raises_warning(mock_warn): - _sanitize_uri("postgres://localhost:5432/database.schema.table") - msg = mock_warn.call_args.args[0] - assert "The dataset URI postgres://localhost:5432/database.schema.table is not AIP-60 compliant" in msg - assert "In Airflow 3, this will raise an exception." in msg - - -@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error) -@conf_vars({("core", "strict_dataset_uri_validation"): "True"}) -def test_sanitize_uri_raises_exception(): - with pytest.raises(ValueError) as e_info: - _sanitize_uri("postgres://localhost:5432/database.schema.table") - assert isinstance(e_info.value, ValueError) - assert str(e_info.value) == "Incorrect URI format" - - -@patch("airflow.datasets._get_uri_normalizer", lambda x: None) -def test_normalize_uri_no_normalizer_found(): - dataset = Dataset(uri="any_uri_without_normalizer_defined") - assert dataset.normalized_uri is None - - -@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error) -def test_normalize_uri_invalid_uri(): - dataset = Dataset(uri="any_uri_not_aip60_compliant") - assert dataset.normalized_uri is None - - -@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_noop) -@patch("airflow.datasets._get_normalized_scheme", lambda x: "valid_scheme") -def test_normalize_uri_valid_uri(): - dataset = Dataset(uri="valid_aip60_uri") - assert dataset.normalized_uri == "valid_aip60_uri" - - -@pytest.mark.skip_if_database_isolation_mode -@pytest.mark.db_test -@pytest.mark.usefixtures("clear_datasets") -class Test_DatasetAliasCondition: - @pytest.fixture - def ds_1(self, session): - """Example dataset links to dataset alias resolved_dsa_2.""" - ds_uri = "test_uri" - ds_1 = DatasetModel(id=1, uri=ds_uri) - - session.add(ds_1) - session.commit() - - return ds_1 - - @pytest.fixture - def dsa_1(self, session): - """Example dataset alias links to no datasets.""" - dsa_name = "test_name" - dsa_1 = DatasetAliasModel(name=dsa_name) - - session.add(dsa_1) - session.commit() - - return dsa_1 - - @pytest.fixture - def resolved_dsa_2(self, session, ds_1): - """Example dataset alias links to no dataset dsa_1.""" - dsa_name = "test_name_2" - dsa_2 = DatasetAliasModel(name=dsa_name) - dsa_2.datasets.append(ds_1) - - session.add(dsa_2) - session.commit() - - return dsa_2 - - def test_init(self, dsa_1, ds_1, resolved_dsa_2): - cond = _DatasetAliasCondition(name=dsa_1.name) - assert cond.objects == [] - - cond = _DatasetAliasCondition(name=resolved_dsa_2.name) - assert cond.objects == [Dataset(uri=ds_1.uri)] - - def test_as_expression(self, dsa_1, resolved_dsa_2): - for dsa in (dsa_1, resolved_dsa_2): - cond = _DatasetAliasCondition(dsa.name) - assert cond.as_expression() == {"alias": dsa.name} - - def test_evalute(self, dsa_1, resolved_dsa_2, ds_1): - cond = _DatasetAliasCondition(dsa_1.name) - assert cond.evaluate({ds_1.uri: True}) is False - - cond = _DatasetAliasCondition(resolved_dsa_2.name) - assert cond.evaluate({ds_1.uri: True}) is True diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 96473518cc982..adbf96a0f41ba 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -983,8 +983,8 @@ def other(x): ... @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode -def test_task_decorator_dataset(dag_maker, session): - from airflow.datasets import Dataset +def test_task_decorator_asset(dag_maker, session): + from airflow.assets import Asset result = None uri = "s3://bucket/name" @@ -992,11 +992,11 @@ def test_task_decorator_dataset(dag_maker, session): with dag_maker(session=session) as dag: @dag.task() - def up1() -> Dataset: - return Dataset(uri) + def up1() -> Asset: + return Asset(uri) @dag.task() - def up2(src: Dataset) -> str: + def up2(src: Asset) -> str: return src.uri @dag.task() diff --git a/tests/io/test_path.py b/tests/io/test_path.py index 195c2423b1822..0e504b586b3d2 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -29,7 +29,7 @@ from fsspec.implementations.memory import MemoryFileSystem from fsspec.registry import _registry as _fsspec_registry, register_implementation -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.io import _register_filesystems, get_fs from airflow.io.path import ObjectStoragePath from airflow.io.store import _STORE_CACHE, ObjectStore, attach @@ -280,12 +280,12 @@ def test_move_local(self, hook_lineage_collector): _to.unlink() - collected_datasets = hook_lineage_collector.collected_datasets + collected_assets = hook_lineage_collector.collected_assets - assert len(collected_datasets.inputs) == 1 - assert len(collected_datasets.outputs) == 1 - assert collected_datasets.inputs[0].dataset == Dataset(uri=_from_path) - assert collected_datasets.outputs[0].dataset == Dataset(uri=_to_path) + assert len(collected_assets.inputs) == 1 + assert len(collected_assets.outputs) == 1 + assert collected_assets.inputs[0].asset == Asset(uri=_from_path) + assert collected_assets.outputs[0].asset == Asset(uri=_to_path) def test_move_remote(self, hook_lineage_collector): attach("fakefs", fs=FakeRemoteFileSystem()) @@ -303,12 +303,12 @@ def test_move_remote(self, hook_lineage_collector): _to.unlink() - collected_datasets = hook_lineage_collector.collected_datasets + collected_assets = hook_lineage_collector.collected_assets - assert len(collected_datasets.inputs) == 1 - assert len(collected_datasets.outputs) == 1 - assert collected_datasets.inputs[0].dataset == Dataset(uri=str(_from)) - assert collected_datasets.outputs[0].dataset == Dataset(uri=str(_to)) + assert len(collected_assets.inputs) == 1 + assert len(collected_assets.outputs) == 1 + assert collected_assets.inputs[0].asset == Asset(uri=str(_from)) + assert collected_assets.outputs[0].asset == Asset(uri=str(_to)) def test_copy_remote_remote(self, hook_lineage_collector): attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True)) @@ -338,11 +338,11 @@ def test_copy_remote_remote(self, hook_lineage_collector): _from.rmdir(recursive=True) _to.rmdir(recursive=True) - assert len(hook_lineage_collector.collected_datasets.inputs) == 1 - assert hook_lineage_collector.collected_datasets.inputs[0].dataset == Dataset(uri=str(_from_file)) + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset(uri=str(_from_file)) # Empty file - shutil.copyfileobj does nothing - assert len(hook_lineage_collector.collected_datasets.outputs) == 0 + assert len(hook_lineage_collector.collected_assets.outputs) == 0 def test_serde_objectstoragepath(self): path = "file:///bucket/key/part1/part2" @@ -402,12 +402,12 @@ def test_backwards_compat(self): # Reset the cache to avoid side effects _register_filesystems.cache_clear() - def test_dataset(self): + def test_asset(self): attach("s3", fs=FakeRemoteFileSystem()) p = "s3" f = "/tmp/foo" - i = Dataset(uri=f"{p}://{f}", extra={"foo": "bar"}) + i = Asset(uri=f"{p}://{f}", extra={"foo": "bar"}) o = ObjectStoragePath(i) assert o.protocol == p assert o.path == f diff --git a/tests/io/test_wrapper.py b/tests/io/test_wrapper.py index e00c5ab22bf64..641eda84d1a4f 100644 --- a/tests/io/test_wrapper.py +++ b/tests/io/test_wrapper.py @@ -19,13 +19,13 @@ import uuid from unittest.mock import patch -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.io.path import ObjectStoragePath @patch("airflow.providers_manager.ProvidersManager") def test_wrapper_catches_reads_writes(providers_manager, hook_lineage_collector): - providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x) + providers_manager.return_value._asset_factories = lambda x: Asset(uri=x) uri = f"file:///tmp/{str(uuid.uuid4())}" path = ObjectStoragePath(uri) file = path.open("w") @@ -33,7 +33,7 @@ def test_wrapper_catches_reads_writes(providers_manager, hook_lineage_collector) file.close() assert len(hook_lineage_collector._outputs) == 1 - assert next(iter(hook_lineage_collector._outputs.values()))[0] == Dataset(uri=uri) + assert next(iter(hook_lineage_collector._outputs.values()))[0] == Asset(uri=uri) file = path.open("r") file.read() @@ -42,23 +42,23 @@ def test_wrapper_catches_reads_writes(providers_manager, hook_lineage_collector) path.unlink(missing_ok=True) assert len(hook_lineage_collector._inputs) == 1 - assert next(iter(hook_lineage_collector._inputs.values()))[0] == Dataset(uri=uri) + assert next(iter(hook_lineage_collector._inputs.values()))[0] == Asset(uri=uri) @patch("airflow.providers_manager.ProvidersManager") def test_wrapper_works_with_contextmanager(providers_manager, hook_lineage_collector): - providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x) + providers_manager.return_value._asset_factories = lambda x: Asset(uri=x) uri = f"file:///tmp/{str(uuid.uuid4())}" path = ObjectStoragePath(uri) with path.open("w") as file: file.write("asdf") assert len(hook_lineage_collector._outputs) == 1 - assert next(iter(hook_lineage_collector._outputs.values()))[0] == Dataset(uri=uri) + assert next(iter(hook_lineage_collector._outputs.values()))[0] == Asset(uri=uri) with path.open("r") as file: file.read() path.unlink(missing_ok=True) assert len(hook_lineage_collector._inputs) == 1 - assert next(iter(hook_lineage_collector._inputs.values()))[0] == Dataset(uri=uri) + assert next(iter(hook_lineage_collector._inputs.values()))[0] == Asset(uri=uri) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 78a911153dab2..32662d7d873db 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -36,12 +36,12 @@ import airflow.example_dags from airflow import settings +from airflow.assets import Asset +from airflow.assets.manager import AssetManager from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest from airflow.callbacks.database_callback_sink import DatabaseCallbackSink from airflow.callbacks.pipe_callback_sink import PipeCallbackSink from airflow.dag_processing.manager import DagFileProcessorAgent -from airflow.datasets import Dataset -from airflow.datasets.manager import DatasetManager from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_constants import MOCK_EXECUTOR @@ -50,10 +50,10 @@ from airflow.jobs.job import Job, run_job from airflow.jobs.local_task_job_runner import LocalTaskJobRunner from airflow.jobs.scheduler_job_runner import SchedulerJobRunner +from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun -from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel from airflow.models.db_callback_request import DbCallbackRequest from airflow.models.pool import Pool from airflow.models.serialized_dag import SerializedDagModel @@ -74,8 +74,8 @@ from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.config import conf_vars, env_vars from tests.test_utils.db import ( + clear_db_assets, clear_db_dags, - clear_db_datasets, clear_db_import_errors, clear_db_jobs, clear_db_pools, @@ -141,7 +141,7 @@ def clean_db(): clear_db_sla_miss() clear_db_import_errors() clear_db_jobs() - clear_db_datasets() + clear_db_assets() # DO NOT try to run clear_db_serialized_dags() here - this will break the tests # The tests expect DAGs to be fully loaded here via setUpClass method below @@ -4094,7 +4094,7 @@ def test_create_dag_runs(self, dag_maker): assert dag.get_last_dagrun().creating_job_id == scheduler_job.id @pytest.mark.need_serialized_dag - def test_create_dag_runs_datasets(self, session, dag_maker): + def test_create_dag_runs_assets(self, session, dag_maker): """ Test various invariants of _create_dag_runs. @@ -4103,21 +4103,21 @@ def test_create_dag_runs_datasets(self, session, dag_maker): - That dag_model has next_dagrun """ - dataset1 = Dataset(uri="ds1") - dataset2 = Dataset(uri="ds2") + asset1 = Asset(uri="ds1") + asset2 = Asset(uri="ds2") - with dag_maker(dag_id="datasets-1", start_date=timezone.utcnow(), session=session): - BashOperator(task_id="task", bash_command="echo 1", outlets=[dataset1]) + with dag_maker(dag_id="assets-1", start_date=timezone.utcnow(), session=session): + BashOperator(task_id="task", bash_command="echo 1", outlets=[asset1]) dr = dag_maker.create_dagrun( run_id="run1", execution_date=(DEFAULT_DATE + timedelta(days=100)), data_interval=(DEFAULT_DATE + timedelta(days=10), DEFAULT_DATE + timedelta(days=11)), ) - ds1_id = session.query(DatasetModel.id).filter_by(uri=dataset1.uri).scalar() + asset1_id = session.query(AssetModel.id).filter_by(uri=asset1.uri).scalar() - event1 = DatasetEvent( - dataset_id=ds1_id, + event1 = AssetEvent( + dataset_id=asset1_id, source_task_id="task", source_dag_id=dr.dag_id, source_run_id=dr.run_id, @@ -4132,8 +4132,8 @@ def test_create_dag_runs_datasets(self, session, dag_maker): data_interval=(DEFAULT_DATE + timedelta(days=5), DEFAULT_DATE + timedelta(days=6)), ) - event2 = DatasetEvent( - dataset_id=ds1_id, + event2 = AssetEvent( + dataset_id=asset1_id, source_task_id="task", source_dag_id=dr.dag_id, source_run_id=dr.run_id, @@ -4141,18 +4141,18 @@ def test_create_dag_runs_datasets(self, session, dag_maker): ) session.add(event2) - with dag_maker(dag_id="datasets-consumer-multiple", schedule=[dataset1, dataset2]): + with dag_maker(dag_id="assets-consumer-multiple", schedule=[asset1, asset2]): pass dag2 = dag_maker.dag - with dag_maker(dag_id="datasets-consumer-single", schedule=[dataset1]): + with dag_maker(dag_id="assets-consumer-single", schedule=[asset1]): pass dag3 = dag_maker.dag session = dag_maker.session session.add_all( [ - DatasetDagRunQueue(dataset_id=ds1_id, target_dag_id=dag2.dag_id), - DatasetDagRunQueue(dataset_id=ds1_id, target_dag_id=dag3.dag_id), + AssetDagRunQueue(dataset_id=asset1_id, target_dag_id=dag2.dag_id), + AssetDagRunQueue(dataset_id=asset1_id, target_dag_id=dag3.dag_id), ] ) session.flush() @@ -4169,24 +4169,24 @@ def dict_from_obj(obj): """Get dict of column attrs from SqlAlchemy object.""" return {k.key: obj.__dict__.get(k) for k in obj.__mapper__.column_attrs} - # dag3 should be triggered since it only depends on dataset1, and it's been queued + # dag3 should be triggered since it only depends on asset1, and it's been queued created_run = session.query(DagRun).filter(DagRun.dag_id == dag3.dag_id).one() assert created_run.state == State.QUEUED assert created_run.start_date is None - # we don't have __eq__ defined on DatasetEvent because... given the fact that in the future - # we may register events from other systems, dataset_id + timestamp might not be enough PK + # we don't have __eq__ defined on AssetEvent because... given the fact that in the future + # we may register events from other systems, asset_id + timestamp might not be enough PK assert list(map(dict_from_obj, created_run.consumed_dataset_events)) == list( map(dict_from_obj, [event1, event2]) ) assert created_run.data_interval_start == DEFAULT_DATE + timedelta(days=5) assert created_run.data_interval_end == DEFAULT_DATE + timedelta(days=11) - # dag2 DDRQ record should still be there since the dag run was *not* triggered - assert session.query(DatasetDagRunQueue).filter_by(target_dag_id=dag2.dag_id).one() is not None - # dag2 should not be triggered since it depends on both dataset 1 and 2 + # dag2 ADRQ record should still be there since the dag run was *not* triggered + assert session.query(AssetDagRunQueue).filter_by(target_dag_id=dag2.dag_id).one() is not None + # dag2 should not be triggered since it depends on both asset 1 and 2 assert session.query(DagRun).filter(DagRun.dag_id == dag2.dag_id).one_or_none() is None - # dag3 DDRQ record should be deleted since the dag run was triggered - assert session.query(DatasetDagRunQueue).filter_by(target_dag_id=dag3.dag_id).one_or_none() is None + # dag3 ADRQ record should be deleted since the dag run was triggered + assert session.query(AssetDagRunQueue).filter_by(target_dag_id=dag3.dag_id).one_or_none() is None assert dag3.get_last_dagrun().creating_job_id == scheduler_job.id @@ -4199,47 +4199,47 @@ def dict_from_obj(obj): ], ) def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable, enable): - ds = Dataset("ds") + ds = Asset("ds") with dag_maker(dag_id="consumer", schedule=[ds], session=session): pass with dag_maker(dag_id="producer", schedule="@daily", session=session): BashOperator(task_id="task", bash_command="echo 1", outlets=ds) - dsm = DatasetManager() + asset_manger = AssetManager() - ds_id = session.scalars(select(DatasetModel.id).filter_by(uri=ds.uri)).one() + asset_id = session.scalars(select(AssetModel.id).filter_by(uri=ds.uri)).one() - dse_q = select(DatasetEvent).where(DatasetEvent.dataset_id == ds_id).order_by(DatasetEvent.timestamp) - ddrq_q = select(DatasetDagRunQueue).where( - DatasetDagRunQueue.dataset_id == ds_id, DatasetDagRunQueue.target_dag_id == "consumer" + ase_q = select(AssetEvent).where(AssetEvent.dataset_id == asset_id).order_by(AssetEvent.timestamp) + adrq_q = select(AssetDagRunQueue).where( + AssetDagRunQueue.dataset_id == asset_id, AssetDagRunQueue.target_dag_id == "consumer" ) # Simulate the consumer DAG being disabled. session.execute(update(DagModel).where(DagModel.dag_id == "consumer").values(**disable)) - # A DDRQ is not scheduled although an event is emitted. + # An ADRQ is not scheduled although an event is emitted. dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) - dsm.register_dataset_change( + asset_manger.register_asset_change( task_instance=dr1.get_task_instance("task", session=session), - dataset=ds, + asset=ds, session=session, ) session.flush() - assert session.scalars(dse_q).one().source_run_id == dr1.run_id - assert session.scalars(ddrq_q).one_or_none() is None + assert session.scalars(ase_q).one().source_run_id == dr1.run_id + assert session.scalars(adrq_q).one_or_none() is None # Simulate the consumer DAG being enabled. session.execute(update(DagModel).where(DagModel.dag_id == "consumer").values(**enable)) - # A DDRQ should be scheduled for the new event, but not the previous one. + # An ADRQ should be scheduled for the new event, but not the previous one. dr2: DagRun = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED) - dsm.register_dataset_change( + asset_manger.register_asset_change( task_instance=dr2.get_task_instance("task", session=session), - dataset=ds, + asset=ds, session=session, ) session.flush() - assert [e.source_run_id for e in session.scalars(dse_q)] == [dr1.run_id, dr2.run_id] - assert session.scalars(ddrq_q).one().target_dag_id == "consumer" + assert [e.source_run_id for e in session.scalars(ase_q)] == [dr1.run_id, dr2.run_id] + assert session.scalars(adrq_q).one().target_dag_id == "consumer" @time_machine.travel(DEFAULT_DATE + datetime.timedelta(days=1, seconds=9), tick=False) @mock.patch("airflow.jobs.scheduler_job_runner.Stats.timing") @@ -5728,87 +5728,85 @@ def test_update_dagrun_state_for_paused_dag_not_for_backfill(self, dag_maker, se (backfill_run,) = DagRun.find(dag_id=dag.dag_id, run_type=DagRunType.BACKFILL_JOB, session=session) assert backfill_run.state == State.RUNNING - def test_dataset_orphaning(self, dag_maker, session): - dataset1 = Dataset(uri="ds1") - dataset2 = Dataset(uri="ds2") - dataset3 = Dataset(uri="ds3") - dataset4 = Dataset(uri="ds4") + def test_asset_orphaning(self, dag_maker, session): + asset1 = Asset(uri="ds1") + asset2 = Asset(uri="ds2") + asset3 = Asset(uri="ds3") + asset4 = Asset(uri="ds4") - with dag_maker(dag_id="datasets-1", schedule=[dataset1, dataset2], session=session): - BashOperator(task_id="task", bash_command="echo 1", outlets=[dataset3, dataset4]) + with dag_maker(dag_id="assets-1", schedule=[asset1, asset2], session=session): + BashOperator(task_id="task", bash_command="echo 1", outlets=[asset3, asset4]) - non_orphaned_dataset_count = session.query(DatasetModel).filter(~DatasetModel.is_orphaned).count() - assert non_orphaned_dataset_count == 4 - orphaned_dataset_count = session.query(DatasetModel).filter(DatasetModel.is_orphaned).count() - assert orphaned_dataset_count == 0 + non_orphaned_asset_count = session.query(AssetModel).filter(~AssetModel.is_orphaned).count() + assert non_orphaned_asset_count == 4 + orphaned_asset_count = session.query(AssetModel).filter(AssetModel.is_orphaned).count() + assert orphaned_asset_count == 0 - # now remove 2 dataset references - with dag_maker(dag_id="datasets-1", schedule=[dataset1], session=session): - BashOperator(task_id="task", bash_command="echo 1", outlets=[dataset3]) + # now remove 2 asset references + with dag_maker(dag_id="assets-1", schedule=[asset1], session=session): + BashOperator(task_id="task", bash_command="echo 1", outlets=[asset3]) scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - self.job_runner._orphan_unreferenced_datasets(session=session) + self.job_runner._orphan_unreferenced_assets(session=session) session.flush() # and find the orphans - non_orphaned_datasets = [ - dataset.uri - for dataset in session.query(DatasetModel.uri) - .filter(~DatasetModel.is_orphaned) - .order_by(DatasetModel.uri) + non_orphaned_assets = [ + asset.uri + for asset in session.query(AssetModel.uri) + .filter(~AssetModel.is_orphaned) + .order_by(AssetModel.uri) ] - assert non_orphaned_datasets == ["ds1", "ds3"] - orphaned_datasets = [ - dataset.uri - for dataset in session.query(DatasetModel.uri) - .filter(DatasetModel.is_orphaned) - .order_by(DatasetModel.uri) + assert non_orphaned_assets == ["ds1", "ds3"] + orphaned_assets = [ + asset.uri + for asset in session.query(AssetModel.uri).filter(AssetModel.is_orphaned).order_by(AssetModel.uri) ] - assert orphaned_datasets == ["ds2", "ds4"] + assert orphaned_assets == ["ds2", "ds4"] - def test_dataset_orphaning_ignore_orphaned_datasets(self, dag_maker, session): - dataset1 = Dataset(uri="ds1") + def test_asset_orphaning_ignore_orphaned_assets(self, dag_maker, session): + asset1 = Asset(uri="ds1") - with dag_maker(dag_id="datasets-1", schedule=[dataset1], session=session): + with dag_maker(dag_id="assets-1", schedule=[asset1], session=session): BashOperator(task_id="task", bash_command="echo 1") - non_orphaned_dataset_count = session.query(DatasetModel).filter(~DatasetModel.is_orphaned).count() - assert non_orphaned_dataset_count == 1 - orphaned_dataset_count = session.query(DatasetModel).filter(DatasetModel.is_orphaned).count() - assert orphaned_dataset_count == 0 + non_orphaned_asset_count = session.query(AssetModel).filter(~AssetModel.is_orphaned).count() + assert non_orphaned_asset_count == 1 + orphaned_asset_count = session.query(AssetModel).filter(AssetModel.is_orphaned).count() + assert orphaned_asset_count == 0 - # now remove dataset1 reference - with dag_maker(dag_id="datasets-1", schedule=None, session=session): + # now remove asset1 reference + with dag_maker(dag_id="assets-1", schedule=None, session=session): BashOperator(task_id="task", bash_command="echo 1") scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - self.job_runner._orphan_unreferenced_datasets(session=session) + self.job_runner._orphan_unreferenced_assets(session=session) session.flush() - orphaned_datasets_before_rerun = ( - session.query(DatasetModel.updated_at, DatasetModel.uri) - .filter(DatasetModel.is_orphaned) - .order_by(DatasetModel.uri) + orphaned_assets_before_rerun = ( + session.query(AssetModel.updated_at, AssetModel.uri) + .filter(AssetModel.is_orphaned) + .order_by(AssetModel.uri) ) - assert [dataset.uri for dataset in orphaned_datasets_before_rerun] == ["ds1"] - updated_at_timestamps = [dataset.updated_at for dataset in orphaned_datasets_before_rerun] + assert [asset.uri for asset in orphaned_assets_before_rerun] == ["ds1"] + updated_at_timestamps = [asset.updated_at for asset in orphaned_assets_before_rerun] - # when rerunning we should ignore the already orphaned datasets and thus the updated_at timestamp + # when rerunning we should ignore the already orphaned assets and thus the updated_at timestamp # should remain the same - self.job_runner._orphan_unreferenced_datasets(session=session) + self.job_runner._orphan_unreferenced_assets(session=session) session.flush() - orphaned_datasets_after_rerun = ( - session.query(DatasetModel.updated_at, DatasetModel.uri) - .filter(DatasetModel.is_orphaned) - .order_by(DatasetModel.uri) + orphaned_assets_after_rerun = ( + session.query(AssetModel.updated_at, AssetModel.uri) + .filter(AssetModel.is_orphaned) + .order_by(AssetModel.uri) ) - assert [dataset.uri for dataset in orphaned_datasets_after_rerun] == ["ds1"] - assert updated_at_timestamps == [dataset.updated_at for dataset in orphaned_datasets_after_rerun] + assert [asset.uri for asset in orphaned_assets_after_rerun] == ["ds1"] + assert updated_at_timestamps == [asset.updated_at for asset in orphaned_assets_after_rerun] def test_misconfigured_dags_doesnt_crash_scheduler(self, session, dag_maker, caplog): """Test that if dagrun creation throws an exception, the scheduler doesn't crash""" diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py index 67059f91b4da1..c076b19aecedd 100644 --- a/tests/lineage/test_hook.py +++ b/tests/lineage/test_hook.py @@ -22,11 +22,11 @@ import pytest from airflow import plugins_manager -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.hooks.base import BaseHook from airflow.lineage import hook from airflow.lineage.hook import ( - DatasetLineageInfo, + AssetLineageInfo, HookLineage, HookLineageCollector, HookLineageReader, @@ -40,156 +40,124 @@ class TestHookLineageCollector: def setup_method(self): self.collector = HookLineageCollector() - def test_are_datasets_collected(self): + def test_are_assets_collected(self): assert self.collector is not None - assert self.collector.collected_datasets == HookLineage() + assert self.collector.collected_assets == HookLineage() input_hook = BaseHook() output_hook = BaseHook() - self.collector.add_input_dataset(input_hook, uri="s3://in_bucket/file") - self.collector.add_output_dataset( - output_hook, uri="postgres://example.com:5432/database/default/table" - ) - assert self.collector.collected_datasets == HookLineage( - [DatasetLineageInfo(dataset=Dataset("s3://in_bucket/file"), count=1, context=input_hook)], + self.collector.add_input_asset(input_hook, uri="s3://in_bucket/file") + self.collector.add_output_asset(output_hook, uri="postgres://example.com:5432/database/default/table") + assert self.collector.collected_assets == HookLineage( + [AssetLineageInfo(asset=Asset("s3://in_bucket/file"), count=1, context=input_hook)], [ - DatasetLineageInfo( - dataset=Dataset("postgres://example.com:5432/database/default/table"), + AssetLineageInfo( + asset=Asset("postgres://example.com:5432/database/default/table"), count=1, context=output_hook, ) ], ) - @patch("airflow.lineage.hook.Dataset") - def test_add_input_dataset(self, mock_dataset): - dataset = MagicMock(spec=Dataset, extra={}) - mock_dataset.return_value = dataset + @patch("airflow.lineage.hook.Asset") + def test_add_input_asset(self, mock_asset): + asset = MagicMock(spec=Asset, extra={}) + mock_asset.return_value = asset hook = MagicMock() - self.collector.add_input_dataset(hook, uri="test_uri") + self.collector.add_input_asset(hook, uri="test_uri") - assert next(iter(self.collector._inputs.values())) == (dataset, hook) - mock_dataset.assert_called_once_with(uri="test_uri", extra=None) + assert next(iter(self.collector._inputs.values())) == (asset, hook) + mock_asset.assert_called_once_with(uri="test_uri", extra=None) - def test_grouping_datasets(self): + def test_grouping_assets(self): hook_1 = MagicMock() hook_2 = MagicMock() uri = "test://uri/" - self.collector.add_input_dataset(context=hook_1, uri=uri) - self.collector.add_input_dataset(context=hook_2, uri=uri) - self.collector.add_input_dataset(context=hook_1, uri=uri, dataset_extra={"key": "value"}) + self.collector.add_input_asset(context=hook_1, uri=uri) + self.collector.add_input_asset(context=hook_2, uri=uri) + self.collector.add_input_asset(context=hook_1, uri=uri, asset_extra={"key": "value"}) - collected_inputs = self.collector.collected_datasets.inputs + collected_inputs = self.collector.collected_assets.inputs assert len(collected_inputs) == 3 - assert collected_inputs[0].dataset.uri == "test://uri/" - assert collected_inputs[0].dataset == collected_inputs[1].dataset + assert collected_inputs[0].asset.uri == "test://uri/" + assert collected_inputs[0].asset == collected_inputs[1].asset assert collected_inputs[0].count == 1 assert collected_inputs[0].context == collected_inputs[2].context == hook_1 assert collected_inputs[1].count == 1 assert collected_inputs[1].context == hook_2 assert collected_inputs[2].count == 1 - assert collected_inputs[2].dataset.extra == {"key": "value"} + assert collected_inputs[2].asset.extra == {"key": "value"} @patch("airflow.lineage.hook.ProvidersManager") - def test_create_dataset(self, mock_providers_manager): - def create_dataset(arg1, arg2="default", extra=None): - return Dataset(uri=f"myscheme://{arg1}/{arg2}", extra=extra) - - test_scheme = "myscheme" - mock_providers_manager.return_value.dataset_factories = {test_scheme: create_dataset} - - test_uri = "urischeme://value_a/value_b" - test_kwargs = {"arg1": "value_1"} - test_kwargs_uri = "myscheme://value_1/default" - test_extra = {"key": "value"} - - # test uri arg - should take precedence over the keyword args + scheme - assert self.collector.create_dataset( - scheme=test_scheme, uri=test_uri, dataset_kwargs=test_kwargs, dataset_extra=None - ) == Dataset(test_uri) - assert self.collector.create_dataset( - scheme=test_scheme, uri=test_uri, dataset_kwargs=test_kwargs, dataset_extra={} - ) == Dataset(test_uri) - assert self.collector.create_dataset( - scheme=test_scheme, uri=test_uri, dataset_kwargs=test_kwargs, dataset_extra=test_extra - ) == Dataset(test_uri, extra=test_extra) - - # test keyword args - assert self.collector.create_dataset( - scheme=test_scheme, uri=None, dataset_kwargs=test_kwargs, dataset_extra=None - ) == Dataset(test_kwargs_uri) - assert self.collector.create_dataset( - scheme=test_scheme, uri=None, dataset_kwargs=test_kwargs, dataset_extra={} - ) == Dataset(test_kwargs_uri) - assert self.collector.create_dataset( - scheme=test_scheme, + def test_create_asset(self, mock_providers_manager): + def create_asset(arg1, arg2="default", extra=None): + return Asset(uri=f"myscheme://{arg1}/{arg2}", extra=extra or {}) + + mock_providers_manager.return_value.asset_factories = {"myscheme": create_asset} + assert self.collector.create_asset( + scheme="myscheme", uri=None, asset_kwargs={"arg1": "value_1"}, asset_extra=None + ) == Asset("myscheme://value_1/default") + assert self.collector.create_asset( + scheme="myscheme", uri=None, - dataset_kwargs={**test_kwargs, "arg2": "value_2"}, - dataset_extra=test_extra, - ) == Dataset("myscheme://value_1/value_2", extra=test_extra) - - # missing both uri and scheme - assert ( - self.collector.create_dataset( - scheme=None, uri=None, dataset_kwargs=test_kwargs, dataset_extra=None - ) - is None - ) + asset_kwargs={"arg1": "value_1", "arg2": "value_2"}, + asset_extra={"key": "value"}, + ) == Asset("myscheme://value_1/value_2", extra={"key": "value"}) @patch("airflow.lineage.hook.ProvidersManager") - def test_create_dataset_no_factory(self, mock_providers_manager): + def test_create_asset_no_factory(self, mock_providers_manager): test_scheme = "myscheme" - mock_providers_manager.return_value.dataset_factories = {} + mock_providers_manager.return_value.asset_factories = {} test_kwargs = {"arg1": "value_1"} assert ( - self.collector.create_dataset( - scheme=test_scheme, uri=None, dataset_kwargs=test_kwargs, dataset_extra=None + self.collector.create_asset( + scheme=test_scheme, uri=None, asset_kwargs=test_kwargs, asset_extra=None ) is None ) @patch("airflow.lineage.hook.ProvidersManager") - def test_create_dataset_factory_exception(self, mock_providers_manager): - def create_dataset(extra=None, **kwargs): + def test_create_asset_factory_exception(self, mock_providers_manager): + def create_asset(extra=None, **kwargs): raise RuntimeError("Factory error") test_scheme = "myscheme" - mock_providers_manager.return_value.dataset_factories = {test_scheme: create_dataset} + mock_providers_manager.return_value.asset_factories = {test_scheme: create_asset} test_kwargs = {"arg1": "value_1"} assert ( - self.collector.create_dataset( - scheme=test_scheme, uri=None, dataset_kwargs=test_kwargs, dataset_extra=None + self.collector.create_asset( + scheme=test_scheme, uri=None, asset_kwargs=test_kwargs, asset_extra=None ) is None ) - def test_collected_datasets(self): + def test_collected_assets(self): context_input = MagicMock() context_output = MagicMock() - self.collector.add_input_dataset(context_input, uri="test://input") - self.collector.add_output_dataset(context_output, uri="test://output") + self.collector.add_input_asset(context_input, uri="test://input") + self.collector.add_output_asset(context_output, uri="test://output") - hook_lineage = self.collector.collected_datasets + hook_lineage = self.collector.collected_assets assert len(hook_lineage.inputs) == 1 - assert hook_lineage.inputs[0].dataset.uri == "test://input/" + assert hook_lineage.inputs[0].asset.uri == "test://input/" assert hook_lineage.inputs[0].context == context_input assert len(hook_lineage.outputs) == 1 - assert hook_lineage.outputs[0].dataset.uri == "test://output/" + assert hook_lineage.outputs[0].asset.uri == "test://output/" def test_has_collected(self): collector = HookLineageCollector() assert not collector.has_collected - collector._inputs = {"unique_key": (MagicMock(spec=Dataset), MagicMock())} + collector._inputs = {"unique_key": (MagicMock(spec=Asset), MagicMock())} assert collector.has_collected diff --git a/tests/listeners/dataset_listener.py b/tests/listeners/asset_listener.py similarity index 80% rename from tests/listeners/dataset_listener.py rename to tests/listeners/asset_listener.py index 0e4b768c696f1..e7adf580363b8 100644 --- a/tests/listeners/dataset_listener.py +++ b/tests/listeners/asset_listener.py @@ -23,21 +23,21 @@ from airflow.listeners import hookimpl if typing.TYPE_CHECKING: - from airflow.datasets import Dataset + from airflow.assets import Asset -changed: list[Dataset] = [] -created: list[Dataset] = [] +changed: list[Asset] = [] +created: list[Asset] = [] @hookimpl -def on_dataset_changed(dataset): - changed.append(copy.deepcopy(dataset)) +def on_asset_changed(asset): + changed.append(copy.deepcopy(asset)) @hookimpl -def on_dataset_created(dataset): - created.append(copy.deepcopy(dataset)) +def on_asset_created(asset): + created.append(copy.deepcopy(asset)) def clear(): diff --git a/tests/listeners/test_dataset_listener.py b/tests/listeners/test_asset_listener.py similarity index 72% rename from tests/listeners/test_dataset_listener.py rename to tests/listeners/test_asset_listener.py index b0ac6223e79ea..bb93acd8a0fff 100644 --- a/tests/listeners/test_dataset_listener.py +++ b/tests/listeners/test_asset_listener.py @@ -18,33 +18,33 @@ import pytest -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.listeners.listener import get_listener_manager -from airflow.models.dataset import DatasetModel +from airflow.models.asset import AssetModel from airflow.operators.empty import EmptyOperator from airflow.utils.session import provide_session -from tests.listeners import dataset_listener +from tests.listeners import asset_listener @pytest.fixture(autouse=True) def clean_listener_manager(): lm = get_listener_manager() lm.clear() - lm.add_listener(dataset_listener) + lm.add_listener(asset_listener) yield lm = get_listener_manager() lm.clear() - dataset_listener.clear() + asset_listener.clear() @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode @pytest.mark.db_test @provide_session -def test_dataset_listener_on_dataset_changed_gets_calls(create_task_instance_of_operator, session): - dataset_uri = "test_dataset_uri" - ds = Dataset(uri=dataset_uri) - ds_model = DatasetModel(uri=dataset_uri) - session.add(ds_model) +def test_asset_listener_on_asset_changed_gets_calls(create_task_instance_of_operator, session): + asset_uri = "test_asset_uri" + asset = Asset(uri=asset_uri) + asset_model = AssetModel(uri=asset_uri) + session.add(asset_model) session.flush() @@ -53,9 +53,9 @@ def test_dataset_listener_on_dataset_changed_gets_calls(create_task_instance_of_ dag_id="producing_dag", task_id="test_task", session=session, - outlets=[ds], + outlets=[asset], ) ti.run() - assert len(dataset_listener.changed) == 1 - assert dataset_listener.changed[0].uri == dataset_uri + assert len(asset_listener.changed) == 1 + assert asset_listener.changed[0].uri == asset_uri diff --git a/tests/models/test_dataset.py b/tests/models/test_asset.py similarity index 73% rename from tests/models/test_dataset.py rename to tests/models/test_asset.py index f562e2347b008..5b35a0c89529e 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_asset.py @@ -17,13 +17,13 @@ from __future__ import annotations -from airflow.datasets import DatasetAlias -from airflow.models.dataset import DatasetAliasModel +from airflow.assets import AssetAlias +from airflow.models.asset import AssetAliasModel -class TestDatasetAliasModel: +class TestAssetAliasModel: def test_from_public(self): - dataset_alias = DatasetAlias(name="test_alias") - dataset_alias_model = DatasetAliasModel.from_public(dataset_alias) + asset_alias = AssetAlias(name="test_alias") + asset_alias_model = AssetAliasModel.from_public(asset_alias) - assert dataset_alias_model.name == "test_alias" + assert asset_alias_model.name == "test_alias" diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index df4a892768816..ab67c3778c262 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -38,8 +38,8 @@ from sqlalchemy import inspect, select from airflow import settings +from airflow.assets import Asset, AssetAlias, AssetAll, AssetAny from airflow.configuration import conf -from airflow.datasets import Dataset, DatasetAlias, DatasetAll, DatasetAny from airflow.decorators import setup, task as task_decorator, teardown from airflow.exceptions import ( AirflowException, @@ -51,6 +51,13 @@ from airflow.executors import executor_loader from airflow.executors.local_executor import LocalExecutor from airflow.executors.sequential_executor import SequentialExecutor +from airflow.models.asset import ( + AssetAliasModel, + AssetDagRunQueue, + AssetEvent, + AssetModel, + TaskOutletAssetReference, +) from airflow.models.baseoperator import BaseOperator from airflow.models.dag import ( DAG, @@ -60,16 +67,9 @@ DagTag, ExecutorLoader, dag as dag_decorator, - get_dataset_triggered_next_run_info, + get_asset_triggered_next_run_info, ) from airflow.models.dagrun import DagRun -from airflow.models.dataset import ( - DatasetAliasModel, - DatasetDagRunQueue, - DatasetEvent, - DatasetModel, - TaskOutletDatasetReference, -) from airflow.models.param import DagParam, Param, ParamsDict from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskfail import TaskFail @@ -81,8 +81,8 @@ from airflow.templates import NativeEnvironment, SandboxedEnvironment from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import ( + AssetTriggeredTimetable, ContinuousTimetable, - DatasetTriggeredTimetable, NullTimetable, OnceTimetable, ) @@ -105,7 +105,7 @@ from tests.test_utils.asserts import assert_queries_count from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.config import conf_vars -from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs, clear_db_serialized_dags +from tests.test_utils.db import clear_db_assets, clear_db_dags, clear_db_runs, clear_db_serialized_dags from tests.test_utils.mapping import expand_mapped_task from tests.test_utils.mock_plugins import mock_plugin_manager from tests.test_utils.timetables import cron_timetable, delta_timetable @@ -133,24 +133,24 @@ def clear_dags(): @pytest.fixture -def clear_datasets(): - clear_db_datasets() +def clear_assets(): + clear_db_assets() yield - clear_db_datasets() + clear_db_assets() class TestDag: def setup_method(self) -> None: clear_db_runs() clear_db_dags() - clear_db_datasets() + clear_db_assets() self.patcher_dag_code = mock.patch("airflow.models.dag.DagCode.bulk_sync_to_db") self.patcher_dag_code.start() def teardown_method(self) -> None: clear_db_runs() clear_db_dags() - clear_db_datasets() + clear_db_assets() self.patcher_dag_code.stop() @staticmethod @@ -1004,47 +1004,47 @@ def test_bulk_write_to_db_has_import_error(self): assert not model.has_import_errors session.close() - def test_bulk_write_to_db_datasets(self): + def test_bulk_write_to_db_assets(self): """ - Ensure that datasets referenced in a dag are correctly loaded into the database. + Ensure that assets referenced in a dag are correctly loaded into the database. """ - dag_id1 = "test_dataset_dag1" - dag_id2 = "test_dataset_dag2" - task_id = "test_dataset_task" - uri1 = "s3://dataset/1" - d1 = Dataset(uri1, extra={"not": "used"}) - d2 = Dataset("s3://dataset/2") - d3 = Dataset("s3://dataset/3") + dag_id1 = "test_asset_dag1" + dag_id2 = "test_asset_dag2" + task_id = "test_asset_task" + uri1 = "s3://asset/1" + d1 = Asset(uri1, extra={"not": "used"}) + d2 = Asset("s3://asset/2") + d3 = Asset("s3://asset/3") dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[d1]) EmptyOperator(task_id=task_id, dag=dag1, outlets=[d2, d3]) dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) - EmptyOperator(task_id=task_id, dag=dag2, outlets=[Dataset(uri1, extra={"should": "be used"})]) + EmptyOperator(task_id=task_id, dag=dag2, outlets=[Asset(uri1, extra={"should": "be used"})]) session = settings.Session() dag1.clear() DAG.bulk_write_to_db([dag1, dag2], session=session) session.commit() - stored_datasets = {x.uri: x for x in session.query(DatasetModel).all()} - d1_orm = stored_datasets[d1.uri] - d2_orm = stored_datasets[d2.uri] - d3_orm = stored_datasets[d3.uri] - assert stored_datasets[uri1].extra == {"should": "be used"} - assert [x.dag_id for x in d1_orm.consuming_dags] == [dag_id1] - assert [(x.task_id, x.dag_id) for x in d1_orm.producing_tasks] == [(task_id, dag_id2)] + stored_assets = {x.uri: x for x in session.query(AssetModel).all()} + asset1_orm = stored_assets[d1.uri] + asset2_orm = stored_assets[d2.uri] + asset3_orm = stored_assets[d3.uri] + assert stored_assets[uri1].extra == {"should": "be used"} + assert [x.dag_id for x in asset1_orm.consuming_dags] == [dag_id1] + assert [(x.task_id, x.dag_id) for x in asset1_orm.producing_tasks] == [(task_id, dag_id2)] assert set( session.query( - TaskOutletDatasetReference.task_id, - TaskOutletDatasetReference.dag_id, - TaskOutletDatasetReference.dataset_id, + TaskOutletAssetReference.task_id, + TaskOutletAssetReference.dag_id, + TaskOutletAssetReference.dataset_id, ) - .filter(TaskOutletDatasetReference.dag_id.in_((dag_id1, dag_id2))) + .filter(TaskOutletAssetReference.dag_id.in_((dag_id1, dag_id2))) .all() ) == { - (task_id, dag_id1, d2_orm.id), - (task_id, dag_id1, d3_orm.id), - (task_id, dag_id2, d1_orm.id), + (task_id, dag_id1, asset2_orm.id), + (task_id, dag_id1, asset3_orm.id), + (task_id, dag_id2, asset1_orm.id), } - # now that we have verified that a new dag has its dataset references recorded properly, + # now that we have verified that a new dag has its asset references recorded properly, # we need to verify that *changes* are recorded properly. # so if any references are *removed*, they should also be deleted from the DB # so let's remove some references and see what happens @@ -1055,96 +1055,96 @@ def test_bulk_write_to_db_datasets(self): DAG.bulk_write_to_db([dag1, dag2], session=session) session.commit() session.expunge_all() - stored_datasets = {x.uri: x for x in session.query(DatasetModel).all()} - d1_orm = stored_datasets[d1.uri] - d2_orm = stored_datasets[d2.uri] - assert [x.dag_id for x in d1_orm.consuming_dags] == [] + stored_assets = {x.uri: x for x in session.query(AssetModel).all()} + asset1_orm = stored_assets[d1.uri] + asset2_orm = stored_assets[d2.uri] + assert [x.dag_id for x in asset1_orm.consuming_dags] == [] assert set( session.query( - TaskOutletDatasetReference.task_id, - TaskOutletDatasetReference.dag_id, - TaskOutletDatasetReference.dataset_id, + TaskOutletAssetReference.task_id, + TaskOutletAssetReference.dag_id, + TaskOutletAssetReference.dataset_id, ) - .filter(TaskOutletDatasetReference.dag_id.in_((dag_id1, dag_id2))) + .filter(TaskOutletAssetReference.dag_id.in_((dag_id1, dag_id2))) .all() - ) == {(task_id, dag_id1, d2_orm.id)} + ) == {(task_id, dag_id1, asset2_orm.id)} - def test_bulk_write_to_db_unorphan_datasets(self): + def test_bulk_write_to_db_unorphan_assets(self): """ - Datasets can lose their last reference and be orphaned, but then if a reference to them reappears, we - need to un-orphan those datasets + Assets can lose their last reference and be orphaned, but then if a reference to them reappears, we + need to un-orphan those assets """ with create_session() as session: - # Create four datasets - two that have references and two that are unreferenced and marked as + # Create four assets - two that have references and two that are unreferenced and marked as # orphans - dataset1 = Dataset(uri="ds1") - dataset2 = Dataset(uri="ds2") - session.add(DatasetModel(uri=dataset2.uri, is_orphaned=True)) - dataset3 = Dataset(uri="ds3") - dataset4 = Dataset(uri="ds4") - session.add(DatasetModel(uri=dataset4.uri, is_orphaned=True)) + asset1 = Asset(uri="ds1") + asset2 = Asset(uri="ds2") + session.add(AssetModel(uri=asset2.uri, is_orphaned=True)) + asset3 = Asset(uri="ds3") + asset4 = Asset(uri="ds4") + session.add(AssetModel(uri=asset4.uri, is_orphaned=True)) session.flush() - dag1 = DAG(dag_id="datasets-1", start_date=DEFAULT_DATE, schedule=[dataset1]) - BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[dataset3]) + dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1]) + BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3]) DAG.bulk_write_to_db([dag1], session=session) # Double check - non_orphaned_datasets = [ - dataset.uri - for dataset in session.query(DatasetModel.uri) - .filter(~DatasetModel.is_orphaned) - .order_by(DatasetModel.uri) + non_orphaned_assets = [ + asset.uri + for asset in session.query(AssetModel.uri) + .filter(~AssetModel.is_orphaned) + .order_by(AssetModel.uri) ] - assert non_orphaned_datasets == ["ds1", "ds3"] - orphaned_datasets = [ - dataset.uri - for dataset in session.query(DatasetModel.uri) - .filter(DatasetModel.is_orphaned) - .order_by(DatasetModel.uri) + assert non_orphaned_assets == ["ds1", "ds3"] + orphaned_assets = [ + asset.uri + for asset in session.query(AssetModel.uri) + .filter(AssetModel.is_orphaned) + .order_by(AssetModel.uri) ] - assert orphaned_datasets == ["ds2", "ds4"] + assert orphaned_assets == ["ds2", "ds4"] - # Now add references to the two unreferenced datasets - dag1 = DAG(dag_id="datasets-1", start_date=DEFAULT_DATE, schedule=[dataset1, dataset2]) - BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[dataset3, dataset4]) + # Now add references to the two unreferenced assets + dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1, asset2]) + BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3, asset4]) DAG.bulk_write_to_db([dag1], session=session) # and count the orphans and non-orphans - non_orphaned_dataset_count = session.query(DatasetModel).filter(~DatasetModel.is_orphaned).count() - assert non_orphaned_dataset_count == 4 - orphaned_dataset_count = session.query(DatasetModel).filter(DatasetModel.is_orphaned).count() - assert orphaned_dataset_count == 0 - - def test_bulk_write_to_db_dataset_aliases(self): - """ - Ensure that dataset aliases referenced in a dag are correctly loaded into the database. - """ - dag_id1 = "test_dataset_alias_dag1" - dag_id2 = "test_dataset_alias_dag2" - task_id = "test_dataset_task" - da1 = DatasetAlias(name="da1") - da2 = DatasetAlias(name="da2") - da2_2 = DatasetAlias(name="da2") - da3 = DatasetAlias(name="da3") + non_orphaned_asset_count = session.query(AssetModel).filter(~AssetModel.is_orphaned).count() + assert non_orphaned_asset_count == 4 + orphaned_asset_count = session.query(AssetModel).filter(AssetModel.is_orphaned).count() + assert orphaned_asset_count == 0 + + def test_bulk_write_to_db_asset_aliases(self): + """ + Ensure that asset aliases referenced in a dag are correctly loaded into the database. + """ + dag_id1 = "test_asset_alias_dag1" + dag_id2 = "test_asset_alias_dag2" + task_id = "test_asset_task" + asset_alias_1 = AssetAlias(name="asset_alias_1") + asset_alias_2 = AssetAlias(name="asset_alias_2") + asset_alias_2_2 = AssetAlias(name="asset_alias_2") + asset_alias_3 = AssetAlias(name="asset_alias_3") dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=None) - EmptyOperator(task_id=task_id, dag=dag1, outlets=[da1, da2, da3]) + EmptyOperator(task_id=task_id, dag=dag1, outlets=[asset_alias_1, asset_alias_2, asset_alias_3]) dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) - EmptyOperator(task_id=task_id, dag=dag2, outlets=[da2_2, da3]) + EmptyOperator(task_id=task_id, dag=dag2, outlets=[asset_alias_2_2, asset_alias_3]) session = settings.Session() DAG.bulk_write_to_db([dag1, dag2], session=session) session.commit() - stored_dataset_aliases = {x.name: x for x in session.query(DatasetAliasModel).all()} - da1_orm = stored_dataset_aliases[da1.name] - da2_orm = stored_dataset_aliases[da2.name] - da3_orm = stored_dataset_aliases[da3.name] - assert da1_orm.name == "da1" - assert da2_orm.name == "da2" - assert da3_orm.name == "da3" - assert len(stored_dataset_aliases) == 3 + stored_asset_alias_models = {x.name: x for x in session.query(AssetAliasModel).all()} + asset_alias_1_orm = stored_asset_alias_models[asset_alias_1.name] + asset_alias_2_orm = stored_asset_alias_models[asset_alias_2.name] + asset_alias_3_orm = stored_asset_alias_models[asset_alias_3.name] + assert asset_alias_1_orm.name == "asset_alias_1" + assert asset_alias_2_orm.name == "asset_alias_2" + assert asset_alias_3_orm.name == "asset_alias_3" + assert len(stored_asset_alias_models) == 3 def test_sync_to_db(self): dag = DAG("dag", start_date=DEFAULT_DATE, schedule=None) @@ -1664,10 +1664,10 @@ def test_timetable_and_description_from_schedule_arg( assert dag.timetable == expected_timetable assert dag.timetable.description == interval_description - def test_timetable_and_description_from_dataset(self): - dag = DAG("test_schedule_arg", schedule=[Dataset(uri="hello")], start_date=TEST_DATE) - assert dag.timetable == DatasetTriggeredTimetable(Dataset(uri="hello")) - assert dag.timetable.description == "Triggered by datasets" + def test_timetable_and_description_from_asset(self): + dag = DAG("test_schedule_interval_arg", schedule=[Asset(uri="hello")], start_date=TEST_DATE) + assert dag.timetable == AssetTriggeredTimetable(Asset(uri="hello")) + assert dag.timetable.description == "Triggered by assets" @pytest.mark.parametrize( "timetable, expected_description", @@ -2400,7 +2400,7 @@ def test_continuous_schedule_linmits_max_active_runs(self): class TestDagModel: def _clean(self): clear_db_dags() - clear_db_datasets() + clear_db_assets() clear_db_runs() def setup_method(self): @@ -2432,13 +2432,13 @@ def test_dags_needing_dagruns_not_too_early(self): session.rollback() session.close() - def test_dags_needing_dagruns_datasets(self, dag_maker, session): - dataset = Dataset(uri="hello") + def test_dags_needing_dagruns_assets(self, dag_maker, session): + asset = Asset(uri="hello") with dag_maker( session=session, dag_id="my_dag", max_active_runs=1, - schedule=[dataset], + schedule=[asset], start_date=pendulum.now().add(days=-2), ) as dag: EmptyOperator(task_id="dummy") @@ -2450,8 +2450,8 @@ def test_dags_needing_dagruns_datasets(self, dag_maker, session): # add queue records so we'll need a run dag_model = session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).one() - dataset_model: DatasetModel = dag_model.schedule_datasets[0] - session.add(DatasetDagRunQueue(dataset_id=dataset_model.id, target_dag_id=dag_model.dag_id)) + asset_model: AssetModel = dag_model.schedule_datasets[0] + session.add(AssetDagRunQueue(dataset_id=asset_model.id, target_dag_id=dag_model.dag_id)) session.flush() query, _ = DagModel.dags_needing_dagruns(session) dag_models = query.all() @@ -2474,19 +2474,19 @@ def test_dags_needing_dagruns_datasets(self, dag_maker, session): dag_models = query.all() assert dag_models == [dag_model] - def test_dags_needing_dagruns_dataset_aliases(self, dag_maker, session): - # link dataset_alias hello_alias to dataset hello - dataset_model = DatasetModel(uri="hello") - dataset_alias_model = DatasetAliasModel(name="hello_alias") - dataset_alias_model.datasets.append(dataset_model) - session.add_all([dataset_model, dataset_alias_model]) + def test_dags_needing_dagruns_asset_aliases(self, dag_maker, session): + # link asset_alias hello_alias to asset hello + asset_model = AssetModel(uri="hello") + asset_alias_model = AssetAliasModel(name="hello_alias") + asset_alias_model.datasets.append(asset_model) + session.add_all([asset_model, asset_alias_model]) session.commit() with dag_maker( session=session, dag_id="my_dag", max_active_runs=1, - schedule=[DatasetAlias(name="hello_alias")], + schedule=[AssetAlias(name="hello_alias")], start_date=pendulum.now().add(days=-2), ): EmptyOperator(task_id="dummy") @@ -2498,8 +2498,8 @@ def test_dags_needing_dagruns_dataset_aliases(self, dag_maker, session): # add queue records so we'll need a run dag_model = dag_maker.dag_model - dataset_model: DatasetModel = dag_model.schedule_datasets[0] - session.add(DatasetDagRunQueue(dataset_id=dataset_model.id, target_dag_id=dag_model.dag_id)) + asset_model: AssetModel = dag_model.schedule_datasets[0] + session.add(AssetDagRunQueue(dataset_id=asset_model.id, target_dag_id=dag_model.dag_id)) session.flush() query, _ = DagModel.dags_needing_dagruns(session) dag_models = query.all() @@ -2663,20 +2663,20 @@ def test__processor_dags_folder(self, session): assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER @pytest.mark.need_serialized_dag - def test_dags_needing_dagruns_dataset_triggered_dag_info_queued_times(self, session, dag_maker): - dataset1 = Dataset(uri="ds1") - dataset2 = Dataset(uri="ds2") + def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, session, dag_maker): + asset1 = Asset(uri="ds1") + asset2 = Asset(uri="ds2") - for dag_id, dataset in [("datasets-1", dataset1), ("datasets-2", dataset2)]: + for dag_id, asset in [("assets-1", asset1), ("assets-2", asset2)]: with dag_maker(dag_id=dag_id, start_date=timezone.utcnow(), session=session): - EmptyOperator(task_id="task", outlets=[dataset]) + EmptyOperator(task_id="task", outlets=[asset]) dr = dag_maker.create_dagrun() - ds_id = session.query(DatasetModel.id).filter_by(uri=dataset.uri).scalar() + asset_id = session.query(AssetModel.id).filter_by(uri=asset.uri).scalar() session.add( - DatasetEvent( - dataset_id=ds_id, + AssetEvent( + dataset_id=asset_id, source_task_id="task", source_dag_id=dr.dag_id, source_run_id=dr.run_id, @@ -2684,18 +2684,20 @@ def test_dags_needing_dagruns_dataset_triggered_dag_info_queued_times(self, sess ) ) - ds1_id = session.query(DatasetModel.id).filter_by(uri=dataset1.uri).scalar() - ds2_id = session.query(DatasetModel.id).filter_by(uri=dataset2.uri).scalar() + asset1_id = session.query(AssetModel.id).filter_by(uri=asset1.uri).scalar() + asset2_id = session.query(AssetModel.id).filter_by(uri=asset2.uri).scalar() - with dag_maker(dag_id="datasets-consumer-multiple", schedule=[dataset1, dataset2]) as dag: + with dag_maker(dag_id="assets-consumer-multiple", schedule=[asset1, asset2]) as dag: pass session.flush() session.add_all( [ - DatasetDagRunQueue(dataset_id=ds1_id, target_dag_id=dag.dag_id, created_at=DEFAULT_DATE), - DatasetDagRunQueue( - dataset_id=ds2_id, target_dag_id=dag.dag_id, created_at=DEFAULT_DATE + timedelta(hours=1) + AssetDagRunQueue(dataset_id=asset1_id, target_dag_id=dag.dag_id, created_at=DEFAULT_DATE), + AssetDagRunQueue( + dataset_id=asset2_id, + target_dag_id=dag.dag_id, + created_at=DEFAULT_DATE + timedelta(hours=1), ), ] ) @@ -2708,16 +2710,16 @@ def test_dags_needing_dagruns_dataset_triggered_dag_info_queued_times(self, sess assert first_queued_time == DEFAULT_DATE assert last_queued_time == DEFAULT_DATE + timedelta(hours=1) - def test_dataset_expression(self, session: Session) -> None: + def test_asset_expression(self, session: Session) -> None: dag = DAG( - dag_id="test_dag_dataset_expression", - schedule=DatasetAny( - Dataset("s3://dag1/output_1.txt", {"hi": "bye"}), - DatasetAll( - Dataset("s3://dag2/output_1.txt", {"hi": "bye"}), - Dataset("s3://dag3/output_3.txt", {"hi": "bye"}), + dag_id="test_dag_asset_expression", + schedule=AssetAny( + Asset("s3://dag1/output_1.txt", {"hi": "bye"}), + AssetAll( + Asset("s3://dag2/output_1.txt", {"hi": "bye"}), + Asset("s3://dag3/output_3.txt", {"hi": "bye"}), ), - DatasetAlias(name="test_name"), + AssetAlias(name="test_name"), ), start_date=datetime.datetime.min, ) @@ -3424,43 +3426,43 @@ def test__tags_mutable(): @pytest.mark.need_serialized_dag -def test_get_dataset_triggered_next_run_info(dag_maker, clear_datasets): - dataset1 = Dataset(uri="ds1") - dataset2 = Dataset(uri="ds2") - dataset3 = Dataset(uri="ds3") - with dag_maker(dag_id="datasets-1", schedule=[dataset2]): +def test_get_asset_triggered_next_run_info(dag_maker, clear_assets): + asset1 = Asset(uri="ds1") + asset2 = Asset(uri="ds2") + asset3 = Asset(uri="ds3") + with dag_maker(dag_id="assets-1", schedule=[asset2]): pass dag1 = dag_maker.dag - with dag_maker(dag_id="datasets-2", schedule=[dataset1, dataset2]): + with dag_maker(dag_id="assets-2", schedule=[asset1, asset2]): pass dag2 = dag_maker.dag - with dag_maker(dag_id="datasets-3", schedule=[dataset1, dataset2, dataset3]): + with dag_maker(dag_id="assets-3", schedule=[asset1, asset2, asset3]): pass dag3 = dag_maker.dag session = dag_maker.session - ds1_id = session.query(DatasetModel.id).filter_by(uri=dataset1.uri).scalar() + asset1_id = session.query(AssetModel.id).filter_by(uri=asset1.uri).scalar() session.bulk_save_objects( [ - DatasetDagRunQueue(dataset_id=ds1_id, target_dag_id=dag2.dag_id), - DatasetDagRunQueue(dataset_id=ds1_id, target_dag_id=dag3.dag_id), + AssetDagRunQueue(dataset_id=asset1_id, target_dag_id=dag2.dag_id), + AssetDagRunQueue(dataset_id=asset1_id, target_dag_id=dag3.dag_id), ] ) session.flush() - datasets = session.query(DatasetModel.uri).order_by(DatasetModel.id).all() + assets = session.query(AssetModel.uri).order_by(AssetModel.id).all() - info = get_dataset_triggered_next_run_info([dag1.dag_id], session=session) + info = get_asset_triggered_next_run_info([dag1.dag_id], session=session) assert info[dag1.dag_id] == { "ready": 0, "total": 1, - "uri": datasets[0].uri, + "uri": assets[0].uri, } # This time, check both dag2 and dag3 at the same time (tests filtering) - info = get_dataset_triggered_next_run_info([dag2.dag_id, dag3.dag_id], session=session) + info = get_asset_triggered_next_run_info([dag2.dag_id, dag3.dag_id], session=session) assert info[dag2.dag_id] == { "ready": 1, "total": 2, @@ -3474,19 +3476,19 @@ def test_get_dataset_triggered_next_run_info(dag_maker, clear_datasets): @pytest.mark.need_serialized_dag -def test_get_dataset_triggered_next_run_info_with_unresolved_dataset_alias(dag_maker, clear_datasets): - dataset_alias1 = DatasetAlias(name="alias") +def test_get_dataset_triggered_next_run_info_with_unresolved_dataset_alias(dag_maker, clear_assets): + dataset_alias1 = AssetAlias(name="alias") with dag_maker(dag_id="dag-1", schedule=[dataset_alias1]): pass dag1 = dag_maker.dag session = dag_maker.session session.flush() - info = get_dataset_triggered_next_run_info([dag1.dag_id], session=session) + info = get_asset_triggered_next_run_info([dag1.dag_id], session=session) assert info == {} dag1_model = DagModel.get_dagmodel(dag1.dag_id) - assert dag1_model.get_dataset_triggered_next_run_info(session=session) is None + assert dag1_model.get_asset_triggered_next_run_info(session=session) is None def test_dag_uses_timetable_for_run_id(session): diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index d2f70ce69314b..c7dacaeb291e4 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -85,7 +85,7 @@ def _clean_db(): db.clear_db_pools() db.clear_db_dags() db.clear_db_variables() - db.clear_db_datasets() + db.clear_db_assets() db.clear_db_xcom() db.clear_db_task_fail() diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index 9f83280f8eb38..b8fddc655dae5 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -25,7 +25,7 @@ import pytest import airflow.example_dags as example_dags_module -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.dagcode import DagCode @@ -237,16 +237,16 @@ def test_order_of_deps_is_consistent(self): dag_id="example", start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), schedule=[ - Dataset("1"), - Dataset("2"), - Dataset("3"), - Dataset("4"), - Dataset("5"), + Asset("1"), + Asset("2"), + Asset("3"), + Asset("4"), + Asset("5"), ], ) as dag6: BashOperator( task_id="any", - outlets=[Dataset("0*"), Dataset("6*")], + outlets=[Asset("0*"), Asset("6*")], bash_command="sleep 5", ) deps_order = [x["dependency_id"] for x in SerializedDAG.serialize_dag(dag6)["dag_dependencies"]] diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index d2922db267805..8c334366f0488 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -39,7 +39,7 @@ from sqlalchemy import select from airflow import settings -from airflow.datasets import DatasetAlias +from airflow.assets import AssetAlias from airflow.decorators import task, task_group from airflow.example_dags.plugins.workday import AfterWorkdayTimetable from airflow.exceptions import ( @@ -53,11 +53,11 @@ UnmappableXComTypePushed, XComForMappingNotPushed, ) +from airflow.models.asset import AssetAliasModel, AssetDagRunQueue, AssetEvent, AssetModel from airflow.models.connection import Connection from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun -from airflow.models.dataset import DatasetAliasModel, DatasetDagRunQueue, DatasetEvent, DatasetModel from airflow.models.expandinput import EXPAND_INPUT_EMPTY, NotFullyPopulated from airflow.models.param import process_params from airflow.models.pool import Pool @@ -158,7 +158,7 @@ def clean_db(): db.clear_db_task_fail() db.clear_rendered_ti_fields() db.clear_db_task_reschedule() - db.clear_db_datasets() + db.clear_db_assets() db.clear_db_xcom() def setup_method(self): @@ -2269,16 +2269,16 @@ def test_success_callback_no_race_condition(self, create_task_instance): assert ti.state == State.SUCCESS @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_outlet_datasets(self, create_task_instance): + def test_outlet_assets(self, create_task_instance): """ - Verify that when we have an outlet dataset on a task, and the task - completes successfully, a DatasetDagRunQueue is logged. + Verify that when we have an outlet asset on a task, and the task + completes successfully, a AssetDagRunQueue is logged. """ - from airflow.example_dags import example_datasets - from airflow.example_dags.example_datasets import dag1 + from airflow.example_dags import example_assets + from airflow.example_dags.example_assets import dag1 session = settings.Session() - dagbag = DagBag(dag_folder=example_datasets.__file__) + dagbag = DagBag(dag_folder=example_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) dagbag.sync_to_db(session=session) run_id = str(uuid4()) @@ -2293,54 +2293,54 @@ def test_outlet_datasets(self, create_task_instance): ti.refresh_from_db() assert ti.state == TaskInstanceState.SUCCESS - # check that no other dataset events recorded + # check that no other asset events recorded event = ( - session.query(DatasetEvent) - .join(DatasetEvent.dataset) - .filter(DatasetEvent.source_task_instance == ti) + session.query(AssetEvent) + .join(AssetEvent.dataset) + .filter(AssetEvent.source_task_instance == ti) .one() ) assert event assert event.dataset - # check that one queue record created for each dag that depends on dataset 1 - assert session.query(DatasetDagRunQueue.target_dag_id).filter_by( - dataset_id=event.dataset.id - ).order_by(DatasetDagRunQueue.target_dag_id).all() == [ - ("conditional_dataset_and_time_based_timetable",), - ("consume_1_and_2_with_dataset_expressions",), - ("consume_1_or_2_with_dataset_expressions",), - ("consume_1_or_both_2_and_3_with_dataset_expressions",), - ("dataset_consumes_1",), - ("dataset_consumes_1_and_2",), - ("dataset_consumes_1_never_scheduled",), + # check that one queue record created for each dag that depends on asset 1 + assert session.query(AssetDagRunQueue.target_dag_id).filter_by(dataset_id=event.dataset.id).order_by( + AssetDagRunQueue.target_dag_id + ).all() == [ + ("asset_consumes_1",), + ("asset_consumes_1_and_2",), + ("asset_consumes_1_never_scheduled",), + ("conditional_asset_and_time_based_timetable",), + ("consume_1_and_2_with_asset_expressions",), + ("consume_1_or_2_with_asset_expressions",), + ("consume_1_or_both_2_and_3_with_asset_expressions",), ] - # check that one event record created for dataset1 and this TI - assert session.query(DatasetModel.uri).join(DatasetEvent.dataset).filter( - DatasetEvent.source_task_instance == ti + # check that one event record created for asset1 and this TI + assert session.query(AssetModel.uri).join(AssetEvent.dataset).filter( + AssetEvent.source_task_instance == ti ).one() == ("s3://dag1/output_1.txt",) - # check that the dataset event has an earlier timestamp than the DDRQ's - ddrq_timestamps = ( - session.query(DatasetDagRunQueue.created_at).filter_by(dataset_id=event.dataset.id).all() + # check that the asset event has an earlier timestamp than the ADRQ's + adrq_timestamps = ( + session.query(AssetDagRunQueue.created_at).filter_by(dataset_id=event.dataset.id).all() ) assert all( - event.timestamp < ddrq_timestamp for (ddrq_timestamp,) in ddrq_timestamps - ), f"Some items in {[str(t) for t in ddrq_timestamps]} are earlier than {event.timestamp}" + event.timestamp < adrq_timestamp for (adrq_timestamp,) in adrq_timestamps + ), f"Some items in {[str(t) for t in adrq_timestamps]} are earlier than {event.timestamp}" @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_outlet_datasets_failed(self, create_task_instance): + def test_outlet_assets_failed(self, create_task_instance): """ - Verify that when we have an outlet dataset on a task, and the task - failed, a DatasetDagRunQueue is not logged, and a DatasetEvent is + Verify that when we have an outlet asset on a task, and the task + failed, a AssetDagRunQueue is not logged, and an AssetEvent is not generated """ - from tests.dags import test_datasets - from tests.dags.test_datasets import dag_with_fail_task + from tests.dags import test_assets + from tests.dags.test_assets import dag_with_fail_task session = settings.Session() - dagbag = DagBag(dag_folder=test_datasets.__file__) + dagbag = DagBag(dag_folder=test_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) dagbag.sync_to_db(session=session) run_id = str(uuid4()) @@ -2356,10 +2356,10 @@ def test_outlet_datasets_failed(self, create_task_instance): assert ti.state == TaskInstanceState.FAILED # check that no dagruns were queued - assert session.query(DatasetDagRunQueue).count() == 0 + assert session.query(AssetDagRunQueue).count() == 0 - # check that no dataset events were generated - assert session.query(DatasetEvent).count() == 0 + # check that no asset events were generated + assert session.query(AssetEvent).count() == 0 @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_mapped_current_state(self, dag_maker): @@ -2386,17 +2386,17 @@ def raise_an_exception(placeholder: int): assert task_instance.current_state() == TaskInstanceState.SUCCESS @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_outlet_datasets_skipped(self): + def test_outlet_assets_skipped(self): """ - Verify that when we have an outlet dataset on a task, and the task - is skipped, a DatasetDagRunQueue is not logged, and a DatasetEvent is + Verify that when we have an outlet asset on a task, and the task + is skipped, a AssetDagRunQueue is not logged, and an AssetEvent is not generated """ - from tests.dags import test_datasets - from tests.dags.test_datasets import dag_with_skip_task + from tests.dags import test_assets + from tests.dags.test_assets import dag_with_skip_task session = settings.Session() - dagbag = DagBag(dag_folder=test_datasets.__file__) + dagbag = DagBag(dag_folder=test_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) dagbag.sync_to_db(session=session) run_id = str(uuid4()) @@ -2411,30 +2411,30 @@ def test_outlet_datasets_skipped(self): assert ti.state == TaskInstanceState.SKIPPED # check that no dagruns were queued - assert session.query(DatasetDagRunQueue).count() == 0 + assert session.query(AssetDagRunQueue).count() == 0 - # check that no dataset events were generated - assert session.query(DatasetEvent).count() == 0 + # check that no asset events were generated + assert session.query(AssetEvent).count() == 0 @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_outlet_dataset_extra(self, dag_maker, session): - from airflow.datasets import Dataset + def test_outlet_asset_extra(self, dag_maker, session): + from airflow.assets import Asset with dag_maker(schedule=None, session=session) as dag: - @task(outlets=Dataset("test_outlet_dataset_extra_1")) + @task(outlets=Asset("test_outlet_asset_extra_1")) def write1(*, outlet_events): - outlet_events["test_outlet_dataset_extra_1"].extra = {"foo": "bar"} + outlet_events["test_outlet_asset_extra_1"].extra = {"foo": "bar"} write1() def _write2_post_execute(context, _): - context["outlet_events"]["test_outlet_dataset_extra_2"].extra = {"x": 1} + context["outlet_events"]["test_outlet_asset_extra_2"].extra = {"x": 1} BashOperator( task_id="write2", bash_command=":", - outlets=Dataset("test_outlet_dataset_extra_2"), + outlets=Asset("test_outlet_asset_extra_2"), post_execute=_write2_post_execute, ) @@ -2443,30 +2443,30 @@ def _write2_post_execute(context, _): ti.refresh_from_task(dag.get_task(ti.task_id)) ti.run(session=session) - events = dict(iter(session.execute(select(DatasetEvent.source_task_id, DatasetEvent)))) + events = dict(iter(session.execute(select(AssetEvent.source_task_id, AssetEvent)))) assert set(events) == {"write1", "write2"} assert events["write1"].source_dag_id == dr.dag_id assert events["write1"].source_run_id == dr.run_id assert events["write1"].source_task_id == "write1" - assert events["write1"].dataset.uri == "test_outlet_dataset_extra_1" + assert events["write1"].dataset.uri == "test_outlet_asset_extra_1" assert events["write1"].extra == {"foo": "bar"} assert events["write2"].source_dag_id == dr.dag_id assert events["write2"].source_run_id == dr.run_id assert events["write2"].source_task_id == "write2" - assert events["write2"].dataset.uri == "test_outlet_dataset_extra_2" + assert events["write2"].dataset.uri == "test_outlet_asset_extra_2" assert events["write2"].extra == {"x": 1} @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_outlet_dataset_extra_ignore_different(self, dag_maker, session): - from airflow.datasets import Dataset + def test_outlet_asset_extra_ignore_different(self, dag_maker, session): + from airflow.assets import Asset with dag_maker(schedule=None, session=session): - @task(outlets=Dataset("test_outlet_dataset_extra")) + @task(outlets=Asset("test_outlet_asset_extra")) def write(*, outlet_events): - outlet_events["test_outlet_dataset_extra"].extra = {"one": 1} + outlet_events["test_outlet_asset_extra"].extra = {"one": 1} outlet_events["different_uri"].extra = {"foo": "bar"} # Will be silently dropped. write() @@ -2474,34 +2474,34 @@ def write(*, outlet_events): dr: DagRun = dag_maker.create_dagrun() dr.get_task_instance("write").run(session=session) - event = session.scalars(select(DatasetEvent)).one() + event = session.scalars(select(AssetEvent)).one() assert event.source_dag_id == dr.dag_id assert event.source_run_id == dr.run_id assert event.source_task_id == "write" assert event.extra == {"one": 1} @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_outlet_dataset_extra_yield(self, dag_maker, session): - from airflow.datasets import Dataset - from airflow.datasets.metadata import Metadata + def test_outlet_asset_extra_yield(self, dag_maker, session): + from airflow.assets import Asset + from airflow.assets.metadata import Metadata with dag_maker(schedule=None, session=session) as dag: - @task(outlets=Dataset("test_outlet_dataset_extra_1")) + @task(outlets=Asset("test_outlet_asset_extra_1")) def write1(): result = "write_1 result" - yield Metadata("test_outlet_dataset_extra_1", {"foo": "bar"}) + yield Metadata("test_outlet_asset_extra_1", {"foo": "bar"}) return result write1() def _write2_post_execute(context, result): - yield Metadata("test_outlet_dataset_extra_2", {"x": 1}) + yield Metadata("test_outlet_asset_extra_2", {"x": 1}) BashOperator( task_id="write2", bash_command=":", - outlets=Dataset("test_outlet_dataset_extra_2"), + outlets=Asset("test_outlet_asset_extra_2"), post_execute=_write2_post_execute, ) @@ -2515,37 +2515,37 @@ def _write2_post_execute(context, result): ).one() assert xcom.value == "write_1 result" - events = dict(iter(session.execute(select(DatasetEvent.source_task_id, DatasetEvent)))) + events = dict(iter(session.execute(select(AssetEvent.source_task_id, AssetEvent)))) assert set(events) == {"write1", "write2"} assert events["write1"].source_dag_id == dr.dag_id assert events["write1"].source_run_id == dr.run_id assert events["write1"].source_task_id == "write1" - assert events["write1"].dataset.uri == "test_outlet_dataset_extra_1" + assert events["write1"].dataset.uri == "test_outlet_asset_extra_1" assert events["write1"].extra == {"foo": "bar"} assert events["write2"].source_dag_id == dr.dag_id assert events["write2"].source_run_id == dr.run_id assert events["write2"].source_task_id == "write2" - assert events["write2"].dataset.uri == "test_outlet_dataset_extra_2" + assert events["write2"].dataset.uri == "test_outlet_asset_extra_2" assert events["write2"].extra == {"x": 1} @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_outlet_dataset_alias(self, dag_maker, session): - from airflow.datasets import Dataset, DatasetAlias + def test_outlet_asset_alias(self, dag_maker, session): + from airflow.assets import Asset, AssetAlias - ds_uri = "test_outlet_dataset_alias_test_case_ds" - dsa_name_1 = "test_outlet_dataset_alias_test_case_dsa_1" + asset_uri = "test_outlet_asset_alias_test_case_ds" + alias_name_1 = "test_outlet_asset_alias_test_case_asset_alias_1" - ds1 = DatasetModel(id=1, uri=ds_uri) + ds1 = AssetModel(id=1, uri=asset_uri) session.add(ds1) session.commit() with dag_maker(dag_id="producer_dag", schedule=None, session=session) as dag: - @task(outlets=DatasetAlias(dsa_name_1)) + @task(outlets=AssetAlias(alias_name_1)) def producer(*, outlet_events): - outlet_events[dsa_name_1].add(Dataset(ds_uri)) + outlet_events[alias_name_1].add(Asset(asset_uri)) producer() @@ -2556,7 +2556,7 @@ def producer(*, outlet_events): ti.run(session=session) producer_events = session.execute( - select(DatasetEvent).where(DatasetEvent.source_task_id == "producer") + select(AssetEvent).where(AssetEvent.source_task_id == "producer") ).fetchall() assert len(producer_events) == 1 @@ -2566,39 +2566,45 @@ def producer(*, outlet_events): assert producer_event.source_dag_id == "producer_dag" assert producer_event.source_run_id == "test" assert producer_event.source_map_index == -1 - assert producer_event.dataset.uri == ds_uri + assert producer_event.dataset.uri == asset_uri assert len(producer_event.source_aliases) == 1 assert producer_event.extra == {} - assert producer_event.source_aliases[0].name == dsa_name_1 + assert producer_event.source_aliases[0].name == alias_name_1 - ds_obj = session.scalar(select(DatasetModel).where(DatasetModel.uri == ds_uri)) - assert len(ds_obj.aliases) == 1 - assert ds_obj.aliases[0].name == dsa_name_1 + asset_obj = session.scalar(select(AssetModel).where(AssetModel.uri == asset_uri)) + assert len(asset_obj.aliases) == 1 + assert asset_obj.aliases[0].name == alias_name_1 - dsa_obj = session.scalar(select(DatasetAliasModel).where(DatasetAliasModel.name == dsa_name_1)) - assert len(dsa_obj.datasets) == 1 - assert dsa_obj.datasets[0].uri == ds_uri + asset_alias_obj = session.scalar(select(AssetAliasModel).where(AssetAliasModel.name == alias_name_1)) + assert len(asset_alias_obj.datasets) == 1 + assert asset_alias_obj.datasets[0].uri == asset_uri @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_outlet_multiple_dataset_alias(self, dag_maker, session): - from airflow.datasets import Dataset, DatasetAlias + def test_outlet_multiple_asset_alias(self, dag_maker, session): + from airflow.assets import Asset, AssetAlias - ds_uri = "test_outlet_mdsa_ds" - dsa_name_1 = "test_outlet_mdsa_dsa_1" - dsa_name_2 = "test_outlet_mdsa_dsa_2" - dsa_name_3 = "test_outlet_mdsa_dsa_3" + asset_uri = "test_outlet_maa_ds" + asset_alias_name_1 = "test_outlet_maa_asset_alias_1" + asset_alias_name_2 = "test_outlet_maa_asset_alias_2" + asset_alias_name_3 = "test_outlet_maa_asset_alias_3" - ds1 = DatasetModel(id=1, uri=ds_uri) + ds1 = AssetModel(id=1, uri=asset_uri) session.add(ds1) session.commit() with dag_maker(dag_id="producer_dag", schedule=None, session=session) as dag: - @task(outlets=[DatasetAlias(dsa_name_1), DatasetAlias(dsa_name_2), DatasetAlias(dsa_name_3)]) + @task( + outlets=[ + AssetAlias(asset_alias_name_1), + AssetAlias(asset_alias_name_2), + AssetAlias(asset_alias_name_3), + ] + ) def producer(*, outlet_events): - outlet_events[dsa_name_1].add(Dataset(ds_uri)) - outlet_events[dsa_name_2].add(Dataset(ds_uri)) - outlet_events[dsa_name_3].add(Dataset(ds_uri), extra={"k": "v"}) + outlet_events[asset_alias_name_1].add(Asset(asset_uri)) + outlet_events[asset_alias_name_2].add(Asset(asset_uri)) + outlet_events[asset_alias_name_3].add(Asset(asset_uri), extra={"k": "v"}) producer() @@ -2609,7 +2615,7 @@ def producer(*, outlet_events): ti.run(session=session) producer_events = session.execute( - select(DatasetEvent).where(DatasetEvent.source_task_id == "producer") + select(AssetEvent).where(AssetEvent.source_task_id == "producer") ).fetchall() assert len(producer_events) == 2 @@ -2619,44 +2625,51 @@ def producer(*, outlet_events): assert producer_event.source_dag_id == "producer_dag" assert producer_event.source_run_id == "test" assert producer_event.source_map_index == -1 - assert producer_event.dataset.uri == ds_uri + assert producer_event.dataset.uri == asset_uri if not producer_event.extra: assert producer_event.extra == {} assert len(producer_event.source_aliases) == 2 - assert {alias.name for alias in producer_event.source_aliases} == {dsa_name_1, dsa_name_2} + assert {alias.name for alias in producer_event.source_aliases} == { + asset_alias_name_1, + asset_alias_name_2, + } else: assert producer_event.extra == {"k": "v"} assert len(producer_event.source_aliases) == 1 - assert producer_event.source_aliases[0].name == dsa_name_3 - - ds_obj = session.scalar(select(DatasetModel).where(DatasetModel.uri == ds_uri)) - assert len(ds_obj.aliases) == 3 - assert {alias.name for alias in ds_obj.aliases} == {dsa_name_1, dsa_name_2, dsa_name_3} + assert producer_event.source_aliases[0].name == asset_alias_name_3 + + asset_obj = session.scalar(select(AssetModel).where(AssetModel.uri == asset_uri)) + assert len(asset_obj.aliases) == 3 + assert {alias.name for alias in asset_obj.aliases} == { + asset_alias_name_1, + asset_alias_name_2, + asset_alias_name_3, + } - dsa_objs = session.scalars(select(DatasetAliasModel)).all() - assert len(dsa_objs) == 3 - for dsa_obj in dsa_objs: - assert len(dsa_obj.datasets) == 1 - assert dsa_obj.datasets[0].uri == ds_uri + asset_alias_objs = session.scalars(select(AssetAliasModel)).all() + assert len(asset_alias_objs) == 3 + for asset_alias_obj in asset_alias_objs: + assert len(asset_alias_obj.datasets) == 1 + assert asset_alias_obj.datasets[0].uri == asset_uri @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_outlet_dataset_alias_through_metadata(self, dag_maker, session): - from airflow.datasets import DatasetAlias - from airflow.datasets.metadata import Metadata + def test_outlet_asset_alias_through_metadata(self, dag_maker, session): + from airflow.assets import AssetAlias + from airflow.assets.metadata import Metadata - ds_uri = "test_outlet_dataset_alias_through_metadata_ds" - dsa_name = "test_outlet_dataset_alias_through_metadata_dsa" + asset_uri = "test_outlet_asset_alias_through_metadata_ds" + asset_alias_name = "test_outlet_asset_alias_through_metadata_asset_alias" - ds1 = DatasetModel(id=1, uri="test_outlet_dataset_alias_through_metadata_ds") + ds1 = AssetModel(id=1, uri="test_outlet_asset_alias_through_metadata_ds") session.add(ds1) session.commit() with dag_maker(dag_id="producer_dag", schedule=None, session=session) as dag: - @task(outlets=DatasetAlias(dsa_name)) + @task(outlets=AssetAlias(asset_alias_name)) def producer(*, outlet_events): - yield Metadata(ds_uri, extra={"key": "value"}, alias=dsa_name) + yield Metadata(asset_uri, extra={"key": "value"}, alias=asset_alias_name) producer() @@ -2666,37 +2679,37 @@ def producer(*, outlet_events): ti.refresh_from_task(dag.get_task(ti.task_id)) ti.run(session=session) - producer_event = session.scalar(select(DatasetEvent).where(DatasetEvent.source_task_id == "producer")) + producer_event = session.scalar(select(AssetEvent).where(AssetEvent.source_task_id == "producer")) assert producer_event.source_task_id == "producer" assert producer_event.source_dag_id == "producer_dag" assert producer_event.source_run_id == "test" assert producer_event.source_map_index == -1 - assert producer_event.dataset.uri == ds_uri + assert producer_event.dataset.uri == asset_uri assert producer_event.extra == {"key": "value"} assert len(producer_event.source_aliases) == 1 - assert producer_event.source_aliases[0].name == dsa_name + assert producer_event.source_aliases[0].name == asset_alias_name - ds_obj = session.scalar(select(DatasetModel).where(DatasetModel.uri == ds_uri)) - assert len(ds_obj.aliases) == 1 - assert ds_obj.aliases[0].name == dsa_name + asset_obj = session.scalar(select(AssetModel).where(AssetModel.uri == asset_uri)) + assert len(asset_obj.aliases) == 1 + assert asset_obj.aliases[0].name == asset_alias_name - dsa_obj = session.scalar(select(DatasetAliasModel)) - assert len(dsa_obj.datasets) == 1 - assert dsa_obj.datasets[0].uri == ds_uri + asset_alias_obj = session.scalar(select(AssetAliasModel)) + assert len(asset_alias_obj.datasets) == 1 + assert asset_alias_obj.datasets[0].uri == asset_uri @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_outlet_dataset_alias_dataset_not_exists(self, dag_maker, session): - from airflow.datasets import Dataset, DatasetAlias + def test_outlet_asset_alias_asset_not_exists(self, dag_maker, session): + from airflow.assets import Asset, AssetAlias - dsa_name = "test_outlet_dataset_alias_dataset_not_exists_dsa" - ds_uri = "did_not_exists" + asset_alias_name = "test_outlet_asset_alias_asset_not_exists_asset_alias" + asset_uri = "did_not_exists" with dag_maker(dag_id="producer_dag", schedule=None, session=session) as dag: - @task(outlets=DatasetAlias(dsa_name)) + @task(outlets=AssetAlias(asset_alias_name)) def producer(*, outlet_events): - outlet_events[dsa_name].add(Dataset(ds_uri), extra={"key": "value"}) + outlet_events[asset_alias_name].add(Asset(asset_uri), extra={"key": "value"}) producer() @@ -2706,51 +2719,51 @@ def producer(*, outlet_events): ti.refresh_from_task(dag.get_task(ti.task_id)) ti.run(session=session) - producer_event = session.scalar(select(DatasetEvent).where(DatasetEvent.source_task_id == "producer")) + producer_event = session.scalar(select(AssetEvent).where(AssetEvent.source_task_id == "producer")) assert producer_event.source_task_id == "producer" assert producer_event.source_dag_id == "producer_dag" assert producer_event.source_run_id == "test" assert producer_event.source_map_index == -1 - assert producer_event.dataset.uri == ds_uri + assert producer_event.dataset.uri == asset_uri assert producer_event.extra == {"key": "value"} assert len(producer_event.source_aliases) == 1 - assert producer_event.source_aliases[0].name == dsa_name + assert producer_event.source_aliases[0].name == asset_alias_name - ds_obj = session.scalar(select(DatasetModel).where(DatasetModel.uri == ds_uri)) - assert len(ds_obj.aliases) == 1 - assert ds_obj.aliases[0].name == dsa_name + asset_obj = session.scalar(select(AssetModel).where(AssetModel.uri == asset_uri)) + assert len(asset_obj.aliases) == 1 + assert asset_obj.aliases[0].name == asset_alias_name - dsa_obj = session.scalar(select(DatasetAliasModel)) - assert len(dsa_obj.datasets) == 1 - assert dsa_obj.datasets[0].uri == ds_uri + asset_alias_obj = session.scalar(select(AssetAliasModel)) + assert len(asset_alias_obj.datasets) == 1 + assert asset_alias_obj.datasets[0].uri == asset_uri @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_inlet_dataset_extra(self, dag_maker, session): - from airflow.datasets import Dataset + def test_inlet_asset_extra(self, dag_maker, session): + from airflow.assets import Asset read_task_evaluated = False with dag_maker(schedule=None, session=session): - @task(outlets=Dataset("test_inlet_dataset_extra")) + @task(outlets=Asset("test_inlet_asset_extra")) def write(*, ti, outlet_events): - outlet_events["test_inlet_dataset_extra"].extra = {"from": ti.task_id} + outlet_events["test_inlet_asset_extra"].extra = {"from": ti.task_id} - @task(inlets=Dataset("test_inlet_dataset_extra")) + @task(inlets=Asset("test_inlet_asset_extra")) def read(*, inlet_events): - second_event = inlet_events["test_inlet_dataset_extra"][1] - assert second_event.uri == "test_inlet_dataset_extra" + second_event = inlet_events["test_inlet_asset_extra"][1] + assert second_event.uri == "test_inlet_asset_extra" assert second_event.extra == {"from": "write2"} - last_event = inlet_events["test_inlet_dataset_extra"][-1] - assert last_event.uri == "test_inlet_dataset_extra" + last_event = inlet_events["test_inlet_asset_extra"][-1] + assert last_event.uri == "test_inlet_asset_extra" assert last_event.extra == {"from": "write3"} with pytest.raises(KeyError): inlet_events["does_not_exist"] with pytest.raises(IndexError): - inlet_events["test_inlet_dataset_extra"][5] + inlet_events["test_inlet_asset_extra"][5] # TODO: Support slices. @@ -2780,42 +2793,42 @@ def read(*, inlet_events): assert read_task_evaluated @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_inlet_dataset_alias_extra(self, dag_maker, session): - ds_uri = "test_inlet_dataset_extra_ds" - dsa_name = "test_inlet_dataset_extra_dsa" - - ds_model = DatasetModel(id=1, uri=ds_uri) - dsa_model = DatasetAliasModel(name=dsa_name) - dsa_model.datasets.append(ds_model) - session.add_all([ds_model, dsa_model]) + def test_inlet_asset_alias_extra(self, dag_maker, session): + asset_uri = "test_inlet_asset_extra_ds" + asset_alias_name = "test_inlet_asset_extra_asset_alias" + + asset_model = AssetModel(id=1, uri=asset_uri) + asset_alias_model = AssetAliasModel(name=asset_alias_name) + asset_alias_model.datasets.append(asset_model) + session.add_all([asset_model, asset_alias_model]) session.commit() - from airflow.datasets import Dataset, DatasetAlias + from airflow.assets import Asset, AssetAlias read_task_evaluated = False with dag_maker(schedule=None, session=session): - @task(outlets=DatasetAlias(dsa_name)) + @task(outlets=AssetAlias(asset_alias_name)) def write(*, ti, outlet_events): - outlet_events[dsa_name].add(Dataset(ds_uri), extra={"from": ti.task_id}) + outlet_events[asset_alias_name].add(Asset(asset_uri), extra={"from": ti.task_id}) - @task(inlets=DatasetAlias(dsa_name)) + @task(inlets=AssetAlias(asset_alias_name)) def read(*, inlet_events): - second_event = inlet_events[DatasetAlias(dsa_name)][1] - assert second_event.uri == ds_uri + second_event = inlet_events[AssetAlias(asset_alias_name)][1] + assert second_event.uri == asset_uri assert second_event.extra == {"from": "write2"} - last_event = inlet_events[DatasetAlias(dsa_name)][-1] - assert last_event.uri == ds_uri + last_event = inlet_events[AssetAlias(asset_alias_name)][-1] + assert last_event.uri == asset_uri assert last_event.extra == {"from": "write3"} with pytest.raises(KeyError): inlet_events["does_not_exist"] with pytest.raises(KeyError): - inlet_events[DatasetAlias("does_not_exist")] + inlet_events[AssetAlias("does_not_exist")] with pytest.raises(IndexError): - inlet_events[DatasetAlias(dsa_name)][5] + inlet_events[AssetAlias(asset_alias_name)][5] nonlocal read_task_evaluated read_task_evaluated = True @@ -2842,21 +2855,21 @@ def read(*, inlet_events): assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis assert read_task_evaluated - def test_inlet_unresolved_dataset_alias(self, dag_maker, session): - dsa_name = "test_inlet_dataset_extra_dsa" + def test_inlet_unresolved_asset_alias(self, dag_maker, session): + asset_alias_name = "test_inlet_asset_extra_asset_alias" - dsa_model = DatasetAliasModel(name=dsa_name) - session.add(dsa_model) + asset_alias_model = AssetAliasModel(name=asset_alias_name) + session.add(asset_alias_model) session.commit() - from airflow.datasets import DatasetAlias + from airflow.assets import AssetAlias with dag_maker(schedule=None, session=session): - @task(inlets=DatasetAlias(dsa_name)) + @task(inlets=AssetAlias(asset_alias_name)) def read(*, inlet_events): with pytest.raises(IndexError): - inlet_events[DatasetAlias(dsa_name)][0] + inlet_events[AssetAlias(asset_alias_name)][0] read() @@ -2879,16 +2892,16 @@ def read(*, inlet_events): (lambda x: x[-5:5], []), ], ) - def test_inlet_dataset_extra_slice(self, dag_maker, session, slicer, expected): - from airflow.datasets import Dataset + def test_inlet_asset_extra_slice(self, dag_maker, session, slicer, expected): + from airflow.assets import Asset - ds_uri = "test_inlet_dataset_extra_slice" + asset_uri = "test_inlet_asset_extra_slice" with dag_maker(dag_id="write", schedule="@daily", params={"i": -1}, session=session): - @task(outlets=Dataset(ds_uri)) + @task(outlets=Asset(asset_uri)) def write(*, params, outlet_events): - outlet_events[ds_uri].extra = {"from": params["i"]} + outlet_events[asset_uri].extra = {"from": params["i"]} write() @@ -2905,10 +2918,10 @@ def write(*, params, outlet_events): with dag_maker(dag_id="read", schedule=None, session=session): - @task(inlets=Dataset(ds_uri)) + @task(inlets=Asset(asset_uri)) def read(*, inlet_events): nonlocal result - result = [e.extra for e in slicer(inlet_events[ds_uri])] + result = [e.extra for e in slicer(inlet_events[asset_uri])] read() @@ -2933,23 +2946,23 @@ def read(*, inlet_events): (lambda x: x[-5:5], []), ], ) - def test_inlet_dataset_alias_extra_slice(self, dag_maker, session, slicer, expected): - ds_uri = "test_inlet_dataset_alias_extra_slice_ds" - dsa_name = "test_inlet_dataset_alias_extra_slice_dsa" - - ds_model = DatasetModel(id=1, uri=ds_uri) - dsa_model = DatasetAliasModel(name=dsa_name) - dsa_model.datasets.append(ds_model) - session.add_all([ds_model, dsa_model]) + def test_inlet_asset_alias_extra_slice(self, dag_maker, session, slicer, expected): + asset_uri = "test_inlet_asset_alias_extra_slice_ds" + asset_alias_name = "test_inlet_asset_alias_extra_slice_asset_alias" + + asset_model = AssetModel(id=1, uri=asset_uri) + asset_alias_model = AssetAliasModel(name=asset_alias_name) + asset_alias_model.datasets.append(asset_model) + session.add_all([asset_model, asset_alias_model]) session.commit() - from airflow.datasets import Dataset + from airflow.assets import Asset with dag_maker(dag_id="write", schedule="@daily", params={"i": -1}, session=session): - @task(outlets=DatasetAlias(dsa_name)) + @task(outlets=AssetAlias(asset_alias_name)) def write(*, params, outlet_events): - outlet_events[dsa_name].add(Dataset(ds_uri), {"from": params["i"]}) + outlet_events[asset_alias_name].add(Asset(asset_uri), {"from": params["i"]}) write() @@ -2966,10 +2979,10 @@ def write(*, params, outlet_events): with dag_maker(dag_id="read", schedule=None, session=session): - @task(inlets=DatasetAlias(dsa_name)) + @task(inlets=AssetAlias(asset_alias_name)) def read(*, inlet_events): nonlocal result - result = [e.extra for e in slicer(inlet_events[DatasetAlias(dsa_name)])] + result = [e.extra for e in slicer(inlet_events[AssetAlias(asset_alias_name)])] read() @@ -2983,16 +2996,16 @@ def read(*, inlet_events): assert result == expected @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_changing_of_dataset_when_ddrq_is_already_populated(self, dag_maker): + def test_changing_of_asset_when_adrq_is_already_populated(self, dag_maker): """ - Test that when a task that produces dataset has ran, that changing the consumer - dag dataset will not cause primary key blank-out + Test that when a task that produces asset has ran, that changing the consumer + dag asset will not cause primary key blank-out """ - from airflow.datasets import Dataset + from airflow.assets import Asset with dag_maker(schedule=None, serialized=True) as dag1: - @task(outlets=Dataset("test/1")) + @task(outlets=Asset("test/1")) def test_task1(): print(1) @@ -3001,7 +3014,7 @@ def test_task1(): dr1 = dag_maker.create_dagrun() test_task1 = dag1.get_task("test_task1") - with dag_maker(dag_id="testdag", schedule=[Dataset("test/1")], serialized=True): + with dag_maker(dag_id="testdag", schedule=[Asset("test/1")], serialized=True): @task def test_task2(): @@ -3011,8 +3024,8 @@ def test_task2(): ti = dr1.get_task_instance(task_id="test_task1") ti.run() - # Change the dataset. - with dag_maker(dag_id="testdag", schedule=[Dataset("test2/1")], serialized=True): + # Change the asset. + with dag_maker(dag_id="testdag", schedule=[Asset("test2/1")], serialized=True): @task def test_task2(): @@ -3153,19 +3166,19 @@ def test_get_previous_start_date_none(self, dag_maker): assert ti_1.start_date is None @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_context_triggering_dataset_events_none(self, session, create_task_instance): + def test_context_triggering_asset_events_none(self, session, create_task_instance): ti = create_task_instance() template_context = ti.get_template_context() assert ti in session session.expunge_all() - assert template_context["triggering_dataset_events"] == {} + assert template_context["triggering_asset_events"] == {} @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode - def test_context_triggering_dataset_events(self, create_dummy_dag, session): - ds1 = DatasetModel(id=1, uri="one") - ds2 = DatasetModel(id=2, uri="two") + def test_context_triggering_asset_events(self, create_dummy_dag, session): + ds1 = AssetModel(id=1, uri="one") + ds2 = AssetModel(id=2, uri="two") session.add_all([ds1, ds2]) session.commit() @@ -3173,7 +3186,7 @@ def test_context_triggering_dataset_events(self, create_dummy_dag, session): triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} # it's easier to fake a manual run here dag, task1 = create_dummy_dag( - dag_id="test_triggering_dataset_events", + dag_id="test_triggering_asset_events", schedule=None, start_date=DEFAULT_DATE, task_id="test_context", @@ -3189,9 +3202,9 @@ def test_context_triggering_dataset_events(self, create_dummy_dag, session): data_interval=(execution_date, execution_date), **triggered_by_kwargs, ) - ds1_event = DatasetEvent(dataset_id=1) - ds2_event_1 = DatasetEvent(dataset_id=2) - ds2_event_2 = DatasetEvent(dataset_id=2) + ds1_event = AssetEvent(dataset_id=1) + ds2_event_1 = AssetEvent(dataset_id=2) + ds2_event_2 = AssetEvent(dataset_id=2) dr.consumed_dataset_events.append(ds1_event) dr.consumed_dataset_events.append(ds2_event_1) dr.consumed_dataset_events.append(ds2_event_2) @@ -3207,7 +3220,7 @@ def test_context_triggering_dataset_events(self, create_dummy_dag, session): template_context = ti.get_template_context() - assert template_context["triggering_dataset_events"] == { + assert template_context["triggering_asset_events"] == { "one": [ds1_event], "two": [ds2_event_1, ds2_event_2], } @@ -4187,7 +4200,7 @@ def _clean(): db.clear_db_dags() db.clear_db_sla_miss() db.clear_db_import_errors() - db.clear_db_datasets() + db.clear_db_assets() def setup_method(self) -> None: self._clean() diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index 28c66511a3d38..c98dc7018a4a1 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -894,8 +894,8 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance): "ti", "var", # Accessor for Variable; var->json and var->value. "conn", # Accessor for Connection. - "inlet_events", # Accessor for inlet DatasetEvent. - "outlet_events", # Accessor for outlet DatasetEvent. + "inlet_events", # Accessor for inlet AssetEvent. + "outlet_events", # Accessor for outlet AssetEvent. ] ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None) diff --git a/tests/providers/mysql/datasets/__init__.py b/tests/providers/amazon/aws/assets/__init__.py similarity index 100% rename from tests/providers/mysql/datasets/__init__.py rename to tests/providers/amazon/aws/assets/__init__.py diff --git a/tests/providers/amazon/aws/datasets/test_s3.py b/tests/providers/amazon/aws/assets/test_s3.py similarity index 75% rename from tests/providers/amazon/aws/datasets/test_s3.py rename to tests/providers/amazon/aws/assets/test_s3.py index 893d6acf677bc..e918c9fdffa2f 100644 --- a/tests/providers/amazon/aws/datasets/test_s3.py +++ b/tests/providers/amazon/aws/assets/test_s3.py @@ -20,10 +20,10 @@ import pytest -from airflow.datasets import Dataset -from airflow.providers.amazon.aws.datasets.s3 import ( - convert_dataset_to_openlineage, - create_dataset, +from airflow.providers.amazon.aws.assets.s3 import ( + Asset, + convert_asset_to_openlineage, + create_asset, sanitize_uri, ) from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -50,9 +50,9 @@ def test_sanitize_uri_no_path(): assert result.path == "" -def test_create_dataset(): - assert create_dataset(bucket="test-bucket", key="test-path") == Dataset(uri="s3://test-bucket/test-path") - assert create_dataset(bucket="test-bucket", key="test-dir/test-path") == Dataset( +def test_create_asset(): + assert create_asset(bucket="test-bucket", key="test-path") == Asset(uri="s3://test-bucket/test-path") + assert create_asset(bucket="test-bucket", key="test-dir/test-path") == Asset( uri="s3://test-bucket/test-dir/test-path" ) @@ -65,15 +65,15 @@ def test_sanitize_uri_trailing_slash(): assert result.path == "/" -def test_convert_dataset_to_openlineage_valid(): +def test_convert_asset_to_openlineage_valid(): uri = "s3://bucket/dir/file.txt" - ol_dataset = convert_dataset_to_openlineage(dataset=Dataset(uri=uri), lineage_context=S3Hook()) + ol_dataset = convert_asset_to_openlineage(asset=Asset(uri=uri), lineage_context=S3Hook()) assert ol_dataset.namespace == "s3://bucket" assert ol_dataset.name == "dir/file.txt" @pytest.mark.parametrize("uri", ("s3://bucket", "s3://bucket/")) -def test_convert_dataset_to_openlineage_no_path(uri): - ol_dataset = convert_dataset_to_openlineage(dataset=Dataset(uri=uri), lineage_context=S3Hook()) +def test_convert_asset_to_openlineage_no_path(uri): + ol_dataset = convert_asset_to_openlineage(asset=Asset(uri=uri), lineage_context=S3Hook()) assert ol_dataset.namespace == "s3://bucket" assert ol_dataset.name == "/" diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py index f54a2a3e5fb1f..d827ba3ff0e6d 100644 --- a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -23,10 +23,24 @@ from flask import Flask, session from flask_appbuilder.menu import MenuItem +from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities +from airflow.providers.amazon.aws.auth_manager.avp.facade import AwsAuthManagerAmazonVerifiedPermissionsFacade +from airflow.providers.amazon.aws.auth_manager.aws_auth_manager import AwsAuthManager from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import ( AwsSecurityManagerOverride, ) +from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser +from airflow.security.permissions import ( + RESOURCE_AUDIT_LOG, + RESOURCE_CLUSTER_ACTIVITY, + RESOURCE_CONNECTION, + RESOURCE_VARIABLE, +) +from airflow.www import app as application +from airflow.www.extensions.init_appbuilder import init_appbuilder from tests.test_utils.compat import AIRFLOW_V_2_8_PLUS, AIRFLOW_V_2_9_PLUS +from tests.test_utils.config import conf_vars +from tests.test_utils.www import check_content_in_response try: from airflow.auth.managers.models.resource_details import ( @@ -35,7 +49,6 @@ ConnectionDetails, DagAccessEntity, DagDetails, - DatasetDetails, PoolDetails, VariableDetails, ) @@ -47,24 +60,18 @@ ) else: raise -from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities -from airflow.providers.amazon.aws.auth_manager.avp.facade import AwsAuthManagerAmazonVerifiedPermissionsFacade -from airflow.providers.amazon.aws.auth_manager.aws_auth_manager import AwsAuthManager -from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser -from airflow.security.permissions import ( - RESOURCE_AUDIT_LOG, - RESOURCE_CLUSTER_ACTIVITY, - RESOURCE_CONNECTION, - RESOURCE_DATASET, - RESOURCE_VARIABLE, -) -from airflow.www import app as application -from airflow.www.extensions.init_appbuilder import init_appbuilder -from tests.test_utils.config import conf_vars -from tests.test_utils.www import check_content_in_response if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod + from airflow.auth.managers.models.resource_details import AssetDetails + from airflow.security.permissions import RESOURCE_ASSET +else: + try: + from airflow.auth.managers.models.resource_details import AssetDetails + from airflow.security.permissions import RESOURCE_ASSET + except ImportError: + from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails + from airflow.security.permissions import RESOURCE_DATASET as RESOURCE_ASSET pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Test requires Airflow 2.9+"), @@ -324,12 +331,12 @@ def test_is_authorized_dag( "details, user, expected_user, expected_entity_id", [ (None, None, ANY, None), - (DatasetDetails(uri="uri"), mock, mock, "uri"), + (AssetDetails(uri="uri"), mock, mock, "uri"), ], ) @patch.object(AwsAuthManager, "avp_facade") @patch.object(AwsAuthManager, "get_user") - def test_is_authorized_dataset( + def test_is_authorized_asset( self, mock_get_user, mock_avp_facade, @@ -343,12 +350,12 @@ def test_is_authorized_dataset( mock_avp_facade.is_authorized = is_authorized method: ResourceMethod = "GET" - result = auth_manager.is_authorized_dataset(method=method, details=details, user=user) + result = auth_manager.is_authorized_asset(method=method, details=details, user=user) if not user: mock_get_user.assert_called_once() is_authorized.assert_called_once_with( - method=method, entity_type=AvpEntities.DATASET, user=expected_user, entity_id=expected_entity_id + method=method, entity_type=AvpEntities.ASSET, user=expected_user, entity_id=expected_entity_id ) assert result @@ -611,7 +618,7 @@ def test_filter_permitted_menu_items(self, mock_get_user, auth_manager, test_use "request": { "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, "action": {"actionType": "Airflow::Action", "actionId": "Menu.MENU"}, - "resource": {"entityType": "Airflow::Menu", "entityId": "Datasets"}, + "resource": {"entityType": "Airflow::Menu", "entityId": RESOURCE_ASSET}, }, "decision": "DENY", }, @@ -649,7 +656,7 @@ def test_filter_permitted_menu_items(self, mock_get_user, auth_manager, test_use result = auth_manager.filter_permitted_menu_items( [ MenuItem("Category1", childs=[MenuItem(RESOURCE_CONNECTION), MenuItem(RESOURCE_VARIABLE)]), - MenuItem("Category2", childs=[MenuItem(RESOURCE_DATASET)]), + MenuItem("Category2", childs=[MenuItem(RESOURCE_ASSET)]), MenuItem(RESOURCE_CLUSTER_ACTIVITY), MenuItem(RESOURCE_AUDIT_LOG), MenuItem("CustomPage"), @@ -679,7 +686,7 @@ def test_filter_permitted_menu_items(self, mock_get_user, auth_manager, test_use { "method": "MENU", "entity_type": AvpEntities.MENU, - "entity_id": "Datasets", + "entity_id": RESOURCE_ASSET, }, {"method": "MENU", "entity_type": AvpEntities.MENU, "entity_id": "Cluster Activity"}, {"method": "MENU", "entity_type": AvpEntities.MENU, "entity_id": "Audit Logs"}, diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index 97696c64b6e7a..43c4b94445b6f 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -31,9 +31,9 @@ from botocore.exceptions import ClientError from moto import mock_aws -from airflow.datasets import Dataset from airflow.exceptions import AirflowException from airflow.models import Connection +from airflow.providers.amazon.aws.assets.s3 import Asset from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure from airflow.providers.amazon.aws.hooks.s3 import ( NO_ACL, @@ -58,6 +58,21 @@ def s3_bucket(mocked_s3_res): return bucket +if AIRFLOW_V_2_10_PLUS: + + @pytest.fixture + def hook_lineage_collector(): + from airflow.lineage import hook + from airflow.providers.amazon.aws.hooks.s3 import get_hook_lineage_collector + + hook._hook_lineage_collector = None + hook._hook_lineage_collector = hook.HookLineageCollector() + + yield get_hook_lineage_collector() + + hook._hook_lineage_collector = None + + class TestAwsS3Hook: @mock_aws def test_get_conn(self): @@ -429,9 +444,10 @@ def test_load_string(self, s3_bucket): @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") def test_load_string_exposes_lineage(self, s3_bucket, hook_lineage_collector): hook = S3Hook() + hook.load_string("Contént", "my_key", s3_bucket) - assert len(hook_lineage_collector.collected_datasets.outputs) == 1 - assert hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset( + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( uri=f"s3://{s3_bucket}/my_key" ) @@ -1023,8 +1039,8 @@ def test_load_file_exposes_lineage(self, s3_bucket, tmp_path, hook_lineage_colle path = tmp_path / "testfile" path.write_text("Content") hook.load_file(path, "my_key", s3_bucket) - assert len(hook_lineage_collector.collected_datasets.outputs) == 1 - assert hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset( + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( uri=f"s3://{s3_bucket}/my_key" ) @@ -1095,13 +1111,13 @@ def test_copy_object_ol_instrumentation(self, s3_bucket, hook_lineage_collector) "get_conn", ): mock_hook.copy_object("my_key", "my_key3", s3_bucket, s3_bucket) - assert len(hook_lineage_collector.collected_datasets.inputs) == 1 - assert hook_lineage_collector.collected_datasets.inputs[0].dataset == Dataset( + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( uri=f"s3://{s3_bucket}/my_key" ) - assert len(hook_lineage_collector.collected_datasets.outputs) == 1 - assert hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset( + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( uri=f"s3://{s3_bucket}/my_key3" ) @@ -1233,8 +1249,8 @@ def test_download_file_exposes_lineage(self, mock_temp_file, tmp_path, hook_line s3_hook.download_file(key=key, bucket_name=bucket) - assert len(hook_lineage_collector.collected_datasets.inputs) == 1 - assert hook_lineage_collector.collected_datasets.inputs[0].dataset == Dataset( + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( uri="s3://test_bucket/test_key" ) @@ -1285,14 +1301,14 @@ def test_download_file_with_preserve_name_exposes_lineage( use_autogenerated_subdir=False, ) - assert len(hook_lineage_collector.collected_datasets.inputs) == 1 - assert hook_lineage_collector.collected_datasets.inputs[0].dataset == Dataset( - uri="s3://test_bucket/test_key/test.log" + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( + uri="s3://test_bucket/test_key/test.log", extra={} ) - assert len(hook_lineage_collector.collected_datasets.outputs) == 1 - assert hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset( - uri=f"file://{local_path}/test.log", + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( + uri=f"file://{local_path}/test.log", extra={} ) @mock.patch("airflow.providers.amazon.aws.hooks.s3.open") diff --git a/tests/providers/postgres/datasets/__init__.py b/tests/providers/common/compat/openlineage/utils/__init__.py similarity index 100% rename from tests/providers/postgres/datasets/__init__.py rename to tests/providers/common/compat/openlineage/utils/__init__.py diff --git a/tests/providers/common/compat/openlineage/utils/test_utils.py b/tests/providers/common/compat/openlineage/utils/test_utils.py new file mode 100644 index 0000000000000..72af469a1becd --- /dev/null +++ b/tests/providers/common/compat/openlineage/utils/test_utils.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +def test_import(): + from airflow.providers.common.compat.openlineage.utils.utils import translate_airflow_asset + + assert translate_airflow_asset is not None diff --git a/tests/providers/trino/datasets/__init__.py b/tests/providers/common/compat/security/__init__.py similarity index 100% rename from tests/providers/trino/datasets/__init__.py rename to tests/providers/common/compat/security/__init__.py diff --git a/tests/providers/common/compat/security/test_permissions.py b/tests/providers/common/compat/security/test_permissions.py new file mode 100644 index 0000000000000..40a13832f9e25 --- /dev/null +++ b/tests/providers/common/compat/security/test_permissions.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +def test_import(): + from airflow.providers.common.compat.security.permissions import RESOURCE_ASSET + + assert RESOURCE_ASSET is not None diff --git a/tests/providers/common/io/assets/__init__.py b/tests/providers/common/io/assets/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/common/io/assets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/common/io/datasets/test_file.py b/tests/providers/common/io/assets/test_file.py similarity index 83% rename from tests/providers/common/io/datasets/test_file.py rename to tests/providers/common/io/assets/test_file.py index d8d53247a6796..21357f933fde8 100644 --- a/tests/providers/common/io/datasets/test_file.py +++ b/tests/providers/common/io/assets/test_file.py @@ -20,11 +20,11 @@ import pytest -from airflow.datasets import Dataset from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset -from airflow.providers.common.io.datasets.file import ( - convert_dataset_to_openlineage, - create_dataset, +from airflow.providers.common.io.assets.file import ( + Asset, + convert_asset_to_openlineage, + create_asset, sanitize_uri, ) @@ -47,8 +47,8 @@ def test_sanitize_uri_invalid(uri): sanitize_uri(urlsplit(uri)) -def test_file_dataset(): - assert create_dataset(path="/asdf/fdsa") == Dataset(uri="file:///asdf/fdsa") +def test_file_asset(): + assert create_asset(path="/asdf/fdsa") == Asset(uri="file:///asdf/fdsa") @pytest.mark.parametrize( @@ -62,6 +62,6 @@ def test_file_dataset(): ("file:///C://dir/file", OpenLineageDataset(namespace="file://", name="/C://dir/file")), ), ) -def test_convert_dataset_to_openlineage(uri, ol_dataset): - result = convert_dataset_to_openlineage(Dataset(uri=uri), None) +def test_convert_asset_to_openlineage(uri, ol_dataset): + result = convert_asset_to_openlineage(Asset(uri=uri), None) assert result == ol_dataset diff --git a/tests/providers/fab/auth_manager/test_fab_auth_manager.py b/tests/providers/fab/auth_manager/test_fab_auth_manager.py index b755afcc70d03..d727b6090822f 100644 --- a/tests/providers/fab/auth_manager/test_fab_auth_manager.py +++ b/tests/providers/fab/auth_manager/test_fab_auth_manager.py @@ -48,7 +48,6 @@ RESOURCE_CONNECTION, RESOURCE_DAG, RESOURCE_DAG_RUN, - RESOURCE_DATASET, RESOURCE_DOCS, RESOURCE_JOB, RESOURCE_PLUGIN, @@ -62,11 +61,20 @@ if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod + from airflow.security.permissions import RESOURCE_ASSET +else: + try: + from airflow.security.permissions import RESOURCE_ASSET + except ImportError: + from airflow.security.permissions import ( + RESOURCE_DATASET as RESOURCE_ASSET, + ) + IS_AUTHORIZED_METHODS_SIMPLE = { "is_authorized_configuration": RESOURCE_CONFIG, "is_authorized_connection": RESOURCE_CONNECTION, - "is_authorized_dataset": RESOURCE_DATASET, + "is_authorized_asset": RESOURCE_ASSET, "is_authorized_variable": RESOURCE_VARIABLE, } diff --git a/tests/providers/fab/auth_manager/test_security.py b/tests/providers/fab/auth_manager/test_security.py index b6aca2d4513a5..156b5cf626271 100644 --- a/tests/providers/fab/auth_manager/test_security.py +++ b/tests/providers/fab/auth_manager/test_security.py @@ -22,6 +22,7 @@ import json import logging import os +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch @@ -60,6 +61,15 @@ from tests.test_utils.mock_security_manager import MockSecurityManager from tests.test_utils.permissions import _resource_name +if TYPE_CHECKING: + from airflow.security.permissions import RESOURCE_ASSET +else: + try: + from airflow.security.permissions import RESOURCE_ASSET + except ImportError: + from airflow.security.permissions import RESOURCE_DATASET as RESOURCE_ASSET + + pytestmark = pytest.mark.db_test READ_WRITE = {permissions.RESOURCE_DAG: {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}} @@ -435,7 +445,7 @@ def test_get_user_roles_for_anonymous_user(app, security_manager): (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_DEPENDENCIES), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_READ, RESOURCE_ASSET), (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), @@ -454,7 +464,7 @@ def test_get_user_roles_for_anonymous_user(app, security_manager): (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_DEPENDENCIES), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_ACCESS_MENU, RESOURCE_ASSET), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_CLUSTER_ACTIVITY), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_JOB), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_SLA_MISS), diff --git a/tests/providers/mysql/assets/__init__.py b/tests/providers/mysql/assets/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/mysql/assets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/mysql/datasets/test_mysql.py b/tests/providers/mysql/assets/test_mysql.py similarity index 97% rename from tests/providers/mysql/datasets/test_mysql.py rename to tests/providers/mysql/assets/test_mysql.py index 5f31d72991f27..28e44558f31d6 100644 --- a/tests/providers/mysql/datasets/test_mysql.py +++ b/tests/providers/mysql/assets/test_mysql.py @@ -21,7 +21,7 @@ import pytest -from airflow.providers.mysql.datasets.mysql import sanitize_uri +from airflow.providers.mysql.assets.mysql import sanitize_uri @pytest.mark.parametrize( diff --git a/tests/providers/openlineage/extractors/test_manager.py b/tests/providers/openlineage/extractors/test_manager.py index 479347179bd17..601a456604843 100644 --- a/tests/providers/openlineage/extractors/test_manager.py +++ b/tests/providers/openlineage/extractors/test_manager.py @@ -25,7 +25,6 @@ from openlineage.client.event_v2 import Dataset as OpenLineageDataset from openlineage.client.facet_v2 import documentation_dataset, ownership_dataset, schema_dataset -from airflow.datasets import Dataset from airflow.io.path import ObjectStoragePath from airflow.lineage.entities import Column, File, Table, User from airflow.models.baseoperator import BaseOperator @@ -33,12 +32,36 @@ from airflow.operators.python import PythonOperator from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.extractors.manager import ExtractorManager +from airflow.providers.openlineage.utils.utils import Asset from airflow.utils.state import State from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS if TYPE_CHECKING: from airflow.utils.context import Context +if AIRFLOW_V_2_10_PLUS: + + @pytest.fixture + def hook_lineage_collector(): + from importlib.util import find_spec + + from airflow.lineage import hook + + if find_spec("airflow.assets"): + # Dataset has been renamed as Asset in 3.0 + from airflow.lineage.hook import get_hook_lineage_collector + else: + from airflow.providers.openlineage.utils.asset_compat_lineage_collector import ( + get_hook_lineage_collector, + ) + + hook._hook_lineage_collector = None + hook._hook_lineage_collector = hook.HookLineageCollector() + + yield get_hook_lineage_collector() + + hook._hook_lineage_collector = None + @pytest.mark.parametrize( ("uri", "dataset"), @@ -213,8 +236,8 @@ def test_extractor_manager_uses_hook_level_lineage(hook_lineage_collector): del task.get_openlineage_facets_on_complete ti = MagicMock() - hook_lineage_collector.add_input_dataset(None, uri="s3://bucket/input_key") - hook_lineage_collector.add_output_dataset(None, uri="s3://bucket/output_key") + hook_lineage_collector.add_input_asset(None, uri="s3://bucket/input_key") + hook_lineage_collector.add_output_asset(None, uri="s3://bucket/output_key") extractor_manager = ExtractorManager() metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task, complete=True, task_instance=ti) @@ -236,7 +259,7 @@ def get_openlineage_facets_on_start(self): dagrun = MagicMock() task = FakeSupportedOperator(task_id="test_task_extractor") ti = MagicMock() - hook_lineage_collector.add_input_dataset(None, uri="s3://bucket/input_key") + hook_lineage_collector.add_input_asset(None, uri="s3://bucket/input_key") extractor_manager = ExtractorManager() metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task, complete=True, task_instance=ti) @@ -269,7 +292,7 @@ def use_read(): ti.run() - datasets = hook_lineage_collector.collected_datasets + datasets = hook_lineage_collector.collected_assets assert len(datasets.outputs) == 1 - assert datasets.outputs[0].dataset == Dataset(uri=path) + assert datasets.outputs[0].asset == Asset(uri=path) diff --git a/tests/providers/postgres/assets/__init__.py b/tests/providers/postgres/assets/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/postgres/assets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/postgres/datasets/test_postgres.py b/tests/providers/postgres/assets/test_postgres.py similarity index 97% rename from tests/providers/postgres/datasets/test_postgres.py rename to tests/providers/postgres/assets/test_postgres.py index 40d6bf11d235d..82c64759a290a 100644 --- a/tests/providers/postgres/datasets/test_postgres.py +++ b/tests/providers/postgres/assets/test_postgres.py @@ -21,7 +21,7 @@ import pytest -from airflow.providers.postgres.datasets.postgres import sanitize_uri +from airflow.providers.postgres.assets.postgres import sanitize_uri @pytest.mark.parametrize( diff --git a/tests/providers/trino/assets/__init__.py b/tests/providers/trino/assets/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/trino/assets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/trino/datasets/test_trino.py b/tests/providers/trino/assets/test_trino.py similarity index 97% rename from tests/providers/trino/datasets/test_trino.py rename to tests/providers/trino/assets/test_trino.py index 12cacd4eb0cf2..4ebea16d9fc6b 100644 --- a/tests/providers/trino/datasets/test_trino.py +++ b/tests/providers/trino/assets/test_trino.py @@ -21,7 +21,7 @@ import pytest -from airflow.providers.trino.datasets.trino import sanitize_uri +from airflow.providers.trino.assets.trino import sanitize_uri @pytest.mark.parametrize( diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 758c7f496ed93..6910514776afe 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -42,7 +42,7 @@ from kubernetes.client import models as k8s import airflow -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.decorators import teardown from airflow.decorators.base import DecoratedOperator from airflow.exceptions import ( @@ -1656,16 +1656,16 @@ class DerivedSensor(ExternalTaskSensor): ] @pytest.mark.db_test - def test_dag_deps_datasets_with_duplicate_dataset(self): + def test_dag_deps_assets_with_duplicate_asset(self): """ - Check that dag_dependencies node is populated correctly for a DAG with duplicate datasets. + Check that dag_dependencies node is populated correctly for a DAG with duplicate assets. """ from airflow.sensors.external_task import ExternalTaskSensor - d1 = Dataset("d1") - d2 = Dataset("d2") - d3 = Dataset("d3") - d4 = Dataset("d4") + d1 = Asset("d1") + d2 = Asset("d2") + d3 = Asset("d3") + d4 = Asset("d4") execution_date = datetime(2020, 1, 1) with DAG(dag_id="test", start_date=execution_date, schedule=[d1, d1, d1, d1, d1]) as dag: ExternalTaskSensor( @@ -1673,13 +1673,13 @@ def test_dag_deps_datasets_with_duplicate_dataset(self): external_dag_id="external_dag_id", mode="reschedule", ) - BashOperator(task_id="dataset_writer", bash_command="echo hello", outlets=[d2, d2, d2, d3]) + BashOperator(task_id="asset_writer", bash_command="echo hello", outlets=[d2, d2, d2, d3]) @dag.task(outlets=[d4]) - def other_dataset_writer(x): + def other_asset_writer(x): pass - other_dataset_writer.expand(x=[1, 2]) + other_asset_writer.expand(x=[1, 2]) dag = SerializedDAG.to_dict(dag) actual = sorted(dag["dag"]["dag_dependencies"], key=lambda x: tuple(x.values())) @@ -1687,8 +1687,8 @@ def other_dataset_writer(x): [ { "source": "test", - "target": "dataset", - "dependency_type": "dataset", + "target": "asset", + "dependency_type": "asset", "dependency_id": "d4", }, { @@ -1699,44 +1699,44 @@ def other_dataset_writer(x): }, { "source": "test", - "target": "dataset", - "dependency_type": "dataset", + "target": "asset", + "dependency_type": "asset", "dependency_id": "d3", }, { "source": "test", - "target": "dataset", - "dependency_type": "dataset", + "target": "asset", + "dependency_type": "asset", "dependency_id": "d2", }, { - "source": "dataset", + "source": "asset", "target": "test", - "dependency_type": "dataset", + "dependency_type": "asset", "dependency_id": "d1", }, { "dependency_id": "d1", - "dependency_type": "dataset", - "source": "dataset", + "dependency_type": "asset", + "source": "asset", "target": "test", }, { "dependency_id": "d1", - "dependency_type": "dataset", - "source": "dataset", + "dependency_type": "asset", + "source": "asset", "target": "test", }, { "dependency_id": "d1", - "dependency_type": "dataset", - "source": "dataset", + "dependency_type": "asset", + "source": "asset", "target": "test", }, { "dependency_id": "d1", - "dependency_type": "dataset", - "source": "dataset", + "dependency_type": "asset", + "source": "asset", "target": "test", }, ], @@ -1745,16 +1745,16 @@ def other_dataset_writer(x): assert actual == expected @pytest.mark.db_test - def test_dag_deps_datasets(self): + def test_dag_deps_assets(self): """ - Check that dag_dependencies node is populated correctly for a DAG with datasets. + Check that dag_dependencies node is populated correctly for a DAG with assets. """ from airflow.sensors.external_task import ExternalTaskSensor - d1 = Dataset("d1") - d2 = Dataset("d2") - d3 = Dataset("d3") - d4 = Dataset("d4") + d1 = Asset("d1") + d2 = Asset("d2") + d3 = Asset("d3") + d4 = Asset("d4") execution_date = datetime(2020, 1, 1) with DAG(dag_id="test", start_date=execution_date, schedule=[d1]) as dag: ExternalTaskSensor( @@ -1762,13 +1762,13 @@ def test_dag_deps_datasets(self): external_dag_id="external_dag_id", mode="reschedule", ) - BashOperator(task_id="dataset_writer", bash_command="echo hello", outlets=[d2, d3]) + BashOperator(task_id="asset_writer", bash_command="echo hello", outlets=[d2, d3]) @dag.task(outlets=[d4]) - def other_dataset_writer(x): + def other_asset_writer(x): pass - other_dataset_writer.expand(x=[1, 2]) + other_asset_writer.expand(x=[1, 2]) dag = SerializedDAG.to_dict(dag) actual = sorted(dag["dag"]["dag_dependencies"], key=lambda x: tuple(x.values())) @@ -1776,8 +1776,8 @@ def other_dataset_writer(x): [ { "source": "test", - "target": "dataset", - "dependency_type": "dataset", + "target": "asset", + "dependency_type": "asset", "dependency_id": "d4", }, { @@ -1788,20 +1788,20 @@ def other_dataset_writer(x): }, { "source": "test", - "target": "dataset", - "dependency_type": "dataset", + "target": "asset", + "dependency_type": "asset", "dependency_id": "d3", }, { "source": "test", - "target": "dataset", - "dependency_type": "dataset", + "target": "asset", + "dependency_type": "asset", "dependency_id": "d2", }, { - "source": "dataset", + "source": "asset", "target": "test", - "dependency_type": "dataset", + "dependency_type": "asset", "dependency_id": "d1", }, ], diff --git a/tests/serialization/test_pydantic_models.py b/tests/serialization/test_pydantic_models.py index d29cba23d86e7..55c61ea220263 100644 --- a/tests/serialization/test_pydantic_models.py +++ b/tests/serialization/test_pydantic_models.py @@ -27,16 +27,16 @@ from airflow.jobs.job import Job from airflow.jobs.local_task_job_runner import LocalTaskJobRunner from airflow.models import MappedOperator -from airflow.models.dag import DAG, DagModel, create_timetable -from airflow.models.dataset import ( - DagScheduleDatasetReference, - DatasetEvent, - DatasetModel, - TaskOutletDatasetReference, +from airflow.models.asset import ( + AssetEvent, + AssetModel, + DagScheduleAssetReference, + TaskOutletAssetReference, ) +from airflow.models.dag import DAG, DagModel, create_timetable +from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag import DagModelPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic -from airflow.serialization.pydantic.dataset import DatasetEventPydantic from airflow.serialization.pydantic.job import JobPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.serialization.serialized_objects import BaseSerialization @@ -222,15 +222,15 @@ def test_serializing_pydantic_local_task_job(session, create_task_instance): @pytest.mark.skip_if_database_isolation_mode @pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") def test_serializing_pydantic_dataset_event(session, create_task_instance, create_dummy_dag): - ds1 = DatasetModel(id=1, uri="one", extra={"foo": "bar"}) - ds2 = DatasetModel(id=2, uri="two") + ds1 = AssetModel(id=1, uri="one", extra={"foo": "bar"}) + ds2 = AssetModel(id=2, uri="two") session.add_all([ds1, ds2]) session.commit() # it's easier to fake a manual run here dag, task1 = create_dummy_dag( - dag_id="test_triggering_dataset_events", + dag_id="test_triggering_asset_events", schedule=None, start_date=DEFAULT_DATE, task_id="test_context", @@ -250,29 +250,29 @@ def test_serializing_pydantic_dataset_event(session, create_task_instance, creat data_interval=(execution_date, execution_date), **triggered_by_kwargs, ) - ds1_event = DatasetEvent(dataset_id=1) - ds2_event_1 = DatasetEvent(dataset_id=2) - ds2_event_2 = DatasetEvent(dataset_id=2) - - dag_ds_ref = DagScheduleDatasetReference(dag_id=dag.dag_id) - session.add(dag_ds_ref) - dag_ds_ref.dataset = ds1 - task_ds_ref = TaskOutletDatasetReference(task_id=task1.task_id, dag_id=dag.dag_id) + asset1_event = AssetEvent(dataset_id=1) + asset2_event_1 = AssetEvent(dataset_id=2) + asset2_event_2 = AssetEvent(dataset_id=2) + + dag_asset_ref = DagScheduleAssetReference(dag_id=dag.dag_id) + session.add(dag_asset_ref) + dag_asset_ref.dataset = ds1 + task_ds_ref = TaskOutletAssetReference(task_id=task1.task_id, dag_id=dag.dag_id) session.add(task_ds_ref) task_ds_ref.dataset = ds1 - dr.consumed_dataset_events.append(ds1_event) - dr.consumed_dataset_events.append(ds2_event_1) - dr.consumed_dataset_events.append(ds2_event_2) + dr.consumed_dataset_events.append(asset1_event) + dr.consumed_dataset_events.append(asset2_event_1) + dr.consumed_dataset_events.append(asset2_event_2) session.commit() TracebackSessionForTests.set_allow_db_access(session, False) - print(ds2_event_2.dataset.consuming_dags) - pydantic_dse1 = DatasetEventPydantic.model_validate(ds1_event) + print(asset2_event_2.dataset.consuming_dags) + pydantic_dse1 = AssetEventPydantic.model_validate(asset1_event) json_string1 = pydantic_dse1.model_dump_json() print(json_string1) - pydantic_dse2 = DatasetEventPydantic.model_validate(ds2_event_1) + pydantic_dse2 = AssetEventPydantic.model_validate(asset2_event_1) json_string2 = pydantic_dse2.model_dump_json() print(json_string2) @@ -280,13 +280,13 @@ def test_serializing_pydantic_dataset_event(session, create_task_instance, creat json_string_dr = pydantic_dag_run.model_dump_json() print(json_string_dr) - deserialized_model1 = DatasetEventPydantic.model_validate_json(json_string1) + deserialized_model1 = AssetEventPydantic.model_validate_json(json_string1) assert deserialized_model1.dataset.id == 1 assert deserialized_model1.dataset.uri == "one" assert len(deserialized_model1.dataset.consuming_dags) == 1 assert len(deserialized_model1.dataset.producing_tasks) == 1 - deserialized_model2 = DatasetEventPydantic.model_validate_json(json_string2) + deserialized_model2 = AssetEventPydantic.model_validate_json(json_string2) assert deserialized_model2.dataset.id == 2 assert deserialized_model2.dataset.uri == "two" assert len(deserialized_model2.dataset.consuming_dags) == 0 diff --git a/tests/serialization/test_serde.py b/tests/serialization/test_serde.py index cc50e772248a7..a36013d20cfa7 100644 --- a/tests/serialization/test_serde.py +++ b/tests/serialization/test_serde.py @@ -28,7 +28,7 @@ import pytest from pydantic import BaseModel -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.serialization.serde import ( CLASSNAME, DATA, @@ -336,7 +336,7 @@ def test_backwards_compat(self): """ uri = "s3://does/not/exist" data = { - "__type": "airflow.datasets.Dataset", + "__type": "airflow.assets.Asset", "__source": None, "__var": { "__var": { @@ -364,7 +364,7 @@ def test_backwards_compat_wrapped(self): assert e["extra"] == {"hi": "bye"} def test_encode_dataset(self): - dataset = Dataset("mytest://dataset") + dataset = Asset("mytest://dataset") obj = deserialize(serialize(dataset)) assert dataset.uri == obj.uri diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 104062ff6c318..0bc8a67ef879a 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -31,7 +31,7 @@ from pendulum.tz.timezone import Timezone from pydantic import BaseModel -from airflow.datasets import Dataset, DatasetAlias, DatasetAliasEvent +from airflow.assets import Asset, AssetAlias, AssetAliasEvent from airflow.exceptions import ( AirflowException, AirflowFailException, @@ -40,10 +40,10 @@ TaskDeferred, ) from airflow.jobs.job import Job +from airflow.models.asset import AssetEvent from airflow.models.connection import Connection from airflow.models.dag import DAG, DagModel, DagTag from airflow.models.dagrun import DagRun -from airflow.models.dataset import DatasetEvent from airflow.models.param import Param from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance from airflow.models.tasklog import LogTemplate @@ -51,9 +51,9 @@ from airflow.operators.empty import EmptyOperator from airflow.operators.python import PythonOperator from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding +from airflow.serialization.pydantic.asset import AssetEventPydantic, AssetPydantic from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic -from airflow.serialization.pydantic.dataset import DatasetEventPydantic, DatasetPydantic from airflow.serialization.pydantic.job import JobPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.serialization.pydantic.tasklog import LogTemplatePydantic @@ -163,7 +163,7 @@ def equal_exception(a: AirflowException, b: AirflowException) -> bool: def equal_outlet_event_accessor(a: OutletEventAccessor, b: OutletEventAccessor) -> bool: - return a.raw_key == b.raw_key and a.extra == b.extra and a.dataset_alias_events == b.dataset_alias_events + return a.raw_key == b.raw_key and a.extra == b.extra and a.asset_alias_events == b.asset_alias_events class MockLazySelectSequence(LazySelectSequence): @@ -232,7 +232,7 @@ def __len__(self) -> int: None, ), (MockLazySelectSequence(), None, lambda a, b: len(a) == len(b) and isinstance(b, list)), - (Dataset(uri="test"), DAT.DATASET, equals), + (Asset(uri="test"), DAT.ASSET, equals), (SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals), ( Connection(conn_id="TEST_ID", uri="mysql://"), @@ -240,24 +240,24 @@ def __len__(self) -> int: lambda a, b: a.get_uri() == b.get_uri(), ), ( - OutletEventAccessor(raw_key=Dataset(uri="test"), extra={"key": "value"}, dataset_alias_events=[]), - DAT.DATASET_EVENT_ACCESSOR, + OutletEventAccessor(raw_key=Asset(uri="test"), extra={"key": "value"}, asset_alias_events=[]), + DAT.ASSET_EVENT_ACCESSOR, equal_outlet_event_accessor, ), ( OutletEventAccessor( - raw_key=DatasetAlias(name="test_alias"), + raw_key=AssetAlias(name="test_alias"), extra={"key": "value"}, - dataset_alias_events=[ - DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={}) + asset_alias_events=[ + AssetAliasEvent(source_alias_name="test_alias", dest_asset_uri="test_uri", extra={}) ], ), - DAT.DATASET_EVENT_ACCESSOR, + DAT.ASSET_EVENT_ACCESSOR, equal_outlet_event_accessor, ), ( - OutletEventAccessor(raw_key="test", extra={"key": "value"}, dataset_alias_events=[]), - DAT.DATASET_EVENT_ACCESSOR, + OutletEventAccessor(raw_key="test", extra={"key": "value"}, asset_alias_events=[]), + DAT.ASSET_EVENT_ACCESSOR, equal_outlet_event_accessor, ), ( @@ -326,8 +326,8 @@ def test_backcompat_deserialize_connection(conn_uri): id=1, filename="test_file", elasticsearch_id="test_id", created_at=datetime.now() ), DagTagPydantic: DagTag(), - DatasetPydantic: Dataset("uri", {}), - DatasetEventPydantic: DatasetEvent(), + AssetPydantic: Asset("uri", {}), + AssetEventPydantic: AssetEvent(), } @@ -354,14 +354,14 @@ def test_backcompat_deserialize_connection(conn_uri): lambda a, b: equal_time(a.execution_date, b.execution_date) and equal_time(a.start_date, b.start_date), ), - # DataSet is already serialized by non-Pydantic serialization. Is DatasetPydantic needed then? + # Asset is already serialized by non-Pydantic serialization. Is AssetPydantic needed then? # ( - # Dataset( + # Asset( # uri="foo://bar", # extra={"foo": "bar"}, # ), - # DatasetPydantic, - # DAT.DATA_SET, + # AssetPydantic, + # DAT.ASSET, # lambda a, b: a.uri == b.uri and a.extra == b.extra, # ), ( @@ -429,12 +429,12 @@ def test_all_pydantic_models_round_trip(): continue classes.add(obj) exclusion_list = { - "DatasetPydantic", + "AssetPydantic", "DagTagPydantic", - "DagScheduleDatasetReferencePydantic", - "TaskOutletDatasetReferencePydantic", + "DagScheduleAssetReferencePydantic", + "TaskOutletAssetReferencePydantic", "DagOwnerAttributesPydantic", - "DatasetEventPydantic", + "AssetEventPydantic", "TriggerPydantic", } for c in sorted(classes, key=str): @@ -490,7 +490,7 @@ def test_serialized_mapped_operator_unmap(dag_maker): assert serialized_unmapped_task.dag is serialized_dag -def test_ser_of_dataset_event_accessor(): +def test_ser_of_asset_event_accessor(): # todo: (Airflow 3.0) we should force reserialization on upgrade d = OutletEventAccessors() d["hi"].extra = "blah1" # todo: this should maybe be forbidden? i.e. can extra be any json or just dict? diff --git a/tests/system/providers/microsoft/azure/example_msfabric.py b/tests/system/providers/microsoft/azure/example_msfabric.py index 7d62a49e0bc31..5f8b0657c4019 100644 --- a/tests/system/providers/microsoft/azure/example_msfabric.py +++ b/tests/system/providers/microsoft/azure/example_msfabric.py @@ -19,7 +19,7 @@ from datetime import datetime from airflow import models -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator DAG_ID = "example_msfabric" @@ -44,7 +44,7 @@ query_parameters={"jobType": "Pipeline"}, dag=dag, outlets=[ - Dataset( + Asset( "workspaces/e90b2873-4812-4dfb-9246-593638165644/items/65448530-e5ec-4aeb-a97e-7cebf5d67c18/jobs/instances?jobType=Pipeline" ) ], diff --git a/tests/system/providers/papermill/input_notebook.ipynb b/tests/system/providers/papermill/input_notebook.ipynb index 6c1d53a5a780c..e985432160a8b 100644 --- a/tests/system/providers/papermill/input_notebook.ipynb +++ b/tests/system/providers/papermill/input_notebook.ipynb @@ -36,6 +36,8 @@ "metadata": {}, "outputs": [], "source": [ + "from __future__ import annotations\n", + "\n", "import scrapbook as sb" ] }, diff --git a/tests/test_utils/compat.py b/tests/test_utils/compat.py index 5daf429cf641f..ca1d7e9c77dfa 100644 --- a/tests/test_utils/compat.py +++ b/tests/test_utils/compat.py @@ -55,6 +55,39 @@ from airflow.models.baseoperator import BaseOperatorLink +if TYPE_CHECKING: + from airflow.models.asset import ( + AssetAliasModel, + AssetDagRunQueue, + AssetEvent, + AssetModel, + DagScheduleAssetReference, + TaskOutletAssetReference, + ) +else: + try: + from airflow.models.asset import ( + AssetAliasModel, + AssetDagRunQueue, + AssetEvent, + AssetModel, + DagScheduleAssetReference, + TaskOutletAssetReference, + ) + except ModuleNotFoundError: + # dataset is renamed to asset since Airflow 3.0 + from airflow.models.dataset import ( + DagScheduleDatasetReference as DagScheduleAssetReference, + DatasetDagRunQueue as AssetDagRunQueue, + DatasetEvent as AssetEvent, + DatasetModel as AssetModel, + TaskOutletDatasetReference as TaskOutletAssetReference, + ) + + if AIRFLOW_V_2_10_PLUS: + from airflow.models.dataset import DatasetAliasModel as AssetAliasModel + + def deserialize_operator(serialized_operator: dict[str, Any]) -> Operator: if AIRFLOW_V_2_10_PLUS: # In airflow 2.10+ we can deserialize operator using regular deserialize method. diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py index bd56ed9175cc4..a5dd94e2d009d 100644 --- a/tests/test_utils/db.py +++ b/tests/test_utils/db.py @@ -38,18 +38,19 @@ from airflow.models.dag import DagOwnerAttributes from airflow.models.dagcode import DagCode from airflow.models.dagwarning import DagWarning -from airflow.models.dataset import ( - DagScheduleDatasetReference, - DatasetDagRunQueue, - DatasetEvent, - DatasetModel, - TaskOutletDatasetReference, -) from airflow.models.serialized_dag import SerializedDagModel from airflow.security.permissions import RESOURCE_DAG_PREFIX from airflow.utils.db import add_default_pool_if_not_exists, create_default_connections, reflect_tables from airflow.utils.session import create_session -from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ParseImportError +from tests.test_utils.compat import ( + AIRFLOW_V_2_10_PLUS, + AssetDagRunQueue, + AssetEvent, + AssetModel, + DagScheduleAssetReference, + ParseImportError, + TaskOutletAssetReference, +) def clear_db_runs(): @@ -74,17 +75,17 @@ def clear_db_backfills(): session.query(Backfill).delete() -def clear_db_datasets(): +def clear_db_assets(): with create_session() as session: - session.query(DatasetEvent).delete() - session.query(DatasetModel).delete() - session.query(DatasetDagRunQueue).delete() - session.query(DagScheduleDatasetReference).delete() - session.query(TaskOutletDatasetReference).delete() + session.query(AssetEvent).delete() + session.query(AssetModel).delete() + session.query(AssetDagRunQueue).delete() + session.query(DagScheduleAssetReference).delete() + session.query(TaskOutletAssetReference).delete() if AIRFLOW_V_2_10_PLUS: - from airflow.models.dataset import DatasetAliasModel + from tests.test_utils.compat import AssetAliasModel - session.query(DatasetAliasModel).delete() + session.query(AssetAliasModel).delete() def clear_db_dags(): @@ -231,7 +232,7 @@ def clear_dag_specific_permissions(): def clear_all(): clear_db_runs() - clear_db_datasets() + clear_db_assets() clear_db_dags() clear_db_serialized_dags() clear_db_sla_miss() diff --git a/tests/timetables/test_datasets_timetable.py b/tests/timetables/test_assets_timetable.py similarity index 57% rename from tests/timetables/test_datasets_timetable.py rename to tests/timetables/test_assets_timetable.py index b456b9bf5dc9c..f2105891c7298 100644 --- a/tests/timetables/test_datasets_timetable.py +++ b/tests/timetables/test_assets_timetable.py @@ -23,11 +23,11 @@ import pytest from pendulum import DateTime -from airflow.datasets import Dataset, DatasetAlias -from airflow.models.dataset import DatasetAliasModel, DatasetEvent, DatasetModel +from airflow.assets import Asset, AssetAlias +from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel +from airflow.timetables.assets import AssetOrTimeSchedule from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable -from airflow.timetables.datasets import DatasetOrTimeSchedule -from airflow.timetables.simple import DatasetTriggeredTimetable +from airflow.timetables.simple import AssetTriggeredTimetable from airflow.utils.types import DagRunType if TYPE_CHECKING: @@ -103,45 +103,45 @@ def test_timetable() -> MockTimetable: @pytest.fixture -def test_datasets() -> list[Dataset]: - """Pytest fixture for creating a list of Dataset objects.""" - return [Dataset("test_dataset")] +def test_assets() -> list[Asset]: + """Pytest fixture for creating a list of Asset objects.""" + return [Asset("test_asset")] @pytest.fixture -def dataset_timetable(test_timetable: MockTimetable, test_datasets: list[Dataset]) -> DatasetOrTimeSchedule: +def asset_timetable(test_timetable: MockTimetable, test_assets: list[Asset]) -> AssetOrTimeSchedule: """ - Pytest fixture for creating a DatasetTimetable object. + Pytest fixture for creating a AssetOrTimeSchedule object. :param test_timetable: The test timetable instance. - :param test_datasets: A list of Dataset instances. + :param test_assets: A list of Asset instances. """ - return DatasetOrTimeSchedule(timetable=test_timetable, datasets=test_datasets) + return AssetOrTimeSchedule(timetable=test_timetable, assets=test_assets) -def test_serialization(dataset_timetable: DatasetOrTimeSchedule, monkeypatch: Any) -> None: +def test_serialization(asset_timetable: AssetOrTimeSchedule, monkeypatch: Any) -> None: """ - Tests the serialization method of DatasetTimetable. + Tests the serialization method of AssetOrTimeSchedule. - :param dataset_timetable: The DatasetTimetable instance to test. + :param asset_timetable: The AssetOrTimeSchedule instance to test. :param monkeypatch: The monkeypatch fixture from pytest. """ monkeypatch.setattr( "airflow.serialization.serialized_objects.encode_timetable", lambda x: "mock_serialized_timetable" ) - serialized = dataset_timetable.serialize() + serialized = asset_timetable.serialize() assert serialized == { "timetable": "mock_serialized_timetable", - "dataset_condition": { - "__type": "dataset_all", - "objects": [{"__type": "dataset", "uri": "test_dataset", "extra": {}}], + "asset_condition": { + "__type": "asset_all", + "objects": [{"__type": "asset", "uri": "test_asset", "extra": {}}], }, } def test_deserialization(monkeypatch: Any) -> None: """ - Tests the deserialization method of DatasetTimetable. + Tests the deserialization method of AssetOrTimeSchedule. :param monkeypatch: The monkeypatch fixture from pytest. """ @@ -150,55 +150,55 @@ def test_deserialization(monkeypatch: Any) -> None: ) mock_serialized_data = { "timetable": "mock_serialized_timetable", - "dataset_condition": { - "__type": "dataset_all", - "objects": [{"__type": "dataset", "uri": "test_dataset", "extra": None}], + "asset_condition": { + "__type": "asset_all", + "objects": [{"__type": "asset", "uri": "test_asset", "extra": None}], }, } - deserialized = DatasetOrTimeSchedule.deserialize(mock_serialized_data) - assert isinstance(deserialized, DatasetOrTimeSchedule) + deserialized = AssetOrTimeSchedule.deserialize(mock_serialized_data) + assert isinstance(deserialized, AssetOrTimeSchedule) -def test_infer_manual_data_interval(dataset_timetable: DatasetOrTimeSchedule) -> None: +def test_infer_manual_data_interval(asset_timetable: AssetOrTimeSchedule) -> None: """ - Tests the infer_manual_data_interval method of DatasetTimetable. + Tests the infer_manual_data_interval method of AssetOrTimeSchedule. - :param dataset_timetable: The DatasetTimetable instance to test. + :param asset_timetable: The AssetOrTimeSchedule instance to test. """ run_after = DateTime.now() - result = dataset_timetable.infer_manual_data_interval(run_after=run_after) + result = asset_timetable.infer_manual_data_interval(run_after=run_after) assert isinstance(result, DataInterval) -def test_next_dagrun_info(dataset_timetable: DatasetOrTimeSchedule) -> None: +def test_next_dagrun_info(asset_timetable: AssetOrTimeSchedule) -> None: """ - Tests the next_dagrun_info method of DatasetTimetable. + Tests the next_dagrun_info method of AssetOrTimeSchedule. - :param dataset_timetable: The DatasetTimetable instance to test. + :param asset_timetable: The AssetOrTimeSchedule instance to test. """ last_interval = DataInterval.exact(DateTime.now()) restriction = TimeRestriction(earliest=DateTime.now(), latest=None, catchup=True) - result = dataset_timetable.next_dagrun_info( + result = asset_timetable.next_dagrun_info( last_automated_data_interval=last_interval, restriction=restriction ) assert result is None or isinstance(result, DagRunInfo) -def test_generate_run_id(dataset_timetable: DatasetOrTimeSchedule) -> None: +def test_generate_run_id(asset_timetable: AssetOrTimeSchedule) -> None: """ - Tests the generate_run_id method of DatasetTimetable. + Tests the generate_run_id method of AssetOrTimeSchedule. - :param dataset_timetable: The DatasetTimetable instance to test. + :param asset_timetable: The AssetOrTimeSchedule instance to test. """ - run_id = dataset_timetable.generate_run_id( + run_id = asset_timetable.generate_run_id( run_type=DagRunType.MANUAL, extra_args="test", logical_date=DateTime.now(), data_interval=None ) assert isinstance(run_id, str) @pytest.fixture -def dataset_events(mocker) -> list[DatasetEvent]: - """Pytest fixture for creating mock DatasetEvent objects.""" +def asset_events(mocker) -> list[AssetEvent]: + """Pytest fixture for creating mock AssetEvent objects.""" now = DateTime.now() earlier = now.subtract(days=1) later = now.add(days=1) @@ -212,9 +212,9 @@ def dataset_events(mocker) -> list[DatasetEvent]: mock_dag_run_later.data_interval_start = now mock_dag_run_later.data_interval_end = later - # Create DatasetEvent objects with mock source_dag_run - event_earlier = DatasetEvent(timestamp=earlier, dataset_id=1) - event_later = DatasetEvent(timestamp=later, dataset_id=1) + # Create AssetEvent objects with mock source_dag_run + event_earlier = AssetEvent(timestamp=earlier, dataset_id=1) + event_later = AssetEvent(timestamp=later, dataset_id=1) # Use mocker to set the source_dag_run attribute to avoid SQLAlchemy's instrumentation mocker.patch.object(event_earlier, "source_dag_run", new=mock_dag_run_earlier) @@ -224,54 +224,50 @@ def dataset_events(mocker) -> list[DatasetEvent]: def test_data_interval_for_events( - dataset_timetable: DatasetOrTimeSchedule, dataset_events: list[DatasetEvent] + asset_timetable: AssetOrTimeSchedule, asset_events: list[AssetEvent] ) -> None: """ - Tests the data_interval_for_events method of DatasetTimetable. + Tests the data_interval_for_events method of AssetOrTimeSchedule. - :param dataset_timetable: The DatasetTimetable instance to test. - :param dataset_events: A list of mock DatasetEvent instances. + :param asset_timetable: The AssetOrTimeSchedule instance to test. + :param asset_events: A list of mock AssetEvent instances. """ - data_interval = dataset_timetable.data_interval_for_events( - logical_date=DateTime.now(), events=dataset_events - ) + data_interval = asset_timetable.data_interval_for_events(logical_date=DateTime.now(), events=asset_events) assert data_interval.start == min( - event.timestamp for event in dataset_events + event.timestamp for event in asset_events ), "Data interval start does not match" assert data_interval.end == max( - event.timestamp for event in dataset_events + event.timestamp for event in asset_events ), "Data interval end does not match" -def test_run_ordering_inheritance(dataset_timetable: DatasetOrTimeSchedule) -> None: +def test_run_ordering_inheritance(asset_timetable: AssetOrTimeSchedule) -> None: """ - Tests that DatasetOrTimeSchedule inherits run_ordering from its parent class correctly. + Tests that AssetOrTimeSchedule inherits run_ordering from its parent class correctly. - :param dataset_timetable: The DatasetTimetable instance to test. + :param asset_timetable: The AssetOrTimeSchedule instance to test. """ assert hasattr( - dataset_timetable, "run_ordering" - ), "DatasetOrTimeSchedule should have 'run_ordering' attribute" - parent_run_ordering = getattr(DatasetTriggeredTimetable, "run_ordering", None) - assert ( - dataset_timetable.run_ordering == parent_run_ordering - ), "run_ordering does not match the parent class" + asset_timetable, "run_ordering" + ), "AssetOrTimeSchedule should have 'run_ordering' attribute" + parent_run_ordering = getattr(AssetTriggeredTimetable, "run_ordering", None) + assert asset_timetable.run_ordering == parent_run_ordering, "run_ordering does not match the parent class" @pytest.mark.db_test def test_summary(session: Session) -> None: - dataset_model = DatasetModel(uri="test_dataset") - dataset_alias_model = DatasetAliasModel(name="test_dataset_alias") - session.add_all([dataset_model, dataset_alias_model]) + asset_model = AssetModel(uri="test_asset") + asset_alias_model = AssetAliasModel(name="test_asset_alias") + session.add_all([asset_model, asset_alias_model]) session.commit() - dataset_alias = DatasetAlias("test_dataset_alias") - table = DatasetTriggeredTimetable(dataset_alias) - assert table.summary == "Unresolved DatasetAlias" + asset_alias = AssetAlias("test_asset_alias") + table = AssetTriggeredTimetable(asset_alias) + assert table.summary == "Unresolved AssetAlias" - dataset_alias_model.datasets.append(dataset_model) - session.add(dataset_alias_model) + asset_alias_model.datasets.append(asset_model) + session.add(asset_alias_model) session.commit() - table = DatasetTriggeredTimetable(dataset_alias) - assert table.summary == "Dataset" + table = AssetTriggeredTimetable(asset_alias) + assert table.summary == "Asset" diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py index 0f4f80f36504c..5d2f7543b6299 100644 --- a/tests/utils/test_context.py +++ b/tests/utils/test_context.py @@ -20,55 +20,55 @@ import pytest -from airflow.datasets import Dataset, DatasetAlias, DatasetAliasEvent -from airflow.models.dataset import DatasetAliasModel, DatasetModel +from airflow.assets import Asset, AssetAlias, AssetAliasEvent +from airflow.models.asset import AssetAliasModel, AssetModel from airflow.utils.context import OutletEventAccessor, OutletEventAccessors class TestOutletEventAccessor: @pytest.mark.parametrize( - "raw_key, dataset_alias_events", + "raw_key, asset_alias_events", ( ( - DatasetAlias("test_alias"), - [DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={})], + AssetAlias("test_alias"), + [AssetAliasEvent(source_alias_name="test_alias", dest_asset_uri="test_uri", extra={})], ), - (Dataset("test_uri"), []), + (Asset("test_uri"), []), ), ) - def test_add(self, raw_key, dataset_alias_events): + def test_add(self, raw_key, asset_alias_events): outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={}) - outlet_event_accessor.add(Dataset("test_uri")) - assert outlet_event_accessor.dataset_alias_events == dataset_alias_events + outlet_event_accessor.add(Asset("test_uri")) + assert outlet_event_accessor.asset_alias_events == asset_alias_events @pytest.mark.db_test @pytest.mark.parametrize( - "raw_key, dataset_alias_events", + "raw_key, asset_alias_events", ( ( - DatasetAlias("test_alias"), - [DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={})], + AssetAlias("test_alias"), + [AssetAliasEvent(source_alias_name="test_alias", dest_asset_uri="test_uri", extra={})], ), ( "test_alias", - [DatasetAliasEvent(source_alias_name="test_alias", dest_dataset_uri="test_uri", extra={})], + [AssetAliasEvent(source_alias_name="test_alias", dest_asset_uri="test_uri", extra={})], ), - (Dataset("test_uri"), []), + (Asset("test_uri"), []), ), ) - def test_add_with_db(self, raw_key, dataset_alias_events, session): - dsm = DatasetModel(uri="test_uri") - dsam = DatasetAliasModel(name="test_alias") - session.add_all([dsm, dsam]) + def test_add_with_db(self, raw_key, asset_alias_events, session): + asm = AssetModel(uri="test_uri") + aam = AssetAliasModel(name="test_alias") + session.add_all([asm, aam]) session.flush() outlet_event_accessor = OutletEventAccessor(raw_key=raw_key, extra={"not": ""}) outlet_event_accessor.add("test_uri", extra={}) - assert outlet_event_accessor.dataset_alias_events == dataset_alias_events + assert outlet_event_accessor.asset_alias_events == asset_alias_events class TestOutletEventAccessors: - @pytest.mark.parametrize("key", ("test", Dataset("test"), DatasetAlias("test_alias"))) + @pytest.mark.parametrize("key", ("test", Asset("test"), AssetAlias("test_alias"))) def test____get_item___dict_key_not_exists(self, key): outlet_event_accessors = OutletEventAccessors() assert len(outlet_event_accessors) == 0 diff --git a/tests/utils/test_db_cleanup.py b/tests/utils/test_db_cleanup.py index 06e99523faa2f..0a8cd9c90c962 100644 --- a/tests/utils/test_db_cleanup.py +++ b/tests/utils/test_db_cleanup.py @@ -48,7 +48,7 @@ run_cleanup, ) from airflow.utils.session import create_session -from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs, drop_tables_with_prefix +from tests.test_utils.db import clear_db_assets, clear_db_dags, clear_db_runs, drop_tables_with_prefix pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -57,11 +57,11 @@ def clean_database(): """Fixture that cleans the database before and after every test.""" clear_db_runs() - clear_db_datasets() + clear_db_assets() clear_db_dags() yield # Test runs here clear_db_dags() - clear_db_datasets() + clear_db_assets() clear_db_runs() diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py index 5e7e6eb1e5c56..5a58b5d790329 100644 --- a/tests/utils/test_json.py +++ b/tests/utils/test_json.py @@ -26,7 +26,7 @@ import pendulum import pytest -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.utils import json as utils_json @@ -85,11 +85,11 @@ def test_encode_raises(self): cls=utils_json.XComEncoder, ) - def test_encode_xcom_dataset(self): - dataset = Dataset("mytest://dataset") - s = json.dumps(dataset, cls=utils_json.XComEncoder) + def test_encode_xcom_asset(self): + asset = Asset("mytest://asset") + s = json.dumps(asset, cls=utils_json.XComEncoder) obj = json.loads(s, cls=utils_json.XComDecoder) - assert dataset.uri == obj.uri + assert asset.uri == obj.uri @pytest.mark.parametrize( "data", diff --git a/tests/www/test_auth.py b/tests/www/test_auth.py index 613812968d27f..40e731b060972 100644 --- a/tests/www/test_auth.py +++ b/tests/www/test_auth.py @@ -34,7 +34,7 @@ "decorator_name, is_authorized_method_name", [ ("has_access_configuration", "is_authorized_configuration"), - ("has_access_dataset", "is_authorized_dataset"), + ("has_access_asset", "is_authorized_asset"), ("has_access_view", "is_authorized_view"), ], ) diff --git a/tests/www/views/test_views_acl.py b/tests/www/views/test_views_acl.py index 053e0f339fb3f..139644f67a6da 100644 --- a/tests/www/views/test_views_acl.py +++ b/tests/www/views/test_views_acl.py @@ -264,22 +264,22 @@ def test_dag_autocomplete_success(client_all_dags): {"name": "airflow", "type": "owner", "dag_display_name": None}, { "dag_display_name": None, - "name": "dataset_alias_example_alias_consumer_with_no_taskflow", + "name": "asset_alias_example_alias_consumer_with_no_taskflow", "type": "dag", }, { "dag_display_name": None, - "name": "dataset_alias_example_alias_producer_with_no_taskflow", + "name": "asset_alias_example_alias_producer_with_no_taskflow", "type": "dag", }, { "dag_display_name": None, - "name": "dataset_s3_bucket_consumer_with_no_taskflow", + "name": "asset_s3_bucket_consumer_with_no_taskflow", "type": "dag", }, { "dag_display_name": None, - "name": "dataset_s3_bucket_producer_with_no_taskflow", + "name": "asset_s3_bucket_producer_with_no_taskflow", "type": "dag", }, { diff --git a/tests/www/views/test_views_dataset.py b/tests/www/views/test_views_dataset.py index 797ed40ba009e..3d3351bb6493a 100644 --- a/tests/www/views/test_views_dataset.py +++ b/tests/www/views/test_views_dataset.py @@ -21,11 +21,11 @@ import pytest from dateutil.tz import UTC -from airflow.datasets import Dataset -from airflow.models.dataset import DatasetEvent, DatasetModel +from airflow.assets import Asset +from airflow.models.asset import AssetEvent, AssetModel from airflow.operators.empty import EmptyOperator from tests.test_utils.asserts import assert_queries_count -from tests.test_utils.db import clear_db_datasets +from tests.test_utils.db import clear_db_assets pytestmark = pytest.mark.db_test @@ -33,23 +33,23 @@ class TestDatasetEndpoint: @pytest.fixture(autouse=True) def cleanup(self): - clear_db_datasets() + clear_db_assets() yield - clear_db_datasets() + clear_db_assets() class TestGetDatasets(TestDatasetEndpoint): def test_should_respond_200(self, admin_client, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( id=i, uri=f"s3://bucket/key/{i}", ) for i in [1, 2] ] - session.add_all(datasets) + session.add_all(assets) session.commit() - assert session.query(DatasetModel).count() == 2 + assert session.query(AssetModel).count() == 2 with assert_queries_count(10): response = admin_client.get("/object/datasets_summary") @@ -75,15 +75,15 @@ def test_should_respond_200(self, admin_client, session): } def test_order_by_raises_400_for_invalid_attr(self, admin_client, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( uri=f"s3://bucket/key/{i}", ) for i in [1, 2] ] - session.add_all(datasets) + session.add_all(assets) session.commit() - assert session.query(DatasetModel).count() == 2 + assert session.query(AssetModel).count() == 2 response = admin_client.get("/object/datasets_summary?order_by=fake") @@ -92,15 +92,10 @@ def test_order_by_raises_400_for_invalid_attr(self, admin_client, session): assert response.json["detail"] == msg def test_order_by_raises_400_for_invalid_datetimes(self, admin_client, session): - datasets = [ - DatasetModel( - uri=f"s3://bucket/key/{i}", - ) - for i in [1, 2] - ] - session.add_all(datasets) + assets = [AssetModel(uri=f"s3://bucket/key/{i}") for i in [1, 2]] + session.add_all(assets) session.commit() - assert session.query(DatasetModel).count() == 2 + assert session.query(AssetModel).count() == 2 response = admin_client.get("/object/datasets_summary?updated_before=null") @@ -115,25 +110,25 @@ def test_order_by_raises_400_for_invalid_datetimes(self, admin_client, session): def test_filter_by_datetimes(self, admin_client, session): today = pendulum.today("UTC") - datasets = [ - DatasetModel( + assets = [ + AssetModel( id=i, uri=f"s3://bucket/key/{i}", ) for i in range(1, 4) ] - session.add_all(datasets) - # Update datasets, one per day, starting with datasets[0], ending with datasets[2] - dataset_events = [ - DatasetEvent( - dataset_id=datasets[i].id, - timestamp=today.add(days=-len(datasets) + i + 1), + session.add_all(assets) + # Update assets, one per day, starting with assets[0], ending with assets[2] + asset_events = [ + AssetEvent( + dataset_id=assets[i].id, + timestamp=today.add(days=-len(assets) + i + 1), ) - for i in range(len(datasets)) + for i in range(len(assets)) ] - session.add_all(dataset_events) + session.add_all(asset_events) session.commit() - assert session.query(DatasetModel).count() == len(datasets) + assert session.query(AssetModel).count() == len(assets) cutoff = today.add(days=-1).add(minutes=-5).to_iso8601_string() response = admin_client.get(f"/object/datasets_summary?updated_after={cutoff}") @@ -150,7 +145,7 @@ def test_filter_by_datetimes(self, admin_client, session): assert [json_dict["id"] for json_dict in response.json["datasets"]] == [1, 2] @pytest.mark.parametrize( - "order_by, ordered_dataset_ids", + "order_by, ordered_asset_ids", [ ("uri", [1, 2, 3, 4]), ("-uri", [4, 3, 2, 1]), @@ -158,50 +153,50 @@ def test_filter_by_datetimes(self, admin_client, session): ("-last_dataset_update", [2, 3, 1, 4]), ], ) - def test_order_by(self, admin_client, session, order_by, ordered_dataset_ids): - datasets = [ - DatasetModel( + def test_order_by(self, admin_client, session, order_by, ordered_asset_ids): + assets = [ + AssetModel( id=i, uri=f"s3://bucket/key/{i}", ) - for i in range(1, len(ordered_dataset_ids) + 1) + for i in range(1, len(ordered_asset_ids) + 1) ] - session.add_all(datasets) - dataset_events = [ - DatasetEvent( - dataset_id=datasets[2].id, + session.add_all(assets) + asset_events = [ + AssetEvent( + dataset_id=assets[2].id, timestamp=pendulum.today("UTC").add(days=-3), ), - DatasetEvent( - dataset_id=datasets[1].id, + AssetEvent( + dataset_id=assets[1].id, timestamp=pendulum.today("UTC").add(days=-2), ), - DatasetEvent( - dataset_id=datasets[1].id, + AssetEvent( + dataset_id=assets[1].id, timestamp=pendulum.today("UTC").add(days=-1), ), ] - session.add_all(dataset_events) + session.add_all(asset_events) session.commit() - assert session.query(DatasetModel).count() == len(ordered_dataset_ids) + assert session.query(AssetModel).count() == len(ordered_asset_ids) response = admin_client.get(f"/object/datasets_summary?order_by={order_by}") assert response.status_code == 200 - assert ordered_dataset_ids == [json_dict["id"] for json_dict in response.json["datasets"]] - assert response.json["total_entries"] == len(ordered_dataset_ids) + assert ordered_asset_ids == [json_dict["id"] for json_dict in response.json["datasets"]] + assert response.json["total_entries"] == len(ordered_asset_ids) def test_search_uri_pattern(self, admin_client, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( id=i, uri=f"s3://bucket/key_{i}", ) for i in [1, 2] ] - session.add_all(datasets) + session.add_all(assets) session.commit() - assert session.query(DatasetModel).count() == 2 + assert session.query(AssetModel).count() == 2 uri_pattern = "key_2" response = admin_client.get(f"/object/datasets_summary?uri_pattern={uri_pattern}") @@ -246,92 +241,92 @@ def test_search_uri_pattern(self, admin_client, session): @pytest.mark.need_serialized_dag def test_correct_counts_update(self, admin_client, session, dag_maker, app, monkeypatch): with monkeypatch.context() as m: - datasets = [Dataset(uri=f"s3://bucket/key/{i}") for i in [1, 2, 3, 4, 5]] + assets = [Asset(uri=f"s3://bucket/key/{i}") for i in [1, 2, 3, 4, 5]] - # DAG that produces dataset #1 + # DAG that produces asset #1 with dag_maker(dag_id="upstream", schedule=None, serialized=True, session=session): - EmptyOperator(task_id="task1", outlets=[datasets[0]]) + EmptyOperator(task_id="task1", outlets=[assets[0]]) - # DAG that is consumes only datasets #1 and #2 - with dag_maker(dag_id="downstream", schedule=datasets[:2], serialized=True, session=session): + # DAG that is consumes only assets #1 and #2 + with dag_maker(dag_id="downstream", schedule=assets[:2], serialized=True, session=session): EmptyOperator(task_id="task1") - # We create multiple dataset-producing and dataset-consuming DAGs because the query requires + # We create multiple asset-producing and asset-consuming DAGs because the query requires # COUNT(DISTINCT ...) for total_updates, or else it returns a multiple of the correct number due - # to the outer joins with DagScheduleDatasetReference and TaskOutletDatasetReference - # Two independent DAGs that produce dataset #3 + # to the outer joins with DagScheduleAssetReference and TaskOutletAssetReference + # Two independent DAGs that produce asset #3 with dag_maker(dag_id="independent_producer_1", serialized=True, session=session): - EmptyOperator(task_id="task1", outlets=[datasets[2]]) + EmptyOperator(task_id="task1", outlets=[assets[2]]) with dag_maker(dag_id="independent_producer_2", serialized=True, session=session): - EmptyOperator(task_id="task1", outlets=[datasets[2]]) - # Two independent DAGs that consume dataset #4 + EmptyOperator(task_id="task1", outlets=[assets[2]]) + # Two independent DAGs that consume asset #4 with dag_maker( dag_id="independent_consumer_1", - schedule=[datasets[3]], + schedule=[assets[3]], serialized=True, session=session, ): EmptyOperator(task_id="task1") with dag_maker( dag_id="independent_consumer_2", - schedule=[datasets[3]], + schedule=[assets[3]], serialized=True, session=session, ): EmptyOperator(task_id="task1") - # Independent DAG that is produces and consumes the same dataset, #5 + # Independent DAG that is produces and consumes the same asset, #5 with dag_maker( dag_id="independent_producer_self_consumer", - schedule=[datasets[4]], + schedule=[assets[4]], serialized=True, session=session, ): - EmptyOperator(task_id="task1", outlets=[datasets[4]]) + EmptyOperator(task_id="task1", outlets=[assets[4]]) m.setattr(app, "dag_bag", dag_maker.dagbag) - ds1_id = session.query(DatasetModel.id).filter_by(uri=datasets[0].uri).scalar() - ds2_id = session.query(DatasetModel.id).filter_by(uri=datasets[1].uri).scalar() - ds3_id = session.query(DatasetModel.id).filter_by(uri=datasets[2].uri).scalar() - ds4_id = session.query(DatasetModel.id).filter_by(uri=datasets[3].uri).scalar() - ds5_id = session.query(DatasetModel.id).filter_by(uri=datasets[4].uri).scalar() + asset1_id = session.query(AssetModel.id).filter_by(uri=assets[0].uri).scalar() + asset2_id = session.query(AssetModel.id).filter_by(uri=assets[1].uri).scalar() + asset3_id = session.query(AssetModel.id).filter_by(uri=assets[2].uri).scalar() + asset4_id = session.query(AssetModel.id).filter_by(uri=assets[3].uri).scalar() + asset5_id = session.query(AssetModel.id).filter_by(uri=assets[4].uri).scalar() - # dataset 1 events + # asset 1 events session.add_all( [ - DatasetEvent( - dataset_id=ds1_id, + AssetEvent( + dataset_id=asset1_id, timestamp=pendulum.DateTime(2022, 8, 1, i, tzinfo=UTC), ) for i in range(3) ] ) - # dataset 3 events + # asset 3 events session.add_all( [ - DatasetEvent( - dataset_id=ds3_id, + AssetEvent( + dataset_id=asset3_id, timestamp=pendulum.DateTime(2022, 8, 1, i, tzinfo=UTC), ) for i in range(3) ] ) - # dataset 4 events + # asset 4 events session.add_all( [ - DatasetEvent( - dataset_id=ds4_id, + AssetEvent( + dataset_id=asset4_id, timestamp=pendulum.DateTime(2022, 8, 1, i, tzinfo=UTC), ) for i in range(4) ] ) - # dataset 5 events + # asset 5 events session.add_all( [ - DatasetEvent( - dataset_id=ds5_id, + AssetEvent( + dataset_id=asset5_id, timestamp=pendulum.DateTime(2022, 8, 1, i, tzinfo=UTC), ) for i in range(5) @@ -346,31 +341,31 @@ def test_correct_counts_update(self, admin_client, session, dag_maker, app, monk assert response_data == { "datasets": [ { - "id": ds1_id, + "id": asset1_id, "uri": "s3://bucket/key/1", "last_dataset_update": "2022-08-01T02:00:00+00:00", "total_updates": 3, }, { - "id": ds2_id, + "id": asset2_id, "uri": "s3://bucket/key/2", "last_dataset_update": None, "total_updates": 0, }, { - "id": ds3_id, + "id": asset3_id, "uri": "s3://bucket/key/3", "last_dataset_update": "2022-08-01T02:00:00+00:00", "total_updates": 3, }, { - "id": ds4_id, + "id": asset4_id, "uri": "s3://bucket/key/4", "last_dataset_update": "2022-08-01T03:00:00+00:00", "total_updates": 4, }, { - "id": ds5_id, + "id": asset5_id, "uri": "s3://bucket/key/5", "last_dataset_update": "2022-08-01T04:00:00+00:00", "total_updates": 5, @@ -395,14 +390,14 @@ class TestGetDatasetsEndpointPagination(TestDatasetEndpoint): ], ) def test_limit_and_offset(self, admin_client, session, url, expected_dataset_uris): - datasets = [ - DatasetModel( + assets = [ + AssetModel( uri=f"s3://bucket/key/{i}", extra={"foo": "bar"}, ) for i in range(1, 10) ] - session.add_all(datasets) + session.add_all(assets) session.commit() response = admin_client.get(url) @@ -412,14 +407,14 @@ def test_limit_and_offset(self, admin_client, session, url, expected_dataset_uri assert dataset_uris == expected_dataset_uris def test_should_respect_page_size_limit_default(self, admin_client, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( uri=f"s3://bucket/key/{i}", extra={"foo": "bar"}, ) for i in range(1, 60) ] - session.add_all(datasets) + session.add_all(assets) session.commit() response = admin_client.get("/object/datasets_summary") @@ -428,14 +423,14 @@ def test_should_respect_page_size_limit_default(self, admin_client, session): assert len(response.json["datasets"]) == 25 def test_should_return_max_if_req_above(self, admin_client, session): - datasets = [ - DatasetModel( + assets = [ + AssetModel( uri=f"s3://bucket/key/{i}", extra={"foo": "bar"}, ) for i in range(1, 60) ] - session.add_all(datasets) + session.add_all(assets) session.commit() response = admin_client.get("/object/datasets_summary?limit=180") @@ -446,7 +441,7 @@ def test_should_return_max_if_req_above(self, admin_client, session): class TestGetDatasetNextRunSummary(TestDatasetEndpoint): def test_next_run_dataset_summary(self, dag_maker, admin_client): - with dag_maker(dag_id="upstream", schedule=[Dataset(uri="s3://bucket/key/1")], serialized=True): + with dag_maker(dag_id="upstream", schedule=[Asset(uri="s3://bucket/key/1")], serialized=True): EmptyOperator(task_id="task1") response = admin_client.post("/next_run_datasets_summary", data={"dag_ids": ["upstream"]}) diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 8726b67e1dad3..b4dd6f6082e57 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -24,11 +24,11 @@ import pytest from dateutil.tz import UTC -from airflow.datasets import Dataset +from airflow.assets import Asset from airflow.decorators import task_group from airflow.lineage.entities import File from airflow.models import DagBag -from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel +from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel from airflow.operators.empty import EmptyOperator from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState @@ -36,7 +36,7 @@ from airflow.utils.types import DagRunType from airflow.www.views import dag_to_grid from tests.test_utils.asserts import assert_queries_count -from tests.test_utils.db import clear_db_datasets, clear_db_runs +from tests.test_utils.db import clear_db_assets, clear_db_runs from tests.test_utils.mock_operators import MockOperator pytestmark = pytest.mark.db_test @@ -56,10 +56,10 @@ def examples_dag_bag(): @pytest.fixture(autouse=True) def clean(): clear_db_runs() - clear_db_datasets() + clear_db_assets() yield clear_db_runs() - clear_db_datasets() + clear_db_assets() @pytest.fixture @@ -419,7 +419,7 @@ def test_query_count(dag_with_runs, session): dag_to_grid(run1.dag, (run1, run2), session) -def test_has_outlet_dataset_flag(admin_client, dag_maker, session, app, monkeypatch): +def test_has_outlet_asset_flag(admin_client, dag_maker, session, app, monkeypatch): with monkeypatch.context() as m: # Remove global operator links for this test m.setattr("airflow.plugins_manager.global_operator_extra_links", []) @@ -430,8 +430,8 @@ def test_has_outlet_dataset_flag(admin_client, dag_maker, session, app, monkeypa lineagefile = File("/tmp/does_not_exist") EmptyOperator(task_id="task1") EmptyOperator(task_id="task2", outlets=[lineagefile]) - EmptyOperator(task_id="task3", outlets=[Dataset("foo"), lineagefile]) - EmptyOperator(task_id="task4", outlets=[Dataset("foo")]) + EmptyOperator(task_id="task3", outlets=[Asset("foo"), lineagefile]) + EmptyOperator(task_id="task4", outlets=[Asset("foo")]) m.setattr(app, "dag_bag", dag_maker.dagbag) resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}", follow_redirects=True) @@ -470,37 +470,37 @@ def _expected_task_details(task_id, has_outlet_datasets): @pytest.mark.need_serialized_dag def test_next_run_datasets(admin_client, dag_maker, session, app, monkeypatch): with monkeypatch.context() as m: - datasets = [Dataset(uri=f"s3://bucket/key/{i}") for i in [1, 2]] + assets = [Asset(uri=f"s3://bucket/key/{i}") for i in [1, 2]] - with dag_maker(dag_id=DAG_ID, schedule=datasets, serialized=True, session=session): + with dag_maker(dag_id=DAG_ID, schedule=assets, serialized=True, session=session): EmptyOperator(task_id="task1") m.setattr(app, "dag_bag", dag_maker.dagbag) - ds1_id = session.query(DatasetModel.id).filter_by(uri=datasets[0].uri).scalar() - ds2_id = session.query(DatasetModel.id).filter_by(uri=datasets[1].uri).scalar() - ddrq = DatasetDagRunQueue( - target_dag_id=DAG_ID, dataset_id=ds1_id, created_at=pendulum.DateTime(2022, 8, 2, tzinfo=UTC) + asset1_id = session.query(AssetModel.id).filter_by(uri=assets[0].uri).scalar() + asset2_id = session.query(AssetModel.id).filter_by(uri=assets[1].uri).scalar() + adrq = AssetDagRunQueue( + target_dag_id=DAG_ID, dataset_id=asset1_id, created_at=pendulum.DateTime(2022, 8, 2, tzinfo=UTC) ) - session.add(ddrq) - dataset_events = [ - DatasetEvent( - dataset_id=ds1_id, + session.add(adrq) + asset_events = [ + AssetEvent( + dataset_id=asset1_id, extra={}, timestamp=pendulum.DateTime(2022, 8, 1, 1, tzinfo=UTC), ), - DatasetEvent( - dataset_id=ds1_id, + AssetEvent( + dataset_id=asset1_id, extra={}, timestamp=pendulum.DateTime(2022, 8, 2, 1, tzinfo=UTC), ), - DatasetEvent( - dataset_id=ds1_id, + AssetEvent( + dataset_id=asset1_id, extra={}, timestamp=pendulum.DateTime(2022, 8, 2, 2, tzinfo=UTC), ), ] - session.add_all(dataset_events) + session.add_all(asset_events) session.commit() resp = admin_client.get(f"/object/next_run_datasets/{DAG_ID}", follow_redirects=True) @@ -509,8 +509,8 @@ def test_next_run_datasets(admin_client, dag_maker, session, app, monkeypatch): assert resp.json == { "dataset_expression": {"all": ["s3://bucket/key/1", "s3://bucket/key/2"]}, "events": [ - {"id": ds1_id, "uri": "s3://bucket/key/1", "lastUpdate": "2022-08-02T02:00:00+00:00"}, - {"id": ds2_id, "uri": "s3://bucket/key/2", "lastUpdate": None}, + {"id": asset1_id, "uri": "s3://bucket/key/1", "lastUpdate": "2022-08-02T02:00:00+00:00"}, + {"id": asset2_id, "uri": "s3://bucket/key/2", "lastUpdate": None}, ], } From 8350351293734d72273d6f62a0175c4e833e5869 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Sun, 29 Sep 2024 23:27:41 -0700 Subject: [PATCH 216/349] Add 'name' and 'group' to DatasetModel (#42407) The unique index is also modified to include 'name', so now an asset is considered unique if *either* the name or URI is different. This makes no difference for the moment---the name is simply populated from URI. We'll add a public interface to set the name in a later PR. This PR strictly only touches the model so it does not conflict with too many things, and can be merged quickly. The unique index on DatasetAliasModel is also renamed since we were using a wrong naming convention on both models. Since the index namespace is shared in the entire database, the index name should include additional components. The idx_name_unique is still usable, but we should a better citizen and name this the right way(tm). --- airflow/assets/__init__.py | 2 +- ...4_3_0_0_add_name_field_to_dataset_model.py | 94 + airflow/models/asset.py | 43 +- airflow/utils/db.py | 2 +- docs/apache-airflow/img/airflow_erd.sha256 | 2 +- docs/apache-airflow/img/airflow_erd.svg | 2266 +++++++++-------- docs/apache-airflow/migrations-ref.rst | 4 +- 7 files changed, 1272 insertions(+), 1141 deletions(-) create mode 100644 airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py index 9727e408edc2e..deb9aa593ded5 100644 --- a/airflow/assets/__init__.py +++ b/airflow/assets/__init__.py @@ -256,7 +256,7 @@ class Asset(os.PathLike, BaseAsset): uri: str = attr.field( converter=_sanitize_uri, - validator=[attr.validators.min_len(1), attr.validators.max_len(3000)], + validator=[attr.validators.min_len(1), attr.validators.max_len(1500)], ) extra: dict[str, Any] = attr.field(factory=dict, converter=_set_extra_default) diff --git a/airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py b/airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py new file mode 100644 index 0000000000000..5c8aec69e9be9 --- /dev/null +++ b/airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add name and group fields to DatasetModel. + +The unique index on DatasetModel is also modified to include name. Existing rows +have their name copied from URI. + +While not strictly related to other changes, the index name on DatasetAliasModel +is also renamed. Index names are scoped to the entire database. Airflow generally +includes the table's name to manually scope the index, but ``idx_uri_unique`` +(on DatasetModel) and ``idx_name_unique`` (on DatasetAliasModel) do not do this. +The one on DatasetModel is already renamed in this PR (to include name), so we +also rename the one on DatasetAliasModel here for consistency. + +Revision ID: 0d9e73a75ee4 +Revises: 16cbcb1c8c36 +Create Date: 2024-08-13 09:45:32.213222 +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.orm import Session + +# revision identifiers, used by Alembic. +revision = "0d9e73a75ee4" +down_revision = "16cbcb1c8c36" +branch_labels = None +depends_on = None +airflow_version = "3.0.0" + +_STRING_COLUMN_TYPE = sa.String(length=1500).with_variant( + sa.String(length=1500, collation="latin1_general_cs"), + dialect_name="mysql", +) + + +def upgrade(): + # Fix index name on DatasetAlias. + with op.batch_alter_table("dataset_alias", schema=None) as batch_op: + batch_op.drop_index("idx_name_unique") + batch_op.create_index("idx_dataset_alias_name_unique", ["name"], unique=True) + # Add 'name' column. Set it to nullable for now. + with op.batch_alter_table("dataset", schema=None) as batch_op: + batch_op.add_column(sa.Column("name", _STRING_COLUMN_TYPE)) + batch_op.add_column(sa.Column("group", _STRING_COLUMN_TYPE, default=str, nullable=False)) + # Fill name from uri column. + Session(bind=op.get_bind()).execute(sa.text("update dataset set name=uri")) + # Set the name column non-nullable. + # Now with values in there, we can create the new unique constraint and index. + # Due to MySQL restrictions, we are also reducing the length on uri. + with op.batch_alter_table("dataset", schema=None) as batch_op: + batch_op.alter_column("name", existing_type=_STRING_COLUMN_TYPE, nullable=False) + batch_op.alter_column("uri", type_=_STRING_COLUMN_TYPE, nullable=False) + batch_op.drop_index("idx_uri_unique") + batch_op.create_index("idx_dataset_name_uri_unique", ["name", "uri"], unique=True) + + +def downgrade(): + with op.batch_alter_table("dataset", schema=None) as batch_op: + batch_op.drop_index("idx_dataset_name_uri_unique") + batch_op.create_index("idx_uri_unique", ["uri"], unique=True) + with op.batch_alter_table("dataset", schema=None) as batch_op: + batch_op.drop_column("group") + batch_op.drop_column("name") + batch_op.alter_column( + "uri", + type_=sa.String(length=3000).with_variant( + sa.String(length=3000, collation="latin1_general_cs"), + dialect_name="mysql", + ), + nullable=False, + ) + with op.batch_alter_table("dataset_alias", schema=None) as batch_op: + batch_op.drop_index("idx_dataset_alias_name_unique") + batch_op.create_index("idx_name_unique", ["name"], unique=True) diff --git a/airflow/models/asset.py b/airflow/models/asset.py index b99aa86f2c889..fb56bc4bf1ecf 100644 --- a/airflow/models/asset.py +++ b/airflow/models/asset.py @@ -106,7 +106,7 @@ class AssetAliasModel(Base): __tablename__ = "dataset_alias" __table_args__ = ( - Index("idx_name_unique", name, unique=True), + Index("idx_dataset_alias_name_unique", name, unique=True), {"sqlite_autoincrement": True}, # ensures PK values not reused ) @@ -151,10 +151,22 @@ class AssetModel(Base): """ id = Column(Integer, primary_key=True, autoincrement=True) + name = Column( + String(length=1500).with_variant( + String( + length=1500, + # latin1 allows for more indexed length in mysql + # and this field should only be ascii chars + collation="latin1_general_cs", + ), + "mysql", + ), + nullable=False, + ) uri = Column( - String(length=3000).with_variant( + String(length=1500).with_variant( String( - length=3000, + length=1500, # latin1 allows for more indexed length in mysql # and this field should only be ascii chars collation="latin1_general_cs", @@ -163,7 +175,21 @@ class AssetModel(Base): ), nullable=False, ) + group = Column( + String(length=1500).with_variant( + String( + length=1500, + # latin1 allows for more indexed length in mysql + # and this field should only be ascii chars + collation="latin1_general_cs", + ), + "mysql", + ), + default=str, + nullable=False, + ) extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) is_orphaned = Column(Boolean, default=False, nullable=False, server_default="0") @@ -173,7 +199,7 @@ class AssetModel(Base): __tablename__ = "dataset" __table_args__ = ( - Index("idx_uri_unique", uri, unique=True), + Index("idx_dataset_name_uri_unique", name, uri, unique=True), {"sqlite_autoincrement": True}, # ensures PK values not reused ) @@ -189,16 +215,15 @@ def __init__(self, uri: str, **kwargs): parsed = urlsplit(uri) if parsed.scheme and parsed.scheme.lower() == "airflow": raise ValueError("Scheme `airflow` is reserved.") - super().__init__(uri=uri, **kwargs) + super().__init__(name=uri, uri=uri, **kwargs) def __eq__(self, other): if isinstance(other, (self.__class__, Asset)): - return self.uri == other.uri - else: - return NotImplemented + return self.name == other.name and self.uri == other.uri + return NotImplemented def __hash__(self): - return hash(self.uri) + return hash((self.name, self.uri)) def __repr__(self): return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})" diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 512195c3aa963..8a254d4fef4d8 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -96,7 +96,7 @@ class MappedClassProtocol(Protocol): "2.9.0": "1949afb29106", "2.9.2": "686269002441", "2.10.0": "22ed7efa9da2", - "3.0.0": "16cbcb1c8c36", + "3.0.0": "0d9e73a75ee4", } diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index 237c598ec1dc8..e4a952da1b9fd 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -f4379048d3f13f35aaba824c00450c17ad4deea9af82b5498d755a12f8a85a37 \ No newline at end of file +c33e9a583a5b29eb748ebd50e117643e11bcb2a9b61ec017efd690621e22769b \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index 65f94c58ad24a..76fbd8f841f25 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -4,11 +4,11 @@ - - + + %3 - + log @@ -527,244 +527,254 @@ dataset_alias_dataset - -dataset_alias_dataset - -alias_id - - [INTEGER] - NOT NULL - -dataset_id - - [INTEGER] - NOT NULL + +dataset_alias_dataset + +alias_id + + [INTEGER] + NOT NULL + +dataset_id + + [INTEGER] + NOT NULL dataset_alias--dataset_alias_dataset - -0..N -1 + +0..N +1 dataset_alias--dataset_alias_dataset - -0..N -1 + +0..N +1 dataset_alias_dataset_event - -dataset_alias_dataset_event - -alias_id - - [INTEGER] - NOT NULL - -event_id - - [INTEGER] - NOT NULL + +dataset_alias_dataset_event + +alias_id + + [INTEGER] + NOT NULL + +event_id + + [INTEGER] + NOT NULL dataset_alias--dataset_alias_dataset_event - -0..N -1 + +0..N +1 dataset_alias--dataset_alias_dataset_event - -0..N -1 + +0..N +1 dag_schedule_dataset_alias_reference - -dag_schedule_dataset_alias_reference - -alias_id - - [INTEGER] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL + +dag_schedule_dataset_alias_reference + +alias_id + + [INTEGER] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL dataset_alias--dag_schedule_dataset_alias_reference - -0..N -1 + +0..N +1 dataset - -dataset - -id - - [INTEGER] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL - -extra - - [JSON] - NOT NULL - -is_orphaned - - [BOOLEAN] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL - -uri - - [VARCHAR(3000)] - NOT NULL + +dataset + +id + + [INTEGER] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL + +extra + + [JSON] + NOT NULL + +group + + [VARCHAR(1500)] + NOT NULL + +is_orphaned + + [BOOLEAN] + NOT NULL + +name + + [VARCHAR(1500)] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL + +uri + + [VARCHAR(1500)] + NOT NULL dataset--dataset_alias_dataset - -0..N -1 + +0..N +1 dataset--dataset_alias_dataset - -0..N -1 + +0..N +1 dag_schedule_dataset_reference - -dag_schedule_dataset_reference - -dag_id - - [VARCHAR(250)] - NOT NULL - -dataset_id - - [INTEGER] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL + +dag_schedule_dataset_reference + +dag_id + + [VARCHAR(250)] + NOT NULL + +dataset_id + + [INTEGER] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL dataset--dag_schedule_dataset_reference - -0..N -1 + +0..N +1 task_outlet_dataset_reference - -task_outlet_dataset_reference - -dag_id - - [VARCHAR(250)] - NOT NULL - -dataset_id - - [INTEGER] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL + +task_outlet_dataset_reference + +dag_id + + [VARCHAR(250)] + NOT NULL + +dataset_id + + [INTEGER] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL dataset--task_outlet_dataset_reference - -0..N -1 + +0..N +1 dataset_dag_run_queue - -dataset_dag_run_queue - -dataset_id - - [INTEGER] - NOT NULL - -target_dag_id - - [VARCHAR(250)] - NOT NULL - -created_at - - [TIMESTAMP] - NOT NULL + +dataset_dag_run_queue + +dataset_id + + [INTEGER] + NOT NULL + +target_dag_id + + [VARCHAR(250)] + NOT NULL + +created_at + + [TIMESTAMP] + NOT NULL dataset--dataset_dag_run_queue - -0..N -1 + +0..N +1 @@ -811,39 +821,39 @@ dataset_event--dataset_alias_dataset_event - -0..N -1 + +0..N +1 dataset_event--dataset_alias_dataset_event - -0..N -1 + +0..N +1 dagrun_dataset_event - -dagrun_dataset_event - -dag_run_id - - [INTEGER] - NOT NULL - -event_id - - [INTEGER] - NOT NULL + +dagrun_dataset_event + +dag_run_id + + [INTEGER] + NOT NULL + +event_id + + [INTEGER] + NOT NULL dataset_event--dagrun_dataset_event - -0..N -1 + +0..N +1 @@ -962,114 +972,114 @@ dag--dag_schedule_dataset_alias_reference - -0..N -1 + +0..N +1 dag--dag_schedule_dataset_reference - -0..N -1 + +0..N +1 dag--task_outlet_dataset_reference - -0..N -1 + +0..N +1 dag--dataset_dag_run_queue - -0..N -1 + +0..N +1 dag_tag - -dag_tag - -dag_id - - [VARCHAR(250)] - NOT NULL - -name - - [VARCHAR(100)] - NOT NULL + +dag_tag + +dag_id + + [VARCHAR(250)] + NOT NULL + +name + + [VARCHAR(100)] + NOT NULL dag--dag_tag - -0..N -1 + +0..N +1 dag_owner_attributes - -dag_owner_attributes - -dag_id - - [VARCHAR(250)] - NOT NULL - -owner - - [VARCHAR(500)] - NOT NULL - -link - - [VARCHAR(500)] - NOT NULL + +dag_owner_attributes + +dag_id + + [VARCHAR(250)] + NOT NULL + +owner + + [VARCHAR(500)] + NOT NULL + +link + + [VARCHAR(500)] + NOT NULL dag--dag_owner_attributes - -0..N -1 + +0..N +1 dag_warning - -dag_warning - -dag_id - - [VARCHAR(250)] - NOT NULL - -warning_type - - [VARCHAR(50)] - NOT NULL - -message - - [TEXT] - NOT NULL - -timestamp - - [TIMESTAMP] - NOT NULL + +dag_warning + +dag_id + + [VARCHAR(250)] + NOT NULL + +warning_type + + [VARCHAR(50)] + NOT NULL + +message + + [TEXT] + NOT NULL + +timestamp + + [TIMESTAMP] + NOT NULL dag--dag_warning - -0..N -1 + +0..N +1 @@ -1199,813 +1209,813 @@ dag_run--dagrun_dataset_event - -0..N -1 + +0..N +1 task_instance - -task_instance - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -custom_operator_name - - [VARCHAR(1000)] - -duration - - [DOUBLE_PRECISION] - -end_date - - [TIMESTAMP] - -executor - - [VARCHAR(1000)] - -executor_config - - [BYTEA] - -external_executor_id - - [VARCHAR(250)] - -hostname - - [VARCHAR(1000)] - -job_id - - [INTEGER] - -max_tries - - [INTEGER] - -next_kwargs - - [JSON] - -next_method - - [VARCHAR(1000)] - -operator - - [VARCHAR(1000)] - -pid - - [INTEGER] - -pool - - [VARCHAR(256)] - NOT NULL - -pool_slots - - [INTEGER] - NOT NULL - -priority_weight - - [INTEGER] - -queue - - [VARCHAR(256)] - -queued_by_job_id - - [INTEGER] - -queued_dttm - - [TIMESTAMP] - -rendered_map_index - - [VARCHAR(250)] - -start_date - - [TIMESTAMP] - -state - - [VARCHAR(20)] - -task_display_name - - [VARCHAR(2000)] - -trigger_id - - [INTEGER] - -trigger_timeout - - [TIMESTAMP] - -try_number - - [INTEGER] - -unixname - - [VARCHAR(1000)] - -updated_at - - [TIMESTAMP] + +task_instance + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +custom_operator_name + + [VARCHAR(1000)] + +duration + + [DOUBLE_PRECISION] + +end_date + + [TIMESTAMP] + +executor + + [VARCHAR(1000)] + +executor_config + + [BYTEA] + +external_executor_id + + [VARCHAR(250)] + +hostname + + [VARCHAR(1000)] + +job_id + + [INTEGER] + +max_tries + + [INTEGER] + +next_kwargs + + [JSON] + +next_method + + [VARCHAR(1000)] + +operator + + [VARCHAR(1000)] + +pid + + [INTEGER] + +pool + + [VARCHAR(256)] + NOT NULL + +pool_slots + + [INTEGER] + NOT NULL + +priority_weight + + [INTEGER] + +queue + + [VARCHAR(256)] + +queued_by_job_id + + [INTEGER] + +queued_dttm + + [TIMESTAMP] + +rendered_map_index + + [VARCHAR(250)] + +start_date + + [TIMESTAMP] + +state + + [VARCHAR(20)] + +task_display_name + + [VARCHAR(2000)] + +trigger_id + + [INTEGER] + +trigger_timeout + + [TIMESTAMP] + +try_number + + [INTEGER] + +unixname + + [VARCHAR(1000)] + +updated_at + + [TIMESTAMP] dag_run--task_instance - -0..N -1 + +0..N +1 dag_run--task_instance - -0..N -1 + +0..N +1 dag_run_note - -dag_run_note - -dag_run_id - - [INTEGER] - NOT NULL - -content - - [VARCHAR(1000)] - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL - -user_id - - [INTEGER] + +dag_run_note + +dag_run_id + + [INTEGER] + NOT NULL + +content + + [VARCHAR(1000)] + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL + +user_id + + [INTEGER] dag_run--dag_run_note - -1 -1 + +1 +1 task_reschedule - -task_reschedule - -id - - [INTEGER] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -duration - - [INTEGER] - NOT NULL - -end_date - - [TIMESTAMP] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -reschedule_date - - [TIMESTAMP] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -start_date - - [TIMESTAMP] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -try_number - - [INTEGER] - NOT NULL + +task_reschedule + +id + + [INTEGER] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +duration + + [INTEGER] + NOT NULL + +end_date + + [TIMESTAMP] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +reschedule_date + + [TIMESTAMP] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +start_date + + [TIMESTAMP] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +try_number + + [INTEGER] + NOT NULL dag_run--task_reschedule - -0..N -1 + +0..N +1 dag_run--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 task_instance--task_reschedule - -0..N -1 + +0..N +1 rendered_task_instance_fields - -rendered_task_instance_fields - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -k8s_pod_yaml - - [JSON] - -rendered_fields - - [JSON] - NOT NULL + +rendered_task_instance_fields + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +k8s_pod_yaml + + [JSON] + +rendered_fields + + [JSON] + NOT NULL task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_instance--rendered_task_instance_fields - -0..N -1 + +0..N +1 task_fail - -task_fail - -id - - [INTEGER] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -duration - - [INTEGER] - -end_date - - [TIMESTAMP] - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -start_date - - [TIMESTAMP] - -task_id - - [VARCHAR(250)] - NOT NULL + +task_fail + +id + + [INTEGER] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +duration + + [INTEGER] + +end_date + + [TIMESTAMP] + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +start_date + + [TIMESTAMP] + +task_id + + [VARCHAR(250)] + NOT NULL task_instance--task_fail - -0..N -1 + +0..N +1 task_instance--task_fail - -0..N -1 + +0..N +1 task_instance--task_fail - -0..N -1 + +0..N +1 task_instance--task_fail - -0..N -1 + +0..N +1 task_map - -task_map - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -keys - - [JSON] - -length - - [INTEGER] - NOT NULL + +task_map + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +keys + + [JSON] + +length + + [INTEGER] + NOT NULL task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 task_instance--task_map - -0..N -1 + +0..N +1 xcom - -xcom - -dag_run_id - - [INTEGER] - NOT NULL - -key - - [VARCHAR(512)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -dag_id - - [VARCHAR(250)] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -timestamp - - [TIMESTAMP] - NOT NULL - -value - - [BYTEA] + +xcom + +dag_run_id + + [INTEGER] + NOT NULL + +key + + [VARCHAR(512)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +dag_id + + [VARCHAR(250)] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +timestamp + + [TIMESTAMP] + NOT NULL + +value + + [BYTEA] task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance--xcom - -0..N -1 + +0..N +1 task_instance_note - -task_instance_note - -dag_id - - [VARCHAR(250)] - NOT NULL - -map_index - - [INTEGER] - NOT NULL - -run_id - - [VARCHAR(250)] - NOT NULL - -task_id - - [VARCHAR(250)] - NOT NULL - -content - - [VARCHAR(1000)] - -created_at - - [TIMESTAMP] - NOT NULL - -updated_at - - [TIMESTAMP] - NOT NULL - -user_id - - [INTEGER] + +task_instance_note + +dag_id + + [VARCHAR(250)] + NOT NULL + +map_index + + [INTEGER] + NOT NULL + +run_id + + [VARCHAR(250)] + NOT NULL + +task_id + + [VARCHAR(250)] + NOT NULL + +content + + [VARCHAR(1000)] + +created_at + + [TIMESTAMP] + NOT NULL + +updated_at + + [TIMESTAMP] + NOT NULL + +user_id + + [INTEGER] task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance--task_instance_note - -0..N -1 + +0..N +1 task_instance_history - -task_instance_history - -id - - [INTEGER] - NOT NULL - -custom_operator_name - - [VARCHAR(1000)] - -dag_id - - [VARCHAR(250)] - NOT NULL - -duration - - [DOUBLE_PRECISION] - -end_date - - [TIMESTAMP] - -executor - - [VARCHAR(1000)] - -executor_config - - [BYTEA] - -external_executor_id - - [VARCHAR(250)] - -hostname - - [VARCHAR(1000)] - -job_id - - [INTEGER] - -map_index - - [INTEGER] - NOT NULL - -max_tries - - [INTEGER] - -next_kwargs - - [JSON] - -next_method - - [VARCHAR(1000)] - -operator - - [VARCHAR(1000)] - -pid - - [INTEGER] - -pool - - [VARCHAR(256)] - NOT NULL - -pool_slots - - [INTEGER] - NOT NULL - -priority_weight - - [INTEGER] - -queue - - [VARCHAR(256)] - -queued_by_job_id - - [INTEGER] - -queued_dttm - - [TIMESTAMP] - -rendered_map_index - - [VARCHAR(250)] - -run_id - - [VARCHAR(250)] - NOT NULL - -start_date - - [TIMESTAMP] - -state - - [VARCHAR(20)] - -task_display_name - - [VARCHAR(2000)] - -task_id - - [VARCHAR(250)] - NOT NULL - -trigger_id - - [INTEGER] - -trigger_timeout - - [TIMESTAMP] - -try_number - - [INTEGER] - NOT NULL - -unixname - - [VARCHAR(1000)] - -updated_at - - [TIMESTAMP] + +task_instance_history + +id + + [INTEGER] + NOT NULL + +custom_operator_name + + [VARCHAR(1000)] + +dag_id + + [VARCHAR(250)] + NOT NULL + +duration + + [DOUBLE_PRECISION] + +end_date + + [TIMESTAMP] + +executor + + [VARCHAR(1000)] + +executor_config + + [BYTEA] + +external_executor_id + + [VARCHAR(250)] + +hostname + + [VARCHAR(1000)] + +job_id + + [INTEGER] + +map_index + + [INTEGER] + NOT NULL + +max_tries + + [INTEGER] + +next_kwargs + + [JSON] + +next_method + + [VARCHAR(1000)] + +operator + + [VARCHAR(1000)] + +pid + + [INTEGER] + +pool + + [VARCHAR(256)] + NOT NULL + +pool_slots + + [INTEGER] + NOT NULL + +priority_weight + + [INTEGER] + +queue + + [VARCHAR(256)] + +queued_by_job_id + + [INTEGER] + +queued_dttm + + [TIMESTAMP] + +rendered_map_index + + [VARCHAR(250)] + +run_id + + [VARCHAR(250)] + NOT NULL + +start_date + + [TIMESTAMP] + +state + + [VARCHAR(20)] + +task_display_name + + [VARCHAR(2000)] + +task_id + + [VARCHAR(250)] + NOT NULL + +trigger_id + + [INTEGER] + +trigger_timeout + + [TIMESTAMP] + +try_number + + [INTEGER] + NOT NULL + +unixname + + [VARCHAR(1000)] + +updated_at + + [TIMESTAMP] task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 task_instance--task_instance_history - -0..N -1 + +0..N +1 @@ -2040,325 +2050,325 @@ trigger--task_instance - -0..N -{0,1} + +0..N +{0,1} session - -session - -id - - [INTEGER] - NOT NULL - -data - - [BYTEA] - -expiry - - [TIMESTAMP] - -session_id - - [VARCHAR(255)] + +session + +id + + [INTEGER] + NOT NULL + +data + + [BYTEA] + +expiry + + [TIMESTAMP] + +session_id + + [VARCHAR(255)] alembic_version - -alembic_version - -version_num - - [VARCHAR(32)] - NOT NULL + +alembic_version + +version_num + + [VARCHAR(32)] + NOT NULL ab_user - -ab_user + +ab_user + +id + + [INTEGER] + NOT NULL -id - - [INTEGER] - NOT NULL +active + + [BOOLEAN] -active - - [BOOLEAN] +changed_by_fk + + [INTEGER] -changed_by_fk - - [INTEGER] +changed_on + + [TIMESTAMP] -changed_on - - [TIMESTAMP] +created_by_fk + + [INTEGER] -created_by_fk - - [INTEGER] +created_on + + [TIMESTAMP] -created_on - - [TIMESTAMP] +email + + [VARCHAR(512)] + NOT NULL -email - - [VARCHAR(512)] - NOT NULL +fail_login_count + + [INTEGER] -fail_login_count - - [INTEGER] +first_name + + [VARCHAR(256)] + NOT NULL -first_name - - [VARCHAR(256)] - NOT NULL +last_login + + [TIMESTAMP] -last_login - - [TIMESTAMP] +last_name + + [VARCHAR(256)] + NOT NULL -last_name - - [VARCHAR(256)] - NOT NULL +login_count + + [INTEGER] -login_count - - [INTEGER] +password + + [VARCHAR(256)] -password - - [VARCHAR(256)] - -username - - [VARCHAR(512)] - NOT NULL +username + + [VARCHAR(512)] + NOT NULL ab_user--ab_user - -0..N -{0,1} + +0..N +{0,1} ab_user--ab_user - -0..N -{0,1} + +0..N +{0,1} ab_user_role - -ab_user_role - -id - - [INTEGER] - NOT NULL - -role_id - - [INTEGER] - -user_id - - [INTEGER] + +ab_user_role + +id + + [INTEGER] + NOT NULL + +role_id + + [INTEGER] + +user_id + + [INTEGER] ab_user--ab_user_role - -0..N -{0,1} + +0..N +{0,1} ab_register_user - -ab_register_user + +ab_register_user + +id + + [INTEGER] + NOT NULL -id - - [INTEGER] - NOT NULL +email + + [VARCHAR(512)] + NOT NULL -email - - [VARCHAR(512)] - NOT NULL +first_name + + [VARCHAR(256)] + NOT NULL -first_name - - [VARCHAR(256)] - NOT NULL +last_name + + [VARCHAR(256)] + NOT NULL -last_name - - [VARCHAR(256)] - NOT NULL +password + + [VARCHAR(256)] -password - - [VARCHAR(256)] +registration_date + + [TIMESTAMP] -registration_date - - [TIMESTAMP] +registration_hash + + [VARCHAR(256)] -registration_hash - - [VARCHAR(256)] - -username - - [VARCHAR(512)] - NOT NULL +username + + [VARCHAR(512)] + NOT NULL ab_permission - -ab_permission + +ab_permission + +id + + [INTEGER] + NOT NULL -id - - [INTEGER] - NOT NULL - -name - - [VARCHAR(100)] - NOT NULL +name + + [VARCHAR(100)] + NOT NULL ab_permission_view - -ab_permission_view + +ab_permission_view + +id + + [INTEGER] + NOT NULL -id - - [INTEGER] - NOT NULL +permission_id + + [INTEGER] -permission_id - - [INTEGER] - -view_menu_id - - [INTEGER] +view_menu_id + + [INTEGER] ab_permission--ab_permission_view - -0..N -{0,1} + +0..N +{0,1} ab_permission_view_role - -ab_permission_view_role - -id - - [INTEGER] - NOT NULL - -permission_view_id - - [INTEGER] - -role_id - - [INTEGER] + +ab_permission_view_role + +id + + [INTEGER] + NOT NULL + +permission_view_id + + [INTEGER] + +role_id + + [INTEGER] ab_permission_view--ab_permission_view_role - -0..N -{0,1} + +0..N +{0,1} ab_view_menu - -ab_view_menu + +ab_view_menu + +id + + [INTEGER] + NOT NULL -id - - [INTEGER] - NOT NULL - -name - - [VARCHAR(250)] - NOT NULL +name + + [VARCHAR(250)] + NOT NULL ab_view_menu--ab_permission_view - -0..N -{0,1} + +0..N +{0,1} ab_role - -ab_role + +ab_role + +id + + [INTEGER] + NOT NULL -id - - [INTEGER] - NOT NULL - -name - - [VARCHAR(64)] - NOT NULL +name + + [VARCHAR(64)] + NOT NULL ab_role--ab_user_role - -0..N -{0,1} + +0..N +{0,1} ab_role--ab_permission_view_role - -0..N -{0,1} + +0..N +{0,1} alembic_version_fab - -alembic_version_fab - -version_num - - [VARCHAR(32)] - NOT NULL + +alembic_version_fab + +version_num + + [VARCHAR(32)] + NOT NULL diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index ded2b290b5a4e..a547d03d75be6 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``16cbcb1c8c36`` (head) | ``522625f6d606`` | ``3.0.0`` | Remove redundant index. | +| ``0d9e73a75ee4`` (head) | ``16cbcb1c8c36`` | ``3.0.0`` | Add name and group fields to DatasetModel. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``16cbcb1c8c36`` | ``522625f6d606`` | ``3.0.0`` | Remove redundant index. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``522625f6d606`` | ``1cdc775ca98f`` | ``3.0.0`` | Add tables for backfill. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ From 32b37c90ad89eda63a3ce90740d070f641a84098 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Mon, 30 Sep 2024 12:10:22 +0100 Subject: [PATCH 217/349] Ensure consistent Seriailized DAG hashing (#42517) * Ensure consistent Seriailized DAG hashing The serialized DAG dictionary is not ordered correctly when creating hashes, and that causes inconsistent hashes, leading to frequent update of the serialized DAG table. Changes: Implemented sorting for serialized DAG dictionaries and nested structures to ensure consistent and predictable serialization order for hashing. Using `sort_keys` in `json.dumps` is not enough to sort the nested structures in the serialized DAG. Added serialize and deserialize methods for DagParam and ArgNotSet to allow for more structured serialization. Updated serialize_template_field to handle objects that implement the serialize method. This was done because of DagParam and ArgNotSet in the template fields. Previously, it produced an object, but with this change, it now serialises to a consistent object. * Move hashing to a method * fixup! Move hashing to a method * Add test --- airflow/models/param.py | 25 ++++++++++++++++++++++ airflow/models/serialized_dag.py | 32 ++++++++++++++++++++++++++--- airflow/serialization/helpers.py | 7 +++++-- airflow/utils/types.py | 8 ++++++++ tests/models/test_serialized_dag.py | 19 +++++++++++++++++ 5 files changed, 86 insertions(+), 5 deletions(-) diff --git a/airflow/models/param.py b/airflow/models/param.py index a4bbce2b768c6..895cd2af8bb42 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -290,6 +290,7 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET): current_dag.params[name] = default self._name = name self._default = default + self.current_dag = current_dag def iter_references(self) -> Iterable[tuple[Operator, str]]: return () @@ -304,6 +305,30 @@ def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: return context["params"][self._name] raise AirflowException(f"No value could be resolved for parameter {self._name}") + def serialize(self) -> dict: + """Serialize the DagParam object into a dictionary.""" + return { + "dag_id": self.current_dag.dag_id, + "name": self._name, + "default": self._default, + } + + @classmethod + def deserialize(cls, data: dict, dags: dict) -> DagParam: + """ + Deserializes the dictionary back into a DagParam object. + + :param data: The serialized representation of the DagParam. + :param dags: A dictionary of available DAGs to look up the DAG. + """ + dag_id = data["dag_id"] + # Retrieve the current DAG from the provided DAGs dictionary + current_dag = dags.get(dag_id) + if not current_dag: + raise ValueError(f"DAG with id {dag_id} not found.") + + return cls(current_dag=current_dag, name=data["name"], default=data["default"]) + def process_params( dag: DAG, diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index dec843451a98a..32be31d721e34 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -22,7 +22,7 @@ import logging import zlib from datetime import timedelta -from typing import TYPE_CHECKING, Collection +from typing import TYPE_CHECKING, Any, Collection import sqlalchemy_jsonfield from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_, exc, or_, select @@ -114,9 +114,10 @@ def __init__(self, dag: DAG, processor_subdir: str | None = None) -> None: self.processor_subdir = processor_subdir dag_data = SerializedDAG.to_dict(dag) - dag_data_json = json.dumps(dag_data, sort_keys=True).encode("utf-8") + self.dag_hash = SerializedDagModel.hash(dag_data) - self.dag_hash = md5(dag_data_json).hexdigest() + # partially ordered json data + dag_data_json = json.dumps(dag_data, sort_keys=True).encode("utf-8") if COMPRESS_SERIALIZED_DAGS: self._data = None @@ -132,6 +133,30 @@ def __init__(self, dag: DAG, processor_subdir: str | None = None) -> None: def __repr__(self) -> str: return f"" + @classmethod + def hash(cls, dag_data): + """Hash the data to get the dag_hash.""" + dag_data = cls._sort_serialized_dag_dict(dag_data) + data_json = json.dumps(dag_data, sort_keys=True).encode("utf-8") + return md5(data_json).hexdigest() + + @classmethod + def _sort_serialized_dag_dict(cls, serialized_dag: Any): + """Recursively sort json_dict and its nested dictionaries and lists.""" + if isinstance(serialized_dag, dict): + return {k: cls._sort_serialized_dag_dict(v) for k, v in sorted(serialized_dag.items())} + elif isinstance(serialized_dag, list): + if all(isinstance(i, dict) for i in serialized_dag): + if all("task_id" in i.get("__var", {}) for i in serialized_dag): + return sorted( + [cls._sort_serialized_dag_dict(i) for i in serialized_dag], + key=lambda x: x["__var"]["task_id"], + ) + elif all(isinstance(item, str) for item in serialized_dag): + return sorted(serialized_dag) + return [cls._sort_serialized_dag_dict(i) for i in serialized_dag] + return serialized_dag + @classmethod @provide_session def write_dag( @@ -149,6 +174,7 @@ def write_dag( :param dag: a DAG to be written into database :param min_update_interval: minimal interval in seconds to update serialized DAG + :param processor_subdir: The dag directory of the processor :param session: ORM Session :returns: Boolean indicating if the DAG was written to the DB diff --git a/airflow/serialization/helpers.py b/airflow/serialization/helpers.py index 7f97e7f2ff299..85bf3a1cc551c 100644 --- a/airflow/serialization/helpers.py +++ b/airflow/serialization/helpers.py @@ -44,14 +44,17 @@ def is_jsonable(x): max_length = conf.getint("core", "max_templated_field_length") if not is_jsonable(template_field): - serialized = str(template_field) + try: + serialized = template_field.serialize() + except AttributeError: + serialized = str(template_field) if len(serialized) > max_length: rendered = redact(serialized, name) return ( "Truncated. You can change this behaviour in [core]max_templated_field_length. " f"{rendered[:max_length - 79]!r}... " ) - return str(template_field) + return serialized else: if not template_field: return template_field diff --git a/airflow/utils/types.py b/airflow/utils/types.py index 86af13832755b..a19b2534b03fb 100644 --- a/airflow/utils/types.py +++ b/airflow/utils/types.py @@ -41,6 +41,14 @@ def is_arg_passed(arg: Union[ArgNotSet, None] = NOTSET) -> bool: is_arg_passed(None) # True. """ + @staticmethod + def serialize(): + return "NOTSET" + + @classmethod + def deserialize(cls): + return cls + NOTSET = ArgNotSet() """Sentinel value for argument default. See ``ArgNotSet``.""" diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index b8fddc655dae5..d9a77e55edaf5 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -23,6 +23,7 @@ import pendulum import pytest +from sqlalchemy import select import airflow.example_dags as example_dags_module from airflow.assets import Asset @@ -264,3 +265,21 @@ def test_order_of_deps_is_consistent(self): # dag hash should not change without change in structure (we're in a loop) assert this_dag_hash == first_dag_hash + + def test_example_dag_hashes_are_always_consistent(self, session): + """ + This test asserts that the hashes of the example dags are always consistent. + """ + + def get_hash_set(): + example_dags = self._write_example_dags() + ordered_example_dags = dict(sorted(example_dags.items())) + hashes = set() + for dag_id in ordered_example_dags.keys(): + smd = session.execute(select(SDM.dag_hash).where(SDM.dag_id == dag_id)).one() + hashes.add(smd.dag_hash) + return hashes + + first_hashes = get_hash_set() + # assert that the hashes are the same + assert first_hashes == get_hash_set() From 018478c889248ac8e36cdabfda38e8dabba7388a Mon Sep 17 00:00:00 2001 From: codecae Date: Mon, 30 Sep 2024 10:01:35 -0400 Subject: [PATCH 218/349] reduce eyestrain in dark mode with reduced contrast and saturation (#42567) * reduce eyestrain in dark mode with reduced contrast and saturation * feat: readjusted saturation --------- Co-authored-by: Curtis Bangert --- airflow/www/static/css/bootstrap-theme.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/www/static/css/bootstrap-theme.css b/airflow/www/static/css/bootstrap-theme.css index 921d795b96b12..94e1ee5887715 100644 --- a/airflow/www/static/css/bootstrap-theme.css +++ b/airflow/www/static/css/bootstrap-theme.css @@ -37,7 +37,7 @@ html { -webkit-text-size-adjust: 100%; } html[data-color-scheme="dark"] { - filter: invert(100%) hue-rotate(180deg); + filter: invert(100%) hue-rotate(180deg) saturate(90%) contrast(85%); } /* Default icons to not display until the data-color-scheme has been set */ From 76629596575a6eb05452b2ead97d1faaaefb43db Mon Sep 17 00:00:00 2001 From: Pierre Jeambrun Date: Mon, 30 Sep 2024 23:04:25 +0800 Subject: [PATCH 219/349] AIP-84 Migrate patch dags to FastAPI API (#42545) * AIP-84 Migrate patch dags to FastAPI API * Fix CI --- .../api_connexion/endpoints/dag_endpoint.py | 1 + airflow/api_fastapi/db/__init__.py | 16 +++ airflow/api_fastapi/db/common.py | 83 ++++++++++++ airflow/api_fastapi/{db.py => db/dags.py} | 54 +++----- airflow/api_fastapi/openapi/v1-generated.yaml | 123 +++++++++++++++++- airflow/api_fastapi/parameters.py | 50 +++++-- airflow/api_fastapi/views/public/dags.py | 93 ++++++++----- airflow/api_fastapi/views/ui/assets.py | 2 +- airflow/ui/openapi-gen/queries/common.ts | 3 + airflow/ui/openapi-gen/queries/queries.ts | 88 ++++++++++++- .../ui/openapi-gen/requests/services.gen.ts | 50 ++++++- airflow/ui/openapi-gen/requests/types.gen.ts | 44 +++++++ tests/api_fastapi/views/public/test_dags.py | 94 ++++++++++--- 13 files changed, 596 insertions(+), 105 deletions(-) create mode 100644 airflow/api_fastapi/db/__init__.py create mode 100644 airflow/api_fastapi/db/common.py rename airflow/api_fastapi/{db.py => db/dags.py} (55%) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 6fca5ae7c93d5..5d10a97dedce6 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -165,6 +165,7 @@ def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = return dag_schema.dump(dag) +@mark_fastapi_migration_done @security.requires_access_dag("PUT") @format_parameters({"limit": check_limit}) @action_logging diff --git a/airflow/api_fastapi/db/__init__.py b/airflow/api_fastapi/db/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/api_fastapi/db/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/api_fastapi/db/common.py b/airflow/api_fastapi/db/common.py new file mode 100644 index 0000000000000..f611eaa64f07d --- /dev/null +++ b/airflow/api_fastapi/db/common.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +from airflow.utils.db import get_query_count +from airflow.utils.session import NEW_SESSION, create_session, provide_session + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + from sqlalchemy.sql import Select + + from airflow.api_fastapi.parameters import BaseParam + + +async def get_session() -> Session: + """ + Dependency for providing a session. + + For non route function please use the :class:`airflow.utils.session.provide_session` decorator. + + Example usage: + + .. code:: python + + @router.get("/your_path") + def your_route(session: Annotated[Session, Depends(get_session)]): + pass + """ + with create_session() as session: + yield session + + +def apply_filters_to_select(base_select: Select, filters: Sequence[BaseParam | None]) -> Select: + base_select = base_select + for filter in filters: + if filter is None: + continue + base_select = filter.to_orm(base_select) + + return base_select + + +@provide_session +def paginated_select( + base_select: Select, + filters: Sequence[BaseParam], + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: Session = NEW_SESSION, +) -> Select: + base_select = apply_filters_to_select( + base_select, + filters, + ) + + total_entries = get_query_count(base_select, session=session) + + # TODO: Re-enable when permissions are handled. Readable / writable entities, + # for instance: + # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) + # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags)) + + base_select = apply_filters_to_select(base_select, [order_by, offset, limit]) + + return base_select, total_entries diff --git a/airflow/api_fastapi/db.py b/airflow/api_fastapi/db/dags.py similarity index 55% rename from airflow/api_fastapi/db.py rename to airflow/api_fastapi/db/dags.py index c3ed01a0aefec..7cd7cc9cd955d 100644 --- a/airflow/api_fastapi/db.py +++ b/airflow/api_fastapi/db/dags.py @@ -17,45 +17,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from sqlalchemy import func, select +from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun -from airflow.utils.session import create_session - -if TYPE_CHECKING: - from sqlalchemy.orm import Session - from sqlalchemy.sql import Select - - from airflow.api_fastapi.parameters import BaseParam - - -async def get_session() -> Session: - """ - Dependency for providing a session. - - For non route function please use the :class:`airflow.utils.session.provide_session` decorator. - - Example usage: - - .. code:: python - - @router.get("/your_path") - def your_route(session: Annotated[Session, Depends(get_session)]): - pass - """ - with create_session() as session: - yield session - - -def apply_filters_to_select(base_select: Select, filters: list[BaseParam]) -> Select: - select = base_select - for filter in filters: - select = filter.to_orm(select) - - return select - latest_dag_run_per_dag_id_cte = ( select(DagRun.dag_id, func.max(DagRun.start_date).label("start_date")) @@ -63,3 +28,20 @@ def apply_filters_to_select(base_select: Select, filters: list[BaseParam]) -> Se .group_by(DagRun.dag_id) .cte() ) + + +dags_select_with_latest_dag_run = ( + select(DagModel) + .join( + latest_dag_run_per_dag_id_cte, + DagModel.dag_id == latest_dag_run_per_dag_id_cte.c.dag_id, + isouter=True, + ) + .join( + DagRun, + DagRun.start_date == latest_dag_run_per_dag_id_cte.c.start_date + and DagRun.dag_id == latest_dag_run_per_dag_id_cte.c.dag_id, + isouter=True, + ) + .order_by(DagModel.dag_id) +) diff --git a/airflow/api_fastapi/openapi/v1-generated.yaml b/airflow/api_fastapi/openapi/v1-generated.yaml index c130f3162c6e6..a38a1021890d6 100644 --- a/airflow/api_fastapi/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/openapi/v1-generated.yaml @@ -131,12 +131,133 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + patch: + tags: + - DAG + summary: Patch Dags + description: Patch multiple DAGs. + operationId: patch_dags_public_dags_patch + parameters: + - name: update_mask + in: query + required: false + schema: + anyOf: + - type: array + items: + type: string + - type: 'null' + title: Update Mask + - name: limit + in: query + required: false + schema: + type: integer + default: 100 + title: Limit + - name: offset + in: query + required: false + schema: + type: integer + default: 0 + title: Offset + - name: tags + in: query + required: false + schema: + type: array + items: + type: string + title: Tags + - name: owners + in: query + required: false + schema: + type: array + items: + type: string + title: Owners + - name: dag_id_pattern + in: query + required: false + schema: + anyOf: + - type: string + - type: 'null' + title: Dag Id Pattern + - name: only_active + in: query + required: false + schema: + type: boolean + default: true + title: Only Active + - name: paused + in: query + required: false + schema: + anyOf: + - type: boolean + - type: 'null' + title: Paused + - name: last_dag_run_state + in: query + required: false + schema: + anyOf: + - $ref: '#/components/schemas/DagRunState' + - type: 'null' + title: Last Dag Run State + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/DAGPatchBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/DAGCollectionResponse' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dags/{dag_id}: patch: tags: - DAG summary: Patch Dag - description: Update the specific DAG. + description: Patch the specific DAG. operationId: patch_dag_public_dags__dag_id__patch parameters: - name: dag_id diff --git a/airflow/api_fastapi/parameters.py b/airflow/api_fastapi/parameters.py index 09eea5f6e055b..504014602f3b5 100644 --- a/airflow/api_fastapi/parameters.py +++ b/airflow/api_fastapi/parameters.py @@ -37,9 +37,10 @@ class BaseParam(Generic[T], ABC): """Base class for filters.""" - def __init__(self) -> None: + def __init__(self, skip_none: bool = True) -> None: self.value: T | None = None self.attribute: ColumnElement | None = None + self.skip_none = skip_none @abstractmethod def to_orm(self, select: Select) -> Select: @@ -58,7 +59,7 @@ class _LimitFilter(BaseParam[int]): """Filter on the limit.""" def to_orm(self, select: Select) -> Select: - if self.value is None: + if self.value is None and self.skip_none: return select return select.limit(self.value) @@ -71,7 +72,7 @@ class _OffsetFilter(BaseParam[int]): """Filter on offset.""" def to_orm(self, select: Select) -> Select: - if self.value is None: + if self.value is None and self.skip_none: return select return select.offset(self.value) @@ -83,7 +84,7 @@ class _PausedFilter(BaseParam[bool]): """Filter on is_paused.""" def to_orm(self, select: Select) -> Select: - if self.value is None: + if self.value is None and self.skip_none: return select return select.where(DagModel.is_paused == self.value) @@ -95,7 +96,7 @@ class _OnlyActiveFilter(BaseParam[bool]): """Filter on is_active.""" def to_orm(self, select: Select) -> Select: - if self.value: + if self.value and self.skip_none: return select.where(DagModel.is_active == self.value) return select @@ -106,33 +107,40 @@ def depends(self, only_active: bool = True) -> _OnlyActiveFilter: class _SearchParam(BaseParam[str]): """Search on attribute.""" - def __init__(self, attribute: ColumnElement) -> None: - super().__init__() + def __init__(self, attribute: ColumnElement, skip_none: bool = True) -> None: + super().__init__(skip_none) self.attribute: ColumnElement = attribute def to_orm(self, select: Select) -> Select: - if self.value is None: + if self.value is None and self.skip_none: return select return select.where(self.attribute.ilike(f"%{self.value}")) + def transform_aliases(self, value: str | None) -> str | None: + if value == "~": + value = "%" + return value + class _DagIdPatternSearch(_SearchParam): """Search on dag_id.""" - def __init__(self) -> None: - super().__init__(DagModel.dag_id) + def __init__(self, skip_none: bool = True) -> None: + super().__init__(DagModel.dag_id, skip_none) def depends(self, dag_id_pattern: str | None = None) -> _DagIdPatternSearch: + dag_id_pattern = super().transform_aliases(dag_id_pattern) return self.set_value(dag_id_pattern) class _DagDisplayNamePatternSearch(_SearchParam): """Search on dag_display_name.""" - def __init__(self) -> None: - super().__init__(DagModel.dag_display_name) + def __init__(self, skip_none: bool = True) -> None: + super().__init__(DagModel.dag_display_name, skip_none) def depends(self, dag_display_name_pattern: str | None = None) -> _DagDisplayNamePatternSearch: + dag_display_name_pattern = super().transform_aliases(dag_display_name_pattern) return self.set_value(dag_display_name_pattern) @@ -149,6 +157,9 @@ def __init__(self, allowed_attrs: list[str]) -> None: self.allowed_attrs = allowed_attrs def to_orm(self, select: Select) -> Select: + if self.skip_none is False: + raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}") + if self.value is None: return select @@ -165,6 +176,10 @@ def to_orm(self, select: Select) -> Select: # MySQL does not support `nullslast`, and True/False ordering depends on the # database implementation. nullscheck = case((column.isnot(None), 0), else_=1) + + # Reset default sorting + select = select.order_by(None) + if self.value[0] == "-": return select.order_by(nullscheck, column.desc(), DagModel.dag_id.desc()) else: @@ -178,6 +193,9 @@ class _TagsFilter(BaseParam[List[str]]): """Filter on tags.""" def to_orm(self, select: Select) -> Select: + if self.skip_none is False: + raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}") + if not self.value: return select @@ -192,6 +210,9 @@ class _OwnersFilter(BaseParam[List[str]]): """Filter on owners.""" def to_orm(self, select: Select) -> Select: + if self.skip_none is False: + raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}") + if not self.value: return select @@ -206,7 +227,7 @@ class _LastDagRunStateFilter(BaseParam[DagRunState]): """Filter on the state of the latest DagRun.""" def to_orm(self, select: Select) -> Select: - if self.value is None: + if self.value is None and self.skip_none: return select return select.where(DagRun.state == self.value) @@ -223,6 +244,9 @@ def depends(self, last_dag_run_state: DagRunState | None = None) -> _LastDagRunS QueryDagDisplayNamePatternSearch = Annotated[ _DagDisplayNamePatternSearch, Depends(_DagDisplayNamePatternSearch().depends) ] +QueryDagIdPatternSearchWithNone = Annotated[ + _DagIdPatternSearch, Depends(_DagIdPatternSearch(skip_none=False).depends) +] QueryTagsFilter = Annotated[_TagsFilter, Depends(_TagsFilter().depends)] QueryOwnersFilter = Annotated[_OwnersFilter, Depends(_OwnersFilter().depends)] # DagRun diff --git a/airflow/api_fastapi/views/public/dags.py b/airflow/api_fastapi/views/public/dags.py index a9fe87eef0953..a6c25d6568c1e 100644 --- a/airflow/api_fastapi/views/public/dags.py +++ b/airflow/api_fastapi/views/public/dags.py @@ -18,15 +18,20 @@ from __future__ import annotations from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy import select +from sqlalchemy import update from sqlalchemy.orm import Session from typing_extensions import Annotated -from airflow.api_fastapi.db import apply_filters_to_select, get_session, latest_dag_run_per_dag_id_cte +from airflow.api_fastapi.db.common import ( + get_session, + paginated_select, +) +from airflow.api_fastapi.db.dags import dags_select_with_latest_dag_run from airflow.api_fastapi.openapi.exceptions import create_openapi_http_exception_doc from airflow.api_fastapi.parameters import ( QueryDagDisplayNamePatternSearch, QueryDagIdPatternSearch, + QueryDagIdPatternSearchWithNone, QueryLastDagRunStateFilter, QueryLimit, QueryOffset, @@ -38,8 +43,6 @@ ) from airflow.api_fastapi.serializers.dags import DAGCollectionResponse, DAGPatchBody, DAGResponse from airflow.models import DagModel -from airflow.models.dagrun import DagRun -from airflow.utils.db import get_query_count dags_router = APIRouter(tags=["DAG"]) @@ -66,35 +69,16 @@ async def get_dags( session: Annotated[Session, Depends(get_session)], ) -> DAGCollectionResponse: """Get all DAGs.""" - dags_query = ( - select(DagModel) - .join( - latest_dag_run_per_dag_id_cte, - DagModel.dag_id == latest_dag_run_per_dag_id_cte.c.dag_id, - isouter=True, - ) - .join( - DagRun, - DagRun.start_date == latest_dag_run_per_dag_id_cte.c.start_date - and DagRun.dag_id == latest_dag_run_per_dag_id_cte.c.dag_id, - isouter=True, - ) - ) - - dags_query = apply_filters_to_select( - dags_query, + dags_select, total_entries = paginated_select( + dags_select_with_latest_dag_run, [only_active, paused, dag_id_pattern, dag_display_name_pattern, tags, owners, last_dag_run_state], + order_by, + offset, + limit, + session, ) - # TODO: Re-enable when permissions are handled. - # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) - # dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags)) - - total_entries = get_query_count(dags_query, session=session) - - dags_query = apply_filters_to_select(dags_query, [order_by, offset, limit]) - - dags = session.scalars(dags_query).all() + dags = session.scalars(dags_select).all() return DAGCollectionResponse( dags=[DAGResponse.model_validate(dag, from_attributes=True) for dag in dags], @@ -109,7 +93,7 @@ async def patch_dag( session: Annotated[Session, Depends(get_session)], update_mask: list[str] | None = Query(None), ) -> DAGResponse: - """Update the specific DAG.""" + """Patch the specific DAG.""" dag = session.get(DagModel, dag_id) if dag is None: @@ -127,3 +111,50 @@ async def patch_dag( setattr(dag, attr_name, attr_value) return DAGResponse.model_validate(dag, from_attributes=True) + + +@dags_router.patch("/dags", responses=create_openapi_http_exception_doc([400, 401, 403, 404])) +async def patch_dags( + patch_body: DAGPatchBody, + limit: QueryLimit, + offset: QueryOffset, + tags: QueryTagsFilter, + owners: QueryOwnersFilter, + dag_id_pattern: QueryDagIdPatternSearchWithNone, + only_active: QueryOnlyActiveFilter, + paused: QueryPausedFilter, + last_dag_run_state: QueryLastDagRunStateFilter, + session: Annotated[Session, Depends(get_session)], + update_mask: list[str] | None = Query(None), +) -> DAGCollectionResponse: + """Patch multiple DAGs.""" + if update_mask: + if update_mask != ["is_paused"]: + raise HTTPException(400, "Only `is_paused` field can be updated through the REST API") + else: + update_mask = ["is_paused"] + + dags_select, total_entries = paginated_select( + dags_select_with_latest_dag_run, + [only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state], + None, + offset, + limit, + session, + ) + + dags = session.scalars(dags_select).all() + + dags_to_update = {dag.dag_id for dag in dags} + + session.execute( + update(DagModel) + .where(DagModel.dag_id.in_(dags_to_update)) + .values(is_paused=patch_body.is_paused) + .execution_options(synchronize_session="fetch") + ) + + return DAGCollectionResponse( + dags=[DAGResponse.model_validate(dag, from_attributes=True) for dag in dags], + total_entries=total_entries, + ) diff --git a/airflow/api_fastapi/views/ui/assets.py b/airflow/api_fastapi/views/ui/assets.py index 458d531facf6a..739c7d64af439 100644 --- a/airflow/api_fastapi/views/ui/assets.py +++ b/airflow/api_fastapi/views/ui/assets.py @@ -22,7 +22,7 @@ from sqlalchemy.orm import Session from typing_extensions import Annotated -from airflow.api_fastapi.db import get_session +from airflow.api_fastapi.db.common import get_session from airflow.models import DagModel from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel, DagScheduleAssetReference diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 46694939ed74e..b1508c86c0c4b 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -76,6 +76,9 @@ export const UseDagServiceGetDagsPublicDagsGetKeyFn = ( }, ]), ]; +export type DagServicePatchDagsPublicDagsPatchMutationResult = Awaited< + ReturnType +>; export type DagServicePatchDagPublicDagsDagIdPatchMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 7cbaac5b2c77d..5eda2a3d0e4d2 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -118,9 +118,95 @@ export const useDagServiceGetDagsPublicDagsGet = < }) as TData, ...options, }); +/** + * Patch Dags + * Patch multiple DAGs. + * @param data The data for the request. + * @param data.requestBody + * @param data.updateMask + * @param data.limit + * @param data.offset + * @param data.tags + * @param data.owners + * @param data.dagIdPattern + * @param data.onlyActive + * @param data.paused + * @param data.lastDagRunState + * @returns DAGCollectionResponse Successful Response + * @throws ApiError + */ +export const useDagServicePatchDagsPublicDagsPatch = < + TData = Common.DagServicePatchDagsPublicDagsPatchMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + dagIdPattern?: string; + lastDagRunState?: DagRunState; + limit?: number; + offset?: number; + onlyActive?: boolean; + owners?: string[]; + paused?: boolean; + requestBody: DAGPatchBody; + tags?: string[]; + updateMask?: string[]; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + dagIdPattern?: string; + lastDagRunState?: DagRunState; + limit?: number; + offset?: number; + onlyActive?: boolean; + owners?: string[]; + paused?: boolean; + requestBody: DAGPatchBody; + tags?: string[]; + updateMask?: string[]; + }, + TContext + >({ + mutationFn: ({ + dagIdPattern, + lastDagRunState, + limit, + offset, + onlyActive, + owners, + paused, + requestBody, + tags, + updateMask, + }) => + DagService.patchDagsPublicDagsPatch({ + dagIdPattern, + lastDagRunState, + limit, + offset, + onlyActive, + owners, + paused, + requestBody, + tags, + updateMask, + }) as unknown as Promise, + ...options, + }); /** * Patch Dag - * Update the specific DAG. + * Patch the specific DAG. * @param data The data for the request. * @param data.dagId * @param data.requestBody diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 5aa5876d112ad..7fb6306afbc67 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -7,6 +7,8 @@ import type { NextRunAssetsUiNextRunDatasetsDagIdGetResponse, GetDagsPublicDagsGetData, GetDagsPublicDagsGetResponse, + PatchDagsPublicDagsPatchData, + PatchDagsPublicDagsPatchResponse, PatchDagPublicDagsDagIdPatchData, PatchDagPublicDagsDagIdPatchResponse, } from "./types.gen"; @@ -77,9 +79,55 @@ export class DagService { }); } + /** + * Patch Dags + * Patch multiple DAGs. + * @param data The data for the request. + * @param data.requestBody + * @param data.updateMask + * @param data.limit + * @param data.offset + * @param data.tags + * @param data.owners + * @param data.dagIdPattern + * @param data.onlyActive + * @param data.paused + * @param data.lastDagRunState + * @returns DAGCollectionResponse Successful Response + * @throws ApiError + */ + public static patchDagsPublicDagsPatch( + data: PatchDagsPublicDagsPatchData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "PATCH", + url: "/public/dags", + query: { + update_mask: data.updateMask, + limit: data.limit, + offset: data.offset, + tags: data.tags, + owners: data.owners, + dag_id_pattern: data.dagIdPattern, + only_active: data.onlyActive, + paused: data.paused, + last_dag_run_state: data.lastDagRunState, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } + /** * Patch Dag - * Update the specific DAG. + * Patch the specific DAG. * @param data The data for the request. * @param data.dagId * @param data.requestBody diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index bc455f63b6449..0fe7134ba8c31 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -111,6 +111,21 @@ export type GetDagsPublicDagsGetData = { export type GetDagsPublicDagsGetResponse = DAGCollectionResponse; +export type PatchDagsPublicDagsPatchData = { + dagIdPattern?: string | null; + lastDagRunState?: DagRunState | null; + limit?: number; + offset?: number; + onlyActive?: boolean; + owners?: Array; + paused?: boolean | null; + requestBody: DAGPatchBody; + tags?: Array; + updateMask?: Array | null; +}; + +export type PatchDagsPublicDagsPatchResponse = DAGCollectionResponse; + export type PatchDagPublicDagsDagIdPatchData = { dagId: string; requestBody: DAGPatchBody; @@ -151,6 +166,35 @@ export type $OpenApiTs = { 422: HTTPValidationError; }; }; + patch: { + req: PatchDagsPublicDagsPatchData; + res: { + /** + * Successful Response + */ + 200: DAGCollectionResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; }; "/public/dags/{dag_id}": { patch: { diff --git a/tests/api_fastapi/views/public/test_dags.py b/tests/api_fastapi/views/public/test_dags.py index 6e400f11cc0d2..7b68ebe512a25 100644 --- a/tests/api_fastapi/views/public/test_dags.py +++ b/tests/api_fastapi/views/public/test_dags.py @@ -112,37 +112,37 @@ def setup(dag_maker) -> None: "query_params, expected_total_entries, expected_ids", [ # Filters - ({}, 2, ["test_dag1", "test_dag2"]), - ({"limit": 1}, 2, ["test_dag1"]), - ({"offset": 1}, 2, ["test_dag2"]), - ({"tags": ["example"]}, 1, ["test_dag1"]), - ({"only_active": False}, 3, ["test_dag1", "test_dag2", "test_dag3"]), - ({"paused": True, "only_active": False}, 1, ["test_dag3"]), - ({"paused": False}, 2, ["test_dag1", "test_dag2"]), - ({"owners": ["airflow"]}, 2, ["test_dag1", "test_dag2"]), - ({"owners": ["test_owner"], "only_active": False}, 1, ["test_dag3"]), - ({"last_dag_run_state": "success", "only_active": False}, 1, ["test_dag3"]), - ({"last_dag_run_state": "failed", "only_active": False}, 1, ["test_dag1"]), + ({}, 2, [DAG1_ID, DAG2_ID]), + ({"limit": 1}, 2, [DAG1_ID]), + ({"offset": 1}, 2, [DAG2_ID]), + ({"tags": ["example"]}, 1, [DAG1_ID]), + ({"only_active": False}, 3, [DAG1_ID, DAG2_ID, DAG3_ID]), + ({"paused": True, "only_active": False}, 1, [DAG3_ID]), + ({"paused": False}, 2, [DAG1_ID, DAG2_ID]), + ({"owners": ["airflow"]}, 2, [DAG1_ID, DAG2_ID]), + ({"owners": ["test_owner"], "only_active": False}, 1, [DAG3_ID]), + ({"last_dag_run_state": "success", "only_active": False}, 1, [DAG3_ID]), + ({"last_dag_run_state": "failed", "only_active": False}, 1, [DAG1_ID]), # # Sort - ({"order_by": "-dag_id"}, 2, ["test_dag2", "test_dag1"]), - ({"order_by": "-dag_display_name"}, 2, ["test_dag2", "test_dag1"]), - ({"order_by": "dag_display_name"}, 2, ["test_dag1", "test_dag2"]), - ({"order_by": "next_dagrun", "only_active": False}, 3, ["test_dag3", "test_dag1", "test_dag2"]), - ({"order_by": "last_run_state", "only_active": False}, 3, ["test_dag1", "test_dag3", "test_dag2"]), - ({"order_by": "-last_run_state", "only_active": False}, 3, ["test_dag3", "test_dag1", "test_dag2"]), + ({"order_by": "-dag_id"}, 2, [DAG2_ID, DAG1_ID]), + ({"order_by": "-dag_display_name"}, 2, [DAG2_ID, DAG1_ID]), + ({"order_by": "dag_display_name"}, 2, [DAG1_ID, DAG2_ID]), + ({"order_by": "next_dagrun", "only_active": False}, 3, [DAG3_ID, DAG1_ID, DAG2_ID]), + ({"order_by": "last_run_state", "only_active": False}, 3, [DAG1_ID, DAG3_ID, DAG2_ID]), + ({"order_by": "-last_run_state", "only_active": False}, 3, [DAG3_ID, DAG1_ID, DAG2_ID]), ( {"order_by": "last_run_start_date", "only_active": False}, 3, - ["test_dag1", "test_dag3", "test_dag2"], + [DAG1_ID, DAG3_ID, DAG2_ID], ), ( {"order_by": "-last_run_start_date", "only_active": False}, 3, - ["test_dag3", "test_dag1", "test_dag2"], + [DAG3_ID, DAG1_ID, DAG2_ID], ), # Search - ({"dag_id_pattern": "1"}, 1, ["test_dag1"]), - ({"dag_display_name_pattern": "display2"}, 1, ["test_dag2"]), + ({"dag_id_pattern": "1"}, 1, [DAG1_ID]), + ({"dag_display_name_pattern": "display2"}, 1, [DAG2_ID]), ], ) def test_get_dags(test_client, query_params, expected_total_entries, expected_ids): @@ -173,3 +173,55 @@ def test_patch_dag(test_client, query_params, dag_id, body, expected_status_code if expected_status_code == 200: body = response.json() assert body["is_paused"] == expected_is_paused + + +@pytest.mark.parametrize( + "query_params, body, expected_status_code, expected_ids, expected_paused_ids", + [ + ({"update_mask": ["field_1", "is_paused"]}, {"is_paused": True}, 400, None, None), + ( + {"only_active": False}, + {"is_paused": True}, + 200, + [], + [], + ), # no-op because the dag_id_pattern is not provided + ( + {"only_active": False, "dag_id_pattern": "~"}, + {"is_paused": True}, + 200, + [DAG1_ID, DAG2_ID, DAG3_ID], + [DAG1_ID, DAG2_ID, DAG3_ID], + ), + ( + {"only_active": False, "dag_id_pattern": "~"}, + {"is_paused": False}, + 200, + [DAG1_ID, DAG2_ID, DAG3_ID], + [], + ), + ( + {"dag_id_pattern": "~"}, + {"is_paused": True}, + 200, + [DAG1_ID, DAG2_ID], + [DAG1_ID, DAG2_ID], + ), + ( + {"dag_id_pattern": "dag1"}, + {"is_paused": True}, + 200, + [DAG1_ID], + [DAG1_ID], + ), + ], +) +def test_patch_dags(test_client, query_params, body, expected_status_code, expected_ids, expected_paused_ids): + response = test_client.patch("/public/dags", json=body, params=query_params) + + assert response.status_code == expected_status_code + if expected_status_code == 200: + body = response.json() + assert [dag["dag_id"] for dag in body["dags"]] == expected_ids + paused_dag_ids = [dag["dag_id"] for dag in body["dags"] if dag["is_paused"]] + assert paused_dag_ids == expected_paused_ids From d2b116c7976a0d9de388082bccf415c64a03f25f Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 30 Sep 2024 08:58:36 -0700 Subject: [PATCH 220/349] KubernetesHook kube_config extra can take dict (#41413) Previously had to be json-encoded string which is less convenient when defining the conn in json. --------- Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/providers/cncf/kubernetes/hooks/kubernetes.py | 2 ++ tests/providers/cncf/kubernetes/hooks/test_kubernetes.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 9f7e33696eb87..a810e8f9ed522 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -250,6 +250,8 @@ def get_conn(self) -> client.ApiClient: if kubeconfig is not None: with tempfile.NamedTemporaryFile() as temp_config: self.log.debug("loading kube_config from: connection kube_config") + if isinstance(kubeconfig, dict): + kubeconfig = json.dumps(kubeconfig) temp_config.write(kubeconfig.encode()) temp_config.flush() self._is_in_cluster = False diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index 348974eacdfa6..065768def24ea 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -79,6 +79,7 @@ def setup_class(cls) -> None: ("in_cluster", {"in_cluster": True}), ("in_cluster_empty", {"in_cluster": ""}), ("kube_config", {"kube_config": '{"test": "kube"}'}), + ("kube_config_dict", {"kube_config": {"test": "kube"}}), ("kube_config_path", {"kube_config_path": "path/to/file"}), ("kube_config_empty", {"kube_config": ""}), ("kube_config_path_empty", {"kube_config_path": ""}), @@ -285,6 +286,7 @@ def test_kube_config_path( ( (None, False), ("kube_config", True), + ("kube_config_dict", True), ("kube_config_empty", False), ), ) From 779a227d4bbdbd0b71af84a237a169f7fac22b79 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 30 Sep 2024 12:58:31 -0700 Subject: [PATCH 221/349] Speed up boring cyborg consistency pre-commit check (#42589) This is typically the slowest pre-commit besides mypy, and it runs every time. Previously it loaded all filenames into memory and ran glob filter on that. It seems faster to apply glob against the file system directly. This makes pre-commit much faster. Previously took around 4 seconds, now about a half a second. --- scripts/ci/pre_commit/boring_cyborg.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/scripts/ci/pre_commit/boring_cyborg.py b/scripts/ci/pre_commit/boring_cyborg.py index cf852b12bb6da..ec674485b5457 100755 --- a/scripts/ci/pre_commit/boring_cyborg.py +++ b/scripts/ci/pre_commit/boring_cyborg.py @@ -17,13 +17,11 @@ # under the License. from __future__ import annotations -import subprocess import sys from pathlib import Path import yaml from termcolor import colored -from wcmatch import glob if __name__ not in ("__main__", "__mp_main__"): raise SystemExit( @@ -33,9 +31,8 @@ CONFIG_KEY = "labelPRBasedOnFilePath" -current_files = subprocess.check_output(["git", "ls-files"]).decode().splitlines() -git_root = Path(subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode().strip()) -cyborg_config_path = git_root / ".github" / "boring-cyborg.yml" +repo_root = Path(__file__).parent.parent.parent.parent +cyborg_config_path = repo_root / ".github" / "boring-cyborg.yml" cyborg_config = yaml.safe_load(cyborg_config_path.read_text()) if CONFIG_KEY not in cyborg_config: raise SystemExit(f"Missing section {CONFIG_KEY}") @@ -43,12 +40,14 @@ errors = [] for label, patterns in cyborg_config[CONFIG_KEY].items(): for pattern in patterns: - if glob.globfilter(current_files, pattern, flags=glob.G | glob.E): + try: + next(Path(repo_root).glob(pattern)) continue - yaml_path = f"{CONFIG_KEY}.{label}" - errors.append( - f"Unused pattern [{colored(pattern, 'cyan')}] in [{colored(yaml_path, 'cyan')}] section." - ) + except StopIteration: + yaml_path = f"{CONFIG_KEY}.{label}" + errors.append( + f"Unused pattern [{colored(pattern, 'cyan')}] in [{colored(yaml_path, 'cyan')}] section." + ) if errors: print(f"Found {colored(str(len(errors)), 'red')} problems:") From 7ead4e844f4957e323e1e118ca2080708465de28 Mon Sep 17 00:00:00 2001 From: JISHAN GARGACHARYA <34843832+jishangarg@users.noreply.github.com> Date: Tue, 1 Oct 2024 12:02:59 +0530 Subject: [PATCH 222/349] Doc update - Airflow local settings no longer importable from dags folder (#42231) --------- Co-authored-by: Jishan Garg --- docs/apache-airflow/howto/set-config.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/apache-airflow/howto/set-config.rst b/docs/apache-airflow/howto/set-config.rst index 4f19159a810d4..2a03b2bbf5ee4 100644 --- a/docs/apache-airflow/howto/set-config.rst +++ b/docs/apache-airflow/howto/set-config.rst @@ -179,6 +179,8 @@ where you can configure such local settings - This is usually done in the ``airf You should create a ``airflow_local_settings.py`` file and put it in a directory in ``sys.path`` or in the ``$AIRFLOW_HOME/config`` folder. (Airflow adds ``$AIRFLOW_HOME/config`` to ``sys.path`` when Airflow is initialized) +Starting from Airflow 2.10.1, the $AIRFLOW_HOME/dags folder is no longer included in sys.path at initialization, so any local settings in that folder will not be imported. Ensure that airflow_local_settings.py is located in a path that is part of sys.path during initialization, like $AIRFLOW_HOME/config. +For more context about this change, see the `mailing list announcement `_. You can see the example of such local settings here: From 329cbf8b6db07b11c2eab6282f316b42349ddfca Mon Sep 17 00:00:00 2001 From: Dewen Kong Date: Tue, 1 Oct 2024 02:40:33 -0400 Subject: [PATCH 223/349] add flexibility for redis service (#41811) * add service type options for redis * additional value * update based on testing * fix syntax * update description * Update chart/templates/redis/redis-service.yaml Co-authored-by: rom sharon <33751805+romsharon98@users.noreply.github.com> * Update chart/templates/redis/redis-service.yaml Co-authored-by: rom sharon <33751805+romsharon98@users.noreply.github.com> --------- Co-authored-by: rom sharon <33751805+romsharon98@users.noreply.github.com> --- chart/templates/redis/redis-service.yaml | 10 ++++++ chart/values.schema.json | 33 +++++++++++++++++++ chart/values.yaml | 8 +++++ helm_tests/other/test_redis.py | 41 ++++++++++++++++++++++++ 4 files changed, 92 insertions(+) diff --git a/chart/templates/redis/redis-service.yaml b/chart/templates/redis/redis-service.yaml index 17d4c8d5e4836..ee010901ef84e 100644 --- a/chart/templates/redis/redis-service.yaml +++ b/chart/templates/redis/redis-service.yaml @@ -35,7 +35,14 @@ metadata: {{- toYaml . | nindent 4 }} {{- end }} spec: +{{- if eq .Values.redis.service.type "ClusterIP" }} type: ClusterIP + {{- if .Values.redis.service.clusterIP }} + clusterIP: {{ .Values.redis.service.clusterIP }} + {{- end }} +{{- else }} + type: {{ .Values.redis.service.type }} +{{- end }} selector: tier: airflow component: redis @@ -45,4 +52,7 @@ spec: protocol: TCP port: {{ .Values.ports.redisDB }} targetPort: {{ .Values.ports.redisDB }} + {{- if (and (eq .Values.redis.service.type "NodePort") (not (empty .Values.redis.service.nodePort))) }} + nodePort: {{ .Values.redis.service.nodePort }} + {{- end }} {{- end }} diff --git a/chart/values.schema.json b/chart/values.schema.json index 948f09f3b9a4d..d8b5de41c8eb8 100644 --- a/chart/values.schema.json +++ b/chart/values.schema.json @@ -7670,6 +7670,39 @@ "type": "integer", "default": 600 }, + "service": { + "description": "service configuration.", + "type": "object", + "additionalProperties": false, + "properties": { + "type": { + "description": "Service type.", + "enum": [ + "ClusterIP", + "NodePort", + "LoadBalancer" + ], + "type": "string", + "default": "ClusterIP" + }, + "clusterIP": { + "description": "If using `ClusterIP` service type, custom IP address can be specified.", + "type": [ + "string", + "null" + ], + "default": null + }, + "nodePort": { + "description": "If using `NodePort` service type, custom node port can be specified.", + "type": [ + "integer", + "null" + ], + "default": null + } + } + }, "persistence": { "description": "Persistence configuration.", "type": "object", diff --git a/chart/values.yaml b/chart/values.yaml index 7bfa733a905b4..0edb9f2bd7cd3 100644 --- a/chart/values.yaml +++ b/chart/values.yaml @@ -2378,6 +2378,14 @@ redis: # Annotations to add to worker kubernetes service account. annotations: {} + service: + # service type, default: ClusterIP + type: "ClusterIP" + # If using ClusterIP service type, custom IP address can be specified + clusterIP: + # If using NodePort service type, custom node port can be specified + nodePort: + persistence: # Enable persistent volumes enabled: true diff --git a/helm_tests/other/test_redis.py b/helm_tests/other/test_redis.py index a5a6f2099e4ab..8c44567420314 100644 --- a/helm_tests/other/test_redis.py +++ b/helm_tests/other/test_redis.py @@ -452,3 +452,44 @@ def test_overridden_automount_service_account_token(self): show_only=["templates/redis/redis-serviceaccount.yaml"], ) assert jmespath.search("automountServiceAccountToken", docs[0]) is False + + +class TestRedisService: + """Tests redis service.""" + + @pytest.mark.parametrize( + "redis_values, expected", + [ + ({"redis": {"service": {"type": "ClusterIP"}}}, "ClusterIP"), + ({"redis": {"service": {"type": "NodePort"}}}, "NodePort"), + ({"redis": {"service": {"type": "LoadBalancer"}}}, "LoadBalancer"), + ], + ) + def test_redis_service_type(self, redis_values, expected): + docs = render_chart( + values=redis_values, + show_only=["templates/redis/redis-service.yaml"], + ) + assert expected == jmespath.search("spec.type", docs[0]) + + def test_redis_service_nodeport(self): + docs = render_chart( + values={ + "redis": { + "service": {"type": "NodePort", "nodePort": 11111}, + }, + }, + show_only=["templates/redis/redis-service.yaml"], + ) + assert 11111 == jmespath.search("spec.ports[0].nodePort", docs[0]) + + def test_redis_service_clusterIP(self): + docs = render_chart( + values={ + "redis": { + "service": {"type": "ClusterIP", "clusterIP": "127.0.0.1"}, + }, + }, + show_only=["templates/redis/redis-service.yaml"], + ) + assert "127.0.0.1" == jmespath.search("spec.clusterIP", docs[0]) From ea1432f315439a1908f0e6d0caa39656c38e664b Mon Sep 17 00:00:00 2001 From: Howard Yoo <32691630+howardyoo@users.noreply.github.com> Date: Tue, 1 Oct 2024 02:04:06 -0500 Subject: [PATCH 224/349] Support of host.name in OTEL metrics and usage of OTEL_RESOURCE_ATTRIBUTES in metrics (#42428) * fixes: 42425, and 42424 * fixed static type check failure --- airflow/metrics/otel_logger.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/metrics/otel_logger.py b/airflow/metrics/otel_logger.py index 14080eb2d8313..6d7d6e8fffa1c 100644 --- a/airflow/metrics/otel_logger.py +++ b/airflow/metrics/otel_logger.py @@ -28,7 +28,7 @@ from opentelemetry.metrics import Observation from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics._internal.export import ConsoleMetricExporter, PeriodicExportingMetricReader -from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.resources import HOST_NAME, SERVICE_NAME, Resource from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning @@ -40,6 +40,7 @@ get_validator, stat_name_otel_handler, ) +from airflow.utils.net import get_hostname if TYPE_CHECKING: from opentelemetry.metrics import Instrument @@ -410,7 +411,7 @@ def get_otel_logger(cls) -> SafeOtelLogger: debug = conf.getboolean("metrics", "otel_debugging_on") service_name = conf.get("metrics", "otel_service") - resource = Resource(attributes={SERVICE_NAME: service_name}) + resource = Resource.create(attributes={HOST_NAME: get_hostname(), SERVICE_NAME: service_name}) protocol = "https" if ssl_active else "http" endpoint = f"{protocol}://{host}:{port}/v1/metrics" From c83d50961c881be68f74536fbaccc64aa3af382d Mon Sep 17 00:00:00 2001 From: Pierre Jeambrun Date: Tue, 1 Oct 2024 15:42:37 +0800 Subject: [PATCH 225/349] Update fastapi operation ids (#42588) * Update operation id automatically * Cherry pick Brent change --------- Co-authored-by: Brent Bovenzi --- airflow/api_fastapi/openapi/v1-generated.yaml | 10 +- airflow/api_fastapi/views/public/__init__.py | 5 +- airflow/api_fastapi/views/public/dags.py | 5 +- airflow/api_fastapi/views/router.py | 93 +++++++++++++++++++ airflow/api_fastapi/views/ui/__init__.py | 5 +- airflow/api_fastapi/views/ui/assets.py | 5 +- airflow/ui/openapi-gen/queries/common.ts | 44 ++++----- airflow/ui/openapi-gen/queries/prefetch.ts | 15 ++- airflow/ui/openapi-gen/queries/queries.ts | 32 +++---- airflow/ui/openapi-gen/queries/suspense.ts | 20 ++-- .../ui/openapi-gen/requests/services.gen.ts | 40 ++++---- airflow/ui/openapi-gen/requests/types.gen.ts | 24 ++--- airflow/ui/package.json | 2 +- airflow/ui/src/App.test.tsx | 7 +- airflow/ui/src/pages/DagsList.tsx | 4 +- 15 files changed, 193 insertions(+), 118 deletions(-) create mode 100644 airflow/api_fastapi/views/router.py diff --git a/airflow/api_fastapi/openapi/v1-generated.yaml b/airflow/api_fastapi/openapi/v1-generated.yaml index a38a1021890d6..b08ef42c16df1 100644 --- a/airflow/api_fastapi/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/openapi/v1-generated.yaml @@ -12,7 +12,7 @@ paths: tags: - Asset summary: Next Run Assets - operationId: next_run_assets_ui_next_run_datasets__dag_id__get + operationId: next_run_assets parameters: - name: dag_id in: path @@ -27,7 +27,7 @@ paths: application/json: schema: type: object - title: Response Next Run Assets Ui Next Run Datasets Dag Id Get + title: Response Next Run Assets '422': description: Validation Error content: @@ -40,7 +40,7 @@ paths: - DAG summary: Get Dags description: Get all DAGs. - operationId: get_dags_public_dags_get + operationId: get_dags parameters: - name: limit in: query @@ -136,7 +136,7 @@ paths: - DAG summary: Patch Dags description: Patch multiple DAGs. - operationId: patch_dags_public_dags_patch + operationId: patch_dags parameters: - name: update_mask in: query @@ -258,7 +258,7 @@ paths: - DAG summary: Patch Dag description: Patch the specific DAG. - operationId: patch_dag_public_dags__dag_id__patch + operationId: patch_dag parameters: - name: dag_id in: path diff --git a/airflow/api_fastapi/views/public/__init__.py b/airflow/api_fastapi/views/public/__init__.py index b6466536c3359..1c2511fc82ac2 100644 --- a/airflow/api_fastapi/views/public/__init__.py +++ b/airflow/api_fastapi/views/public/__init__.py @@ -17,11 +17,10 @@ from __future__ import annotations -from fastapi import APIRouter - from airflow.api_fastapi.views.public.dags import dags_router +from airflow.api_fastapi.views.router import AirflowRouter -public_router = APIRouter(prefix="/public") +public_router = AirflowRouter(prefix="/public") public_router.include_router(dags_router) diff --git a/airflow/api_fastapi/views/public/dags.py b/airflow/api_fastapi/views/public/dags.py index a6c25d6568c1e..3761d593d2fd0 100644 --- a/airflow/api_fastapi/views/public/dags.py +++ b/airflow/api_fastapi/views/public/dags.py @@ -17,7 +17,7 @@ from __future__ import annotations -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import Depends, HTTPException, Query from sqlalchemy import update from sqlalchemy.orm import Session from typing_extensions import Annotated @@ -42,9 +42,10 @@ SortParam, ) from airflow.api_fastapi.serializers.dags import DAGCollectionResponse, DAGPatchBody, DAGResponse +from airflow.api_fastapi.views.router import AirflowRouter from airflow.models import DagModel -dags_router = APIRouter(tags=["DAG"]) +dags_router = AirflowRouter(tags=["DAG"]) @dags_router.get("/dags") diff --git a/airflow/api_fastapi/views/router.py b/airflow/api_fastapi/views/router.py new file mode 100644 index 0000000000000..5bf07e0fe834a --- /dev/null +++ b/airflow/api_fastapi/views/router.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from enum import Enum +from typing import Any, Callable, Sequence + +from fastapi import APIRouter, params +from fastapi.datastructures import Default +from fastapi.routing import APIRoute +from fastapi.types import DecoratedCallable, IncEx +from fastapi.utils import generate_unique_id +from starlette.responses import JSONResponse, Response +from starlette.routing import BaseRoute + + +class AirflowRouter(APIRouter): + """Extends the FastAPI default router.""" + + def api_route( + self, + path: str, + *, + response_model: Any = Default(None), + status_code: int | None = None, + tags: list[str | Enum] | None = None, + dependencies: Sequence[params.Depends] | None = None, + summary: str | None = None, + description: str | None = None, + response_description: str = "Successful Response", + responses: dict[int | str, dict[str, Any]] | None = None, + deprecated: bool | None = None, + methods: list[str] | None = None, + operation_id: str | None = None, + response_model_include: IncEx | None = None, + response_model_exclude: IncEx | None = None, + response_model_by_alias: bool = True, + response_model_exclude_unset: bool = False, + response_model_exclude_defaults: bool = False, + response_model_exclude_none: bool = False, + include_in_schema: bool = True, + response_class: type[Response] = Default(JSONResponse), + name: str | None = None, + callbacks: list[BaseRoute] | None = None, + openapi_extra: dict[str, Any] | None = None, + generate_unique_id_function: Callable[[APIRoute], str] = Default(generate_unique_id), + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: + self.add_api_route( + path, + func, + response_model=response_model, + status_code=status_code, + tags=tags, + dependencies=dependencies, + summary=summary, + description=description, + response_description=response_description, + responses=responses, + deprecated=deprecated, + methods=methods, + operation_id=operation_id or func.__name__, + response_model_include=response_model_include, + response_model_exclude=response_model_exclude, + response_model_by_alias=response_model_by_alias, + response_model_exclude_unset=response_model_exclude_unset, + response_model_exclude_defaults=response_model_exclude_defaults, + response_model_exclude_none=response_model_exclude_none, + include_in_schema=include_in_schema, + response_class=response_class, + name=name, + callbacks=callbacks, + openapi_extra=openapi_extra, + generate_unique_id_function=generate_unique_id_function, + ) + return func + + return decorator diff --git a/airflow/api_fastapi/views/ui/__init__.py b/airflow/api_fastapi/views/ui/__init__.py index edba930c3d1d1..8495ac5e5e6a4 100644 --- a/airflow/api_fastapi/views/ui/__init__.py +++ b/airflow/api_fastapi/views/ui/__init__.py @@ -16,10 +16,9 @@ # under the License. from __future__ import annotations -from fastapi import APIRouter - +from airflow.api_fastapi.views.router import AirflowRouter from airflow.api_fastapi.views.ui.assets import assets_router -ui_router = APIRouter(prefix="/ui") +ui_router = AirflowRouter(prefix="/ui") ui_router.include_router(assets_router) diff --git a/airflow/api_fastapi/views/ui/assets.py b/airflow/api_fastapi/views/ui/assets.py index 739c7d64af439..01cc9fd1cfbff 100644 --- a/airflow/api_fastapi/views/ui/assets.py +++ b/airflow/api_fastapi/views/ui/assets.py @@ -17,16 +17,17 @@ from __future__ import annotations -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import Depends, HTTPException, Request from sqlalchemy import and_, func, select from sqlalchemy.orm import Session from typing_extensions import Annotated from airflow.api_fastapi.db.common import get_session +from airflow.api_fastapi.views.router import AirflowRouter from airflow.models import DagModel from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel, DagScheduleAssetReference -assets_router = APIRouter(tags=["Asset"]) +assets_router = AirflowRouter(tags=["Asset"]) @assets_router.get("/next_run_datasets/{dag_id}", include_in_schema=False) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index b1508c86c0c4b..96e49cc6d7673 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -4,37 +4,31 @@ import { UseQueryResult } from "@tanstack/react-query"; import { AssetService, DagService } from "../requests/services.gen"; import { DagRunState } from "../requests/types.gen"; -export type AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse = - Awaited< - ReturnType - >; -export type AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetQueryResult< - TData = AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse, +export type AssetServiceNextRunAssetsDefaultResponse = Awaited< + ReturnType +>; +export type AssetServiceNextRunAssetsQueryResult< + TData = AssetServiceNextRunAssetsDefaultResponse, TError = unknown, > = UseQueryResult; -export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKey = - "AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet"; -export const UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn = ( +export const useAssetServiceNextRunAssetsKey = "AssetServiceNextRunAssets"; +export const UseAssetServiceNextRunAssetsKeyFn = ( { dagId, }: { dagId: string; }, queryKey?: Array, -) => [ - useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKey, - ...(queryKey ?? [{ dagId }]), -]; -export type DagServiceGetDagsPublicDagsGetDefaultResponse = Awaited< - ReturnType +) => [useAssetServiceNextRunAssetsKey, ...(queryKey ?? [{ dagId }])]; +export type DagServiceGetDagsDefaultResponse = Awaited< + ReturnType >; -export type DagServiceGetDagsPublicDagsGetQueryResult< - TData = DagServiceGetDagsPublicDagsGetDefaultResponse, +export type DagServiceGetDagsQueryResult< + TData = DagServiceGetDagsDefaultResponse, TError = unknown, > = UseQueryResult; -export const useDagServiceGetDagsPublicDagsGetKey = - "DagServiceGetDagsPublicDagsGet"; -export const UseDagServiceGetDagsPublicDagsGetKeyFn = ( +export const useDagServiceGetDagsKey = "DagServiceGetDags"; +export const UseDagServiceGetDagsKeyFn = ( { dagDisplayNamePattern, dagIdPattern, @@ -60,7 +54,7 @@ export const UseDagServiceGetDagsPublicDagsGetKeyFn = ( } = {}, queryKey?: Array, ) => [ - useDagServiceGetDagsPublicDagsGetKey, + useDagServiceGetDagsKey, ...(queryKey ?? [ { dagDisplayNamePattern, @@ -76,9 +70,9 @@ export const UseDagServiceGetDagsPublicDagsGetKeyFn = ( }, ]), ]; -export type DagServicePatchDagsPublicDagsPatchMutationResult = Awaited< - ReturnType +export type DagServicePatchDagsMutationResult = Awaited< + ReturnType >; -export type DagServicePatchDagPublicDagsDagIdPatchMutationResult = Awaited< - ReturnType +export type DagServicePatchDagMutationResult = Awaited< + ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow/ui/openapi-gen/queries/prefetch.ts index 7de7282a9bd01..95c2c7b737348 100644 --- a/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow/ui/openapi-gen/queries/prefetch.ts @@ -12,7 +12,7 @@ import * as Common from "./common"; * @returns unknown Successful Response * @throws ApiError */ -export const prefetchUseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = ( +export const prefetchUseAssetServiceNextRunAssets = ( queryClient: QueryClient, { dagId, @@ -21,11 +21,8 @@ export const prefetchUseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = ( }, ) => queryClient.prefetchQuery({ - queryKey: Common.UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn( - { dagId }, - ), - queryFn: () => - AssetService.nextRunAssetsUiNextRunDatasetsDagIdGet({ dagId }), + queryKey: Common.UseAssetServiceNextRunAssetsKeyFn({ dagId }), + queryFn: () => AssetService.nextRunAssets({ dagId }), }); /** * Get Dags @@ -44,7 +41,7 @@ export const prefetchUseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = ( * @returns DAGCollectionResponse Successful Response * @throws ApiError */ -export const prefetchUseDagServiceGetDagsPublicDagsGet = ( +export const prefetchUseDagServiceGetDags = ( queryClient: QueryClient, { dagDisplayNamePattern, @@ -71,7 +68,7 @@ export const prefetchUseDagServiceGetDagsPublicDagsGet = ( } = {}, ) => queryClient.prefetchQuery({ - queryKey: Common.UseDagServiceGetDagsPublicDagsGetKeyFn({ + queryKey: Common.UseDagServiceGetDagsKeyFn({ dagDisplayNamePattern, dagIdPattern, lastDagRunState, @@ -84,7 +81,7 @@ export const prefetchUseDagServiceGetDagsPublicDagsGet = ( tags, }), queryFn: () => - DagService.getDagsPublicDagsGet({ + DagService.getDags({ dagDisplayNamePattern, dagIdPattern, lastDagRunState, diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 5eda2a3d0e4d2..985bf952e3eb3 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -17,8 +17,8 @@ import * as Common from "./common"; * @returns unknown Successful Response * @throws ApiError */ -export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = < - TData = Common.AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse, +export const useAssetServiceNextRunAssets = < + TData = Common.AssetServiceNextRunAssetsDefaultResponse, TError = unknown, TQueryKey extends Array = unknown[], >( @@ -31,12 +31,8 @@ export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = < options?: Omit, "queryKey" | "queryFn">, ) => useQuery({ - queryKey: Common.UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn( - { dagId }, - queryKey, - ), - queryFn: () => - AssetService.nextRunAssetsUiNextRunDatasetsDagIdGet({ dagId }) as TData, + queryKey: Common.UseAssetServiceNextRunAssetsKeyFn({ dagId }, queryKey), + queryFn: () => AssetService.nextRunAssets({ dagId }) as TData, ...options, }); /** @@ -56,8 +52,8 @@ export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGet = < * @returns DAGCollectionResponse Successful Response * @throws ApiError */ -export const useDagServiceGetDagsPublicDagsGet = < - TData = Common.DagServiceGetDagsPublicDagsGetDefaultResponse, +export const useDagServiceGetDags = < + TData = Common.DagServiceGetDagsDefaultResponse, TError = unknown, TQueryKey extends Array = unknown[], >( @@ -88,7 +84,7 @@ export const useDagServiceGetDagsPublicDagsGet = < options?: Omit, "queryKey" | "queryFn">, ) => useQuery({ - queryKey: Common.UseDagServiceGetDagsPublicDagsGetKeyFn( + queryKey: Common.UseDagServiceGetDagsKeyFn( { dagDisplayNamePattern, dagIdPattern, @@ -104,7 +100,7 @@ export const useDagServiceGetDagsPublicDagsGet = < queryKey, ), queryFn: () => - DagService.getDagsPublicDagsGet({ + DagService.getDags({ dagDisplayNamePattern, dagIdPattern, lastDagRunState, @@ -135,8 +131,8 @@ export const useDagServiceGetDagsPublicDagsGet = < * @returns DAGCollectionResponse Successful Response * @throws ApiError */ -export const useDagServicePatchDagsPublicDagsPatch = < - TData = Common.DagServicePatchDagsPublicDagsPatchMutationResult, +export const useDagServicePatchDags = < + TData = Common.DagServicePatchDagsMutationResult, TError = unknown, TContext = unknown, >( @@ -190,7 +186,7 @@ export const useDagServicePatchDagsPublicDagsPatch = < tags, updateMask, }) => - DagService.patchDagsPublicDagsPatch({ + DagService.patchDags({ dagIdPattern, lastDagRunState, limit, @@ -214,8 +210,8 @@ export const useDagServicePatchDagsPublicDagsPatch = < * @returns DAGResponse Successful Response * @throws ApiError */ -export const useDagServicePatchDagPublicDagsDagIdPatch = < - TData = Common.DagServicePatchDagPublicDagsDagIdPatchMutationResult, +export const useDagServicePatchDag = < + TData = Common.DagServicePatchDagMutationResult, TError = unknown, TContext = unknown, >( @@ -244,7 +240,7 @@ export const useDagServicePatchDagPublicDagsDagIdPatch = < TContext >({ mutationFn: ({ dagId, requestBody, updateMask }) => - DagService.patchDagPublicDagsDagIdPatch({ + DagService.patchDag({ dagId, requestBody, updateMask, diff --git a/airflow/ui/openapi-gen/queries/suspense.ts b/airflow/ui/openapi-gen/queries/suspense.ts index 18dba7acb4b5b..dc8b99dfb2188 100644 --- a/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow/ui/openapi-gen/queries/suspense.ts @@ -12,8 +12,8 @@ import * as Common from "./common"; * @returns unknown Successful Response * @throws ApiError */ -export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetSuspense = < - TData = Common.AssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetDefaultResponse, +export const useAssetServiceNextRunAssetsSuspense = < + TData = Common.AssetServiceNextRunAssetsDefaultResponse, TError = unknown, TQueryKey extends Array = unknown[], >( @@ -26,12 +26,8 @@ export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetSuspense = < options?: Omit, "queryKey" | "queryFn">, ) => useSuspenseQuery({ - queryKey: Common.UseAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetKeyFn( - { dagId }, - queryKey, - ), - queryFn: () => - AssetService.nextRunAssetsUiNextRunDatasetsDagIdGet({ dagId }) as TData, + queryKey: Common.UseAssetServiceNextRunAssetsKeyFn({ dagId }, queryKey), + queryFn: () => AssetService.nextRunAssets({ dagId }) as TData, ...options, }); /** @@ -51,8 +47,8 @@ export const useAssetServiceNextRunAssetsUiNextRunDatasetsDagIdGetSuspense = < * @returns DAGCollectionResponse Successful Response * @throws ApiError */ -export const useDagServiceGetDagsPublicDagsGetSuspense = < - TData = Common.DagServiceGetDagsPublicDagsGetDefaultResponse, +export const useDagServiceGetDagsSuspense = < + TData = Common.DagServiceGetDagsDefaultResponse, TError = unknown, TQueryKey extends Array = unknown[], >( @@ -83,7 +79,7 @@ export const useDagServiceGetDagsPublicDagsGetSuspense = < options?: Omit, "queryKey" | "queryFn">, ) => useSuspenseQuery({ - queryKey: Common.UseDagServiceGetDagsPublicDagsGetKeyFn( + queryKey: Common.UseDagServiceGetDagsKeyFn( { dagDisplayNamePattern, dagIdPattern, @@ -99,7 +95,7 @@ export const useDagServiceGetDagsPublicDagsGetSuspense = < queryKey, ), queryFn: () => - DagService.getDagsPublicDagsGet({ + DagService.getDags({ dagDisplayNamePattern, dagIdPattern, lastDagRunState, diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 7fb6306afbc67..be216bd534c61 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3,14 +3,14 @@ import type { CancelablePromise } from "./core/CancelablePromise"; import { OpenAPI } from "./core/OpenAPI"; import { request as __request } from "./core/request"; import type { - NextRunAssetsUiNextRunDatasetsDagIdGetData, - NextRunAssetsUiNextRunDatasetsDagIdGetResponse, - GetDagsPublicDagsGetData, - GetDagsPublicDagsGetResponse, - PatchDagsPublicDagsPatchData, - PatchDagsPublicDagsPatchResponse, - PatchDagPublicDagsDagIdPatchData, - PatchDagPublicDagsDagIdPatchResponse, + NextRunAssetsData, + NextRunAssetsResponse, + GetDagsData, + GetDagsResponse, + PatchDagsData, + PatchDagsResponse, + PatchDagData, + PatchDagResponse, } from "./types.gen"; export class AssetService { @@ -21,9 +21,9 @@ export class AssetService { * @returns unknown Successful Response * @throws ApiError */ - public static nextRunAssetsUiNextRunDatasetsDagIdGet( - data: NextRunAssetsUiNextRunDatasetsDagIdGetData, - ): CancelablePromise { + public static nextRunAssets( + data: NextRunAssetsData, + ): CancelablePromise { return __request(OpenAPI, { method: "GET", url: "/ui/next_run_datasets/{dag_id}", @@ -55,9 +55,9 @@ export class DagService { * @returns DAGCollectionResponse Successful Response * @throws ApiError */ - public static getDagsPublicDagsGet( - data: GetDagsPublicDagsGetData = {}, - ): CancelablePromise { + public static getDags( + data: GetDagsData = {}, + ): CancelablePromise { return __request(OpenAPI, { method: "GET", url: "/public/dags", @@ -96,9 +96,9 @@ export class DagService { * @returns DAGCollectionResponse Successful Response * @throws ApiError */ - public static patchDagsPublicDagsPatch( - data: PatchDagsPublicDagsPatchData, - ): CancelablePromise { + public static patchDags( + data: PatchDagsData, + ): CancelablePromise { return __request(OpenAPI, { method: "PATCH", url: "/public/dags", @@ -135,9 +135,9 @@ export class DagService { * @returns DAGResponse Successful Response * @throws ApiError */ - public static patchDagPublicDagsDagIdPatch( - data: PatchDagPublicDagsDagIdPatchData, - ): CancelablePromise { + public static patchDag( + data: PatchDagData, + ): CancelablePromise { return __request(OpenAPI, { method: "PATCH", url: "/public/dags/{dag_id}", diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 0fe7134ba8c31..e1db8310a1dc1 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -88,15 +88,15 @@ export type ValidationError = { type: string; }; -export type NextRunAssetsUiNextRunDatasetsDagIdGetData = { +export type NextRunAssetsData = { dagId: string; }; -export type NextRunAssetsUiNextRunDatasetsDagIdGetResponse = { +export type NextRunAssetsResponse = { [key: string]: unknown; }; -export type GetDagsPublicDagsGetData = { +export type GetDagsData = { dagDisplayNamePattern?: string | null; dagIdPattern?: string | null; lastDagRunState?: DagRunState | null; @@ -109,9 +109,9 @@ export type GetDagsPublicDagsGetData = { tags?: Array; }; -export type GetDagsPublicDagsGetResponse = DAGCollectionResponse; +export type GetDagsResponse = DAGCollectionResponse; -export type PatchDagsPublicDagsPatchData = { +export type PatchDagsData = { dagIdPattern?: string | null; lastDagRunState?: DagRunState | null; limit?: number; @@ -124,20 +124,20 @@ export type PatchDagsPublicDagsPatchData = { updateMask?: Array | null; }; -export type PatchDagsPublicDagsPatchResponse = DAGCollectionResponse; +export type PatchDagsResponse = DAGCollectionResponse; -export type PatchDagPublicDagsDagIdPatchData = { +export type PatchDagData = { dagId: string; requestBody: DAGPatchBody; updateMask?: Array | null; }; -export type PatchDagPublicDagsDagIdPatchResponse = DAGResponse; +export type PatchDagResponse = DAGResponse; export type $OpenApiTs = { "/ui/next_run_datasets/{dag_id}": { get: { - req: NextRunAssetsUiNextRunDatasetsDagIdGetData; + req: NextRunAssetsData; res: { /** * Successful Response @@ -154,7 +154,7 @@ export type $OpenApiTs = { }; "/public/dags": { get: { - req: GetDagsPublicDagsGetData; + req: GetDagsData; res: { /** * Successful Response @@ -167,7 +167,7 @@ export type $OpenApiTs = { }; }; patch: { - req: PatchDagsPublicDagsPatchData; + req: PatchDagsData; res: { /** * Successful Response @@ -198,7 +198,7 @@ export type $OpenApiTs = { }; "/public/dags/{dag_id}": { patch: { - req: PatchDagPublicDagsDagIdPatchData; + req: PatchDagData; res: { /** * Successful Response diff --git a/airflow/ui/package.json b/airflow/ui/package.json index c7d79f792a59e..1f77334074f03 100644 --- a/airflow/ui/package.json +++ b/airflow/ui/package.json @@ -11,7 +11,7 @@ "lint:fix": "eslint --fix && tsc --p tsconfig.app.json", "format": "pnpm prettier --write .", "preview": "vite preview", - "codegen": "openapi-rq -i \"../api_fastapi/openapi/v1-generated.yaml\" -c axios --format prettier -o openapi-gen", + "codegen": "openapi-rq -i \"../api_fastapi/openapi/v1-generated.yaml\" -c axios --format prettier -o openapi-gen --operationId", "test": "vitest run", "coverage": "vitest run --coverage" }, diff --git a/airflow/ui/src/App.test.tsx b/airflow/ui/src/App.test.tsx index d34cf016befdb..5efcf90f1a05d 100644 --- a/airflow/ui/src/App.test.tsx +++ b/airflow/ui/src/App.test.tsx @@ -105,10 +105,9 @@ beforeEach(() => { isLoading: false, } as QueryObserverSuccessResult; - vi.spyOn( - openapiQueriesModule, - "useDagServiceGetDagsPublicDagsGet", - ).mockImplementation(() => returnValue); + vi.spyOn(openapiQueriesModule, "useDagServiceGetDags").mockImplementation( + () => returnValue, + ); }); afterEach(() => { diff --git a/airflow/ui/src/pages/DagsList.tsx b/airflow/ui/src/pages/DagsList.tsx index fe764f117e45d..ab480d2cbabdb 100644 --- a/airflow/ui/src/pages/DagsList.tsx +++ b/airflow/ui/src/pages/DagsList.tsx @@ -30,7 +30,7 @@ import { Select as ReactSelect } from "chakra-react-select"; import { type ChangeEventHandler, useCallback } from "react"; import { useSearchParams } from "react-router-dom"; -import { useDagServiceGetDagsPublicDagsGet } from "openapi/queries"; +import { useDagServiceGetDags } from "openapi/queries"; import type { DAGResponse } from "openapi/requests/types.gen"; import { DataTable } from "../components/DataTable"; @@ -93,7 +93,7 @@ export const DagsList = ({ cardView = false }) => { const [sort] = sorting; const orderBy = sort ? `${sort.desc ? "-" : ""}${sort.id}` : undefined; - const { data, isLoading } = useDagServiceGetDagsPublicDagsGet({ + const { data, isLoading } = useDagServiceGetDags({ limit: pagination.pageSize, offset: pagination.pageIndex * pagination.pageSize, onlyActive: true, From 1045c8009b1f3db076a8728f5bce0ff70fb66bc6 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Tue, 1 Oct 2024 01:02:38 -0700 Subject: [PATCH 226/349] Limit build-images workflow to main and v2-10 branches (#42601) There is no need to run image builds for PRs to old branches. --- .github/workflows/build-images.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/build-images.yml b/.github/workflows/build-images.yml index 1256fd2f0da6e..abf966faede02 100644 --- a/.github/workflows/build-images.yml +++ b/.github/workflows/build-images.yml @@ -21,6 +21,10 @@ run-name: > Build images for ${{ github.event.pull_request.title }} ${{ github.event.pull_request._links.html.href }} on: # yamllint disable-line rule:truthy pull_request_target: + branches: + - main + - v2-10-stable + - v2-10-test permissions: # all other permissions are set to none contents: read From 807f267081ce944bf0568d82bbc3617bf0921fd8 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Tue, 1 Oct 2024 14:16:03 +0200 Subject: [PATCH 227/349] openlineage: add unit test for listener hooks on dag run state changes. (#42554) openlineage: cover task instance failure in unit tests. Signed-off-by: Jakub Dardzinski --- tests/dags/test_openlineage_execution.py | 12 +++++- .../openlineage/plugins/test_execution.py | 11 +++++ .../openlineage/plugins/test_listener.py | 43 +++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/tests/dags/test_openlineage_execution.py b/tests/dags/test_openlineage_execution.py index 475e43ef6ac2e..f8db91611e848 100644 --- a/tests/dags/test_openlineage_execution.py +++ b/tests/dags/test_openlineage_execution.py @@ -27,13 +27,16 @@ class OpenLineageExecutionOperator(BaseOperator): - def __init__(self, *, stall_amount=0, **kwargs) -> None: + def __init__(self, *, stall_amount=0, fail=False, **kwargs) -> None: super().__init__(**kwargs) self.stall_amount = stall_amount + self.fail = fail def execute(self, context): self.log.error("STALL AMOUNT %s", self.stall_amount) time.sleep(1) + if self.fail: + raise Exception("Failed") def get_openlineage_facets_on_start(self): return OperatorLineage(inputs=[Dataset(namespace="test", name="on-start")]) @@ -43,6 +46,11 @@ def get_openlineage_facets_on_complete(self, task_instance): time.sleep(self.stall_amount) return OperatorLineage(inputs=[Dataset(namespace="test", name="on-complete")]) + def get_openlineage_facets_on_failure(self, task_instance): + self.log.error("STALL AMOUNT %s", self.stall_amount) + time.sleep(self.stall_amount) + return OperatorLineage(inputs=[Dataset(namespace="test", name="on-failure")]) + with DAG( dag_id="test_openlineage_execution", @@ -57,3 +65,5 @@ def get_openlineage_facets_on_complete(self, task_instance): mid_stall = OpenLineageExecutionOperator(task_id="execute_mid_stall", stall_amount=15) long_stall = OpenLineageExecutionOperator(task_id="execute_long_stall", stall_amount=30) + + fail = OpenLineageExecutionOperator(task_id="execute_fail", fail=True) diff --git a/tests/providers/openlineage/plugins/test_execution.py b/tests/providers/openlineage/plugins/test_execution.py index 3adaaac582dd7..8c0bdd55a1f96 100644 --- a/tests/providers/openlineage/plugins/test_execution.py +++ b/tests/providers/openlineage/plugins/test_execution.py @@ -124,6 +124,17 @@ def test_not_stalled_task_emits_proper_lineage(self): assert has_value_in_events(events, ["inputs", "name"], "on-start") assert has_value_in_events(events, ["inputs", "name"], "on-complete") + @pytest.mark.db_test + @conf_vars({("openlineage", "transport"): f'{{"type": "file", "log_file_path": "{listener_path}"}}'}) + def test_not_stalled_failing_task_emits_proper_lineage(self): + task_name = "execute_fail" + run_id = "test_failure" + self.setup_job(task_name, run_id) + + events = get_sorted_events(tmp_dir) + assert has_value_in_events(events, ["inputs", "name"], "on-start") + assert has_value_in_events(events, ["inputs", "name"], "on-failure") + @conf_vars( { ("openlineage", "transport"): f'{{"type": "file", "log_file_path": "{listener_path}"}}', diff --git a/tests/providers/openlineage/plugins/test_listener.py b/tests/providers/openlineage/plugins/test_listener.py index 92467a58af8c5..57c0134f79d82 100644 --- a/tests/providers/openlineage/plugins/test_listener.py +++ b/tests/providers/openlineage/plugins/test_listener.py @@ -606,6 +606,49 @@ def test_listener_on_dag_run_state_changes_configure_process_pool_size(mock_exec mock_executor.return_value.submit.assert_called_once() +class MockExecutor: + def __init__(self, *args, **kwargs): + self.submitted = False + self.succeeded = False + self.result = None + + def submit(self, fn, /, *args, **kwargs): + self.submitted = True + try: + fn(*args, **kwargs) + self.succeeded = True + except Exception: + pass + return MagicMock() + + def shutdown(self, *args, **kwargs): + print("Shutting down") + + +@pytest.mark.parametrize( + ("method", "dag_run_state"), + [ + ("on_dag_run_running", DagRunState.RUNNING), + ("on_dag_run_success", DagRunState.SUCCESS), + ("on_dag_run_failed", DagRunState.FAILED), + ], +) +@patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit") +def test_listener_on_dag_run_state_changes(mock_emit, method, dag_run_state, create_task_instance): + mock_executor = MockExecutor() + ti = create_task_instance(dag_id="dag", task_id="op") + # Change the state explicitly to set end_date following the logic in the method + ti.dag_run.set_state(dag_run_state) + with mock.patch( + "airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", return_value=mock_executor + ): + listener = OpenLineageListener() + getattr(listener, method)(ti.dag_run, None) + assert mock_executor.submitted is True + assert mock_executor.succeeded is True + mock_emit.assert_called_once() + + def test_listener_logs_failed_serialization(): listener = OpenLineageListener() callback_future = Future() From 0c49fbeffbdd6e46a967417ff58c79a187110c7f Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Tue, 1 Oct 2024 09:59:11 -0400 Subject: [PATCH 228/349] Update Rest API tests to no longer rely on FAB auth manager. Move tests specific to FAB permissions to FAB provider (#42523) --- .../managers/simple/simple_auth_manager.py | 7 +- airflow/auth/managers/simple/user.py | 6 +- .../0034_3_0_0_update_user_id_type.py | 52 +++ ..._3_0_0_add_name_field_to_dataset_model.py} | 4 +- airflow/models/dagrun.py | 2 +- airflow/models/taskinstance.py | 2 +- .../api/auth/backend/basic_auth.py | 4 +- .../api/auth/backend/kerberos_auth.py | 2 +- .../fab/auth_manager/models/anonymous_user.py | 2 +- docs/apache-airflow/img/airflow_erd.sha256 | 2 +- docs/apache-airflow/img/airflow_erd.svg | 4 +- docs/apache-airflow/migrations-ref.rst | 5 +- tests/api_connexion/conftest.py | 11 +- .../endpoints/test_backfill_endpoint.py | 31 +- .../endpoints/test_config_endpoint.py | 12 +- .../endpoints/test_connection_endpoint.py | 17 +- .../endpoints/test_dag_endpoint.py | 100 +--- .../endpoints/test_dag_parsing.py | 16 +- .../endpoints/test_dag_run_endpoint.py | 154 +------ .../endpoints/test_dag_source_endpoint.py | 71 +-- .../endpoints/test_dag_stats_endpoint.py | 15 +- .../endpoints/test_dag_warning_endpoint.py | 33 +- .../endpoints/test_dataset_endpoint.py | 234 +--------- .../endpoints/test_event_log_endpoint.py | 80 +--- .../endpoints/test_extra_link_endpoint.py | 20 +- .../endpoints/test_import_error_endpoint.py | 170 +------ .../endpoints/test_log_endpoint.py | 9 +- .../test_mapped_task_instance_endpoint.py | 25 +- .../endpoints/test_plugin_endpoint.py | 12 +- .../endpoints/test_pool_endpoint.py | 17 +- .../endpoints/test_provider_endpoint.py | 12 +- .../endpoints/test_task_endpoint.py | 16 +- .../endpoints/test_task_instance_endpoint.py | 219 +-------- .../endpoints/test_variable_endpoint.py | 37 +- .../endpoints/test_xcom_endpoint.py | 74 +-- tests/api_connexion/test_auth.py | 188 ++------ tests/api_connexion/test_security.py | 8 +- .../api_endpoints/api_connexion_utils.py | 116 +++++ .../remote_user_api_auth_backend.py | 81 ++++ .../auth_manager/api_endpoints/test_auth.py | 176 ++++++++ .../api_endpoints/test_backfill_endpoint.py | 264 +++++++++++ .../auth_manager/api_endpoints}/test_cors.py | 35 +- .../api_endpoints/test_dag_endpoint.py | 252 +++++++++++ .../api_endpoints/test_dag_run_endpoint.py | 273 +++++++++++ .../api_endpoints/test_dag_source_endpoint.py | 144 ++++++ .../test_dag_warning_endpoint.py | 84 ++++ .../api_endpoints/test_dataset_endpoint.py | 327 ++++++++++++++ .../api_endpoints/test_event_log_endpoint.py | 151 +++++++ .../test_import_error_endpoint.py | 221 +++++++++ .../test_role_and_permission_endpoint.py | 22 +- .../test_role_and_permission_schema.py | 22 +- .../test_task_instance_endpoint.py | 427 ++++++++++++++++++ .../api_endpoints/test_user_endpoint.py | 15 +- .../api_endpoints/test_user_schema.py | 3 +- .../api_endpoints/test_variable_endpoint.py | 88 ++++ .../api_endpoints/test_xcom_endpoint.py | 230 ++++++++++ tests/providers/fab/auth_manager/conftest.py | 17 +- .../fab/auth_manager/test_security.py | 2 +- .../auth_manager/views/test_permissions.py | 2 +- .../fab/auth_manager/views/test_roles_list.py | 2 +- .../fab/auth_manager/views/test_user.py | 2 +- .../fab/auth_manager/views/test_user_edit.py | 2 +- .../fab/auth_manager/views/test_user_stats.py | 2 +- tests/test_utils/api_connexion_utils.py | 64 +-- .../remote_user_api_auth_backend.py | 32 +- .../www/views/test_views_custom_user_views.py | 5 +- tests/www/views/test_views_dagrun.py | 6 +- tests/www/views/test_views_home.py | 2 +- tests/www/views/test_views_tasks.py | 6 +- tests/www/views/test_views_variable.py | 2 +- 70 files changed, 3208 insertions(+), 1542 deletions(-) create mode 100644 airflow/migrations/versions/0034_3_0_0_update_user_id_type.py rename airflow/migrations/versions/{0034_3_0_0_add_name_field_to_dataset_model.py => 0035_3_0_0_add_name_field_to_dataset_model.py} (98%) create mode 100644 tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_auth.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py rename tests/{api_connexion => providers/fab/auth_manager/api_endpoints}/test_cors.py (81%) create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py rename tests/{api_connexion/schemas => providers/fab/auth_manager/api_endpoints}/test_role_and_permission_schema.py (85%) create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py create mode 100644 tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py diff --git a/airflow/auth/managers/simple/simple_auth_manager.py b/airflow/auth/managers/simple/simple_auth_manager.py index 451068733667c..4a9639a998c46 100644 --- a/airflow/auth/managers/simple/simple_auth_manager.py +++ b/airflow/auth/managers/simple/simple_auth_manager.py @@ -221,7 +221,12 @@ def _is_authorized( user = self.get_user() if not user: return False - role_str = user.get_role().upper() + + user_role = user.get_role() + if not user_role: + return False + + role_str = user_role.upper() role = SimpleAuthManagerRole[role_str] if role == SimpleAuthManagerRole.ADMIN: return True diff --git a/airflow/auth/managers/simple/user.py b/airflow/auth/managers/simple/user.py index fa032f596ee44..f4591b0b1c751 100644 --- a/airflow/auth/managers/simple/user.py +++ b/airflow/auth/managers/simple/user.py @@ -24,10 +24,10 @@ class SimpleAuthManagerUser(BaseUser): User model for users managed by the simple auth manager. :param username: The username - :param role: The role associated to the user + :param role: The role associated to the user. If not provided, the user has no permission """ - def __init__(self, *, username: str, role: str) -> None: + def __init__(self, *, username: str, role: str | None) -> None: self.username = username self.role = role @@ -37,5 +37,5 @@ def get_id(self) -> str: def get_name(self) -> str: return self.username - def get_role(self): + def get_role(self) -> str | None: return self.role diff --git a/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py b/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py new file mode 100644 index 0000000000000..321a1e2bbafa8 --- /dev/null +++ b/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Update dag_run_note.user_id and task_instance_note.user_id columns to String. + +Revision ID: 44eabb1904b4 +Revises: 16cbcb1c8c36 +Create Date: 2024-09-27 09:57:29.830521 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "44eabb1904b4" +down_revision = "16cbcb1c8c36" +branch_labels = None +depends_on = None +airflow_version = "3.0.0" + + +def upgrade(): + with op.batch_alter_table("dag_run_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.String(length=128)) + with op.batch_alter_table("task_instance_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.String(length=128)) + + +def downgrade(): + with op.batch_alter_table("dag_run_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.Integer(), postgresql_using="user_id::integer") + with op.batch_alter_table("task_instance_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.Integer(), postgresql_using="user_id::integer") diff --git a/airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py b/airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py similarity index 98% rename from airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py rename to airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py index 5c8aec69e9be9..6016dd9658908 100644 --- a/airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py +++ b/airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py @@ -30,7 +30,7 @@ also rename the one on DatasetAliasModel here for consistency. Revision ID: 0d9e73a75ee4 -Revises: 16cbcb1c8c36 +Revises: 44eabb1904b4 Create Date: 2024-08-13 09:45:32.213222 """ @@ -42,7 +42,7 @@ # revision identifiers, used by Alembic. revision = "0d9e73a75ee4" -down_revision = "16cbcb1c8c36" +down_revision = "44eabb1904b4" branch_labels = None depends_on = None airflow_version = "3.0.0" diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 5d53e51763dff..4928c7fcbd8f7 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1687,7 +1687,7 @@ class DagRunNote(Base): __tablename__ = "dag_run_note" - user_id = Column(Integer, nullable=True) + user_id = Column(String(128), nullable=True) dag_run_id = Column(Integer, primary_key=True, nullable=False) content = Column(String(1000).with_variant(Text(1000), "mysql")) created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index b19e65486307d..333a4cad91cbe 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -4002,7 +4002,7 @@ class TaskInstanceNote(TaskInstanceDependencies): __tablename__ = "task_instance_note" - user_id = Column(Integer, nullable=True) + user_id = Column(String(128), nullable=True) task_id = Column(StringID(), primary_key=True, nullable=False) dag_id = Column(StringID(), primary_key=True, nullable=False) run_id = Column(StringID(), primary_key=True, nullable=False) diff --git a/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py b/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py index 3a0328fe9962c..7b50338733453 100644 --- a/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py +++ b/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py @@ -62,9 +62,7 @@ def requires_authentication(function: T): @wraps(function) def decorated(*args, **kwargs): - if auth_current_user() is not None or current_app.appbuilder.get_app.config.get( - "AUTH_ROLE_PUBLIC", None - ): + if auth_current_user() is not None or current_app.config.get("AUTH_ROLE_PUBLIC", None): return function(*args, **kwargs) else: return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"}) diff --git a/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py b/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py index d8d5a95ee676b..f2038b27597c1 100644 --- a/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py +++ b/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py @@ -124,7 +124,7 @@ def requires_authentication(function: T, find_user: Callable[[str], BaseUser] | @wraps(function) def decorated(*args, **kwargs): - if current_app.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None): + if current_app.config.get("AUTH_ROLE_PUBLIC", None): response = function(*args, **kwargs) return make_response(response) diff --git a/airflow/providers/fab/auth_manager/models/anonymous_user.py b/airflow/providers/fab/auth_manager/models/anonymous_user.py index 2f294fd9e5d0e..9afb2cdff635f 100644 --- a/airflow/providers/fab/auth_manager/models/anonymous_user.py +++ b/airflow/providers/fab/auth_manager/models/anonymous_user.py @@ -35,7 +35,7 @@ class AnonymousUser(AnonymousUserMixin, BaseUser): @property def roles(self): if not self._roles: - public_role = current_app.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None) + public_role = current_app.config.get("AUTH_ROLE_PUBLIC", None) self._roles = {current_app.appbuilder.sm.find_role(public_role)} if public_role else set() return self._roles diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index e4a952da1b9fd..bca068fde6749 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -c33e9a583a5b29eb748ebd50e117643e11bcb2a9b61ec017efd690621e22769b \ No newline at end of file +64dfad12dfd49f033c4723c2f3bb3bac58dd956136fb24a87a2e5a6ae176ec1a \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index 76fbd8f841f25..4eb6c2ee70917 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -1394,7 +1394,7 @@ user_id - [INTEGER] + [VARCHAR(100)] @@ -1813,7 +1813,7 @@ user_id - [INTEGER] + [VARCHAR(100)] diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index a547d03d75be6..e4fb2dfa332eb 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -39,7 +39,10 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``0d9e73a75ee4`` (head) | ``16cbcb1c8c36`` | ``3.0.0`` | Add name and group fields to DatasetModel. | +| ``0d9e73a75ee4`` (head) | ``44eabb1904b4`` | ``3.0.0`` | Add name and group fields to DatasetModel. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``44eabb1904b4`` | ``16cbcb1c8c36`` | ``3.0.0`` | Update dag_run_note.user_id and task_instance_note.user_id | +| | | | columns to String. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``16cbcb1c8c36`` | ``522625f6d606`` | ``3.0.0`` | Remove redundant index. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py index 38e7b58cb5981..6a23b2cf11d93 100644 --- a/tests/api_connexion/conftest.py +++ b/tests/api_connexion/conftest.py @@ -36,9 +36,16 @@ def minimal_app_for_api(): ] ) def factory(): - with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}): + with conf_vars( + { + ("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend", + ( + "core", + "auth_manager", + ): "airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager", + } + ): _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore - _app.config["AUTH_ROLE_PUBLIC"] = None return _app return factory() diff --git a/tests/api_connexion/endpoints/test_backfill_endpoint.py b/tests/api_connexion/endpoints/test_backfill_endpoint.py index 51a4faf40055c..07b2a3fd56c2d 100644 --- a/tests/api_connexion/endpoints/test_backfill_endpoint.py +++ b/tests/api_connexion/endpoints/test_backfill_endpoint.py @@ -29,7 +29,6 @@ from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import create_user, delete_user @@ -50,25 +49,11 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), - ], - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user(app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) with DAG( DAG_ID, @@ -93,9 +78,8 @@ def configured_app(minimal_app_for_api): yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBackfillEndpoint: @@ -178,7 +162,6 @@ def test_should_respond_200(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -240,7 +223,6 @@ def test_no_exist(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -268,7 +250,6 @@ class TestCreateBackfill(TestBackfillEndpoint): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -347,7 +328,6 @@ def test_should_respond_200(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -409,7 +389,6 @@ def test_should_respond_200(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), diff --git a/tests/api_connexion/endpoints/test_config_endpoint.py b/tests/api_connexion/endpoints/test_config_endpoint.py index 475753a4a902e..bd88c491c952b 100644 --- a/tests/api_connexion/endpoints/test_config_endpoint.py +++ b/tests/api_connexion/endpoints/test_config_endpoint.py @@ -21,7 +21,6 @@ import pytest -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -54,18 +53,17 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) with conf_vars({("webserver", "expose_config"): "True"}): yield minimal_app_for_api - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestGetConfig: diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py b/tests/api_connexion/endpoints/test_connection_endpoint.py index a19b046aa2747..a140046656e31 100644 --- a/tests/api_connexion/endpoints/test_connection_endpoint.py +++ b/tests/api_connexion/endpoints/test_connection_endpoint.py @@ -24,7 +24,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Connection from airflow.secrets.environment_variables import CONN_ENV_PREFIX -from airflow.security import permissions from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -38,22 +37,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestConnectionEndpoint: diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 9905b4e27ab2c..6d4ffc2d06d2c 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -28,7 +28,6 @@ from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils.session import provide_session from airflow.utils.state import TaskInstanceState from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user @@ -56,33 +55,11 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), - ], - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user(app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) with DAG( DAG_ID, @@ -107,9 +84,8 @@ def configured_app(minimal_app_for_api): yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagEndpoint: @@ -258,13 +234,6 @@ def test_should_respond_200_with_schedule_none(self, session): "pickle_id": None, } == response.json - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(1) - response = self.client.get( - "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - def test_should_respond_404(self): response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 404 @@ -282,13 +251,6 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 - def test_should_respond_403_with_granular_access_for_different_dag(self): - self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 403 - @pytest.mark.parametrize( "fields", [ @@ -961,15 +923,6 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids): assert expected_dag_ids == dag_ids - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" - @pytest.mark.parametrize( "url, expected_dag_ids", [ @@ -1252,18 +1205,6 @@ def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, sessio session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None, expected_extra=payload ) - def test_should_respond_200_on_patch_with_granular_dag_access(self, session): - self._create_dag_models(1) - response = self.client.patch( - "/api/v1/dags/TEST_DAG_1", - json={ - "is_paused": False, - }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None) - def test_should_respond_400_on_invalid_request(self): patch_body = { "is_paused": True, @@ -1279,24 +1220,6 @@ def test_should_respond_400_on_invalid_request(self): "type": EXCEPTIONS_LINK_MAP[400], } - def test_validation_error_raises_400(self): - patch_body = { - "ispaused": True, - } - dag_model = self._create_dag_model() - response = self.client.patch( - f"/api/v1/dags/{dag_model.dag_id}", - json=patch_body, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 400 - assert response.json == { - "detail": "{'ispaused': ['Unknown field.']}", - "status": 400, - "title": "Bad Request", - "type": EXCEPTIONS_LINK_MAP[400], - } - def test_non_existing_dag_raises_not_found(self): patch_body = { "is_paused": True, @@ -1820,19 +1743,6 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids): assert expected_dag_ids == dag_ids - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(3) - response = self.client.patch( - "api/v1/dags?dag_id_pattern=~", - json={ - "is_paused": False, - }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" - @pytest.mark.parametrize( "url, expected_dag_ids", [ diff --git a/tests/api_connexion/endpoints/test_dag_parsing.py b/tests/api_connexion/endpoints/test_dag_parsing.py index 521d8d9e8cd99..ae42a565dd052 100644 --- a/tests/api_connexion/endpoints/test_dag_parsing.py +++ b/tests/api_connexion/endpoints/test_dag_parsing.py @@ -24,7 +24,6 @@ from airflow.models import DagBag from airflow.models.dagbag import DagPriorityParsingRequest -from airflow.security import permissions from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_dag_parsing_requests @@ -45,21 +44,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)], # type: ignore + role_name="admin", ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - TEST_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_EDIT]}, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagParsingRequest: diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index f3921da7b9c29..73c75b98a43b1 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -30,12 +30,11 @@ from airflow.models.dagrun import DagRun from airflow.models.param import Param from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags @@ -52,79 +51,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, # type: ignore - username="test_no_dag_run_create_permission", - role_name="TestNoDagRunCreatePermission", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, # type: ignore - username="test_dag_view_only", - role_name="TestViewDags", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, # type: ignore - username="test_view_dags", - role_name="TestViewDags", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - ], + role_name="admin", ) - create_user( - app, # type: ignore - username="test_granular_permissions", - role_name="TestGranularDag", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)], - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_ID", - access_control={ - "TestGranularDag": {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}, - "TestNoDagRunCreatePermission": {permissions.RESOURCE_DAG_RUN: {permissions.ACTION_CAN_CREATE}}, - }, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_dag_view_only") # type: ignore - delete_user(app, username="test_view_dags") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_no_dag_run_create_permission") # type: ignore - delete_roles(app) + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagRunEndpoint: @@ -499,16 +435,6 @@ def test_should_return_all_with_tilde_as_dag_id_and_all_dag_permissions(self): dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] assert dag_run_ids == expected_dag_run_ids - def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): - self._create_test_dag_run(extra_dag=True) - expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"] - response = self.client.get( - "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] - assert dag_run_ids == expected_dag_run_ids - def test_should_raises_401_unauthenticated(self): self._create_test_dag_run() @@ -907,57 +833,6 @@ def test_order_by_raises_for_invalid_attr(self): msg = "Ordering with 'dag_ru' is disallowed or the attribute does not exist on the model" assert response.json["detail"] == msg - def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): - self._create_test_dag_run(extra_dag=True) - expected_response_json_1 = { - "dag_id": "TEST_DAG_ID", - "dag_run_id": "TEST_DAG_RUN_ID_1", - "end_date": None, - "state": "running", - "execution_date": self.default_time, - "logical_date": self.default_time, - "external_trigger": True, - "start_date": self.default_time, - "conf": {}, - "data_interval_end": None, - "data_interval_start": None, - "last_scheduling_decision": None, - "run_type": "manual", - "note": None, - } - expected_response_json_1.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) - expected_response_json_2 = { - "dag_id": "TEST_DAG_ID", - "dag_run_id": "TEST_DAG_RUN_ID_2", - "end_date": None, - "state": "running", - "execution_date": self.default_time_2, - "logical_date": self.default_time_2, - "external_trigger": True, - "start_date": self.default_time, - "conf": {}, - "data_interval_end": None, - "data_interval_start": None, - "last_scheduling_decision": None, - "run_type": "manual", - "note": None, - } - expected_response_json_2.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) - - response = self.client.post( - "api/v1/dags/~/dagRuns/list", - json={"dag_ids": []}, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - assert response.json == { - "dag_runs": [ - expected_response_json_1, - expected_response_json_2, - ], - "total_entries": 2, - } - @pytest.mark.parametrize( "payload, error", [ @@ -1328,15 +1203,6 @@ def test_raises_validation_error_for_invalid_params(self): assert response.status_code == 400 assert "Invalid input for param" in response.json["detail"] - def test_dagrun_trigger_with_dag_level_permissions(self): - self._create_dag("TEST_DAG_ID") - response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns", - json={"conf": {"validated_number": 1}}, - environ_overrides={"REMOTE_USER": "test_no_dag_run_create_permission"}, - ) - assert response.status_code == 200 - @mock.patch("airflow.api_connexion.endpoints.dag_run_endpoint.get_airflow_app") def test_dagrun_creation_exception_is_handled(self, mock_get_app, session): self._create_dag("TEST_DAG_ID") @@ -1627,11 +1493,7 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) - @pytest.mark.parametrize( - "username", - ["test_dag_view_only", "test_view_dags", "test_granular_permissions", "test_no_permissions"], - ) - def test_should_raises_403_unauthorized(self, username): + def test_should_raises_403_unauthorized(self): self._create_dag("TEST_DAG_ID") response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", @@ -1639,7 +1501,7 @@ def test_should_raises_403_unauthorized(self, username): "dag_run_id": "TEST_DAG_RUN_ID_1", "execution_date": self.default_time, }, - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index a8d1224e034c3..f4df56ba629ae 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -23,7 +23,6 @@ import pytest from airflow.models import DagBag -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags @@ -44,29 +43,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], # type: ignore + role_name="admin", ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - TEST_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - EXAMPLE_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - TEST_MULTIPLE_DAGS_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestGetSource: @@ -123,18 +109,6 @@ def test_should_respond_200_json(self, url_safe_serializer): assert dag_docstring in response.json["content"] assert "application/json" == response.headers["Content-Type"] - def test_should_respond_406(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - test_dag: DAG = dagbag.dags[TEST_DAG_ID] - - url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" - response = self.client.get( - url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"} - ) - - assert 406 == response.status_code - def test_should_respond_404(self): wrong_fileloc = "abcd1234" url = f"/api/v1/dagSources/{wrong_fileloc}" @@ -167,38 +141,3 @@ def test_should_raise_403_forbidden(self, url_safe_serializer): environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 - - def test_should_respond_403_not_readable(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID] - - response = self.client.get( - f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, - ) - read_dag = self.client.get( - f"/api/v1/dags/{NOT_READABLE_DAG_ID}", - environ_overrides={"REMOTE_USER": "test"}, - ) - assert response.status_code == 403 - assert read_dag.status_code == 403 - - def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID] - - response = self.client.get( - f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, - ) - - read_dag = self.client.get( - f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}", - environ_overrides={"REMOTE_USER": "test"}, - ) - assert response.status_code == 403 - assert read_dag.status_code == 200 diff --git a/tests/api_connexion/endpoints/test_dag_stats_endpoint.py b/tests/api_connexion/endpoints/test_dag_stats_endpoint.py index 36fc54d3a5b17..9ab5b49765931 100644 --- a/tests/api_connexion/endpoints/test_dag_stats_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_stats_endpoint.py @@ -22,7 +22,6 @@ from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState @@ -38,21 +37,17 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagStatsEndpoint: diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py index 3e7c805173b39..f156d8921c0e6 100644 --- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py @@ -22,7 +22,6 @@ from airflow.models.dag import DagModel from airflow.models.dagwarning import DagWarning -from airflow.security import permissions from airflow.utils.session import create_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_dag_warnings, clear_db_dags @@ -34,30 +33,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - ], # type: ignore - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user( - app, # type:ignore - username="test_with_dag2_read", - role_name="TestWithDag2Read", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"), - ], # type: ignore + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) yield minimal_app_for_api - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_with_dag2_read") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBaseDagWarning: @@ -162,11 +147,3 @@ def test_should_raise_403_forbidden(self): "/api/v1/dagWarnings", environ_overrides={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 - - def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): - response = self.client.get( - "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, - query_string={"dag_id": "dag1"}, - ) - assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py index 5caec0ac2a131..76c164654c9d8 100644 --- a/tests/api_connexion/endpoints/test_dataset_endpoint.py +++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py @@ -33,7 +33,6 @@ TaskOutletAssetReference, ) from airflow.models.dagrun import DagRun -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session from airflow.utils.types import DagRunType @@ -50,31 +49,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ASSET), - ], - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user( - app, # type: ignore - username="test_queued_event", - role_name="TestQueuedEvent", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET), - ], + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test_queued_event") # type: ignore - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDatasetEndpoint: @@ -768,43 +752,6 @@ def _create_dataset_dag_run_queues(self, dag_id, dataset_id, session): class TestGetDagDatasetQueuedEvent(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - dataset_uri = "s3://bucket/key" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - - def test_should_respond_404(self): - dag_id = "not_exists" - dataset_uri = "not_exists" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self, session): dag_id = "dummy" dataset_uri = "dummy" @@ -826,47 +773,6 @@ def test_should_raise_403_forbidden(self, session): class TestDeleteDagDatasetQueuedEvent(TestDatasetEndpoint): - def test_delete_should_respond_204(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_uri = "s3://bucket/key" - dataset_id = self._create_dataset(session).id - - adrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) - session.add(adrq) - session.commit() - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 1 - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 204 - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 0 - _check_last_log( - session, dag_id=dag_id, event="api.delete_dag_dataset_queued_event", execution_date=None - ) - - def test_should_respond_404(self): - dag_id = "not_exists" - dataset_uri = "not_exists" - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self, session): dag_id = "dummy" dataset_uri = "dummy" @@ -884,46 +790,6 @@ def test_should_raise_403_forbidden(self, session): class TestGetDagDatasetQueuedEvents(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "queued_events": [ - { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - ], - "total_entries": 1, - } - - def test_should_respond_404(self): - dag_id = "not_exists" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dag_id = "dummy" @@ -943,22 +809,6 @@ def test_should_raise_403_forbidden(self): class TestDeleteDagDatasetQueuedEvents(TestDatasetEndpoint): - def test_should_respond_404(self): - dag_id = "not_exists" - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dag_id = "dummy" @@ -978,47 +828,6 @@ def test_should_raise_403_forbidden(self): class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - dataset_uri = "s3://bucket/key" - - response = self.client.get( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "queued_events": [ - { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - ], - "total_entries": 1, - } - - def test_should_respond_404(self): - dataset_uri = "not_exists" - - response = self.client.get( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dataset_uri = "not_exists" @@ -1038,39 +847,6 @@ def test_should_raise_403_forbidden(self): class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint): - def test_delete_should_respond_204(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - dataset_uri = "s3://bucket/key" - - response = self.client.delete( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 204 - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 0 - _check_last_log(session, dag_id=None, event="api.delete_dataset_queued_events", execution_date=None) - - def test_should_respond_404(self): - dataset_uri = "not_exists" - - response = self.client.delete( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dataset_uri = "not_exists" diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py index 0fdef1a3af2b6..e5ca3d301765a 100644 --- a/tests/api_connexion/endpoints/test_event_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py @@ -20,7 +20,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Log -from airflow.security import permissions from airflow.utils import timezone from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -33,32 +32,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore + role_name="admin", ) - create_user( - app, # type:ignore - username="test_granular", - role_name="TestGranular", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_ID_1", - access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_ID_2", - access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_granular") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") @pytest.fixture @@ -274,33 +257,6 @@ def test_should_raises_401_unauthenticated(self, log_model): assert_401(response) - def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session): - eventlog1 = create_log_model( - event="TEST_EVENT_1", - dag_id="TEST_DAG_ID_1", - task_id="TEST_TASK_ID_1", - owner="TEST_OWNER_1", - when=self.default_time, - ) - eventlog2 = create_log_model( - event="TEST_EVENT_2", - dag_id="TEST_DAG_ID_2", - task_id="TEST_TASK_ID_2", - owner="TEST_OWNER_2", - when=self.default_time_2, - ) - session.add_all([eventlog1, eventlog2]) - session.commit() - for attr in ["dag_id", "task_id", "owner", "event"]: - attr_value = f"TEST_{attr}_1".upper() - response = self.client.get( - f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} - ) - assert response.status_code == 200 - assert response.json["total_entries"] == 1 - assert len(response.json["event_logs"]) == 1 - assert response.json["event_logs"][0][attr] == attr_value - def test_should_filter_eventlogs_by_when(self, create_log_model, session): eventlog1 = create_log_model(event="TEST_EVENT_1", when=self.default_time) eventlog2 = create_log_model(event="TEST_EVENT_2", when=self.default_time_2) @@ -339,32 +295,6 @@ def test_should_filter_eventlogs_by_run_id(self, create_log_model, session): assert {eventlog["event"] for eventlog in response.json["event_logs"]} == expected_eventlogs assert all({eventlog["run_id"] == run_id for eventlog in response.json["event_logs"]}) - def test_should_filter_eventlogs_by_included_events(self, create_log_model): - for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: - create_log_model(event=event, when=self.default_time) - response = self.client.get( - "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, - ) - assert response.status_code == 200 - response_data = response.json - assert len(response_data["event_logs"]) == 2 - assert response_data["total_entries"] == 2 - assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]} - - def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): - for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: - create_log_model(event=event, when=self.default_time) - response = self.client.get( - "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, - ) - assert response.status_code == 200 - response_data = response.json - assert len(response_data["event_logs"]) == 1 - assert response_data["total_entries"] == 1 - assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} - class TestGetEventLogPagination(TestEventLogEndpoint): @pytest.mark.parametrize( diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py index 1e9226ede9847..2c3eacdc91dc0 100644 --- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py +++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py @@ -26,7 +26,6 @@ from airflow.models.dagbag import DagBag from airflow.models.xcom import XCom from airflow.plugins_manager import AirflowPlugin -from airflow.security import permissions from airflow.timetables.base import DataInterval from airflow.utils import timezone from airflow.utils.state import DagRunState @@ -48,21 +47,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestGetExtraLinks: @@ -78,8 +72,8 @@ def setup_attrs(self, configured_app, session) -> None: self.dag = self._create_dag() self.app.dag_bag = DagBag(os.devnull, include_examples=False) - self.app.dag_bag.dags = {self.dag.dag_id: self.dag} # type: ignore - self.app.dag_bag.sync_to_db() # type: ignore + self.app.dag_bag.dags = {self.dag.dag_id: self.dag} + self.app.dag_bag.sync_to_db() triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} self.dag.create_dagrun( diff --git a/tests/api_connexion/endpoints/test_import_error_endpoint.py b/tests/api_connexion/endpoints/test_import_error_endpoint.py index 635e159bb292c..af2b83ebb1eed 100644 --- a/tests/api_connexion/endpoints/test_import_error_endpoint.py +++ b/tests/api_connexion/endpoints/test_import_error_endpoint.py @@ -21,15 +21,12 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.models.dag import DagModel -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.compat import ParseImportError from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_import_errors -from tests.test_utils.permissions import _resource_name pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -40,42 +37,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), - ], # type: ignore - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user( - app, # type:ignore - username="test_single_dag", - role_name="TestSingleDAG", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore - ) - # For some reason, DAG level permissions are not synced when in the above list of perms, - # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( - [ - { - "role": "TestSingleDAG", - "perms": [ - ( - permissions.ACTION_CAN_READ, - _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG), - ) - ], - } - ] + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_single_dag") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBaseImportError: @@ -152,72 +123,6 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 - def test_should_raise_403_forbidden_without_dag_read(self, session): - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 403 - - def test_should_return_200_with_single_dag_read(self, session): - dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py") - session.add(dag_model) - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - response_data["import_error_id"] = 1 - assert { - "filename": "Lorem_ipsum.py", - "import_error_id": 1, - "stack_trace": "Lorem ipsum", - "timestamp": "2020-06-10T12:00:00+00:00", - } == response_data - - def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session): - for dag_id in TEST_DAG_IDS: - dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py") - session.add(dag_model) - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - response_data["import_error_id"] = 1 - assert { - "filename": "Lorem_ipsum.py", - "import_error_id": 1, - "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", - "timestamp": "2020-06-10T12:00:00+00:00", - } == response_data - class TestGetImportErrorsEndpoint(TestBaseImportError): def test_get_import_errors(self, session): @@ -328,71 +233,6 @@ def test_should_raises_401_unauthenticated(self, session): assert_401(response) - def test_get_import_errors_single_dag(self, session): - for dag_id in TEST_DAG_IDS: - fake_filename = f"/tmp/{dag_id}.py" - dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) - session.add(dag_model) - importerror = ParseImportError( - filename=fake_filename, - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(importerror) - session.commit() - - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - self._normalize_import_errors(response_data["import_errors"]) - assert { - "import_errors": [ - { - "filename": "/tmp/test_dag.py", - "import_error_id": 1, - "stack_trace": "Lorem ipsum", - "timestamp": "2020-06-10T12:00:00+00:00", - }, - ], - "total_entries": 1, - } == response_data - - def test_get_import_errors_single_dag_in_dagfile(self, session): - for dag_id in TEST_DAG_IDS: - fake_filename = "/tmp/all_in_one.py" - dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) - session.add(dag_model) - - importerror = ParseImportError( - filename="/tmp/all_in_one.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(importerror) - session.commit() - - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - self._normalize_import_errors(response_data["import_errors"]) - assert { - "import_errors": [ - { - "filename": "/tmp/all_in_one.py", - "import_error_id": 1, - "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", - "timestamp": "2020-06-10T12:00:00+00:00", - }, - ], - "total_entries": 1, - } == response_data - class TestGetImportErrorsEndpointPagination(TestBaseImportError): @pytest.mark.parametrize( diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index 420d2dd65f89c..2b112e3221843 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -30,7 +30,6 @@ from airflow.decorators import task from airflow.models.dag import DAG from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.types import DagRunType from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user @@ -46,13 +45,9 @@ def configured_app(minimal_app_for_api): create_user( app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") + create_user(app, username="test_no_permissions", role_name=None) yield app diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py index 72cdccdee68df..fc53b8952f4aa 100644 --- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py @@ -28,12 +28,11 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dagbag import DagBag from airflow.models.taskmap import TaskMap -from airflow.security import permissions from airflow.utils.platform import getuser from airflow.utils.session import provide_session from airflow.utils.state import State, TaskInstanceState from airflow.utils.timezone import datetime -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields from tests.test_utils.mock_operators import MockOperator @@ -50,24 +49,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_roles(app) + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestMappedTaskInstanceEndpoint: @@ -133,8 +124,8 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): session.add(ti) self.app.dag_bag = DagBag(os.devnull, include_examples=False) - self.app.dag_bag.dags = {dag_id: dag_maker.dag} # type: ignore - self.app.dag_bag.sync_to_db() # type: ignore + self.app.dag_bag.dags = {dag_id: dag_maker.dag} + self.app.dag_bag.sync_to_db() session.flush() mapped.expand_mapped_task(dr.run_id, session=session) diff --git a/tests/api_connexion/endpoints/test_plugin_endpoint.py b/tests/api_connexion/endpoints/test_plugin_endpoint.py index edf925cf0fa73..0cd630375a282 100644 --- a/tests/api_connexion/endpoints/test_plugin_endpoint.py +++ b/tests/api_connexion/endpoints/test_plugin_endpoint.py @@ -24,7 +24,6 @@ from airflow.hooks.base import BaseHook from airflow.plugins_manager import AirflowPlugin -from airflow.security import permissions from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.base import Timetable from airflow.utils.module_loading import qualname @@ -105,17 +104,16 @@ class MockPlugin(AirflowPlugin): def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestPluginsEndpoint: diff --git a/tests/api_connexion/endpoints/test_pool_endpoint.py b/tests/api_connexion/endpoints/test_pool_endpoint.py index 87439a5811945..2cc095d077aa9 100644 --- a/tests/api_connexion/endpoints/test_pool_endpoint.py +++ b/tests/api_connexion/endpoints/test_pool_endpoint.py @@ -20,7 +20,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models.pool import Pool -from airflow.security import permissions from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -35,22 +34,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBasePoolEndpoints: diff --git a/tests/api_connexion/endpoints/test_provider_endpoint.py b/tests/api_connexion/endpoints/test_provider_endpoint.py index 16e5989cc56db..b4cf8f10a92ae 100644 --- a/tests/api_connexion/endpoints/test_provider_endpoint.py +++ b/tests/api_connexion/endpoints/test_provider_endpoint.py @@ -21,7 +21,6 @@ import pytest from airflow.providers_manager import ProviderInfo -from airflow.security import permissions from tests.test_utils.api_connexion_utils import create_user, delete_user pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -54,17 +53,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBaseProviderEndpoint: diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index d0a4fb903c8b8..b2e068bd507fe 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -27,7 +27,6 @@ from airflow.models.expandinput import EXPAND_INPUT_EMPTY from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags @@ -38,21 +37,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestTaskEndpoint: diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index 25ded6c814b72..b5b3163e988d0 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -25,19 +25,17 @@ from sqlalchemy import select from sqlalchemy.orm import contains_eager -from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models import DagRun, SlaMiss, TaskInstance, Trigger from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF from airflow.models.taskinstancehistory import TaskInstanceHistory -from airflow.security import permissions from airflow.utils.platform import getuser from airflow.utils.session import provide_session from airflow.utils.state import State from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields from tests.test_utils.www import _check_last_log @@ -55,69 +53,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, # type: ignore - username="test_dag_read_only", - role_name="TestDagReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, # type: ignore - username="test_task_read_only", - role_name="TestTaskReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, # type: ignore - username="test_read_only_one_dag", - role_name="TestReadOnlyOneDag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms, - # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( - [ - { - "role": "TestReadOnlyOneDag", - "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")], - } - ] + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_dag_read_only") # type: ignore - delete_user(app, username="test_task_read_only") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_read_only_one_dag") # type: ignore - delete_roles(app) + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestTaskInstanceEndpoint: @@ -219,9 +164,8 @@ def setup_method(self): def teardown_method(self): clear_db_runs() - @pytest.mark.parametrize("username", ["test", "test_dag_read_only", "test_task_read_only"]) @provide_session - def test_should_respond_200(self, username, session): + def test_should_respond_200(self, session): self.create_task_instances(session) # Update ti and set operator to None to # test that operator field is nullable. @@ -232,7 +176,7 @@ def test_should_respond_200(self, username, session): session.commit() response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert response.json == { @@ -723,36 +667,11 @@ def test_should_respond_200(self, task_instances, update_extras, url, expected_t assert response.json["total_entries"] == expected_ti assert len(response.json["task_instances"]) == expected_ti - @pytest.mark.parametrize( - "task_instances, user, expected_ti", - [ - pytest.param( - { - "example_python_operator": 2, - "example_skip_dag": 1, - }, - "test_read_only_one_dag", - 2, - ), - pytest.param( - { - "example_python_operator": 1, - "example_skip_dag": 2, - }, - "test_read_only_one_dag", - 1, - ), - pytest.param( - { - "example_python_operator": 1, - "example_skip_dag": 2, - }, - "test", - 3, - ), - ], - ) - def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session): + def test_return_TI_only_from_readable_dags(self, session): + task_instances = { + "example_python_operator": 1, + "example_skip_dag": 2, + } for dag_id in task_instances: self.create_task_instances( session, @@ -763,11 +682,11 @@ def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ dag_id=dag_id, ) response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} + "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["total_entries"] == expected_ti - assert len(response.json["task_instances"]) == expected_ti + assert response.json["total_entries"] == 3 + assert len(response.json["task_instances"]) == 3 def test_should_respond_200_for_dag_id_filter(self, session): self.create_task_instances(session) @@ -898,44 +817,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): "test", id="test executor filter", ), - pytest.param( - [ - {"pool": "test_pool_1"}, - {"pool": "test_pool_2"}, - {"pool": "test_pool_3"}, - ], - True, - {"pool": ["test_pool_1", "test_pool_2"]}, - 2, - "test_dag_read_only", - id="test pool filter", - ), - pytest.param( - [ - {"state": State.RUNNING}, - {"state": State.QUEUED}, - {"state": State.SUCCESS}, - {"state": State.NONE}, - ], - False, - {"state": ["running", "queued", "none"]}, - 3, - "test_task_read_only", - id="test state filter", - ), - pytest.param( - [ - {"state": State.NONE}, - {"state": State.NONE}, - {"state": State.NONE}, - {"state": State.NONE}, - ], - False, - {}, - 4, - "test_task_read_only", - id="test dag with null states", - ), pytest.param( [ {"duration": 100}, @@ -948,36 +829,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): "test", id="test duration filter", ), - pytest.param( - [ - {"end_date": DEFAULT_DATETIME_1}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - { - "end_date_gte": DEFAULT_DATETIME_STR_1, - "end_date_lte": DEFAULT_DATETIME_STR_2, - }, - 2, - "test_task_read_only", - id="test end date filter", - ), - pytest.param( - [ - {"start_date": DEFAULT_DATETIME_1}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - { - "start_date_gte": DEFAULT_DATETIME_STR_1, - "start_date_lte": DEFAULT_DATETIME_STR_2, - }, - 2, - "test_dag_read_only", - id="test start date filter", - ), pytest.param( [ {"execution_date": DEFAULT_DATETIME_1}, @@ -1162,24 +1013,6 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 - def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session): - self.create_task_instances(session=session) - self.create_task_instances(session=session, dag_id="example_skip_dag") - payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]} - - response = self.client.post( - "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test_read_only_one_dag"}, - json=payload, - ) - assert response.status_code == 403 - assert response.json == { - "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']", - "status": 403, - "title": "Forbidden", - "type": EXCEPTIONS_LINK_MAP[403], - } - def test_should_raise_400_for_no_json(self): response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", @@ -1794,11 +1627,10 @@ def test_should_raises_401_unauthenticated(self): ) assert_401(response) - @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username: str): + def test_should_raise_403_forbidden(self): response = self.client.post( "/api/v1/dags/example_python_operator/clearTaskInstances", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, json={ "dry_run": False, "reset_dag_runs": True, @@ -2043,11 +1875,10 @@ def test_should_raises_401_unauthenticated(self): ) assert_401(response) - @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username): + def test_should_raise_403_forbidden(self): response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, json={ "dry_run": True, "task_id": "print_the_context", @@ -2386,11 +2217,10 @@ def test_should_raises_401_unauthenticated(self): ) assert_401(response) - @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username): + def test_should_raise_403_forbidden(self): response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, json={ "dry_run": True, "new_state": "failed", @@ -2748,14 +2578,13 @@ def setup_method(self): def teardown_method(self): clear_db_runs() - @pytest.mark.parametrize("username", ["test", "test_dag_read_only", "test_task_read_only"]) @provide_session - def test_should_respond_200(self, username, session): + def test_should_respond_200(self, session): self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True) response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert response.json == { diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py b/tests/api_connexion/endpoints/test_variable_endpoint.py index 81405df08b045..aa5f7c99674f8 100644 --- a/tests/api_connexion/endpoints/test_variable_endpoint.py +++ b/tests/api_connexion/endpoints/test_variable_endpoint.py @@ -22,7 +22,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Variable -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_variables @@ -36,40 +35,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), - ], - ) - create_user( - app, # type: ignore - username="test_read_only", - role_name="TestReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), - ], - ) - create_user( - app, # type: ignore - username="test_delete_only", - role_name="TestDeleteOnly", - permissions=[ - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_read_only") # type: ignore - delete_user(app, username="test_delete_only") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestVariableEndpoint: @@ -131,8 +106,6 @@ class TestGetVariable(TestVariableEndpoint): "user, expected_status_code", [ ("test", 200), - ("test_read_only", 200), - ("test_delete_only", 403), ("test_no_permissions", 403), ], ) diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index 7a51714c5b299..809e537f9f88d 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -26,7 +26,6 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils.dates import parse_execution_date from airflow.utils.session import create_session from airflow.utils.timezone import utcnow @@ -52,32 +51,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], - ) - create_user( - app, # type: ignore - username="test_granular_permissions", - role_name="TestGranularDag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "test-dag-id-1", - access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") def _compare_xcom_collections(collection1: dict, collection_2: dict): @@ -435,53 +418,6 @@ def test_should_respond_200_with_tilde_and_access_to_all_dags(self): }, ) - def test_should_respond_200_with_tilde_and_granular_dag_access(self): - dag_id_1 = "test-dag-id-1" - task_id_1 = "test-task-id-1" - execution_date = "2005-04-02T00:00:00+00:00" - execution_date_parsed = parse_execution_date(execution_date) - dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) - self._create_xcom_entries(dag_id_1, dag_run_id_1, execution_date_parsed, task_id_1) - - dag_id_2 = "test-dag-id-2" - task_id_2 = "test-task-id-2" - run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) - self._create_xcom_entries(dag_id_2, run_id_2, execution_date_parsed, task_id_2) - self._create_invalid_xcom_entries(execution_date_parsed) - response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - - assert 200 == response.status_code - response_data = response.json - for xcom_entry in response_data["xcom_entries"]: - xcom_entry["timestamp"] = "TIMESTAMP" - _compare_xcom_collections( - response_data, - { - "xcom_entries": [ - { - "dag_id": dag_id_1, - "execution_date": execution_date, - "key": "test-xcom-key-1", - "task_id": task_id_1, - "timestamp": "TIMESTAMP", - "map_index": -1, - }, - { - "dag_id": dag_id_1, - "execution_date": execution_date, - "key": "test-xcom-key-2", - "task_id": task_id_1, - "timestamp": "TIMESTAMP", - "map_index": -1, - }, - ], - "total_entries": 2, - }, - ) - def test_should_respond_200_with_map_index(self): dag_id = "test-dag-id" task_id = "test-task-id" diff --git a/tests/api_connexion/test_auth.py b/tests/api_connexion/test_auth.py index 7d1dcc088273c..54e5632ad84d1 100644 --- a/tests/api_connexion/test_auth.py +++ b/tests/api_connexion/test_auth.py @@ -16,15 +16,15 @@ # under the License. from __future__ import annotations -from base64 import b64encode +from unittest.mock import patch import pytest -from flask_login import current_user +from airflow.auth.managers.simple.simple_auth_manager import SimpleAuthManager +from airflow.auth.managers.simple.user import SimpleAuthManagerUser from tests.test_utils.api_connexion_utils import assert_401 from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_pools -from tests.test_utils.www import client_with_login pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -34,101 +34,6 @@ class BaseTestAuth: def set_attrs(self, minimal_app_for_api): self.app = minimal_app_for_api - sm = self.app.appbuilder.sm - tester = sm.find_user(username="test") - if not tester: - role_admin = sm.find_role("Admin") - sm.add_user( - username="test", - first_name="test", - last_name="test", - email="test@fab.org", - role=role_admin, - password="test", - ) - - -class TestBasicAuth(BaseTestAuth): - @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): - from airflow.www.extensions.init_security import init_api_auth - - old_auth = getattr(minimal_app_for_api, "api_auth") - - try: - with conf_vars( - {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} - ): - init_api_auth(minimal_app_for_api) - yield - finally: - setattr(minimal_app_for_api, "api_auth", old_auth) - - def test_success(self): - token = "Basic " + b64encode(b"test:test").decode() - clear_db_pools() - - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert current_user.email == "test@fab.org" - - assert response.status_code == 200 - assert response.json == { - "pools": [ - { - "name": "default_pool", - "slots": 128, - "occupied_slots": 0, - "running_slots": 0, - "queued_slots": 0, - "scheduled_slots": 0, - "deferred_slots": 0, - "open_slots": 128, - "description": "Default pool", - "include_deferred": False, - }, - ], - "total_entries": 1, - } - - @pytest.mark.parametrize( - "token", - [ - "basic", - "basic ", - "bearer", - "test:test", - b64encode(b"test:test").decode(), - "bearer ", - "basic: ", - "basic 123", - ], - ) - def test_malformed_headers(self, token): - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 401 - assert response.headers["Content-Type"] == "application/problem+json" - assert response.headers["WWW-Authenticate"] == "Basic" - assert_401(response) - - @pytest.mark.parametrize( - "token", - [ - "basic " + b64encode(b"test").decode(), - "basic " + b64encode(b"test:").decode(), - "basic " + b64encode(b"test:123").decode(), - "basic " + b64encode(b"test test").decode(), - ], - ) - def test_invalid_auth_header(self, token): - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 401 - assert response.headers["Content-Type"] == "application/problem+json" - assert response.headers["WWW-Authenticate"] == "Basic" - assert_401(response) - class TestSessionAuth(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") @@ -144,74 +49,37 @@ def with_session_backend(self, minimal_app_for_api): finally: setattr(minimal_app_for_api, "api_auth", old_auth) - def test_success(self): + @patch.object(SimpleAuthManager, "is_logged_in", return_value=True) + @patch.object( + SimpleAuthManager, "get_user", return_value=SimpleAuthManagerUser(username="test", role="admin") + ) + def test_success(self, *args): clear_db_pools() - admin_user = client_with_login(self.app, username="test", password="test") - response = admin_user.get("/api/v1/pools") - assert response.status_code == 200 - assert response.json == { - "pools": [ - { - "name": "default_pool", - "slots": 128, - "occupied_slots": 0, - "running_slots": 0, - "queued_slots": 0, - "scheduled_slots": 0, - "deferred_slots": 0, - "open_slots": 128, - "description": "Default pool", - "include_deferred": False, - }, - ], - "total_entries": 1, - } - - def test_failure(self): with self.app.test_client() as test_client: response = test_client.get("/api/v1/pools") - assert response.status_code == 401 - assert response.headers["Content-Type"] == "application/problem+json" - assert_401(response) - - -class TestSessionWithBasicAuthFallback(BaseTestAuth): - @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): - from airflow.www.extensions.init_security import init_api_auth - - old_auth = getattr(minimal_app_for_api, "api_auth") - - try: - with conf_vars( - { - ( - "api", - "auth_backends", - ): "airflow.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth" - } - ): - init_api_auth(minimal_app_for_api) - yield - finally: - setattr(minimal_app_for_api, "api_auth", old_auth) - - def test_basic_auth_fallback(self): - token = "Basic " + b64encode(b"test:test").decode() - clear_db_pools() - - # request uses session - admin_user = client_with_login(self.app, username="test", password="test") - response = admin_user.get("/api/v1/pools") - assert response.status_code == 200 - - # request uses basic auth - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 200 + assert response.json == { + "pools": [ + { + "name": "default_pool", + "slots": 128, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "deferred_slots": 0, + "open_slots": 128, + "description": "Default pool", + "include_deferred": False, + }, + ], + "total_entries": 1, + } - # request without session or basic auth header + def test_failure(self): with self.app.test_client() as test_client: response = test_client.get("/api/v1/pools") assert response.status_code == 401 + assert response.headers["Content-Type"] == "application/problem+json" + assert_401(response) diff --git a/tests/api_connexion/test_security.py b/tests/api_connexion/test_security.py index 13a5dd4e25af1..c6a112b1a1bb9 100644 --- a/tests/api_connexion/test_security.py +++ b/tests/api_connexion/test_security.py @@ -18,7 +18,6 @@ import pytest -from airflow.security import permissions from tests.test_utils.api_connexion_utils import create_user, delete_user pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -28,15 +27,14 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore + role_name="admin", ) yield minimal_app_for_api - delete_user(app, username="test") # type: ignore + delete_user(app, username="test") class TestSession: diff --git a/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py b/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py new file mode 100644 index 0000000000000..61d923d5ff125 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from contextlib import contextmanager + +from tests.test_utils.compat import ignore_provider_compatibility_error + +with ignore_provider_compatibility_error("2.9.0+", __file__): + from airflow.providers.fab.auth_manager.security_manager.override import EXISTING_ROLES + + +@contextmanager +def create_test_client(app, user_name, role_name, permissions): + """ + Helper function to create a client with a temporary user which will be deleted once done + """ + client = app.test_client() + with create_user_scope(app, username=user_name, role_name=role_name, permissions=permissions) as _: + resp = client.post("/login/", data={"username": user_name, "password": user_name}) + assert resp.status_code == 302 + yield client + + +@contextmanager +def create_user_scope(app, username, **kwargs): + """ + Helper function designed to be used with pytest fixture mainly. + It will create a user and provide it for the fixture via YIELD (generator) + then will tidy up once test is complete + """ + test_user = create_user(app, username, **kwargs) + + try: + yield test_user + finally: + delete_user(app, username) + + +def create_user(app, username, role_name=None, email=None, permissions=None): + appbuilder = app.appbuilder + + # Removes user and role so each test has isolated test data. + delete_user(app, username) + role = None + if role_name: + delete_role(app, role_name) + role = create_role(app, role_name, permissions) + else: + role = [] + + return appbuilder.sm.add_user( + username=username, + first_name=username, + last_name=username, + email=email or f"{username}@example.org", + role=role, + password=username, + ) + + +def create_role(app, name, permissions=None): + appbuilder = app.appbuilder + role = appbuilder.sm.find_role(name) + if not role: + role = appbuilder.sm.add_role(name) + if not permissions: + permissions = [] + for permission in permissions: + perm_object = appbuilder.sm.get_permission(*permission) + appbuilder.sm.add_permission_to_role(role, perm_object) + return role + + +def set_user_single_role(app, user, role_name): + role = create_role(app, role_name) + if role not in user.roles: + user.roles = [role] + app.appbuilder.sm.update_user(user) + user._perms = None + + +def delete_role(app, name): + if name not in EXISTING_ROLES: + if app.appbuilder.sm.find_role(name): + app.appbuilder.sm.delete_role(name) + + +def delete_roles(app): + for role in app.appbuilder.sm.get_all_roles(): + delete_role(app, role.name) + + +def delete_user(app, username): + appbuilder = app.appbuilder + for user in appbuilder.sm.get_all_users(): + if user.username == username: + _ = [ + delete_role(app, role.name) for role in user.roles if role and role.name not in EXISTING_ROLES + ] + appbuilder.sm.del_register_user(user) + break diff --git a/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py b/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py new file mode 100644 index 0000000000000..b7714e5192e6a --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default authentication backend - everything is allowed""" + +from __future__ import annotations + +import logging +from functools import wraps +from typing import TYPE_CHECKING, Callable, TypeVar, cast + +from flask import Response, request +from flask_login import login_user + +from airflow.utils.airflow_flask_app import get_airflow_app + +if TYPE_CHECKING: + from requests.auth import AuthBase + +log = logging.getLogger(__name__) + +CLIENT_AUTH: tuple[str, str] | AuthBase | None = None + + +def init_app(_): + """Initializes authentication backend""" + + +T = TypeVar("T", bound=Callable) + + +def _lookup_user(user_email_or_username: str): + security_manager = get_airflow_app().appbuilder.sm + user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user( + username=user_email_or_username + ) + if not user: + return None + + if not user.is_active: + return None + + return user + + +def requires_authentication(function: T): + """Decorator for functions that require authentication""" + + @wraps(function) + def decorated(*args, **kwargs): + user_id = request.remote_user + if not user_id: + log.debug("Missing REMOTE_USER.") + return Response("Forbidden", 403) + + log.debug("Looking for user: %s", user_id) + + user = _lookup_user(user_id) + if not user: + return Response("Forbidden", 403) + + log.debug("Found user: %s", user) + + login_user(user, remember=False) + return function(*args, **kwargs) + + return cast(T, decorated) diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_auth.py b/tests/providers/fab/auth_manager/api_endpoints/test_auth.py new file mode 100644 index 0000000000000..d3012e2f1b43e --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_auth.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from base64 import b64encode + +import pytest +from flask_login import current_user + +from tests.test_utils.api_connexion_utils import assert_401 +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_pools +from tests.test_utils.www import client_with_login + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +class BaseTestAuth: + @pytest.fixture(autouse=True) + def set_attrs(self, minimal_app_for_auth_api): + self.app = minimal_app_for_auth_api + + sm = self.app.appbuilder.sm + tester = sm.find_user(username="test") + if not tester: + role_admin = sm.find_role("Admin") + sm.add_user( + username="test", + first_name="test", + last_name="test", + email="test@fab.org", + role=role_admin, + password="test", + ) + + +class TestBasicAuth(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_auth_api): + from airflow.www.extensions.init_security import init_api_auth + + old_auth = getattr(minimal_app_for_auth_api, "api_auth") + + try: + with conf_vars( + {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} + ): + init_api_auth(minimal_app_for_auth_api) + yield + finally: + setattr(minimal_app_for_auth_api, "api_auth", old_auth) + + def test_success(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert current_user.email == "test@fab.org" + + assert response.status_code == 200 + assert response.json == { + "pools": [ + { + "name": "default_pool", + "slots": 128, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "deferred_slots": 0, + "open_slots": 128, + "description": "Default pool", + "include_deferred": False, + }, + ], + "total_entries": 1, + } + + @pytest.mark.parametrize( + "token", + [ + "basic", + "basic ", + "bearer", + "test:test", + b64encode(b"test:test").decode(), + "bearer ", + "basic: ", + "basic 123", + ], + ) + def test_malformed_headers(self, token): + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 401 + assert response.headers["Content-Type"] == "application/problem+json" + assert response.headers["WWW-Authenticate"] == "Basic" + assert_401(response) + + @pytest.mark.parametrize( + "token", + [ + "basic " + b64encode(b"test").decode(), + "basic " + b64encode(b"test:").decode(), + "basic " + b64encode(b"test:123").decode(), + "basic " + b64encode(b"test test").decode(), + ], + ) + def test_invalid_auth_header(self, token): + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 401 + assert response.headers["Content-Type"] == "application/problem+json" + assert response.headers["WWW-Authenticate"] == "Basic" + assert_401(response) + + +class TestSessionWithBasicAuthFallback(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_auth_api): + from airflow.www.extensions.init_security import init_api_auth + + old_auth = getattr(minimal_app_for_auth_api, "api_auth") + + try: + with conf_vars( + { + ( + "api", + "auth_backends", + ): "airflow.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth" + } + ): + init_api_auth(minimal_app_for_auth_api) + yield + finally: + setattr(minimal_app_for_auth_api, "api_auth", old_auth) + + def test_basic_auth_fallback(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + # request uses session + admin_user = client_with_login(self.app, username="test", password="test") + response = admin_user.get("/api/v1/pools") + assert response.status_code == 200 + + # request uses basic auth + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 200 + + # request without session or basic auth header + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools") + assert response.status_code == 401 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py new file mode 100644 index 0000000000000..56f135d457e9c --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime +from unittest import mock +from urllib.parse import urlencode + +import pendulum +import pytest + +from airflow.models import DagBag, DagModel +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS + +try: + from airflow.models.backfill import Backfill +except ImportError: + if AIRFLOW_V_3_0_PLUS: + raise + else: + pass +from airflow.models.dag import DAG +from airflow.models.serialized_dag import SerializedDagModel +from airflow.operators.empty import EmptyOperator +from airflow.security import permissions +from airflow.utils import timezone +from airflow.utils.session import provide_session +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs, clear_db_serialized_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +DAG_ID = "test_dag" +TASK_ID = "op1" +DAG2_ID = "test_dag2" +DAG3_ID = "test_dag3" +UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user(app, username="test_granular_permissions", role_name="TestGranularDag") + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_1", + access_control={ + "TestGranularDag": { + permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} + }, + }, + ) + + with DAG( + DAG_ID, + schedule=None, + start_date=datetime(2020, 6, 15), + doc_md="details", + params={"foo": 1}, + tags=["example"], + ) as dag: + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG2_ID, schedule=None, start_date=datetime(2020, 6, 15)) as dag2: # no doc_md + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG3_ID, schedule=None) as dag3: # DAG start_date set to None + EmptyOperator(task_id=TASK_ID, start_date=datetime(2019, 6, 12)) + + dag_bag = DagBag(os.devnull, include_examples=False) + dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3} + + app.dag_bag = dag_bag + + yield app + + delete_user(app, username="test_granular_permissions") + + +class TestBackfillEndpoint: + @staticmethod + def clean_db(): + clear_db_backfills() + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.clean_db() + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.dag_id = DAG_ID + self.dag2_id = DAG2_ID + self.dag3_id = DAG3_ID + + def teardown_method(self) -> None: + self.clean_db() + + @provide_session + def _create_dag_models(self, *, count=1, dag_id_prefix="TEST_DAG", is_paused=False, session=None): + dags = [] + for num in range(1, count + 1): + dag_model = DagModel( + dag_id=f"{dag_id_prefix}_{num}", + fileloc=f"/tmp/dag_{num}.py", + is_active=True, + timetable_summary="0 0 * * *", + is_paused=is_paused, + ) + session.add(dag_model) + dags.append(dag_model) + return dags + + @provide_session + def _create_deactivated_dag(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_DELETED_1", + fileloc="/tmp/dag_del_1.py", + schedule_interval="2 2 * * *", + is_active=False, + ) + session.add(dag_model) + + +class TestListBackfills(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + b = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + + session.add(b) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.get("/api/v1/backfills?dag_id=TEST_DAG_1", **kwargs) + assert response.status_code == 200 + + +class TestGetBackfill(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + backfill = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + session.add(backfill) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.get(f"/api/v1/backfills/{backfill.id}", **kwargs) + assert response.status_code == 200 + + +class TestCreateBackfill(TestBackfillEndpoint): + def test_create_backfill(self, session, dag_maker): + with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * * *") as dag: + EmptyOperator(task_id="mytask") + session.add(SerializedDagModel(dag)) + session.commit() + session.query(DagModel).all() + from_date = pendulum.parse("2024-01-01") + from_date_iso = from_date.isoformat() + to_date = pendulum.parse("2024-02-01") + to_date_iso = to_date.isoformat() + max_active_runs = 5 + query = urlencode( + query={ + "dag_id": dag.dag_id, + "from_date": f"{from_date_iso}", + "to_date": f"{to_date_iso}", + "max_active_runs": max_active_runs, + "reverse": False, + } + ) + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + + response = self.client.post( + f"/api/v1/backfills?{query}", + **kwargs, + ) + assert response.status_code == 200 + assert response.json == { + "completed_at": mock.ANY, + "created_at": mock.ANY, + "dag_id": "TEST_DAG_1", + "dag_run_conf": None, + "from_date": from_date_iso, + "id": mock.ANY, + "is_paused": False, + "max_active_runs": 5, + "to_date": to_date_iso, + "updated_at": mock.ANY, + } + + +class TestPauseBackfill(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + backfill = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + session.add(backfill) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.post(f"/api/v1/backfills/{backfill.id}/pause", **kwargs) + assert response.status_code == 200 + + +class TestCancelBackfill(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + backfill = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + session.add(backfill) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.post(f"/api/v1/backfills/{backfill.id}/cancel", **kwargs) + assert response.status_code == 200 + # now it is marked as completed + assert pendulum.parse(response.json["completed_at"]) + + # get conflict when canceling already-canceled backfill + response = self.client.post(f"/api/v1/backfills/{backfill.id}/cancel", **kwargs) + assert response.status_code == 409 diff --git a/tests/api_connexion/test_cors.py b/tests/providers/fab/auth_manager/api_endpoints/test_cors.py similarity index 81% rename from tests/api_connexion/test_cors.py rename to tests/providers/fab/auth_manager/api_endpoints/test_cors.py index a2b7f0ebca743..b44eab8820ec6 100644 --- a/tests/api_connexion/test_cors.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_cors.py @@ -20,16 +20,21 @@ import pytest +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_pools -pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] class BaseTestAuth: @pytest.fixture(autouse=True) - def set_attrs(self, minimal_app_for_api): - self.app = minimal_app_for_api + def set_attrs(self, minimal_app_for_auth_api): + self.app = minimal_app_for_auth_api sm = self.app.appbuilder.sm tester = sm.find_user(username="test") @@ -47,19 +52,19 @@ def set_attrs(self, minimal_app_for_api): class TestEmptyCors(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): + def with_basic_auth_backend(self, minimal_app_for_auth_api): from airflow.www.extensions.init_security import init_api_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + old_auth = getattr(minimal_app_for_auth_api, "api_auth") try: with conf_vars( {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} ): - init_api_auth(minimal_app_for_api) + init_api_auth(minimal_app_for_auth_api) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(minimal_app_for_auth_api, "api_auth", old_auth) def test_empty_cors_headers(self): token = "Basic " + b64encode(b"test:test").decode() @@ -75,10 +80,10 @@ def test_empty_cors_headers(self): class TestCorsOrigin(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): + def with_basic_auth_backend(self, minimal_app_for_auth_api): from airflow.www.extensions.init_security import init_api_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + old_auth = getattr(minimal_app_for_auth_api, "api_auth") try: with conf_vars( @@ -90,10 +95,10 @@ def with_basic_auth_backend(self, minimal_app_for_api): ("api", "access_control_allow_origins"): "http://apache.org http://example.com", } ): - init_api_auth(minimal_app_for_api) + init_api_auth(minimal_app_for_auth_api) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(minimal_app_for_auth_api, "api_auth", old_auth) def test_cors_origin_reflection(self): token = "Basic " + b64encode(b"test:test").decode() @@ -119,10 +124,10 @@ def test_cors_origin_reflection(self): class TestCorsWildcard(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): + def with_basic_auth_backend(self, minimal_app_for_auth_api): from airflow.www.extensions.init_security import init_api_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + old_auth = getattr(minimal_app_for_auth_api, "api_auth") try: with conf_vars( @@ -134,10 +139,10 @@ def with_basic_auth_backend(self, minimal_app_for_api): ("api", "access_control_allow_origins"): "*", } ): - init_api_auth(minimal_app_for_api) + init_api_auth(minimal_app_for_auth_api) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(minimal_app_for_auth_api, "api_auth", old_auth) def test_cors_origin_reflection(self): token = "Basic " + b64encode(b"test:test").decode() diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py new file mode 100644 index 0000000000000..b78ac58e442e0 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py @@ -0,0 +1,252 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +import pendulum +import pytest + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.models import DagBag, DagModel +from airflow.models.dag import DAG +from airflow.operators.empty import EmptyOperator +from airflow.security import permissions +from airflow.utils.session import provide_session +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags +from tests.test_utils.www import _check_last_log + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture +def current_file_token(url_safe_serializer) -> str: + return url_safe_serializer.dumps(__file__) + + +DAG_ID = "test_dag" +TASK_ID = "op1" +DAG2_ID = "test_dag2" +DAG3_ID = "test_dag3" +UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user(app, username="test_granular_permissions", role_name="TestGranularDag") + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_1", + access_control={ + "TestGranularDag": { + permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} + }, + }, + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_1", + access_control={ + "TestGranularDag": { + permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} + }, + }, + ) + + with DAG( + DAG_ID, + schedule=None, + start_date=datetime(2020, 6, 15), + doc_md="details", + params={"foo": 1}, + tags=["example"], + ) as dag: + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG2_ID, schedule=None, start_date=datetime(2020, 6, 15)) as dag2: # no doc_md + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG3_ID, schedule=None) as dag3: # DAG start_date set to None + EmptyOperator(task_id=TASK_ID, start_date=datetime(2019, 6, 12)) + + dag_bag = DagBag(os.devnull, include_examples=False) + dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3} + + app.dag_bag = dag_bag + + yield app + + delete_user(app, username="test_granular_permissions") + + +class TestDagEndpoint: + @staticmethod + def clean_db(): + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.clean_db() + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.dag_id = DAG_ID + self.dag2_id = DAG2_ID + self.dag3_id = DAG3_ID + + def teardown_method(self) -> None: + self.clean_db() + + @provide_session + def _create_dag_models(self, count, dag_id_prefix="TEST_DAG", is_paused=False, session=None): + for num in range(1, count + 1): + dag_model = DagModel( + dag_id=f"{dag_id_prefix}_{num}", + fileloc=f"/tmp/dag_{num}.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=is_paused, + ) + session.add(dag_model) + + @provide_session + def _create_dag_model_for_details_endpoint(self, dag_id, session=None): + dag_model = DagModel( + dag_id=dag_id, + fileloc="/tmp/dag.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=False, + ) + session.add(dag_model) + + @provide_session + def _create_dag_model_for_details_endpoint_with_dataset_expression(self, dag_id, session=None): + dag_model = DagModel( + dag_id=dag_id, + fileloc="/tmp/dag.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=False, + dataset_expression={ + "any": [ + "s3://dag1/output_1.txt", + {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]}, + ] + }, + ) + session.add(dag_model) + + @provide_session + def _create_deactivated_dag(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_DELETED_1", + fileloc="/tmp/dag_del_1.py", + timetable_summary="2 2 * * *", + is_active=False, + ) + session.add(dag_model) + + +class TestGetDag(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(1) + response = self.client.get( + "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 200 + + def test_should_respond_403_with_granular_access_for_different_dag(self): + self._create_dag_models(3) + response = self.client.get( + "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 403 + + +class TestGetDags(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(3) + response = self.client.get( + "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 200 + assert len(response.json["dags"]) == 1 + assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" + + +class TestPatchDag(TestDagEndpoint): + @provide_session + def _create_dag_model(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_1", fileloc="/tmp/dag_1.py", timetable_summary="2 2 * * *", is_paused=True + ) + session.add(dag_model) + return dag_model + + def test_should_respond_200_on_patch_with_granular_dag_access(self, session): + self._create_dag_models(1) + response = self.client.patch( + "/api/v1/dags/TEST_DAG_1", + json={ + "is_paused": False, + }, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 200 + _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None) + + def test_validation_error_raises_400(self): + patch_body = { + "ispaused": True, + } + dag_model = self._create_dag_model() + response = self.client.patch( + f"/api/v1/dags/{dag_model.dag_id}", + json=patch_body, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 400 + assert response.json == { + "detail": "{'ispaused': ['Unknown field.']}", + "status": 400, + "title": "Bad Request", + "type": EXCEPTIONS_LINK_MAP[400], + } + + +class TestPatchDags(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(3) + response = self.client.patch( + "api/v1/dags?dag_id_pattern=~", + json={ + "is_paused": False, + }, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 200 + assert len(response.json["dags"]) == 1 + assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py new file mode 100644 index 0000000000000..a58ea08ff31cf --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import timedelta + +import pytest + +from airflow.models.dag import DAG, DagModel +from airflow.models.dagrun import DagRun +from airflow.models.param import Param +from airflow.security import permissions +from airflow.utils import timezone +from airflow.utils.session import create_session +from airflow.utils.state import DagRunState +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS + +try: + from airflow.utils.types import DagRunTriggeredByType, DagRunType +except ImportError: + if AIRFLOW_V_3_0_PLUS: + raise + else: + pass +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user( + app, + username="test_no_dag_run_create_permission", + role_name="TestNoDagRunCreatePermission", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), + ], + ) + create_user( + app, + username="test_dag_view_only", + role_name="TestViewDags", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), + ], + ) + create_user( + app, + username="test_view_dags", + role_name="TestViewDags", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), + ], + ) + create_user( + app, + username="test_granular_permissions", + role_name="TestGranularDag", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)], + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_ID", + access_control={ + "TestGranularDag": {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}, + "TestNoDagRunCreatePermission": {permissions.RESOURCE_DAG_RUN: {permissions.ACTION_CAN_CREATE}}, + }, + ) + + yield app + + delete_user(app, username="test_dag_view_only") + delete_user(app, username="test_view_dags") + delete_user(app, username="test_granular_permissions") + delete_user(app, username="test_no_dag_run_create_permission") + delete_roles(app) + + +class TestDagRunEndpoint: + default_time = "2020-06-11T18:00:00+00:00" + default_time_2 = "2020-06-12T18:00:00+00:00" + default_time_3 = "2020-06-13T18:00:00+00:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() + + def teardown_method(self) -> None: + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + def _create_dag(self, dag_id): + dag_instance = DagModel(dag_id=dag_id) + dag_instance.is_active = True + with create_session() as session: + session.add(dag_instance) + dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)}) + self.app.dag_bag.bag_dag(dag) + return dag_instance + + def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): + dag_runs = [] + dags = [] + triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} + + for i in range(idx_start, idx_start + 2): + if i == 1: + dags.append(DagModel(dag_id="TEST_DAG_ID", is_active=True)) + dagrun_model = DagRun( + dag_id="TEST_DAG_ID", + run_id=f"TEST_DAG_RUN_ID_{i}", + run_type=DagRunType.MANUAL, + execution_date=timezone.parse(self.default_time) + timedelta(days=i - 1), + start_date=timezone.parse(self.default_time), + external_trigger=True, + state=state, + **triggered_by_kwargs, + ) + dagrun_model.updated_at = timezone.parse(self.default_time) + dag_runs.append(dagrun_model) + + if extra_dag: + for i in range(idx_start + 2, idx_start + 4): + dags.append(DagModel(dag_id=f"TEST_DAG_ID_{i}")) + dag_runs.append( + DagRun( + dag_id=f"TEST_DAG_ID_{i}", + run_id=f"TEST_DAG_RUN_ID_{i}", + run_type=DagRunType.MANUAL, + execution_date=timezone.parse(self.default_time_2), + start_date=timezone.parse(self.default_time), + external_trigger=True, + state=state, + ) + ) + if commit: + with create_session() as session: + session.add_all(dag_runs) + session.add_all(dags) + return dag_runs + + +class TestGetDagRuns(TestDagRunEndpoint): + def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): + self._create_test_dag_run(extra_dag=True) + expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"] + response = self.client.get( + "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 200 + dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] + assert dag_run_ids == expected_dag_run_ids + + +class TestGetDagRunBatch(TestDagRunEndpoint): + def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): + self._create_test_dag_run(extra_dag=True) + expected_response_json_1 = { + "dag_id": "TEST_DAG_ID", + "dag_run_id": "TEST_DAG_RUN_ID_1", + "end_date": None, + "state": "running", + "execution_date": self.default_time, + "logical_date": self.default_time, + "external_trigger": True, + "start_date": self.default_time, + "conf": {}, + "data_interval_end": None, + "data_interval_start": None, + "last_scheduling_decision": None, + "run_type": "manual", + "note": None, + } + expected_response_json_1.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) + expected_response_json_2 = { + "dag_id": "TEST_DAG_ID", + "dag_run_id": "TEST_DAG_RUN_ID_2", + "end_date": None, + "state": "running", + "execution_date": self.default_time_2, + "logical_date": self.default_time_2, + "external_trigger": True, + "start_date": self.default_time, + "conf": {}, + "data_interval_end": None, + "data_interval_start": None, + "last_scheduling_decision": None, + "run_type": "manual", + "note": None, + } + expected_response_json_2.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) + + response = self.client.post( + "api/v1/dags/~/dagRuns/list", + json={"dag_ids": []}, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 200 + assert response.json == { + "dag_runs": [ + expected_response_json_1, + expected_response_json_2, + ], + "total_entries": 2, + } + + +class TestPostDagRun(TestDagRunEndpoint): + def test_dagrun_trigger_with_dag_level_permissions(self): + self._create_dag("TEST_DAG_ID") + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns", + json={"conf": {"validated_number": 1}}, + environ_overrides={"REMOTE_USER": "test_no_dag_run_create_permission"}, + ) + assert response.status_code == 200 + + @pytest.mark.parametrize( + "username", + ["test_dag_view_only", "test_view_dags", "test_granular_permissions"], + ) + def test_should_raises_403_unauthorized(self, username): + self._create_dag("TEST_DAG_ID") + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns", + json={ + "dag_run_id": "TEST_DAG_RUN_ID_1", + "execution_date": self.default_time, + }, + environ_overrides={"REMOTE_USER": username}, + ) + assert response.status_code == 403 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py new file mode 100644 index 0000000000000..f0d9b0da298c6 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import os +from typing import TYPE_CHECKING + +import pytest + +from airflow.models import DagBag +from airflow.security import permissions +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +if TYPE_CHECKING: + from airflow.models.dag import DAG + +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) +EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") +EXAMPLE_DAG_ID = "example_bash_operator" +TEST_DAG_ID = "latest_only" +NOT_READABLE_DAG_ID = "latest_only_with_trigger" +TEST_MULTIPLE_DAGS_ID = "asset_produces_1" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test", + role_name="Test", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], + ) + app.appbuilder.sm.sync_perm_for_dag( + TEST_DAG_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + EXAMPLE_DAG_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + TEST_MULTIPLE_DAGS_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test") + + +class TestGetSource: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.clear_db() + + def teardown_method(self) -> None: + self.clear_db() + + @staticmethod + def clear_db(): + clear_db_dags() + clear_db_serialized_dags() + clear_db_dag_code() + + @staticmethod + def _get_dag_file_docstring(fileloc: str) -> str | None: + with open(fileloc) as f: + file_contents = f.read() + module = ast.parse(file_contents) + docstring = ast.get_docstring(module) + return docstring + + def test_should_respond_406(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + test_dag: DAG = dagbag.dags[TEST_DAG_ID] + + url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" + response = self.client.get( + url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"} + ) + + assert 406 == response.status_code + + def test_should_respond_403_not_readable(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID] + + response = self.client.get( + f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", + headers={"Accept": "text/plain"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + read_dag = self.client.get( + f"/api/v1/dags/{NOT_READABLE_DAG_ID}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 403 + assert read_dag.status_code == 403 + + def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID] + + response = self.client.get( + f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", + headers={"Accept": "text/plain"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + + read_dag = self.client.get( + f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 403 + assert read_dag.status_code == 200 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py new file mode 100644 index 0000000000000..adfde1cc5b3eb --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models.dag import DagModel +from airflow.models.dagwarning import DagWarning +from airflow.security import permissions +from airflow.utils.session import create_session +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dag_warnings, clear_db_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, # type:ignore + username="test_with_dag2_read", + role_name="TestWithDag2Read", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), + (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"), + ], + ) + + yield app + + delete_user(app, username="test_with_dag2_read") + + +class TestBaseDagWarning: + timestamp = "2020-06-10T12:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + + def teardown_method(self) -> None: + clear_db_dag_warnings() + clear_db_dags() + + +class TestGetDagWarningEndpoint(TestBaseDagWarning): + def setup_class(self): + clear_db_dag_warnings() + clear_db_dags() + + def setup_method(self): + with create_session() as session: + session.add(DagModel(dag_id="dag1")) + session.add(DagWarning("dag1", "non-existent pool", "test message")) + session.commit() + + def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): + response = self.client.get( + "/api/v1/dagWarnings", + environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, + query_string={"dag_id": "dag1"}, + ) + assert response.status_code == 403 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py new file mode 100644 index 0000000000000..4d302722223d8 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Generator + +import pytest +import time_machine + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS + +try: + from airflow.models.asset import AssetDagRunQueue, AssetModel +except ImportError: + if AIRFLOW_V_3_0_PLUS: + raise + else: + pass +from airflow.security import permissions +from airflow.utils import timezone +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.db import clear_db_assets, clear_db_runs +from tests.test_utils.www import _check_last_log + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_queued_event", + role_name="TestQueuedEvent", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET), + ], + ) + + yield app + + delete_user(app, username="test_queued_event") + + +class TestAssetEndpoint: + default_time = "2020-06-11T18:00:00+00:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() + clear_db_assets() + clear_db_runs() + + def teardown_method(self) -> None: + clear_db_assets() + clear_db_runs() + + def _create_asset(self, session): + asset_model = AssetModel( + id=1, + uri="s3://bucket/key", + extra={"foo": "bar"}, + created_at=timezone.parse(self.default_time), + updated_at=timezone.parse(self.default_time), + ) + session.add(asset_model) + session.commit() + return asset_model + + +class TestQueuedEventEndpoint(TestAssetEndpoint): + @pytest.fixture + def time_freezer(self) -> Generator: + freezer = time_machine.travel(self.default_time, tick=False) + freezer.start() + + yield + + freezer.stop() + + def _create_asset_dag_run_queues(self, dag_id, dataset_id, session): + ddrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) + session.add(ddrq) + session.commit() + return ddrq + + +class TestGetDagDatasetQueuedEvent(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + dataset_uri = "s3://bucket/key" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + + def test_should_respond_404(self): + dag_id = "not_exists" + dataset_uri = "not_exists" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDagDatasetQueuedEvent(TestAssetEndpoint): + def test_delete_should_respond_204(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_uri = "s3://bucket/key" + dataset_id = self._create_asset(session).id + + ddrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) + session.add(ddrq) + session.commit() + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 1 + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 204 + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 0 + _check_last_log( + session, dag_id=dag_id, event="api.delete_dag_dataset_queued_event", execution_date=None + ) + + def test_should_respond_404(self): + dag_id = "not_exists" + dataset_uri = "not_exists" + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestGetDagDatasetQueuedEvents(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "queued_events": [ + { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + ], + "total_entries": 1, + } + + def test_should_respond_404(self): + dag_id = "not_exists" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDagDatasetQueuedEvents(TestAssetEndpoint): + def test_should_respond_404(self): + dag_id = "not_exists" + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + dataset_uri = "s3://bucket/key" + + response = self.client.get( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "queued_events": [ + { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + ], + "total_entries": 1, + } + + def test_should_respond_404(self): + dataset_uri = "not_exists" + + response = self.client.get( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint): + def test_delete_should_respond_204(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + dataset_uri = "s3://bucket/key" + + response = self.client.delete( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 204 + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 0 + _check_last_log(session, dag_id=None, event="api.delete_dataset_queued_events", execution_date=None) + + def test_should_respond_404(self): + dataset_uri = "not_exists" + + response = self.client.delete( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py new file mode 100644 index 0000000000000..acf3ca62684a1 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models import Log +from airflow.security import permissions +from airflow.utils import timezone +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_logs + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_granular", + role_name="TestGranular", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_ID_1", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_ID_2", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test_granular") + + +@pytest.fixture +def task_instance(session, create_task_instance, request): + return create_task_instance( + session=session, + dag_id="TEST_DAG_ID", + task_id="TEST_TASK_ID", + run_id="TEST_RUN_ID", + execution_date=request.instance.default_time, + ) + + +@pytest.fixture +def create_log_model(create_task_instance, task_instance, session, request): + def maker(event, when, **kwargs): + log_model = Log( + event=event, + task_instance=task_instance, + **kwargs, + ) + log_model.dttm = when + + session.add(log_model) + session.flush() + return log_model + + return maker + + +class TestEventLogEndpoint: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_logs() + self.default_time = timezone.parse("2020-06-10T20:00:00+00:00") + self.default_time_2 = timezone.parse("2020-06-11T07:00:00+00:00") + + def teardown_method(self) -> None: + clear_db_logs() + + +class TestGetEventLogs(TestEventLogEndpoint): + def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session): + eventlog1 = create_log_model( + event="TEST_EVENT_1", + dag_id="TEST_DAG_ID_1", + task_id="TEST_TASK_ID_1", + owner="TEST_OWNER_1", + when=self.default_time, + ) + eventlog2 = create_log_model( + event="TEST_EVENT_2", + dag_id="TEST_DAG_ID_2", + task_id="TEST_TASK_ID_2", + owner="TEST_OWNER_2", + when=self.default_time_2, + ) + session.add_all([eventlog1, eventlog2]) + session.commit() + for attr in ["dag_id", "task_id", "owner", "event"]: + attr_value = f"TEST_{attr}_1".upper() + response = self.client.get( + f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} + ) + assert response.status_code == 200 + assert response.json["total_entries"] == 1 + assert len(response.json["event_logs"]) == 1 + assert response.json["event_logs"][0][attr] == attr_value + + def test_should_filter_eventlogs_by_included_events(self, create_log_model): + for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: + create_log_model(event=event, when=self.default_time) + response = self.client.get( + "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2", + environ_overrides={"REMOTE_USER": "test_granular"}, + ) + assert response.status_code == 200 + response_data = response.json + assert len(response_data["event_logs"]) == 2 + assert response_data["total_entries"] == 2 + assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]} + + def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): + for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: + create_log_model(event=event, when=self.default_time) + response = self.client.get( + "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2", + environ_overrides={"REMOTE_USER": "test_granular"}, + ) + assert response.status_code == 200 + response_data = response.json + assert len(response_data["event_logs"]) == 1 + assert response_data["total_entries"] == 1 + assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py new file mode 100644 index 0000000000000..a2fa1d028a3f2 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models.dag import DagModel +from airflow.security import permissions +from airflow.utils import timezone +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS, ParseImportError +from tests.test_utils.db import clear_db_dags, clear_db_import_errors +from tests.test_utils.permissions import _resource_name + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +TEST_DAG_IDS = ["test_dag", "test_dag2"] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_single_dag", + role_name="TestSingleDAG", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], + ) + # For some reason, DAG level permissions are not synced when in the above list of perms, + # so do it manually here: + app.appbuilder.sm.bulk_sync_roles( + [ + { + "role": "TestSingleDAG", + "perms": [ + ( + permissions.ACTION_CAN_READ, + _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG), + ) + ], + } + ] + ) + + yield app + + delete_user(app, username="test_single_dag") + + +class TestBaseImportError: + timestamp = "2020-06-10T12:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + + clear_db_import_errors() + clear_db_dags() + + def teardown_method(self) -> None: + clear_db_import_errors() + clear_db_dags() + + @staticmethod + def _normalize_import_errors(import_errors): + for i, import_error in enumerate(import_errors, 1): + import_error["import_error_id"] = i + + +class TestGetImportErrorEndpoint(TestBaseImportError): + def test_should_raise_403_forbidden_without_dag_read(self, session): + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 403 + + def test_should_return_200_with_single_dag_read(self, session): + dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py") + session.add(dag_model) + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + response_data["import_error_id"] = 1 + assert { + "filename": "Lorem_ipsum.py", + "import_error_id": 1, + "stack_trace": "Lorem ipsum", + "timestamp": "2020-06-10T12:00:00+00:00", + } == response_data + + def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session): + for dag_id in TEST_DAG_IDS: + dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py") + session.add(dag_model) + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + response_data["import_error_id"] = 1 + assert { + "filename": "Lorem_ipsum.py", + "import_error_id": 1, + "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", + "timestamp": "2020-06-10T12:00:00+00:00", + } == response_data + + +class TestGetImportErrorsEndpoint(TestBaseImportError): + def test_get_import_errors_single_dag(self, session): + for dag_id in TEST_DAG_IDS: + fake_filename = f"/tmp/{dag_id}.py" + dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) + session.add(dag_model) + importerror = ParseImportError( + filename=fake_filename, + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(importerror) + session.commit() + + response = self.client.get( + "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + self._normalize_import_errors(response_data["import_errors"]) + assert { + "import_errors": [ + { + "filename": "/tmp/test_dag.py", + "import_error_id": 1, + "stack_trace": "Lorem ipsum", + "timestamp": "2020-06-10T12:00:00+00:00", + }, + ], + "total_entries": 1, + } == response_data + + def test_get_import_errors_single_dag_in_dagfile(self, session): + for dag_id in TEST_DAG_IDS: + fake_filename = "/tmp/all_in_one.py" + dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) + session.add(dag_model) + + importerror = ParseImportError( + filename="/tmp/all_in_one.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(importerror) + session.commit() + + response = self.client.get( + "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + self._normalize_import_errors(response_data["import_errors"]) + assert { + "import_errors": [ + { + "filename": "/tmp/all_in_one.py", + "import_error_id": 1, + "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", + "timestamp": "2020-06-10T12:00:00+00:00", + }, + ], + "total_entries": 1, + } == response_data diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py index 30cfaeb227903..413a49a9d86a1 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py @@ -19,6 +19,13 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_role, + create_user, + delete_role, + delete_user, +) +from tests.test_utils.api_connexion_utils import assert_401 from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): @@ -27,13 +34,6 @@ from airflow.security import permissions -from tests.test_utils.api_connexion_utils import ( - assert_401, - create_role, - create_user, - delete_role, - delete_user, -) pytestmark = pytest.mark.db_test @@ -42,7 +42,7 @@ def configured_app(minimal_app_for_auth_api): app = minimal_app_for_auth_api create_user( - app, # type: ignore + app, username="test", role_name="Test", permissions=[ @@ -53,11 +53,11 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name="TestNoPermissions") yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestRoleEndpoint: diff --git a/tests/api_connexion/schemas/test_role_and_permission_schema.py b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py similarity index 85% rename from tests/api_connexion/schemas/test_role_and_permission_schema.py rename to tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py index f2967d519794c..4a2f0068e5e4a 100644 --- a/tests/api_connexion/schemas/test_role_and_permission_schema.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py @@ -31,19 +31,19 @@ class TestRoleCollectionItemSchema: @pytest.fixture(scope="class") - def role(self, minimal_app_for_api): + def role(self, minimal_app_for_auth_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_auth_api, # type: ignore name="Test", permissions=[ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), ], ) - delete_role(minimal_app_for_api, "Test") + delete_role(minimal_app_for_auth_api, "Test") @pytest.fixture(autouse=True) - def _set_attrs(self, minimal_app_for_api, role): - self.app = minimal_app_for_api + def _set_attrs(self, minimal_app_for_auth_api, role): + self.app = minimal_app_for_auth_api self.role = role def test_serialize(self): @@ -67,26 +67,26 @@ def test_deserialize(self): class TestRoleCollectionSchema: @pytest.fixture(scope="class") - def role1(self, minimal_app_for_api): + def role1(self, minimal_app_for_auth_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_auth_api, # type: ignore name="Test1", permissions=[ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), ], ) - delete_role(minimal_app_for_api, "Test1") + delete_role(minimal_app_for_auth_api, "Test1") @pytest.fixture(scope="class") - def role2(self, minimal_app_for_api): + def role2(self, minimal_app_for_auth_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_auth_api, # type: ignore name="Test2", permissions=[ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), ], ) - delete_role(minimal_app_for_api, "Test2") + delete_role(minimal_app_for_auth_api, "Test2") def test_serialize(self, role1, role2): instance = RoleCollection([role1, role2], total_entries=2) diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py new file mode 100644 index 0000000000000..69b3c221eae93 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py @@ -0,0 +1,427 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime as dt +import urllib + +import pytest + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.models import DagRun, TaskInstance +from airflow.security import permissions +from airflow.utils.session import provide_session +from airflow.utils.state import State +from airflow.utils.timezone import datetime +from airflow.utils.types import DagRunType +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +DEFAULT_DATETIME_1 = datetime(2020, 1, 1) +DEFAULT_DATETIME_STR_1 = "2020-01-01T00:00:00+00:00" +DEFAULT_DATETIME_STR_2 = "2020-01-02T00:00:00+00:00" + +QUOTED_DEFAULT_DATETIME_STR_1 = urllib.parse.quote(DEFAULT_DATETIME_STR_1) +QUOTED_DEFAULT_DATETIME_STR_2 = urllib.parse.quote(DEFAULT_DATETIME_STR_2) + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_dag_read_only", + role_name="TestDagReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + create_user( + app, + username="test_task_read_only", + role_name="TestTaskReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + create_user( + app, + username="test_read_only_one_dag", + role_name="TestReadOnlyOneDag", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms, + # so do it manually here: + app.appbuilder.sm.bulk_sync_roles( + [ + { + "role": "TestReadOnlyOneDag", + "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")], + } + ] + ) + + yield app + + delete_user(app, username="test_dag_read_only") + delete_user(app, username="test_task_read_only") + delete_user(app, username="test_read_only_one_dag") + delete_roles(app) + + +class TestTaskInstanceEndpoint: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app, dagbag) -> None: + self.default_time = DEFAULT_DATETIME_1 + self.ti_init = { + "execution_date": self.default_time, + "state": State.RUNNING, + } + self.ti_extras = { + "start_date": self.default_time + dt.timedelta(days=1), + "end_date": self.default_time + dt.timedelta(days=2), + "pid": 100, + "duration": 10000, + "pool": "default_pool", + "queue": "default_queue", + "job_id": 0, + } + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_runs() + clear_db_sla_miss() + clear_rendered_ti_fields() + self.dagbag = dagbag + + def create_task_instances( + self, + session, + dag_id: str = "example_python_operator", + update_extras: bool = True, + task_instances=None, + dag_run_state=State.RUNNING, + with_ti_history=False, + ): + """Method to create task instances using kwargs and default arguments""" + + dag = self.dagbag.get_dag(dag_id) + tasks = dag.tasks + counter = len(tasks) + if task_instances is not None: + counter = min(len(task_instances), counter) + + run_id = "TEST_DAG_RUN_ID" + execution_date = self.ti_init.pop("execution_date", self.default_time) + dr = None + + tis = [] + for i in range(counter): + if task_instances is None: + pass + elif update_extras: + self.ti_extras.update(task_instances[i]) + else: + self.ti_init.update(task_instances[i]) + + if "execution_date" in self.ti_init: + run_id = f"TEST_DAG_RUN_ID_{i}" + execution_date = self.ti_init.pop("execution_date") + dr = None + + if not dr: + dr = DagRun( + run_id=run_id, + dag_id=dag_id, + execution_date=execution_date, + run_type=DagRunType.MANUAL, + state=dag_run_state, + ) + session.add(dr) + ti = TaskInstance(task=tasks[i], **self.ti_init) + session.add(ti) + ti.dag_run = dr + ti.note = "placeholder-note" + + for key, value in self.ti_extras.items(): + setattr(ti, key, value) + tis.append(ti) + + session.commit() + if with_ti_history: + for ti in tis: + ti.try_number = 1 + session.merge(ti) + session.commit() + dag.clear() + for ti in tis: + ti.try_number = 2 + ti.queue = "default_queue" + session.merge(ti) + session.commit() + return tis + + +class TestGetTaskInstance(TestTaskInstanceEndpoint): + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + @provide_session + def test_should_respond_200(self, username, session): + self.create_task_instances(session) + # Update ti and set operator to None to + # test that operator field is nullable. + # This prevents issue when users upgrade to 2.0+ + # from 1.10.x + # https://github.com/apache/airflow/issues/14421 + session.query(TaskInstance).update({TaskInstance.operator: None}, synchronize_session="fetch") + session.commit() + response = self.client.get( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + environ_overrides={"REMOTE_USER": username}, + ) + assert response.status_code == 200 + + +class TestGetTaskInstances(TestTaskInstanceEndpoint): + @pytest.mark.parametrize( + "task_instances, user, expected_ti", + [ + pytest.param( + { + "example_python_operator": 2, + "example_skip_dag": 1, + }, + "test_read_only_one_dag", + 2, + ), + pytest.param( + { + "example_python_operator": 1, + "example_skip_dag": 2, + }, + "test_read_only_one_dag", + 1, + ), + ], + ) + def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session): + for dag_id in task_instances: + self.create_task_instances( + session, + task_instances=[ + {"execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=i)} + for i in range(task_instances[dag_id]) + ], + dag_id=dag_id, + ) + response = self.client.get( + "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} + ) + assert response.status_code == 200 + assert response.json["total_entries"] == expected_ti + assert len(response.json["task_instances"]) == expected_ti + + +class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): + @pytest.mark.parametrize( + "task_instances, update_extras, payload, expected_ti_count, username", + [ + pytest.param( + [ + {"pool": "test_pool_1"}, + {"pool": "test_pool_2"}, + {"pool": "test_pool_3"}, + ], + True, + {"pool": ["test_pool_1", "test_pool_2"]}, + 2, + "test_dag_read_only", + id="test pool filter", + ), + pytest.param( + [ + {"state": State.RUNNING}, + {"state": State.QUEUED}, + {"state": State.SUCCESS}, + {"state": State.NONE}, + ], + False, + {"state": ["running", "queued", "none"]}, + 3, + "test_task_read_only", + id="test state filter", + ), + pytest.param( + [ + {"state": State.NONE}, + {"state": State.NONE}, + {"state": State.NONE}, + {"state": State.NONE}, + ], + False, + {}, + 4, + "test_task_read_only", + id="test dag with null states", + ), + pytest.param( + [ + {"end_date": DEFAULT_DATETIME_1}, + {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, + {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, + ], + True, + { + "end_date_gte": DEFAULT_DATETIME_STR_1, + "end_date_lte": DEFAULT_DATETIME_STR_2, + }, + 2, + "test_task_read_only", + id="test end date filter", + ), + pytest.param( + [ + {"start_date": DEFAULT_DATETIME_1}, + {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, + {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, + ], + True, + { + "start_date_gte": DEFAULT_DATETIME_STR_1, + "start_date_lte": DEFAULT_DATETIME_STR_2, + }, + 2, + "test_dag_read_only", + id="test start date filter", + ), + ], + ) + def test_should_respond_200( + self, task_instances, update_extras, payload, expected_ti_count, username, session + ): + self.create_task_instances( + session, + update_extras=update_extras, + task_instances=task_instances, + ) + response = self.client.post( + "/api/v1/dags/~/dagRuns/~/taskInstances/list", + environ_overrides={"REMOTE_USER": username}, + json=payload, + ) + assert response.status_code == 200, response.json + assert expected_ti_count == response.json["total_entries"] + assert expected_ti_count == len(response.json["task_instances"]) + + def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session): + self.create_task_instances(session=session) + self.create_task_instances(session=session, dag_id="example_skip_dag") + payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]} + + response = self.client.post( + "/api/v1/dags/~/dagRuns/~/taskInstances/list", + environ_overrides={"REMOTE_USER": "test_read_only_one_dag"}, + json=payload, + ) + assert response.status_code == 403 + assert response.json == { + "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']", + "status": 403, + "title": "Forbidden", + "type": EXCEPTIONS_LINK_MAP[403], + } + + +class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint): + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + def test_should_raise_403_forbidden(self, username): + response = self.client.post( + "/api/v1/dags/example_python_operator/updateTaskInstancesState", + environ_overrides={"REMOTE_USER": username}, + json={ + "dry_run": True, + "task_id": "print_the_context", + "execution_date": DEFAULT_DATETIME_1.isoformat(), + "include_upstream": True, + "include_downstream": True, + "include_future": True, + "include_past": True, + "new_state": "failed", + }, + ) + assert response.status_code == 403 + + +class TestPatchTaskInstance(TestTaskInstanceEndpoint): + ENDPOINT_URL = ( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" + ) + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + def test_should_raise_403_forbidden(self, username): + response = self.client.patch( + self.ENDPOINT_URL, + environ_overrides={"REMOTE_USER": username}, + json={ + "dry_run": True, + "new_state": "failed", + }, + ) + assert response.status_code == 403 + + +class TestGetTaskInstanceTry(TestTaskInstanceEndpoint): + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + @provide_session + def test_should_respond_200(self, username, session): + self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True) + + response = self.client.get( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1", + environ_overrides={"REMOTE_USER": username}, + ) + assert response.status_code == 200 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py index bc400c8a43fad..7f2c885bab52c 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py @@ -30,7 +30,12 @@ with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.models import User -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_role, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_role, + delete_user, +) +from tests.test_utils.api_connexion_utils import assert_401 from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -43,7 +48,7 @@ def configured_app(minimal_app_for_auth_api): app = minimal_app_for_auth_api create_user( - app, # type: ignore + app, username="test", role_name="Test", permissions=[ @@ -53,12 +58,12 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name="TestNoPermissions") yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") delete_role(app, name="TestNoPermissions") diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py index 265407622e269..f3399de6a9775 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py @@ -18,6 +18,7 @@ import pytest +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_role, delete_role from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): @@ -30,8 +31,6 @@ DEFAULT_TIME = "2021-01-09T13:59:56.336000+00:00" -from tests.test_utils.api_connexion_utils import create_role, delete_role # noqa: E402 - pytestmark = pytest.mark.db_test diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py new file mode 100644 index 0000000000000..a8e71e1a82466 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models import Variable +from airflow.security import permissions +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_variables + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user( + app, + username="test_read_only", + role_name="TestReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), + ], + ) + create_user( + app, + username="test_delete_only", + role_name="TestDeleteOnly", + permissions=[ + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), + ], + ) + + yield app + + delete_user(app, username="test_read_only") + delete_user(app, username="test_delete_only") + + +class TestVariableEndpoint: + @pytest.fixture(autouse=True) + def setup_method(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_variables() + + def teardown_method(self) -> None: + clear_db_variables() + + +class TestGetVariable(TestVariableEndpoint): + @pytest.mark.parametrize( + "user, expected_status_code", + [ + ("test_read_only", 200), + ("test_delete_only", 403), + ], + ) + def test_read_variable(self, user, expected_status_code): + expected_value = '{"foo": 1}' + Variable.set("TEST_VARIABLE_KEY", expected_value) + response = self.client.get( + "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": user} + ) + assert response.status_code == expected_status_code + if expected_status_code == 200: + assert response.json == {"key": "TEST_VARIABLE_KEY", "value": expected_value, "description": None} diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py new file mode 100644 index 0000000000000..01336f9957c6d --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import timedelta + +import pytest + +from airflow.models.dag import DagModel +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance +from airflow.models.xcom import BaseXCom, XCom +from airflow.operators.empty import EmptyOperator +from airflow.security import permissions +from airflow.utils.dates import parse_execution_date +from airflow.utils.session import create_session +from airflow.utils.types import DagRunType +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +class CustomXCom(BaseXCom): + @classmethod + def deserialize_value(cls, xcom: XCom): + return f"real deserialized {super().deserialize_value(xcom)}" + + def orm_deserialize_value(self): + return f"orm deserialized {super().orm_deserialize_value()}" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user( + app, + username="test_granular_permissions", + role_name="TestGranularDag", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), + ], + ) + app.appbuilder.sm.sync_perm_for_dag( + "test-dag-id-1", + access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test_granular_permissions") + + +def _compare_xcom_collections(collection1: dict, collection_2: dict): + assert collection1.get("total_entries") == collection_2.get("total_entries") + + def sort_key(record): + return ( + record.get("dag_id"), + record.get("task_id"), + record.get("execution_date"), + record.get("map_index"), + record.get("key"), + ) + + assert sorted(collection1.get("xcom_entries", []), key=sort_key) == sorted( + collection_2.get("xcom_entries", []), key=sort_key + ) + + +class TestXComEndpoint: + @staticmethod + def clean_db(): + clear_db_dags() + clear_db_runs() + clear_db_xcom() + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + """ + Setup For XCom endpoint TC + """ + self.app = configured_app + self.client = self.app.test_client() # type:ignore + # clear existing xcoms + self.clean_db() + + def teardown_method(self) -> None: + """ + Clear Hanging XComs + """ + self.clean_db() + + +class TestGetXComEntries(TestXComEndpoint): + def test_should_respond_200_with_tilde_and_granular_dag_access(self): + dag_id_1 = "test-dag-id-1" + task_id_1 = "test-task-id-1" + execution_date = "2005-04-02T00:00:00+00:00" + execution_date_parsed = parse_execution_date(execution_date) + dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + self._create_xcom_entries(dag_id_1, dag_run_id_1, execution_date_parsed, task_id_1) + + dag_id_2 = "test-dag-id-2" + task_id_2 = "test-task-id-2" + run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + self._create_xcom_entries(dag_id_2, run_id_2, execution_date_parsed, task_id_2) + self._create_invalid_xcom_entries(execution_date_parsed) + response = self.client.get( + "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + + assert 200 == response.status_code + response_data = response.json + for xcom_entry in response_data["xcom_entries"]: + xcom_entry["timestamp"] = "TIMESTAMP" + _compare_xcom_collections( + response_data, + { + "xcom_entries": [ + { + "dag_id": dag_id_1, + "execution_date": execution_date, + "key": "test-xcom-key-1", + "task_id": task_id_1, + "timestamp": "TIMESTAMP", + "map_index": -1, + }, + { + "dag_id": dag_id_1, + "execution_date": execution_date, + "key": "test-xcom-key-2", + "task_id": task_id_1, + "timestamp": "TIMESTAMP", + "map_index": -1, + }, + ], + "total_entries": 2, + }, + ) + + def _create_xcom_entries(self, dag_id, run_id, execution_date, task_id, mapped_ti=False): + with create_session() as session: + dag = DagModel(dag_id=dag_id) + session.add(dag) + dagrun = DagRun( + dag_id=dag_id, + run_id=run_id, + execution_date=execution_date, + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun) + if mapped_ti: + for i in [0, 1]: + ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, map_index=i) + ti.dag_id = dag_id + session.add(ti) + else: + ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id) + ti.dag_id = dag_id + session.add(ti) + + for i in [1, 2]: + if mapped_ti: + key = "test-xcom-key" + map_index = i - 1 + else: + key = f"test-xcom-key-{i}" + map_index = -1 + + XCom.set( + key=key, value="TEST", run_id=run_id, task_id=task_id, dag_id=dag_id, map_index=map_index + ) + + def _create_invalid_xcom_entries(self, execution_date): + """ + Invalid XCom entries to test join query + """ + with create_session() as session: + dag = DagModel(dag_id="invalid_dag") + session.add(dag) + dagrun = DagRun( + dag_id="invalid_dag", + run_id="invalid_run_id", + execution_date=execution_date + timedelta(days=1), + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun) + dagrun1 = DagRun( + dag_id="invalid_dag", + run_id="not_this_run_id", + execution_date=execution_date, + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun1) + ti = TaskInstance(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id") + ti.dag_id = "invalid_dag" + session.add(ti) + for i in [1, 2]: + XCom.set( + key=f"invalid-xcom-key-{i}", + value="TEST", + run_id="not_this_run_id", + task_id="invalid_task", + dag_id="invalid_dag", + ) diff --git a/tests/providers/fab/auth_manager/conftest.py b/tests/providers/fab/auth_manager/conftest.py index 22c29dd229fa1..a8fbe5fbdaaae 100644 --- a/tests/providers/fab/auth_manager/conftest.py +++ b/tests/providers/fab/auth_manager/conftest.py @@ -30,7 +30,10 @@ def minimal_app_for_auth_api(): "init_appbuilder", "init_api_auth", "init_api_auth_provider", + "init_api_connexion", "init_api_error_handlers", + "init_airflow_session_interface", + "init_appbuilder_views", ] ) def factory(): @@ -39,7 +42,11 @@ def factory(): ( "api", "auth_backends", - ): "tests.test_utils.remote_user_api_auth_backend,airflow.api.auth.backend.session" + ): "tests.providers.fab.auth_manager.api_endpoints.remote_user_api_auth_backend,airflow.api.auth.backend.session", + ( + "core", + "auth_manager", + ): "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", } ): _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore @@ -58,3 +65,11 @@ def set_auth_role_public(request): yield app.config["AUTH_ROLE_PUBLIC"] = auto_role_public + + +@pytest.fixture(scope="module") +def dagbag(): + from airflow.models import DagBag + + DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() + return DagBag(include_examples=True, read_dags_from_db=True) diff --git a/tests/providers/fab/auth_manager/test_security.py b/tests/providers/fab/auth_manager/test_security.py index 156b5cf626271..bebb52c256fc8 100644 --- a/tests/providers/fab/auth_manager/test_security.py +++ b/tests/providers/fab/auth_manager/test_security.py @@ -49,7 +49,7 @@ from airflow.www.auth import get_access_denied_message from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.utils import CustomSQLAInterface -from tests.test_utils.api_connexion_utils import ( +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( create_user, create_user_scope, delete_role, diff --git a/tests/providers/fab/auth_manager/views/test_permissions.py b/tests/providers/fab/auth_manager/views/test_permissions.py index 0b1073df287fa..f24d9b738343b 100644 --- a/tests/providers/fab/auth_manager/views/test_permissions.py +++ b/tests/providers/fab/auth_manager/views/test_permissions.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_roles_list.py b/tests/providers/fab/auth_manager/views/test_roles_list.py index 156f07df41209..8de63ad5ba88a 100644 --- a/tests/providers/fab/auth_manager/views/test_roles_list.py +++ b/tests/providers/fab/auth_manager/views/test_roles_list.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_user.py b/tests/providers/fab/auth_manager/views/test_user.py index 6660ab926d886..62b03a99e7c2c 100644 --- a/tests/providers/fab/auth_manager/views/test_user.py +++ b/tests/providers/fab/auth_manager/views/test_user.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_user_edit.py b/tests/providers/fab/auth_manager/views/test_user_edit.py index 65937b6f83d33..8099f67948183 100644 --- a/tests/providers/fab/auth_manager/views/test_user_edit.py +++ b/tests/providers/fab/auth_manager/views/test_user_edit.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_user_stats.py b/tests/providers/fab/auth_manager/views/test_user_stats.py index 8cb260fcf1ec4..ae09cf92252c6 100644 --- a/tests/providers/fab/auth_manager/views/test_user_stats.py +++ b/tests/providers/fab/auth_manager/views/test_user_stats.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/test_utils/api_connexion_utils.py b/tests/test_utils/api_connexion_utils.py index af746b2d55468..48869ee48078d 100644 --- a/tests/test_utils/api_connexion_utils.py +++ b/tests/test_utils/api_connexion_utils.py @@ -17,6 +17,7 @@ from __future__ import annotations from contextlib import contextmanager +from typing import TYPE_CHECKING from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from tests.test_utils.compat import ignore_provider_compatibility_error @@ -24,6 +25,9 @@ with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.security_manager.override import EXISTING_ROLES +if TYPE_CHECKING: + from flask import Flask + @contextmanager def create_test_client(app, user_name, role_name, permissions): @@ -44,7 +48,11 @@ def create_user_scope(app, username, **kwargs): It will create a user and provide it for the fixture via YIELD (generator) then will tidy up once test is complete """ - test_user = create_user(app, username, **kwargs) + from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user as create_user_fab, + ) + + test_user = create_user_fab(app, username, **kwargs) try: yield test_user @@ -52,27 +60,20 @@ def create_user_scope(app, username, **kwargs): delete_user(app, username) -def create_user(app, username, role_name=None, email=None, permissions=None): - appbuilder = app.appbuilder - +def create_user(app: Flask, username: str, role_name: str | None): # Removes user and role so each test has isolated test data. delete_user(app, username) - role = None - if role_name: - delete_role(app, role_name) - role = create_role(app, role_name, permissions) - else: - role = [] - - return appbuilder.sm.add_user( - username=username, - first_name=username, - last_name=username, - email=email or f"{username}@example.org", - role=role, - password=username, + + users = app.config.get("SIMPLE_AUTH_MANAGER_USERS", []) + users.append( + { + "username": username, + "role": role_name, + } ) + app.config["SIMPLE_AUTH_MANAGER_USERS"] = users + def create_role(app, name, permissions=None): appbuilder = app.appbuilder @@ -87,14 +88,6 @@ def create_role(app, name, permissions=None): return role -def set_user_single_role(app, user, role_name): - role = create_role(app, role_name) - if role not in user.roles: - user.roles = [role] - app.appbuilder.sm.update_user(user) - user._perms = None - - def delete_role(app, name): if name not in EXISTING_ROLES: if app.appbuilder.sm.find_role(name): @@ -106,20 +99,11 @@ def delete_roles(app): delete_role(app, role.name) -def delete_user(app, username): - appbuilder = app.appbuilder - for user in appbuilder.sm.get_all_users(): - if user.username == username: - _ = [ - delete_role(app, role.name) for role in user.roles if role and role.name not in EXISTING_ROLES - ] - appbuilder.sm.del_register_user(user) - break - - -def delete_users(app): - for user in app.appbuilder.sm.get_all_users(): - delete_user(app, user.username) +def delete_user(app: Flask, username): + users = app.config.get("SIMPLE_AUTH_MANAGER_USERS", []) + users = [user for user in users if user["username"] != username] + + app.config["SIMPLE_AUTH_MANAGER_USERS"] = users def assert_401(response): diff --git a/tests/test_utils/remote_user_api_auth_backend.py b/tests/test_utils/remote_user_api_auth_backend.py index b7714e5192e6a..59df201e530e4 100644 --- a/tests/test_utils/remote_user_api_auth_backend.py +++ b/tests/test_utils/remote_user_api_auth_backend.py @@ -15,17 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Default authentication backend - everything is allowed""" - from __future__ import annotations import logging from functools import wraps from typing import TYPE_CHECKING, Callable, TypeVar, cast -from flask import Response, request -from flask_login import login_user +from flask import Response, request, session +from airflow.auth.managers.simple.user import SimpleAuthManagerUser from airflow.utils.airflow_flask_app import get_airflow_app if TYPE_CHECKING: @@ -36,25 +34,15 @@ CLIENT_AUTH: tuple[str, str] | AuthBase | None = None -def init_app(_): - """Initializes authentication backend""" +def init_app(_): ... T = TypeVar("T", bound=Callable) -def _lookup_user(user_email_or_username: str): - security_manager = get_airflow_app().appbuilder.sm - user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user( - username=user_email_or_username - ) - if not user: - return None - - if not user.is_active: - return None - - return user +def _lookup_user(username: str): + users = get_airflow_app().config.get("SIMPLE_AUTH_MANAGER_USERS", []) + return next((user for user in users if user["username"] == username), None) def requires_authentication(function: T): @@ -69,13 +57,13 @@ def decorated(*args, **kwargs): log.debug("Looking for user: %s", user_id) - user = _lookup_user(user_id) - if not user: + user_dict = _lookup_user(user_id) + if not user_dict: return Response("Forbidden", 403) - log.debug("Found user: %s", user) + log.debug("Found user: %s", user_dict) + session["user"] = SimpleAuthManagerUser(username=user_dict["username"], role=user_dict["role"]) - login_user(user, remember=False) return function(*args, **kwargs) return cast(T, decorated) diff --git a/tests/www/views/test_views_custom_user_views.py b/tests/www/views/test_views_custom_user_views.py index ae6d0132827c2..84947a8e5f36f 100644 --- a/tests/www/views/test_views_custom_user_views.py +++ b/tests/www/views/test_views_custom_user_views.py @@ -27,7 +27,10 @@ from airflow import settings from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_role +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user as create_user, + delete_role, +) from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login pytestmark = pytest.mark.db_test diff --git a/tests/www/views/test_views_dagrun.py b/tests/www/views/test_views_dagrun.py index 39c17d086f379..d95955246ac78 100644 --- a/tests/www/views/test_views_dagrun.py +++ b/tests/www/views/test_views_dagrun.py @@ -24,7 +24,11 @@ from airflow.utils import timezone from airflow.utils.session import create_session from airflow.www.views import DagRunModelView -from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login from tests.www.views.test_views_tasks import _get_appbuilder_pk_string diff --git a/tests/www/views/test_views_home.py b/tests/www/views/test_views_home.py index 5393115041392..ddec0c0bcfed3 100644 --- a/tests/www/views/test_views_home.py +++ b/tests/www/views/test_views_home.py @@ -27,7 +27,7 @@ from airflow.utils.state import State from airflow.www.utils import UIAlert from airflow.www.views import FILTER_LASTRUN_COOKIE, FILTER_STATUS_COOKIE, FILTER_TAGS_COOKIE -from tests.test_utils.api_connexion_utils import create_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user from tests.test_utils.db import clear_db_dags, clear_db_import_errors, clear_db_serialized_dags from tests.test_utils.permissions import _resource_name from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index f5cc011fb6f0e..7b65051724c27 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -44,7 +44,11 @@ from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType from airflow.www.views import TaskInstanceModelView, _safe_parse_datetime -from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_runs, clear_db_xcom diff --git a/tests/www/views/test_views_variable.py b/tests/www/views/test_views_variable.py index a91a12ddc470b..b7fa8b37c52c8 100644 --- a/tests/www/views/test_views_variable.py +++ b/tests/www/views/test_views_variable.py @@ -25,7 +25,7 @@ from airflow.models import Variable from airflow.security import permissions from airflow.utils.session import create_session -from tests.test_utils.api_connexion_utils import create_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user from tests.test_utils.www import ( _check_last_log, check_content_in_response, From f3e1b32c3a18b2525c306468f40c8b6d99fc410f Mon Sep 17 00:00:00 2001 From: Brent Bovenzi Date: Tue, 1 Oct 2024 17:25:18 +0200 Subject: [PATCH 229/349] Add Docs button to Nav (#42586) * Add Docs button to new UI nav * Add Docs menu button to Nav * Use src alias * Address PR feedback, update documentation * Delete airflow/ui/.env.local --- .gitignore | 1 + airflow/ui/.env.example | 23 +++++++ airflow/ui/src/layouts/Nav/DocsButton.tsx | 67 +++++++++++++++++++ airflow/ui/src/layouts/{ => Nav}/Nav.tsx | 9 ++- .../ui/src/layouts/{ => Nav}/NavButton.tsx | 17 ++--- airflow/ui/src/layouts/Nav/index.tsx | 20 ++++++ airflow/ui/src/layouts/Nav/navButtonProps.ts | 30 +++++++++ airflow/ui/src/main.tsx | 2 +- airflow/ui/src/vite-env.d.ts | 9 +++ .../14_node_environment_setup.rst | 17 +++++ 10 files changed, 178 insertions(+), 17 deletions(-) create mode 100644 airflow/ui/.env.example create mode 100644 airflow/ui/src/layouts/Nav/DocsButton.tsx rename airflow/ui/src/layouts/{ => Nav}/Nav.tsx (94%) rename airflow/ui/src/layouts/{ => Nav}/NavButton.tsx (83%) create mode 100644 airflow/ui/src/layouts/Nav/index.tsx create mode 100644 airflow/ui/src/layouts/Nav/navButtonProps.ts diff --git a/.gitignore b/.gitignore index 257331cb4e90b..a9c055041d980 100644 --- a/.gitignore +++ b/.gitignore @@ -111,6 +111,7 @@ celerybeat-schedule # dotenv .env +.env.local .autoenv*.zsh # virtualenv diff --git a/airflow/ui/.env.example b/airflow/ui/.env.example new file mode 100644 index 0000000000000..9374d93de6bca --- /dev/null +++ b/airflow/ui/.env.example @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#/ + + +# This is an example. You should make your own `.env.local` file for development + +VITE_FASTAPI_URL="http://localhost:29091" diff --git a/airflow/ui/src/layouts/Nav/DocsButton.tsx b/airflow/ui/src/layouts/Nav/DocsButton.tsx new file mode 100644 index 0000000000000..07a4b93dfaede --- /dev/null +++ b/airflow/ui/src/layouts/Nav/DocsButton.tsx @@ -0,0 +1,67 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { + IconButton, + Link, + Menu, + MenuButton, + MenuItem, + MenuList, +} from "@chakra-ui/react"; +import { FiBookOpen } from "react-icons/fi"; + +import { navButtonProps } from "./navButtonProps"; + +const links = [ + { + href: "https://airflow.apache.org/docs/", + title: "Documentation", + }, + { + href: "https://github.com/apache/airflow", + title: "GitHub Repo", + }, + { + href: `${import.meta.env.VITE_FASTAPI_URL}/docs`, + title: "REST API Reference", + }, +]; + +export const DocsButton = () => ( + + } + {...navButtonProps} + /> + + {links.map((link) => ( + + {link.title} + + ))} + + +); diff --git a/airflow/ui/src/layouts/Nav.tsx b/airflow/ui/src/layouts/Nav/Nav.tsx similarity index 94% rename from airflow/ui/src/layouts/Nav.tsx rename to airflow/ui/src/layouts/Nav/Nav.tsx index 4900540cd96d1..55bfd4480e0f4 100644 --- a/airflow/ui/src/layouts/Nav.tsx +++ b/airflow/ui/src/layouts/Nav/Nav.tsx @@ -37,8 +37,10 @@ import { FiSun, } from "react-icons/fi"; -import { AirflowPin } from "../assets/AirflowPin"; -import { DagIcon } from "../assets/DagIcon"; +import { AirflowPin } from "src/assets/AirflowPin"; +import { DagIcon } from "src/assets/DagIcon"; + +import { DocsButton } from "./DocsButton"; import { NavButton } from "./NavButton"; export const Nav = () => { @@ -78,7 +80,7 @@ export const Nav = () => { } isDisabled - title="Datasets" + title="Assets" /> } @@ -103,6 +105,7 @@ export const Nav = () => { icon={} title="Return to legacy UI" /> + ( - diff --git a/airflow/ui/src/layouts/Nav/index.tsx b/airflow/ui/src/layouts/Nav/index.tsx new file mode 100644 index 0000000000000..403e140919b04 --- /dev/null +++ b/airflow/ui/src/layouts/Nav/index.tsx @@ -0,0 +1,20 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export { Nav } from "./Nav"; diff --git a/airflow/ui/src/layouts/Nav/navButtonProps.ts b/airflow/ui/src/layouts/Nav/navButtonProps.ts new file mode 100644 index 0000000000000..740348bc9676b --- /dev/null +++ b/airflow/ui/src/layouts/Nav/navButtonProps.ts @@ -0,0 +1,30 @@ +/*! + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import type { ButtonProps } from "@chakra-ui/react"; + +export const navButtonProps: ButtonProps = { + alignItems: "center", + borderRadius: "none", + flexDir: "column", + height: 16, + transition: "0.2s background-color ease-in-out", + variant: "ghost", + whiteSpace: "wrap", + width: 24, +}; diff --git a/airflow/ui/src/main.tsx b/airflow/ui/src/main.tsx index ca5fbed04b6ce..7b762508ea7b3 100644 --- a/airflow/ui/src/main.tsx +++ b/airflow/ui/src/main.tsx @@ -43,7 +43,7 @@ const queryClient = new QueryClient({ }, }); -axios.defaults.baseURL = "http://localhost:29091"; +axios.defaults.baseURL = import.meta.env.VITE_FASTAPI_URL; // redirect to login page if the API responds with unauthorized or forbidden errors axios.interceptors.response.use( diff --git a/airflow/ui/src/vite-env.d.ts b/airflow/ui/src/vite-env.d.ts index a1fdcdd1e6fc5..193866687bff9 100644 --- a/airflow/ui/src/vite-env.d.ts +++ b/airflow/ui/src/vite-env.d.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/consistent-type-definitions */ /*! * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -18,3 +19,11 @@ */ /// + +interface ImportMetaEnv { + readonly VITE_FASTAPI_URL: string; +} + +interface ImportMeta { + readonly env: ImportMetaEnv; +} diff --git a/contributing-docs/14_node_environment_setup.rst b/contributing-docs/14_node_environment_setup.rst index 8d98f0860fc8b..7b10f0b0d5ed5 100644 --- a/contributing-docs/14_node_environment_setup.rst +++ b/contributing-docs/14_node_environment_setup.rst @@ -84,6 +84,23 @@ Project Structure - ``/src/components`` shared components across the UI - ``/dist`` build files +Local Environment Variables +--------------------------- + +Copy the example environment + +.. code-block:: bash + + cp .env.example .env.local + +If you run into CORS issues, you may need to add some variables to your Breeze config, ``files/airflow-breeze-config/variables.env``: + +.. code-block:: bash + + export AIRFLOW__API__ACCESS_CONTROL_ALLOW_HEADERS="Origin, Access-Control-Request-Method" + export AIRFLOW__API__ACCESS_CONTROL_ALLOW_METHODS="*" + export AIRFLOW__API__ACCESS_CONTROL_ALLOW_ORIGINS="http://localhost:28080,http://localhost:8080" + DEPRECATED Airflow WWW From 09c2d4ae801d75856b66886214db524c5be2a319 Mon Sep 17 00:00:00 2001 From: Elad Kalif <45845474+eladkal@users.noreply.github.com> Date: Tue, 1 Oct 2024 22:31:17 +0700 Subject: [PATCH 230/349] Update providers metadata 2024-10-01 (#42611) --- generated/provider_metadata.json | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/generated/provider_metadata.json b/generated/provider_metadata.json index a73e3da9f6fce..56199e2f82c3d 100644 --- a/generated/provider_metadata.json +++ b/generated/provider_metadata.json @@ -2763,6 +2763,10 @@ "1.17.0": { "associated_airflow_version": "2.10.1", "date_released": "2024-09-24T13:49:56Z" + }, + "1.17.1": { + "associated_airflow_version": "2.10.1", + "date_released": "2024-10-01T09:05:14Z" } }, "databricks": { @@ -6225,6 +6229,10 @@ "1.12.0": { "associated_airflow_version": "2.10.1", "date_released": "2024-09-24T13:49:56Z" + }, + "1.12.1": { + "associated_airflow_version": "2.10.1", + "date_released": "2024-10-01T09:05:14Z" } }, "opensearch": { From d413f0b76c7fb5808f40bcab2a3975b4d256cac4 Mon Sep 17 00:00:00 2001 From: Julian Maicher Date: Tue, 1 Oct 2024 21:23:40 +0200 Subject: [PATCH 231/349] Prevent redirect loop on /home with tags/lastrun filters (#42607) (#42609) Closes #42607 --- airflow/www/views.py | 17 +++++++------ tests/www/views/test_views_home.py | 38 ++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index b3300b517e757..a361b9bd50c16 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -814,19 +814,22 @@ def index(self): return redirect(url_for("Airflow.index")) filter_tags_cookie_val = flask_session.get(FILTER_TAGS_COOKIE) + filter_lastrun_cookie_val = flask_session.get(FILTER_LASTRUN_COOKIE) + + # update filter args in url from session values if needed + if (not arg_tags_filter and filter_tags_cookie_val) or ( + not arg_lastrun_filter and filter_lastrun_cookie_val + ): + tags = arg_tags_filter or (filter_tags_cookie_val and filter_tags_cookie_val.split(",")) + lastrun = arg_lastrun_filter or filter_lastrun_cookie_val + return redirect(url_for("Airflow.index", tags=tags, lastrun=lastrun)) + if arg_tags_filter: flask_session[FILTER_TAGS_COOKIE] = ",".join(arg_tags_filter) - elif filter_tags_cookie_val: - # If tags exist in cookie, but not URL, add them to the URL - return redirect(url_for("Airflow.index", tags=filter_tags_cookie_val.split(","))) - filter_lastrun_cookie_val = flask_session.get(FILTER_LASTRUN_COOKIE) if arg_lastrun_filter: arg_lastrun_filter = arg_lastrun_filter.strip().lower() flask_session[FILTER_LASTRUN_COOKIE] = arg_lastrun_filter - elif filter_lastrun_cookie_val: - # If tags exist in cookie, but not URL, add them to the URL - return redirect(url_for("Airflow.index", lastrun=filter_lastrun_cookie_val)) if arg_status_filter is None: filter_status_cookie_val = flask_session.get(FILTER_STATUS_COOKIE) diff --git a/tests/www/views/test_views_home.py b/tests/www/views/test_views_home.py index ddec0c0bcfed3..44dda24feecbc 100644 --- a/tests/www/views/test_views_home.py +++ b/tests/www/views/test_views_home.py @@ -466,3 +466,41 @@ def test_analytics_pixel(user_client, is_enabled, should_have_pixel): check_content_in_response("apacheairflow.gateway.scarf.sh", resp) else: check_content_not_in_response("apacheairflow.gateway.scarf.sh", resp) + + +@pytest.mark.parametrize( + "url, filter_tags_cookie_val, filter_lastrun_cookie_val, expected_filter_tags, expected_filter_lastrun", + [ + ("home", None, None, [], None), + # from url only + ("home?tags=example&tags=test", None, None, ["example", "test"], None), + ("home?lastrun=running", None, None, [], "running"), + ("home?tags=example&tags=test&lastrun=running", None, None, ["example", "test"], "running"), + # from cookie only + ("home", "example,test", None, ["example", "test"], None), + ("home", None, "running", [], "running"), + ("home", "example,test", "running", ["example", "test"], "running"), + # from url and cookie + ("home?tags=example", "example,test", None, ["example"], None), + ("home?lastrun=failed", None, "running", [], "failed"), + ("home?tags=example", None, "running", ["example"], "running"), + ("home?lastrun=running", "example,test", None, ["example", "test"], "running"), + ("home?tags=example&lastrun=running", "example,test", "failed", ["example"], "running"), + ], +) +def test_filter_cookie_eval( + working_dags, + admin_client, + url, + filter_tags_cookie_val, + filter_lastrun_cookie_val, + expected_filter_tags, + expected_filter_lastrun, +): + with admin_client.session_transaction() as flask_session: + flask_session[FILTER_TAGS_COOKIE] = filter_tags_cookie_val + flask_session[FILTER_LASTRUN_COOKIE] = filter_lastrun_cookie_val + + resp = admin_client.get(url, follow_redirects=True) + assert resp.request.args.getlist("tags") == expected_filter_tags + assert resp.request.args.get("lastrun") == expected_filter_lastrun From 13f115820d2420a8775edc23dc753b999f3ddb19 Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Tue, 1 Oct 2024 15:37:19 -0400 Subject: [PATCH 232/349] Remove `AIRFLOW_V_2_7_PLUS` constant (#42627) --- contributing-docs/testing/unit_tests.rst | 4 ++-- tests/test_utils/compat.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/contributing-docs/testing/unit_tests.rst b/contributing-docs/testing/unit_tests.rst index 935a7b9b602b4..cc6513eaa2e5f 100644 --- a/contributing-docs/testing/unit_tests.rst +++ b/contributing-docs/testing/unit_tests.rst @@ -1184,10 +1184,10 @@ are not part of the public API. We deal with it in one of the following ways: .. code-block:: python - from tests.test_utils.compat import AIRFLOW_V_2_7_PLUS + from tests.test_utils.compat import AIRFLOW_V_2_8_PLUS - @pytest.mark.skipif(not AIRFLOW_V_2_7_PLUS, reason="The tests should be skipped for Airflow < 2.7") + @pytest.mark.skipif(not AIRFLOW_V_2_8_PLUS, reason="The tests should be skipped for Airflow < 2.8") def some_test_that_only_works_for_airflow_2_7_plus(): pass diff --git a/tests/test_utils/compat.py b/tests/test_utils/compat.py index ca1d7e9c77dfa..09f3653db82d8 100644 --- a/tests/test_utils/compat.py +++ b/tests/test_utils/compat.py @@ -42,7 +42,6 @@ from airflow import __version__ as airflow_version AIRFLOW_VERSION = Version(airflow_version) -AIRFLOW_V_2_7_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.7.0") AIRFLOW_V_2_8_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.8.0") AIRFLOW_V_2_9_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.9.0") AIRFLOW_V_2_10_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.10.0") From 4380bffcbe495d5711767dfefeda53fa8fb35a8b Mon Sep 17 00:00:00 2001 From: GPK Date: Tue, 1 Oct 2024 21:00:29 +0100 Subject: [PATCH 233/349] Move FSHook/PackageIndexHook/SubprocessHook to standard provider (#42506) * move hooks to standard providers * fix document build and adding hooks to provider yaml file * adding fshook tests * marking as db test * doc reference update to subprocess hook --- airflow/operators/bash.py | 2 +- airflow/providers/standard/hooks/__init__.py | 16 ++++++++ .../standard}/hooks/filesystem.py | 0 .../standard}/hooks/package_index.py | 0 .../standard}/hooks/subprocess.py | 4 +- airflow/providers/standard/provider.yaml | 7 ++++ airflow/providers_manager.py | 4 +- airflow/sensors/filesystem.py | 2 +- .../logging-monitoring/errors.rst | 2 +- .../operators-and-hooks-ref.rst | 4 +- tests/providers/standard/hooks/__init__.py | 16 ++++++++ .../standard/hooks/test_filesystem.py | 39 +++++++++++++++++++ .../standard}/hooks/test_package_index.py | 6 +-- .../standard}/hooks/test_subprocess.py | 6 +-- tests/sensors/test_filesystem.py | 2 +- 15 files changed, 94 insertions(+), 16 deletions(-) create mode 100644 airflow/providers/standard/hooks/__init__.py rename airflow/{ => providers/standard}/hooks/filesystem.py (100%) rename airflow/{ => providers/standard}/hooks/package_index.py (100%) rename airflow/{ => providers/standard}/hooks/subprocess.py (96%) create mode 100644 tests/providers/standard/hooks/__init__.py create mode 100644 tests/providers/standard/hooks/test_filesystem.py rename tests/{ => providers/standard}/hooks/test_package_index.py (93%) rename tests/{ => providers/standard}/hooks/test_subprocess.py (95%) diff --git a/airflow/operators/bash.py b/airflow/operators/bash.py index 2ec0341a0d1e2..bf4a943df6e08 100644 --- a/airflow/operators/bash.py +++ b/airflow/operators/bash.py @@ -24,8 +24,8 @@ from typing import TYPE_CHECKING, Any, Callable, Container, Sequence, cast from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.hooks.subprocess import SubprocessHook from airflow.models.baseoperator import BaseOperator +from airflow.providers.standard.hooks.subprocess import SubprocessHook from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.types import ArgNotSet diff --git a/airflow/providers/standard/hooks/__init__.py b/airflow/providers/standard/hooks/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/standard/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/hooks/filesystem.py b/airflow/providers/standard/hooks/filesystem.py similarity index 100% rename from airflow/hooks/filesystem.py rename to airflow/providers/standard/hooks/filesystem.py diff --git a/airflow/hooks/package_index.py b/airflow/providers/standard/hooks/package_index.py similarity index 100% rename from airflow/hooks/package_index.py rename to airflow/providers/standard/hooks/package_index.py diff --git a/airflow/hooks/subprocess.py b/airflow/providers/standard/hooks/subprocess.py similarity index 96% rename from airflow/hooks/subprocess.py rename to airflow/providers/standard/hooks/subprocess.py index bc20b5c20b4c5..9e578a7d8034b 100644 --- a/airflow/hooks/subprocess.py +++ b/airflow/providers/standard/hooks/subprocess.py @@ -52,8 +52,8 @@ def run_command( :param env: Optional dict containing environment variables to be made available to the shell environment in which ``command`` will be executed. If omitted, ``os.environ`` will be used. Note, that in case you have Sentry configured, original variables from the environment - will also be passed to the subprocess with ``SUBPROCESS_`` prefix. See - :doc:`/administration-and-deployment/logging-monitoring/errors` for details. + will also be passed to the subprocess with ``SUBPROCESS_`` prefix. See: + https://airflow.apache.org/docs/apache-airflow/stable/administration-and-deployment/logging-monitoring/errors.html for details. :param output_encoding: encoding to use for decoding stdout :param cwd: Working directory to run the command in. If None (default), the command is run in a temporary directory. diff --git a/airflow/providers/standard/provider.yaml b/airflow/providers/standard/provider.yaml index 83d8acf0a68b3..068fde1fe3761 100644 --- a/airflow/providers/standard/provider.yaml +++ b/airflow/providers/standard/provider.yaml @@ -50,3 +50,10 @@ sensors: - airflow.providers.standard.sensors.time_delta - airflow.providers.standard.sensors.time - airflow.providers.standard.sensors.weekday + +hooks: + - integration-name: Standard + python-modules: + - airflow.providers.standard.hooks.filesystem + - airflow.providers.standard.hooks.package_index + - airflow.providers.standard.hooks.subprocess diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 2c673063cb23e..e276c465ef689 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -36,8 +36,8 @@ from packaging.utils import canonicalize_name from airflow.exceptions import AirflowOptionalProviderFeatureException -from airflow.hooks.filesystem import FSHook -from airflow.hooks.package_index import PackageIndexHook +from airflow.providers.standard.hooks.filesystem import FSHook +from airflow.providers.standard.hooks.package_index import PackageIndexHook from airflow.typing_compat import ParamSpec from airflow.utils import yaml from airflow.utils.entry_points import entry_points_with_dist diff --git a/airflow/sensors/filesystem.py b/airflow/sensors/filesystem.py index 5d32ab07ad4e7..4496f5d6abfa4 100644 --- a/airflow/sensors/filesystem.py +++ b/airflow/sensors/filesystem.py @@ -25,7 +25,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.hooks.filesystem import FSHook +from airflow.providers.standard.hooks.filesystem import FSHook from airflow.sensors.base import BaseSensorOperator from airflow.triggers.base import StartTriggerArgs from airflow.triggers.file import FileTrigger diff --git a/docs/apache-airflow/administration-and-deployment/logging-monitoring/errors.rst b/docs/apache-airflow/administration-and-deployment/logging-monitoring/errors.rst index cb09843422321..0ad3fa8c5127a 100644 --- a/docs/apache-airflow/administration-and-deployment/logging-monitoring/errors.rst +++ b/docs/apache-airflow/administration-and-deployment/logging-monitoring/errors.rst @@ -96,7 +96,7 @@ Impact of Sentry on Environment variables passed to Subprocess Hook When Sentry is enabled, by default it changes the standard library to pass all environment variables to subprocesses opened by Airflow. This changes the default behaviour of -:class:`airflow.hooks.subprocess.SubprocessHook` - always all environment variables are passed to the +:class:`airflow.providers.standard.hooks.subprocess.SubprocessHook` - always all environment variables are passed to the subprocess executed with specific set of environment variables. In this case not only the specified environment variables are passed but also all existing environment variables are passed with ``SUBPROCESS_`` prefix added. This happens also for all other subprocesses. diff --git a/docs/apache-airflow/operators-and-hooks-ref.rst b/docs/apache-airflow/operators-and-hooks-ref.rst index 16b74305a958b..d4ac6bda74c34 100644 --- a/docs/apache-airflow/operators-and-hooks-ref.rst +++ b/docs/apache-airflow/operators-and-hooks-ref.rst @@ -106,8 +106,8 @@ For details see: :doc:`apache-airflow-providers:operators-and-hooks-ref/index`. * - Hooks - Guides - * - :mod:`airflow.hooks.filesystem` + * - :mod:`airflow.providers.standard.hooks.filesystem` - - * - :mod:`airflow.hooks.subprocess` + * - :mod:`airflow.providers.standard.hooks.subprocess` - diff --git a/tests/providers/standard/hooks/__init__.py b/tests/providers/standard/hooks/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/standard/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/standard/hooks/test_filesystem.py b/tests/providers/standard/hooks/test_filesystem.py new file mode 100644 index 0000000000000..bbcd22dc94219 --- /dev/null +++ b/tests/providers/standard/hooks/test_filesystem.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.standard.hooks.filesystem import FSHook + +pytestmark = pytest.mark.db_test + + +class TestFSHook: + def test_get_ui_field_behaviour(self): + fs_hook = FSHook() + assert fs_hook.get_ui_field_behaviour() == { + "hidden_fields": ["host", "schema", "port", "login", "password", "extra"], + "relabeling": {}, + "placeholders": {}, + } + + def test_get_path(self): + fs_hook = FSHook(fs_conn_id="fs_default") + + assert fs_hook.get_path() == "/" diff --git a/tests/hooks/test_package_index.py b/tests/providers/standard/hooks/test_package_index.py similarity index 93% rename from tests/hooks/test_package_index.py rename to tests/providers/standard/hooks/test_package_index.py index 9da429c5a09cf..6a90db0715d81 100644 --- a/tests/hooks/test_package_index.py +++ b/tests/providers/standard/hooks/test_package_index.py @@ -21,8 +21,8 @@ import pytest -from airflow.hooks.package_index import PackageIndexHook from airflow.models.connection import Connection +from airflow.providers.standard.hooks.package_index import PackageIndexHook class MockConnection(Connection): @@ -73,7 +73,7 @@ def mock_get_connection(monkeypatch: pytest.MonkeyPatch, request: pytest.Fixture password: str | None = testdata.get("password", None) expected_result: str | None = testdata.get("expected_result", None) monkeypatch.setattr( - "airflow.hooks.package_index.PackageIndexHook.get_connection", + "airflow.providers.standard.hooks.package_index.PackageIndexHook.get_connection", lambda *_: MockConnection(host, login, password), ) return expected_result @@ -104,7 +104,7 @@ class MockProc: return MockProc() - monkeypatch.setattr("airflow.hooks.package_index.subprocess.run", mock_run) + monkeypatch.setattr("airflow.providers.standard.hooks.package_index.subprocess.run", mock_run) hook_instance = PackageIndexHook() if mock_get_connection: diff --git a/tests/hooks/test_subprocess.py b/tests/providers/standard/hooks/test_subprocess.py similarity index 95% rename from tests/hooks/test_subprocess.py rename to tests/providers/standard/hooks/test_subprocess.py index 0f625be816887..2b2e9473359e5 100644 --- a/tests/hooks/test_subprocess.py +++ b/tests/providers/standard/hooks/test_subprocess.py @@ -26,7 +26,7 @@ import pytest -from airflow.hooks.subprocess import SubprocessHook +from airflow.providers.standard.hooks.subprocess import SubprocessHook OS_ENV_KEY = "SUBPROCESS_ENV_TEST" OS_ENV_VAL = "this-is-from-os-environ" @@ -81,11 +81,11 @@ def test_return_value(self, val, expected): @mock.patch.dict("os.environ", clear=True) @mock.patch( - "airflow.hooks.subprocess.TemporaryDirectory", + "airflow.providers.standard.hooks.subprocess.TemporaryDirectory", return_value=MagicMock(__enter__=MagicMock(return_value="/tmp/airflowtmpcatcat")), ) @mock.patch( - "airflow.hooks.subprocess.Popen", + "airflow.providers.standard.hooks.subprocess.Popen", return_value=MagicMock(stdout=MagicMock(readline=MagicMock(side_effect=StopIteration), returncode=0)), ) def test_should_exec_subprocess(self, mock_popen, mock_temporary_directory): diff --git a/tests/sensors/test_filesystem.py b/tests/sensors/test_filesystem.py index 1fb123cfe7248..641f2f218f2db 100644 --- a/tests/sensors/test_filesystem.py +++ b/tests/sensors/test_filesystem.py @@ -40,7 +40,7 @@ @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode class TestFileSensor: def setup_method(self): - from airflow.hooks.filesystem import FSHook + from airflow.providers.standard.hooks.filesystem import FSHook hook = FSHook() args = {"owner": "airflow", "start_date": DEFAULT_DATE} From 828a6d872e01826c515ca146971a0314e151b9df Mon Sep 17 00:00:00 2001 From: rom sharon <33751805+romsharon98@users.noreply.github.com> Date: Tue, 1 Oct 2024 23:05:45 +0300 Subject: [PATCH 234/349] send notification to internal-ci-cd channel (#42630) --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8828a30ce3ecd..716323cb9acfd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -690,7 +690,7 @@ jobs: id: slack uses: slackapi/slack-github-action@v1.27.0 with: - channel-id: 'zzz_webhook_test' + channel-id: 'internal-airflow-ci-cd' # yamllint disable rule:line-length payload: | { From 730347b35c544fa0c400580af40eadc7054b0d1c Mon Sep 17 00:00:00 2001 From: Kalyan Date: Wed, 2 Oct 2024 03:29:54 +0530 Subject: [PATCH 235/349] Add heartbeat metric for DAG processor (#42398) --------- Signed-off-by: kalyanr --- airflow/jobs/dag_processor_job_runner.py | 16 ++++++++++------ chart/files/statsd-mappings.yml | 6 ++++++ .../logging-monitoring/metrics.rst | 1 + 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/airflow/jobs/dag_processor_job_runner.py b/airflow/jobs/dag_processor_job_runner.py index 76b2ab5925540..28128efba474b 100644 --- a/airflow/jobs/dag_processor_job_runner.py +++ b/airflow/jobs/dag_processor_job_runner.py @@ -17,18 +17,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import Job, perform_heartbeat +from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: - from airflow.dag_processing.manager import DagFileProcessorManager - + from sqlalchemy.orm import Session -def empty_callback(_: Any) -> None: - pass + from airflow.dag_processing.manager import DagFileProcessorManager class DagProcessorJobRunner(BaseJobRunner, LoggingMixin): @@ -52,7 +52,7 @@ def __init__( self.processor = processor self.processor.heartbeat = lambda: perform_heartbeat( job=self.job, - heartbeat_callback=empty_callback, + heartbeat_callback=self.heartbeat_callback, only_if_necessary=True, ) @@ -67,3 +67,7 @@ def _execute(self) -> int | None: self.processor.terminate() self.processor.end() return None + + @provide_session + def heartbeat_callback(self, session: Session = NEW_SESSION) -> None: + Stats.incr("dag_processor_heartbeat", 1, 1) diff --git a/chart/files/statsd-mappings.yml b/chart/files/statsd-mappings.yml index 86d773fd20b7f..cef9593dd16d3 100644 --- a/chart/files/statsd-mappings.yml +++ b/chart/files/statsd-mappings.yml @@ -46,6 +46,12 @@ mappings: labels: type: counter + - match: airflow.dag_processor_heartbeat + match_type: regex + name: "airflow_dag_processor_heartbeat" + labels: + type: counter + - match: airflow.dag.*.*.duration name: "airflow_task_duration" labels: diff --git a/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst b/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst index ac44d1acba9c0..079aa5d397a41 100644 --- a/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst +++ b/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst @@ -159,6 +159,7 @@ Name Descripti ``previously_succeeded`` Number of previously succeeded task instances. Metric with dag_id and task_id tagging. ``zombies_killed`` Zombie tasks killed. Metric with dag_id and task_id tagging. ``scheduler_heartbeat`` Scheduler heartbeats +``dag_processor_heartbeat`` Standalone DAG processor heartbeats ``dag_processing.processes`` Relative number of currently running DAG parsing processes (ie this delta is negative when, since the last metric was sent, processes have completed). Metric with file_path and action tagging. From e8fbe037e247ed69dddd5d7232656547b2e26470 Mon Sep 17 00:00:00 2001 From: Alexander Millin Date: Wed, 2 Oct 2024 03:44:04 +0300 Subject: [PATCH 236/349] Fix the order of tasks during serialization (#42219) SerializedDagModel().dag_hash may change if the order of tasks is not fixed --- airflow/serialization/serialized_objects.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index a4801b767acc5..08944391b8166 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1604,7 +1604,9 @@ def serialize_dag(cls, dag: DAG) -> dict: try: serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields) serialized_dag["_processor_dags_folder"] = DAGS_FOLDER - serialized_dag["tasks"] = [cls.serialize(task) for _, task in dag.task_dict.items()] + serialized_dag["tasks"] = [ + cls.serialize(dag.task_dict[task_id]) for task_id in sorted(dag.task_dict) + ] dag_deps = [ dep From 375ea2d79545c0e8292b49e68833e17ce16ffc99 Mon Sep 17 00:00:00 2001 From: Kyle Thatcher <33584092+Kytha@users.noreply.github.com> Date: Tue, 1 Oct 2024 20:46:16 -0400 Subject: [PATCH 237/349] Remove state sync during celery task processing (#41870) --- airflow/providers/celery/executors/celery_executor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/airflow/providers/celery/executors/celery_executor.py b/airflow/providers/celery/executors/celery_executor.py index 93037bb31c136..807c77ab98782 100644 --- a/airflow/providers/celery/executors/celery_executor.py +++ b/airflow/providers/celery/executors/celery_executor.py @@ -300,9 +300,6 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: # which point we don't need the ID anymore anyway self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id) - # If the task runs _really quickly_ we may already have a result! - self.update_task_state(key, result.state, getattr(result, "info", None)) - def _send_tasks_to_celery(self, task_tuples_to_send: list[TaskInstanceInCelery]): from airflow.providers.celery.executors.celery_executor_utils import send_task_to_executor From 71112f0da3a0737920816cb998c18f467851cdfb Mon Sep 17 00:00:00 2001 From: phi-friday Date: Wed, 2 Oct 2024 10:12:30 +0900 Subject: [PATCH 238/349] fix: rm `skip_if` and `run_if` in python source (#41832) --- airflow/utils/decorators.py | 2 +- tests/utils/test_decorators.py | 128 +++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_decorators.py diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index 4cad5ab9e6073..e299999423e56 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -49,7 +49,7 @@ def _remove_task_decorator(py_source, decorator_name): after_decorator = after_decorator[1:] return before_decorator + after_decorator - decorators = ["@setup", "@teardown", task_decorator_name] + decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name] for decorator in decorators: python_source = _remove_task_decorator(python_source, decorator) return python_source diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py new file mode 100644 index 0000000000000..19d3ec31d0311 --- /dev/null +++ b/tests/utils/test_decorators.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from airflow.decorators import task + +if TYPE_CHECKING: + from airflow.decorators.base import Task, TaskDecorator + +_CONDITION_DECORATORS = frozenset({"skip_if", "run_if"}) +_NO_SOURCE_DECORATORS = frozenset({"sensor"}) +DECORATORS = sorted( + set(x for x in dir(task) if not x.startswith("_")) - _CONDITION_DECORATORS - _NO_SOURCE_DECORATORS +) +DECORATORS_USING_SOURCE = ("external_python", "virtualenv", "branch_virtualenv", "branch_external_python") + + +@pytest.fixture +def decorator(request: pytest.FixtureRequest) -> TaskDecorator: + decorator_factory = getattr(task, request.param) + + kwargs = {} + if "external" in request.param: + kwargs["python"] = "python3" + return decorator_factory(**kwargs) + + +@pytest.mark.parametrize("decorator", DECORATORS_USING_SOURCE, indirect=["decorator"]) +def test_task_decorator_using_source(decorator: TaskDecorator): + @decorator + def f(): + return ["some_task"] + + assert parse_python_source(f, "decorator") == 'def f():\n return ["some_task"]\n' + + +@pytest.mark.parametrize("decorator", DECORATORS, indirect=["decorator"]) +def test_skip_if(decorator: TaskDecorator): + @task.skip_if(lambda context: True) + @decorator + def f(): + return "hello world" + + assert parse_python_source(f, "decorator") == 'def f():\n return "hello world"\n' + + +@pytest.mark.parametrize("decorator", DECORATORS, indirect=["decorator"]) +def test_run_if(decorator: TaskDecorator): + @task.run_if(lambda context: True) + @decorator + def f(): + return "hello world" + + assert parse_python_source(f, "decorator") == 'def f():\n return "hello world"\n' + + +def test_skip_if_and_run_if(): + @task.skip_if(lambda context: True) + @task.run_if(lambda context: True) + @task.virtualenv() + def f(): + return "hello world" + + assert parse_python_source(f) == 'def f():\n return "hello world"\n' + + +def test_run_if_and_skip_if(): + @task.run_if(lambda context: True) + @task.skip_if(lambda context: True) + @task.virtualenv() + def f(): + return "hello world" + + assert parse_python_source(f) == 'def f():\n return "hello world"\n' + + +def test_skip_if_allow_decorator(): + def non_task_decorator(func): + return func + + @task.skip_if(lambda context: True) + @task.virtualenv() + @non_task_decorator + def f(): + return "hello world" + + assert parse_python_source(f) == '@non_task_decorator\ndef f():\n return "hello world"\n' + + +def test_run_if_allow_decorator(): + def non_task_decorator(func): + return func + + @task.run_if(lambda context: True) + @task.virtualenv() + @non_task_decorator + def f(): + return "hello world" + + assert parse_python_source(f) == '@non_task_decorator\ndef f():\n return "hello world"\n' + + +def parse_python_source(task: Task, custom_operator_name: str | None = None) -> str: + operator = task().operator + if custom_operator_name: + custom_operator_name = ( + custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}" + ) + operator.__dict__["custom_operator_name"] = custom_operator_name + return operator.get_python_source() From bdee2d36f6ec88b7ff1d2007042a3c6c2b24a30b Mon Sep 17 00:00:00 2001 From: phi-friday Date: Wed, 2 Oct 2024 10:13:38 +0900 Subject: [PATCH 239/349] fix: task flow dynamic mapping with default_args (#41592) --- airflow/decorators/base.py | 25 ++++++++++++++++++------- tests/decorators/test_mapped.py | 24 ++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 1ef2c12c702f2..e650c1920a870 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -431,18 +431,29 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag() task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag) - partial_kwargs, partial_params = get_merged_defaults( + default_args, partial_params = get_merged_defaults( dag=dag, task_group=task_group, task_params=task_kwargs.pop("params", None), task_default_args=task_kwargs.pop("default_args", None), ) - partial_kwargs.update( - task_kwargs, - is_setup=self.is_setup, - is_teardown=self.is_teardown, - on_failure_fail_dagrun=self.on_failure_fail_dagrun, - ) + partial_kwargs: dict[str, Any] = { + "is_setup": self.is_setup, + "is_teardown": self.is_teardown, + "on_failure_fail_dagrun": self.on_failure_fail_dagrun, + } + base_signature = inspect.signature(BaseOperator) + ignore = { + "default_args", # This is target we are working on now. + "kwargs", # A common name for a keyword argument. + "do_xcom_push", # In the same boat as `multiple_outputs` + "multiple_outputs", # We will use `self.multiple_outputs` instead. + "params", # Already handled above `partial_params`. + "task_concurrency", # Deprecated(replaced by `max_active_tis_per_dag`). + } + partial_keys = set(base_signature.parameters) - ignore + partial_kwargs.update({key: value for key, value in default_args.items() if key in partial_keys}) + partial_kwargs.update(task_kwargs) task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) if task_group: diff --git a/tests/decorators/test_mapped.py b/tests/decorators/test_mapped.py index 3812367425f8b..2d3747b5f34ef 100644 --- a/tests/decorators/test_mapped.py +++ b/tests/decorators/test_mapped.py @@ -17,6 +17,9 @@ # under the License. from __future__ import annotations +import pytest + +from airflow.decorators import task from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup from tests.models import DEFAULT_DATE @@ -36,3 +39,24 @@ def f(z): dag.get_task("t1") == x1.operator dag.get_task("g.t2") == x2.operator + + +@pytest.mark.db_test +def test_mapped_task_with_arbitrary_default_args(dag_maker, session): + default_args = {"some": "value", "not": "in", "the": "task", "or": "dag"} + with dag_maker(session=session, default_args=default_args): + + @task.python(do_xcom_push=True) + def f(x: int, y: int) -> int: + return x + y + + f.partial(y=10).expand(x=[1, 2, 3]) + + dag_run = dag_maker.create_dagrun(session=session) + decision = dag_run.task_instance_scheduling_decisions(session=session) + xcoms = set() + for ti in decision.schedulable_tis: + ti.run(session=session) + xcoms.add(ti.xcom_pull(session=session, task_ids=ti.task_id, map_indexes=ti.map_index)) + + assert xcoms == {11, 12, 13} From 4fea1bc0b246a51d81e88e524341811f2bdf4405 Mon Sep 17 00:00:00 2001 From: Usiel Riedl Date: Wed, 2 Oct 2024 09:41:11 +0800 Subject: [PATCH 240/349] Adds new `triggerer.capacity_left[.]` metric (#41323) After reducing the default capacity our deployment, it is rather close to the total capacity at certain times, hence it would be useful to be able to create monitoring and alerting based on the left capacity. The new metric will enable better alerting (not relying on hardcoded capacity values) and even auto-scaling if wished for. --- airflow/jobs/triggerer_job_runner.py | 6 ++++++ .../logging-monitoring/metrics.rst | 3 +++ 2 files changed, 9 insertions(+) diff --git a/airflow/jobs/triggerer_job_runner.py b/airflow/jobs/triggerer_job_runner.py index b41af29f376ba..defde4a16471c 100644 --- a/airflow/jobs/triggerer_job_runner.py +++ b/airflow/jobs/triggerer_job_runner.py @@ -430,9 +430,15 @@ def emit_metrics(self): Stats.gauge( "triggers.running", len(self.trigger_runner.triggers), tags={"hostname": self.job.hostname} ) + + capacity_left = self.capacity - len(self.trigger_runner.triggers) + Stats.gauge(f"triggerer.capacity_left.{self.job.hostname}", capacity_left) + Stats.gauge("triggerer.capacity_left", capacity_left, tags={"hostname": self.job.hostname}) + span = Trace.get_current_span() span.set_attribute("trigger host", self.job.hostname) span.set_attribute("triggers running", len(self.trigger_runner.triggers)) + span.set_attribute("capacity left", capacity_left) class TriggerDetails(TypedDict): diff --git a/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst b/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst index 079aa5d397a41..7ce9b9b765a9c 100644 --- a/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst +++ b/docs/apache-airflow/administration-and-deployment/logging-monitoring/metrics.rst @@ -248,6 +248,9 @@ Name Description ``triggers.running.`` Number of triggers currently running for a triggerer (described by hostname) ``triggers.running`` Number of triggers currently running for a triggerer (described by hostname). Metric with hostname tagging. +``triggerer.capacity_left.`` Capacity left on a triggerer to run triggers (described by hostname) +``triggerer.capacity_left`` Capacity left on a triggerer to run triggers (described by hostname). + Metric with hostname tagging. ==================================================== ======================================================================== Timers From 296f84c8741c2a0d30d4d625d5b6372798bfc93d Mon Sep 17 00:00:00 2001 From: Andor Markus <51825189+andormarkus@users.noreply.github.com> Date: Wed, 2 Oct 2024 03:51:48 +0200 Subject: [PATCH 241/349] [HELM] - Add guide how to PgBouncer with Kubernetes Secret (#42460) * feat: Add guide how to PgBouncer with Kubernetes Secret * feat: Add guide how to PgBouncer with Kubernetes Secret --------- Co-authored-by: Andor Markus (AllCloud) --- docs/helm-chart/production-guide.rst | 74 ++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/docs/helm-chart/production-guide.rst b/docs/helm-chart/production-guide.rst index ee1fc2308be43..020394c8583fc 100644 --- a/docs/helm-chart/production-guide.rst +++ b/docs/helm-chart/production-guide.rst @@ -91,10 +91,84 @@ If you are using PostgreSQL as your database, you will likely want to enable `Pg Airflow can open a lot of database connections due to its distributed nature and using a connection pooler can significantly reduce the number of open connections on the database. +Database credentials stored Values file +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: yaml + + pgbouncer: + enabled: true + + +Database credentials stored Kubernetes Secret +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The default connection string in this case will not work you need to modify accordingly + +.. code-block:: bash + + kubectl create secret generic mydatabase --from-literal=connection=postgresql://user:pass@pgbouncer_svc_name.deployment_namespace:6543/airflow-metadata + +Two additional Kubernetes Secret required to PgBouncer able to properly work in this configuration: + +``airflow-pgbouncer-stats`` + +.. code-block:: bash + + kubectl create secret generic airflow-pgbouncer-stats --from-literal=connection=postgresql://user:pass@127.0.0.1:6543/pgbouncer?sslmode=disable + +``airflow-pgbouncer-config`` + +.. code-block:: yaml + + apiVersion: v1 + kind: Secret + metadata: + name: airflow-pgbouncer-config + data: + pgbouncer.ini: dmFsdWUtMg0KDQo= + users.txt: dmFsdWUtMg0KDQo= + + +``pgbouncer.ini`` equal to the base64 encoded version of this text + +.. code-block:: text + + [databases] + airflow-metadata = host={external_database_host} dbname={external_database_dbname} port=5432 pool_size=10 + + [pgbouncer] + pool_mode = transaction + listen_port = 6543 + listen_addr = * + auth_type = scram-sha-256 + auth_file = /etc/pgbouncer/users.txt + stats_users = postgres + ignore_startup_parameters = extra_float_digits + max_client_conn = 100 + verbose = 0 + log_disconnections = 0 + log_connections = 0 + + server_tls_sslmode = prefer + server_tls_ciphers = normal + +``users.txt`` equal to the base64 encoded version of this text + +.. code-block:: text + + "{ external_database_host }" "{ external_database_pass }" + +The ``values.yaml`` should looks like this + .. code-block:: yaml pgbouncer: enabled: true + configSecretName: airflow-pgbouncer-config + metricsExporterSidecar: + statsSecretName: airflow-pgbouncer-stats + Depending on the size of your Airflow instance, you may want to adjust the following as well (defaults are shown): From 830f35bf5acfdd63d26da47492e6a0acbb22eb3d Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 1 Oct 2024 18:54:15 -0700 Subject: [PATCH 242/349] Add dag run creation logic for backfill (#42529) Add basic backfill creation logic. This will be refined, but we're trying to be incremental here. --- .../endpoints/backfill_endpoint.py | 54 ++----- airflow/models/backfill.py | 127 ++++++++++++++- airflow/utils/types.py | 1 + .../endpoints/test_backfill_endpoint.py | 30 ++-- tests/models/test_backfill.py | 152 ++++++++++++++++++ 5 files changed, 308 insertions(+), 56 deletions(-) create mode 100644 tests/models/test_backfill.py diff --git a/airflow/api_connexion/endpoints/backfill_endpoint.py b/airflow/api_connexion/endpoints/backfill_endpoint.py index f974be4d75d82..baafdeea4f992 100644 --- a/airflow/api_connexion/endpoints/backfill_endpoint.py +++ b/airflow/api_connexion/endpoints/backfill_endpoint.py @@ -19,9 +19,10 @@ import logging from functools import wraps -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import pendulum +from pendulum import DateTime from sqlalchemy import select from airflow.api_connexion import security @@ -31,8 +32,7 @@ backfill_collection_schema, backfill_schema, ) -from airflow.models.backfill import Backfill -from airflow.models.serialized_dag import SerializedDagModel +from airflow.models.backfill import AlreadyRunningBackfill, Backfill, _create_backfill from airflow.utils import timezone from airflow.utils.session import NEW_SESSION, provide_session from airflow.www.decorators import action_logging @@ -64,33 +64,6 @@ def wrapper(*, backfill_id, session, **kwargs): return wrapper -@provide_session -def _create_backfill( - *, - dag_id: str, - from_date: str, - to_date: str, - max_active_runs: int, - reverse: bool, - dag_run_conf: dict | None, - session: Session = NEW_SESSION, -) -> Backfill: - serdag = session.get(SerializedDagModel, dag_id) - if not serdag: - raise NotFound(f"Could not find dag {dag_id}") - - br = Backfill( - dag_id=dag_id, - from_date=pendulum.parse(from_date), - to_date=pendulum.parse(to_date), - max_active_runs=max_active_runs, - dag_run_conf=dag_run_conf, - ) - session.add(br) - session.commit() - return br - - @security.requires_access_dag("GET") @action_logging @provide_session @@ -170,12 +143,15 @@ def create_backfill( reverse: bool = False, dag_run_conf: dict | None = None, ) -> APIResponse: - backfill_obj = _create_backfill( - dag_id=dag_id, - from_date=from_date, - to_date=to_date, - max_active_runs=max_active_runs, - reverse=reverse, - dag_run_conf=dag_run_conf, - ) - return backfill_schema.dump(backfill_obj) + try: + backfill_obj = _create_backfill( + dag_id=dag_id, + from_date=cast(DateTime, pendulum.parse(from_date)), + to_date=cast(DateTime, pendulum.parse(to_date)), + max_active_runs=max_active_runs, + reverse=reverse, + dag_run_conf=dag_run_conf, + ) + return backfill_schema.dump(backfill_obj) + except AlreadyRunningBackfill: + raise Conflict(f"There is already a running backfill for dag {dag_id}") diff --git a/airflow/models/backfill.py b/airflow/models/backfill.py index 8ff2541353688..6d3a8ee4fa922 100644 --- a/airflow/models/backfill.py +++ b/airflow/models/backfill.py @@ -15,15 +15,40 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +Internal classes for management of dag backfills. + +:meta private: +""" + from __future__ import annotations -from sqlalchemy import Boolean, Column, Integer, UniqueConstraint +import logging +from typing import TYPE_CHECKING + +from sqlalchemy import Boolean, Column, ForeignKeyConstraint, Integer, UniqueConstraint, func, select +from sqlalchemy.orm import relationship from sqlalchemy_jsonfield import JSONField +from airflow.api_connexion.exceptions import NotFound +from airflow.exceptions import AirflowException from airflow.models.base import Base, StringID +from airflow.models.serialized_dag import SerializedDagModel from airflow.settings import json from airflow.utils import timezone +from airflow.utils.session import create_session from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.state import DagRunState +from airflow.utils.types import DagRunTriggeredByType, DagRunType + +if TYPE_CHECKING: + from pendulum import DateTime + +log = logging.getLogger(__name__) + + +class AlreadyRunningBackfill(AirflowException): + """Raised when attempting to create backfill and one already active.""" class Backfill(Base): @@ -47,6 +72,11 @@ class Backfill(Base): completed_at = Column(UtcDateTime, nullable=True) updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + backfill_dag_run_associations = relationship("BackfillDagRun", back_populates="backfill") + + def __repr__(self): + return f"Backfill({self.dag_id=}, {self.from_date=}, {self.to_date=})" + class BackfillDagRun(Base): """Mapping table between backfill run and dag run.""" @@ -59,4 +89,97 @@ class BackfillDagRun(Base): ) # the run might already exist; we could store the reason we did not create sort_ordinal = Column(Integer, nullable=False) - __table_args__ = (UniqueConstraint("backfill_id", "dag_run_id", name="ix_bdr_backfill_id_dag_run_id"),) + backfill = relationship("Backfill", back_populates="backfill_dag_run_associations") + dag_run = relationship("DagRun") + + __table_args__ = ( + UniqueConstraint("backfill_id", "dag_run_id", name="ix_bdr_backfill_id_dag_run_id"), + ForeignKeyConstraint( + [backfill_id], + ["backfill.id"], + name="bdr_backfill_fkey", + ondelete="cascade", + ), + ForeignKeyConstraint( + [dag_run_id], + ["dag_run.id"], + name="bdr_dag_run_fkey", + ondelete="set null", + ), + ) + + +def _create_backfill( + *, + dag_id: str, + from_date: DateTime, + to_date: DateTime, + max_active_runs: int, + reverse: bool, + dag_run_conf: dict | None, +) -> Backfill | None: + with create_session() as session: + serdag = session.get(SerializedDagModel, dag_id) + if not serdag: + raise NotFound(f"Could not find dag {dag_id}") + + num_active = session.scalar( + select(func.count()).where(Backfill.dag_id == dag_id, Backfill.completed_at.is_(None)) + ) + if num_active > 0: + raise AlreadyRunningBackfill( + f"Another backfill is running for dag {dag_id}. " + f"There can be only one running backfill per dag." + ) + + br = Backfill( + dag_id=dag_id, + from_date=from_date, + to_date=to_date, + max_active_runs=max_active_runs, + dag_run_conf=dag_run_conf, + ) + session.add(br) + session.commit() + + dag = serdag.dag + depends_on_past = any(x.depends_on_past for x in dag.tasks) + if depends_on_past: + if reverse is True: + raise ValueError( + "Backfill cannot be run in reverse when the dag has tasks where depends_on_past=True" + ) + + backfill_sort_ordinal = 0 + dagrun_info_list = dag.iter_dagrun_infos_between(from_date, to_date) + if reverse: + dagrun_info_list = reversed([x for x in dag.iter_dagrun_infos_between(from_date, to_date)]) + for info in dagrun_info_list: + backfill_sort_ordinal += 1 + log.info("creating backfill dag run %s dag_id=%s backfill_id=%s, info=", dag.dag_id, br.id, info) + dr = None + try: + dr = dag.create_dagrun( + triggered_by=DagRunTriggeredByType.BACKFILL, + execution_date=info.logical_date, + data_interval=info.data_interval, + start_date=timezone.utcnow(), + state=DagRunState.QUEUED, + external_trigger=False, + conf=br.dag_run_conf, + run_type=DagRunType.BACKFILL_JOB, + creating_job_id=None, + session=session, + ) + except Exception: + dag.log.exception("something failed") + session.rollback() + session.add( + BackfillDagRun( + backfill_id=br.id, + dag_run_id=dr.id if dr else None, # this means we failed to create the dag run + sort_ordinal=backfill_sort_ordinal, + ) + ) + session.commit() + return br diff --git a/airflow/utils/types.py b/airflow/utils/types.py index a19b2534b03fb..80ee1d644d4d2 100644 --- a/airflow/utils/types.py +++ b/airflow/utils/types.py @@ -119,3 +119,4 @@ class DagRunTriggeredByType(enum.Enum): TEST = "test" # for dag.test() TIMETABLE = "timetable" # for timetable based triggering DATASET = "dataset" # for dataset_triggered run type + BACKFILL = "backfill" diff --git a/tests/api_connexion/endpoints/test_backfill_endpoint.py b/tests/api_connexion/endpoints/test_backfill_endpoint.py index 07b2a3fd56c2d..dd086339b73ac 100644 --- a/tests/api_connexion/endpoints/test_backfill_endpoint.py +++ b/tests/api_connexion/endpoints/test_backfill_endpoint.py @@ -27,14 +27,13 @@ from airflow.models import DagBag, DagModel from airflow.models.backfill import Backfill from airflow.models.dag import DAG -from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator from airflow.utils import timezone from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs, clear_db_serialized_dags -pytestmark = [pytest.mark.db_test] +pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] DAG_ID = "test_dag" @@ -44,6 +43,20 @@ UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')" +def _clean_db(): + clear_db_backfills() + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + +@pytest.fixture(autouse=True) +def clean_db(): + _clean_db() + yield + _clean_db() + + @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): app = minimal_app_for_api @@ -83,25 +96,14 @@ def configured_app(minimal_app_for_api): class TestBackfillEndpoint: - @staticmethod - def clean_db(): - clear_db_backfills() - clear_db_runs() - clear_db_dags() - clear_db_serialized_dags() - @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.clean_db() self.app = configured_app self.client = self.app.test_client() # type:ignore self.dag_id = DAG_ID self.dag2_id = DAG2_ID self.dag3_id = DAG3_ID - def teardown_method(self) -> None: - self.clean_db() - @provide_session def _create_dag_models(self, *, count=1, dag_id_prefix="TEST_DAG", is_paused=False, session=None): dags = [] @@ -258,8 +260,6 @@ class TestCreateBackfill(TestBackfillEndpoint): def test_create_backfill(self, user, expected, session, dag_maker): with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * * *") as dag: EmptyOperator(task_id="mytask") - session.add(SerializedDagModel(dag)) - session.commit() session.query(DagModel).all() from_date = pendulum.parse("2024-01-01") from_date_iso = from_date.isoformat() diff --git a/tests/models/test_backfill.py b/tests/models/test_backfill.py new file mode 100644 index 0000000000000..9a845f86803e0 --- /dev/null +++ b/tests/models/test_backfill.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from contextlib import nullcontext + +import pendulum +import pytest +from sqlalchemy import select + +from airflow.models import DagRun +from airflow.models.backfill import AlreadyRunningBackfill, Backfill, BackfillDagRun, _create_backfill +from airflow.operators.python import PythonOperator +from airflow.utils.state import DagRunState +from tests.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs, clear_db_serialized_dags + +pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] + + +def _clean_db(): + clear_db_backfills() + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + +@pytest.fixture(autouse=True) +def clean_db(): + _clean_db() + yield + _clean_db() + + +@pytest.mark.parametrize("dep_on_past", [True, False]) +def test_reverse_and_depends_on_past_fails(dep_on_past, dag_maker, session): + with dag_maker() as dag: + PythonOperator(task_id="hi", python_callable=print, depends_on_past=dep_on_past) + session.commit() + cm = nullcontext() + if dep_on_past: + cm = pytest.raises(ValueError, match="cannot be run in reverse") + b = None + with cm: + b = _create_backfill( + dag_id=dag.dag_id, + from_date=pendulum.parse("2021-01-01"), + to_date=pendulum.parse("2021-01-05"), + max_active_runs=2, + reverse=True, + dag_run_conf={}, + ) + if dep_on_past: + assert b is None + else: + assert b is not None + + +@pytest.mark.parametrize("reverse", [True, False]) +def test_simple(reverse, dag_maker, session): + """ + Verify simple case behavior. + + This test verifies that runs in the range are created according + to schedule intervals, and the sort ordinal is correct. Also verifies + that dag runs are created in the queued state. + """ + with dag_maker(schedule="@daily") as dag: + PythonOperator(task_id="hi", python_callable=print) + b = _create_backfill( + dag_id=dag.dag_id, + from_date=pendulum.parse("2021-01-01"), + to_date=pendulum.parse("2021-01-05"), + max_active_runs=2, + reverse=reverse, + dag_run_conf={}, + ) + query = ( + select(DagRun) + .join(BackfillDagRun.dag_run) + .where(BackfillDagRun.backfill_id == b.id) + .order_by(BackfillDagRun.sort_ordinal) + ) + dag_runs = session.scalars(query).all() + dates = [str(x.logical_date.date()) for x in dag_runs] + expected_dates = ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-04", "2021-01-05"] + if reverse: + expected_dates = list(reversed(expected_dates)) + assert dates == expected_dates + assert all(x.state == DagRunState.QUEUED for x in dag_runs) + + +def test_params_stored_correctly(dag_maker, session): + with dag_maker(schedule="@daily") as dag: + PythonOperator(task_id="hi", python_callable=print) + b = _create_backfill( + dag_id=dag.dag_id, + from_date=pendulum.parse("2021-01-01"), + to_date=pendulum.parse("2021-01-05"), + max_active_runs=263, + reverse=False, + dag_run_conf={"this": "param"}, + ) + session.expunge_all() + b_stored = session.get(Backfill, b.id) + assert all( + ( + b_stored.dag_id == b.dag_id, + b_stored.from_date == b.from_date, + b_stored.to_date == b.to_date, + b_stored.max_active_runs == b.max_active_runs, + b_stored.dag_run_conf == b.dag_run_conf, + ) + ) + + +def test_active_dag_run(dag_maker, session): + with dag_maker(schedule="@daily") as dag: + PythonOperator(task_id="hi", python_callable=print) + session.commit() + b1 = _create_backfill( + dag_id=dag.dag_id, + from_date=pendulum.parse("2021-01-01"), + to_date=pendulum.parse("2021-01-05"), + max_active_runs=10, + reverse=False, + dag_run_conf={"this": "param"}, + ) + assert b1 is not None + with pytest.raises(AlreadyRunningBackfill, match="Another backfill is running for dag"): + _create_backfill( + dag_id=dag.dag_id, + from_date=pendulum.parse("2021-02-01"), + to_date=pendulum.parse("2021-02-05"), + max_active_runs=10, + reverse=False, + dag_run_conf={"this": "param"}, + ) From 5600388cdb18565ab901f5c21c3d693cabe3def9 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 1 Oct 2024 19:23:44 -0700 Subject: [PATCH 243/349] All executors should inherit from BaseExecutor (#41904) --- .../executors/celery_kubernetes_executor.py | 34 +++++++++++++++---- .../executors/local_kubernetes_executor.py | 34 +++++++++++++++---- 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/airflow/providers/celery/executors/celery_kubernetes_executor.py b/airflow/providers/celery/executors/celery_kubernetes_executor.py index bc2ed7904f5a5..acd1afcba995a 100644 --- a/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Sequence from airflow.configuration import conf +from airflow.executors.base_executor import BaseExecutor from airflow.providers.celery.executors.celery_executor import CeleryExecutor try: @@ -30,18 +31,21 @@ raise AirflowOptionalProviderFeatureException(e) -from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.providers_configuration_loader import providers_configuration_loaded if TYPE_CHECKING: from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest - from airflow.executors.base_executor import CommandType, EventBufferValueType, QueuedTaskInstanceType + from airflow.executors.base_executor import ( + CommandType, + EventBufferValueType, + QueuedTaskInstanceType, + ) from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey -class CeleryKubernetesExecutor(LoggingMixin): +class CeleryKubernetesExecutor(BaseExecutor): """ CeleryKubernetesExecutor consists of CeleryExecutor and KubernetesExecutor. @@ -71,11 +75,21 @@ def kubernetes_queue(self) -> str: def __init__(self, celery_executor: CeleryExecutor, kubernetes_executor: KubernetesExecutor): super().__init__() - self._job_id: int | None = None + self._job_id: int | str | None = None self.celery_executor = celery_executor self.kubernetes_executor = kubernetes_executor self.kubernetes_executor.kubernetes_queue = self.kubernetes_queue + @property + def _task_event_logs(self): + self.celery_executor._task_event_logs += self.kubernetes_executor._task_event_logs + self.kubernetes_executor._task_event_logs.clear() + return self.celery_executor._task_event_logs + + @_task_event_logs.setter + def _task_event_logs(self, value): + """Not implemented for hybrid executors.""" + @property def queued_tasks(self) -> dict[TaskInstanceKey, QueuedTaskInstanceType]: """Return queued tasks from celery and kubernetes executor.""" @@ -84,13 +98,21 @@ def queued_tasks(self) -> dict[TaskInstanceKey, QueuedTaskInstanceType]: return queued_tasks + @queued_tasks.setter + def queued_tasks(self, value) -> None: + """Not implemented for hybrid executors.""" + @property def running(self) -> set[TaskInstanceKey]: """Return running tasks from celery and kubernetes executor.""" return self.celery_executor.running.union(self.kubernetes_executor.running) + @running.setter + def running(self, value) -> None: + """Not implemented for hybrid executors.""" + @property - def job_id(self) -> int | None: + def job_id(self) -> int | str | None: """ Inherited attribute from BaseExecutor. @@ -100,7 +122,7 @@ def job_id(self) -> int | None: return self._job_id @job_id.setter - def job_id(self, value: int | None) -> None: + def job_id(self, value: int | str | None) -> None: """Expose job ID for SchedulerJob.""" self._job_id = value self.kubernetes_executor.job_id = value diff --git a/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 75de1101c59ba..63755d3d11a1c 100644 --- a/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -20,18 +20,22 @@ from typing import TYPE_CHECKING, Sequence from airflow.configuration import conf +from airflow.executors.base_executor import BaseExecutor from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubernetesExecutor -from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest - from airflow.executors.base_executor import CommandType, EventBufferValueType, QueuedTaskInstanceType + from airflow.executors.base_executor import ( + CommandType, + EventBufferValueType, + QueuedTaskInstanceType, + ) from airflow.executors.local_executor import LocalExecutor from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey -class LocalKubernetesExecutor(LoggingMixin): +class LocalKubernetesExecutor(BaseExecutor): """ Chooses between LocalExecutor and KubernetesExecutor based on the queue defined on the task. @@ -57,11 +61,21 @@ class LocalKubernetesExecutor(LoggingMixin): def __init__(self, local_executor: LocalExecutor, kubernetes_executor: KubernetesExecutor): super().__init__() - self._job_id: str | None = None + self._job_id: int | str | None = None self.local_executor = local_executor self.kubernetes_executor = kubernetes_executor self.kubernetes_executor.kubernetes_queue = self.KUBERNETES_QUEUE + @property + def _task_event_logs(self): + self.local_executor._task_event_logs += self.kubernetes_executor._task_event_logs + self.kubernetes_executor._task_event_logs.clear() + return self.local_executor._task_event_logs + + @_task_event_logs.setter + def _task_event_logs(self, value): + """Not implemented for hybrid executors.""" + @property def queued_tasks(self) -> dict[TaskInstanceKey, QueuedTaskInstanceType]: """Return queued tasks from local and kubernetes executor.""" @@ -70,13 +84,21 @@ def queued_tasks(self) -> dict[TaskInstanceKey, QueuedTaskInstanceType]: return queued_tasks + @queued_tasks.setter + def queued_tasks(self, value) -> None: + """Not implemented for hybrid executors.""" + @property def running(self) -> set[TaskInstanceKey]: """Return running tasks from local and kubernetes executor.""" return self.local_executor.running.union(self.kubernetes_executor.running) + @running.setter + def running(self, value) -> None: + """Not implemented for hybrid executors.""" + @property - def job_id(self) -> str | None: + def job_id(self) -> int | str | None: """ Inherited attribute from BaseExecutor. @@ -86,7 +108,7 @@ def job_id(self) -> str | None: return self._job_id @job_id.setter - def job_id(self, value: str | None) -> None: + def job_id(self, value: int | str | None) -> None: """Expose job ID for SchedulerJob.""" self._job_id = value self.kubernetes_executor.job_id = value From a7ff6a57e0e27dc9fcaf7124ca4cb71a7b63ac94 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 1 Oct 2024 21:35:14 -0700 Subject: [PATCH 244/349] Revert "Fix the order of tasks during serialization (#42219)" (#42646) This reverts commit adb9466bd7ce1c92e51f11a90d39fd557c99dc5b a.k.a. PR #42219. Was causing tests to fail. --- airflow/serialization/serialized_objects.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 08944391b8166..a4801b767acc5 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1604,9 +1604,7 @@ def serialize_dag(cls, dag: DAG) -> dict: try: serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields) serialized_dag["_processor_dags_folder"] = DAGS_FOLDER - serialized_dag["tasks"] = [ - cls.serialize(dag.task_dict[task_id]) for task_id in sorted(dag.task_dict) - ] + serialized_dag["tasks"] = [cls.serialize(task) for _, task in dag.task_dict.items()] dag_deps = [ dep From f5e37119d405e23d886598accaddb92329598ec7 Mon Sep 17 00:00:00 2001 From: Karen Braganza Date: Wed, 2 Oct 2024 02:20:43 -0400 Subject: [PATCH 245/349] Check pool_slots on partial task import instead of execution (#39724) Co-authored-by: Ryan Hatter <25823361+RNHTTR@users.noreply.github.com> --- airflow/decorators/base.py | 6 ++++++ airflow/models/baseoperator.py | 5 +++++ tests/models/test_mappedoperator.py | 9 +++++++++ 3 files changed, 20 insertions(+) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index e650c1920a870..bb9602d50c1cd 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -468,6 +468,12 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME + if "pool_slots" in partial_kwargs: + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_timedelta( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 20656586ba01e..9e0c8e1e69b61 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -358,6 +358,11 @@ def partial( partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"]) if partial_kwargs["pool"] is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") if partial_kwargs["max_retry_delay"] is not None: diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 2b0cd50165c45..0571e07e671f8 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -220,6 +220,15 @@ def test_partial_on_class_invalid_ctor_args() -> None: MockOperator.partial(task_id="a", foo="bar", bar=2) +def test_partial_on_invalid_pool_slots_raises() -> None: + """Test that when we pass an invalid value to pool_slots in partial(), + + i.e. if the value is not an integer, an error is raised at import time.""" + + with pytest.raises(TypeError, match="'<' not supported between instances of 'str' and 'int'"): + MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.parametrize( ["num_existing_tis", "expected"], From 7ea01a3f02e84876f5bb41f2ff685aa549483d58 Mon Sep 17 00:00:00 2001 From: TakawaAkirayo <153728772+TakawaAkirayo@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:51:08 +0800 Subject: [PATCH 246/349] Add retry logic in the scheduler for updating trigger timeouts in case of deadlocks. (#41429) * Add retry in update trigger timeout * add ut for these cases * use OperationalError in ut to describe deadlock scenarios * [MINOR] add newsfragment for this PR * [MINOR] refactor UT for mypy check --- airflow/jobs/scheduler_job_runner.py | 36 +++++++------ newsfragments/41429.improvement.rst | 1 + tests/jobs/test_scheduler_job.py | 78 +++++++++++++++++++++++++++- 3 files changed, 98 insertions(+), 17 deletions(-) create mode 100644 newsfragments/41429.improvement.rst diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 242154820df9e..de6ce5019b9de 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -1884,23 +1884,27 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: return len(to_reset) @provide_session - def check_trigger_timeouts(self, session: Session = NEW_SESSION) -> None: + def check_trigger_timeouts( + self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION + ) -> None: """Mark any "deferred" task as failed if the trigger or execution timeout has passed.""" - num_timed_out_tasks = session.execute( - update(TI) - .where( - TI.state == TaskInstanceState.DEFERRED, - TI.trigger_timeout < timezone.utcnow(), - ) - .values( - state=TaskInstanceState.SCHEDULED, - next_method="__fail__", - next_kwargs={"error": "Trigger/execution timeout"}, - trigger_id=None, - ) - ).rowcount - if num_timed_out_tasks: - self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) + for attempt in run_with_db_retries(max_retries, logger=self.log): + with attempt: + num_timed_out_tasks = session.execute( + update(TI) + .where( + TI.state == TaskInstanceState.DEFERRED, + TI.trigger_timeout < timezone.utcnow(), + ) + .values( + state=TaskInstanceState.SCHEDULED, + next_method="__fail__", + next_kwargs={"error": "Trigger/execution timeout"}, + trigger_id=None, + ) + ).rowcount + if num_timed_out_tasks: + self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) # [START find_zombies] def _find_zombies(self) -> None: diff --git a/newsfragments/41429.improvement.rst b/newsfragments/41429.improvement.rst new file mode 100644 index 0000000000000..6d04d5dfe61af --- /dev/null +++ b/newsfragments/41429.improvement.rst @@ -0,0 +1 @@ +Add ``run_with_db_retries`` when the scheduler updates the deferred Task as failed to tolerate database deadlock issues. diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 32662d7d873db..40a7220698407 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -148,7 +148,7 @@ def clean_db(): @pytest.fixture(autouse=True) def per_test(self) -> Generator: self.clean_db() - self.job_runner = None + self.job_runner: SchedulerJobRunner | None = None yield @@ -5192,6 +5192,82 @@ def test_timeout_triggers(self, dag_maker): assert ti1.next_method == "__fail__" assert ti2.state == State.DEFERRED + def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker): + """ + Tests that it will retry on DB error like deadlock when updating timeout triggers. + """ + from sqlalchemy.exc import OperationalError + + retry_times = 3 + + session = settings.Session() + # Create the test DAG and task + with dag_maker( + dag_id="test_retry_on_db_error_when_update_timeout_triggers", + start_date=DEFAULT_DATE, + schedule="@once", + max_active_runs=1, + session=session, + ): + EmptyOperator(task_id="dummy1") + + # Mock the db failure within retry times + might_fail_session = MagicMock(wraps=session) + + def check_if_trigger_timeout(max_retries: int): + def make_side_effect(): + call_count = 0 + + def side_effect(*args, **kwargs): + nonlocal call_count + if call_count < retry_times - 1: + call_count += 1 + raise OperationalError("any_statement", "any_params", "any_orig") + else: + return session.execute(*args, **kwargs) + + return side_effect + + might_fail_session.execute.side_effect = make_side_effect() + + try: + # Create a Task Instance for the task that is allegedly deferred + # but past its timeout, and one that is still good. + # We don't actually need a linked trigger here; the code doesn't check. + dr1 = dag_maker.create_dagrun() + dr2 = dag_maker.create_dagrun( + run_id="test2", execution_date=DEFAULT_DATE + datetime.timedelta(seconds=1) + ) + ti1 = dr1.get_task_instance("dummy1", session) + ti2 = dr2.get_task_instance("dummy1", session) + ti1.state = State.DEFERRED + ti1.trigger_timeout = timezone.utcnow() - datetime.timedelta(seconds=60) + ti2.state = State.DEFERRED + ti2.trigger_timeout = timezone.utcnow() + datetime.timedelta(seconds=60) + session.flush() + + # Boot up the scheduler and make it check timeouts + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.check_trigger_timeouts(max_retries=max_retries, session=might_fail_session) + + # Make sure that TI1 is now scheduled to fail, and 2 wasn't touched + session.refresh(ti1) + session.refresh(ti2) + assert ti1.state == State.SCHEDULED + assert ti1.next_method == "__fail__" + assert ti2.state == State.DEFERRED + finally: + self.clean_db() + + # Positive case, will retry until success before reach max retry times + check_if_trigger_timeout(retry_times) + + # Negative case: no retries, execute only once. + with pytest.raises(OperationalError): + check_if_trigger_timeout(1) + def test_find_zombies_nothing(self): executor = MockExecutor(do_update=False) scheduler_job = Job(executor=executor) From 4ee2b0b806f643125b086251acd08c299394c149 Mon Sep 17 00:00:00 2001 From: GPK Date: Wed, 2 Oct 2024 07:59:17 +0100 Subject: [PATCH 247/349] Fix consistent return response from PubSubPullSensor (#42080) * fix consistent return response pubsubsensor * removed messages_callback argument to pubsub trigger and using it in execute_complete * updated variable name * updates as per comments, added return types and refactored logic * update types, tests and use inherit exception --- .../providers/google/cloud/sensors/pubsub.py | 24 +++++++- .../providers/google/cloud/triggers/pubsub.py | 22 +++----- .../google/cloud/sensors/test_pubsub.py | 48 ++++++++++++++++ .../google/cloud/triggers/test_pubsub.py | 55 ++++++++++++++++++- 4 files changed, 129 insertions(+), 20 deletions(-) diff --git a/airflow/providers/google/cloud/sensors/pubsub.py b/airflow/providers/google/cloud/sensors/pubsub.py index cb224d42979b7..aa74411f072e5 100644 --- a/airflow/providers/google/cloud/sensors/pubsub.py +++ b/airflow/providers/google/cloud/sensors/pubsub.py @@ -22,6 +22,7 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Callable, Sequence +from google.cloud import pubsub_v1 from google.cloud.pubsub_v1.types import ReceivedMessage from airflow.configuration import conf @@ -34,6 +35,10 @@ from airflow.utils.context import Context +class PubSubMessageTransformException(AirflowException): + """Raise when messages failed to convert pubsub received format.""" + + class PubSubPullSensor(BaseSensorOperator): """ Pulls messages from a PubSub subscription and passes them through XCom. @@ -170,7 +175,6 @@ def execute(self, context: Context) -> None: subscription=self.subscription, max_messages=self.max_messages, ack_messages=self.ack_messages, - messages_callback=self.messages_callback, poke_interval=self.poke_interval, gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -178,14 +182,28 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[str]]) -> str | list[str]: - """Return immediately and relies on trigger to throw a success event. Callback for the trigger.""" + def execute_complete(self, context: Context, event: dict[str, str | list[str]]) -> Any: + """If messages_callback is provided, execute it; otherwise, return immediately with trigger event message.""" if event["status"] == "success": self.log.info("Sensor pulls messages: %s", event["message"]) + if self.messages_callback: + received_messages = self._convert_to_received_messages(event["message"]) + _return_value = self.messages_callback(received_messages, context) + return _return_value + return event["message"] self.log.info("Sensor failed: %s", event["message"]) raise AirflowException(event["message"]) + def _convert_to_received_messages(self, messages: Any) -> list[ReceivedMessage]: + try: + received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in messages] + return received_messages + except Exception as e: + raise PubSubMessageTransformException( + f"Error converting triggerer event message back to received message format: {e}" + ) + def _default_message_callback( self, pulled_messages: list[ReceivedMessage], diff --git a/airflow/providers/google/cloud/triggers/pubsub.py b/airflow/providers/google/cloud/triggers/pubsub.py index 535bfe2ba1c68..db3fe409e942b 100644 --- a/airflow/providers/google/cloud/triggers/pubsub.py +++ b/airflow/providers/google/cloud/triggers/pubsub.py @@ -19,16 +19,13 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Sequence +from typing import Any, AsyncIterator, Sequence + +from google.cloud.pubsub_v1.types import ReceivedMessage from airflow.providers.google.cloud.hooks.pubsub import PubSubAsyncHook from airflow.triggers.base import BaseTrigger, TriggerEvent -if TYPE_CHECKING: - from google.cloud.pubsub_v1.types import ReceivedMessage - - from airflow.utils.context import Context - class PubsubPullTrigger(BaseTrigger): """ @@ -41,11 +38,6 @@ class PubsubPullTrigger(BaseTrigger): :param ack_messages: If True, each message will be acknowledged immediately rather than by any downstream tasks :param gcp_conn_id: Reference to google cloud connection id - :param messages_callback: (Optional) Callback to process received messages. - Its return value will be saved to XCom. - If you are pulling large messages, you probably want to provide a custom callback. - If not provided, the default implementation will convert `ReceivedMessage` objects - into JSON-serializable dicts using `google.protobuf.json_format.MessageToDict` function. :param poke_interval: polling period in seconds to check for the status :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token @@ -64,7 +56,6 @@ def __init__( max_messages: int, ack_messages: bool, gcp_conn_id: str, - messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None, poke_interval: float = 10.0, impersonation_chain: str | Sequence[str] | None = None, ): @@ -73,7 +64,6 @@ def __init__( self.subscription = subscription self.max_messages = max_messages self.ack_messages = ack_messages - self.messages_callback = messages_callback self.poke_interval = poke_interval self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -88,7 +78,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "subscription": self.subscription, "max_messages": self.max_messages, "ack_messages": self.ack_messages, - "messages_callback": self.messages_callback, "poke_interval": self.poke_interval, "gcp_conn_id": self.gcp_conn_id, "impersonation_chain": self.impersonation_chain, @@ -106,7 +95,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] ): if self.ack_messages: await self.message_acknowledgement(pulled_messages) - yield TriggerEvent({"status": "success", "message": pulled_messages}) + + messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages] + + yield TriggerEvent({"status": "success", "message": messages_json}) return self.log.info("Sleeping for %s seconds.", self.poke_interval) await asyncio.sleep(self.poke_interval) diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py b/tests/providers/google/cloud/sensors/test_pubsub.py index a77167dda3037..5a3fb170b7482 100644 --- a/tests/providers/google/cloud/sensors/test_pubsub.py +++ b/tests/providers/google/cloud/sensors/test_pubsub.py @@ -21,6 +21,7 @@ from unittest import mock import pytest +from google.cloud import pubsub_v1 from google.cloud.pubsub_v1.types import ReceivedMessage from airflow.exceptions import AirflowException, TaskDeferred @@ -197,3 +198,50 @@ def test_pubsub_pull_sensor_async_execute_complete(self): with mock.patch.object(operator.log, "info") as mock_log_info: operator.execute_complete(context={}, event={"status": "success", "message": test_message}) mock_log_info.assert_called_with("Sensor pulls messages: %s", test_message) + + @mock.patch("airflow.providers.google.cloud.sensors.pubsub.PubSubHook") + def test_pubsub_pull_sensor_async_execute_complete_use_message_callback(self, mock_hook): + test_message = [ + { + "ack_id": "UAYWLF1GSFE3GQhoUQ5PXiM_NSAoRRIJB08CKF15MU0sQVhwaFENGXJ9YHxrUxsDV0ECel1RGQdoTm11H4GglfRLQ1RrWBIHB01Vel5TEwxoX11wBnm4vPO6v8vgfwk9OpX-8tltO6ywsP9GZiM9XhJLLD5-LzlFQV5AEkwkDERJUytDCypYEU4EISE-MD5FU0Q", + "message": { + "data": "aGkgZnJvbSBjbG91ZCBjb25zb2xlIQ==", + "message_id": "12165864188103151", + "publish_time": "2024-08-28T11:49:50.962Z", + "attributes": {}, + "ordering_key": "", + }, + "delivery_attempt": 0, + } + ] + + received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in test_message] + + messages_callback_return_value = "custom_message_from_callback" + + def messages_callback( + pulled_messages: list[ReceivedMessage], + context: dict[str, Any], + ): + assert pulled_messages == received_messages + + assert isinstance(context, dict) + for key in context.keys(): + assert isinstance(key, str) + + return messages_callback_return_value + + operator = PubSubPullSensor( + task_id="test_task", + ack_messages=True, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + deferrable=True, + messages_callback=messages_callback, + ) + mock_hook.return_value.pull.return_value = received_messages + + with mock.patch.object(operator.log, "info") as mock_log_info: + resp = operator.execute_complete(context={}, event={"status": "success", "message": test_message}) + mock_log_info.assert_called_with("Sensor pulls messages: %s", test_message) + assert resp == messages_callback_return_value diff --git a/tests/providers/google/cloud/triggers/test_pubsub.py b/tests/providers/google/cloud/triggers/test_pubsub.py index d2294eb61414b..e1a4e178d2918 100644 --- a/tests/providers/google/cloud/triggers/test_pubsub.py +++ b/tests/providers/google/cloud/triggers/test_pubsub.py @@ -16,9 +16,13 @@ # under the License. from __future__ import annotations +from unittest import mock + import pytest +from google.cloud.pubsub_v1.types import ReceivedMessage from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger +from airflow.triggers.base import TriggerEvent TEST_POLL_INTERVAL = 10 TEST_GCP_CONN_ID = "google_cloud_default" @@ -34,13 +38,25 @@ def trigger(): subscription="subscription", max_messages=MAX_MESSAGES, ack_messages=ACK_MESSAGES, - messages_callback=None, poke_interval=TEST_POLL_INTERVAL, gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=None, ) +async def generate_messages(count: int) -> list[ReceivedMessage]: + return [ + ReceivedMessage( + ack_id=f"{i}", + message={ + "data": f"Message {i}".encode(), + "attributes": {"type": "generated message"}, + }, + ) + for i in range(1, count + 1) + ] + + class TestPubsubPullTrigger: def test_async_pubsub_pull_trigger_serialization_should_execute_successfully(self, trigger): """ @@ -54,8 +70,43 @@ def test_async_pubsub_pull_trigger_serialization_should_execute_successfully(sel "subscription": "subscription", "max_messages": MAX_MESSAGES, "ack_messages": ACK_MESSAGES, - "messages_callback": None, "poke_interval": TEST_POLL_INTERVAL, "gcp_conn_id": TEST_GCP_CONN_ID, "impersonation_chain": None, } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubAsyncHook.pull") + async def test_async_pubsub_pull_trigger_return_event(self, mock_pull): + mock_pull.return_value = generate_messages(1) + trigger = PubsubPullTrigger( + project_id=PROJECT_ID, + subscription="subscription", + max_messages=MAX_MESSAGES, + ack_messages=False, + poke_interval=TEST_POLL_INTERVAL, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=None, + ) + + expected_event = TriggerEvent( + { + "status": "success", + "message": [ + { + "ack_id": "1", + "message": { + "data": "TWVzc2FnZSAx", + "attributes": {"type": "generated message"}, + "message_id": "", + "ordering_key": "", + }, + "delivery_attempt": 0, + } + ], + } + ) + + response = await trigger.run().asend(None) + + assert response == expected_event From 51cb1ff7170bc44343a0da6b6689ab4583e4a64b Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 2 Oct 2024 01:35:43 -0700 Subject: [PATCH 248/349] Fix type-ignore comment for typing changes (#42656) --- tests/jobs/test_scheduler_job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 40a7220698407..97d84da9c4d58 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -5552,7 +5552,7 @@ def spy(*args, **kwargs): def watch_set_state(dr: DagRun, state, **kwargs): if state in (DagRunState.SUCCESS, DagRunState.FAILED): # Stop the scheduler - self.job_runner.num_runs = 1 # type: ignore[attr-defined] + self.job_runner.num_runs = 1 # type: ignore[union-attr] orig_set_state(dr, state, **kwargs) # type: ignore[call-arg] def watch_heartbeat(*args, **kwargs): From e0cf55efcc91c0836bdada843f0f24550fc9bf6c Mon Sep 17 00:00:00 2001 From: Bugra Ozturk Date: Wed, 2 Oct 2024 10:47:56 +0200 Subject: [PATCH 249/349] AIP-84 Migrate delete a connection to FastAPI API (#42571) * Include connections router and migrate delete a connection endpoint to fastapi * Mark tests as db_test * Use only pyfixture session * make method async * setup method to setup_attrs * Convert APIRouter tags, make setup method unified * Use AirflowRouter over fastapi.APIRouter --- .../endpoints/connection_endpoint.py | 2 + airflow/api_fastapi/openapi/v1-generated.yaml | 41 ++++++++++++ airflow/api_fastapi/views/public/__init__.py | 2 + .../api_fastapi/views/public/connections.py | 47 ++++++++++++++ airflow/ui/openapi-gen/queries/common.ts | 9 ++- airflow/ui/openapi-gen/queries/queries.ts | 45 ++++++++++++- .../ui/openapi-gen/requests/services.gen.ts | 30 +++++++++ airflow/ui/openapi-gen/requests/types.gen.ts | 33 ++++++++++ .../views/public/test_connections.py | 63 +++++++++++++++++++ 9 files changed, 270 insertions(+), 2 deletions(-) create mode 100644 airflow/api_fastapi/views/public/connections.py create mode 100644 tests/api_fastapi/views/public/test_connections.py diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index c17a9280d78f8..b28c9dfcafa79 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -40,6 +40,7 @@ from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.security import permissions from airflow.utils import helpers +from airflow.utils.api_migration import mark_fastapi_migration_done from airflow.utils.log.action_logger import action_event_from_permission from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.strings import get_random_string @@ -53,6 +54,7 @@ RESOURCE_EVENT_PREFIX = "connection" +@mark_fastapi_migration_done @security.requires_access_connection("DELETE") @provide_session @action_logging( diff --git a/airflow/api_fastapi/openapi/v1-generated.yaml b/airflow/api_fastapi/openapi/v1-generated.yaml index b08ef42c16df1..a54e0e4ca57dd 100644 --- a/airflow/api_fastapi/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/openapi/v1-generated.yaml @@ -319,6 +319,47 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/connections/{connection_id}: + delete: + tags: + - Connection + summary: Delete Connection + description: Delete a connection entry. + operationId: delete_connection + parameters: + - name: connection_id + in: path + required: true + schema: + type: string + title: Connection Id + responses: + '204': + description: Successful Response + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' components: schemas: DAGCollectionResponse: diff --git a/airflow/api_fastapi/views/public/__init__.py b/airflow/api_fastapi/views/public/__init__.py index 1c2511fc82ac2..9c0eefebb875e 100644 --- a/airflow/api_fastapi/views/public/__init__.py +++ b/airflow/api_fastapi/views/public/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations +from airflow.api_fastapi.views.public.connections import connections_router from airflow.api_fastapi.views.public.dags import dags_router from airflow.api_fastapi.views.router import AirflowRouter @@ -24,3 +25,4 @@ public_router.include_router(dags_router) +public_router.include_router(connections_router) diff --git a/airflow/api_fastapi/views/public/connections.py b/airflow/api_fastapi/views/public/connections.py new file mode 100644 index 0000000000000..d418e10026796 --- /dev/null +++ b/airflow/api_fastapi/views/public/connections.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from fastapi import Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.orm import Session +from typing_extensions import Annotated + +from airflow.api_fastapi.db.common import get_session +from airflow.api_fastapi.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.views.router import AirflowRouter +from airflow.models import Connection + +connections_router = AirflowRouter(tags=["Connection"]) + + +@connections_router.delete( + "/connections/{connection_id}", + status_code=204, + responses=create_openapi_http_exception_doc([401, 403, 404]), +) +async def delete_connection( + connection_id: str, + session: Annotated[Session, Depends(get_session)], +): + """Delete a connection entry.""" + connection = session.scalar(select(Connection).filter_by(conn_id=connection_id)) + + if connection is None: + raise HTTPException(404, f"The Connection with connection_id: `{connection_id}` was not found") + + session.delete(connection) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 96e49cc6d7673..fcddded7dc121 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1,7 +1,11 @@ // generated with @7nohe/openapi-react-query-codegen@1.6.0 import { UseQueryResult } from "@tanstack/react-query"; -import { AssetService, DagService } from "../requests/services.gen"; +import { + AssetService, + ConnectionService, + DagService, +} from "../requests/services.gen"; import { DagRunState } from "../requests/types.gen"; export type AssetServiceNextRunAssetsDefaultResponse = Awaited< @@ -76,3 +80,6 @@ export type DagServicePatchDagsMutationResult = Awaited< export type DagServicePatchDagMutationResult = Awaited< ReturnType >; +export type ConnectionServiceDeleteConnectionMutationResult = Awaited< + ReturnType +>; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 985bf952e3eb3..f83c151b91e23 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -6,7 +6,11 @@ import { UseQueryOptions, } from "@tanstack/react-query"; -import { AssetService, DagService } from "../requests/services.gen"; +import { + AssetService, + ConnectionService, + DagService, +} from "../requests/services.gen"; import { DAGPatchBody, DagRunState } from "../requests/types.gen"; import * as Common from "./common"; @@ -247,3 +251,42 @@ export const useDagServicePatchDag = < }) as unknown as Promise, ...options, }); +/** + * Delete Connection + * Delete a connection entry. + * @param data The data for the request. + * @param data.connectionId + * @returns void Successful Response + * @throws ApiError + */ +export const useConnectionServiceDeleteConnection = < + TData = Common.ConnectionServiceDeleteConnectionMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + connectionId: string; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + connectionId: string; + }, + TContext + >({ + mutationFn: ({ connectionId }) => + ConnectionService.deleteConnection({ + connectionId, + }) as unknown as Promise, + ...options, + }); diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index be216bd534c61..24c960d2b7d5f 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -11,6 +11,8 @@ import type { PatchDagsResponse, PatchDagData, PatchDagResponse, + DeleteConnectionData, + DeleteConnectionResponse, } from "./types.gen"; export class AssetService { @@ -159,3 +161,31 @@ export class DagService { }); } } + +export class ConnectionService { + /** + * Delete Connection + * Delete a connection entry. + * @param data The data for the request. + * @param data.connectionId + * @returns void Successful Response + * @throws ApiError + */ + public static deleteConnection( + data: DeleteConnectionData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "DELETE", + url: "/public/connections/{connection_id}", + path: { + connection_id: data.connectionId, + }, + errors: { + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } +} diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index e1db8310a1dc1..b38d5c00a69f3 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -134,6 +134,12 @@ export type PatchDagData = { export type PatchDagResponse = DAGResponse; +export type DeleteConnectionData = { + connectionId: string; +}; + +export type DeleteConnectionResponse = void; + export type $OpenApiTs = { "/ui/next_run_datasets/{dag_id}": { get: { @@ -227,4 +233,31 @@ export type $OpenApiTs = { }; }; }; + "/public/connections/{connection_id}": { + delete: { + req: DeleteConnectionData; + res: { + /** + * Successful Response + */ + 204: void; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; }; diff --git a/tests/api_fastapi/views/public/test_connections.py b/tests/api_fastapi/views/public/test_connections.py new file mode 100644 index 0000000000000..cfdca1d67984d --- /dev/null +++ b/tests/api_fastapi/views/public/test_connections.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models import Connection +from airflow.utils.session import provide_session +from tests.test_utils.db import clear_db_connections + +pytestmark = pytest.mark.db_test + +TEST_CONN_ID = "test_connection_id" +TEST_CONN_TYPE = "test_type" + + +@provide_session +def _create_connection(session) -> None: + connection_model = Connection(conn_id=TEST_CONN_ID, conn_type=TEST_CONN_TYPE) + session.add(connection_model) + + +class TestConnectionEndpoint: + @pytest.fixture(autouse=True) + def setup(self) -> None: + clear_db_connections(False) + + def teardown_method(self) -> None: + clear_db_connections() + + def create_connection(self): + _create_connection() + + +class TestDeleteConnection(TestConnectionEndpoint): + def test_delete_should_respond_204(self, test_client, session): + self.create_connection() + conns = session.query(Connection).all() + assert len(conns) == 1 + response = test_client.delete(f"/public/connections/{TEST_CONN_ID}") + assert response.status_code == 204 + connection = session.query(Connection).all() + assert len(connection) == 0 + + def test_delete_should_respond_404(self, test_client): + response = test_client.delete(f"/public/connections/{TEST_CONN_ID}") + assert response.status_code == 404 + body = response.json() + assert f"The Connection with connection_id: `{TEST_CONN_ID}` was not found" == body["detail"] From 908b8e74e2a06dff9f1e10d3ad0019160b644b14 Mon Sep 17 00:00:00 2001 From: Brent Bovenzi Date: Wed, 2 Oct 2024 11:08:53 +0200 Subject: [PATCH 250/349] Add is_paused toggle (#42621) * Add pause/unpause DAG toggle * wire up onSuccess handler * Refactor query names --- airflow/ui/package.json | 1 + airflow/ui/pnpm-lock.yaml | 9 ++ airflow/ui/rules/react.js | 3 +- .../ui/src/components/DataTable/DataTable.tsx | 2 +- airflow/ui/src/components/TogglePause.tsx | 56 ++++++++++++ airflow/ui/src/pages/DagsList/DagsFilters.tsx | 86 +++++++++++++++++++ .../ui/src/pages/{ => DagsList}/DagsList.tsx | 63 +++++--------- airflow/ui/src/pages/DagsList/index.tsx | 20 +++++ 8 files changed, 196 insertions(+), 44 deletions(-) create mode 100644 airflow/ui/src/components/TogglePause.tsx create mode 100644 airflow/ui/src/pages/DagsList/DagsFilters.tsx rename airflow/ui/src/pages/{ => DagsList}/DagsList.tsx (72%) create mode 100644 airflow/ui/src/pages/DagsList/index.tsx diff --git a/airflow/ui/package.json b/airflow/ui/package.json index 1f77334074f03..82c6370f9dcba 100644 --- a/airflow/ui/package.json +++ b/airflow/ui/package.json @@ -32,6 +32,7 @@ }, "devDependencies": { "@7nohe/openapi-react-query-codegen": "^1.6.0", + "@eslint/compat": "^1.1.1", "@eslint/js": "^9.10.0", "@stylistic/eslint-plugin": "^2.8.0", "@tanstack/eslint-plugin-query": "^5.52.0", diff --git a/airflow/ui/pnpm-lock.yaml b/airflow/ui/pnpm-lock.yaml index 0f9f256941f5e..515e7fea5279d 100644 --- a/airflow/ui/pnpm-lock.yaml +++ b/airflow/ui/pnpm-lock.yaml @@ -51,6 +51,9 @@ importers: '@7nohe/openapi-react-query-codegen': specifier: ^1.6.0 version: 1.6.0(commander@12.1.0)(glob@11.0.0)(magicast@0.3.5)(ts-morph@23.0.0)(typescript@5.5.4) + '@eslint/compat': + specifier: ^1.1.1 + version: 1.1.1 '@eslint/js': specifier: ^9.10.0 version: 9.10.0 @@ -920,6 +923,10 @@ packages: resolution: {integrity: sha512-G/M/tIiMrTAxEWRfLfQJMmGNX28IxBg4PBz8XqQhqUHLFI6TL2htpIB1iQCj144V5ee/JaKyT9/WZ0MGZWfA7A==} engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0} + '@eslint/compat@1.1.1': + resolution: {integrity: sha512-lpHyRyplhGPL5mGEh6M9O5nnKk0Gz4bFI+Zu6tKlPpDUN7XshWvH9C/px4UVm87IAANE0W81CEsNGbS1KlzXpA==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/config-array@0.18.0': resolution: {integrity: sha512-fTxvnS1sRMu3+JjXwJG0j/i4RT9u4qJ+lqS/yCGap4lH4zZGzQ7tu+xZqQmcMZq5OBZDL4QRxQzRjkWcGt8IVw==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} @@ -4368,6 +4375,8 @@ snapshots: '@eslint-community/regexpp@4.11.0': {} + '@eslint/compat@1.1.1': {} + '@eslint/config-array@0.18.0': dependencies: '@eslint/object-schema': 2.1.4 diff --git a/airflow/ui/rules/react.js b/airflow/ui/rules/react.js index 8b4b610078c53..4c8d8b8ba5f09 100644 --- a/airflow/ui/rules/react.js +++ b/airflow/ui/rules/react.js @@ -20,6 +20,7 @@ /** * @import { FlatConfig } from "@typescript-eslint/utils/ts-eslint"; */ +import { fixupPluginRules } from "@eslint/compat"; import jsxA11y from "eslint-plugin-jsx-a11y"; import react from "eslint-plugin-react"; import reactHooks from "eslint-plugin-react-hooks"; @@ -57,7 +58,7 @@ export const reactRefreshNamespace = "react-refresh"; export const reactRules = /** @type {const} @satisfies {FlatConfig.Config} */ ({ plugins: { [jsxA11yNamespace]: jsxA11y, - [reactHooksNamespace]: reactHooks, + [reactHooksNamespace]: fixupPluginRules(reactHooks), [reactNamespace]: react, [reactRefreshNamespace]: reactRefresh, }, diff --git a/airflow/ui/src/components/DataTable/DataTable.tsx b/airflow/ui/src/components/DataTable/DataTable.tsx index a4bf1255a4ba6..705d7883f07d2 100644 --- a/airflow/ui/src/components/DataTable/DataTable.tsx +++ b/airflow/ui/src/components/DataTable/DataTable.tsx @@ -115,7 +115,7 @@ export const DataTable = ({ return ( -
+ + + + + + + + + + + {% for host in hosts %} + + + + + + + + + + + {% endfor %} +
HostnameStateQueuesFirst OnlineLast Heart BeatActive JobsSystem Information
{{ host.worker_name }} + {%- if host.state == "offline" -%} + {{ host.state }} + {%- elif host.last_update.timestamp() <= five_min_ago.timestamp() -%} + Reported {{ host.state }} + but no heartbeat + {%- elif host.state == "starting" -%} + {{ host.state }} + {%- elif host.state == "running" -%} + {{ host.state }} + {%- elif host.state == "idle" -%} + {{ host.state }} + {%- elif host.state == "terminating" -%} + {{ host.state }} + {%- elif host.state == "unknown" -%} + {{ host.state }} + {%- else -%} + {{ host.state }} + {%- endif -%} + {% if host.queues %}{{ host.queues }}{% else %}(all){% endif %}{% if host.last_update %}{% endif %}{{ host.jobs_active }} +
    + {% for item in host.sysinfo_json %} +
  • {{ item }}: {{ host.sysinfo_json[item] }}
  • + {% endfor %} +
+
+ {% endif %} + {% endblock %} + diff --git a/airflow/providers/edge/plugins/templates/edge_worker_jobs.html b/airflow/providers/edge/plugins/templates/edge_worker_jobs.html new file mode 100644 index 0000000000000..a73e0f1d485f4 --- /dev/null +++ b/airflow/providers/edge/plugins/templates/edge_worker_jobs.html @@ -0,0 +1,63 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + #} + + {% extends base_template %} + + {% block title %} + Edge Worker Jobs + {% endblock %} + + {% block content %} +

Edge Worker Jobs

+ {% if jobs|length == 0 %} +

No jobs running currently

+ {% else %} + + + + + + + + + + + + + + + + {% for job in jobs %} + + + + + + + + + + + + + {% endfor %} +
DAG IDTask IDRun IDMap IndexTry NumberStateQueueQueued DTTMEdge WorkerLast Update
{{ job.dag_id }}{{ job.task_id }}{{ job.run_id }}{% if job.map_index >= 0 %}{{ job.map_index }}{% else %}-{% endif %}{{ job.try_number }}{{ html_states[job.state] }}{{ job.queue }}{% if job.edge_worker %}{{ job.edge_worker }}{% endif %}{% if job.last_update %}{% endif %}
+ {% endif %} + {% endblock %} + diff --git a/airflow/providers/edge/provider.yaml b/airflow/providers/edge/provider.yaml index cb775ee7cc7e4..6525b7bb846ff 100644 --- a/airflow/providers/edge/provider.yaml +++ b/airflow/providers/edge/provider.yaml @@ -32,6 +32,10 @@ dependencies: - apache-airflow>=2.10.0 - pydantic>=2.3.0 +plugins: + - name: edge_executor + plugin-class: airflow.providers.edge.plugins.edge_executor_plugin.EdgeExecutorPlugin + config: edge: description: | diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 10631afb9b292..59da56f180744 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -528,7 +528,12 @@ "pydantic>=2.3.0" ], "devel-deps": [], - "plugins": [], + "plugins": [ + { + "name": "edge_executor", + "plugin-class": "airflow.providers.edge.plugins.edge_executor_plugin.EdgeExecutorPlugin" + } + ], "cross-providers-deps": [], "excluded-python-versions": [], "state": "not-ready" diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index cb59afd36742a..7e4bedbfb8c1b 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -417,7 +417,7 @@ def test_does_not_double_import_entrypoint_provider_plugins(self): assert len(plugins_manager.plugins) == 0 plugins_manager.load_entrypoint_plugins() plugins_manager.load_providers_plugins() - assert len(plugins_manager.plugins) == 3 + assert len(plugins_manager.plugins) == 4 class TestPluginsDirectorySource: diff --git a/tests/providers/edge/api_endpoints/__init__.py b/tests/providers/edge/api_endpoints/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/edge/api_endpoints/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/edge/api_endpoints/test_health_endpoint.py b/tests/providers/edge/api_endpoints/test_health_endpoint.py new file mode 100644 index 0000000000000..1bfc9e5c0c5bf --- /dev/null +++ b/tests/providers/edge/api_endpoints/test_health_endpoint.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.edge.api_endpoints.health_endpoint import health + + +def test_health(): + assert health() == {} diff --git a/tests/providers/edge/api_endpoints/test_rpc_api_endpoint.py b/tests/providers/edge/api_endpoints/test_rpc_api_endpoint.py new file mode 100644 index 0000000000000..becf2f9397e31 --- /dev/null +++ b/tests/providers/edge/api_endpoints/test_rpc_api_endpoint.py @@ -0,0 +1,281 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Generator +from unittest import mock + +import pytest + +from airflow.api_connexion.exceptions import PermissionDenied +from airflow.configuration import conf +from airflow.models.baseoperator import BaseOperator +from airflow.models.connection import Connection +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance +from airflow.models.xcom import XCom +from airflow.operators.empty import EmptyOperator +from airflow.providers.edge.api_endpoints.rpc_api_endpoint import _initialize_method_map +from airflow.providers.edge.models.edge_job import EdgeJob +from airflow.providers.edge.models.edge_logs import EdgeLogs +from airflow.providers.edge.models.edge_worker import EdgeWorker +from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic +from airflow.serialization.serialized_objects import BaseSerialization +from airflow.settings import _ENABLE_AIP_44 +from airflow.utils.jwt_signer import JWTSigner +from airflow.utils.state import State +from airflow.www import app +from tests.test_utils.decorators import dont_initialize_flask_app_submodules +from tests.test_utils.mock_plugins import mock_plugin_manager + +# Note: Sounds a bit strange to disable internal API tests in isolation mode but... +# As long as the test is modelled to run its own internal API endpoints, it is conflicting +# to the test setup with a dedicated internal API server. +pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] + + +def test_initialize_method_map(): + method_map = _initialize_method_map() + assert len(method_map) > 70 + for method in [ + # Test some basics + XCom.get_value, + XCom.get_one, + XCom.clear, + XCom.set, + DagRun.get_previous_dagrun, + DagRun.get_previous_scheduled_dagrun, + DagRun.get_task_instances, + DagRun.fetch_task_instance, + # Test some for Edge + EdgeJob.reserve_task, + EdgeJob.set_state, + EdgeLogs.push_logs, + EdgeWorker.register_worker, + EdgeWorker.set_state, + ]: + method_key = f"{method.__module__}.{method.__qualname__}" + assert method_key in method_map.keys() + + +if TYPE_CHECKING: + from flask import Flask + +TEST_METHOD_NAME = "test_method" +TEST_METHOD_WITH_LOG_NAME = "test_method_with_log" +TEST_API_ENDPOINT = "/edge_worker/v1/rpcapi" + +mock_test_method = mock.MagicMock() + +pytest.importorskip("pydantic", minversion="2.0.0") + + +def equals(a, b) -> bool: + return a == b + + +@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") +class TestRpcApiEndpoint: + @pytest.fixture(scope="session") + def minimal_app_for_edge_api(self) -> Flask: + @dont_initialize_flask_app_submodules( + skip_all_except=[ + "init_api_auth", # This is needed for Airflow 2.10 compat tests + "init_appbuilder", + "init_plugins", + ] + ) + def factory() -> Flask: + import airflow.providers.edge.plugins.edge_executor_plugin as plugin_module + + class TestingEdgeExecutorPlugin(plugin_module.EdgeExecutorPlugin): + flask_blueprints = [plugin_module._get_api_endpoints(), plugin_module.template_bp] + + testing_edge_plugin = TestingEdgeExecutorPlugin() + assert len(testing_edge_plugin.flask_blueprints) > 0 + with mock_plugin_manager(plugins=[testing_edge_plugin]): + return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore + + return factory() + + @pytest.fixture + def setup_attrs(self, minimal_app_for_edge_api: Flask) -> Generator: + self.app = minimal_app_for_edge_api + self.client = self.app.test_client() # type:ignore + mock_test_method.reset_mock() + mock_test_method.side_effect = None + with mock.patch( + "airflow.providers.edge.api_endpoints.rpc_api_endpoint._initialize_method_map" + ) as mock_initialize_method_map: + mock_initialize_method_map.return_value = { + TEST_METHOD_NAME: mock_test_method, + } + yield mock_initialize_method_map + + @pytest.fixture + def signer(self) -> JWTSigner: + return JWTSigner( + secret_key=conf.get("core", "internal_api_secret_key"), + expiration_time_in_seconds=conf.getint("core", "internal_api_clock_grace", fallback=30), + audience="api", + ) + + @pytest.mark.parametrize( + "input_params, method_result, result_cmp_func, method_params", + [ + ({}, None, lambda got, _: got == b"", {}), + ({}, "test_me", equals, {}), + ( + BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"}), + ("dag_id_15", "fake-task", 1), + equals, + {"dag_id": 15, "task_id": "fake-task"}, + ), + ( + {}, + TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING), + lambda a, b: a.model_dump() == TaskInstancePydantic.model_validate(b).model_dump() + and isinstance(a.task, BaseOperator), + {}, + ), + ( + {}, + Connection(conn_id="test_conn", conn_type="http", host="", password=""), + lambda a, b: a.get_uri() == b.get_uri() and a.conn_id == b.conn_id, + {}, + ), + ], + ) + def test_method( + self, input_params, method_result, result_cmp_func, method_params, setup_attrs, signer: JWTSigner + ): + mock_test_method.return_value = method_result + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), + } + input_data = { + "jsonrpc": "2.0", + "method": TEST_METHOD_NAME, + "params": input_params, + } + response = self.client.post( + TEST_API_ENDPOINT, + headers=headers, + data=json.dumps(input_data), + ) + assert response.status_code == 200 + if method_result: + response_data = BaseSerialization.deserialize(json.loads(response.data), use_pydantic_models=True) + else: + response_data = response.data + + assert result_cmp_func(response_data, method_result) + + mock_test_method.assert_called_once_with(**method_params, session=mock.ANY) + + def test_method_with_exception(self, setup_attrs, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), + } + mock_test_method.side_effect = ValueError("Error!!!") + data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": {}} + + response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) + assert response.status_code == 500 + assert response.data, b"Error executing method: test_method." + mock_test_method.assert_called_once() + + def test_unknown_method(self, setup_attrs, signer: JWTSigner): + UNKNOWN_METHOD = "i-bet-it-does-not-exist" + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": signer.generate_signed_token({"method": UNKNOWN_METHOD}), + } + data = {"jsonrpc": "2.0", "method": UNKNOWN_METHOD, "params": {}} + + response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) + assert response.status_code == 400 + assert response.data.startswith(b"Unrecognized method: i-bet-it-does-not-exist.") + mock_test_method.assert_not_called() + + def test_invalid_jsonrpc(self, setup_attrs, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), + } + data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} + + response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) + assert response.status_code == 400 + assert response.data.startswith(b"Expected jsonrpc 2.0 request.") + mock_test_method.assert_not_called() + + def test_missing_token(self, setup_attrs): + mock_test_method.return_value = None + + input_data = { + "jsonrpc": "2.0", + "method": TEST_METHOD_NAME, + "params": {}, + } + with pytest.raises(PermissionDenied, match="Unable to authenticate API via token."): + self.client.post( + TEST_API_ENDPOINT, + headers={"Content-Type": "application/json", "Accept": "application/json"}, + data=json.dumps(input_data), + ) + + def test_invalid_token(self, setup_attrs, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), + } + data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} + + with pytest.raises( + PermissionDenied, match="Bad Signature. Please use only the tokens provided by the API." + ): + self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) + + def test_missing_accept(self, setup_attrs, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), + } + data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} + + with pytest.raises(PermissionDenied, match="Expected Accept: application/json"): + self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) + + def test_wrong_accept(self, setup_attrs, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Accept": "application/html", + "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), + } + data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} + + with pytest.raises(PermissionDenied, match="Expected Accept: application/json"): + self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) diff --git a/tests/providers/edge/plugins/__init__.py b/tests/providers/edge/plugins/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/edge/plugins/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/edge/plugins/test_edge_executor_plugin.py b/tests/providers/edge/plugins/test_edge_executor_plugin.py new file mode 100644 index 0000000000000..e3422b17da3c8 --- /dev/null +++ b/tests/providers/edge/plugins/test_edge_executor_plugin.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import importlib + +import pytest + +from airflow.plugins_manager import AirflowPlugin +from airflow.providers.edge.plugins import edge_executor_plugin +from tests.test_utils.config import conf_vars + + +def test_plugin_inactive(): + with conf_vars({("edge", "api_enabled"): "false"}): + importlib.reload(edge_executor_plugin) + + from airflow.providers.edge.plugins.edge_executor_plugin import ( + EDGE_EXECUTOR_ACTIVE, + EdgeExecutorPlugin, + ) + + rep = EdgeExecutorPlugin() + assert not EDGE_EXECUTOR_ACTIVE + assert len(rep.flask_blueprints) == 0 + assert len(rep.appbuilder_views) == 0 + + +def test_plugin_active(): + with conf_vars({("edge", "api_enabled"): "true"}): + importlib.reload(edge_executor_plugin) + + from airflow.providers.edge.plugins.edge_executor_plugin import ( + EDGE_EXECUTOR_ACTIVE, + EdgeExecutorPlugin, + ) + + rep = EdgeExecutorPlugin() + assert EDGE_EXECUTOR_ACTIVE + assert len(rep.flask_blueprints) == 2 + assert len(rep.appbuilder_views) == 2 + + +@pytest.fixture +def plugin(): + from airflow.providers.edge.plugins.edge_executor_plugin import EdgeExecutorPlugin + + return EdgeExecutorPlugin() + + +def test_plugin_is_airflow_plugin(plugin): + assert isinstance(plugin, AirflowPlugin) From 446ac403a540e5c21c2d7ea84e22bbc1f19f84a3 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Thu, 3 Oct 2024 03:00:54 -0700 Subject: [PATCH 276/349] Update min version of Pydantic to 2.6.4 (#42694) Pydantic 2.6.4 fixes problem with AliasGenerator to throw error when generating schema - see an issue in Pydantic repository https://github.com/pydantic/pydantic/issues/8768 --- airflow/providers/edge/provider.yaml | 2 +- generated/provider_dependencies.json | 2 +- hatch_build.py | 2 +- newsfragments/41857.significant.rst | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/providers/edge/provider.yaml b/airflow/providers/edge/provider.yaml index 6525b7bb846ff..d6644271a02f3 100644 --- a/airflow/providers/edge/provider.yaml +++ b/airflow/providers/edge/provider.yaml @@ -30,7 +30,7 @@ versions: dependencies: - apache-airflow>=2.10.0 - - pydantic>=2.3.0 + - pydantic>=2.6.4 plugins: - name: edge_executor diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 59da56f180744..7dc5e337292b8 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -525,7 +525,7 @@ "edge": { "deps": [ "apache-airflow>=2.10.0", - "pydantic>=2.3.0" + "pydantic>=2.6.4" ], "devel-deps": [], "plugins": [ diff --git a/hatch_build.py b/hatch_build.py index 6e3d77981e3d4..765e71ff98962 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -467,7 +467,7 @@ 'pendulum>=3.0.0,<4.0;python_version>="3.12"', "pluggy>=1.5.0", "psutil>=5.8.0", - "pydantic>=2.6.0", + "pydantic>=2.6.4", "pygments>=2.0.1", "pyjwt>=2.0.0", "python-daemon>=3.0.0", diff --git a/newsfragments/41857.significant.rst b/newsfragments/41857.significant.rst index df3c85853eee7..f0b06f2811b1f 100644 --- a/newsfragments/41857.significant.rst +++ b/newsfragments/41857.significant.rst @@ -1,3 +1,3 @@ **Breaking Change** -Airflow core now depends on ``pydantic>=2.3.0``. If you have Pydantic v1 installed, please upgrade. +Airflow core now depends on Pydantic v2. If you have Pydantic v1 installed, please upgrade. From b5bafdf00b4711da7a82ff8c283722238bf807b6 Mon Sep 17 00:00:00 2001 From: Lorin Dawson <22798188+R7L208@users.noreply.github.com> Date: Thu, 3 Oct 2024 06:36:30 -0600 Subject: [PATCH 277/349] Add `on_kill` to Databricks Workflow Operator (#42115) * add on_kill override to databricks workflow operator * on_kill equivalent for DatabricksSqlOperator * add tests for create_timeout_thread * add note for on_kill in DatabricksCopyIntoOperator * chore: static checks * remove changes for databricks_sql.py for PR isolated to databricks_workflows.py --------- Co-authored-by: Lorin --- .../operators/databricks_workflow.py | 27 ++++++++++++++++++- .../operators/test_databricks_workflow.py | 22 +++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/airflow/providers/databricks/operators/databricks_workflow.py b/airflow/providers/databricks/operators/databricks_workflow.py index 15333dc69118b..6df8e2d025cea 100644 --- a/airflow/providers/databricks/operators/databricks_workflow.py +++ b/airflow/providers/databricks/operators/databricks_workflow.py @@ -52,7 +52,7 @@ class WorkflowRunMetadata: """ conn_id: str - job_id: str + job_id: int run_id: int @@ -116,6 +116,7 @@ def __init__( self.notebook_params = notebook_params or {} self.tasks_to_convert = tasks_to_convert or [] self.relevant_upstreams = [task_id] + self.workflow_run_metadata: WorkflowRunMetadata | None = None super().__init__(task_id=task_id, **kwargs) def _get_hook(self, caller: str) -> DatabricksHook: @@ -212,12 +213,36 @@ def execute(self, context: Context) -> Any: self._wait_for_job_to_start(run_id) + self.workflow_run_metadata = WorkflowRunMetadata( + self.databricks_conn_id, + job_id, + run_id, + ) + return { "conn_id": self.databricks_conn_id, "job_id": job_id, "run_id": run_id, } + def on_kill(self) -> None: + if self.workflow_run_metadata: + run_id = self.workflow_run_metadata.run_id + job_id = self.workflow_run_metadata.job_id + + self._hook.cancel_run(run_id) + self.log.info( + "Run: %(run_id)s of job_id: %(job_id)s was requested to be cancelled.", + {"run_id": run_id, "job_id": job_id}, + ) + else: + self.log.error( + """ + Error: Workflow Run metadata is not populated, so the run was not canceled. This could be due + to the workflow not being started or an error in the workflow creation process. + """ + ) + class DatabricksWorkflowTaskGroup(TaskGroup): """ diff --git a/tests/providers/databricks/operators/test_databricks_workflow.py b/tests/providers/databricks/operators/test_databricks_workflow.py index 4c3f54b800ae9..fbc429ed1d9a8 100644 --- a/tests/providers/databricks/operators/test_databricks_workflow.py +++ b/tests/providers/databricks/operators/test_databricks_workflow.py @@ -28,6 +28,7 @@ from airflow.providers.databricks.hooks.databricks import RunLifeCycleState from airflow.providers.databricks.operators.databricks_workflow import ( DatabricksWorkflowTaskGroup, + WorkflowRunMetadata, _CreateDatabricksWorkflowOperator, _flatten_node, ) @@ -59,6 +60,11 @@ def mock_task_group(): return mock_group +@pytest.fixture +def mock_workflow_run_metadata(): + return MagicMock(spec=WorkflowRunMetadata) + + def test_flatten_node(): """Test that _flatten_node returns a flat list of operators.""" task_group = MagicMock(spec=DatabricksWorkflowTaskGroup) @@ -231,3 +237,19 @@ def test_task_group_root_tasks_set_upstream_to_operator(mock_databricks_workflow create_operator_instance = mock_databricks_workflow_operator.return_value task1.set_upstream.assert_called_once_with(create_operator_instance) + + +def test_on_kill(mock_databricks_hook, context, mock_workflow_run_metadata): + """Test that _CreateDatabricksWorkflowOperator.execute runs the task group.""" + operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default") + operator.workflow_run_metadata = mock_workflow_run_metadata + + RUN_ID = 789 + + mock_workflow_run_metadata.conn_id = operator.databricks_conn_id + mock_workflow_run_metadata.job_id = "123" + mock_workflow_run_metadata.run_id = RUN_ID + + operator.on_kill() + + operator._hook.cancel_run.assert_called_once_with(RUN_ID) From 0a9fb4e423559547d9e7f130b9a22d30dc394a1c Mon Sep 17 00:00:00 2001 From: Pierre Jeambrun Date: Fri, 4 Oct 2024 00:50:30 +0800 Subject: [PATCH 278/349] AIP-84 Serve new UI from FastAPI API (#42663) * Serve new UI from FastAPI API * Fix CI * Fix CI another try --- airflow/api_fastapi/app.py | 32 ++++++++++++++- airflow/ui/.env.example | 3 +- airflow/ui/src/App.tsx | 1 - airflow/ui/src/layouts/Nav/DocsButton.tsx | 2 +- airflow/ui/src/layouts/Nav/Nav.tsx | 2 +- airflow/ui/src/main.tsx | 4 +- airflow/ui/src/vite-env.d.ts | 2 +- airflow/ui/vite.config.ts | 4 +- airflow/www/app.py | 2 - airflow/www/extensions/init_react_ui.py | 40 ------------------- airflow/www/templates/airflow/main.html | 2 +- .../14_node_environment_setup.rst | 10 ----- .../run_update_fastapi_api_spec.py | 2 + 13 files changed, 41 insertions(+), 65 deletions(-) delete mode 100644 airflow/www/extensions/init_react_ui.py diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py index 6f8bbcdf149b4..6b9df0ed8b7f8 100644 --- a/airflow/api_fastapi/app.py +++ b/airflow/api_fastapi/app.py @@ -16,9 +16,16 @@ # under the License. from __future__ import annotations -from fastapi import FastAPI +import os +from pathlib import Path + +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates +from airflow.settings import AIRFLOW_PATH from airflow.www.extensions.init_dagbag import get_dag_bag app: FastAPI | None = None @@ -70,6 +77,29 @@ def init_views(app) -> None: app.include_router(ui_router) app.include_router(public_router) + dev_mode = os.environ.get("DEV_MODE", False) == "true" + + directory = Path(AIRFLOW_PATH) / ("airflow/ui/dev" if dev_mode else "airflow/ui/dist") + + # During python tests or when the backend is run without having the frontend build + # those directories might not exist. App should not fail initializing in those scenarios. + Path(directory).mkdir(exist_ok=True) + + templates = Jinja2Templates(directory=directory) + + app.mount( + "/static", + StaticFiles( + directory=directory, + html=True, + ), + name="webapp_static_folder", + ) + + @app.get("/webapp/{rest_of_path:path}", response_class=HTMLResponse, include_in_schema=False) + def webapp(request: Request, rest_of_path: str): + return templates.TemplateResponse("/index.html", {"request": request}, media_type="text/html") + def cached_app(config=None, testing=False) -> FastAPI: """Return cached instance of Airflow UI app.""" diff --git a/airflow/ui/.env.example b/airflow/ui/.env.example index 9374d93de6bca..3e3c1569f1238 100644 --- a/airflow/ui/.env.example +++ b/airflow/ui/.env.example @@ -19,5 +19,4 @@ # This is an example. You should make your own `.env.local` file for development - -VITE_FASTAPI_URL="http://localhost:29091" +VITE_LEGACY_API_URL="http://localhost:28080" diff --git a/airflow/ui/src/App.tsx b/airflow/ui/src/App.tsx index 0eb603b46a330..3c5e9d866f0c9 100644 --- a/airflow/ui/src/App.tsx +++ b/airflow/ui/src/App.tsx @@ -22,7 +22,6 @@ import { DagsList } from "src/pages/DagsList"; import { BaseLayout } from "./layouts/BaseLayout"; -// Note: When changing routes, make sure to update init_react_ui.py too export const App = () => ( } path="/"> diff --git a/airflow/ui/src/layouts/Nav/DocsButton.tsx b/airflow/ui/src/layouts/Nav/DocsButton.tsx index 07a4b93dfaede..e85d923b88f54 100644 --- a/airflow/ui/src/layouts/Nav/DocsButton.tsx +++ b/airflow/ui/src/layouts/Nav/DocsButton.tsx @@ -38,7 +38,7 @@ const links = [ title: "GitHub Repo", }, { - href: `${import.meta.env.VITE_FASTAPI_URL}/docs`, + href: `/docs`, title: "REST API Reference", }, ]; diff --git a/airflow/ui/src/layouts/Nav/Nav.tsx b/airflow/ui/src/layouts/Nav/Nav.tsx index 55bfd4480e0f4..9886b5eb75760 100644 --- a/airflow/ui/src/layouts/Nav/Nav.tsx +++ b/airflow/ui/src/layouts/Nav/Nav.tsx @@ -101,7 +101,7 @@ export const Nav = () => { } title="Return to legacy UI" /> diff --git a/airflow/ui/src/main.tsx b/airflow/ui/src/main.tsx index 7b762508ea7b3..daf4bcd024cd6 100644 --- a/airflow/ui/src/main.tsx +++ b/airflow/ui/src/main.tsx @@ -43,8 +43,6 @@ const queryClient = new QueryClient({ }, }); -axios.defaults.baseURL = import.meta.env.VITE_FASTAPI_URL; - // redirect to login page if the API responds with unauthorized or forbidden errors axios.interceptors.response.use( (response: AxiosResponse) => response, @@ -61,7 +59,7 @@ axios.interceptors.response.use( const root = createRoot(document.querySelector("#root") as HTMLDivElement); root.render( - + diff --git a/airflow/ui/src/vite-env.d.ts b/airflow/ui/src/vite-env.d.ts index 193866687bff9..8a62dd17206eb 100644 --- a/airflow/ui/src/vite-env.d.ts +++ b/airflow/ui/src/vite-env.d.ts @@ -21,7 +21,7 @@ /// interface ImportMetaEnv { - readonly VITE_FASTAPI_URL: string; + readonly VITE_LEGACY_API_URL: string; } interface ImportMeta { diff --git a/airflow/ui/vite.config.ts b/airflow/ui/vite.config.ts index 06ad450f377a1..7bc48d640418a 100644 --- a/airflow/ui/vite.config.ts +++ b/airflow/ui/vite.config.ts @@ -29,8 +29,8 @@ export default defineConfig({ name: "transform-url-src", transformIndexHtml: (html) => html - .replace(`src="/assets/`, `src="/ui/assets/`) - .replace(`href="/`, `href="/ui/`), + .replace(`src="/assets/`, `src="/static/assets/`) + .replace(`href="/`, `href="/webapp/`), }, ], resolve: { alias: { openapi: "/openapi-gen", src: "/src" } }, diff --git a/airflow/www/app.py b/airflow/www/app.py index f5e1191fb43fb..3409510b5a1a6 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -40,7 +40,6 @@ from airflow.www.extensions.init_dagbag import init_dagbag from airflow.www.extensions.init_jinja_globals import init_jinja_globals from airflow.www.extensions.init_manifest_files import configure_manifest_files -from airflow.www.extensions.init_react_ui import init_react_ui from airflow.www.extensions.init_robots import init_robots from airflow.www.extensions.init_security import ( init_api_auth, @@ -155,7 +154,6 @@ def create_app(config=None, testing=False): with flask_app.app_context(): init_appbuilder(flask_app) - init_react_ui(flask_app) init_appbuilder_views(flask_app) init_appbuilder_links(flask_app) init_plugins(flask_app) diff --git a/airflow/www/extensions/init_react_ui.py b/airflow/www/extensions/init_react_ui.py deleted file mode 100644 index 872a22c059476..0000000000000 --- a/airflow/www/extensions/init_react_ui.py +++ /dev/null @@ -1,40 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import os - -from flask import Blueprint - - -def init_react_ui(app): - dev_mode = os.environ.get("DEV_MODE", False) == "true" - - bp = Blueprint( - "ui", - __name__, - # The dev mode index file points to the vite dev server instead of static build files - static_folder="../../ui/dev" if dev_mode else "../../ui/dist", - static_url_path="/ui", - ) - - @bp.route("/ui", defaults={"page": ""}) - @bp.route("/ui/") - def index(page): - return bp.send_static_file("index.html") - - app.register_blueprint(bp) diff --git a/airflow/www/templates/airflow/main.html b/airflow/www/templates/airflow/main.html index 69aa6faaaf0de..008418d7e5b78 100644 --- a/airflow/www/templates/airflow/main.html +++ b/airflow/www/templates/airflow/main.html @@ -99,7 +99,7 @@ {% if auth_manager.is_logged_in() %} {% call show_message(category='info', dismissible=true) %} We have a new UI for Airflow 3.0 - Check it out now! + Check it out now! {% endcall %} {% endif %} {% endblock %} diff --git a/contributing-docs/14_node_environment_setup.rst b/contributing-docs/14_node_environment_setup.rst index 7b10f0b0d5ed5..81ced88240ac5 100644 --- a/contributing-docs/14_node_environment_setup.rst +++ b/contributing-docs/14_node_environment_setup.rst @@ -93,16 +93,6 @@ Copy the example environment cp .env.example .env.local -If you run into CORS issues, you may need to add some variables to your Breeze config, ``files/airflow-breeze-config/variables.env``: - -.. code-block:: bash - - export AIRFLOW__API__ACCESS_CONTROL_ALLOW_HEADERS="Origin, Access-Control-Request-Method" - export AIRFLOW__API__ACCESS_CONTROL_ALLOW_METHODS="*" - export AIRFLOW__API__ACCESS_CONTROL_ALLOW_ORIGINS="http://localhost:28080,http://localhost:8080" - - - DEPRECATED Airflow WWW ---------------------- diff --git a/scripts/in_container/run_update_fastapi_api_spec.py b/scripts/in_container/run_update_fastapi_api_spec.py index 4d78bc4afd585..5d31b0bee3f0c 100644 --- a/scripts/in_container/run_update_fastapi_api_spec.py +++ b/scripts/in_container/run_update_fastapi_api_spec.py @@ -29,6 +29,8 @@ # The persisted openapi spec will list all endpoints (public and ui), this # is used for code generation. for route in app.routes: + if getattr(route, "name") == "webapp": + continue route.__setattr__("include_in_schema", True) with open(OPENAPI_SPEC_FILE, "w+") as f: From 7f1588ee185319fbfc57a1774bab56acc9a94d76 Mon Sep 17 00:00:00 2001 From: Josix Date: Fri, 4 Oct 2024 02:36:21 +0900 Subject: [PATCH 279/349] fix(shell_params): prevent generating `,celery` in extra when there is no other extra items (#42709) --- dev/breeze/src/airflow_breeze/params/shell_params.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dev/breeze/src/airflow_breeze/params/shell_params.py b/dev/breeze/src/airflow_breeze/params/shell_params.py index af74be27c919b..36fa44bb8fed5 100644 --- a/dev/breeze/src/airflow_breeze/params/shell_params.py +++ b/dev/breeze/src/airflow_breeze/params/shell_params.py @@ -332,7 +332,9 @@ def compose_file(self) -> str: get_console().print( "[warning]Adding `celery` extras as it is implicitly needed by celery executor" ) - self.airflow_extras = ",".join(current_extras.split(",") + ["celery"]) + self.airflow_extras = ( + ",".join(current_extras.split(",") + ["celery"]) if current_extras else "celery" + ) compose_file_list.append(DOCKER_COMPOSE_DIR / "base.yml") self.add_docker_in_docker(compose_file_list) From 0ab281d147711be3e922e16d84fc59abd2d39f96 Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Thu, 3 Oct 2024 13:37:42 -0400 Subject: [PATCH 280/349] Mention in simple auth manager doc how to read/update passwords directly form file (#42710) * Mention in simple auth manager doc how to read/update passwords directly form file * Fix static checks --- docs/apache-airflow/core-concepts/auth-manager/simple.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/apache-airflow/core-concepts/auth-manager/simple.rst b/docs/apache-airflow/core-concepts/auth-manager/simple.rst index bef2e5032f0d7..f418ca15f2981 100644 --- a/docs/apache-airflow/core-concepts/auth-manager/simple.rst +++ b/docs/apache-airflow/core-concepts/auth-manager/simple.rst @@ -51,6 +51,8 @@ Each user needs two pieces of information: The password is auto-generated for each user and printed out in the webserver logs. When generated, these passwords are also saved in your environment, therefore they will not change if you stop or restart your environment. +The passwords are saved in the file ``generated/simple_auth_manager_passwords.json.generated``, you can read and update them directly in the file as well if desired. + .. _roles-permissions: Manage roles and permissions From a912fef043925a5a666ac8be10accf426f330ae1 Mon Sep 17 00:00:00 2001 From: Maksim Date: Thu, 3 Oct 2024 12:49:46 -0700 Subject: [PATCH 281/349] Update tensorflow image uris for VertexAI system tests (#42707) --- .../cloud/vertex_ai/example_vertex_ai_custom_container.py | 3 +-- .../google/cloud/vertex_ai/example_vertex_ai_custom_job.py | 4 ++-- .../vertex_ai/example_vertex_ai_custom_job_python_package.py | 4 ++-- .../google/cloud/vertex_ai/example_vertex_ai_model_service.py | 4 ++-- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py index b8d01f8d71493..dc09a8be90ed7 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py @@ -68,9 +68,8 @@ def TABULAR_DATASET(bucket_name): } -CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest" CUSTOM_CONTAINER_URI = "us-central1-docker.pkg.dev/airflow-system-tests-resources/system-tests/housing" -MODEL_SERVING_CONTAINER_URI = "gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest" +MODEL_SERVING_CONTAINER_URI = "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest" REPLICA_COUNT = 1 MACHINE_TYPE = "n1-standard-4" ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED" diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py index b2856a28a23d0..8762feb85ba39 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py @@ -68,8 +68,8 @@ def TABULAR_DATASET(bucket_name): } -CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest" -MODEL_SERVING_CONTAINER_URI = "gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest" +CONTAINER_URI = "us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-2:latest" +MODEL_SERVING_CONTAINER_URI = "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest" REPLICA_COUNT = 1 # LOCAL_TRAINING_SCRIPT_PATH should be set for Airflow which is running on distributed system. diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py index 33105d273f159..49a8d870bc394 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py @@ -68,8 +68,8 @@ def TABULAR_DATASET(bucket_name): } -CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest" -MODEL_SERVING_CONTAINER_URI = "gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest" +CONTAINER_URI = "us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-2:latest" +MODEL_SERVING_CONTAINER_URI = "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest" REPLICA_COUNT = 1 MACHINE_TYPE = "n1-standard-4" ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED" diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py index e6ad1e710e4c3..b06f8287798df 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py @@ -85,7 +85,7 @@ ), } -CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest" +CONTAINER_URI = "us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-2:latest" # LOCAL_TRAINING_SCRIPT_PATH should be set for Airflow which is running on distributed system. # For example in Composer the correct path is `gcs/data/california_housing_training_script.py`. @@ -99,7 +99,7 @@ }, "export_format_id": "custom-trained", } -MODEL_SERVING_CONTAINER_URI = "gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest" +MODEL_SERVING_CONTAINER_URI = "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest" MODEL_OBJ = { "display_name": f"model-{ENV_ID}", "artifact_uri": "{{ti.xcom_pull('custom_task')['artifactUri']}}", From 9f4cf3bdee743293cd3f8dcde81a2421ea2f526d Mon Sep 17 00:00:00 2001 From: GPK Date: Thu, 3 Oct 2024 21:36:09 +0100 Subject: [PATCH 282/349] fix PubSubAsyncHook in PubsubPullTrigger to use gcp_conn_id (#42671) --- .../providers/google/cloud/triggers/pubsub.py | 10 +++++++++- .../google/cloud/triggers/test_pubsub.py | 20 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/airflow/providers/google/cloud/triggers/pubsub.py b/airflow/providers/google/cloud/triggers/pubsub.py index db3fe409e942b..e98603006f725 100644 --- a/airflow/providers/google/cloud/triggers/pubsub.py +++ b/airflow/providers/google/cloud/triggers/pubsub.py @@ -19,6 +19,7 @@ from __future__ import annotations import asyncio +from functools import cached_property from typing import Any, AsyncIterator, Sequence from google.cloud.pubsub_v1.types import ReceivedMessage @@ -67,7 +68,6 @@ def __init__( self.poke_interval = poke_interval self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - self.hook = PubSubAsyncHook() def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize PubsubPullTrigger arguments and classpath.""" @@ -113,3 +113,11 @@ async def message_acknowledgement(self, pulled_messages): messages=pulled_messages, ) self.log.info("Acknowledged ack_ids from subscription %s", self.subscription) + + @cached_property + def hook(self) -> PubSubAsyncHook: + return PubSubAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + project_id=self.project_id, + ) diff --git a/tests/providers/google/cloud/triggers/test_pubsub.py b/tests/providers/google/cloud/triggers/test_pubsub.py index e1a4e178d2918..60acd2b7d4c2e 100644 --- a/tests/providers/google/cloud/triggers/test_pubsub.py +++ b/tests/providers/google/cloud/triggers/test_pubsub.py @@ -110,3 +110,23 @@ async def test_async_pubsub_pull_trigger_return_event(self, mock_pull): response = await trigger.run().asend(None) assert response == expected_event + + @mock.patch("airflow.providers.google.cloud.triggers.pubsub.PubSubAsyncHook") + def test_hook(self, mock_async_hook): + trigger = PubsubPullTrigger( + project_id=PROJECT_ID, + subscription="subscription", + max_messages=MAX_MESSAGES, + ack_messages=False, + poke_interval=TEST_POLL_INTERVAL, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=None, + ) + async_hook_actual = trigger.hook + + mock_async_hook.assert_called_once_with( + gcp_conn_id=trigger.gcp_conn_id, + impersonation_chain=trigger.impersonation_chain, + project_id=trigger.project_id, + ) + assert async_hook_actual == mock_async_hook.return_value From 38941c508da99d6b42d0cf39d4c50edd8911173c Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 4 Oct 2024 08:57:51 +0900 Subject: [PATCH 283/349] Rename dataset endpoints as asset endpoints (#42579) * feat(api_connexion): rename dataset_endpoint module as asset_endpoint * feat(api_connexion/openapi): rename tag Dataset as Asset * feat(api_connexion): rename create_dataset_event as create_asset_event * feat(api_connexion): rename schema CreateDatasetEvent as CreateAssetEvent * test(api_connexion): rename test_dataset_endpoint as test_asset_endpoint * feat(api_connexion): rename delete_dataset_queued_events as delete_asset_queued_events * feat(api_connexion): rename get_dataset_queued_events as get_asset_queued_events * feat(api_connexion): rename delete_dag_dataset_queued_events as delete_dag_asset_queued_events * feat(api_connexion): rename delete_dag_dataset_queued_event as delete_dag_asset_queued_event * feat(api_connexion): rename get_dag_dataset_queued_events as get_dag_asset_queued_events * feat(api_connexion): rename get_dag_dataset_queued_event as get_dag_asset_queued_event * refactor(api_connexion): remove unused dataset_id in _generate_queued_event_where_clause * feat(api_connexion): rename get_dataset_events as get_asset_events * feat(api_connexion): rename get_datasets as get_assets * feat(api_connexion): rename get_dataset as get_asset * feat(api_connexion/openapi): update api docs * feat(js): rename DatasetEvents as AssetEvents * feat(js): rename DatasetDetails as AssetDetails * feat(js): rename DatasetList as AssetList * feat(js/api): rename useUpstreamDatasetEvents as useUpstreamAssetEvents * feat(js/api): rename useDatasetsSummary as useAssetsSummary * feat(js/api): rename useDatasetDependencies as useAssetDependencies * feat(js/api): rename useDatasetEvents as useAssetEvents * feat(js/api): rename useDatasets as useAssets * feat(js/api): rename useDataset as useAsset * feat(api_connexion): rename get_upstream_dataset_events as get_upstream_asset_events * feat(api_connexion/openapi/v1): rename DatasetURI as AssetURI * feat(api_connexion/openapi/v1): rename DatasetCollection as AssetCollection * feat(api_connexion/openapi/v1): rename DagScheduleDatasetReference as DagScheduleAssetReference * feat(api_connexion/openapi/v1): rename TaskOutletDatasetReference as TaskOutletAssetReference * feat(js/api): rename DatasetEventCollection as AssetEventCollection * feat(api_connexion/openapi/v1): rename DatasetEvent as AssetEvent * feat(api_connexion/openapi/v1): rename Dataset as Asset * docs(api_connexion/openapi/v1): update dataset to asset in v1.yaml * feat(api_connexion): rename endpoint datasets as assets * test(api_connexion): rename dataset as asset * fix(api_connexion/openapi/v1): fix queued_events property name error * feat(api_fastapi): rename next_run_datasets as next_run_assets * test: resolve test conflict * docs(newsfragments): add newsfragments for dataset to asset endpoint rename * feat(js/api): rename datasetEvents as assetEvents * feat(js/api): rename variable datasetEvent as assetEvent * feat(js/api): rename dataset_api as asset_api --- ...{dataset_endpoint.py => asset_endpoint.py} | 41 +-- .../endpoints/dag_run_endpoint.py | 8 +- airflow/api_connexion/openapi/v1.yaml | 258 ++++++------- airflow/api_connexion/schemas/asset_schema.py | 10 +- airflow/api_fastapi/openapi/v1-generated.yaml | 2 +- airflow/api_fastapi/views/ui/assets.py | 2 +- .../ui/openapi-gen/requests/services.gen.ts | 2 +- airflow/ui/openapi-gen/requests/types.gen.ts | 2 +- airflow/www/static/js/api/index.ts | 28 +- .../js/api/{useDataset.ts => useAsset.ts} | 6 +- ...ependencies.ts => useAssetDependencies.ts} | 6 +- ...{useDatasetEvents.ts => useAssetEvents.ts} | 24 +- .../js/api/{useDatasets.ts => useAssets.ts} | 4 +- ...DatasetsSummary.ts => useAssetsSummary.ts} | 2 +- ...DatasetEvent.ts => useCreateAssetEvent.ts} | 19 +- ...setEvents.ts => useUpstreamAssetEvents.ts} | 22 +- .../static/js/components/DatasetEventCard.tsx | 33 +- .../js/components/SourceTaskInstance.tsx | 12 +- .../details/dagRun/DatasetTriggerEvents.tsx | 14 +- .../js/dag/details/graph/DatasetNode.tsx | 24 +- .../www/static/js/dag/details/graph/Node.tsx | 4 +- .../www/static/js/dag/details/graph/index.tsx | 40 +- .../www/static/js/dag/details/graph/utils.ts | 10 +- .../taskInstance/DatasetUpdateEvents.tsx | 14 +- airflow/www/static/js/datasetUtils.js | 8 +- .../{DatasetDetails.tsx => AssetDetails.tsx} | 12 +- .../{DatasetEvents.tsx => AssetEvents.tsx} | 20 +- ...tasetsList.test.tsx => AssetList.test.tsx} | 20 +- .../{DatasetsList.tsx => AssetsList.tsx} | 12 +- ...eDatasetEvent.tsx => CreateAssetEvent.tsx} | 12 +- .../www/static/js/datasets/Graph/index.tsx | 4 +- airflow/www/static/js/datasets/Main.tsx | 20 +- airflow/www/static/js/datasets/SearchBar.tsx | 2 +- airflow/www/static/js/types/api-generated.ts | 346 +++++++++--------- airflow/www/static/js/types/index.ts | 2 +- airflow/www/templates/airflow/dag.html | 4 +- airflow/www/templates/airflow/datasets.html | 6 +- airflow/www/templates/airflow/grid.html | 2 +- clients/python/README.md | 24 +- .../auth-manager/access-control.rst | 6 +- newsfragments/42579.significant.rst | 20 + ...set_endpoint.py => test_asset_endpoint.py} | 250 ++++++------- .../endpoints/test_dag_run_endpoint.py | 8 +- .../schemas/test_dataset_schema.py | 16 +- tests/api_fastapi/views/ui/test_assets.py | 4 +- 45 files changed, 696 insertions(+), 689 deletions(-) rename airflow/api_connexion/endpoints/{dataset_endpoint.py => asset_endpoint.py} (92%) rename airflow/www/static/js/api/{useDataset.ts => useAsset.ts} (86%) rename airflow/www/static/js/api/{useDatasetDependencies.ts => useAssetDependencies.ts} (94%) rename airflow/www/static/js/api/{useDatasetEvents.ts => useAssetEvents.ts} (80%) rename airflow/www/static/js/api/{useDatasets.ts => useAssets.ts} (90%) rename airflow/www/static/js/api/{useDatasetsSummary.ts => useAssetsSummary.ts} (98%) rename airflow/www/static/js/api/{useCreateDatasetEvent.ts => useCreateAssetEvent.ts} (77%) rename airflow/www/static/js/api/{useUpstreamDatasetEvents.ts => useUpstreamAssetEvents.ts} (67%) rename airflow/www/static/js/datasets/{DatasetDetails.tsx => AssetDetails.tsx} (91%) rename airflow/www/static/js/datasets/{DatasetEvents.tsx => AssetEvents.tsx} (87%) rename airflow/www/static/js/datasets/{DatasetsList.test.tsx => AssetList.test.tsx} (87%) rename airflow/www/static/js/datasets/{DatasetsList.tsx => AssetsList.tsx} (95%) rename airflow/www/static/js/datasets/{CreateDatasetEvent.tsx => CreateAssetEvent.tsx} (88%) create mode 100644 newsfragments/42579.significant.rst rename tests/api_connexion/endpoints/{test_dataset_endpoint.py => test_asset_endpoint.py} (74%) diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/asset_endpoint.py similarity index 92% rename from airflow/api_connexion/endpoints/dataset_endpoint.py rename to airflow/api_connexion/endpoints/asset_endpoint.py index 95c3bead3da52..cbbe542ea7987 100644 --- a/airflow/api_connexion/endpoints/dataset_endpoint.py +++ b/airflow/api_connexion/endpoints/asset_endpoint.py @@ -57,13 +57,13 @@ from airflow.api_connexion.types import APIResponse -RESOURCE_EVENT_PREFIX = "dataset" +RESOURCE_EVENT_PREFIX = "asset" @security.requires_access_asset("GET") @provide_session -def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse: - """Get an asset .""" +def get_asset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse: + """Get an asset.""" asset = session.scalar( select(AssetModel) .where(AssetModel.uri == uri) @@ -80,7 +80,7 @@ def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse: @security.requires_access_asset("GET") @format_parameters({"limit": check_limit}) @provide_session -def get_datasets( +def get_assets( *, limit: int, offset: int = 0, @@ -109,18 +109,18 @@ def get_datasets( .offset(offset) .limit(limit) ).all() - return asset_collection_schema.dump(AssetCollection(datasets=assets, total_entries=total_entries)) + return asset_collection_schema.dump(AssetCollection(assets=assets, total_entries=total_entries)) @security.requires_access_asset("GET") @provide_session @format_parameters({"limit": check_limit}) -def get_dataset_events( +def get_asset_events( *, limit: int, offset: int = 0, order_by: str = "timestamp", - dataset_id: int | None = None, + asset_id: int | None = None, source_dag_id: str | None = None, source_task_id: str | None = None, source_run_id: str | None = None, @@ -132,8 +132,8 @@ def get_dataset_events( query = select(AssetEvent) - if dataset_id: - query = query.where(AssetEvent.dataset_id == dataset_id) + if asset_id: + query = query.where(AssetEvent.dataset_id == asset_id) if source_dag_id: query = query.where(AssetEvent.source_dag_id == source_dag_id) if source_task_id: @@ -149,14 +149,13 @@ def get_dataset_events( query = apply_sorting(query, order_by, {}, allowed_attrs) events = session.scalars(query.offset(offset).limit(limit)).all() return asset_event_collection_schema.dump( - AssetEventCollection(dataset_events=events, total_entries=total_entries) + AssetEventCollection(asset_events=events, total_entries=total_entries) ) def _generate_queued_event_where_clause( *, dag_id: str | None = None, - dataset_id: int | None = None, uri: str | None = None, before: str | None = None, permitted_dag_ids: set[str] | None = None, @@ -165,8 +164,6 @@ def _generate_queued_event_where_clause( where_clause = [] if dag_id is not None: where_clause.append(AssetDagRunQueue.target_dag_id == dag_id) - if dataset_id is not None: - where_clause.append(AssetDagRunQueue.dataset_id == dataset_id) if uri is not None: where_clause.append( AssetDagRunQueue.dataset_id.in_( @@ -183,7 +180,7 @@ def _generate_queued_event_where_clause( @security.requires_access_asset("GET") @security.requires_access_dag("GET") @provide_session -def get_dag_dataset_queued_event( +def get_dag_asset_queued_event( *, dag_id: str, uri: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: """Get a queued asset event for a DAG.""" @@ -206,7 +203,7 @@ def get_dag_dataset_queued_event( @security.requires_access_dag("GET") @provide_session @action_logging -def delete_dag_dataset_queued_event( +def delete_dag_asset_queued_event( *, dag_id: str, uri: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: """Delete a queued asset event for a DAG.""" @@ -224,7 +221,7 @@ def delete_dag_dataset_queued_event( @security.requires_access_asset("GET") @security.requires_access_dag("GET") @provide_session -def get_dag_dataset_queued_events( +def get_dag_asset_queued_events( *, dag_id: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: """Get queued asset events for a DAG.""" @@ -253,7 +250,7 @@ def get_dag_dataset_queued_events( @security.requires_access_dag("GET") @action_logging @provide_session -def delete_dag_dataset_queued_events( +def delete_dag_asset_queued_events( *, dag_id: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: """Delete queued asset events for a DAG.""" @@ -271,7 +268,7 @@ def delete_dag_dataset_queued_events( @security.requires_access_asset("GET") @provide_session -def get_dataset_queued_events( +def get_asset_queued_events( *, uri: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: """Get queued asset events for an asset.""" @@ -303,7 +300,7 @@ def get_dataset_queued_events( @security.requires_access_asset("DELETE") @action_logging @provide_session -def delete_dataset_queued_events( +def delete_asset_queued_events( *, uri: str, before: str | None = None, session: Session = NEW_SESSION ) -> APIResponse: """Delete queued asset events for an asset.""" @@ -325,7 +322,7 @@ def delete_dataset_queued_events( @security.requires_access_asset("POST") @provide_session @action_logging -def create_dataset_event(session: Session = NEW_SESSION) -> APIResponse: +def create_asset_event(session: Session = NEW_SESSION) -> APIResponse: """Create asset event.""" body = get_json_request_dict() try: @@ -333,7 +330,7 @@ def create_dataset_event(session: Session = NEW_SESSION) -> APIResponse: except ValidationError as err: raise BadRequest(detail=str(err)) - uri = json_body["dataset_uri"] + uri = json_body["asset_uri"] asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri).limit(1)) if not asset: raise NotFound(title="Asset not found", detail=f"Asset with uri: '{uri}' not found") @@ -341,7 +338,7 @@ def create_dataset_event(session: Session = NEW_SESSION) -> APIResponse: extra = json_body.get("extra", {}) extra["from_rest_api"] = True asset_event = asset_manager.register_asset_change( - asset=Asset(uri), + asset=Asset(uri=uri), timestamp=timestamp, extra=extra, session=session, diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 02d4663837f4e..44891c0ef2c84 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -114,10 +114,8 @@ def get_dag_run( @security.requires_access_dag("GET", DagAccessEntity.RUN) @security.requires_access_asset("GET") @provide_session -def get_upstream_dataset_events( - *, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION -) -> APIResponse: - """If dag run is dataset-triggered, return the asset events that triggered it.""" +def get_upstream_asset_events(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: + """If dag run is asset-triggered, return the asset events that triggered it.""" dag_run: DagRun | None = session.scalar( select(DagRun).where( DagRun.dag_id == dag_id, @@ -131,7 +129,7 @@ def get_upstream_dataset_events( ) events = dag_run.consumed_dataset_events return asset_event_collection_schema.dump( - AssetEventCollection(dataset_events=events, total_entries=len(events)) + AssetEventCollection(asset_events=events, total_entries=len(events)) ) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 15ad6fd8a4f63..828a3af25e879 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -1181,26 +1181,26 @@ paths: "404": $ref: "#/components/responses/NotFound" - /dags/{dag_id}/dagRuns/{dag_run_id}/upstreamDatasetEvents: + /dags/{dag_id}/dagRuns/{dag_run_id}/upstreamAssetEvents: parameters: - $ref: "#/components/parameters/DAGID" - $ref: "#/components/parameters/DAGRunID" get: - summary: Get dataset events for a DAG run + summary: Get asset events for a DAG run description: | - Get datasets for a dag run. + Get asset for a dag run. *New in version 2.4.0* x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint - operationId: get_upstream_dataset_events - tags: [DAGRun, Dataset] + operationId: get_upstream_asset_events + tags: [DAGRun, Asset] responses: "200": description: Success. content: application/json: schema: - $ref: "#/components/schemas/DatasetEventCollection" + $ref: "#/components/schemas/AssetEventCollection" "401": $ref: "#/components/responses/Unauthenticated" "403": @@ -1245,22 +1245,22 @@ paths: "404": $ref: "#/components/responses/NotFound" - /dags/{dag_id}/datasets/queuedEvent/{uri}: + /dags/{dag_id}/assets/queuedEvent/{uri}: parameters: - $ref: "#/components/parameters/DAGID" - - $ref: "#/components/parameters/DatasetURI" + - $ref: "#/components/parameters/AssetURI" get: - summary: Get a queued Dataset event for a DAG + summary: Get a queued asset event for a DAG description: | - Get a queued Dataset event for a DAG. + Get a queued asset event for a DAG. *New in version 2.9.0* - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: get_dag_dataset_queued_event + x-openapi-router-controller: airflow.api_connexion.endpoints.asset_endpoint + operationId: get_dag_asset_queued_event parameters: - $ref: "#/components/parameters/Before" - tags: [Dataset] + tags: [Asset] responses: "200": description: Success. @@ -1276,16 +1276,16 @@ paths: $ref: "#/components/responses/NotFound" delete: - summary: Delete a queued Dataset event for a DAG. + summary: Delete a queued Asset event for a DAG. description: | - Delete a queued Dataset event for a DAG. + Delete a queued Asset event for a DAG. *New in version 2.9.0* - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: delete_dag_dataset_queued_event + x-openapi-router-controller: airflow.api_connexion.endpoints.asset_endpoint + operationId: delete_dag_asset_queued_event parameters: - $ref: "#/components/parameters/Before" - tags: [Dataset] + tags: [Asset] responses: "204": description: Success. @@ -1298,21 +1298,21 @@ paths: "404": $ref: "#/components/responses/NotFound" - /dags/{dag_id}/datasets/queuedEvent: + /dags/{dag_id}/assets/queuedEvent: parameters: - $ref: "#/components/parameters/DAGID" get: - summary: Get queued Dataset events for a DAG. + summary: Get queued Asset events for a DAG. description: | - Get queued Dataset events for a DAG. + Get queued Asset events for a DAG. *New in version 2.9.0* - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: get_dag_dataset_queued_events + x-openapi-router-controller: airflow.api_connexion.endpoints.asset_endpoint + operationId: get_dag_asset_queued_events parameters: - $ref: "#/components/parameters/Before" - tags: [Dataset] + tags: [Asset] responses: "200": description: Success. @@ -1328,16 +1328,16 @@ paths: $ref: "#/components/responses/NotFound" delete: - summary: Delete queued Dataset events for a DAG. + summary: Delete queued Asset events for a DAG. description: | - Delete queued Dataset events for a DAG. + Delete queued Asset events for a DAG. *New in version 2.9.0* - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: delete_dag_dataset_queued_events + x-openapi-router-controller: airflow.api_connexion.endpoints.asset_endpoint + operationId: delete_dag_asset_queued_events parameters: - $ref: "#/components/parameters/Before" - tags: [Dataset] + tags: [Asset] responses: "204": description: Success. @@ -1371,21 +1371,21 @@ paths: "404": $ref: "#/components/responses/NotFound" - /datasets/queuedEvent/{uri}: + /assets/queuedEvent/{uri}: parameters: - - $ref: "#/components/parameters/DatasetURI" + - $ref: "#/components/parameters/AssetURI" get: - summary: Get queued Dataset events for a Dataset. + summary: Get queued Asset events for an Asset. description: | - Get queued Dataset events for a Dataset + Get queued Asset events for an Asset *New in version 2.9.0* - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: get_dataset_queued_events + x-openapi-router-controller: airflow.api_connexion.endpoints.asset_endpoint + operationId: get_asset_queued_events parameters: - $ref: "#/components/parameters/Before" - tags: [Dataset] + tags: [Asset] responses: "200": description: Success. @@ -1401,16 +1401,16 @@ paths: $ref: "#/components/responses/NotFound" delete: - summary: Delete queued Dataset events for a Dataset. + summary: Delete queued Asset events for an Asset. description: | - Delete queued Dataset events for a Dataset. + Delete queued Asset events for a Asset. *New in version 2.9.0* - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: delete_dataset_queued_events + x-openapi-router-controller: airflow.api_connexion.endpoints.asset_endpoint + operationId: delete_asset_queued_events parameters: - $ref: "#/components/parameters/Before" - tags: [Dataset] + tags: [Asset] responses: "204": description: Success. @@ -2517,12 +2517,12 @@ paths: "403": $ref: "#/components/responses/PermissionDenied" - /datasets: + /assets: get: - summary: List datasets - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: get_datasets - tags: [Dataset] + summary: List assets + x-openapi-router-controller: airflow.api_connexion.endpoints.asset_endpoint + operationId: get_assets + tags: [Asset] parameters: - $ref: "#/components/parameters/PageLimit" - $ref: "#/components/parameters/PageOffset" @@ -2533,14 +2533,14 @@ paths: type: string required: false description: | - If set, only return datasets with uris matching this pattern. + If set, only return assets with uris matching this pattern. - name: dag_ids in: query schema: type: string required: false description: | - One or more DAG IDs separated by commas to filter datasets by associated DAGs either consuming or producing. + One or more DAG IDs separated by commas to filter assets by associated DAGs either consuming or producing. *New in version 2.9.0* responses: @@ -2549,28 +2549,28 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/DatasetCollection" + $ref: "#/components/schemas/AssetCollection" "401": $ref: "#/components/responses/Unauthenticated" "403": $ref: "#/components/responses/PermissionDenied" - /datasets/{uri}: + /assets/{uri}: parameters: - - $ref: "#/components/parameters/DatasetURI" + - $ref: "#/components/parameters/AssetURI" get: - summary: Get a dataset - description: Get a dataset by uri. - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: get_dataset - tags: [Dataset] + summary: Get an asset + description: Get an asset by uri. + x-openapi-router-controller: airflow.api_connexion.endpoints.asset_endpoint + operationId: get_asset + tags: [Asset] responses: "200": description: Success. content: application/json: schema: - $ref: "#/components/schemas/Dataset" + $ref: "#/components/schemas/Asset" "401": $ref: "#/components/responses/Unauthenticated" "403": @@ -2578,18 +2578,18 @@ paths: "404": $ref: "#/components/responses/NotFound" - /datasets/events: + /assets/events: get: - summary: Get dataset events - description: Get dataset events - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: get_dataset_events - tags: [Dataset] + summary: Get asset events + description: Get asset events + x-openapi-router-controller: airflow.api_connexion.endpoints.asset_endpoint + operationId: get_asset_events + tags: [Asset] parameters: - $ref: "#/components/parameters/PageLimit" - $ref: "#/components/parameters/PageOffset" - $ref: "#/components/parameters/OrderBy" - - $ref: "#/components/parameters/FilterDatasetID" + - $ref: "#/components/parameters/FilterAssetID" - $ref: "#/components/parameters/FilterSourceDAGID" - $ref: "#/components/parameters/FilterSourceTaskID" - $ref: "#/components/parameters/FilterSourceRunID" @@ -2600,7 +2600,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/DatasetEventCollection" + $ref: "#/components/schemas/AssetEventCollection" "401": $ref: "#/components/responses/Unauthenticated" "403": @@ -2608,24 +2608,24 @@ paths: "404": $ref: "#/components/responses/NotFound" post: - summary: Create dataset event - description: Create dataset event - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: create_dataset_event - tags: [Dataset] + summary: Create asset event + description: Create asset event + x-openapi-router-controller: airflow.api_connexion.endpoints.asset_endpoint + operationId: create_asset_event + tags: [Asset] requestBody: required: true content: application/json: schema: - $ref: '#/components/schemas/CreateDatasetEvent' + $ref: '#/components/schemas/CreateAssetEvent' responses: '200': description: Success. content: application/json: schema: - $ref: '#/components/schemas/DatasetEvent' + $ref: '#/components/schemas/AssetEvent' "400": $ref: "#/components/responses/BadRequest" '401': @@ -4133,7 +4133,7 @@ components: nullable: true dataset_expression: type: object - description: Nested dataset any/all conditions + description: Nested asset any/all conditions nullable: true doc_md: type: string @@ -4507,133 +4507,133 @@ components: $ref: "#/components/schemas/Resource" description: The permission resource - Dataset: + Asset: description: | - A dataset item. + An asset item. *New in version 2.4.0* type: object properties: id: type: integer - description: The dataset id + description: The asset id uri: type: string - description: The dataset uri + description: The asset uri nullable: false extra: type: object - description: The dataset extra + description: The asset extra nullable: true created_at: type: string - description: The dataset creation time + description: The asset creation time nullable: false updated_at: type: string - description: The dataset update time + description: The asset update time nullable: false consuming_dags: type: array items: - $ref: "#/components/schemas/DagScheduleDatasetReference" + $ref: "#/components/schemas/DagScheduleAssetReference" producing_tasks: type: array items: - $ref: "#/components/schemas/TaskOutletDatasetReference" + $ref: "#/components/schemas/TaskOutletAssetReference" - TaskOutletDatasetReference: + TaskOutletAssetReference: description: | - A datasets reference to an upstream task. + An asset reference to an upstream task. *New in version 2.4.0* type: object properties: dag_id: type: string - description: The DAG ID that updates the dataset. + description: The DAG ID that updates the asset. nullable: true task_id: type: string - description: The task ID that updates the dataset. + description: The task ID that updates the asset. nullable: true created_at: type: string - description: The dataset creation time + description: The asset creation time nullable: false updated_at: type: string - description: The dataset update time + description: The asset update time nullable: false - DagScheduleDatasetReference: + DagScheduleAssetReference: description: | - A datasets reference to a downstream DAG. + An asset reference to a downstream DAG. *New in version 2.4.0* type: object properties: dag_id: type: string - description: The DAG ID that depends on the dataset. + description: The DAG ID that depends on the asset. nullable: true created_at: type: string - description: The dataset reference creation time + description: The asset reference creation time nullable: false updated_at: type: string - description: The dataset reference update time + description: The asset reference update time nullable: false - DatasetCollection: + AssetCollection: description: | - A collection of datasets. + A collection of assets. *New in version 2.4.0* type: object allOf: - type: object properties: - datasets: + assets: type: array items: - $ref: "#/components/schemas/Dataset" + $ref: "#/components/schemas/Asset" - $ref: "#/components/schemas/CollectionInfo" - DatasetEvent: + AssetEvent: description: | - A dataset event. + An asset event. *New in version 2.4.0* type: object properties: dataset_id: type: integer - description: The dataset id + description: The asset id dataset_uri: type: string - description: The URI of the dataset + description: The URI of the asset nullable: false extra: type: object - description: The dataset event extra + description: The asset event extra nullable: true source_dag_id: type: string - description: The DAG ID that updated the dataset. + description: The DAG ID that updated the asset. nullable: true source_task_id: type: string - description: The task ID that updated the dataset. + description: The task ID that updated the asset. nullable: true source_run_id: type: string - description: The DAG run ID that updated the dataset. + description: The DAG run ID that updated the asset. nullable: true source_map_index: type: integer - description: The task map index that updated the dataset. + description: The task map index that updated the asset. nullable: true created_dagruns: type: array @@ -4641,21 +4641,21 @@ components: $ref: "#/components/schemas/BasicDAGRun" timestamp: type: string - description: The dataset event creation time + description: The asset event creation time nullable: false - CreateDatasetEvent: + CreateAssetEvent: type: object required: - - dataset_uri + - asset_uri properties: - dataset_uri: + asset_uri: type: string - description: The URI of the dataset + description: The URI of the asset nullable: false extra: type: object - description: The dataset event extra + description: The asset event extra nullable: true QueuedEvent: @@ -4663,7 +4663,7 @@ components: properties: uri: type: string - description: The datata uri. + description: The asset uri. dag_id: type: string description: The DAG ID. @@ -4674,14 +4674,14 @@ components: QueuedEventCollection: description: | - A collection of Dataset Dag Run Queues. + A collection of asset Dag Run Queues. *New in version 2.9.0* type: object allOf: - type: object properties: - datasets: + queued_events: type: array items: $ref: "#/components/schemas/QueuedEvent" @@ -4737,19 +4737,19 @@ components: state: $ref: "#/components/schemas/DagState" - DatasetEventCollection: + AssetEventCollection: description: | - A collection of dataset events. + A collection of asset events. *New in version 2.4.0* type: object allOf: - type: object properties: - dataset_events: + asset_events: type: array items: - $ref: "#/components/schemas/DatasetEvent" + $ref: "#/components/schemas/AssetEvent" - $ref: "#/components/schemas/CollectionInfo" # Configuration @@ -5545,14 +5545,14 @@ components: required: true description: The import error ID. - DatasetURI: + AssetURI: in: path name: uri schema: type: string format: path required: true - description: The encoded Dataset URI + description: The encoded Asset URI PoolName: in: path @@ -5733,40 +5733,40 @@ components: *New in version 2.2.0* - FilterDatasetID: + FilterAssetID: in: query - name: dataset_id + name: asset_id schema: type: integer - description: The Dataset ID that updated the dataset. + description: The Asset ID that updated the asset. FilterSourceDAGID: in: query name: source_dag_id schema: type: string - description: The DAG ID that updated the dataset. + description: The DAG ID that updated the asset. FilterSourceTaskID: in: query name: source_task_id schema: type: string - description: The task ID that updated the dataset. + description: The task ID that updated the asset. FilterSourceRunID: in: query name: source_run_id schema: type: string - description: The DAG run ID that updated the dataset. + description: The DAG run ID that updated the asset. FilterSourceMapIndex: in: query name: source_map_index schema: type: integer - description: The map index that updated the dataset. + description: The map index that updated the asset. FilterMapIndex: in: query @@ -6024,12 +6024,12 @@ components: security: [] tags: + - name: Asset - name: Config - name: Connection - name: DAG - name: DAGRun - name: DagWarning - - name: Dataset - name: EventLog - name: ImportError - name: Monitoring diff --git a/airflow/api_connexion/schemas/asset_schema.py b/airflow/api_connexion/schemas/asset_schema.py index 791941f42016d..662f73a50d8b9 100644 --- a/airflow/api_connexion/schemas/asset_schema.py +++ b/airflow/api_connexion/schemas/asset_schema.py @@ -93,14 +93,14 @@ class Meta: class AssetCollection(NamedTuple): """List of Assets with meta.""" - datasets: list[AssetModel] + assets: list[AssetModel] total_entries: int class AssetCollectionSchema(Schema): """Asset Collection Schema.""" - datasets = fields.List(fields.Nested(AssetSchema)) + assets = fields.List(fields.Nested(AssetSchema)) total_entries = fields.Int() @@ -150,21 +150,21 @@ class Meta: class AssetEventCollection(NamedTuple): """List of Asset events with meta.""" - dataset_events: list[AssetEvent] + asset_events: list[AssetEvent] total_entries: int class AssetEventCollectionSchema(Schema): """Asset Event Collection Schema.""" - dataset_events = fields.List(fields.Nested(AssetEventSchema)) + asset_events = fields.List(fields.Nested(AssetEventSchema)) total_entries = fields.Int() class CreateAssetEventSchema(Schema): """Create Asset Event Schema.""" - dataset_uri = fields.String() + asset_uri = fields.String() extra = JsonObjectField() diff --git a/airflow/api_fastapi/openapi/v1-generated.yaml b/airflow/api_fastapi/openapi/v1-generated.yaml index ce488a996af47..23c4ecf545d9f 100644 --- a/airflow/api_fastapi/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/openapi/v1-generated.yaml @@ -7,7 +7,7 @@ info: Users should not rely on those but use the public ones instead. version: 0.1.0 paths: - /ui/next_run_datasets/{dag_id}: + /ui/next_run_assets/{dag_id}: get: tags: - Asset diff --git a/airflow/api_fastapi/views/ui/assets.py b/airflow/api_fastapi/views/ui/assets.py index 01cc9fd1cfbff..4a4ad1d0df9b4 100644 --- a/airflow/api_fastapi/views/ui/assets.py +++ b/airflow/api_fastapi/views/ui/assets.py @@ -30,7 +30,7 @@ assets_router = AirflowRouter(tags=["Asset"]) -@assets_router.get("/next_run_datasets/{dag_id}", include_in_schema=False) +@assets_router.get("/next_run_assets/{dag_id}", include_in_schema=False) async def next_run_assets( dag_id: str, request: Request, diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 0aefb56d06e66..0e91fa416571e 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -30,7 +30,7 @@ export class AssetService { ): CancelablePromise { return __request(OpenAPI, { method: "GET", - url: "/ui/next_run_datasets/{dag_id}", + url: "/ui/next_run_assets/{dag_id}", path: { dag_id: data.dagId, }, diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index c37106abc8fcd..b87a172363584 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -203,7 +203,7 @@ export type DeleteConnectionData = { export type DeleteConnectionResponse = void; export type $OpenApiTs = { - "/ui/next_run_datasets/{dag_id}": { + "/ui/next_run_assets/{dag_id}": { get: { req: NextRunAssetsData; res: { diff --git a/airflow/www/static/js/api/index.ts b/airflow/www/static/js/api/index.ts index a4a45a08bfef8..c2a9885b2c7ea 100644 --- a/airflow/www/static/js/api/index.ts +++ b/airflow/www/static/js/api/index.ts @@ -32,14 +32,14 @@ import useMarkTaskDryRun from "./useMarkTaskDryRun"; import useGraphData from "./useGraphData"; import useGridData from "./useGridData"; import useMappedInstances from "./useMappedInstances"; -import useDatasets from "./useDatasets"; -import useDatasetsSummary from "./useDatasetsSummary"; -import useDataset from "./useDataset"; -import useDatasetDependencies from "./useDatasetDependencies"; -import useDatasetEvents from "./useDatasetEvents"; +import useAssets from "./useAssets"; +import useAssetsSummary from "./useAssetsSummary"; +import useAsset from "./useAsset"; +import useAssetDependencies from "./useAssetDependencies"; +import useAssetEvents from "./useAssetEvents"; import useSetDagRunNote from "./useSetDagRunNote"; import useSetTaskInstanceNote from "./useSetTaskInstanceNote"; -import useUpstreamDatasetEvents from "./useUpstreamDatasetEvents"; +import useUpstreamAssetEvents from "./useUpstreamAssetEvents"; import useTaskInstance from "./useTaskInstance"; import useTaskFailedDependency from "./useTaskFailedDependency"; import useDag from "./useDag"; @@ -53,7 +53,7 @@ import useHistoricalMetricsData from "./useHistoricalMetricsData"; import { useTaskXcomEntry, useTaskXcomCollection } from "./useTaskXcom"; import useEventLogs from "./useEventLogs"; import useCalendarData from "./useCalendarData"; -import useCreateDatasetEvent from "./useCreateDatasetEvent"; +import useCreateAssetEvent from "./useCreateAssetEvent"; import useRenderedK8s from "./useRenderedK8s"; import useTaskDetail from "./useTaskDetail"; import useTIHistory from "./useTIHistory"; @@ -85,11 +85,11 @@ export { useDagDetails, useDagRuns, useDags, - useDataset, - useDatasets, - useDatasetDependencies, - useDatasetEvents, - useDatasetsSummary, + useAsset, + useAssets, + useAssetDependencies, + useAssetEvents, + useAssetsSummary, useExtraLinks, useGraphData, useGridData, @@ -105,14 +105,14 @@ export { useSetDagRunNote, useSetTaskInstanceNote, useTaskInstance, - useUpstreamDatasetEvents, + useUpstreamAssetEvents, useHistoricalMetricsData, useTaskXcomEntry, useTaskXcomCollection, useTaskFailedDependency, useEventLogs, useCalendarData, - useCreateDatasetEvent, + useCreateAssetEvent, useRenderedK8s, useTaskDetail, useTIHistory, diff --git a/airflow/www/static/js/api/useDataset.ts b/airflow/www/static/js/api/useAsset.ts similarity index 86% rename from airflow/www/static/js/api/useDataset.ts rename to airflow/www/static/js/api/useAsset.ts index 4793464fac378..b490ca6e46565 100644 --- a/airflow/www/static/js/api/useDataset.ts +++ b/airflow/www/static/js/api/useAsset.ts @@ -27,12 +27,12 @@ interface Props { uri: string; } -export default function useDataset({ uri }: Props) { +export default function useAsset({ uri }: Props) { return useQuery(["dataset", uri], () => { - const datasetUrl = getMetaValue("dataset_api").replace( + const datasetUrl = getMetaValue("asset_api").replace( "__URI__", encodeURIComponent(uri) ); - return axios.get(datasetUrl); + return axios.get(datasetUrl); }); } diff --git a/airflow/www/static/js/api/useDatasetDependencies.ts b/airflow/www/static/js/api/useAssetDependencies.ts similarity index 94% rename from airflow/www/static/js/api/useDatasetDependencies.ts rename to airflow/www/static/js/api/useAssetDependencies.ts index d2ba627f64458..11e7219c53fc8 100644 --- a/airflow/www/static/js/api/useDatasetDependencies.ts +++ b/airflow/www/static/js/api/useAssetDependencies.ts @@ -82,15 +82,15 @@ const formatDependencies = async ({ edges, nodes }: DatasetDependencies) => { return graph as DatasetGraph; }; -export default function useDatasetDependencies() { +export default function useAssetDependencies() { return useQuery("datasetDependencies", async () => { const datasetDepsUrl = getMetaValue("dataset_dependencies_url"); return axios.get(datasetDepsUrl); }); } -export const useDatasetGraphs = () => { - const { data: datasetDependencies } = useDatasetDependencies(); +export const useAssetGraphs = () => { + const { data: datasetDependencies } = useAssetDependencies(); return useQuery(["datasetGraphs", datasetDependencies], () => { if (datasetDependencies) { return formatDependencies(datasetDependencies); diff --git a/airflow/www/static/js/api/useDatasetEvents.ts b/airflow/www/static/js/api/useAssetEvents.ts similarity index 80% rename from airflow/www/static/js/api/useDatasetEvents.ts rename to airflow/www/static/js/api/useAssetEvents.ts index 30e4670a87d3e..068bb471ef64e 100644 --- a/airflow/www/static/js/api/useDatasetEvents.ts +++ b/airflow/www/static/js/api/useAssetEvents.ts @@ -23,16 +23,16 @@ import { useQuery, UseQueryOptions } from "react-query"; import { getMetaValue } from "src/utils"; import URLSearchParamsWrapper from "src/utils/URLSearchParamWrapper"; import type { - DatasetEventCollection, - GetDatasetEventsVariables, + AssetEventCollection, + GetAssetEventsVariables, } from "src/types/api-generated"; -interface Props extends GetDatasetEventsVariables { - options?: UseQueryOptions; +interface Props extends GetAssetEventsVariables { + options?: UseQueryOptions; } -const useDatasetEvents = ({ - datasetId, +const useAssetEvents = ({ + assetId, sourceDagId, sourceRunId, sourceTaskId, @@ -42,10 +42,10 @@ const useDatasetEvents = ({ orderBy, options, }: Props) => { - const query = useQuery( + const query = useQuery( [ "datasets-events", - datasetId, + assetId, sourceDagId, sourceRunId, sourceTaskId, @@ -55,14 +55,14 @@ const useDatasetEvents = ({ orderBy, ], () => { - const datasetsUrl = getMetaValue("dataset_events_api"); + const datasetsUrl = getMetaValue("asset_events_api"); const params = new URLSearchParamsWrapper(); if (limit) params.set("limit", limit.toString()); if (offset) params.set("offset", offset.toString()); if (orderBy) params.set("order_by", orderBy); - if (datasetId) params.set("dataset_id", datasetId.toString()); + if (assetId) params.set("asset_id", assetId.toString()); if (sourceDagId) params.set("source_dag_id", sourceDagId); if (sourceRunId) params.set("source_run_id", sourceRunId); if (sourceTaskId) params.set("source_task_id", sourceTaskId); @@ -80,8 +80,8 @@ const useDatasetEvents = ({ ); return { ...query, - data: query.data ?? { datasetEvents: [], totalEntries: 0 }, + data: query.data ?? { assetEvents: [], totalEntries: 0 }, }; }; -export default useDatasetEvents; +export default useAssetEvents; diff --git a/airflow/www/static/js/api/useDatasets.ts b/airflow/www/static/js/api/useAssets.ts similarity index 90% rename from airflow/www/static/js/api/useDatasets.ts rename to airflow/www/static/js/api/useAssets.ts index db46415062c1a..3654c583c12ef 100644 --- a/airflow/www/static/js/api/useDatasets.ts +++ b/airflow/www/static/js/api/useAssets.ts @@ -28,7 +28,7 @@ interface Props { enabled?: boolean; } -export default function useDatasets({ dagIds, enabled = true }: Props) { +export default function useAssets({ dagIds, enabled = true }: Props) { return useQuery( ["datasets", dagIds], () => { @@ -36,7 +36,7 @@ export default function useDatasets({ dagIds, enabled = true }: Props) { const dagIdsParam = dagIds && dagIds.length ? { dag_ids: dagIds.join(",") } : {}; - return axios.get(datasetsUrl, { + return axios.get(datasetsUrl, { params: { ...dagIdsParam, }, diff --git a/airflow/www/static/js/api/useDatasetsSummary.ts b/airflow/www/static/js/api/useAssetsSummary.ts similarity index 98% rename from airflow/www/static/js/api/useDatasetsSummary.ts rename to airflow/www/static/js/api/useAssetsSummary.ts index 6f902946f6296..66b56ca9f6925 100644 --- a/airflow/www/static/js/api/useDatasetsSummary.ts +++ b/airflow/www/static/js/api/useAssetsSummary.ts @@ -42,7 +42,7 @@ interface Props { updatedAfter?: DateOption; } -export default function useDatasetsSummary({ +export default function useAssetsSummary({ limit, offset, order, diff --git a/airflow/www/static/js/api/useCreateDatasetEvent.ts b/airflow/www/static/js/api/useCreateAssetEvent.ts similarity index 77% rename from airflow/www/static/js/api/useCreateDatasetEvent.ts rename to airflow/www/static/js/api/useCreateAssetEvent.ts index f14b35ee375fe..7d2322c33ce9d 100644 --- a/airflow/www/static/js/api/useCreateDatasetEvent.ts +++ b/airflow/www/static/js/api/useCreateAssetEvent.ts @@ -29,22 +29,19 @@ interface Props { uri?: string; } -const createDatasetUrl = getMetaValue("create_dataset_event_api"); +const createAssetUrl = getMetaValue("create_asset_event_api"); -export default function useCreateDatasetEvent({ datasetId, uri }: Props) { +export default function useCreateAssetEvent({ datasetId, uri }: Props) { const queryClient = useQueryClient(); const errorToast = useErrorToast(); return useMutation( - ["createDatasetEvent", uri], - (extra?: API.DatasetEvent["extra"]) => - axios.post( - createDatasetUrl, - { - dataset_uri: uri, - extra: extra || {}, - } - ), + ["createAssetEvent", uri], + (extra?: API.AssetEvent["extra"]) => + axios.post(createAssetUrl, { + asset_uri: uri, + extra: extra || {}, + }), { onSuccess: () => { queryClient.invalidateQueries(["datasets-events", datasetId]); diff --git a/airflow/www/static/js/api/useUpstreamDatasetEvents.ts b/airflow/www/static/js/api/useUpstreamAssetEvents.ts similarity index 67% rename from airflow/www/static/js/api/useUpstreamDatasetEvents.ts rename to airflow/www/static/js/api/useUpstreamAssetEvents.ts index 32d1c7aeff2d8..437205501d6c5 100644 --- a/airflow/www/static/js/api/useUpstreamDatasetEvents.ts +++ b/airflow/www/static/js/api/useUpstreamAssetEvents.ts @@ -22,30 +22,30 @@ import { useQuery, UseQueryOptions } from "react-query"; import { getMetaValue } from "src/utils"; import type { - DatasetEventCollection, - GetUpstreamDatasetEventsVariables, + AssetEventCollection, + GetUpstreamAssetEventsVariables, } from "src/types/api-generated"; -interface Props extends GetUpstreamDatasetEventsVariables { - options?: UseQueryOptions; +interface Props extends GetUpstreamAssetEventsVariables { + options?: UseQueryOptions; } -const useUpstreamDatasetEvents = ({ dagId, dagRunId, options }: Props) => { +const useUpstreamAssetEvents = ({ dagId, dagRunId, options }: Props) => { const upstreamEventsUrl = ( - getMetaValue("upstream_dataset_events_api") || - `api/v1/dags/${dagId}/dagRuns/_DAG_RUN_ID_/upstreamDatasetEvents` + getMetaValue("upstream_asset_events_api") || + `api/v1/dags/${dagId}/dagRuns/_DAG_RUN_ID_/upstreamAssetEvents` ).replace("_DAG_RUN_ID_", encodeURIComponent(dagRunId)); - const query = useQuery( - ["upstreamDatasetEvents", dagRunId], + const query = useQuery( + ["upstreamAssetEvents", dagRunId], () => axios.get(upstreamEventsUrl), options ); return { ...query, - data: query.data ?? { datasetEvents: [], totalEntries: 0 }, + data: query.data ?? { assetEvents: [], totalEntries: 0 }, }; }; -export default useUpstreamDatasetEvents; +export default useUpstreamAssetEvents; diff --git a/airflow/www/static/js/components/DatasetEventCard.tsx b/airflow/www/static/js/components/DatasetEventCard.tsx index 2367c8efa9b4a..9dd1ee91e3731 100644 --- a/airflow/www/static/js/components/DatasetEventCard.tsx +++ b/airflow/www/static/js/components/DatasetEventCard.tsx @@ -21,7 +21,7 @@ import React from "react"; import { isEmpty } from "lodash"; import { TbApi } from "react-icons/tb"; -import type { DatasetEvent } from "src/types/api-generated"; +import type { AssetEvent } from "src/types/api-generated"; import { Box, Flex, @@ -43,7 +43,7 @@ import SourceTaskInstance from "./SourceTaskInstance"; import TriggeredDagRuns from "./TriggeredDagRuns"; type CardProps = { - datasetEvent: DatasetEvent; + assetEvent: AssetEvent; showSource?: boolean; showTriggeredDagRuns?: boolean; }; @@ -51,7 +51,7 @@ type CardProps = { const datasetsUrl = getMetaValue("datasets_url"); const DatasetEventCard = ({ - datasetEvent, + assetEvent, showSource = true, showTriggeredDagRuns = true, }: CardProps) => { @@ -60,14 +60,16 @@ const DatasetEventCard = ({ const selectedUri = decodeURIComponent(searchParams.get("uri") || ""); const containerRef = useContainerRef(); - const { from_rest_api: fromRestApi, ...extra } = - datasetEvent?.extra as Record; + const { from_rest_api: fromRestApi, ...extra } = assetEvent?.extra as Record< + string, + string + >; return ( - @@ -111,17 +112,17 @@ const DatasetEventCard = ({ )} - {!!datasetEvent.sourceTaskId && ( - + {!!assetEvent.sourceTaskId && ( + )} )} - {showTriggeredDagRuns && !!datasetEvent?.createdDagruns?.length && ( + {showTriggeredDagRuns && !!assetEvent?.createdDagruns?.length && ( <> Triggered Dag Runs: - + )} diff --git a/airflow/www/static/js/components/SourceTaskInstance.tsx b/airflow/www/static/js/components/SourceTaskInstance.tsx index 4c63198c5f40c..4343d3ce82443 100644 --- a/airflow/www/static/js/components/SourceTaskInstance.tsx +++ b/airflow/www/static/js/components/SourceTaskInstance.tsx @@ -22,7 +22,7 @@ import { Box, Link, Tooltip, Flex } from "@chakra-ui/react"; import { FiLink } from "react-icons/fi"; import { useTaskInstance } from "src/api"; -import type { DatasetEvent } from "src/types/api-generated"; +import type { AssetEvent } from "src/types/api-generated"; import { useContainerRef } from "src/context/containerRef"; import { SimpleStatus } from "src/dag/StatusBox"; import InstanceTooltip from "src/components/InstanceTooltip"; @@ -30,20 +30,16 @@ import type { TaskInstance } from "src/types"; import { getMetaValue } from "src/utils"; type SourceTIProps = { - datasetEvent: DatasetEvent; + assetEvent: AssetEvent; showLink?: boolean; }; const gridUrl = getMetaValue("grid_url"); const dagId = getMetaValue("dag_id") || "__DAG_ID__"; -const SourceTaskInstance = ({ - datasetEvent, - showLink = true, -}: SourceTIProps) => { +const SourceTaskInstance = ({ assetEvent, showLink = true }: SourceTIProps) => { const containerRef = useContainerRef(); - const { sourceDagId, sourceRunId, sourceTaskId, sourceMapIndex } = - datasetEvent; + const { sourceDagId, sourceRunId, sourceTaskId, sourceMapIndex } = assetEvent; const { data: taskInstance } = useTaskInstance({ dagId: sourceDagId || "", diff --git a/airflow/www/static/js/dag/details/dagRun/DatasetTriggerEvents.tsx b/airflow/www/static/js/dag/details/dagRun/DatasetTriggerEvents.tsx index 5fa585830b437..6deedb073e8d9 100644 --- a/airflow/www/static/js/dag/details/dagRun/DatasetTriggerEvents.tsx +++ b/airflow/www/static/js/dag/details/dagRun/DatasetTriggerEvents.tsx @@ -19,10 +19,10 @@ import React, { useMemo } from "react"; import { Box, Text } from "@chakra-ui/react"; -import { useUpstreamDatasetEvents } from "src/api"; +import { useUpstreamAssetEvents } from "src/api"; import type { DagRun as DagRunType } from "src/types"; import { CardDef, CardList } from "src/components/Table"; -import type { DatasetEvent } from "src/types/api-generated"; +import type { AssetEvent } from "src/types/api-generated"; import DatasetEventCard from "src/components/DatasetEventCard"; import { getMetaValue } from "src/utils"; @@ -32,17 +32,17 @@ interface Props { const dagId = getMetaValue("dag_id"); -const cardDef: CardDef = { +const cardDef: CardDef = { card: ({ row }) => ( - + ), }; const DatasetTriggerEvents = ({ runId }: Props) => { const { - data: { datasetEvents = [] }, + data: { assetEvents = [] }, isLoading, - } = useUpstreamDatasetEvents({ dagRunId: runId, dagId }); + } = useUpstreamAssetEvents({ dagRunId: runId, dagId }); const columns = useMemo( () => [ @@ -66,7 +66,7 @@ const DatasetTriggerEvents = ({ runId }: Props) => { [] ); - const data = useMemo(() => datasetEvents, [datasetEvents]); + const data = useMemo(() => assetEvents, [assetEvents]); return ( diff --git a/airflow/www/static/js/dag/details/graph/DatasetNode.tsx b/airflow/www/static/js/dag/details/graph/DatasetNode.tsx index bfd288f072dc3..d80f399b032a3 100644 --- a/airflow/www/static/js/dag/details/graph/DatasetNode.tsx +++ b/airflow/www/static/js/dag/details/graph/DatasetNode.tsx @@ -47,11 +47,11 @@ import type { CustomNodeProps } from "./Node"; const datasetsUrl = getMetaValue("datasets_url"); const DatasetNode = ({ - data: { label, height, width, latestDagRunId, isZoomedOut, datasetEvent }, + data: { label, height, width, latestDagRunId, isZoomedOut, assetEvent }, }: NodeProps) => { const containerRef = useContainerRef(); - const { from_rest_api: fromRestApi } = (datasetEvent?.extra || {}) as Record< + const { from_rest_api: fromRestApi } = (assetEvent?.extra || {}) as Record< string, string >; @@ -61,8 +61,8 @@ const DatasetNode = ({ Dataset - {!!datasetEvent && ( + {!!assetEvent && ( {/* @ts-ignore */} - {moment(datasetEvent.timestamp).fromNow()} + {moment(assetEvent.timestamp).fromNow()} )} @@ -120,23 +120,23 @@ const DatasetNode = ({ {label} - {!!datasetEvent && ( + {!!assetEvent && ( -