Skip to content

Commit 902fa80

Browse files
authored
Merge branch 'aws:master' into master
2 parents 2508683 + 219ad24 commit 902fa80

File tree

10 files changed

+330
-11
lines changed

10 files changed

+330
-11
lines changed

doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,102 @@ smdistributed.modelparallel.torch.DistributedOptimizer
494494
``state_dict`` contains elements corresponding to only the current
495495
partition, or to the entire model.
496496
497+
smdistributed.modelparallel.torch.nn.FlashAttentionLayer
498+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
499+
500+
.. function:: smdistributed.modelparallel.torch.nn.FlashAttentionLayer(attention_dropout_prob=0.1, attention_head_size=None, scale_attention_scores=True, scale_attn_by_layer_idx=False, layer_idx=None, scale=None, triton_flash_attention=False, use_alibi=False)
501+
502+
This class supports
503+
`FlashAttention <https://github.com/HazyResearch/flash-attention>`_
504+
for PyTorch 2.0.
505+
It takes the ``qkv`` matrix as an argument through its ``forward`` class method,
506+
computes attention scores and probabilities,
507+
and then operates the matrix multiplication with value layers.
508+
509+
Through this class, the smp library supports
510+
custom attention masks such as Attention with
511+
Linear Biases (ALiBi), and you can activate them by setting
512+
``triton_flash_attention`` and ``use_alibi`` to ``True``.
513+
514+
Note that the Triton flash attention does not support dropout
515+
on the attention probabilities. It uses standard lower triangular
516+
causal mask when causal mode is enabled. It also runs only
517+
on P4d and P4de instances, with fp16 or bf16.
518+
519+
This class computes the scale factor to apply when computing attention.
520+
By default, ``scale`` is set to ``None``, and it's automatically calculated.
521+
When ``scale_attention_scores`` is ``True`` (which is default), you must pass a value
522+
to ``attention_head_size``. When ``scale_attn_by_layer_idx`` is ``True``,
523+
you must pass a value to ``layer_idx``. If both factors are used, they are
524+
multiplied as follows: ``(1/(sqrt(attention_head_size) * (layer_idx+1)))``.
525+
This scale calculation can be bypassed if you specify a custom scaling
526+
factor to ``scale``. In other words, if you specify a value to ``scale``, the set of parameters
527+
(``scale_attention_scores``, ``attention_head_size``, ``scale_attn_by_layer_idx``, ``layer_idx``)
528+
is overridden and ignored.
529+
530+
**Parameters**
531+
532+
* ``attention_dropout_prob`` (float): (default: 0.1) specifies dropout probability
533+
to apply to attention.
534+
* ``attention_head_size`` (int): Required when ``scale_attention_scores`` is True.
535+
When ``scale_attention_scores`` is passed, this contributes
536+
``1/sqrt(attention_head_size)`` to the scale factor.
537+
* ``scale_attention_scores`` (boolean): (default: True) determines whether
538+
to multiply 1/sqrt(attention_head_size) to the scale factor.
539+
* ``layer_idx`` (int): Required when ``scale_attn_by_layer_idx`` is ``True``.
540+
The layer id to use for scaling attention by layer id.
541+
It contributes 1/(layer_idx + 1) to the scaling factor.
542+
* ``scale_attn_by_layer_idx`` (boolean): (default: False) determines whether
543+
to multiply 1/(layer_idx + 1) to the scale factor.
544+
* ``scale`` (float) (default: None): If passed, this scale factor will be
545+
applied bypassing the all of the previous arguments.
546+
* ``triton_flash_attention`` (bool): (default: False) If passed, Triton
547+
implementation of flash attention will be used. This is necessary to supports
548+
Attention with Linear Biases (ALiBi) (see next arg). Note that this version
549+
of the kernel doesn’t support dropout.
550+
* ``use_alibi`` (bool): (default: False) If passed, it enables Attention with
551+
Linear Biases (ALiBi) using the mask provided.
552+
553+
.. method:: forward(self, qkv, attn_mask=None, causal=False)
554+
555+
Returns a single ``torch.Tensor`` ``(batch_size x num_heads x seq_len x head_size)``,
556+
which represents the output of attention computation.
557+
558+
**Parameters**
559+
560+
* ``qkv``: ``torch.Tensor`` in the form of ``(batch_size x seqlen x 3 x num_heads x head_size)``.
561+
* ``attn_mask``: ``torch.Tensor`` in the form of ``(batch_size x 1 x 1 x seqlen)``.
562+
By default it is ``None``, and usage of this mask needs ``triton_flash_attention``
563+
and ``use_alibi`` to be set. See how to generate the mask in the following code snippet.
564+
* ``causal``: When passed, it uses the standard lower triangular mask. The default is ``False``.
565+
566+
When using ALiBi, it needs an attention mask prepared like the following.
567+
568+
.. code:: python
569+
570+
def generate_alibi_attn_mask(attention_mask, batch_size, seq_length,
571+
num_attention_heads, alibi_bias_max=8):
572+
573+
device, dtype = attention_mask.device, attention_mask.dtype
574+
alibi_attention_mask = torch.zeros(
575+
1, num_attention_heads, 1, seq_length, dtype=dtype, device=device
576+
)
577+
578+
alibi_bias = torch.arange(1 - seq_length, 1, dtype=dtype, device=device).view(
579+
1, 1, 1, seq_length
580+
)
581+
m = torch.arange(1, num_attention_heads + 1, dtype=dtype, device=device)
582+
m.mul_(alibi_bias_max / num_attention_heads)
583+
alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1, num_attention_heads, 1, 1)))
584+
585+
alibi_attention_mask.add_(alibi_bias)
586+
alibi_attention_mask = alibi_attention_mask[..., :seq_length, :seq_length]
587+
if attention_mask is not None and attention_mask.bool().any():
588+
alibi_attention_mask.masked_fill(
589+
attention_mask.bool().view(batch_size, 1, 1, seq_length), float("-inf")
590+
)
591+
592+
return alibi_attention_mask
497593
498594
smdistributed.modelparallel.torch Context Managers and Util Functions
499595
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
doc8==0.10.1
2-
Pygments==2.11.2
2+
Pygments==2.15.0

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
334334
"training_dependencies",
335335
"training_vulnerabilities",
336336
"deprecated",
337+
"deprecated_message",
338+
"deprecate_warn_message",
337339
"default_inference_instance_type",
338340
"supported_inference_instance_types",
339341
"default_training_instance_type",
@@ -389,6 +391,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
389391
self.training_dependencies: List[str] = json_obj["training_dependencies"]
390392
self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"]
391393
self.deprecated: bool = bool(json_obj["deprecated"])
394+
self.deprecated_message: Optional[str] = json_obj.get("deprecated_message")
395+
self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message")
392396
self.default_inference_instance_type: Optional[str] = json_obj.get(
393397
"default_inference_instance_type"
394398
)

src/sagemaker/jumpstart/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,14 @@ def verify_model_region_and_return_specs(
415415

416416
if model_specs.deprecated:
417417
if not tolerate_deprecated_model:
418-
raise DeprecatedJumpStartModelError(model_id=model_id, version=version)
418+
raise DeprecatedJumpStartModelError(
419+
model_id=model_id, version=version, message=model_specs.deprecated_message
420+
)
419421
LOGGER.warning("Using deprecated JumpStart model '%s' and version '%s'.", model_id, version)
420422

423+
if model_specs.deprecate_warn_message:
424+
LOGGER.warning(model_specs.deprecate_warn_message)
425+
421426
if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable:
422427
if not tolerate_vulnerable_model:
423428
raise VulnerableJumpStartModelError(

src/sagemaker/jumpstart/validators.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@
1515
from typing import Any, Dict, List, Optional
1616
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
1717

18-
from sagemaker.jumpstart.enums import HyperparameterValidationMode, VariableScope, VariableTypes
19-
from sagemaker.jumpstart import accessors as jumpstart_accessors
18+
from sagemaker.jumpstart.enums import (
19+
HyperparameterValidationMode,
20+
JumpStartScriptScope,
21+
VariableScope,
22+
VariableTypes,
23+
)
2024
from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError
2125
from sagemaker.jumpstart.types import JumpStartHyperparameter
26+
from sagemaker.jumpstart.utils import verify_model_region_and_return_specs
2227

2328

2429
def _validate_hyperparameter(
@@ -190,8 +195,11 @@ def validate_hyperparameters(
190195
if region is None:
191196
region = JUMPSTART_DEFAULT_REGION_NAME
192197

193-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
194-
region=region, model_id=model_id, version=model_version
198+
model_specs = verify_model_region_and_return_specs(
199+
model_id=model_id,
200+
version=model_version,
201+
region=region,
202+
scope=JumpStartScriptScope.TRAINING,
195203
)
196204
hyperparameters_specs = model_specs.hyperparameters
197205

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
import os
1515
import time
1616

17+
import pytest
18+
19+
import tests.integ
20+
1721
from sagemaker.jumpstart.model import JumpStartModel
1822
from tests.integ.sagemaker.jumpstart.constants import (
1923
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
@@ -29,6 +33,8 @@
2933

3034
MAX_INIT_TIME_SECONDS = 5
3135

36+
MODEL_PACKAGE_ARN_SUPPORTED_REGIONS = {"us-west-2", "us-east-1"}
37+
3238

3339
def test_non_prepacked_jumpstart_model(setup):
3440

@@ -73,6 +79,35 @@ def test_prepacked_jumpstart_model(setup):
7379
assert response is not None
7480

7581

82+
@pytest.mark.skipif(
83+
tests.integ.test_region() not in MODEL_PACKAGE_ARN_SUPPORTED_REGIONS,
84+
reason=f"JumpStart Model Package models unavailable in {tests.integ.test_region()}.",
85+
)
86+
def test_model_package_arn_jumpstart_model(setup):
87+
88+
model_id = "meta-textgeneration-llama-2-7b"
89+
90+
model = JumpStartModel(
91+
model_id=model_id,
92+
role=get_sm_session().get_caller_identity_arn(),
93+
sagemaker_session=get_sm_session(),
94+
)
95+
96+
# uses ml.g5.2xlarge instance
97+
predictor = model.deploy(
98+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
99+
)
100+
101+
payload = {
102+
"inputs": "some-payload",
103+
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
104+
}
105+
106+
response = predictor.predict(payload, custom_attributes="accept_eula=true")
107+
108+
assert response is not None
109+
110+
76111
def test_instatiating_model_not_too_slow(setup):
77112

78113
model_id = "catboost-regression-model"

tests/scripts/run-notebook-test.sh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,7 @@ echo "set SAGEMAKER_ROLE_ARN=$SAGEMAKER_ROLE_ARN"
126126
./amazon-sagemaker-examples/advanced_functionality/kmeans_bring_your_own_model/kmeans_bring_your_own_model.ipynb \
127127
./amazon-sagemaker-examples/advanced_functionality/tensorflow_iris_byom/tensorflow_BYOM_iris.ipynb \
128128
./amazon-sagemaker-examples/sagemaker-python-sdk/1P_kmeans_highlevel/kmeans_mnist.ipynb \
129-
./amazon-sagemaker-examples/sagemaker-python-sdk/1P_kmeans_lowlevel/kmeans_mnist_lowlevel.ipynb \
130-
./amazon-sagemaker-examples/sagemaker-python-sdk/mxnet_gluon_sentiment/mxnet_sentiment_analysis_with_gluon.ipynb \
131-
./amazon-sagemaker-examples/sagemaker-python-sdk/mxnet_onnx_export/mxnet_onnx_export.ipynb \
132129
./amazon-sagemaker-examples/sagemaker-python-sdk/scikit_learn_randomforest/Sklearn_on_SageMaker_end2end.ipynb \
133130
./amazon-sagemaker-examples/sagemaker-python-sdk/tensorflow_moving_from_framework_mode_to_script_mode/tensorflow_moving_from_framework_mode_to_script_mode.ipynb \
134-
./amazon-sagemaker-examples/sagemaker-python-sdk/tensorflow_script_mode_pipe_mode/tensorflow_script_mode_pipe_mode.ipynb \
135-
./amazon-sagemaker-examples/sagemaker-python-sdk/tensorflow_serving_using_elastic_inference_with_your_own_model/tensorflow_serving_pretrained_model_elastic_inference.ipynb \
136131
137132
(DeleteLifeCycleConfig "$LIFECYCLE_CONFIG_NAME")

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,63 @@
1414

1515

1616
SPECIAL_MODEL_SPECS_DICT = {
17+
"js-model-package-arn": {
18+
"model_id": "meta-textgeneration-llama-2-7b-f",
19+
"url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/",
20+
"version": "1.0.0",
21+
"min_sdk_version": "2.173.0",
22+
"training_supported": False,
23+
"incremental_training_supported": False,
24+
"hosting_ecr_specs": {
25+
"framework": "pytorch",
26+
"framework_version": "1.12.0",
27+
"py_version": "py38",
28+
},
29+
"hosting_artifact_key": "meta-infer/infer-meta-textgeneration-llama-2-7b-f.tar.gz",
30+
"hosting_script_key": "source-directory-tarballs/meta/inference/textgeneration/v1.0.0/sourcedir.tar.gz",
31+
"hosting_eula_key": "fmhMetadata/eula/llamaEula.txt",
32+
"hosting_model_package_arns": {
33+
"us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/"
34+
"llama2-7b-f-e46eb8a833643ed58aaccd81498972c3",
35+
"us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/"
36+
"llama2-7b-f-e46eb8a833643ed58aaccd81498972c3",
37+
},
38+
"inference_vulnerable": False,
39+
"inference_dependencies": [],
40+
"inference_vulnerabilities": [],
41+
"training_vulnerable": False,
42+
"training_dependencies": [],
43+
"training_vulnerabilities": [],
44+
"deprecated": False,
45+
"inference_environment_variables": [],
46+
"metrics": [],
47+
"default_inference_instance_type": "ml.g5.2xlarge",
48+
"supported_inference_instance_types": [
49+
"ml.g5.2xlarge",
50+
"ml.g5.4xlarge",
51+
"ml.g5.8xlarge",
52+
"ml.g5.12xlarge",
53+
"ml.g5.24xlarge",
54+
"ml.g5.48xlarge",
55+
"ml.p4d.24xlarge",
56+
],
57+
"model_kwargs": {},
58+
"deploy_kwargs": {
59+
"model_data_download_timeout": 3600,
60+
"container_startup_health_check_timeout": 3600,
61+
},
62+
"predictor_specs": {
63+
"supported_content_types": ["application/json"],
64+
"supported_accept_types": ["application/json"],
65+
"default_content_type": "application/json",
66+
"default_accept_type": "application/json",
67+
},
68+
"inference_volume_size": 256,
69+
"inference_enable_network_isolation": True,
70+
"validation_supported": False,
71+
"fine_tuning_supported": False,
72+
"resource_name_base": "meta-textgeneration-llama-2-7b-f",
73+
},
1774
"js-trainable-model-prepacked": {
1875
"model_id": "huggingface-text2text-flan-t5-base",
1976
"url": "https://huggingface.co/google/flan-t5-base",
@@ -2299,6 +2356,8 @@
22992356
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
23002357
"training_prepacked_script_key": None,
23012358
"hosting_prepacked_artifact_key": None,
2359+
"deprecate_warn_message": None,
2360+
"deprecated_message": None,
23022361
"hosting_model_package_arns": None,
23032362
"hosting_eula_key": None,
23042363
"hyperparameters": [

0 commit comments

Comments
 (0)