Skip to content

Commit 22c337b

Browse files
Update UI To Use New Model Manager (#3548)
PR for the Model Manager UI work related to 3.0 [DONE] - Update ModelType Config names to be specific so that the front end can parse them correctly. - Rebuild frontend schema to reflect these changes. - Update Linear UI Text To Image and Image to Image to work with the new model loader. - Updated the ModelInput component in the Node Editor to work with the new changes. [TODO REMEMBER] - Add proper types for ModelLoaderType in `ModelSelect.tsx` [TODO] - Everything else.
2 parents 2d889e1 + 339e7ce commit 22c337b

File tree

67 files changed

+709
-667
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+709
-667
lines changed

invokeai/app/api/routers/models.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from pydantic import BaseModel, Field, parse_obj_as
88
from ..dependencies import ApiDependencies
99
from invokeai.backend import BaseModelType, ModelType
10-
from invokeai.backend.model_management.models import get_all_model_configs
11-
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
10+
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
11+
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
1212

1313
models_router = APIRouter(prefix="/v1/models", tags=["models"])
1414

@@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
6262
info: DiffusersModelInfo = Field(description="The converted model info")
6363

6464
class ModelsList(BaseModel):
65-
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
66-
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
65+
models: list[MODEL_CONFIGS]
6766

6867

6968
@models_router.get(
@@ -72,10 +71,10 @@ class ModelsList(BaseModel):
7271
responses={200: {"model": ModelsList }},
7372
)
7473
async def list_models(
75-
base_model: BaseModelType = Query(
74+
base_model: Optional[BaseModelType] = Query(
7675
default=None, description="Base model"
7776
),
78-
model_type: ModelType = Query(
77+
model_type: Optional[ModelType] = Query(
7978
default=None, description="The type of model to get"
8079
),
8180
) -> ModelsList:

invokeai/app/api_app.py

+16
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,22 @@ def custom_openapi():
120120

121121
invoker_schema["output"] = outputs_ref
122122

123+
from invokeai.backend.model_management.models import get_model_config_enums
124+
for model_config_format_enum in set(get_model_config_enums()):
125+
name = model_config_format_enum.__qualname__
126+
127+
if name in openapi_schema["components"]["schemas"]:
128+
# print(f"Config with name {name} already defined")
129+
continue
130+
131+
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
132+
openapi_schema["components"]["schemas"][name] = dict(
133+
title=name,
134+
description="An enumeration.",
135+
type="string",
136+
enum=list(v.value for v in model_config_format_enum),
137+
)
138+
123139
app.openapi_schema = openapi_schema
124140
return app.openapi_schema
125141

invokeai/app/invocations/model.py

+25-119
Original file line numberDiff line numberDiff line change
@@ -43,115 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
4343
#fmt: on
4444

4545

46-
class SD1ModelLoaderInvocation(BaseInvocation):
47-
"""Loading submodels of selected model."""
46+
class PipelineModelField(BaseModel):
47+
"""Pipeline model field"""
4848

49-
type: Literal["sd1_model_loader"] = "sd1_model_loader"
50-
51-
model_name: str = Field(default="", description="Model to load")
52-
# TODO: precision?
53-
54-
# Schema customisation
55-
class Config(InvocationConfig):
56-
schema_extra = {
57-
"ui": {
58-
"tags": ["model", "loader"],
59-
"type_hints": {
60-
"model_name": "model" # TODO: rename to model_name?
61-
}
62-
},
63-
}
64-
65-
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
66-
67-
base_model = BaseModelType.StableDiffusion1 # TODO:
68-
69-
# TODO: not found exceptions
70-
if not context.services.model_manager.model_exists(
71-
model_name=self.model_name,
72-
base_model=base_model,
73-
model_type=ModelType.Pipeline,
74-
):
75-
raise Exception(f"Unkown model name: {self.model_name}!")
76-
77-
"""
78-
if not context.services.model_manager.model_exists(
79-
model_name=self.model_name,
80-
model_type=SDModelType.Diffusers,
81-
submodel=SDModelType.Tokenizer,
82-
):
83-
raise Exception(
84-
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
85-
)
86-
87-
if not context.services.model_manager.model_exists(
88-
model_name=self.model_name,
89-
model_type=SDModelType.Diffusers,
90-
submodel=SDModelType.TextEncoder,
91-
):
92-
raise Exception(
93-
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
94-
)
95-
96-
if not context.services.model_manager.model_exists(
97-
model_name=self.model_name,
98-
model_type=SDModelType.Diffusers,
99-
submodel=SDModelType.UNet,
100-
):
101-
raise Exception(
102-
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
103-
)
104-
"""
49+
model_name: str = Field(description="Name of the model")
50+
base_model: BaseModelType = Field(description="Base model")
10551

10652

107-
return ModelLoaderOutput(
108-
unet=UNetField(
109-
unet=ModelInfo(
110-
model_name=self.model_name,
111-
base_model=base_model,
112-
model_type=ModelType.Pipeline,
113-
submodel=SubModelType.UNet,
114-
),
115-
scheduler=ModelInfo(
116-
model_name=self.model_name,
117-
base_model=base_model,
118-
model_type=ModelType.Pipeline,
119-
submodel=SubModelType.Scheduler,
120-
),
121-
loras=[],
122-
),
123-
clip=ClipField(
124-
tokenizer=ModelInfo(
125-
model_name=self.model_name,
126-
base_model=base_model,
127-
model_type=ModelType.Pipeline,
128-
submodel=SubModelType.Tokenizer,
129-
),
130-
text_encoder=ModelInfo(
131-
model_name=self.model_name,
132-
base_model=base_model,
133-
model_type=ModelType.Pipeline,
134-
submodel=SubModelType.TextEncoder,
135-
),
136-
loras=[],
137-
),
138-
vae=VaeField(
139-
vae=ModelInfo(
140-
model_name=self.model_name,
141-
base_model=base_model,
142-
model_type=ModelType.Pipeline,
143-
submodel=SubModelType.Vae,
144-
),
145-
)
146-
)
53+
class PipelineModelLoaderInvocation(BaseInvocation):
54+
"""Loads a pipeline model, outputting its submodels."""
14755

148-
# TODO: optimize(less code copy)
149-
class SD2ModelLoaderInvocation(BaseInvocation):
150-
"""Loading submodels of selected model."""
56+
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
15157

152-
type: Literal["sd2_model_loader"] = "sd2_model_loader"
153-
154-
model_name: str = Field(default="", description="Model to load")
58+
model: PipelineModelField = Field(description="The model to load")
15559
# TODO: precision?
15660

15761
# Schema customisation
@@ -160,22 +64,24 @@ class Config(InvocationConfig):
16064
"ui": {
16165
"tags": ["model", "loader"],
16266
"type_hints": {
163-
"model_name": "model" # TODO: rename to model_name?
67+
"model": "model"
16468
}
16569
},
16670
}
16771

16872
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
16973

170-
base_model = BaseModelType.StableDiffusion2 # TODO:
74+
base_model = self.model.base_model
75+
model_name = self.model.model_name
76+
model_type = ModelType.Pipeline
17177

17278
# TODO: not found exceptions
17379
if not context.services.model_manager.model_exists(
174-
model_name=self.model_name,
80+
model_name=model_name,
17581
base_model=base_model,
176-
model_type=ModelType.Pipeline,
82+
model_type=model_type,
17783
):
178-
raise Exception(f"Unkown model name: {self.model_name}!")
84+
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
17985

18086
"""
18187
if not context.services.model_manager.model_exists(
@@ -210,39 +116,39 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
210116
return ModelLoaderOutput(
211117
unet=UNetField(
212118
unet=ModelInfo(
213-
model_name=self.model_name,
119+
model_name=model_name,
214120
base_model=base_model,
215-
model_type=ModelType.Pipeline,
121+
model_type=model_type,
216122
submodel=SubModelType.UNet,
217123
),
218124
scheduler=ModelInfo(
219-
model_name=self.model_name,
125+
model_name=model_name,
220126
base_model=base_model,
221-
model_type=ModelType.Pipeline,
127+
model_type=model_type,
222128
submodel=SubModelType.Scheduler,
223129
),
224130
loras=[],
225131
),
226132
clip=ClipField(
227133
tokenizer=ModelInfo(
228-
model_name=self.model_name,
134+
model_name=model_name,
229135
base_model=base_model,
230-
model_type=ModelType.Pipeline,
136+
model_type=model_type,
231137
submodel=SubModelType.Tokenizer,
232138
),
233139
text_encoder=ModelInfo(
234-
model_name=self.model_name,
140+
model_name=model_name,
235141
base_model=base_model,
236-
model_type=ModelType.Pipeline,
142+
model_type=model_type,
237143
submodel=SubModelType.TextEncoder,
238144
),
239145
loras=[],
240146
),
241147
vae=VaeField(
242148
vae=ModelInfo(
243-
model_name=self.model_name,
149+
model_name=model_name,
244150
base_model=base_model,
245-
model_type=ModelType.Pipeline,
151+
model_type=model_type,
246152
submodel=SubModelType.Vae,
247153
),
248154
)

invokeai/app/services/model_manager_service.py

+4-39
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from abc import ABC, abstractmethod
77
from pathlib import Path
8-
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
8+
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
99
from dataclasses import dataclass
1010

1111
from invokeai.backend.model_management.model_manager import (
@@ -69,19 +69,6 @@ def model_exists(
6969
) -> bool:
7070
pass
7171

72-
@abstractmethod
73-
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
74-
"""
75-
Returns the name and typeof the default model, or None
76-
if none is defined.
77-
"""
78-
pass
79-
80-
@abstractmethod
81-
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
82-
"""Sets the default model to the indicated name."""
83-
pass
84-
8572
@abstractmethod
8673
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
8774
"""
@@ -270,17 +257,6 @@ def model_exists(
270257
model_type,
271258
)
272259

273-
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
274-
"""
275-
Returns the name of the default model, or None
276-
if none is defined.
277-
"""
278-
return self.mgr.default_model()
279-
280-
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
281-
"""Sets the default model to the indicated name."""
282-
self.mgr.set_default_model(model_name, base_model, model_type)
283-
284260
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
285261
"""
286262
Given a model name returns a dict-like (OmegaConf) object describing it.
@@ -297,21 +273,10 @@ def list_models(
297273
self,
298274
base_model: Optional[BaseModelType] = None,
299275
model_type: Optional[ModelType] = None
300-
) -> dict:
276+
) -> list[dict]:
277+
# ) -> dict:
301278
"""
302-
Return a dict of models in the format:
303-
{ model_type1:
304-
{ model_name1: {'status': 'active'|'cached'|'not loaded',
305-
'model_name' : name,
306-
'model_type' : SDModelType,
307-
'description': description,
308-
'format': 'folder'|'safetensors'|'ckpt'
309-
},
310-
model_name2: { etc }
311-
},
312-
model_type2:
313-
{ model_name_n: etc
314-
}
279+
Return a list of models.
315280
"""
316281
return self.mgr.list_models(base_model, model_type)
317282

0 commit comments

Comments
 (0)