Skip to content

Commit c7d3bd2

Browse files
committed
fix: unit tests
1 parent e070e78 commit c7d3bd2

34 files changed

+141
-108
lines changed

src/sagemaker/accept_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import List, Optional
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
18+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1819
from sagemaker.session import Session
1920

2021

@@ -24,7 +25,7 @@ def retrieve_options(
2425
model_version: Optional[str] = None,
2526
tolerate_vulnerable_model: bool = False,
2627
tolerate_deprecated_model: bool = False,
27-
sagemaker_session: Session = Session(),
28+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2829
) -> List[str]:
2930
"""Retrieves the supported accept types for the model matching the given arguments.
3031
@@ -73,7 +74,7 @@ def retrieve_default(
7374
model_version: Optional[str] = None,
7475
tolerate_vulnerable_model: bool = False,
7576
tolerate_deprecated_model: bool = False,
76-
sagemaker_session: Session = Session(),
77+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
7778
) -> str:
7879
"""Retrieves the default accept type for the model matching the given arguments.
7980

src/sagemaker/content_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import List, Optional
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
18+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1819
from sagemaker.session import Session
1920

2021

@@ -24,7 +25,7 @@ def retrieve_options(
2425
model_version: Optional[str] = None,
2526
tolerate_vulnerable_model: bool = False,
2627
tolerate_deprecated_model: bool = False,
27-
sagemaker_session: Session = Session(),
28+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2829
) -> List[str]:
2930
"""Retrieves the supported content types for the model matching the given arguments.
3031
@@ -73,7 +74,7 @@ def retrieve_default(
7374
model_version: Optional[str] = None,
7475
tolerate_vulnerable_model: bool = False,
7576
tolerate_deprecated_model: bool = False,
76-
sagemaker_session: Session = Session(),
77+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
7778
) -> str:
7879
"""Retrieves the default content type for the model matching the given arguments.
7980

src/sagemaker/deserializers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434

3535
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
36+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
3637
from sagemaker.session import Session
3738

3839

@@ -42,7 +43,7 @@ def retrieve_options(
4243
model_version: Optional[str] = None,
4344
tolerate_vulnerable_model: bool = False,
4445
tolerate_deprecated_model: bool = False,
45-
sagemaker_session: Session = Session(),
46+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4647
) -> List[BaseDeserializer]:
4748
"""Retrieves the supported deserializers for the model matching the given arguments.
4849
@@ -92,7 +93,7 @@ def retrieve_default(
9293
model_version: Optional[str] = None,
9394
tolerate_vulnerable_model: bool = False,
9495
tolerate_deprecated_model: bool = False,
95-
sagemaker_session: Session = Session(),
96+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
9697
) -> BaseDeserializer:
9798
"""Retrieves the default deserializer for the model matching the given arguments.
9899

src/sagemaker/environment_variables.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
22+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
2223
from sagemaker.session import Session
2324

2425
logger = logging.getLogger(__name__)
@@ -31,7 +32,7 @@ def retrieve_default(
3132
tolerate_vulnerable_model: bool = False,
3233
tolerate_deprecated_model: bool = False,
3334
include_aws_sdk_env_vars: bool = True,
34-
sagemaker_session: Session = Session(),
35+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3536
) -> Dict[str, str]:
3637
"""Retrieves the default container environment variables for the model matching the arguments.
3738

src/sagemaker/hyperparameters.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
22+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
2223
from sagemaker.jumpstart.enums import HyperparameterValidationMode
2324
from sagemaker.jumpstart.validators import validate_hyperparameters
2425
from sagemaker.session import Session
@@ -33,7 +34,7 @@ def retrieve_default(
3334
include_container_hyperparameters: bool = False,
3435
tolerate_vulnerable_model: bool = False,
3536
tolerate_deprecated_model: bool = False,
36-
sagemaker_session: Session = Session(),
37+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3738
) -> Dict[str, str]:
3839
"""Retrieves the default training hyperparameters for the model matching the given arguments.
3940
@@ -92,7 +93,7 @@ def validate(
9293
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
9394
tolerate_vulnerable_model: bool = False,
9495
tolerate_deprecated_model: bool = False,
95-
sagemaker_session: Session = Session(),
96+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
9697
) -> None:
9798
"""Validates hyperparameters for models.
9899

src/sagemaker/image_uris.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from packaging.version import Version
2222

2323
from sagemaker import utils
24+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
2425
from sagemaker.jumpstart.utils import is_jumpstart_model_input
25-
from sagemaker.session import Session
2626
from sagemaker.spark import defaults
2727
from sagemaker.jumpstart import artifacts
2828
from sagemaker.workflow import is_pipeline_variable
@@ -61,7 +61,7 @@ def retrieve(
6161
sdk_version=None,
6262
inference_tool=None,
6363
serverless_inference_config=None,
64-
sagemaker_session=Session(),
64+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
6565
) -> str:
6666
"""Retrieves the ECR URI for the Docker image matching the given arguments.
6767
@@ -114,7 +114,7 @@ def retrieve(
114114
sagemaker_session (sagemaker.session.Session): A SageMaker Session
115115
object, used for SageMaker interactions. If not
116116
specified, one is created using the default AWS configuration
117-
chain. (Default: Session()).
117+
chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
118118
119119
Returns:
120120
str: The ECR URI for the corresponding SageMaker Docker image.

src/sagemaker/instance_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
22+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
2223
from sagemaker.session import Session
2324

2425
logger = logging.getLogger(__name__)
@@ -31,7 +32,7 @@ def retrieve_default(
3132
scope: Optional[str] = None,
3233
tolerate_vulnerable_model: bool = False,
3334
tolerate_deprecated_model: bool = False,
34-
sagemaker_session: Session = Session(),
35+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3536
) -> str:
3637
"""Retrieves the default instance type for the model matching the given arguments.
3738
@@ -87,7 +88,7 @@ def retrieve(
8788
scope: Optional[str] = None,
8889
tolerate_vulnerable_model: bool = False,
8990
tolerate_deprecated_model: bool = False,
90-
sagemaker_session: Session = Session(),
91+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
9192
) -> List[str]:
9293
"""Retrieves the supported training instance types for the model matching the given arguments.
9394

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515
from typing import Dict, Optional
1616
from sagemaker.jumpstart.constants import (
17+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1718
JUMPSTART_DEFAULT_REGION_NAME,
1819
)
1920
from sagemaker.jumpstart.enums import (
@@ -32,7 +33,7 @@ def _retrieve_default_environment_variables(
3233
tolerate_vulnerable_model: bool = False,
3334
tolerate_deprecated_model: bool = False,
3435
include_aws_sdk_env_vars: bool = True,
35-
sagemaker_session: Session = Session(),
36+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3637
) -> Dict[str, str]:
3738
"""Retrieves the inference environment variables for the model matching the given arguments.
3839
@@ -57,7 +58,7 @@ def _retrieve_default_environment_variables(
5758
sagemaker_session (sagemaker.session.Session): A SageMaker Session
5859
object, used for SageMaker interactions. If not
5960
specified, one is created using the default AWS configuration
60-
chain. (Default: Session()).
61+
chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
6162
Returns:
6263
dict: the inference environment variables to use for the model.
6364
"""

src/sagemaker/jumpstart/artifacts/hyperparameters.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515
from typing import Dict, Optional
1616
from sagemaker.jumpstart.constants import (
17+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1718
JUMPSTART_DEFAULT_REGION_NAME,
1819
)
1920
from sagemaker.jumpstart.enums import (
@@ -33,7 +34,7 @@ def _retrieve_default_hyperparameters(
3334
include_container_hyperparameters: bool = False,
3435
tolerate_vulnerable_model: bool = False,
3536
tolerate_deprecated_model: bool = False,
36-
sagemaker_session: Session = Session(),
37+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3738
):
3839
"""Retrieves the training hyperparameters for the model matching the given arguments.
3940
@@ -61,7 +62,7 @@ def _retrieve_default_hyperparameters(
6162
sagemaker_session (sagemaker.session.Session): A SageMaker Session
6263
object, used for SageMaker interactions. If not
6364
specified, one is created using the default AWS configuration
64-
chain. (Default: Session()).
65+
chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
6566
Returns:
6667
dict: the hyperparameters to use for the model.
6768
"""

src/sagemaker/jumpstart/artifacts/image_uris.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Optional
1717
from sagemaker import image_uris
1818
from sagemaker.jumpstart.constants import (
19+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1920
JUMPSTART_DEFAULT_REGION_NAME,
2021
)
2122
from sagemaker.jumpstart.enums import (
@@ -44,7 +45,7 @@ def _retrieve_image_uri(
4445
training_compiler_config: Optional[str] = None,
4546
tolerate_vulnerable_model: bool = False,
4647
tolerate_deprecated_model: bool = False,
47-
sagemaker_session: Session = Session(),
48+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4849
):
4950
"""Retrieves the container image URI for JumpStart models.
5051
@@ -93,7 +94,7 @@ def _retrieve_image_uri(
9394
sagemaker_session (sagemaker.session.Session): A SageMaker Session
9495
object, used for SageMaker interactions. If not
9596
specified, one is created using the default AWS configuration
96-
chain. (Default: Session()).
97+
chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
9798
Returns:
9899
str: the ECR URI for the corresponding SageMaker Docker image.
99100

src/sagemaker/jumpstart/artifacts/incremental_training.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515
from typing import Optional
1616
from sagemaker.jumpstart.constants import (
17+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1718
JUMPSTART_DEFAULT_REGION_NAME,
1819
)
1920
from sagemaker.jumpstart.enums import (
@@ -31,7 +32,7 @@ def _model_supports_incremental_training(
3132
region: Optional[str],
3233
tolerate_vulnerable_model: bool = False,
3334
tolerate_deprecated_model: bool = False,
34-
sagemaker_session: Session = Session(),
35+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3536
) -> bool:
3637
"""Returns True if the model supports incremental training.
3738
@@ -52,7 +53,7 @@ def _model_supports_incremental_training(
5253
sagemaker_session (sagemaker.session.Session): A SageMaker Session
5354
object, used for SageMaker interactions. If not
5455
specified, one is created using the default AWS configuration
55-
chain. (Default: Session()).
56+
chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
5657
Returns:
5758
bool: the support status for incremental training.
5859
"""

src/sagemaker/jumpstart/artifacts/instance_types.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from sagemaker.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG
1919
from sagemaker.jumpstart.constants import (
20+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2021
JUMPSTART_DEFAULT_REGION_NAME,
2122
)
2223
from sagemaker.jumpstart.enums import (
@@ -35,7 +36,7 @@ def _retrieve_default_instance_type(
3536
region: Optional[str] = None,
3637
tolerate_vulnerable_model: bool = False,
3738
tolerate_deprecated_model: bool = False,
38-
sagemaker_session: Session = Session(),
39+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3940
) -> str:
4041
"""Retrieves the default instance type for the model.
4142
@@ -58,7 +59,7 @@ def _retrieve_default_instance_type(
5859
sagemaker_session (sagemaker.session.Session): A SageMaker Session
5960
object, used for SageMaker interactions. If not
6061
specified, one is created using the default AWS configuration
61-
chain. (Default: Session()).
62+
chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
6263
Returns:
6364
str: the default instance type to use for the model or None.
6465
@@ -101,7 +102,7 @@ def _retrieve_instance_types(
101102
region: Optional[str] = None,
102103
tolerate_vulnerable_model: bool = False,
103104
tolerate_deprecated_model: bool = False,
104-
sagemaker_session: Session = Session(),
105+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
105106
) -> List[str]:
106107
"""Retrieves the supported instance types for the model.
107108
@@ -124,7 +125,7 @@ def _retrieve_instance_types(
124125
sagemaker_session (sagemaker.session.Session): A SageMaker Session
125126
object, used for SageMaker interactions. If not
126127
specified, one is created using the default AWS configuration
127-
chain. (Default: Session()).
128+
chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
128129
Returns:
129130
list: the supported instance types to use for the model or None.
130131

0 commit comments

Comments
 (0)