-
Notifications
You must be signed in to change notification settings - Fork 1.2k
tests: Implement integration tests covering JumpStart PrivateHub workflows #4883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
8910f50
bfeb2c0
705ceb9
386d836
fa7e47c
7c50ee8
cb5f1c7
52991a0
4bd94ec
3fed9f4
6456883
bac00dd
556d120
8ff04d3
d79b8a3
9e29524
e0b8467
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,24 +16,44 @@ | |
import boto3 | ||
import pytest | ||
from botocore.config import Config | ||
from sagemaker.jumpstart.hub.hub import Hub | ||
from sagemaker.session import Session | ||
from tests.integ.sagemaker.jumpstart.constants import ( | ||
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, | ||
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, | ||
HUB_NAME_PREFIX, | ||
JUMPSTART_TAG, | ||
SM_JUMPSTART_PUBLIC_HUB_NAME, | ||
) | ||
|
||
from sagemaker.jumpstart.types import ( | ||
HubContentType, | ||
) | ||
|
||
|
||
from tests.integ.sagemaker.jumpstart.utils import ( | ||
get_test_artifact_bucket, | ||
get_test_suite_id, | ||
get_sm_session, | ||
) | ||
|
||
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME | ||
|
||
|
||
def _setup(): | ||
print("Setting up...") | ||
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: get_test_suite_id()}) | ||
test_suit_id = get_test_suite_id() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
test_hub_name = f"{HUB_NAME_PREFIX}{test_suit_id}" | ||
test_hub_description = "PySDK Integ Test Private Hub" | ||
|
||
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: test_suit_id}) | ||
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: test_hub_name}) | ||
|
||
# Create a private hub to use for the test session | ||
hub = Hub( | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() | ||
) | ||
hub.create(description=test_hub_description) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we necessarily create a Hub every time a JS integ test is run? Does this bring any problems? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think of any problems tbh with this strategy, we're cleaning it up in the end. do you think we should be approaching it differently? |
||
|
||
|
||
def _teardown(): | ||
|
@@ -113,6 +133,37 @@ def _teardown(): | |
bucket = s3_resource.Bucket(test_cache_bucket) | ||
bucket.objects.filter(Prefix=test_suite_id + "/").delete() | ||
|
||
# delete private hubs | ||
_delete_hubs(sagemaker_session) | ||
|
||
|
||
def _delete_hubs(sagemaker_session): | ||
# list Hubs created by PySDK integration tests | ||
list_hub_response = sagemaker_session.list_hubs( | ||
name_contains=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we create a utility to get the hub name from the env var? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just to confirm you mean a function to get There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: would there ever be several hubs here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JGuinegagne yeah there'll be only one, I changed this recently previously I was deleting all the hubs starting with a specific prefix but that's messing up with the concurrent pytest executions. Let me no use list_hubs here, thanks. |
||
) | ||
|
||
for hub in list_hub_response["HubSummaries"]: | ||
if hub["HubName"] != SM_JUMPSTART_PUBLIC_HUB_NAME: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While we don't delete public hub, should we also restrict to only delete the hubs that are associated with a pysdk integ test? This can delete the hubs we don't want to in the account that runs the integ test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am already filtering the hubs related only to PySDK integ test on line 142 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Concurrent runs of the test will cause an issue with this clean-up strategy. Should it clean-up the hub created by this specific test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah i actually had to change it later on, I have updated the commit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't the suite id unique to each run? That would prevent overlapping test sessions on the same account/region from interfering with each other. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @evakravi yes but I wasn't deleting the unique Hub created by a specific Pytest session. I was rather deleting all the hubs created by PySDK test regardless of whether it was created by a specific pytest session or not. PySDK runs this command to run integ tests |
||
# delete all hub contents first | ||
_delete_hub_contents(sagemaker_session, hub["HubName"]) | ||
sagemaker_session.delete_hub(hub["HubName"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. careful, if there are more than one hub to delete, you might get throttled by the TPS limit of 1. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there'll be only 1 hub at a time |
||
|
||
|
||
def _delete_hub_contents(sagemaker_session, test_hub_name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this deletes only the model references (not models or other contents), recommend to take in a optional arg of which content type to delete |
||
# list hub_contents for the given hub | ||
list_hub_content_response = sagemaker_session.list_hub_contents( | ||
hub_name=test_hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value | ||
) | ||
|
||
# delete hub_contents for the given hub | ||
for models in list_hub_content_response["HubContentSummaries"]: | ||
sagemaker_session.delete_hub_content_reference( | ||
hub_name=test_hub_name, | ||
hub_content_type=HubContentType.MODEL_REFERENCE.value, | ||
hub_content_name=models["HubContentName"], | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. careful with throttling here. |
||
|
||
|
||
@pytest.fixture(scope="session", autouse=True) | ||
def setup(request): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,8 +37,13 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str: | |
|
||
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID = "JUMPSTART_SDK_TEST_SUITE_ID" | ||
|
||
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME = "JUMPSTART_SDK_TEST_HUB_NAME" | ||
|
||
JUMPSTART_TAG = "JumpStart-SDK-Integ-Test-Suite-Id" | ||
|
||
SM_JUMPSTART_PUBLIC_HUB_NAME = "SageMakerPublicHub" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Likely this is already defined somewhere in src |
||
|
||
HUB_NAME_PREFIX = "PySDK-HubTest-" | ||
|
||
TRAINING_DATASET_MODEL_DICT = { | ||
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import os | ||
import time | ||
|
||
import pytest | ||
from sagemaker.enums import EndpointType | ||
from sagemaker.jumpstart.hub.hub import Hub | ||
from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs | ||
from sagemaker.predictor import retrieve_default | ||
|
||
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER | ||
|
||
import tests.integ | ||
|
||
from sagemaker.jumpstart.model import JumpStartModel | ||
from tests.integ.sagemaker.jumpstart.constants import ( | ||
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, | ||
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, | ||
JUMPSTART_TAG, | ||
) | ||
from tests.integ.sagemaker.jumpstart.utils import ( | ||
get_public_hub_model_arn, | ||
get_sm_session, | ||
) | ||
|
||
MAX_INIT_TIME_SECONDS = 5 | ||
|
||
TEST_MODEL_IDS = { | ||
"catboost-classification-model", | ||
"huggingface-txt2img-conflictx-complex-lineart", | ||
"meta-textgeneration-llama-2-7b", | ||
"meta-textgeneration-llama-3-2-1b", | ||
"catboost-regression-model", | ||
} | ||
Comment on lines
+40
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the integ test runs in PDX. Double check these are available in PDX region. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these should be, I have chosen these models from the existing Jumpstart hub integ tests |
||
|
||
|
||
@pytest.fixture(scope="session") | ||
def add_models(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: consider renaming There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, thanks |
||
# Create Model References to test in Hub | ||
hub_instance = Hub( | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() | ||
) | ||
for model in TEST_MODEL_IDS: | ||
hub_instance.create_model_reference(model_arn=get_public_hub_model_arn(hub_instance, model)) | ||
|
||
|
||
def test_jumpstart_hub_model(setup, add_models): | ||
|
||
JUMPSTART_LOGGER.info("starting test") | ||
JUMPSTART_LOGGER.info(f"get identity {get_sm_session().get_caller_identity_arn()}") | ||
|
||
model_id = "catboost-classification-model" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we make |
||
|
||
model = JumpStartModel( | ||
model_id=model_id, | ||
role=get_sm_session().get_caller_identity_arn(), | ||
sagemaker_session=get_sm_session(), | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we test via the HubArn path as well, since we noticed an issue around that once. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry didn't get it, this is the HubArn path right? we ask customers to provide hub_name in the jumpstart model parameter, but we convert it into HubArn right after it. Sure customers can provide arn directly to model class but in that case we just leave it as it is and that gets passed to the rest of the code. |
||
) | ||
|
||
# uses ml.m5.4xlarge instance | ||
model.deploy( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should assert the success status of the endpoint, by adding a wait step to poll on status. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, we may even want to consider sending it one request using the default payload, although that's optional. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I have missed that default payload request here. Similar to other inference tests let me add it here |
||
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], | ||
) | ||
|
||
|
||
def test_jumpstart_hub_gated_model(setup, add_models): | ||
|
||
model_id = "meta-textgeneration-llama-3-2-1b" | ||
|
||
model = JumpStartModel( | ||
model_id=model_id, | ||
role=get_sm_session().get_caller_identity_arn(), | ||
sagemaker_session=get_sm_session(), | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], | ||
) | ||
|
||
# uses ml.g6.xlarge instance | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might change unless you pin the model version. |
||
predictor = model.deploy( | ||
accept_eula=True, | ||
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], | ||
) | ||
|
||
payload = { | ||
"inputs": "some-payload", | ||
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Retrieve the payload from the SDK utility please, otherwise this might fail if we ever release a new version. |
||
|
||
response = predictor.predict(payload, custom_attributes="accept_eula=true") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the |
||
|
||
assert response is not None | ||
|
||
|
||
def test_jumpstart_gated_model_inference_component_enabled(setup, add_models): | ||
|
||
model_id = "meta-textgeneration-llama-2-7b" | ||
|
||
hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] | ||
|
||
region = tests.integ.test_region() | ||
|
||
sagemaker_session = get_sm_session() | ||
|
||
hub_arn = generate_hub_arn_for_init_kwargs( | ||
hub_name=hub_name, region=region, session=sagemaker_session | ||
) | ||
|
||
model = JumpStartModel( | ||
model_id=model_id, | ||
role=get_sm_session().get_caller_identity_arn(), | ||
sagemaker_session=sagemaker_session, | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], | ||
) | ||
|
||
# uses ml.g5.2xlarge instance | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment |
||
model.deploy( | ||
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], | ||
accept_eula=True, | ||
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, | ||
) | ||
|
||
predictor = retrieve_default( | ||
endpoint_name=model.endpoint_name, | ||
sagemaker_session=sagemaker_session, | ||
tolerate_vulnerable_model=True, | ||
hub_arn=hub_arn, | ||
) | ||
|
||
payload = { | ||
"inputs": "some-payload", | ||
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment please |
||
|
||
response = predictor.predict(payload) | ||
|
||
assert response is not None | ||
|
||
model = JumpStartModel.attach( | ||
predictor.endpoint_name, sagemaker_session=sagemaker_session, hub_name=hub_name | ||
) | ||
assert model.model_id == model_id | ||
assert model.endpoint_name == predictor.endpoint_name | ||
assert model.inference_component_name == predictor.component_name | ||
|
||
|
||
def test_instatiating_model(setup, add_models): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i wonder if we can run all the tests currently for JumpStart but for PrivateHub. RIsk is that this test coverage drifts from the non-private-hub tests as new models/features are added. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I am not sure, are you thinking about reusing the test code of Jumpstart public hub specific tests and integrate PrivateHub workflow triggers there? won't it be too complicated? keeping separate increases the readability? also not all features available for content type Model is available for content type ModelReferences?
sorry I can't understand what you mean by this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Evan is suggesting defining a single set of model tests, and systematically test them against the public hub and a private hub. That would be substantial rework from your current PR though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix typo please: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
+1 That would require more efforts and rework. I can take it as an improvement but my question is I want to understand what's the benefit of it? is it better design or just a personal preference to implement it? |
||
|
||
model_id = "catboost-regression-model" | ||
|
||
start_time = time.perf_counter() | ||
|
||
JumpStartModel( | ||
model_id=model_id, | ||
role=get_sm_session().get_caller_identity_arn(), | ||
sagemaker_session=get_sm_session(), | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], | ||
) | ||
|
||
elapsed_time = time.perf_counter() - start_time | ||
|
||
assert elapsed_time <= MAX_INIT_TIME_SECONDS |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
from sagemaker.jumpstart.hub.hub import Hub | ||
|
||
from tests.integ.sagemaker.jumpstart.utils import ( | ||
get_sm_session, | ||
) | ||
from tests.integ.sagemaker.jumpstart.utils import ( | ||
get_test_suite_id, | ||
) | ||
from tests.integ.sagemaker.jumpstart.constants import ( | ||
HUB_NAME_PREFIX, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def hub_instance(): | ||
HUB_NAME = f"{HUB_NAME_PREFIX}-{get_test_suite_id()}" | ||
hub = Hub(HUB_NAME, sagemaker_session=get_sm_session()) | ||
yield hub | ||
|
||
|
||
def test_private_hub(setup, hub_instance): | ||
# Createhub | ||
create_hub_response = hub_instance.create( | ||
description="This is a Test Private Hub.", | ||
display_name="PySDK integration tests Hub", | ||
search_keywords=["jumpstart-sdk-integ-test"], | ||
) | ||
|
||
# Create Hub Verifications | ||
assert create_hub_response is not None | ||
|
||
# Describe Hub | ||
hub_description = hub_instance.describe() | ||
assert hub_description is not None | ||
|
||
# Delete Hub | ||
delete_hub_response = hub_instance.delete() | ||
assert delete_hub_response is not None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
from sagemaker.jumpstart.hub.hub import Hub | ||
|
||
from sagemaker.jumpstart.hub.interfaces import DescribeHubContentResponse | ||
from tests.integ.sagemaker.jumpstart.utils import ( | ||
get_sm_session, | ||
) | ||
from tests.integ.sagemaker.jumpstart.utils import get_public_hub_model_arn | ||
from tests.integ.sagemaker.jumpstart.constants import ( | ||
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, | ||
) | ||
|
||
|
||
def test_hub_model_reference(setup): | ||
model_id = "meta-textgenerationneuron-llama-3-2-1b-instruct" | ||
|
||
hub_instance = Hub( | ||
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() | ||
) | ||
|
||
# Create Model Reference | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: unnecessary comment |
||
create_model_response = hub_instance.create_model_reference( | ||
model_arn=get_public_hub_model_arn(hub_instance, model_id) | ||
) | ||
assert create_model_response is not None | ||
|
||
# Describe Model | ||
describe_model_response = hub_instance.describe_model(model_name=model_id) | ||
assert describe_model_response is not None | ||
assert type(describe_model_response) == DescribeHubContentResponse | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optional: consider asserting that
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, could you address this if you have time? I think what we are currently checking is pretty shallow, no harm to assert these in the tests. |
||
|
||
# Delete Model Reference | ||
delete_model_response = hub_instance.delete_model_reference(model_name=model_id) | ||
assert delete_model_response is not None |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
) | ||
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME | ||
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket | ||
from sagemaker.jumpstart.hub.hub import Hub | ||
|
||
from sagemaker.session import Session | ||
|
||
|
@@ -115,6 +116,20 @@ def download_file(local_download_path, s3_bucket, s3_key, s3_client) -> None: | |
s3_client.download_file(s3_bucket, s3_key, local_download_path) | ||
|
||
|
||
def get_public_hub_model_arn(hub: Hub, model_id: str) -> str: | ||
filter_value = f"model_id == {model_id}" | ||
response = hub.list_sagemaker_public_hub_models(filter=filter_value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: why not use a Or do you intend to test discovery through list? In that case, please rename the method accordingly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a util function for these tests not a test itself. use of describe method would get us a public hub content arn which would consist model version in it at the end. Create Model Reference don't accept arn with model version in it. On the other side we have implemented this list method in a way that it gets us model name and a public hub arn which can be accepted by create_model_reference api call. |
||
|
||
models = response["hub_content_summaries"] | ||
while response["next_token"]: | ||
response = hub.list_sagemaker_public_hub_models( | ||
filter=filter_value, next_token=response["next_token"] | ||
) | ||
models.extend(response["hub_content_summaries"]) | ||
|
||
return models[0]["hub_content_arn"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we only care about the first model arn, do we need to paginate all the responses? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are right we don't need to paginate here, infact this function only returns a single model because I am filtering on the exact model id. But the output is list so I am accessing the first element. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you move this to the unit tests? integ tests are only for when we need to create new resources. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @evakravi sorry I didn't get that. this is the utility function to get the Public hub content arn which in turn we use for creating the model reference in the Hub, why do we need to move it to unit tests? |
||
|
||
|
||
class EndpointInvoker: | ||
def __init__( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a unit test for this? seems like the current coverage didn't cover this bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
synced with @evakravi offline, this needs to be covered through unit tests and I'll add it as a fast follow.