diff --git a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml index 6a715fec57501..a42dd97d4bdfb 100644 --- a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml @@ -94,6 +94,7 @@ body: - presto - qdrant - redis + - remote - salesforce - samba - segment diff --git a/INSTALL b/INSTALL index 78506f9a571fc..7b3cac3282c2e 100644 --- a/INSTALL +++ b/INSTALL @@ -277,9 +277,9 @@ cncf.kubernetes, cohere, common.compat, common.io, common.sql, databricks, datad dingding, discord, docker, elasticsearch, exasol, fab, facebook, ftp, github, google, grpc, hashicorp, http, imap, influxdb, jdbc, jenkins, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mysql, neo4j, odbc, openai, openfaas, openlineage, opensearch, opsgenie, -oracle, pagerduty, papermill, pgvector, pinecone, postgres, presto, qdrant, redis, salesforce, -samba, segment, sendgrid, sftp, singularity, slack, smtp, snowflake, sqlite, ssh, tableau, tabular, -telegram, teradata, trino, vertica, weaviate, yandex, ydb, zendesk +oracle, pagerduty, papermill, pgvector, pinecone, postgres, presto, qdrant, redis, remote, +salesforce, samba, segment, sendgrid, sftp, singularity, slack, smtp, snowflake, sqlite, ssh, +tableau, tabular, telegram, teradata, trino, vertica, weaviate, yandex, ydb, zendesk # END PROVIDER EXTRAS HERE diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index d2b952c5b716d..d8e69cef68ea5 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -39,7 +39,7 @@ @functools.lru_cache -def _initialize_map() -> dict[str, Callable]: +def initialize_method_map() -> dict[str, Callable]: from airflow.cli.commands.task_command import _get_ti_db_access from airflow.dag_processing.manager import DagFileProcessorManager from airflow.dag_processing.processor import DagFileProcessor @@ -148,7 +148,7 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse: if json_rpc != "2.0": return log_and_build_error_response(message="Expected jsonrpc 2.0 request.", status=400) - methods_map = _initialize_map() + methods_map = initialize_method_map() method_name = body.get("method") if method_name not in methods_map: return log_and_build_error_response(message=f"Unrecognized method: {method_name}.", status=400) diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index c3a67d03ee18c..e9e665dadc8ce 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -21,10 +21,11 @@ import json import logging from functools import wraps -from typing import Callable, TypeVar +from typing import TYPE_CHECKING, Callable, TypeVar import requests import tenacity +from requests.auth import HTTPBasicAuth from urllib3.exceptions import NewConnectionError from airflow.configuration import conf @@ -32,6 +33,9 @@ from airflow.settings import _ENABLE_AIP_44 from airflow.typing_compat import ParamSpec +if TYPE_CHECKING: + from requests.auth import AuthBase + PS = ParamSpec("PS") RT = TypeVar("RT") @@ -44,6 +48,7 @@ class InternalApiConfig: _initialized = False _use_internal_api = False _internal_api_endpoint = "" + _internal_api_auth: AuthBase | None = None @staticmethod def force_database_direct_access(): @@ -56,6 +61,19 @@ def force_database_direct_access(): InternalApiConfig._initialized = True InternalApiConfig._use_internal_api = False + @staticmethod + def force_api_access(api_endpoint: str, auth: AuthBase): + """ + Force using Internal API with provided endpoint. + + All methods decorated with internal_api_call will always be executed remote/via API. + This mode is needed for remote setups/remote executor. + """ + InternalApiConfig._initialized = True + InternalApiConfig._use_internal_api = True + InternalApiConfig._internal_api_endpoint = api_endpoint + InternalApiConfig._internal_api_auth = auth + @staticmethod def get_use_internal_api(): if not InternalApiConfig._initialized: @@ -68,21 +86,31 @@ def get_internal_api_endpoint(): InternalApiConfig._init_values() return InternalApiConfig._internal_api_endpoint + @staticmethod + def get_auth() -> AuthBase | None: + return InternalApiConfig._internal_api_auth + @staticmethod def _init_values(): use_internal_api = conf.getboolean("core", "database_access_isolation", fallback=False) if use_internal_api and not _ENABLE_AIP_44: raise RuntimeError("The AIP_44 is not enabled so you cannot use it.") - internal_api_endpoint = "" if use_internal_api: - internal_api_url = conf.get("core", "internal_api_url") - internal_api_endpoint = internal_api_url + "/internal_api/v1/rpcapi" - if not internal_api_endpoint.startswith("http://"): - raise AirflowConfigException("[core]internal_api_url must start with http://") + internal_api_endpoint = conf.get("core", "internal_api_url") + if internal_api_endpoint.find("/", 8) == -1: + internal_api_endpoint = internal_api_endpoint + "/internal_api/v1/rpcapi" + if not internal_api_endpoint.startswith("http://") and not internal_api_endpoint.startswith( + "https://" + ): + raise AirflowConfigException("[core]internal_api_url must start with http:// or https://") + InternalApiConfig._internal_api_endpoint = internal_api_endpoint + internal_api_user = conf.get("core", "internal_api_user") + internal_api_password = conf.get("core", "internal_api_password") + if internal_api_user and internal_api_password: + InternalApiConfig._internal_api_auth = HTTPBasicAuth(internal_api_user, internal_api_password) InternalApiConfig._initialized = True InternalApiConfig._use_internal_api = use_internal_api - InternalApiConfig._internal_api_endpoint = internal_api_endpoint def internal_api_call(func: Callable[PS, RT]) -> Callable[PS, RT]: @@ -112,7 +140,8 @@ def internal_api_call(func: Callable[PS, RT]) -> Callable[PS, RT]: def make_jsonrpc_request(method_name: str, params_json: str) -> bytes: data = {"jsonrpc": "2.0", "method": method_name, "params": params_json} internal_api_endpoint = InternalApiConfig.get_internal_api_endpoint() - response = requests.post(url=internal_api_endpoint, data=json.dumps(data), headers=headers) + auth = InternalApiConfig.get_auth() + response = requests.post(url=internal_api_endpoint, data=json.dumps(data), headers=headers, auth=auth) if response.status_code != 200: raise AirflowException( f"Got {response.status_code}:{response.reason} when sending " diff --git a/airflow/example_dags/integration_test.py b/airflow/example_dags/integration_test.py new file mode 100644 index 0000000000000..0baa78a6b28b8 --- /dev/null +++ b/airflow/example_dags/integration_test.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +In this DAG all critical functions as integration test are contained. + +The DAG should work in all standard setups without error. +""" + +from __future__ import annotations + +from datetime import datetime + +from airflow.decorators import task, task_group +from airflow.exceptions import AirflowNotFoundException +from airflow.hooks.base import BaseHook +from airflow.models.dag import DAG +from airflow.models.param import Param +from airflow.models.variable import Variable +from airflow.operators.bash import BashOperator +from airflow.operators.empty import EmptyOperator +from airflow.operators.python import PythonOperator + +with DAG( + dag_id="integration_test", + dag_display_name="Integration Test", + description=__doc__.partition(".")[0], + doc_md=__doc__, + schedule=None, + start_date=datetime(2024, 7, 1), + tags=["example", "params", "integration test"], + params={ + "mapping_count": Param( + 4, + type="integer", + title="Mapping Count", + description="Amount of tasks that should be mapped", + ), + }, +) as dag: + + @task + def my_setup(): + print("Assume this is a setup task") + + @task + def mapping_from_params(**context) -> list[int]: + mapping_count: int = context["params"]["mapping_count"] + return list(range(1, mapping_count + 1)) + + @task + def add_one(x: int): + return x + 1 + + @task + def sum_it(values): + total = sum(values) + print(f"Total was {total}") + + @task_group(prefix_group_id=False) + def mapping_task_group(): + added_values = add_one.expand(x=mapping_from_params()) + sum_it(added_values) + + @task.branch + def branching(): + return ["bash", "virtualenv", "variable", "connection", "classic_bash", "classic_python"] + + @task.bash + def bash(): + return "echo hello world" + + @task.virtualenv(requirements="numpy") + def virtualenv(): + import numpy + + print(f"Welcome to virtualenv with numpy version {numpy.__version__}.") + + @task + def variable(): + Variable.set("integration_test_key", "value") + assert Variable.get("integration_test_key") == "value" # noqa: S101 + Variable.delete("integration_test_key") + + @task + def connection(): + try: + conn = BaseHook.get_connection("integration_test") + print(f"Got connection {conn}") + except AirflowNotFoundException: + print("Connection not found... but also OK.") + + @task_group(prefix_group_id=False) + def standard_tasks_group(): + classic_bash = BashOperator( + task_id="classic_bash", bash_command="echo Parameter is {{ params.mapping_count }}" + ) + + empty = EmptyOperator(task_id="not_executed") + + def python_call(): + print("Hello world") + + classic_py = PythonOperator(task_id="classic_python", python_callable=python_call) + + branching() >> [bash(), virtualenv(), variable(), connection(), classic_bash, classic_py, empty] + + @task + def my_teardown(): + print("Assume this is a teardown task") + + my_setup().as_setup() >> [mapping_task_group(), standard_tasks_group()] >> my_teardown().as_teardown() diff --git a/airflow/executors/executor_constants.py b/airflow/executors/executor_constants.py index 4e4923beb477b..f7cef9dbc1ea4 100644 --- a/airflow/executors/executor_constants.py +++ b/airflow/executors/executor_constants.py @@ -34,6 +34,7 @@ class ConnectorSource(Enum): CELERY_EXECUTOR = "CeleryExecutor" CELERY_KUBERNETES_EXECUTOR = "CeleryKubernetesExecutor" KUBERNETES_EXECUTOR = "KubernetesExecutor" +REMOTE_EXECUTOR = "RemoteExecutor" DEBUG_EXECUTOR = "DebugExecutor" MOCK_EXECUTOR = "MockExecutor" CORE_EXECUTOR_NAMES = { @@ -43,6 +44,7 @@ class ConnectorSource(Enum): CELERY_EXECUTOR, CELERY_KUBERNETES_EXECUTOR, KUBERNETES_EXECUTOR, + REMOTE_EXECUTOR, DEBUG_EXECUTOR, MOCK_EXECUTOR, } diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py index 31b9a369bc3fc..fe61590abcb21 100644 --- a/airflow/executors/executor_loader.py +++ b/airflow/executors/executor_loader.py @@ -34,6 +34,7 @@ KUBERNETES_EXECUTOR, LOCAL_EXECUTOR, LOCAL_KUBERNETES_EXECUTOR, + REMOTE_EXECUTOR, SEQUENTIAL_EXECUTOR, ConnectorSource, ) @@ -70,6 +71,7 @@ class ExecutorLoader: "executors.celery_kubernetes_executor.CeleryKubernetesExecutor", KUBERNETES_EXECUTOR: "airflow.providers.cncf.kubernetes." "executors.kubernetes_executor.KubernetesExecutor", + REMOTE_EXECUTOR: "airflow.providers.remote.executors.RemoteExecutor", DEBUG_EXECUTOR: "airflow.executors.debug_executor.DebugExecutor", } @@ -334,6 +336,9 @@ def validate_database_executor_compatibility(cls, executor: type[BaseExecutor]) if InternalApiConfig.get_use_internal_api(): return + if executor.__name__ == REMOTE_EXECUTOR: + return + from airflow.settings import engine # SQLite only works with single threaded executors diff --git a/airflow/providers/remote/CHANGELOG.rst b/airflow/providers/remote/CHANGELOG.rst new file mode 100644 index 0000000000000..5eaf0ae9d514a --- /dev/null +++ b/airflow/providers/remote/CHANGELOG.rst @@ -0,0 +1,38 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + +``apache-airflow-providers-remote-executor`` + + +Changelog +--------- + +0.1.0 +..... + +|experimental| + +Initial version of the provider. + +.. note:: + This provider is currently experimental diff --git a/airflow/providers/remote/TODO.md b/airflow/providers/remote/TODO.md new file mode 100644 index 0000000000000..d56ca419a1513 --- /dev/null +++ b/airflow/providers/remote/TODO.md @@ -0,0 +1,133 @@ + + +# Implementation Status + +https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-69+Remote+Executor + +## Primary Functionality - MVP + +- [x] Model -> DB Table creation +- [ ] Breeze + - [x] Support RemoteExecutor + - [ ] Start Remote Worker panel + - [x] Hatch dynamically load plugin(s) +- [x] Bootstrap Provider Package +- [ ] Executor class + - [x] Writes new jobs + - [ ] Acknowledge success/fail + - [ ] Can terminate a job + - [ ] Archiving of job table +- [ ] Plugin + - [x] REST API + - [ ] REST API Authentication + - [x] Expose jobs via UI + - [x] Expose active remote worker +- [ ] Remote Worker + - [x] CLI + - [x] Get a job and execute it + - [x] Report result + - [x] Heartbeat + - [x] Queues + - [x] Retry on connection loss + - [x] Send logs (basic) + - [ ] Send logs also to TaskFileHandler + Archive logs on completions (DB) + - [ ] Can terminate job + - [ ] Check version match + - [ ] Handle SIG-INT/CTRL+C and gracefully terminate and complete job + - [ ] Add a stop command +- [ ] Web UI + - [ ] Show logs while executing + - [x] Show logs after completion +- [x] Configurability +- [x] Documentation + - [x] References in Airflow core + - [x] Provider Package docs +- [ ] Tests + - [ ] Pytest + - [ ] Breeze integration tests in Github +- [x] Known problems + - [x] AIP-44 related + - [x] Mapped Tasks? + - [x] Branching Operator/SkipMixin + - [x] RenderedTaskinstanceFields + - [x] Templated Fields? + - [x] AIP-44 Integration Tests +- [ ] AIP-69 + - [x] Draft + - [x] Update specs + - [ ] Vote + +## Future Feature Collection + +- [ ] Support for API token on top of normal auth +- [ ] API token per worker +- [ ] Plugin + - [ ] Overview about queues + - [ ] Allow starting REST API separate + - [ ] Administrative maintenance / temporary disable jobs on worker +- [ ] Remote Worker + - [x] Multiple jobs / concurrency + - [ ] Publish system metrics with heartbeats + - [ ] Integration into telemetry to send metrics +- [ ] API token provisioning can be automated +- [ ] Test/Support on Windows +- [ ] Scaling test +- [ ] Airflow 3 / AIP-72 Migration + - [ ] Thin deployment + - [ ] DAG Code push (no need to GIT Sync) + - [ ] Move task context generation from Remote to Executor ("Need to know", depends on Task Execution API) + +## Notes + +### Test on Windows + +Create wheel on Linux: + +``` bash +breeze release-management generate-constraints --python 3.10 +breeze release-management prepare-provider-packages --package-format wheel --include-removed-providers remote +breeze release-management prepare-airflow-package +``` + +Copy the files to Windows + +On Windows "cheat sheet", Assume Python 3.10 installed, files mounted in Z:\Temp: + +``` text +python -m venv airflow-venv +airflow-venv\Scripts\activate.bat + +pip install --constraint Z:\temp\constraints-source-providers-3.10.txt Z:\temp\apache_airflow_providers_remote-0.1.0-py3-none-any.whl Z:\temp\apache_airflow-2.10.0.dev0-py3-none-any.whl + +set AIRFLOW_ENABLE_AIP_44=true +set AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION=True +set AIRFLOW__CORE__INTERNAL_API_URL=http://nas:8080/remote_worker/v1/rpcapi +set AIRFLOW__SCHEDULER__SCHEDULE_AFTER_TASK_EXECUTION=False +set AIRFLOW__CORE__EXECUTOR=RemoteExecutor +set AIRFLOW__CORE__DAGS_FOLDER=dags +set AIRFLOW__LOGGING__BASE_LOG_FOLDER=logs + +airflow remote worker --concurrency 4 --queues windows +``` + +Notes on Windows: + +- PR https://github.com/apache/airflow/pull/40424 fixes PythonOperator +- Log folder temple must replace run_id colons _or_ DAG must be triggered with Run ID w/o colons as not allowed as file name in Windows diff --git a/airflow/providers/remote/__init__.py b/airflow/providers/remote/__init__.py new file mode 100644 index 0000000000000..92b5989e4627d --- /dev/null +++ b/airflow/providers/remote/__init__.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE +# OVERWRITTEN WHEN PREPARING DOCUMENTATION FOR THE PACKAGES. +# +# IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE +# `PROVIDER__INIT__PY_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY +# +from __future__ import annotations + +import packaging.version + +from airflow import __version__ as airflow_version + +__all__ = ["__version__"] + +__version__ = "0.1.0" + +if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( + "2.10.0" +): + raise RuntimeError( + f"The package `apache-airflow-providers-remote-executor:{__version__}` needs Apache Airflow 2.10.0+" + ) diff --git a/airflow/providers/remote/_start_remote_worker.sh b/airflow/providers/remote/_start_remote_worker.sh new file mode 100755 index 0000000000000..8c9190a85d097 --- /dev/null +++ b/airflow/providers/remote/_start_remote_worker.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +unset AIRFLOW__DATABASE__SQL_ALCHEMY_CONN +unset AIRFLOW__CELERY__RESULT_BACKEND +unset POSTGRES_HOST_PORT +unset BACKEND +unset POSTGRES_VERSION +unset DATABASE_ISOLATION + +export AIRFLOW_ENABLE_AIP_44=true +export AIRFLOW__REMOTE__API_ENABLED=true +export AIRFLOW__REMOTE__API_URL=http://localhost:8080/remote_worker/v1/rpcapi +export AIRFLOW__SCHEDULER__SCHEDULE_AFTER_TASK_EXECUTION=False + +# Ensure logs are smelling like remote and are not visible to other components +export AIRFLOW__LOGGING__BASE_LOG_FOLDER=remote_logs + +airflow remote worker --concurrency 8 --user admin --password admin + +# Eventually start with: +# airflow remote worker --concurrency 8 --queues remote + + +# Note: Webserver must be started with: +# AIRFLOW__API__AUTH_BACKENDS=airflow.providers.fab.auth_manager.api.auth.backend.basic_auth AIRFLOW__REMOTE__API_ENABLED=true airflow webserver diff --git a/airflow/providers/remote/api_endpoints/__init__.py b/airflow/providers/remote/api_endpoints/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/remote/api_endpoints/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/remote/api_endpoints/health_endpoint.py b/airflow/providers/remote/api_endpoints/health_endpoint.py new file mode 100644 index 0000000000000..a6c8a9c7950da --- /dev/null +++ b/airflow/providers/remote/api_endpoints/health_endpoint.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + + +def health(): + return {} diff --git a/airflow/providers/remote/api_endpoints/rpc_api_endpoint.py b/airflow/providers/remote/api_endpoints/rpc_api_endpoint.py new file mode 100644 index 0000000000000..2c29911e5ecb3 --- /dev/null +++ b/airflow/providers/remote/api_endpoints/rpc_api_endpoint.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import functools +import json +import logging +from typing import TYPE_CHECKING, Any, Callable +from uuid import uuid4 + +from flask import Response + +from airflow.api_connexion.security import requires_access_custom_view +from airflow.serialization.serialized_objects import BaseSerialization +from airflow.utils.session import create_session + +if TYPE_CHECKING: + from airflow.api_connexion.types import APIResponse + +log = logging.getLogger(__name__) + +REMOTE_WORKER_API_ROLE = "Remote Worker API" + + +@functools.lru_cache +def _initialize_method_map() -> dict[str, Callable]: + from airflow.api_internal.endpoints.rpc_api_endpoint import initialize_method_map + from airflow.providers.remote.models.remote_job import RemoteJob + from airflow.providers.remote.models.remote_logs import RemoteLogs + from airflow.providers.remote.models.remote_worker import RemoteWorker + + internal_api_functions = initialize_method_map().values() + functions: list[Callable] = [ + # TODO Trim down functions really needed by remote worker / rebase from AIP-44 + *internal_api_functions, + # Additional things from Remote Executor + RemoteJob.reserve_task, + RemoteJob.set_state, + RemoteLogs.push_logs, + RemoteWorker.register_worker, + RemoteWorker.set_state, + ] + return {f"{func.__module__}.{func.__qualname__}": func for func in functions} + + +def log_and_build_error_response(message, status): + error_id = uuid4() + server_message = message + f" error_id={error_id}" + log.exception(server_message) + client_message = message + f" The server side traceback may be identified with error_id={error_id}" + return Response(response=client_message, status=status) + + +@requires_access_custom_view("POST", REMOTE_WORKER_API_ROLE) +def remote_worker_api(body: dict[str, Any]) -> APIResponse: + """Handle Remote Worker API `/remote_worker/v1/rpcapi` endpoint.""" + log.debug("Got request") + json_rpc = body.get("jsonrpc") + if json_rpc != "2.0": + return log_and_build_error_response(message="Expected jsonrpc 2.0 request.", status=400) + + methods_map = _initialize_method_map() + method_name = body.get("method") + if method_name not in methods_map: + return log_and_build_error_response(message=f"Unrecognized method: {method_name}.", status=400) + + handler = methods_map[method_name] + params = {} + try: + if body.get("params"): + params_json = body.get("params") + params = BaseSerialization.deserialize(params_json, use_pydantic_models=True) + except Exception: + return log_and_build_error_response(message="Error deserializing parameters.", status=400) + + log.debug("Calling method %s\nparams: %s", method_name, params) + try: + # Session must be created there as it may be needed by serializer for lazy-loaded fields. + with create_session() as session: + output = handler(**params, session=session) + output_json = BaseSerialization.serialize(output, use_pydantic_models=True) + response = json.dumps(output_json) if output_json is not None else None + return Response(response=response, headers={"Content-Type": "application/json"}) + except Exception: + return log_and_build_error_response(message=f"Error executing method '{method_name}'.", status=500) diff --git a/airflow/providers/remote/cli/__init__.py b/airflow/providers/remote/cli/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/remote/cli/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/remote/cli/remote_command.py b/airflow/providers/remote/cli/remote_command.py new file mode 100644 index 0000000000000..3801664f96fd2 --- /dev/null +++ b/airflow/providers/remote/cli/remote_command.py @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +import os +import platform +import signal +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from subprocess import Popen +from time import sleep + +from requests.auth import HTTPBasicAuth + +from airflow.api_internal.internal_api_call import InternalApiConfig +from airflow.cli.cli_config import ARG_VERBOSE, ActionCommand, Arg +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.providers.remote.models.remote_job import RemoteJob +from airflow.providers.remote.models.remote_logs import RemoteLogs +from airflow.providers.remote.models.remote_worker import RemoteWorker, RemoteWorkerState +from airflow.utils import cli as cli_utils +from airflow.utils.state import TaskInstanceState + +logger = logging.getLogger(__name__) + + +def _hostname() -> str: + if platform.system() == "Windows": + return platform.uname().node + else: + return os.uname()[1] + + +@dataclass +class Job: + """Holds all information for a task/job to be executed as bundle.""" + + remote_job: RemoteJob + process: Popen + logfile: Path + logsize: int + """Last size of log file, point of last chunk push.""" + + +@dataclass +class WorkerState: + """Holds all information about the state of the worker instance.""" + + hostname: str + queues: list[str] | None + drain_worker: bool + api_url: str + user: str + password: str + + +def _fetch_job(state: WorkerState, jobs: list[Job]) -> bool: + """Fetch and start a new job from central site.""" + logger.debug("Attempting to fetch a new job...") + remote_job = RemoteJob.reserve_task(state.hostname, state.queues) + if remote_job: + logger.info("Received job: %s", remote_job) + env = os.environ.copy() + env["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "True" + env["AIRFLOW__CORE__INTERNAL_API_URL"] = state.api_url + env["AIRFLOW__CORE__INTERNAL_API_USER"] = state.user + env["AIRFLOW__CORE__INTERNAL_API_PASSWORD"] = state.password + process = Popen(remote_job.command, close_fds=True, env=env) + logfile = RemoteLogs.logfile_path( + remote_job.dag_id, + remote_job.run_id, + remote_job.task_id, + remote_job.map_index, + remote_job.try_number, + ) + jobs.append(Job(remote_job, process, logfile, 0)) + RemoteJob.set_state(remote_job.key, TaskInstanceState.RUNNING) + return True + + logger.info("No new job to process%s", f", {len(jobs)} still running" if jobs else "") + return False + + +def _check_running_jobs(jobs: list[Job]) -> None: + """Check which of the running tasks/jobs are completed and report back.""" + for i in range(len(jobs) - 1, -1, -1): + job = jobs[i] + job.process.poll() + if job.process.returncode is not None: + jobs.remove(job) + if job.process.returncode == 0: + logger.info("Job completed: %s", job.remote_job) + RemoteJob.set_state(job.remote_job.key, TaskInstanceState.SUCCESS) + else: + logger.error("Job failed: %s", job.remote_job) + RemoteJob.set_state(job.remote_job.key, TaskInstanceState.FAILED) + if job.logfile.exists() and job.logfile.stat().st_size > job.logsize: + with job.logfile.open("r") as logfile: + logfile.seek(job.logsize, os.SEEK_SET) + logdata = logfile.read() + RemoteLogs.push_logs( + dag_id=job.remote_job.dag_id, + run_id=job.remote_job.run_id, + task_id=job.remote_job.task_id, + map_index=job.remote_job.map_index, + try_number=job.remote_job.try_number, + log_chunk_time=datetime.now(), + log_chunk_data=logdata, + ) + job.logsize += len(logdata) + + +def _heartbeat(state: WorkerState, jobs: list[Job]) -> None: + """Report liveness state of worker to central site with stats.""" + reported_state = ( + (RemoteWorkerState.TERMINATING if state.drain_worker else RemoteWorkerState.RUNNING) + if jobs + else RemoteWorkerState.IDLE + ) + RemoteWorker.set_state(state.hostname, reported_state, len(jobs), {}) + + +@cli_utils.action_cli(check_db=False) +def worker(args): + """Start Airflow Remote worker.""" + worker_state = WorkerState( + hostname=args.remote_hostname or _hostname(), + queues=args.queues.split(",") if args.queues else None, + drain_worker=False, + api_url=conf.get("remote", "api_url"), + user=args.user or conf.get("remote", "user"), + password=args.password or conf.get("remote", "password"), + ) + job_poll_interval = conf.getint("remote", "job_poll_interval") + heartbeat_interval = conf.getint("remote", "heartbeat_interval") + if not worker_state.api_url: + raise SystemExit("Error: API URL is not configured, please correct configuration.") + logger.info("Starting worker with API endpoint %s", worker_state.api_url) + InternalApiConfig.force_api_access( + worker_state.api_url, HTTPBasicAuth(worker_state.user, worker_state.password) + ) + + concurrency: int = args.concurrency + jobs: list[Job] = [] + new_job = False + try: + last_heartbeat = RemoteWorker.register_worker( + worker_state.hostname, RemoteWorkerState.STARTING, worker_state.queues + ).last_update + except AirflowException as e: + if "403" in str(e): + raise SystemExit(f"Error: API endpoint authentication failed, check your credentials: {e}") + raise SystemExit("Error: API endpoint is not ready, please set [remote] api_enabled=True.") + + def signal_handler(sig, frame): + logger.info("Request to show down remote worker received, waiting for jobs to complete.") + worker_state.drain_worker = True + + signal.signal(signal.SIGINT, signal_handler) + + while not worker_state.drain_worker or jobs: + if not worker_state.drain_worker and len(jobs) < concurrency: + new_job = _fetch_job(worker_state, jobs) + _check_running_jobs(jobs) + + if ( + worker_state.drain_worker + or datetime.now().timestamp() - last_heartbeat.timestamp() > heartbeat_interval + ): + _heartbeat(worker_state, jobs) + last_heartbeat = datetime.now() + + if not new_job: + sleep(job_poll_interval) + new_job = False + + logger.info("Quitting worker, signal being offline.") + RemoteWorker.set_state(worker_state.hostname, RemoteWorkerState.OFFLINE, 0, {}) + + +ARG_CONCURRENCY = Arg( + ("-c", "--concurrency"), + type=int, + help="The number of worker processes", + default=1, +) +ARG_QUEUES = Arg( + ("-q", "--queues"), + help="Comma delimited list of queues to serve, serve all queues if not provided.", +) +ARG_REMOTE_HOSTNAME = Arg( + ("-H", "--remote-hostname"), + help="Set the hostname of worker if you have multiple workers on a single machine", +) +ARG_REMOTE_USER = Arg( + ("-u", "--user"), + help="Username to authenticate against Remote Worker API endpoint", +) +ARG_REMOTE_PASSWORD = Arg( + ("-p", "--password"), + help="Password to authenticate against Remote Worker API endpoint", +) +REMOTE_COMMANDS: list[ActionCommand] = [ + ActionCommand( + name=worker.__name__, + help=worker.__doc__, + func=worker, + args=( + ARG_CONCURRENCY, + ARG_QUEUES, + ARG_REMOTE_HOSTNAME, + ARG_REMOTE_USER, + ARG_REMOTE_PASSWORD, + ARG_VERBOSE, + ), + ), +] diff --git a/airflow/providers/remote/executors/__init__.py b/airflow/providers/remote/executors/__init__.py new file mode 100644 index 0000000000000..d42ad286855f3 --- /dev/null +++ b/airflow/providers/remote/executors/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.providers.remote.executors.remote_executor import RemoteExecutor + +__all__ = ["RemoteExecutor"] diff --git a/airflow/providers/remote/executors/remote_executor.py b/airflow/providers/remote/executors/remote_executor.py new file mode 100644 index 0000000000000..4ca19ee00bb3e --- /dev/null +++ b/airflow/providers/remote/executors/remote_executor.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from airflow.cli.cli_config import GroupCommand +from airflow.executors.base_executor import BaseExecutor +from airflow.models.abstractoperator import DEFAULT_QUEUE +from airflow.models.taskinstance import TaskInstanceState +from airflow.providers.remote.cli.remote_command import REMOTE_COMMANDS +from airflow.providers.remote.models import RemoteJobModel +from airflow.providers.remote.models.remote_logs import RemoteLogsModel +from airflow.providers.remote.models.remote_worker import RemoteWorkerModel +from airflow.utils.db import DBLocks, create_global_lock +from airflow.utils.session import NEW_SESSION, provide_session + +if TYPE_CHECKING: + import argparse + + from sqlalchemy.orm import Session + + from airflow.executors.base_executor import CommandType + from airflow.models.taskinstance import TaskInstance + from airflow.models.taskinstancekey import TaskInstanceKey + + +class RemoteExecutor(BaseExecutor): + """Implementation of the remote executor to distribute work to remote workers via HTTP.""" + + @provide_session + def start(self, session: Session = NEW_SESSION): + """If Remote Executor provider is loaded first time, ensure table exists.""" + with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): + engine = session.get_bind().engine + RemoteJobModel.metadata.create_all(engine) + RemoteLogsModel.metadata.create_all(engine) + RemoteWorkerModel.metadata.create_all(engine) + + @provide_session + def execute_async( + self, + key: TaskInstanceKey, + command: CommandType, + queue: str | None = None, + executor_config: Any | None = None, + session: Session = NEW_SESSION, + ) -> None: + """Execute asynchronously.""" + self.validate_airflow_tasks_run_command(command) + session.add( + RemoteJobModel( + dag_id=key.dag_id, + task_id=key.task_id, + run_id=key.run_id, + map_index=key.map_index, + try_number=key.try_number, + state=TaskInstanceState.QUEUED, + queue=queue or DEFAULT_QUEUE, + command=str(command), + ) + ) + + def sync(self) -> None: + """Sync will get called periodically by the heartbeat method.""" + # TODO This might be used to clean the task table and upload completed logs + + def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]: + """ + Return the task logs. + + :param ti: A TaskInstance object + :param try_number: current try_number to read log from + :return: tuple of logs and messages + """ + return [], [] + + def end(self) -> None: + """End the executor.""" + self.log.info("Shutting down RemoteExecutor") + + def terminate(self): + """Terminate the executor is not doing anything.""" + + def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]: # pragma: no cover + """ + Handle remnants of tasks that were failed because they were stuck in queued. + + Tasks can get stuck in queued. If such a task is detected, it will be marked + as `UP_FOR_RETRY` if the task instance has remaining retries or marked as `FAILED` + if it doesn't. + + :param tis: List of Task Instances to clean up + :return: List of readable task instances for a warning message + """ + raise NotImplementedError() + + @staticmethod + def get_cli_commands() -> list[GroupCommand]: + return [ + GroupCommand( + name="remote", + help="Remote worker components", + description=( + "Start and manage remote worker. Works only when using RemoteExecutor. For more information, " + "see https://airflow.apache.org/docs/apache-airflow-providers-remote/stable/remote_executor.html" + ), + subcommands=REMOTE_COMMANDS, + ), + ] + + +def _get_parser() -> argparse.ArgumentParser: + """ + Generate documentation; used by Sphinx. + + :meta private: + """ + return RemoteExecutor._get_parser() diff --git a/airflow/providers/remote/models/__init__.py b/airflow/providers/remote/models/__init__.py new file mode 100644 index 0000000000000..7d1d709cf8da4 --- /dev/null +++ b/airflow/providers/remote/models/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.remote.models.remote_job import RemoteJobModel +from airflow.providers.remote.models.remote_logs import RemoteLogsModel +from airflow.providers.remote.models.remote_worker import RemoteWorkerModel + +__all__ = ["RemoteJobModel", "RemoteLogsModel", "RemoteWorkerModel"] diff --git a/airflow/providers/remote/models/remote_job.py b/airflow/providers/remote/models/remote_job.py new file mode 100644 index 0000000000000..93afd5d125758 --- /dev/null +++ b/airflow/providers/remote/models/remote_job.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from ast import literal_eval +from datetime import datetime +from typing import TYPE_CHECKING, List, Optional + +from sqlalchemy import ( + Column, + Index, + Integer, + String, + select, + text, +) + +from airflow.api_internal.internal_api_call import internal_api_call +from airflow.models.base import Base, StringID +from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.serialization.serialized_objects import add_pydantic_class_type_mapping +from airflow.utils import timezone +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.pydantic import BaseModel as BaseModelPydantic, ConfigDict, is_pydantic_2_installed +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + + +class RemoteJobModel(Base, LoggingMixin): + """ + A job which is queued, waiting or running on a Remote Worker. + + Each tuple in the database represents and describes the state of one job. + """ + + __tablename__ = "remote_job" + dag_id = Column(StringID(), primary_key=True, nullable=False) + task_id = Column(StringID(), primary_key=True, nullable=False) + run_id = Column(StringID(), primary_key=True, nullable=False) + map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) + try_number = Column(Integer, primary_key=True, default=0) + state = Column(String(20)) + queue = Column(String(256)) + command = Column(String(1000)) + queued_dttm = Column(UtcDateTime) + remote_worker = Column(String(64)) + last_update = Column(UtcDateTime) + + def __init__( + self, + dag_id: str, + task_id: str, + run_id: str, + map_index: int, + try_number: int, + state: str, + queue: str, + command: str, + queued_dttm: datetime | None = None, + remote_worker: str | None = None, + last_update: datetime | None = None, + ): + self.dag_id = dag_id + self.task_id = task_id + self.run_id = run_id + self.map_index = map_index + self.try_number = try_number + self.state = state + self.queue = queue + self.command = command + self.queued_dttm = queued_dttm or timezone.utcnow() + self.remote_worker = remote_worker + self.last_update = last_update + super().__init__() + + __table_args__ = (Index("rj_order", state, queued_dttm, queue),) + + +class RemoteJob(BaseModelPydantic, LoggingMixin): + """Accessor for remote jobs as logical model.""" + + dag_id: str + task_id: str + run_id: str + map_index: int + try_number: int + state: TaskInstanceState + queue: str + command: List[str] # noqa: UP006 - prevent Sphinx failing + queued_dttm: datetime + remote_worker: Optional[str] # noqa: UP007 - prevent Sphinx failing + last_update: Optional[datetime] # noqa: UP007 - prevent Sphinx failing + model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) + + @property + def key(self) -> TaskInstanceKey: + return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index) + + @staticmethod + @internal_api_call + @provide_session + def reserve_task( + worker_name: str, queues: list[str] | None = None, session: Session = NEW_SESSION + ) -> RemoteJob | None: + query = ( + select(RemoteJobModel) + .where(RemoteJobModel.state == TaskInstanceState.QUEUED) + .order_by(RemoteJobModel.queued_dttm) + ) + if queues: + query = query.where(RemoteJobModel.queue.in_(queues)) + query = query.limit(1) + query = with_row_locks(query, of=RemoteJobModel, session=session, skip_locked=True) + job: RemoteJobModel = session.scalar(query) + if not job: + return None + job.state = TaskInstanceState.RUNNING + job.remote_worker = worker_name + job.last_update = timezone.utcnow() + session.commit() + return RemoteJob( + dag_id=job.dag_id, + task_id=job.task_id, + run_id=job.run_id, + map_index=job.map_index, + try_number=job.try_number, + state=job.state, + queue=job.queue, + command=literal_eval(job.command), + queued_dttm=job.queued_dttm, + remote_worker=job.remote_worker, + last_update=job.last_update, + ) + + @staticmethod + @internal_api_call + @provide_session + def set_state(task: TaskInstanceKey | tuple, state: TaskInstanceState, session: Session = NEW_SESSION): + if isinstance(task, tuple): + task = TaskInstanceKey(*task) + query = select(RemoteJobModel).where( + RemoteJobModel.dag_id == task.dag_id, + RemoteJobModel.task_id == task.task_id, + RemoteJobModel.run_id == task.run_id, + RemoteJobModel.map_index == task.map_index, + RemoteJobModel.try_number == task.try_number, + ) + job: RemoteJobModel = session.scalar(query) + job.state = state + job.last_update = timezone.utcnow() + session.commit() + + def __hash__(self): + return f"{self.dag_id}|{self.task_id}|{self.run_id}|{self.map_index}|{self.try_number}".__hash__() + + +if is_pydantic_2_installed(): + RemoteJob.model_rebuild() + +add_pydantic_class_type_mapping("remote_job", RemoteJobModel, RemoteJob) diff --git a/airflow/providers/remote/models/remote_logs.py b/airflow/providers/remote/models/remote_logs.py new file mode 100644 index 0000000000000..f308759d35d43 --- /dev/null +++ b/airflow/providers/remote/models/remote_logs.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING + +from sqlalchemy import ( + Column, + Integer, + Text, + text, +) +from sqlalchemy.dialects.mysql import MEDIUMTEXT + +from airflow.api_internal.internal_api_call import internal_api_call +from airflow.configuration import conf +from airflow.models.base import Base, StringID +from airflow.models.taskinstance import TaskInstance +from airflow.serialization.serialized_objects import add_pydantic_class_type_mapping +from airflow.utils.log.file_task_handler import FileTaskHandler +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.pydantic import BaseModel as BaseModelPydantic, ConfigDict, is_pydantic_2_installed +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import UtcDateTime + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + + +class RemoteLogsModel(Base, LoggingMixin): + """ + Temporary collected logs from a remote worker while job runs remote. + + As the remote worker in most cases has a local file system and the web UI no access + to read files from remote, remote workers will send incremental chunks of logs + of running jobs to the central site. As log storage backends in most cloud cases can not + append logs, the table is used as buffer to receive. Upon task completion logs can be + flushed to task log handler. + + Log data therefore is collected in chunks and is only temporary. + """ + + __tablename__ = "remote_logs" + dag_id = Column(StringID(), primary_key=True, nullable=False) + task_id = Column(StringID(), primary_key=True, nullable=False) + run_id = Column(StringID(), primary_key=True, nullable=False) + map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) + try_number = Column(Integer, primary_key=True, default=0) + log_chunk_time = Column(UtcDateTime, primary_key=True, nullable=False) + log_chunk_data = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False) + + def __init__( + self, + dag_id: str, + task_id: str, + run_id: str, + map_index: int, + try_number: int, + log_chunk_time: datetime, + log_chunk_data: str, + ): + self.dag_id = dag_id + self.task_id = task_id + self.run_id = run_id + self.map_index = map_index + self.try_number = try_number + self.log_chunk_time = log_chunk_time + self.log_chunk_data = log_chunk_data + super().__init__() + + +class RemoteLogs(BaseModelPydantic, LoggingMixin): + """Accessor for remote worker instances as logical model.""" + + dag_id: str + task_id: str + run_id: str + map_index: int + try_number: int + log_chunk_time: datetime + log_chunk_data: str + model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) + + @staticmethod + @internal_api_call + @provide_session + def push_logs( + dag_id: str, + task_id: str, + run_id: str, + map_index: int, + try_number: int, + log_chunk_time: datetime, + log_chunk_data: str, + session: Session = NEW_SESSION, + ) -> None: + log_chunk = RemoteLogsModel( + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + try_number=try_number, + log_chunk_time=log_chunk_time, + log_chunk_data=log_chunk_data, + ) + session.add(log_chunk) + # Write logs to local file to make them accessible + logfile_path = RemoteLogs.logfile_path(dag_id, run_id, task_id, map_index, try_number) + if not logfile_path.exists(): + new_folder_permissions = int( + conf.get("logging", "file_task_handler_new_folder_permissions", fallback="0o775"), 8 + ) + logfile_path.parent.mkdir(parents=True, exist_ok=True, mode=new_folder_permissions) + with logfile_path.open("a") as logfile: + logfile.write(log_chunk_data) + + @staticmethod + @lru_cache + def logfile_path(dag_id: str, run_id: str, task_id: str, map_index: int, try_number: int) -> Path: + """Elaborate the path and filename to expect from task execution.""" + ti = TaskInstance.get_task_instance( + dag_id=dag_id, + run_id=run_id, + task_id=task_id, + map_index=map_index, + ) + if TYPE_CHECKING: + assert ti + base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE") + return Path(base_log_folder, FileTaskHandler(base_log_folder, None)._render_filename(ti, try_number)) + + +if is_pydantic_2_installed(): + RemoteLogs.model_rebuild() + +add_pydantic_class_type_mapping("remote_worker", RemoteLogsModel, RemoteLogs) diff --git a/airflow/providers/remote/models/remote_worker.py b/airflow/providers/remote/models/remote_worker.py new file mode 100644 index 0000000000000..2a3d47e4de24d --- /dev/null +++ b/airflow/providers/remote/models/remote_worker.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, List, Optional + +from sqlalchemy import ( + Column, + Integer, + String, + select, +) + +from airflow.api_internal.internal_api_call import internal_api_call +from airflow.models.base import Base +from airflow.serialization.serialized_objects import add_pydantic_class_type_mapping +from airflow.utils import timezone +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.pydantic import BaseModel as BaseModelPydantic, ConfigDict, is_pydantic_2_installed +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import UtcDateTime + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + + +class RemoteWorkerModel(Base, LoggingMixin): + """A Remote Worker instance which reports the state and health.""" + + __tablename__ = "remote_worker" + worker_name = Column(String(64), primary_key=True, nullable=False) + state = Column(String(20)) + queues = Column(String(256)) + first_online = Column(UtcDateTime) + last_update = Column(UtcDateTime) + jobs_active = Column(Integer, default=0) + jobs_taken = Column(Integer, default=0) + jobs_success = Column(Integer, default=0) + jobs_failed = Column(Integer, default=0) + sysinfo = Column(String(256)) + + def __init__( + self, + worker_name: str, + state: str, + queues: list[str] | None, + first_online: datetime | None = None, + last_update: datetime | None = None, + ): + self.worker_name = worker_name + self.state = state + self.queues = ", ".join(queues) if queues else None + self.first_online = first_online or timezone.utcnow() + self.last_update = last_update + super().__init__() + + +class RemoteWorkerState(str, Enum): + """Status of a remote worker instance.""" + + STARTING = "starting" + """Remote worker is in initialization.""" + RUNNING = "running" + """Remote worker is actively running a task.""" + IDLE = "idle" + """Remote worker is active and waiting for a task.""" + TERMINATING = "terminating" + """Remote worker is completing work and stopping.""" + OFFLINE = "offline" + """Remote worker was show down.""" + UNKNOWN = "unknown" + """No heartbeat signal from worker for some time, remote worker probably down.""" + + +class RemoteWorker(BaseModelPydantic, LoggingMixin): + """Accessor for remote worker instances as logical model.""" + + worker_name: str + state: RemoteWorkerState + queues: Optional[List[str]] # noqa: UP006,UP007 - prevent Sphinx failing + first_online: datetime + last_update: Optional[datetime] = None # noqa: UP007 - prevent Sphinx failing + jobs_active: int + jobs_taken: int + jobs_success: int + jobs_failed: int + sysinfo: str + model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) + + @staticmethod + @internal_api_call + @provide_session + def register_worker( + worker_name: str, state: RemoteWorkerState, queues: list[str] | None, session: Session = NEW_SESSION + ) -> RemoteWorker: + query = select(RemoteWorkerModel).where(RemoteWorkerModel.worker_name == worker_name) + worker: RemoteWorkerModel = session.scalar(query) + if not worker: + worker = RemoteWorkerModel(worker_name=worker_name, state=state, queues=queues) + worker.state = state + worker.queues = queues + worker.last_update = timezone.utcnow() + session.add(worker) + return RemoteWorker( + worker_name=worker_name, + state=state, + queues=worker.queues, + first_online=worker.first_online, + last_update=worker.last_update, + jobs_active=worker.jobs_active or 0, + jobs_taken=worker.jobs_taken or 0, + jobs_success=worker.jobs_success or 0, + jobs_failed=worker.jobs_failed or 0, + sysinfo=worker.sysinfo or "{}", + ) + + @staticmethod + @internal_api_call + @provide_session + def set_state( + worker_name: str, + state: RemoteWorkerState, + jobs_active: int, + sysinfo: dict[str, str], + session: Session = NEW_SESSION, + ): + query = select(RemoteWorkerModel).where(RemoteWorkerModel.worker_name == worker_name) + worker: RemoteWorkerModel = session.scalar(query) + worker.state = state + worker.jobs_active = jobs_active + worker.sysinfo = json.dumps(sysinfo) + worker.last_update = timezone.utcnow() + session.commit() + + +if is_pydantic_2_installed(): + RemoteWorker.model_rebuild() + +add_pydantic_class_type_mapping("remote_worker", RemoteWorkerModel, RemoteWorker) diff --git a/airflow/providers/remote/openapi/__init__.py b/airflow/providers/remote/openapi/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/remote/openapi/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/remote/openapi/remote_worker_api_v1.yaml b/airflow/providers/remote/openapi/remote_worker_api_v1.yaml new file mode 100644 index 0000000000000..0a4afe1b2bc78 --- /dev/null +++ b/airflow/providers/remote/openapi/remote_worker_api_v1.yaml @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +openapi: 3.0.2 +info: + title: Airflow Remote Worker API + version: 0.1.0 + description: | + This is Airflow Remote Worker API - which is a the access endpoint for workers + running remote serving for Apache Airflow jobs. It also proxies internal API + to remote endpoints. + + It is not intended to be used by any external code. + + You can find more information in AIP-69 + https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-69+Remote+Executor + + +servers: + - url: /remote_worker/v1 + description: Airflow Remote Worker API +paths: + "/rpcapi": + post: + deprecated: false + x-openapi-router-controller: airflow.providers.remote.api_endpoints.rpc_api_endpoint + operationId: remote_worker_api + tags: + - JSONRPC + parameters: [] + responses: + '200': + description: Successful response + requestBody: + x-body-name: body + required: true + content: + application/json: + schema: + type: object + required: + - method + - jsonrpc + - params + properties: + jsonrpc: + type: string + default: '2.0' + description: JSON-RPC Version (2.0) + method: + type: string + description: Method name + params: + title: Parameters + type: object + "/health": + get: + operationId: health + deprecated: false + x-openapi-router-controller: airflow.providers.remote.api_endpoints.health_endpoint + tags: + - JSONRPC + parameters: [] + responses: + '200': + description: Successful response +x-headers: [] +x-explorer-enabled: true +x-proxy-enabled: true +components: + schemas: + JsonRpcRequired: + type: object + required: + - method + - jsonrpc + properties: + method: + type: string + description: Method name + jsonrpc: + type: string + default: '2.0' + description: JSON-RPC Version (2.0) + discriminator: + propertyName: method_name +tags: [] diff --git a/airflow/providers/remote/plugins/__init__.py b/airflow/providers/remote/plugins/__init__.py new file mode 100644 index 0000000000000..6d623fb01beed --- /dev/null +++ b/airflow/providers/remote/plugins/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.remote.plugins.remote_executor_plugin import RemoteExecutorPlugin + +__all__ = ["RemoteExecutorPlugin"] diff --git a/airflow/providers/remote/plugins/remote_executor_plugin.py b/airflow/providers/remote/plugins/remote_executor_plugin.py new file mode 100644 index 0000000000000..56d3c1b0a314a --- /dev/null +++ b/airflow/providers/remote/plugins/remote_executor_plugin.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from connexion import FlaskApi +from flask import Blueprint +from flask_appbuilder import BaseView, expose +from sqlalchemy import select + +from airflow.auth.managers.models.resource_details import AccessView +from airflow.configuration import conf +from airflow.models.taskinstance import TaskInstanceState +from airflow.plugins_manager import AirflowPlugin +from airflow.providers.remote.api_endpoints.rpc_api_endpoint import REMOTE_WORKER_API_ROLE +from airflow.security.permissions import ACTION_CAN_CREATE +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.yaml import safe_load +from airflow.www import utils as wwwutils +from airflow.www.auth import has_access_view +from airflow.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED +from airflow.www.extensions.init_auth_manager import auth_manager +from airflow.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + +def _get_api_endpoints() -> Blueprint: + folder = Path(__file__).parents[1].resolve() # this is airflow/providers/remote/ + with folder.joinpath("openapi", "remote_worker_api_v1.yaml").open() as f: + specification = safe_load(f) + bp = FlaskApi( + specification=specification, + resolver=_LazyResolver(), + base_path="/remote_worker/v1", + options={"swagger_ui": SWAGGER_ENABLED, "swagger_path": SWAGGER_BUNDLE.__fspath__()}, + strict_validation=True, + validate_responses=True, + validator_map={"body": _CustomErrorRequestBodyValidator}, + ).blueprint + # Need to excemp CSRF to make API usable + from airflow.www.app import csrf, get_auth_manager + + csrf.exempt(bp) + asm = get_auth_manager().appbuilder.sm + permission = asm.create_permission(ACTION_CAN_CREATE, REMOTE_WORKER_API_ROLE) + role = asm.find_role("Admin") + asm.add_permission_to_role(role, permission) + return bp + + +# registers airflow/providers/remote/plugins/templates as a Jinja template folder +template_bp = Blueprint( + "template_blueprint", + __name__, + template_folder="templates", +) + + +class RemoteWorkerJobs(BaseView): + """Simple view to show remote worker jobs.""" + + default_view = "jobs" + + @expose("/jobs") + @has_access_view(AccessView.JOBS) + @provide_session + def jobs(self, session: Session = NEW_SESSION): + from airflow.providers.remote.models.remote_job import RemoteJobModel + + jobs = session.scalars(select(RemoteJobModel)).all() + html_states = { + str(state): wwwutils.state_token(str(state)) for state in TaskInstanceState.__members__.values() + } + return self.render_template("remote_worker_jobs.html", jobs=jobs, html_states=html_states) + + +class RemoteWorkerHosts(BaseView): + """Simple view to show remote worker status.""" + + default_view = "status" + + @expose("/status") + @has_access_view(AccessView.JOBS) + @provide_session + def status(self, session: Session = NEW_SESSION): + from airflow.providers.remote.models.remote_worker import RemoteWorkerModel + + hosts = session.scalars(select(RemoteWorkerModel)).all() + return self.render_template("remote_worker_hosts.html", hosts=hosts) + + +# Check if RemoteExecutor is actually loaded (auth_manager is set if we run on the webserver) +REMOTE_EXECUTOR_ACTIVE = conf.getboolean("remote", "api_enabled") and auth_manager + + +class RemoteExecutorPlugin(AirflowPlugin): + """Remote Executor Plugin - provides API endpoints for remote workers in Webserver.""" + + name = "remote_executor" + flask_blueprints = [_get_api_endpoints(), template_bp] if REMOTE_EXECUTOR_ACTIVE else [] + appbuilder_views = ( + [ + { + "name": "Remote Worker Jobs", + "category": "Admin", + "view": RemoteWorkerJobs(), + }, + { + "name": "Remote Worker Hosts", + "category": "Admin", + "view": RemoteWorkerHosts(), + }, + ] + if REMOTE_EXECUTOR_ACTIVE + else [] + ) diff --git a/airflow/providers/remote/plugins/templates/remote_worker_hosts.html b/airflow/providers/remote/plugins/templates/remote_worker_hosts.html new file mode 100644 index 0000000000000..ad645d96a13e3 --- /dev/null +++ b/airflow/providers/remote/plugins/templates/remote_worker_hosts.html @@ -0,0 +1,78 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + #} + + {% extends base_template %} + + {% block title %} + Remote Worker Hosts + {% endblock %} + + {% block content %} +
No Remote Workers connected or known currently.
+ {% else %} + +Hostname | +State | +Queues | +First Online | +Last Heart Beat | +Active Jobs | +Jobs Taken | +Jobs Success | +Jobs Failed | +System Information | +
---|---|---|---|---|---|---|---|---|---|
{{ host.worker_name }} | ++ {%- if host.state == "starting" -%} + {{ host.state }} + {%- elif host.state == "running" -%} + {{ host.state }} + {%- elif host.state == "idle" -%} + {{ host.state }} + {%- elif host.state == "terminating" -%} + {{ host.state }} + {%- elif host.state == "offline" -%} + {{ host.state }} + {%- elif host.state == "unknown" -%} + {{ host.state }} + {%- else -%} + {{ host.state }} + {%- endif -%} + | +{% if host.queues %}{{ host.queues }}{% else %}(all){% endif %} | ++ | {% if host.last_update %}{% endif %} | +{{ host.jobs_active }} | +{{ host.jobs_taken }} | +{{ host.jobs_success }} | +{{ host.jobs_failed }} | +{{ host.sysinfo }} | +
No jobs running currently
+ {% else %} + +DAG ID | +Task ID | +Run ID | +Map Index | +Try Number | +State | +Queue | +Queued DTTM | +Remote Worker | +Last Update | +
---|---|---|---|---|---|---|---|---|---|
{{ job.dag_id }} | +{{ job.task_id }} | +{{ job.run_id }} | +{{ job.map_index }} | +{{ job.try_number }} | +{{ html_states[job.state] }} | +{{ job.queue }} | ++ | {{ job.remote_worker }} | ++ |