Skip to content

Commit aa2e62d

Browse files
beniericnargokul
authored andcommitted
Use sagemaker core Session (#1607)
* Use sagemaker core Session * update tests with session * update * flake8 * update docs --------- Co-authored-by: Gokul Anantha Narayanan <[email protected]>
1 parent 1e17a1e commit aa2e62d

File tree

7 files changed

+12
-10
lines changed

7 files changed

+12
-10
lines changed

src/sagemaker/modules/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker_core.main.utils import logger as sagemaker_core_logger
17+
from sagemaker_core.helper.session_helper import Session, get_execution_role # noqa: F401
1718

1819
logger = sagemaker_core_logger

src/sagemaker/modules/train/model_trainer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call
2828

29-
from sagemaker import get_execution_role, Session
29+
from sagemaker.modules import Session, get_execution_role
3030
from sagemaker.modules.configs import (
3131
Compute,
3232
StoppingCondition,
@@ -119,8 +119,9 @@ class ModelTrainer(BaseModel):
119119
```
120120
121121
Attributes:
122-
sagemaker_session (Optiona(Session)):
123-
The SageMaker session.
122+
session (Optiona(Session)):
123+
The SageMakerCore session. For convinience, can be imported like:
124+
`from sagemaker.modules import Session`.
124125
If not specified, a new session will be created.
125126
role (Optional(str)):
126127
The IAM role ARN for the training job.

tests/integ/sagemaker/modules/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import os
1919
import boto3
20-
from sagemaker import Session
20+
from sagemaker.modules import Session
2121

2222
DEFAULT_REGION = "us-west-2"
2323

tests/integ/sagemaker/modules/train/test_model_trainer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
def test_hp_contract_basic_py_script(modules_sagemaker_session):
4141
source_code = SourceCode(
42-
source_dir=f"{DATA_DIR}/modules/params-script",
42+
source_dir=f"{DATA_DIR}/modules/params_script",
4343
entry_script="train.py",
4444
)
4545

@@ -56,7 +56,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
5656

5757
def test_hp_contract_basic_sh_script(modules_sagemaker_session):
5858
source_code = SourceCode(
59-
source_dir=f"{DATA_DIR}/modules/params-script",
59+
source_dir=f"{DATA_DIR}/modules/params_script",
6060
entry_script="train.sh",
6161
)
6262
model_trainer = ModelTrainer(
@@ -72,7 +72,7 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):
7272

7373
def test_hp_contract_mpi_script(modules_sagemaker_session):
7474
source_code = SourceCode(
75-
source_dir=f"{DATA_DIR}/modules/params-script",
75+
source_dir=f"{DATA_DIR}/modules/params_script",
7676
entry_script="train.py",
7777
)
7878
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
@@ -91,7 +91,7 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):
9191

9292
def test_hp_contract_torchrun_script(modules_sagemaker_session):
9393
source_code = SourceCode(
94-
source_dir=f"{DATA_DIR}/modules/params-script",
94+
source_dir=f"{DATA_DIR}/modules/params_script",
9595
entry_script="train.py",
9696
)
9797
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)

tests/unit/sagemaker/modules/train/test_model_trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from sagemaker_core.main.resources import TrainingJob
2424

25-
from sagemaker.session import Session
25+
from sagemaker.modules import Session
2626
from sagemaker.modules.train.model_trainer import ModelTrainer
2727
from sagemaker.modules.constants import (
2828
DEFAULT_INSTANCE_TYPE,
@@ -80,7 +80,7 @@
8080

8181
@pytest.fixture(scope="module", autouse=True)
8282
def modules_session():
83-
with patch("sagemaker.session.Session", spec=Session) as session_mock:
83+
with patch("sagemaker.modules.Session", spec=Session) as session_mock:
8484
session_instance = session_mock.return_value
8585
session_instance.default_bucket.return_value = DEFAULT_BUCKET
8686
session_instance.get_caller_identity_arn.return_value = DEFAULT_ROLE

0 commit comments

Comments
 (0)