@@ -29,24 +29,22 @@ def is_diffusers_available():
29
29
if is_diffusers_available ():
30
30
import torch
31
31
32
- from diffusers import AutoPipelineForText2Image , DPMSolverMultistepScheduler , StableDiffusionPipeline
32
+ from diffusers import DiffusionPipeline
33
33
34
34
35
- class SMAutoPipelineForText2Image :
35
+ class SMDiffusionPipelineForText2Image :
36
+
36
37
def __init__ (self , model_dir : str , device : str = None ): # needs "cuda" for GPU
38
+ self .pipeline = None
37
39
dtype = torch .float32
38
40
if device == "cuda" :
39
41
dtype = torch .bfloat16 if is_torch_bf16_gpu_available () else torch .float16
40
- device_map = "auto" if device == "cuda" else None
42
+ if torch .cuda .device_count () > 1 :
43
+ device_map = "balanced"
44
+ self .pipeline = DiffusionPipeline .from_pretrained (model_dir , torch_dtype = dtype , device_map = device_map )
41
45
42
- self .pipeline = AutoPipelineForText2Image .from_pretrained (model_dir , torch_dtype = dtype , device_map = device_map )
43
- # try to use DPMSolverMultistepScheduler
44
- if isinstance (self .pipeline , StableDiffusionPipeline ):
45
- try :
46
- self .pipeline .scheduler = DPMSolverMultistepScheduler .from_config (self .pipeline .scheduler .config )
47
- except Exception :
48
- pass
49
- self .pipeline .to (device )
46
+ if not self .pipeline :
47
+ self .pipeline = DiffusionPipeline .from_pretrained (model_dir , torch_dtype = dtype ).to (device )
50
48
51
49
def __call__ (
52
50
self ,
@@ -64,7 +62,7 @@ def __call__(
64
62
65
63
66
64
DIFFUSERS_TASKS = {
67
- "text-to-image" : SMAutoPipelineForText2Image ,
65
+ "text-to-image" : SMDiffusionPipelineForText2Image ,
68
66
}
69
67
70
68
0 commit comments