Skip to content

Commit 5003f35

Browse files
yishan-puYishan Pu
andauthored
fix: update tuning code samples to improve readability (#13384)
Co-authored-by: Yishan Pu <[email protected]>
1 parent b2f4217 commit 5003f35

7 files changed

+44
-36
lines changed

genai/tuning/test_tuning_examples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_tuning_textgen_with_txt(mock_genai_client: MagicMock) -> None:
113113
mock_genai_client.return_value.tunings.get.return_value = mock_tuning_job
114114
mock_genai_client.return_value.models.generate_content.return_value = mock_response
115115

116-
tuning_textgen_with_txt.test_tuned_endpoint("test-tuning-job")
116+
tuning_textgen_with_txt.predict_with_tuned_endpoint("test-tuning-job")
117117

118118
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
119119
mock_genai_client.return_value.tunings.get.assert_called_once()
@@ -277,7 +277,7 @@ def test_tuning_with_checkpoints_textgen_with_txt(mock_genai_client: MagicMock)
277277
mock_genai_client.return_value.tunings.get.return_value = mock_tuning_job
278278
mock_genai_client.return_value.models.generate_content.return_value = mock_response
279279

280-
tuning_with_checkpoints_textgen_with_txt.test_checkpoint("test-tuning-job")
280+
tuning_with_checkpoints_textgen_with_txt.predict_with_checkpoints("test-tuning-job")
281281

282282
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
283283
mock_genai_client.return_value.tunings.get.assert_called_once()

genai/tuning/tuning_job_get.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
# limitations under the License.
1414

1515

16-
def get_tuning_job(name: str) -> str:
16+
def get_tuning_job(tuning_job_name: str) -> str:
1717
# [START googlegenaisdk_tuning_job_get]
1818
from google import genai
1919
from google.genai.types import HttpOptions
2020

2121
client = genai.Client(http_options=HttpOptions(api_version="v1"))
2222

2323
# Get the tuning job and the tuned model.
24-
# Eg. name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25-
tuning_job = client.tunings.get(name=name)
24+
# Eg. tuning_job_name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25+
tuning_job = client.tunings.get(name=tuning_job_name)
2626

2727
print(tuning_job.tuned_model.model)
2828
print(tuning_job.tuned_model.endpoint)
@@ -37,5 +37,5 @@ def get_tuning_job(name: str) -> str:
3737

3838

3939
if __name__ == "__main__":
40-
tuning_job_name = input("Tuning job name: ")
41-
get_tuning_job(tuning_job_name)
40+
input_tuning_job_name = input("Tuning job name: ")
41+
get_tuning_job(input_tuning_job_name)

genai/tuning/tuning_textgen_with_txt.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,32 @@
1313
# limitations under the License.
1414

1515

16-
def test_tuned_endpoint(name: str) -> str:
16+
def predict_with_tuned_endpoint(tuning_job_name: str) -> str:
1717
# [START googlegenaisdk_tuning_textgen_with_txt]
1818
from google import genai
1919
from google.genai.types import HttpOptions
2020

2121
client = genai.Client(http_options=HttpOptions(api_version="v1"))
2222

2323
# Get the tuning job and the tuned model.
24-
# Eg. name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25-
tuning_job = client.tunings.get(name=name)
24+
# Eg. tuning_job_name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25+
tuning_job = client.tunings.get(name=tuning_job_name)
2626

2727
contents = "Why is the sky blue?"
2828

29-
# Tests the default checkpoint
29+
# Predicts with the tuned endpoint.
3030
response = client.models.generate_content(
3131
model=tuning_job.tuned_model.endpoint,
3232
contents=contents,
3333
)
3434
print(response.text)
35+
# Example response:
36+
# The sky is blue because ...
3537

3638
# [END googlegenaisdk_tuning_textgen_with_txt]
3739
return response.text
3840

3941

4042
if __name__ == "__main__":
41-
tuning_job_name = input("Tuning job name: ")
42-
test_tuned_endpoint(tuning_job_name)
43+
input_tuning_job_name = input("Tuning job name: ")
44+
predict_with_tuned_endpoint(input_tuning_job_name)

genai/tuning/tuning_with_checkpoints_get_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
# limitations under the License.
1414

1515

16-
def get_tuned_model_with_checkpoints(name: str) -> str:
16+
def get_tuned_model_with_checkpoints(tuning_job_name: str) -> str:
1717
# [START googlegenaisdk_tuning_with_checkpoints_get_model]
1818
from google import genai
1919
from google.genai.types import HttpOptions
2020

2121
client = genai.Client(http_options=HttpOptions(api_version="v1"))
2222

2323
# Get the tuning job and the tuned model.
24-
# Eg. name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25-
tuning_job = client.tunings.get(name=name)
24+
# Eg. tuning_job_name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25+
tuning_job = client.tunings.get(name=tuning_job_name)
2626
tuned_model = client.models.get(model=tuning_job.tuned_model.model)
2727
print(tuned_model)
2828
# Example response:
@@ -44,5 +44,5 @@ def get_tuned_model_with_checkpoints(name: str) -> str:
4444

4545

4646
if __name__ == "__main__":
47-
tuning_job_name = input("Tuning job name: ")
48-
get_tuned_model_with_checkpoints(tuning_job_name)
47+
input_tuning_job_name = input("Tuning job name: ")
48+
get_tuned_model_with_checkpoints(input_tuning_job_name)

genai/tuning/tuning_with_checkpoints_list_checkpoints.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
# limitations under the License.
1414

1515

16-
def list_checkpoints(name: str) -> str:
16+
def list_checkpoints(tuning_job_name: str) -> str:
1717
# [START googlegenaisdk_tuning_with_checkpoints_list_checkpoints]
1818
from google import genai
1919
from google.genai.types import HttpOptions
2020

2121
client = genai.Client(http_options=HttpOptions(api_version="v1"))
2222

2323
# Get the tuning job and the tuned model.
24-
# Eg. name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25-
tuning_job = client.tunings.get(name=name)
24+
# Eg. tuning_job_name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25+
tuning_job = client.tunings.get(name=tuning_job_name)
2626

2727
if tuning_job.tuned_model.checkpoints:
2828
for i, checkpoint in enumerate(tuning_job.tuned_model.checkpoints):
@@ -36,5 +36,5 @@ def list_checkpoints(name: str) -> str:
3636

3737

3838
if __name__ == "__main__":
39-
tuning_job_name = input("Tuning job name: ")
40-
list_checkpoints(tuning_job_name)
39+
input_tuning_job_name = input("Tuning job name: ")
40+
list_checkpoints(input_tuning_job_name)

genai/tuning/tuning_with_checkpoints_set_default_checkpoint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
# limitations under the License.
1414

1515

16-
def set_default_checkpoint(name: str, checkpoint_id: str) -> str:
16+
def set_default_checkpoint(tuning_job_name: str, checkpoint_id: str) -> str:
1717
# [START googlegenaisdk_tuning_with_checkpoints_set_default]
1818
from google import genai
1919
from google.genai.types import HttpOptions, UpdateModelConfig
2020

2121
client = genai.Client(http_options=HttpOptions(api_version="v1"))
2222

2323
# Get the tuning job and the tuned model.
24-
# Eg. name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25-
tuning_job = client.tunings.get(name=name)
24+
# Eg. tuning_job_name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25+
tuning_job = client.tunings.get(name=tuning_job_name)
2626
tuned_model = client.models.get(model=tuning_job.tuned_model.model)
2727

2828
print(f"Default checkpoint: {tuned_model.default_checkpoint_id}")
@@ -49,6 +49,6 @@ def set_default_checkpoint(name: str, checkpoint_id: str) -> str:
4949

5050

5151
if __name__ == "__main__":
52-
tuning_job_name = input("Tuning job name: ")
52+
input_tuning_job_name = input("Tuning job name: ")
5353
default_checkpoint_id = input("Default checkpoint id: ")
54-
set_default_checkpoint(tuning_job_name, default_checkpoint_id)
54+
set_default_checkpoint(input_tuning_job_name, default_checkpoint_id)

genai/tuning/tuning_with_checkpoints_textgen_with_txt.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,50 @@
1313
# limitations under the License.
1414

1515

16-
def test_checkpoint(name: str) -> str:
16+
def predict_with_checkpoints(tuning_job_name: str) -> str:
1717
# [START googlegenaisdk_tuning_with_checkpoints_test]
1818
from google import genai
1919
from google.genai.types import HttpOptions
2020

2121
client = genai.Client(http_options=HttpOptions(api_version="v1"))
2222

2323
# Get the tuning job and the tuned model.
24-
# Eg. name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25-
tuning_job = client.tunings.get(name=name)
24+
# Eg. tuning_job_name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25+
tuning_job = client.tunings.get(name=tuning_job_name)
2626

2727
contents = "Why is the sky blue?"
2828

29-
# Tests the default checkpoint
29+
# Predicts with the default checkpoint.
3030
response = client.models.generate_content(
3131
model=tuning_job.tuned_model.endpoint,
3232
contents=contents,
3333
)
3434
print(response.text)
35+
# Example response:
36+
# The sky is blue because ...
3537

36-
# Tests Checkpoint 1
38+
# Predicts with Checkpoint 1.
3739
checkpoint1_response = client.models.generate_content(
3840
model=tuning_job.tuned_model.checkpoints[0].endpoint,
3941
contents=contents,
4042
)
4143
print(checkpoint1_response.text)
44+
# Example response:
45+
# The sky is blue because ...
4246

43-
# Tests Checkpoint 2
47+
# Predicts with Checkpoint 2.
4448
checkpoint2_response = client.models.generate_content(
4549
model=tuning_job.tuned_model.checkpoints[1].endpoint,
4650
contents=contents,
4751
)
4852
print(checkpoint2_response.text)
53+
# Example response:
54+
# The sky is blue because ...
4955

5056
# [END googlegenaisdk_tuning_with_checkpoints_test]
5157
return response.text
5258

5359

5460
if __name__ == "__main__":
55-
tuning_job_name = input("Tuning job name: ")
56-
test_checkpoint(tuning_job_name)
61+
input_tuning_job_name = input("Tuning job name: ")
62+
predict_with_checkpoints(input_tuning_job_name)

0 commit comments

Comments
 (0)