Skip to content

Commit 019e859

Browse files
Created more generic pipeline for text-to-image task
Created more generic pipeline for text-to-image task
2 parents 9923001 + 9797ee8 commit 019e859

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,22 @@ def is_diffusers_available():
2929
if is_diffusers_available():
3030
import torch
3131

32-
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionPipeline
32+
from diffusers import DiffusionPipeline
3333

3434

35-
class SMAutoPipelineForText2Image:
35+
class SMDiffusionPipelineForText2Image:
36+
3637
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
38+
self.pipeline = None
3739
dtype = torch.float32
3840
if device == "cuda":
3941
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
40-
device_map = "auto" if device == "cuda" else None
42+
if torch.cuda.device_count() > 1:
43+
device_map = "balanced"
44+
self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map)
4145

42-
self.pipeline = AutoPipelineForText2Image.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map)
43-
# try to use DPMSolverMultistepScheduler
44-
if isinstance(self.pipeline, StableDiffusionPipeline):
45-
try:
46-
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
47-
except Exception:
48-
pass
49-
self.pipeline.to(device)
46+
if not self.pipeline:
47+
self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype).to(device)
5048

5149
def __call__(
5250
self,
@@ -64,7 +62,7 @@ def __call__(
6462

6563

6664
DIFFUSERS_TASKS = {
67-
"text-to-image": SMAutoPipelineForText2Image,
65+
"text-to-image": SMDiffusionPipelineForText2Image,
6866
}
6967

7068

tests/unit/test_diffusers_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from transformers.testing_utils import require_torch, slow
1717

1818
from PIL import Image
19-
from sagemaker_huggingface_inference_toolkit.diffusers_utils import SMAutoPipelineForText2Image
19+
from sagemaker_huggingface_inference_toolkit.diffusers_utils import SMDiffusionPipelineForText2Image
2020
from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline
2121

2222

@@ -28,7 +28,7 @@ def test_get_diffusers_pipeline():
2828
tmpdirname,
2929
)
3030
pipe = get_pipeline("text-to-image", -1, storage_dir)
31-
assert isinstance(pipe, SMAutoPipelineForText2Image)
31+
assert isinstance(pipe, SMDiffusionPipelineForText2Image)
3232

3333

3434
@slow

0 commit comments

Comments
 (0)