Skip to content

Commit cb81d11

Browse files
committed
rebase with master
1 parent 4cef235 commit cb81d11

File tree

7 files changed

+9
-6
lines changed

7 files changed

+9
-6
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
get_wildcard_model_version_msg,
4141
get_wildcard_proprietary_model_version_msg,
4242
)
43-
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
4443
from sagemaker.jumpstart.parameters import (
4544
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
4645
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
@@ -480,6 +479,7 @@ def _retrieval_function(
480479
return JumpStartCachedContentValue(
481480
formatted_content=model_specs
482481
)
482+
483483
if data_type == HubType.HUB:
484484
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
485485
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)

src/sagemaker/jumpstart/types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from copy import deepcopy
1616
from enum import Enum
1717
from typing import Any, Dict, List, Optional, Set, Union
18-
from sagemaker.session import Session
1918
from sagemaker.utils import get_instance_type_family, format_tags, Tags
2019
from sagemaker.enums import EndpointType
2120
from sagemaker.model_metrics import ModelMetrics

src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import dataclasses
2525
import json
2626

27-
import sagemaker
28-
2927

3028
class _UTCFormatter(logging.Formatter):
3129
"""Class that overrides the default local time provider in log formatter."""

src/sagemaker/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import random
2323
import re
2424
import shutil
25-
import sys
2625
import tarfile
2726
import tempfile
2827
import time

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache):
137137
> 0
138138
)
139139

140+
140141
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
141142
def test_jumpstart_models_cache_get_model_specs(mock_cache):
142143
mock_cache.get_specs = Mock()

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
2929
JumpStartModelsCache,
3030
)
31-
from sagemaker.session_settings import SessionSettings
3231
from sagemaker.jumpstart.constants import (
3332
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
3433
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,13 @@ def patched_retrieval_function(
253253
model_type=JumpStartModelType.PROPRIETARY,
254254
)
255255
)
256+
257+
if datatype == HubContentType.MODEL:
258+
_, _, _, model_name, model_version = id_info.split("/")
259+
return JumpStartCachedContentValue(
260+
formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version)
261+
)
262+
256263
# TODO: Implement
257264
if datatype == HubType.HUB:
258265
return None

0 commit comments

Comments
 (0)