@@ -199,6 +199,253 @@ def test_conditional_pytorch_training_model_registration(
199
199
pass
200
200
201
201
202
+ def test_conditional_pytorch_training_model_registration_without_instance_types (
203
+ sagemaker_session ,
204
+ role ,
205
+ cpu_instance_type ,
206
+ pipeline_name ,
207
+ region_name ,
208
+ ):
209
+ base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
210
+ entry_point = os .path .join (base_dir , "mnist.py" )
211
+ input_path = sagemaker_session .upload_data (
212
+ path = os .path .join (base_dir , "training" ),
213
+ key_prefix = "integ-test-data/pytorch_mnist/training" ,
214
+ )
215
+ inputs = TrainingInput (s3_data = input_path )
216
+
217
+ instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
218
+ instance_type = "ml.m5.xlarge"
219
+ good_enough_input = ParameterInteger (name = "GoodEnoughInput" , default_value = 1 )
220
+ in_condition_input = ParameterString (name = "Foo" , default_value = "Foo" )
221
+
222
+ task = "IMAGE_CLASSIFICATION"
223
+ sample_payload_url = "s3://test-bucket/model"
224
+ framework = "TENSORFLOW"
225
+ framework_version = "2.9"
226
+ nearest_model_name = "resnet50"
227
+ data_input_configuration = '{"input_1":[1,224,224,3]}'
228
+
229
+ # If image_uri is not provided, the instance_type should not be a pipeline variable
230
+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
231
+ pytorch_estimator = PyTorch (
232
+ entry_point = entry_point ,
233
+ role = role ,
234
+ framework_version = "1.5.0" ,
235
+ py_version = "py3" ,
236
+ instance_count = instance_count ,
237
+ instance_type = instance_type ,
238
+ sagemaker_session = sagemaker_session ,
239
+ )
240
+ step_train = TrainingStep (
241
+ name = "pytorch-train" ,
242
+ estimator = pytorch_estimator ,
243
+ inputs = inputs ,
244
+ )
245
+
246
+ step_register = RegisterModel (
247
+ name = "pytorch-register-model" ,
248
+ estimator = pytorch_estimator ,
249
+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
250
+ content_types = ["*" ],
251
+ response_types = ["*" ],
252
+ description = "test-description" ,
253
+ sample_payload_url = sample_payload_url ,
254
+ task = task ,
255
+ framework = framework ,
256
+ framework_version = framework_version ,
257
+ nearest_model_name = nearest_model_name ,
258
+ data_input_configuration = data_input_configuration ,
259
+ )
260
+
261
+ model = Model (
262
+ image_uri = pytorch_estimator .training_image_uri (),
263
+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
264
+ sagemaker_session = sagemaker_session ,
265
+ role = role ,
266
+ )
267
+ model_inputs = CreateModelInput (
268
+ instance_type = "ml.m5.large" ,
269
+ accelerator_type = "ml.eia1.medium" ,
270
+ )
271
+ step_model = CreateModelStep (
272
+ name = "pytorch-model" ,
273
+ model = model ,
274
+ inputs = model_inputs ,
275
+ )
276
+
277
+ step_cond = ConditionStep (
278
+ name = "cond-good-enough" ,
279
+ conditions = [
280
+ ConditionGreaterThanOrEqualTo (left = good_enough_input , right = 1 ),
281
+ ConditionIn (value = in_condition_input , in_values = ["foo" , "bar" ]),
282
+ ],
283
+ if_steps = [step_register ],
284
+ else_steps = [step_model ],
285
+ depends_on = [step_train ],
286
+ )
287
+
288
+ pipeline = Pipeline (
289
+ name = pipeline_name ,
290
+ parameters = [
291
+ in_condition_input ,
292
+ good_enough_input ,
293
+ instance_count ,
294
+ ],
295
+ steps = [step_train , step_cond ],
296
+ sagemaker_session = sagemaker_session ,
297
+ )
298
+
299
+ try :
300
+ response = pipeline .create (role )
301
+ create_arn = response ["PipelineArn" ]
302
+ assert re .match (
303
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
304
+ create_arn ,
305
+ )
306
+
307
+ execution = pipeline .start (parameters = {})
308
+ assert re .match (
309
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
310
+ execution .arn ,
311
+ )
312
+
313
+ execution = pipeline .start (parameters = {"GoodEnoughInput" : 0 })
314
+ assert re .match (
315
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
316
+ execution .arn ,
317
+ )
318
+ finally :
319
+ try :
320
+ pipeline .delete ()
321
+ except Exception :
322
+ pass
323
+
324
+
325
+ def test_conditional_pytorch_training_model_registration_with_one_instance_types (
326
+ sagemaker_session ,
327
+ role ,
328
+ cpu_instance_type ,
329
+ pipeline_name ,
330
+ region_name ,
331
+ ):
332
+ base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
333
+ entry_point = os .path .join (base_dir , "mnist.py" )
334
+ input_path = sagemaker_session .upload_data (
335
+ path = os .path .join (base_dir , "training" ),
336
+ key_prefix = "integ-test-data/pytorch_mnist/training" ,
337
+ )
338
+ inputs = TrainingInput (s3_data = input_path )
339
+
340
+ instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
341
+ instance_type = "ml.m5.xlarge"
342
+ good_enough_input = ParameterInteger (name = "GoodEnoughInput" , default_value = 1 )
343
+ in_condition_input = ParameterString (name = "Foo" , default_value = "Foo" )
344
+
345
+ task = "IMAGE_CLASSIFICATION"
346
+ sample_payload_url = "s3://test-bucket/model"
347
+ framework = "TENSORFLOW"
348
+ framework_version = "2.9"
349
+ nearest_model_name = "resnet50"
350
+ data_input_configuration = '{"input_1":[1,224,224,3]}'
351
+
352
+ # If image_uri is not provided, the instance_type should not be a pipeline variable
353
+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
354
+ pytorch_estimator = PyTorch (
355
+ entry_point = entry_point ,
356
+ role = role ,
357
+ framework_version = "1.5.0" ,
358
+ py_version = "py3" ,
359
+ instance_count = instance_count ,
360
+ instance_type = instance_type ,
361
+ sagemaker_session = sagemaker_session ,
362
+ )
363
+ step_train = TrainingStep (
364
+ name = "pytorch-train" ,
365
+ estimator = pytorch_estimator ,
366
+ inputs = inputs ,
367
+ )
368
+
369
+ step_register = RegisterModel (
370
+ name = "pytorch-register-model" ,
371
+ estimator = pytorch_estimator ,
372
+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
373
+ content_types = ["*" ],
374
+ response_types = ["*" ],
375
+ inference_instances = ["*" ],
376
+ description = "test-description" ,
377
+ sample_payload_url = sample_payload_url ,
378
+ task = task ,
379
+ framework = framework ,
380
+ framework_version = framework_version ,
381
+ nearest_model_name = nearest_model_name ,
382
+ data_input_configuration = data_input_configuration ,
383
+ )
384
+
385
+ model = Model (
386
+ image_uri = pytorch_estimator .training_image_uri (),
387
+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
388
+ sagemaker_session = sagemaker_session ,
389
+ role = role ,
390
+ )
391
+ model_inputs = CreateModelInput (
392
+ instance_type = "ml.m5.large" ,
393
+ accelerator_type = "ml.eia1.medium" ,
394
+ )
395
+ step_model = CreateModelStep (
396
+ name = "pytorch-model" ,
397
+ model = model ,
398
+ inputs = model_inputs ,
399
+ )
400
+
401
+ step_cond = ConditionStep (
402
+ name = "cond-good-enough" ,
403
+ conditions = [
404
+ ConditionGreaterThanOrEqualTo (left = good_enough_input , right = 1 ),
405
+ ConditionIn (value = in_condition_input , in_values = ["foo" , "bar" ]),
406
+ ],
407
+ if_steps = [step_register ],
408
+ else_steps = [step_model ],
409
+ depends_on = [step_train ],
410
+ )
411
+
412
+ pipeline = Pipeline (
413
+ name = pipeline_name ,
414
+ parameters = [
415
+ in_condition_input ,
416
+ good_enough_input ,
417
+ instance_count ,
418
+ ],
419
+ steps = [step_train , step_cond ],
420
+ sagemaker_session = sagemaker_session ,
421
+ )
422
+
423
+ try :
424
+ response = pipeline .create (role )
425
+ create_arn = response ["PipelineArn" ]
426
+ assert re .match (
427
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
428
+ create_arn ,
429
+ )
430
+
431
+ execution = pipeline .start (parameters = {})
432
+ assert re .match (
433
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
434
+ execution .arn ,
435
+ )
436
+
437
+ execution = pipeline .start (parameters = {"GoodEnoughInput" : 0 })
438
+ assert re .match (
439
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
440
+ execution .arn ,
441
+ )
442
+ finally :
443
+ try :
444
+ pipeline .delete ()
445
+ except Exception :
446
+ pass
447
+
448
+
202
449
def test_mxnet_model_registration (
203
450
sagemaker_session ,
204
451
role ,
0 commit comments