|
67 | 67 | "123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1"
|
68 | 68 | )
|
69 | 69 |
|
| 70 | +mock_model_data = { |
| 71 | + "S3DataSource": { |
| 72 | + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/huggingface-llm/huggingface-llm-zephyr-7b-gemma" |
| 73 | + "/artifacts/inference-prepack/v1.0.0/", |
| 74 | + "S3DataType": "S3Prefix", |
| 75 | + "CompressionType": "None", |
| 76 | + } |
| 77 | +} |
| 78 | +mock_model_data_str = ( |
| 79 | + "s3://jumpstart-private-cache-prod-us-west-2/huggingface-llm/huggingface-llm-zephyr-7b-gemma" |
| 80 | + "/artifacts/inference-prepack/v1.0.0/" |
| 81 | +) |
| 82 | + |
70 | 83 |
|
71 | 84 | class TestJumpStartBuilder(unittest.TestCase):
|
72 | 85 | @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
|
@@ -527,3 +540,101 @@ def test_tune_for_djl_js_endpoint_mode_ex(
|
527 | 540 |
|
528 | 541 | tuned_model = model.tune()
|
529 | 542 | assert tuned_model == model
|
| 543 | + |
| 544 | + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) |
| 545 | + @patch( |
| 546 | + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", |
| 547 | + return_value=True, |
| 548 | + ) |
| 549 | + @patch( |
| 550 | + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", |
| 551 | + return_value=MagicMock(), |
| 552 | + ) |
| 553 | + @patch( |
| 554 | + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", |
| 555 | + return_value=({"model_type": "t5", "n_head": 71}, True), |
| 556 | + ) |
| 557 | + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) |
| 558 | + @patch( |
| 559 | + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" |
| 560 | + ) |
| 561 | + def test_js_gated_model_in_endpoint_mode( |
| 562 | + self, |
| 563 | + mock_get_nb_instance, |
| 564 | + mock_get_ram_usage_mb, |
| 565 | + mock_prepare_for_tgi, |
| 566 | + mock_pre_trained_model, |
| 567 | + mock_is_jumpstart_model, |
| 568 | + mock_telemetry, |
| 569 | + ): |
| 570 | + builder = ModelBuilder( |
| 571 | + model="facebook/galactica-mock-model-id", |
| 572 | + schema_builder=mock_schema_builder, |
| 573 | + mode=Mode.SAGEMAKER_ENDPOINT, |
| 574 | + ) |
| 575 | + |
| 576 | + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri |
| 577 | + mock_pre_trained_model.return_value.model_data = mock_model_data |
| 578 | + |
| 579 | + model = builder.build() |
| 580 | + |
| 581 | + assert model is not None |
| 582 | + |
| 583 | + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) |
| 584 | + @patch( |
| 585 | + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", |
| 586 | + return_value=True, |
| 587 | + ) |
| 588 | + @patch( |
| 589 | + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", |
| 590 | + return_value=MagicMock(), |
| 591 | + ) |
| 592 | + def test_js_gated_model_in_local_mode( |
| 593 | + self, |
| 594 | + mock_pre_trained_model, |
| 595 | + mock_is_jumpstart_model, |
| 596 | + mock_telemetry, |
| 597 | + ): |
| 598 | + builder = ModelBuilder( |
| 599 | + model="huggingface-llm-zephyr-7b-gemma", |
| 600 | + schema_builder=mock_schema_builder, |
| 601 | + mode=Mode.LOCAL_CONTAINER, |
| 602 | + ) |
| 603 | + |
| 604 | + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri |
| 605 | + mock_pre_trained_model.return_value.model_data = mock_model_data_str |
| 606 | + |
| 607 | + self.assertRaisesRegex( |
| 608 | + ValueError, |
| 609 | + "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode.", |
| 610 | + lambda: builder.build(), |
| 611 | + ) |
| 612 | + |
| 613 | + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) |
| 614 | + @patch( |
| 615 | + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", |
| 616 | + return_value=True, |
| 617 | + ) |
| 618 | + @patch( |
| 619 | + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", |
| 620 | + return_value=MagicMock(), |
| 621 | + ) |
| 622 | + def test_js_gated_model_ex( |
| 623 | + self, |
| 624 | + mock_pre_trained_model, |
| 625 | + mock_is_jumpstart_model, |
| 626 | + mock_telemetry, |
| 627 | + ): |
| 628 | + builder = ModelBuilder( |
| 629 | + model="huggingface-llm-zephyr-7b-gemma", |
| 630 | + schema_builder=mock_schema_builder, |
| 631 | + mode=Mode.LOCAL_CONTAINER, |
| 632 | + ) |
| 633 | + |
| 634 | + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri |
| 635 | + mock_pre_trained_model.return_value.model_data = None |
| 636 | + |
| 637 | + self.assertRaises( |
| 638 | + ValueError, |
| 639 | + lambda: builder.build(), |
| 640 | + ) |
0 commit comments