Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Honoring the model's supported_deployment_resources_types #865

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@

from google.protobuf import field_mask_pb2, json_format

_DEFAULT_MACHINE_TYPE = "n1-standard-2"

_LOGGER = base.Logger(__name__)


Expand Down Expand Up @@ -798,7 +800,7 @@ def _deploy(
self._deploy_call(
self.api_client,
self.resource_name,
model.resource_name,
model,
self._gca_resource.traffic_split,
deployed_model_display_name=deployed_model_display_name,
traffic_percentage=traffic_percentage,
Expand All @@ -823,7 +825,7 @@ def _deploy_call(
cls,
api_client: endpoint_service_client.EndpointServiceClient,
endpoint_resource_name: str,
model_resource_name: str,
model: "Model",
endpoint_resource_traffic_split: Optional[proto.MapField] = None,
deployed_model_display_name: Optional[str] = None,
traffic_percentage: Optional[int] = 0,
Expand All @@ -845,8 +847,8 @@ def _deploy_call(
Required. endpoint_service_client.EndpointServiceClient to make call.
endpoint_resource_name (str):
Required. Endpoint resource name to deploy model to.
model_resource_name (str):
Required. Model resource name of Model to deploy.
model (aiplatform.Model):
Required. Model to be deployed.
endpoint_resource_traffic_split (proto.MapField):
Optional. Endpoint current resource traffic split.
deployed_model_display_name (str):
Expand Down Expand Up @@ -913,6 +915,7 @@ def _deploy_call(
is not 0 or 100.
ValueError: If only `explanation_metadata` or `explanation_parameters`
is specified.
ValueError: If model does not support deployment.
"""

max_replica_count = max(min_replica_count, max_replica_count)
Expand All @@ -923,12 +926,40 @@ def _deploy_call(
)

deployed_model = gca_endpoint_compat.DeployedModel(
model=model_resource_name,
model=model.resource_name,
display_name=deployed_model_display_name,
service_account=service_account,
)

if machine_type:
supports_automatic_resources = (
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
in model.supported_deployment_resources_types
)
supports_dedicated_resources = (
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
in model.supported_deployment_resources_types
)
provided_custom_machine_spec = (
machine_type or accelerator_type or accelerator_count
)

# If the model supports both automatic and dedicated deployment resources,
# decide based on the presence of machine spec customizations
use_dedicated_resources = supports_dedicated_resources and (
not supports_automatic_resources or provided_custom_machine_spec
)

if provided_custom_machine_spec and not use_dedicated_resources:
_LOGGER.info(
"Model does not support dedicated deployment resources. "
"The machine_type, accelerator_type and accelerator_count parameters are ignored."
)

if use_dedicated_resources and not machine_type:
machine_type = _DEFAULT_MACHINE_TYPE
_LOGGER.info(f"Using default machine_type: {machine_type}")

if use_dedicated_resources:
machine_spec = gca_machine_resources_compat.MachineSpec(
machine_type=machine_type
)
Expand All @@ -944,11 +975,16 @@ def _deploy_call(
max_replica_count=max_replica_count,
)

else:
elif supports_automatic_resources:
deployed_model.automatic_resources = gca_machine_resources_compat.AutomaticResources(
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
)
else:
raise ValueError(
"Model does not support deployment. "
"See https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#google.cloud.aiplatform.v1.Model.FIELDS.repeated.google.cloud.aiplatform.v1.Model.DeploymentResourcesType.google.cloud.aiplatform.v1.Model.supported_deployment_resources_types"
)

# Service will throw error if both metadata and parameters are not provided
if explanation_metadata and explanation_parameters:
Expand Down Expand Up @@ -2115,7 +2151,7 @@ def _deploy(
Endpoint._deploy_call(
endpoint.api_client,
endpoint.resource_name,
self.resource_name,
self,
endpoint._gca_resource.traffic_split,
deployed_model_display_name=deployed_model_display_name,
traffic_percentage=traffic_percentage,
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,9 @@ def test_create_with_labels(self, create_endpoint_mock, sync):
def test_deploy(self, deploy_model_mock, sync):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(test_model, sync=sync)

if not sync:
Expand All @@ -636,6 +639,9 @@ def test_deploy(self, deploy_model_mock, sync):
def test_deploy_with_display_name(self, deploy_model_mock, sync):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(
model=test_model, deployed_model_display_name=_TEST_DISPLAY_NAME, sync=sync
)
Expand Down Expand Up @@ -664,6 +670,9 @@ def test_deploy_raise_error_traffic_80(self, sync):
with pytest.raises(ValueError):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(model=test_model, traffic_percentage=80, sync=sync)

if not sync:
Expand All @@ -675,6 +684,9 @@ def test_deploy_raise_error_traffic_120(self, sync):
with pytest.raises(ValueError):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(model=test_model, traffic_percentage=120, sync=sync)

@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
Expand All @@ -683,6 +695,9 @@ def test_deploy_raise_error_traffic_negative(self, sync):
with pytest.raises(ValueError):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(model=test_model, traffic_percentage=-18, sync=sync)

@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
Expand All @@ -691,6 +706,9 @@ def test_deploy_raise_error_min_replica(self, sync):
with pytest.raises(ValueError):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(model=test_model, min_replica_count=-1, sync=sync)

@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
Expand All @@ -699,6 +717,9 @@ def test_deploy_raise_error_max_replica(self, sync):
with pytest.raises(ValueError):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(model=test_model, max_replica_count=-2, sync=sync)

@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
Expand All @@ -707,6 +728,9 @@ def test_deploy_raise_error_traffic_split(self, sync):
with pytest.raises(ValueError):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(model=test_model, traffic_split={"a": 99}, sync=sync)

@pytest.mark.usefixtures("get_model_mock")
Expand All @@ -723,6 +747,9 @@ def test_deploy_with_traffic_percent(self, deploy_model_mock, sync):

test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(model=test_model, traffic_percentage=70, sync=sync)
if not sync:
test_endpoint.wait()
Expand Down Expand Up @@ -755,6 +782,9 @@ def test_deploy_with_traffic_split(self, deploy_model_mock, sync):

test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(
model=test_model, traffic_split={"model1": 30, "0": 70}, sync=sync
)
Expand All @@ -781,6 +811,9 @@ def test_deploy_with_traffic_split(self, deploy_model_mock, sync):
def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
)
test_endpoint.deploy(
model=test_model,
machine_type=_TEST_MACHINE_TYPE,
Expand Down Expand Up @@ -821,6 +854,9 @@ def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync):
def test_deploy_with_explanations(self, deploy_model_with_explanations_mock, sync):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
)
test_endpoint.deploy(
model=test_model,
machine_type=_TEST_MACHINE_TYPE,
Expand Down Expand Up @@ -865,6 +901,9 @@ def test_deploy_with_explanations(self, deploy_model_with_explanations_mock, syn
def test_deploy_with_min_replica_count(self, deploy_model_mock, sync):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(model=test_model, min_replica_count=2, sync=sync)

if not sync:
Expand All @@ -889,6 +928,9 @@ def test_deploy_with_min_replica_count(self, deploy_model_mock, sync):
def test_deploy_with_max_replica_count(self, deploy_model_mock, sync):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint.deploy(model=test_model, max_replica_count=2, sync=sync)
if not sync:
test_endpoint.wait()
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,10 @@ def test_upload_uploads_and_gets_model_with_custom_location(
def test_deploy(self, deploy_model_mock, sync):

test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)

test_endpoint = models.Endpoint(_TEST_ID)

assert test_model.deploy(test_endpoint, sync=sync,) == test_endpoint
Expand Down Expand Up @@ -854,6 +858,9 @@ def test_deploy(self, deploy_model_mock, sync):
def test_deploy_no_endpoint(self, deploy_model_mock, sync):

test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
)
test_endpoint = test_model.deploy(sync=sync)

if not sync:
Expand Down Expand Up @@ -881,6 +888,9 @@ def test_deploy_no_endpoint(self, deploy_model_mock, sync):
def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync):

test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
)
test_endpoint = test_model.deploy(
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
Expand Down Expand Up @@ -919,6 +929,9 @@ def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync):
@pytest.mark.parametrize("sync", [True, False])
def test_deploy_no_endpoint_with_explanations(self, deploy_model_mock, sync):
test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
)
test_endpoint = test_model.deploy(
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
Expand Down Expand Up @@ -961,6 +974,9 @@ def test_deploy_no_endpoint_with_explanations(self, deploy_model_mock, sync):
def test_deploy_raises_with_impartial_explanation_spec(self):

test_model = models.Model(_TEST_ID)
test_model._gca_resource.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
)

with pytest.raises(ValueError) as e:
test_model.deploy(
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,9 @@ def mock_model_service_get():
model_service_client.ModelServiceClient, "get_model"
) as mock_get_model:
mock_get_model.return_value = gca_model.Model(name=_TEST_MODEL_NAME)
mock_get_model.return_value.supported_deployment_resources_types.append(
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
)
yield mock_get_model


Expand Down