Skip to content

Commit

Permalink
feat: Ray on Vertex enables XGBoost register model with custom versio…
Browse files Browse the repository at this point in the history
…n using pre-built container

PiperOrigin-RevId: 619575247
  • Loading branch information
yinghsienwu authored and copybara-github committed Mar 27, 2024
1 parent b587a8d commit e45ef96
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import pickle
import ray
import ray.cloudpickle as cpickle
import tempfile
from typing import Optional, TYPE_CHECKING

Expand Down Expand Up @@ -117,7 +118,9 @@ def _get_estimator_from(
Raises:
ValueError: Invalid Argument.
RuntimeError: Model not found.
"""

ray_version = ray.__version__
if ray_version == "2.4.0":
if not isinstance(checkpoint, ray_sklearn.SklearnCheckpoint):
Expand All @@ -133,8 +136,25 @@ def _get_estimator_from(
)
return checkpoint.get_estimator()

# get_model() signature changed in future versions
try:
return checkpoint.get_estimator()
return checkpoint.get_model()
except AttributeError:
raise RuntimeError("Unsupported Ray version.")
model_file_name = ray.train.sklearn.SklearnCheckpoint.MODEL_FILENAME

model_path = os.path.join(checkpoint.path, model_file_name)

if os.path.exists(model_path):
with open(model_path, mode="rb") as f:
obj = pickle.load(f)
else:
try:
# Download from GCS to temp and then load_model
with tempfile.TemporaryDirectory() as temp_dir:
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
with open(f"{temp_dir}/{model_file_name}", mode="rb") as f:
obj = cpickle.load(f)
except Exception as e:
raise RuntimeError(
f"{model_file_name} not found in this checkpoint due to: {e}."
)
return obj
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
# limitations under the License.

import logging
from typing import Optional
import os
import ray
from ray.air._internal.torch_utils import load_torch_model
import tempfile
from google.cloud.aiplatform.utils import gcs_utils
from typing import Optional


try:
from ray.train import torch as ray_torch
Expand Down Expand Up @@ -51,6 +56,8 @@ def get_pytorch_model_from(
Raises:
ValueError: Invalid Argument.
ModuleNotFoundError: PyTorch isn't installed.
RuntimeError: Model not found.
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
Expand All @@ -67,8 +74,33 @@ def get_pytorch_model_from(
)
return checkpoint.get_model(model=model)

# get_model() signature changed in future versions
try:
return checkpoint.get_model()
except AttributeError:
raise RuntimeError("Unsupported Ray version.")
model_file_name = ray.train.torch.TorchCheckpoint.MODEL_FILENAME

model_path = os.path.join(checkpoint.path, model_file_name)

try:
import torch

except ModuleNotFoundError as mnfe:
raise ModuleNotFoundError("PyTorch isn't installed.") from mnfe

if os.path.exists(model_path):
model_or_state_dict = torch.load(model_path, map_location="cpu")
else:
try:
# Download from GCS to temp and then load_model
with tempfile.TemporaryDirectory() as temp_dir:
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
model_or_state_dict = torch.load(
f"{temp_dir}/{model_file_name}", map_location="cpu"
)
except Exception as e:
raise RuntimeError(
f"{model_file_name} not found in this checkpoint due to: {e}."
)

model = load_torch_model(saved_model=model_or_state_dict, model_definition=model)
return model
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def register_xgboost(
checkpoint: "ray_xgboost.XGBoostCheckpoint",
artifact_uri: Optional[str] = None,
display_name: Optional[str] = None,
xgboost_version: Optional[str] = None,
**kwargs,
) -> aiplatform.Model:
"""Uploads a Ray XGBoost Checkpoint as XGBoost Model to Model Registry.
Expand Down Expand Up @@ -75,6 +76,9 @@ def register_xgboost(
display_name (str):
Optional. The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
xgboost_version (str): Optional. The version of the XGBoost serving container.
Supported versions: ["0.82", "0.90", "1.1", "1.2", "1.3", "1.4", "1.6", "1.7", "2.0"].
If the version is not specified, the latest version is used.
**kwargs:
Any kwargs will be passed to aiplatform.Model registration.
Expand All @@ -96,14 +100,16 @@ def register_xgboost(

model_dir = os.path.join(artifact_uri, display_model_name)
file_path = os.path.join(model_dir, constants._PICKLE_FILE_NAME)
if xgboost_version is None:
xgboost_version = constants._XGBOOST_VERSION

with tempfile.NamedTemporaryFile(suffix=constants._PICKLE_EXTENTION) as temp_file:
pickle.dump(model, temp_file)
gcs_utils.upload_to_gcs(temp_file.name, file_path)
return aiplatform.Model.upload_xgboost_model_file(
model_file_path=temp_file.name,
display_name=display_model_name,
xgboost_version=constants._XGBOOST_VERSION,
xgboost_version=xgboost_version,
**kwargs,
)

Expand All @@ -121,6 +127,8 @@ def _get_xgboost_model_from(
Raises:
ValueError: Invalid Argument.
ModuleNotFoundError: XGBoost isn't installed.
RuntimeError: Model not found.
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
Expand All @@ -137,8 +145,33 @@ def _get_xgboost_model_from(
)
return checkpoint.get_model()

# get_model() signature changed in future versions
try:
# This works for Ray v2.5
return checkpoint.get_model()
except AttributeError:
raise RuntimeError("Unsupported Ray version.")
# This works for Ray v2.9
model_file_name = ray.train.xgboost.XGBoostCheckpoint.MODEL_FILENAME

model_path = os.path.join(checkpoint.path, model_file_name)

try:
import xgboost

except ModuleNotFoundError as mnfe:
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe

booster = xgboost.Booster()
if os.path.exists(model_path):
booster.load_model(model_path)
return booster

try:
# Download from GCS to temp and then load_model
with tempfile.TemporaryDirectory() as temp_dir:
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
booster.load_model(f"{temp_dir}/{model_file_name}")
return booster
except Exception as e:
raise RuntimeError(
f"{model_file_name} not found in this checkpoint due to: {e}."
)
3 changes: 3 additions & 0 deletions tests/unit/vertex_ray/test_ray_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def test_convert_checkpoint_to_sklearn_raise_exception(
"ray.train.sklearn.SklearnCheckpoint .*"
)

@tc.rovminversion
def test_convert_checkpoint_to_sklearn_model_succeed(
self, ray_sklearn_checkpoint
) -> None:
Expand All @@ -302,6 +303,7 @@ def test_convert_checkpoint_to_sklearn_model_succeed(
y_pred = estimator.predict([[10, 11]])
assert y_pred[0] is not None

@tc.rovminversion
def test_register_sklearn_succeed(
self,
ray_sklearn_checkpoint,
Expand All @@ -325,6 +327,7 @@ def test_register_sklearn_succeed(
pickle_dump.assert_called_once()
gcs_utils_upload_to_gcs.assert_called_once()

@tc.rovminversion
def test_register_sklearn_initialized_succeed(
self,
ray_sklearn_checkpoint,
Expand Down

0 comments on commit e45ef96

Please sign in to comment.