@@ -43,115 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
43
43
#fmt: on
44
44
45
45
46
- class SD1ModelLoaderInvocation ( BaseInvocation ):
47
- """Loading submodels of selected model. """
46
+ class PipelineModelField ( BaseModel ):
47
+ """Pipeline model field """
48
48
49
- type : Literal ["sd1_model_loader" ] = "sd1_model_loader"
50
-
51
- model_name : str = Field (default = "" , description = "Model to load" )
52
- # TODO: precision?
53
-
54
- # Schema customisation
55
- class Config (InvocationConfig ):
56
- schema_extra = {
57
- "ui" : {
58
- "tags" : ["model" , "loader" ],
59
- "type_hints" : {
60
- "model_name" : "model" # TODO: rename to model_name?
61
- }
62
- },
63
- }
64
-
65
- def invoke (self , context : InvocationContext ) -> ModelLoaderOutput :
66
-
67
- base_model = BaseModelType .StableDiffusion1 # TODO:
68
-
69
- # TODO: not found exceptions
70
- if not context .services .model_manager .model_exists (
71
- model_name = self .model_name ,
72
- base_model = base_model ,
73
- model_type = ModelType .Pipeline ,
74
- ):
75
- raise Exception (f"Unkown model name: { self .model_name } !" )
76
-
77
- """
78
- if not context.services.model_manager.model_exists(
79
- model_name=self.model_name,
80
- model_type=SDModelType.Diffusers,
81
- submodel=SDModelType.Tokenizer,
82
- ):
83
- raise Exception(
84
- f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
85
- )
86
-
87
- if not context.services.model_manager.model_exists(
88
- model_name=self.model_name,
89
- model_type=SDModelType.Diffusers,
90
- submodel=SDModelType.TextEncoder,
91
- ):
92
- raise Exception(
93
- f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
94
- )
95
-
96
- if not context.services.model_manager.model_exists(
97
- model_name=self.model_name,
98
- model_type=SDModelType.Diffusers,
99
- submodel=SDModelType.UNet,
100
- ):
101
- raise Exception(
102
- f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
103
- )
104
- """
49
+ model_name : str = Field (description = "Name of the model" )
50
+ base_model : BaseModelType = Field (description = "Base model" )
105
51
106
52
107
- return ModelLoaderOutput (
108
- unet = UNetField (
109
- unet = ModelInfo (
110
- model_name = self .model_name ,
111
- base_model = base_model ,
112
- model_type = ModelType .Pipeline ,
113
- submodel = SubModelType .UNet ,
114
- ),
115
- scheduler = ModelInfo (
116
- model_name = self .model_name ,
117
- base_model = base_model ,
118
- model_type = ModelType .Pipeline ,
119
- submodel = SubModelType .Scheduler ,
120
- ),
121
- loras = [],
122
- ),
123
- clip = ClipField (
124
- tokenizer = ModelInfo (
125
- model_name = self .model_name ,
126
- base_model = base_model ,
127
- model_type = ModelType .Pipeline ,
128
- submodel = SubModelType .Tokenizer ,
129
- ),
130
- text_encoder = ModelInfo (
131
- model_name = self .model_name ,
132
- base_model = base_model ,
133
- model_type = ModelType .Pipeline ,
134
- submodel = SubModelType .TextEncoder ,
135
- ),
136
- loras = [],
137
- ),
138
- vae = VaeField (
139
- vae = ModelInfo (
140
- model_name = self .model_name ,
141
- base_model = base_model ,
142
- model_type = ModelType .Pipeline ,
143
- submodel = SubModelType .Vae ,
144
- ),
145
- )
146
- )
53
+ class PipelineModelLoaderInvocation (BaseInvocation ):
54
+ """Loads a pipeline model, outputting its submodels."""
147
55
148
- # TODO: optimize(less code copy)
149
- class SD2ModelLoaderInvocation (BaseInvocation ):
150
- """Loading submodels of selected model."""
56
+ type : Literal ["pipeline_model_loader" ] = "pipeline_model_loader"
151
57
152
- type : Literal ["sd2_model_loader" ] = "sd2_model_loader"
153
-
154
- model_name : str = Field (default = "" , description = "Model to load" )
58
+ model : PipelineModelField = Field (description = "The model to load" )
155
59
# TODO: precision?
156
60
157
61
# Schema customisation
@@ -160,22 +64,24 @@ class Config(InvocationConfig):
160
64
"ui" : {
161
65
"tags" : ["model" , "loader" ],
162
66
"type_hints" : {
163
- "model_name " : "model" # TODO: rename to model_name?
67
+ "model " : "model"
164
68
}
165
69
},
166
70
}
167
71
168
72
def invoke (self , context : InvocationContext ) -> ModelLoaderOutput :
169
73
170
- base_model = BaseModelType .StableDiffusion2 # TODO:
74
+ base_model = self .model .base_model
75
+ model_name = self .model .model_name
76
+ model_type = ModelType .Pipeline
171
77
172
78
# TODO: not found exceptions
173
79
if not context .services .model_manager .model_exists (
174
- model_name = self . model_name ,
80
+ model_name = model_name ,
175
81
base_model = base_model ,
176
- model_type = ModelType . Pipeline ,
82
+ model_type = model_type ,
177
83
):
178
- raise Exception (f"Unkown model name : { self . model_name } ! " )
84
+ raise Exception (f"Unknown { base_model } { model_type } model : { model_name } " )
179
85
180
86
"""
181
87
if not context.services.model_manager.model_exists(
@@ -210,39 +116,39 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
210
116
return ModelLoaderOutput (
211
117
unet = UNetField (
212
118
unet = ModelInfo (
213
- model_name = self . model_name ,
119
+ model_name = model_name ,
214
120
base_model = base_model ,
215
- model_type = ModelType . Pipeline ,
121
+ model_type = model_type ,
216
122
submodel = SubModelType .UNet ,
217
123
),
218
124
scheduler = ModelInfo (
219
- model_name = self . model_name ,
125
+ model_name = model_name ,
220
126
base_model = base_model ,
221
- model_type = ModelType . Pipeline ,
127
+ model_type = model_type ,
222
128
submodel = SubModelType .Scheduler ,
223
129
),
224
130
loras = [],
225
131
),
226
132
clip = ClipField (
227
133
tokenizer = ModelInfo (
228
- model_name = self . model_name ,
134
+ model_name = model_name ,
229
135
base_model = base_model ,
230
- model_type = ModelType . Pipeline ,
136
+ model_type = model_type ,
231
137
submodel = SubModelType .Tokenizer ,
232
138
),
233
139
text_encoder = ModelInfo (
234
- model_name = self . model_name ,
140
+ model_name = model_name ,
235
141
base_model = base_model ,
236
- model_type = ModelType . Pipeline ,
142
+ model_type = model_type ,
237
143
submodel = SubModelType .TextEncoder ,
238
144
),
239
145
loras = [],
240
146
),
241
147
vae = VaeField (
242
148
vae = ModelInfo (
243
- model_name = self . model_name ,
149
+ model_name = model_name ,
244
150
base_model = base_model ,
245
- model_type = ModelType . Pipeline ,
151
+ model_type = model_type ,
246
152
submodel = SubModelType .Vae ,
247
153
),
248
154
)
0 commit comments