Skip to content

Commit 7cf54a7

Browse files
authored
Add CountTokensOperator for Google Generative AI CountTokensAPI (#41908)
* Add CountTokensOperator for Google Generative AI CountTokensAPI * Update system test DAG with correct arguments
1 parent 2823acd commit 7cf54a7

File tree

6 files changed

+170
-4
lines changed

6 files changed

+170
-4
lines changed

airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
3333

3434
if TYPE_CHECKING:
35-
from google.cloud.aiplatform_v1 import types
35+
from google.cloud.aiplatform_v1 import types as types_v1
36+
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
3637

3738

3839
class GenerativeModelHook(GoogleBaseHook):
@@ -367,7 +368,7 @@ def supervised_fine_tuning_train(
367368
adapter_size: int | None = None,
368369
learning_rate_multiplier: float | None = None,
369370
project_id: str = PROVIDE_PROJECT_ID,
370-
) -> types.TuningJob:
371+
) -> types_v1.TuningJob:
371372
"""
372373
Use the Supervised Fine Tuning API to create a tuning job.
373374
@@ -406,3 +407,32 @@ def supervised_fine_tuning_train(
406407
sft_tuning_job.refresh()
407408

408409
return sft_tuning_job
410+
411+
@GoogleBaseHook.fallback_to_default_project_id
412+
def count_tokens(
413+
self,
414+
contents: list,
415+
location: str,
416+
pretrained_model: str = "gemini-pro",
417+
project_id: str = PROVIDE_PROJECT_ID,
418+
) -> types_v1beta1.CountTokensResponse:
419+
"""
420+
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
421+
422+
:param contents: Required. The multi-part content of a message that a user or a program
423+
gives to the generative model, in order to elicit a specific response.
424+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
425+
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
426+
supporting prompts with text-only input, including natural language
427+
tasks, multi-turn text and code chat, and code generation. It can
428+
output text and code.
429+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
430+
"""
431+
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
432+
433+
model = self.get_generative_model(pretrained_model)
434+
response = model.count_tokens(
435+
contents=contents,
436+
)
437+
438+
return response

airflow/providers/google/cloud/operators/vertex_ai/generative_model.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
from typing import TYPE_CHECKING, Sequence
2323

24-
from google.cloud.aiplatform_v1 import types
24+
from google.cloud.aiplatform_v1 import types as types_v1
25+
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
2526

2627
from airflow.exceptions import AirflowProviderDeprecationWarning
2728
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
@@ -665,4 +666,73 @@ def execute(self, context: Context):
665666
self.xcom_push(context, key="tuned_model_name", value=response.tuned_model_name)
666667
self.xcom_push(context, key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
667668

668-
return types.TuningJob.to_dict(response)
669+
return types_v1.TuningJob.to_dict(response)
670+
671+
672+
class CountTokensOperator(GoogleCloudBaseOperator):
673+
"""
674+
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
675+
676+
:param project_id: Required. The ID of the Google Cloud project that the
677+
service belongs to (templated).
678+
:param contents: Required. The multi-part content of a message that a user or a program
679+
gives to the generative model, in order to elicit a specific response.
680+
:param location: Required. The ID of the Google Cloud location that the
681+
service belongs to (templated).
682+
:param system_instruction: Optional. Instructions for the model to steer it toward better
683+
performance. For example, "Answer as concisely as possible"
684+
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
685+
supporting prompts with text-only input, including natural language
686+
tasks, multi-turn text and code chat, and code generation. It can
687+
output text and code.
688+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
689+
:param impersonation_chain: Optional service account to impersonate using short-term
690+
credentials, or chained list of accounts required to get the access_token
691+
of the last account in the list, which will be impersonated in the request.
692+
If set as a string, the account must grant the originating account
693+
the Service Account Token Creator IAM role.
694+
If set as a sequence, the identities from the list must grant
695+
Service Account Token Creator IAM role to the directly preceding identity, with first
696+
account from the list granting this role to the originating account (templated).
697+
"""
698+
699+
template_fields = ("location", "project_id", "impersonation_chain", "contents", "pretrained_model")
700+
701+
def __init__(
702+
self,
703+
*,
704+
project_id: str,
705+
contents: list,
706+
location: str,
707+
pretrained_model: str = "gemini-pro",
708+
gcp_conn_id: str = "google_cloud_default",
709+
impersonation_chain: str | Sequence[str] | None = None,
710+
**kwargs,
711+
) -> None:
712+
super().__init__(**kwargs)
713+
self.project_id = project_id
714+
self.location = location
715+
self.contents = contents
716+
self.pretrained_model = pretrained_model
717+
self.gcp_conn_id = gcp_conn_id
718+
self.impersonation_chain = impersonation_chain
719+
720+
def execute(self, context: Context):
721+
self.hook = GenerativeModelHook(
722+
gcp_conn_id=self.gcp_conn_id,
723+
impersonation_chain=self.impersonation_chain,
724+
)
725+
response = self.hook.count_tokens(
726+
project_id=self.project_id,
727+
location=self.location,
728+
contents=self.contents,
729+
pretrained_model=self.pretrained_model,
730+
)
731+
732+
self.log.info("Total tokens: %s", response.total_tokens)
733+
self.log.info("Total billable characters: %s", response.total_billable_characters)
734+
735+
self.xcom_push(context, key="total_tokens", value=response.total_tokens)
736+
self.xcom_push(context, key="total_billable_characters", value=response.total_billable_characters)
737+
738+
return types_v1beta1.CountTokensResponse.to_dict(response)

docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,17 @@ The operator returns the tuned model's endpoint name in :ref:`XCom <concepts:xco
625625
:start-after: [START how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
626626
:end-before: [END how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
627627

628+
629+
To calculates the number of input tokens before sending a request to the Gemini API you can use:
630+
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.CountTokensOperator`.
631+
The operator returns the total tokens in :ref:`XCom <concepts:xcom>` under ``total_tokens`` key.
632+
633+
.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
634+
:language: python
635+
:dedent: 4
636+
:start-after: [START how_to_cloud_vertex_ai_count_tokens_operator]
637+
:end-before: [END how_to_cloud_vertex_ai_count_tokens_operator]
638+
628639
Reference
629640
^^^^^^^^^
630641

tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,16 @@ def test_supervised_fine_tuning_train(self, mock_sft_train) -> None:
217217
learning_rate_multiplier=None,
218218
tuned_model_display_name=None,
219219
)
220+
221+
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model"))
222+
def test_count_tokens(self, mock_model) -> None:
223+
self.hook.count_tokens(
224+
project_id=GCP_PROJECT,
225+
contents=TEST_CONTENTS,
226+
location=GCP_LOCATION,
227+
pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL,
228+
)
229+
mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL)
230+
mock_model.return_value.count_tokens.assert_called_once_with(
231+
contents=TEST_CONTENTS,
232+
)

tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@
2626

2727
# For no Pydantic environment, we need to skip the tests
2828
pytest.importorskip("google.cloud.aiplatform_v1")
29+
pytest.importorskip("google.cloud.aiplatform_v1beta1")
2930
vertexai = pytest.importorskip("vertexai.generative_models")
3031
from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Tool, grounding
3132

3233
from airflow.providers.google.cloud.operators.vertex_ai.generative_model import (
34+
CountTokensOperator,
3335
GenerateTextEmbeddingsOperator,
3436
GenerativeModelGenerateContentOperator,
3537
PromptLanguageModelOperator,
@@ -417,3 +419,32 @@ def test_execute(
417419
tuned_model_display_name=None,
418420
validation_dataset=None,
419421
)
422+
423+
424+
class TestVertexAICountTokensOperator:
425+
@mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
426+
@mock.patch("google.cloud.aiplatform_v1beta1.types.CountTokensResponse.to_dict")
427+
def test_execute(self, to_dict_mock, mock_hook):
428+
contents = ["In 10 words or less, what is Apache Airflow?"]
429+
pretrained_model = "gemini-pro"
430+
431+
op = CountTokensOperator(
432+
task_id=TASK_ID,
433+
project_id=GCP_PROJECT,
434+
location=GCP_LOCATION,
435+
contents=contents,
436+
pretrained_model=pretrained_model,
437+
gcp_conn_id=GCP_CONN_ID,
438+
impersonation_chain=IMPERSONATION_CHAIN,
439+
)
440+
op.execute(context={"ti": mock.MagicMock()})
441+
mock_hook.assert_called_once_with(
442+
gcp_conn_id=GCP_CONN_ID,
443+
impersonation_chain=IMPERSONATION_CHAIN,
444+
)
445+
mock_hook.return_value.count_tokens.assert_called_once_with(
446+
project_id=GCP_PROJECT,
447+
location=GCP_LOCATION,
448+
contents=contents,
449+
pretrained_model=pretrained_model,
450+
)

tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from airflow.models.dag import DAG
3131
from airflow.providers.google.cloud.operators.vertex_ai.generative_model import (
32+
CountTokensOperator,
3233
GenerativeModelGenerateContentOperator,
3334
TextEmbeddingModelGetEmbeddingsOperator,
3435
TextGenerationModelPredictOperator,
@@ -84,6 +85,16 @@
8485
)
8586
# [END how_to_cloud_vertex_ai_text_embedding_model_get_embeddings_operator]
8687

88+
# [START how_to_cloud_vertex_ai_count_tokens_operator]
89+
count_tokens_task = CountTokensOperator(
90+
task_id="count_tokens_task",
91+
project_id=PROJECT_ID,
92+
contents=CONTENTS,
93+
location=REGION,
94+
pretrained_model=MULTIMODAL_MODEL,
95+
)
96+
# [END how_to_cloud_vertex_ai_count_tokens_operator]
97+
8798
# [START how_to_cloud_vertex_ai_generative_model_generate_content_operator]
8899
generate_content_task = GenerativeModelGenerateContentOperator(
89100
task_id="generate_content_task",

0 commit comments

Comments
 (0)