From 23cab80685fcec3779468b98e813a20329a4dddc Mon Sep 17 00:00:00 2001 From: EmanuelB25 Date: Thu, 6 Mar 2025 14:48:10 -0500 Subject: [PATCH] updating testing --- .../test_model_garden_examples.py | 64 ++++++++++++++++--- 1 file changed, 55 insertions(+), 9 deletions(-) diff --git a/generative_ai/model_garden/test_model_garden_examples.py b/generative_ai/model_garden/test_model_garden_examples.py index 03fd19a70eff..47e113fd9892 100644 --- a/generative_ai/model_garden/test_model_garden_examples.py +++ b/generative_ai/model_garden/test_model_garden_examples.py @@ -11,10 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from typing import Callable import backoff - from google.api_core.exceptions import ResourceExhausted +from google.cloud import storage +from google.cloud.aiplatform import BatchPredictionJob +from google.cloud.aiplatform_v1 import JobState +import pytest import claude_3_batch_prediciton_bq import claude_3_batch_prediction_gcs @@ -22,6 +27,39 @@ import claude_3_tool_example import claude_3_unary_example +PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") + +INPUT_BUCKET = "kellysun-test-project-europe-west1" +OUTPUT_BUCKET = "python-docs-samples-tests" +OUTPUT_PATH = "batch/batch_text_predict_output" +GCS_OUTPUT_PATH = "gs://python-docs-samples-tests/" +OUTPUT_TABLE = f"bq://{PROJECT_ID}.gen_ai_batch_prediction.predictions" + + +def _clean_resources() -> None: + storage_client = storage.Client() + bucket = storage_client.get_bucket(OUTPUT_BUCKET) + blobs = bucket.list_blobs(prefix=OUTPUT_PATH) + for blob in blobs: + blob.delete() + + +@pytest.fixture(scope="session") +def output_folder() -> str: + yield f"gs://{OUTPUT_BUCKET}/{OUTPUT_PATH}" + _clean_resources() + + +def _main_test(test_func: Callable) -> BatchPredictionJob: + job = None + try: + job = test_func() + assert job.state == JobState.JOB_STATE_SUCCEEDED + return job + finally: + if job is not None: + job.delete() + @backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10) def test_generate_text_streaming() -> None: @@ -44,13 +82,21 @@ def test_generate_text() -> None: assert "bread" in responses.model_dump_json(indent=2) -@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10) -def test_generate_text_gcs() -> None: - responses = claude_3_batch_prediction_gcs.generate_text() - assert "bread" in responses.model_dump_json(indent=2) +def test_batch_gemini_predict_gcs(output_folder: pytest.fixture()) -> None: + output_uri = "gs://python-docs-samples-tests" + job = _main_test( + test_func=lambda: claude_3_batch_prediction_gcs.batch_predict_gemini_createjob( + output_uri + ) + ) + assert GCS_OUTPUT_PATH in job.output_location -@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10) -def test_generate_text_bq() -> None: - responses = claude_3_batch_prediciton_bq.generate_text() - assert "bread" in responses.model_dump_json(indent=2) +def test_batch_gemini_predict_bigquery(output_folder: pytest.fixture()) -> None: + output_uri = f"bq://{PROJECT_ID}.gen_ai_batch_prediction.predictions" + job = _main_test( + test_func=lambda: claude_3_batch_prediciton_bq.batch_predict_gemini_createjob( + output_uri + ) + ) + assert OUTPUT_TABLE in job.output_location