@@ -368,6 +368,12 @@ def test_register_model(estimator, model_metrics, drift_check_baselines):
368
368
display_name = "RegisterModelStep" ,
369
369
depends_on = ["TestStep" ],
370
370
tags = [{"Key" : "myKey" , "Value" : "myValue" }],
371
+ sample_payload_url = "s3://test-bucket/model" ,
372
+ task = "IMAGE_CLASSIFICATION" ,
373
+ framework = "TENSORFLOW" ,
374
+ framework_version = "2.9" ,
375
+ nearest_model_name = "resnet50" ,
376
+ data_input_configuration = '{"input_1":[1,224,224,3]}' ,
371
377
)
372
378
assert ordered (register_model .request_dicts ()) == ordered (
373
379
[
@@ -383,6 +389,12 @@ def test_register_model(estimator, model_metrics, drift_check_baselines):
383
389
{
384
390
"Image" : "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri" ,
385
391
"ModelDataUrl" : f"s3://{ BUCKET } /model.tar.gz" ,
392
+ "Framework" : None ,
393
+ "FrameworkVersion" : None ,
394
+ "NearestModelName" : None ,
395
+ "ModelInput" : {
396
+ "DataInputConfig" : None ,
397
+ },
386
398
}
387
399
],
388
400
"SupportedContentTypes" : ["content_type" ],
@@ -412,6 +424,8 @@ def test_register_model(estimator, model_metrics, drift_check_baselines):
412
424
"ModelPackageDescription" : "description" ,
413
425
"ModelPackageGroupName" : "mpg" ,
414
426
"Tags" : [{"Key" : "myKey" , "Value" : "myValue" }],
427
+ "SamplePayloadUrl" : "s3://test-bucket/model" ,
428
+ "Task" : "IMAGE_CLASSIFICATION" ,
415
429
},
416
430
},
417
431
]
@@ -433,6 +447,12 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
433
447
drift_check_baselines = drift_check_baselines ,
434
448
approval_status = "Approved" ,
435
449
description = "description" ,
450
+ sample_payload_url = "s3://test-bucket/model" ,
451
+ task = "IMAGE_CLASSIFICATION" ,
452
+ framework = "TENSORFLOW" ,
453
+ framework_version = "2.9" ,
454
+ nearest_model_name = "resnet50" ,
455
+ data_input_configuration = '{"input_1":[1,224,224,3]}' ,
436
456
)
437
457
assert ordered (register_model .request_dicts ()) == ordered (
438
458
[
@@ -446,6 +466,12 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
446
466
{
447
467
"Image" : "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.15.2-cpu" ,
448
468
"ModelDataUrl" : f"s3://{ BUCKET } /model.tar.gz" ,
469
+ "Framework" : None ,
470
+ "FrameworkVersion" : None ,
471
+ "NearestModelName" : None ,
472
+ "ModelInput" : {
473
+ "DataInputConfig" : None ,
474
+ },
449
475
}
450
476
],
451
477
"SupportedContentTypes" : ["content_type" ],
@@ -474,6 +500,8 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
474
500
},
475
501
"ModelPackageDescription" : "description" ,
476
502
"ModelPackageGroupName" : "mpg" ,
503
+ "SamplePayloadUrl" : "s3://test-bucket/model" ,
504
+ "Task" : "IMAGE_CLASSIFICATION" ,
477
505
},
478
506
},
479
507
]
@@ -502,6 +530,12 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
502
530
description = "description" ,
503
531
model = pipeline_model ,
504
532
depends_on = ["TestStep" ],
533
+ sample_payload_url = "s3://test-bucket/model" ,
534
+ task = "IMAGE_CLASSIFICATION" ,
535
+ framework = "TENSORFLOW" ,
536
+ framework_version = "2.9" ,
537
+ nearest_model_name = "resnet50" ,
538
+ data_input_configuration = '{"input_1":[1,224,224,3]}' ,
505
539
)
506
540
assert ordered (register_model .request_dicts ()) == ordered (
507
541
[
@@ -517,11 +551,23 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
517
551
"Image" : "fakeimage1" ,
518
552
"ModelDataUrl" : "Url1" ,
519
553
"Environment" : [{"k1" : "v1" }, {"k2" : "v2" }],
554
+ "Framework" : "TENSORFLOW" ,
555
+ "FrameworkVersion" : "2.9" ,
556
+ "NearestModelName" : "resnet50" ,
557
+ "ModelInput" : {
558
+ "DataInputConfig" : '{"input_1":[1,224,224,3]}' ,
559
+ },
520
560
},
521
561
{
522
562
"Image" : "fakeimage2" ,
523
563
"ModelDataUrl" : "Url2" ,
524
564
"Environment" : [{"k3" : "v3" }, {"k4" : "v4" }],
565
+ "Framework" : "TENSORFLOW" ,
566
+ "FrameworkVersion" : "2.9" ,
567
+ "NearestModelName" : "resnet50" ,
568
+ "ModelInput" : {
569
+ "DataInputConfig" : '{"input_1":[1,224,224,3]}' ,
570
+ },
525
571
},
526
572
],
527
573
"SupportedContentTypes" : ["content_type" ],
@@ -550,6 +596,8 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
550
596
},
551
597
"ModelPackageDescription" : "description" ,
552
598
"ModelPackageGroupName" : "mpg" ,
599
+ "SamplePayloadUrl" : "s3://test-bucket/model" ,
600
+ "Task" : "IMAGE_CLASSIFICATION" ,
553
601
},
554
602
},
555
603
]
@@ -578,6 +626,12 @@ def test_register_model_with_model_repack_with_estimator(
578
626
dependencies = [dummy_requirements ],
579
627
depends_on = ["TestStep" ],
580
628
tags = [{"Key" : "myKey" , "Value" : "myValue" }],
629
+ sample_payload_url = "s3://test-bucket/model" ,
630
+ task = "IMAGE_CLASSIFICATION" ,
631
+ framework = "TENSORFLOW" ,
632
+ framework_version = "2.9" ,
633
+ nearest_model_name = "resnet50" ,
634
+ data_input_configuration = '{"input_1":[1,224,224,3]}' ,
581
635
)
582
636
583
637
request_dicts = register_model .request_dicts ()
@@ -649,6 +703,15 @@ def test_register_model_with_model_repack_with_estimator(
649
703
assert isinstance (
650
704
arguments ["InferenceSpecification" ]["Containers" ][0 ]["ModelDataUrl" ], Properties
651
705
)
706
+ assert arguments ["InferenceSpecification" ]["Containers" ][0 ]["Framework" ] == None
707
+ assert arguments ["InferenceSpecification" ]["Containers" ][0 ]["FrameworkVersion" ] == None
708
+ assert arguments ["InferenceSpecification" ]["Containers" ][0 ]["NearestModelName" ] == None
709
+ assert (
710
+ arguments ["InferenceSpecification" ]["Containers" ][0 ]["ModelInput" ][
711
+ "DataInputConfig"
712
+ ]
713
+ == None
714
+ )
652
715
del arguments ["InferenceSpecification" ]["Containers" ]
653
716
assert ordered (arguments ) == ordered (
654
717
{
@@ -680,6 +743,8 @@ def test_register_model_with_model_repack_with_estimator(
680
743
"ModelPackageDescription" : "description" ,
681
744
"ModelPackageGroupName" : "mpg" ,
682
745
"Tags" : [{"Key" : "myKey" , "Value" : "myValue" }],
746
+ "SamplePayloadUrl" : "s3://test-bucket/model" ,
747
+ "Task" : "IMAGE_CLASSIFICATION" ,
683
748
}
684
749
)
685
750
else :
0 commit comments