@@ -128,7 +128,8 @@ def __init__(
128
128
129
129
if self .instance_type in ("local" , "local_gpu" ):
130
130
if not isinstance (sagemaker_session , LocalSession ):
131
- sagemaker_session = LocalSession ()
131
+ # Until Local Mode Processing supports local code, we need to disable it:
132
+ sagemaker_session = LocalSession (disable_local_code = True )
132
133
133
134
self .sagemaker_session = sagemaker_session or Session ()
134
135
@@ -1298,10 +1299,15 @@ def __init__(
1298
1299
self .framework_version = framework_version
1299
1300
self .py_version = py_version
1300
1301
1301
- image_uri , base_job_name = self ._pre_init_normalization (
1302
- instance_type , image_uri , base_job_name , sagemaker_session
1303
- )
1304
-
1302
+ # 1. To finalize/normalize the image_uri or base_job_name, we need to create an
1303
+ # estimator_cls instance.
1304
+ # 2. We want to make it easy for children of FrameworkProcessor to override estimator
1305
+ # creation via a function (to create FrameworkProcessors for Estimators that may have
1306
+ # different signatures - like HuggingFace or others in future).
1307
+ # 3. Super-class __init__ doesn't (currently) do anything with these params besides
1308
+ # storing them
1309
+ #
1310
+ # Therefore we'll init the superclass first and then customize the setup after:
1305
1311
super ().__init__ (
1306
1312
role = role ,
1307
1313
image_uri = image_uri ,
@@ -1318,6 +1324,7 @@ def __init__(
1318
1324
tags = tags ,
1319
1325
network_config = network_config ,
1320
1326
)
1327
+
1321
1328
# This subclass uses the "code" input for actual payload and the ScriptProcessor parent's
1322
1329
# functionality for uploading just a small entrypoint script to invoke it.
1323
1330
self ._CODE_CONTAINER_INPUT_NAME = "entrypoint"
@@ -1326,38 +1333,45 @@ def __init__(
1326
1333
code_location [:- 1 ] if (code_location and code_location .endswith ("/" )) else code_location
1327
1334
)
1328
1335
1329
- def _pre_init_normalization (
1330
- self ,
1331
- instance_type : str ,
1332
- image_uri : Optional [str ] = None ,
1333
- base_job_name : Optional [str ] = None ,
1334
- sagemaker_session : Optional [str ] = None ,
1335
- ) -> Tuple [str , str ]:
1336
- """Normalize job name and container image uri."""
1337
- # Normalize base_job_name
1338
- if base_job_name is None :
1339
- base_job_name = self .estimator_cls ._framework_name
1336
+ if image_uri is None or base_job_name is None :
1337
+ # For these default configuration purposes, we don't need the optional args:
1338
+ est = self ._create_estimator ()
1339
+ if image_uri is None :
1340
+ self .image_uri = est .training_image_uri ()
1340
1341
if base_job_name is None :
1341
- logger .warning ("Framework name is None. Please check with the maintainer." )
1342
- base_job_name = str (base_job_name ) # Keep mypy happy.
1343
-
1344
- # Normalize image uri.
1345
- if image_uri is None :
1346
- # Estimator used only to probe image uri, so can get away with some dummy values.
1347
- est = self .estimator_cls (
1348
- framework_version = self .framework_version ,
1349
- instance_type = instance_type ,
1350
- py_version = self .py_version ,
1351
- image_uri = image_uri ,
1352
- entry_point = "" ,
1353
- role = "" ,
1354
- enable_network_isolation = False ,
1355
- instance_count = 1 , # SKLearn estimator explicitly disables instance_count>1
1356
- sagemaker_session = sagemaker_session ,
1357
- )
1358
- image_uri = est .training_image_uri ()
1342
+ self .base_job_name = est .base_job_name or estimator_cls ._framework_name
1343
+ if base_job_name is None :
1344
+ base_job_name = "framework-processor"
1359
1345
1360
- return image_uri , base_job_name
1346
+ def _create_estimator (
1347
+ self ,
1348
+ entry_point = "" ,
1349
+ source_dir = None ,
1350
+ dependencies = None ,
1351
+ git_config = None ,
1352
+ ):
1353
+ """Instantiate the Framework Estimator that backs this Processor"""
1354
+ return self .estimator_cls (
1355
+ framework_version = self .framework_version ,
1356
+ py_version = self .py_version ,
1357
+ entry_point = entry_point ,
1358
+ source_dir = source_dir ,
1359
+ dependencies = dependencies ,
1360
+ git_config = git_config ,
1361
+ code_location = self .code_location ,
1362
+ enable_network_isolation = False , # True -> uploads to input channel. Not what we want!
1363
+ image_uri = self .image_uri ,
1364
+ role = self .role ,
1365
+ # Estimator instance_count doesn't currently matter to FrameworkProcessor, and the
1366
+ # SKLearn Framework Estimator requires instance_type==1. So here we hard-wire it to 1,
1367
+ # but if it matters in future perhaps we could take self.instance_count here and have
1368
+ # SKLearnProcessor override this function instead:
1369
+ instance_count = 1 ,
1370
+ instance_type = self .instance_type ,
1371
+ sagemaker_session = self .sagemaker_session ,
1372
+ debugger_hook_config = False ,
1373
+ disable_profiler = True ,
1374
+ )
1361
1375
1362
1376
def get_run_args (
1363
1377
self ,
@@ -1555,10 +1569,11 @@ def _pack_and_upload_code(self, code, source_dir, dependencies, git_config, job_
1555
1569
1556
1570
local_code = get_config_value ("local.local_code" , self .sagemaker_session .config )
1557
1571
if self .sagemaker_session .local_mode and local_code :
1558
- # TODO: Can we be more prescriptive about how to not trigger this error?
1559
- # How can user or us force a local mode `Estimator` to run with `local_code=False`?
1560
1572
raise RuntimeError (
1561
- "Local *code* is not currently supported for SageMaker Processing in Local Mode"
1573
+ "SageMaker Processing Local Mode does not currently support 'local code' mode. "
1574
+ "Please use a LocalSession created with disable_local_code=True, or leave "
1575
+ "sagemaker_session unspecified when creating your Processor to have one set up "
1576
+ "automatically."
1562
1577
)
1563
1578
1564
1579
# Upload the bootstrapping code as s3://.../jobname/source/runproc.sh.
@@ -1623,22 +1638,11 @@ def _upload_payload(
1623
1638
"""Upload payload sourcedir.tar.gz to S3."""
1624
1639
# A new estimator instance is required, because each call to ScriptProcessor.run() can
1625
1640
# use different codes.
1626
- estimator = self .estimator_cls (
1641
+ estimator = self ._create_estimator (
1627
1642
entry_point = entry_point ,
1628
1643
source_dir = source_dir ,
1629
1644
dependencies = dependencies ,
1630
1645
git_config = git_config ,
1631
- framework_version = self .framework_version ,
1632
- py_version = self .py_version ,
1633
- code_location = self .code_location , # Upload to <code_loc>/jobname/output/source.tar.gz
1634
- enable_network_isolation = False , # If true, uploads to input channel. Not what we want!
1635
- image_uri = self .image_uri , # The image uri is already normalized by this point.
1636
- role = self .role ,
1637
- instance_type = self .instance_type ,
1638
- instance_count = 1 ,
1639
- sagemaker_session = self .sagemaker_session ,
1640
- debugger_hook_config = False ,
1641
- disable_profiler = True ,
1642
1646
)
1643
1647
1644
1648
estimator ._prepare_for_training (job_name = job_name )
0 commit comments