diff --git a/generative_ai/embeddings/batch_example.py b/generative_ai/embeddings/batch_example.py index 91be92de79b..bffb7419ae4 100644 --- a/generative_ai/embeddings/batch_example.py +++ b/generative_ai/embeddings/batch_example.py @@ -16,10 +16,9 @@ from google.cloud.aiplatform import BatchPredictionJob PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") -OUTPUT_URI = os.getenv("GCS_OUTPUT_URI") -def embed_text_batch() -> BatchPredictionJob: +def embed_text_batch(OUTPUT_URI: str) -> BatchPredictionJob: """Example of how to generate embeddings from text using batch processing. Read more: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/batch-prediction-genai-embeddings diff --git a/generative_ai/embeddings/test_embeddings_examples.py b/generative_ai/embeddings/test_embeddings_examples.py index b4472d25a56..b430b978e2c 100644 --- a/generative_ai/embeddings/test_embeddings_examples.py +++ b/generative_ai/embeddings/test_embeddings_examples.py @@ -22,7 +22,6 @@ from google.cloud import aiplatform from google.cloud.aiplatform import initializer as aiplatform_init -import pytest import batch_example import code_retrieval_example @@ -35,10 +34,8 @@ @backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10) -@pytest.fixture(scope="session") def test_embed_text_batch() -> None: - os.environ["GCS_OUTPUT_URI"] = "gs://python-docs-samples-tests/" - batch_prediction_job = batch_example.embed_text_batch() + batch_prediction_job = batch_example.embed_text_batch("gs://python-docs-samples-tests/") assert batch_prediction_job