Skip to content

Commit 03a3ac7

Browse files
beniericpintaoz-aws
authored andcommitted
Use exact python path in trainer template (#1584)
1 parent fcae9f7 commit 03a3ac7

File tree

5 files changed

+31
-33
lines changed

5 files changed

+31
-33
lines changed

src/sagemaker/modules/templates.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,16 @@
5050
}}
5151
5252
check_python() {{
53-
if command -v python3 &>/dev/null; then
54-
SM_PYTHON_CMD="python3"
55-
SM_PIP_CMD="pip3"
56-
echo "Found python3"
57-
elif command -v python &>/dev/null; then
58-
SM_PYTHON_CMD="python"
59-
SM_PIP_CMD="pip"
60-
echo "Found python"
61-
else
62-
echo "Python may not be installed"
53+
SM_PYTHON_CMD=$(command -v python3 || command -v python)
54+
SM_PIP_CMD=$(command -v pip3 || command -v pip)
55+
56+
# Check if Python is found
57+
if [[ -z "$SM_PYTHON_CMD" || -z "$SM_PIP_CMD" ]]; then
58+
echo "Error: The Python executable was not found in the system path."
6359
return 1
6460
fi
61+
62+
return 0
6563
}}
6664
6765
trap 'handle_error' ERR

src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
"metadata": {},
4747
"outputs": [],
4848
"source": [
49-
"model_trainer.train()"
49+
"model_trainer.train(wait=False)"
5050
]
5151
},
5252
{
@@ -94,7 +94,7 @@
9494
" environment=env_vars,\n",
9595
")\n",
9696
"\n",
97-
"model_trainer.train()"
97+
"model_trainer.train(wait=False)"
9898
]
9999
},
100100
{
@@ -137,7 +137,7 @@
137137
"metadata": {},
138138
"outputs": [],
139139
"source": [
140-
"model_trainer.train()"
140+
"model_trainer.train(wait=False)"
141141
]
142142
},
143143
{

src/sagemaker/modules/train/model_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def model_post_init(self, __context: Any):
293293
logger.warning("Session not provided. Using default Session.")
294294

295295
if self.role is None:
296-
self.role = get_execution_role()
296+
self.role = get_execution_role(sagemaker_session=self.session)
297297
logger.warning(f"Role not provided. Using default role:\n{self.role}")
298298

299299
if self.base_job_name is None:

src/sagemaker/serve/builder/model_builder.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
download_huggingface_model_metadata,
113113
)
114114
from sagemaker.serve.validations.optimization import _validate_optimization_configuration
115+
from sagemaker.modules.train import ModelTrainer
115116

116117
logger = logging.getLogger(__name__)
117118

@@ -272,9 +273,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
272273
schema_builder: Optional[SchemaBuilder] = field(
273274
default=None, metadata={"help": "Defines the i/o schema of the model"}
274275
)
275-
model: Optional[Union[object, str, "ModelTrainer", TrainingJob, Estimator]] = field(
276+
model: Optional[Union[object, str, ModelTrainer, TrainingJob, Estimator]] = field(
276277
default=None,
277-
metadata={"help": "Define object from which training artifacts can be extracted"}
278+
metadata={"help": "Define object from which training artifacts can be extracted"},
278279
)
279280
inference_spec: InferenceSpec = field(
280281
default=None,
@@ -851,7 +852,7 @@ def build( # pylint: disable=R0911
851852
Returns:
852853
Type[Model]: A deployable ``Model`` object.
853854
"""
854-
from sagemaker.modules.train.model_trainer import ModelTrainer
855+
855856
self.modes = dict()
856857

857858
if mode:

tests/integ/sagemaker/modules/conftest.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,27 @@
1414
from __future__ import absolute_import
1515

1616
import pytest
17-
import json
1817

18+
import os
1919
import boto3
20-
from botocore.config import Config
2120
from sagemaker import Session
2221

2322
DEFAULT_REGION = "us-west-2"
2423

24+
2525
@pytest.fixture(scope="module")
26-
def modules_boto_session(request):
27-
config = request.config.getoption("--boto-config")
28-
if config:
29-
return boto3.Session(**json.loads(config))
26+
def modules_sagemaker_session():
27+
region = os.environ.get("AWS_DEFAULT_REGION")
28+
if not region:
29+
os.environ["AWS_DEFAULT_REGION"] = DEFAULT_REGION
30+
region_manual_set = True
3031
else:
31-
return boto3.Session(region_name=DEFAULT_REGION)
32+
region_manual_set = True
3233

33-
@pytest.fixture(scope="module")
34-
def modules_sagemaker_session(request, modules_boto_session):
35-
sagemaker_client = (
36-
modules_boto_session.client(
37-
"sagemaker",
38-
config=Config(retries={"max_attempts": 10, "mode": "standard"})
39-
)
40-
)
41-
return Session(boto_session=modules_boto_session)
34+
boto_session = boto3.Session(region_name=os.environ["AWS_DEFAULT_REGION"])
35+
sagemaker_session = Session(boto_session=boto_session)
36+
37+
yield sagemaker_session
38+
39+
if region_manual_set:
40+
del os.environ["AWS_DEFAULT_REGION"]

0 commit comments

Comments
 (0)