Skip to content

Commit 76fda15

Browse files
WSHoekstraWalter Hoekstrashahar1
authored and
Lorin
committed
vertex ai training operators: add display_name to rendered fields (apache#43028)
* vertex ai training operators: add display_name to rendered fields * fix validate-operators-init static checks --------- Co-authored-by: Walter Hoekstra <walterhoekstra@bol.com> Co-authored-by: Shahar Epstein <60007259+shahar1@users.noreply.github.com>
1 parent cbe4a2f commit 76fda15

File tree

4 files changed

+31
-2
lines changed

4 files changed

+31
-2
lines changed

providers/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
106106
"dataset_id",
107107
"region",
108108
"impersonation_chain",
109+
"display_name",
110+
"model_display_name",
109111
)
110112
operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
111113

@@ -121,6 +123,8 @@ def __init__(
121123
forecast_horizon: int,
122124
data_granularity_unit: str,
123125
data_granularity_count: int,
126+
display_name: str,
127+
model_display_name: str | None = None,
124128
optimization_objective: str | None = None,
125129
column_specs: dict[str, str] | None = None,
126130
column_transformations: list[dict[str, dict[str, str]]] | None = None,
@@ -143,7 +147,12 @@ def __init__(
143147
**kwargs,
144148
) -> None:
145149
super().__init__(
146-
region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs
150+
display_name=display_name,
151+
model_display_name=model_display_name,
152+
region=region,
153+
impersonation_chain=impersonation_chain,
154+
parent_model=parent_model,
155+
**kwargs,
147156
)
148157
self.dataset_id = dataset_id
149158
self.target_column = target_column

providers/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
163163
:param poll_interval: Interval size which defines how often job status is checked in deferrable mode.
164164
"""
165165

166-
template_fields = ("region", "project_id", "model_name", "impersonation_chain")
166+
template_fields = ("region", "project_id", "model_name", "impersonation_chain", "job_display_name")
167167
operator_extra_links = (VertexAIBatchPredictionJobLink(),)
168168

169169
def __init__(

providers/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py

+19
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
496496
"parent_model",
497497
"dataset_id",
498498
"impersonation_chain",
499+
"display_name",
500+
"model_display_name",
499501
)
500502
operator_extra_links = (
501503
VertexAIModelLink(),
@@ -507,6 +509,8 @@ def __init__(
507509
*,
508510
command: Sequence[str] = [],
509511
region: str,
512+
display_name: str,
513+
model_display_name: str | None = None,
510514
parent_model: str | None = None,
511515
impersonation_chain: str | Sequence[str] | None = None,
512516
dataset_id: str | None = None,
@@ -515,6 +519,8 @@ def __init__(
515519
**kwargs,
516520
) -> None:
517521
super().__init__(
522+
display_name=display_name,
523+
model_display_name=model_display_name,
518524
region=region,
519525
parent_model=parent_model,
520526
impersonation_chain=impersonation_chain,
@@ -949,6 +955,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
949955
"region",
950956
"dataset_id",
951957
"impersonation_chain",
958+
"display_name",
959+
"model_display_name",
952960
)
953961
operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
954962

@@ -958,6 +966,8 @@ def __init__(
958966
python_package_gcs_uri: str,
959967
python_module_name: str,
960968
region: str,
969+
display_name: str,
970+
model_display_name: str | None = None,
961971
parent_model: str | None = None,
962972
impersonation_chain: str | Sequence[str] | None = None,
963973
dataset_id: str | None = None,
@@ -966,6 +976,8 @@ def __init__(
966976
**kwargs,
967977
) -> None:
968978
super().__init__(
979+
display_name=display_name,
980+
model_display_name=model_display_name,
969981
region=region,
970982
parent_model=parent_model,
971983
impersonation_chain=impersonation_chain,
@@ -1405,6 +1417,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
14051417
"requirements",
14061418
"dataset_id",
14071419
"impersonation_chain",
1420+
"display_name",
1421+
"model_display_name",
14081422
)
14091423
operator_extra_links = (
14101424
VertexAIModelLink(),
@@ -1417,6 +1431,8 @@ def __init__(
14171431
script_path: str,
14181432
requirements: Sequence[str] | None = None,
14191433
region: str,
1434+
display_name: str,
1435+
model_display_name: str | None = None,
14201436
parent_model: str | None = None,
14211437
impersonation_chain: str | Sequence[str] | None = None,
14221438
dataset_id: str | None = None,
@@ -1425,6 +1441,8 @@ def __init__(
14251441
**kwargs,
14261442
) -> None:
14271443
super().__init__(
1444+
display_name=display_name,
1445+
model_display_name=model_display_name,
14281446
region=region,
14291447
parent_model=parent_model,
14301448
impersonation_chain=impersonation_chain,
@@ -1732,6 +1750,7 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
17321750
"region",
17331751
"project_id",
17341752
"impersonation_chain",
1753+
"display_name",
17351754
]
17361755
operator_extra_links = [
17371756
VertexAITrainingPipelinesLink(),

providers/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py

+1
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
147147
"region",
148148
"project_id",
149149
"impersonation_chain",
150+
"display_name",
150151
]
151152
operator_extra_links = (VertexAITrainingLink(),)
152153

0 commit comments

Comments
 (0)