diff --git a/src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py b/src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py index 068a41a..c0de779 100644 --- a/src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py +++ b/src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py @@ -29,24 +29,22 @@ def is_diffusers_available(): if is_diffusers_available(): import torch - from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionPipeline + from diffusers import DiffusionPipeline -class SMAutoPipelineForText2Image: +class SMDiffusionPipelineForText2Image: + def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU + self.pipeline = None dtype = torch.float32 if device == "cuda": dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16 - device_map = "auto" if device == "cuda" else None + if torch.cuda.device_count() > 1: + device_map = "balanced" + self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map) - self.pipeline = AutoPipelineForText2Image.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map) - # try to use DPMSolverMultistepScheduler - if isinstance(self.pipeline, StableDiffusionPipeline): - try: - self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config) - except Exception: - pass - self.pipeline.to(device) + if not self.pipeline: + self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype).to(device) def __call__( self, @@ -64,7 +62,7 @@ def __call__( DIFFUSERS_TASKS = { - "text-to-image": SMAutoPipelineForText2Image, + "text-to-image": SMDiffusionPipelineForText2Image, } diff --git a/tests/unit/test_diffusers_utils.py b/tests/unit/test_diffusers_utils.py index c00c139..52a7929 100644 --- a/tests/unit/test_diffusers_utils.py +++ b/tests/unit/test_diffusers_utils.py @@ -16,7 +16,7 @@ from transformers.testing_utils import require_torch, slow from PIL import Image -from sagemaker_huggingface_inference_toolkit.diffusers_utils import SMAutoPipelineForText2Image +from sagemaker_huggingface_inference_toolkit.diffusers_utils import SMDiffusionPipelineForText2Image from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline @@ -28,7 +28,7 @@ def test_get_diffusers_pipeline(): tmpdirname, ) pipe = get_pipeline("text-to-image", -1, storage_dir) - assert isinstance(pipe, SMAutoPipelineForText2Image) + assert isinstance(pipe, SMDiffusionPipelineForText2Image) @slow