Skip to content

Commit

Permalink
feat: LLM - Add support for batch prediction to CodeGenerationModel
Browse files Browse the repository at this point in the history
… (`code-bison`)

PiperOrigin-RevId: 609627761
  • Loading branch information
Ark-kun authored and copybara-github committed Feb 23, 2024
1 parent 0b55762 commit fbf2f7c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
21 changes: 21 additions & 0 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from vertexai.preview.language_models import (
ChatModel,
CodeGenerationModel,
InputOutputTextPair,
TextGenerationModel,
TextGenerationResponse,
Expand Down Expand Up @@ -434,6 +435,26 @@ def test_batch_prediction_for_textembedding(self):

assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED

def test_batch_prediction_for_code_generation(self):
source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/code-bison.batch_prediction_prompts.1.jsonl"
destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/code-bison@001_"

aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = CodeGenerationModel.from_pretrained("code-bison@001")
job = model.batch_predict(
dataset=source_uri,
destination_uri_prefix=destination_uri_prefix,
model_parameters={"temperature": 0},
)

job.wait_for_resource_creation()
job.wait()
gapic_job = job._gca_resource
job.delete()

assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED

def test_code_generation_streaming(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

Expand Down
30 changes: 30 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4324,6 +4324,36 @@ def test_batch_prediction(
model_parameters={"temperature": 0.1},
)

def test_batch_prediction_for_code_generation(self):
"""Tests batch prediction."""
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.CodeGenerationModel.from_pretrained(
"code-bison@001"
)

with mock.patch.object(
target=aiplatform.BatchPredictionJob,
attribute="create",
) as mock_create:
model.batch_predict(
dataset="gs://test-bucket/test_table.jsonl",
destination_uri_prefix="gs://test-bucket/results/",
model_parameters={},
)
mock_create.assert_called_once_with(
model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/code-bison@001",
job_display_name=None,
gcs_source="gs://test-bucket/test_table.jsonl",
gcs_destination_prefix="gs://test-bucket/results/",
model_parameters={},
)

def test_batch_prediction_for_text_embedding(self):
"""Tests batch prediction."""
aiplatform.init(
Expand Down
6 changes: 5 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3366,7 +3366,11 @@ def count_tokens(
)


class CodeGenerationModel(_CodeGenerationModel, _TunableTextModelMixin):
class CodeGenerationModel(
_CodeGenerationModel,
_TunableTextModelMixin,
_ModelWithBatchPredict,
):
pass


Expand Down

0 comments on commit fbf2f7c

Please sign in to comment.