diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py index 603362f03e..ab1d613128 100644 --- a/tests/system/aiplatform/test_language_models.py +++ b/tests/system/aiplatform/test_language_models.py @@ -33,6 +33,7 @@ ) from vertexai.preview.language_models import ( ChatModel, + CodeGenerationModel, InputOutputTextPair, TextGenerationModel, TextGenerationResponse, @@ -434,6 +435,26 @@ def test_batch_prediction_for_textembedding(self): assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED + def test_batch_prediction_for_code_generation(self): + source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/code-bison.batch_prediction_prompts.1.jsonl" + destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/code-bison@001_" + + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + model = CodeGenerationModel.from_pretrained("code-bison@001") + job = model.batch_predict( + dataset=source_uri, + destination_uri_prefix=destination_uri_prefix, + model_parameters={"temperature": 0}, + ) + + job.wait_for_resource_creation() + job.wait() + gapic_job = job._gca_resource + job.delete() + + assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED + def test_code_generation_streaming(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index ef208b26d5..4c6891c836 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -4324,6 +4324,36 @@ def test_batch_prediction( model_parameters={"temperature": 0.1}, ) + def test_batch_prediction_for_code_generation(self): + """Tests batch prediction.""" + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = preview_language_models.CodeGenerationModel.from_pretrained( + "code-bison@001" + ) + + with mock.patch.object( + target=aiplatform.BatchPredictionJob, + attribute="create", + ) as mock_create: + model.batch_predict( + dataset="gs://test-bucket/test_table.jsonl", + destination_uri_prefix="gs://test-bucket/results/", + model_parameters={}, + ) + mock_create.assert_called_once_with( + model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/code-bison@001", + job_display_name=None, + gcs_source="gs://test-bucket/test_table.jsonl", + gcs_destination_prefix="gs://test-bucket/results/", + model_parameters={}, + ) + def test_batch_prediction_for_text_embedding(self): """Tests batch prediction.""" aiplatform.init( diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index ebbfd9303c..dd7a5458d0 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -3366,7 +3366,11 @@ def count_tokens( ) -class CodeGenerationModel(_CodeGenerationModel, _TunableTextModelMixin): +class CodeGenerationModel( + _CodeGenerationModel, + _TunableTextModelMixin, + _ModelWithBatchPredict, +): pass