Skip to content

Commit d7cf143

Browse files
authored
fix diffusers offline (#363)
1 parent 36ff383 commit d7cf143

File tree

4 files changed

+103
-138
lines changed

4 files changed

+103
-138
lines changed

docker_images/diffusers/app/lora.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,19 @@
11
import logging
22

33
import torch.nn as nn
4-
from huggingface_hub import hf_hub_download, model_info
4+
from app import offline
55
from safetensors.torch import load_file
66

77

88
logger = logging.getLogger(__name__)
99

1010

11-
class LoRAPipelineMixin(object):
12-
def __init__(self):
13-
if not hasattr(self, "current_lora_adapter"):
14-
self.current_lora_adapter = None
15-
if not hasattr(self, "model_id"):
16-
self.model_id = None
17-
if not hasattr(self, "current_tokens_loaded"):
18-
self.current_tokens_loaded = 0
19-
11+
class LoRAPipelineMixin(offline.OfflineBestEffortMixin):
2012
@staticmethod
2113
def _get_lora_weight_name(model_data):
22-
is_diffusers_lora = LoRAPipelineMixin._is_diffusers_lora(model_data)
14+
weight_name_candidate = LoRAPipelineMixin._lora_weights_candidates(model_data)
15+
if weight_name_candidate:
16+
return weight_name_candidate
2317
file_to_load = next(
2418
(
2519
file.rfilename
@@ -28,25 +22,28 @@ def _get_lora_weight_name(model_data):
2822
),
2923
None,
3024
)
31-
if not file_to_load and not is_diffusers_lora:
25+
if not file_to_load and not weight_name_candidate:
3226
raise ValueError("No *.safetensors file found for your LoRA")
33-
weight_name = file_to_load if not is_diffusers_lora else None
34-
return weight_name
27+
return file_to_load
3528

3629
@staticmethod
3730
def _is_lora(model_data):
38-
return LoRAPipelineMixin._is_diffusers_lora(
39-
model_data
40-
) or "lora" in model_data.cardData.get("tags", [])
31+
return LoRAPipelineMixin._lora_weights_candidates(model_data) or (
32+
model_data.cardData.get("tags")
33+
and "lora" in model_data.cardData.get("tags", [])
34+
)
4135

4236
@staticmethod
43-
def _is_diffusers_lora(model_data):
44-
is_diffusers_lora = any(
45-
file.rfilename
46-
in ("pytorch_lora_weights.bin", "pytorch_lora_weights.safetensors")
47-
for file in model_data.siblings
48-
)
49-
return is_diffusers_lora
37+
def _lora_weights_candidates(model_data):
38+
candidate = None
39+
for file in model_data.siblings:
40+
rfilename = str(file.rfilename)
41+
if rfilename.endswith("pytorch_lora_weights.bin"):
42+
candidate = rfilename
43+
elif rfilename.endswith("pytorch_lora_weights.safetensors"):
44+
candidate = rfilename
45+
break
46+
return candidate
5047

5148
@staticmethod
5249
def _is_safetensors_pivotal(model_data):
@@ -72,7 +69,8 @@ def _fuse_or_raise(self):
7269
self.current_lora_adapter = None
7370
raise
7471

75-
def _reset_tokenizer_and_encoder(self, tokenizer, text_encoder, token_to_remove):
72+
@staticmethod
73+
def _reset_tokenizer_and_encoder(tokenizer, text_encoder, token_to_remove):
7674
token_id = tokenizer(token_to_remove)["input_ids"][1]
7775
del tokenizer._added_tokens_decoder[token_id]
7876
del tokenizer._added_tokens_encoder[token_to_remove]
@@ -101,13 +99,14 @@ def _unload_textual_embeddings(self):
10199

102100
def _load_textual_embeddings(self, adapter, model_data):
103101
if self._is_pivotal_tuning_lora(model_data):
104-
embedding_path = hf_hub_download(
102+
embedding_path = self._hub_repo_file(
105103
repo_id=adapter,
106104
filename="embeddings.safetensors"
107105
if self._is_safetensors_pivotal(model_data)
108106
else "embeddings.pti",
109107
repo_type="model",
110108
)
109+
111110
embeddings = load_file(embedding_path)
112111
state_dict_clip_l = (
113112
embeddings.get("text_encoders_0")
@@ -152,7 +151,7 @@ def _load_lora_adapter(self, kwargs):
152151
if adapter is not None:
153152
logger.info("LoRA adapter %s requested", adapter)
154153
if adapter != self.current_lora_adapter:
155-
model_data = model_info(adapter, token=self.use_auth_token)
154+
model_data = self._hub_model_info(adapter)
156155
if not self._is_lora(model_data):
157156
msg = f"Requested adapter {adapter:s} is not a LoRA adapter"
158157
logger.error(msg)
@@ -167,10 +166,11 @@ def _load_lora_adapter(self, kwargs):
167166
self.current_lora_adapter,
168167
adapter,
169168
)
170-
self.ldm.unfuse_lora()
171-
self.ldm.unload_lora_weights()
172-
self._unload_textual_embeddings()
173-
self.current_lora_adapter = None
169+
if self.current_lora_adapter is not None:
170+
self.ldm.unfuse_lora()
171+
self.ldm.unload_lora_weights()
172+
self._unload_textual_embeddings()
173+
self.current_lora_adapter = None
174174
logger.info("LoRA weights unloaded, loading new weights")
175175
weight_name = self._get_lora_weight_name(model_data=model_data)
176176

@@ -184,7 +184,7 @@ def _load_lora_adapter(self, kwargs):
184184
else:
185185
logger.info("LoRA adapter %s already loaded", adapter)
186186
# Needed while a LoRA is loaded w/ model
187-
model_data = model_info(adapter, token=self.use_auth_token)
187+
model_data = self._hub_model_info(adapter)
188188
if (
189189
self._is_pivotal_tuning_lora(model_data)
190190
and self.current_tokens_loaded == 0
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import json
2+
import logging
3+
import os
4+
5+
from huggingface_hub import file_download, hf_api, hf_hub_download, model_info, utils
6+
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class OfflineBestEffortMixin(object):
12+
def _hub_repo_file(self, repo_id, filename, repo_type="model"):
13+
if self.offline_preferred:
14+
try:
15+
config_file = hf_hub_download(
16+
repo_id,
17+
filename,
18+
token=self.use_auth_token,
19+
local_files_only=True,
20+
repo_type=repo_type,
21+
)
22+
except utils.LocalEntryNotFoundError:
23+
logger.info("Unable to fetch model index in local cache")
24+
else:
25+
return config_file
26+
27+
return hf_hub_download(
28+
repo_id, filename, token=self.use_auth_token, repo_type=repo_type
29+
)
30+
31+
def _hub_model_info(self, model_id):
32+
"""
33+
This method tries to fetch locally cached model_info if any.
34+
If none, it requests the Hub. Useful for pre cached private models when no token is available
35+
"""
36+
if self.offline_preferred:
37+
cache_root = os.getenv(
38+
"DIFFUSERS_CACHE", os.getenv("HUGGINGFACE_HUB_CACHE", "")
39+
)
40+
folder_name = file_download.repo_folder_name(
41+
repo_id=model_id, repo_type="model"
42+
)
43+
folder_path = os.path.join(cache_root, folder_name)
44+
logger.debug("Cache folder path %s", folder_path)
45+
filename = os.path.join(folder_path, "hub_model_info.json")
46+
try:
47+
with open(filename, "r") as f:
48+
model_data = json.load(f)
49+
except OSError:
50+
logger.info(
51+
"No cached model info found in file %s found for model %s. Fetching on the hub",
52+
filename,
53+
model_id,
54+
)
55+
else:
56+
model_data = hf_api.ModelInfo(**model_data)
57+
return model_data
58+
model_data = model_info(model_id, token=self.use_auth_token)
59+
return model_data

docker_images/diffusers/app/pipelines/image_to_image.py

Lines changed: 4 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44

55
import torch
6-
from app import idle, timing, validation
6+
from app import idle, offline, timing, validation
77
from app.pipelines import Pipeline
88
from diffusers import (
99
AltDiffusionImg2ImgPipeline,
@@ -26,47 +26,20 @@
2626
StableUnCLIPImg2ImgPipeline,
2727
StableUnCLIPPipeline,
2828
)
29-
from huggingface_hub import file_download, hf_api, hf_hub_download, model_info, utils
3029
from PIL import Image
3130

3231

3332
logger = logging.getLogger(__name__)
3433

3534

36-
class ImageToImagePipeline(Pipeline):
35+
class ImageToImagePipeline(Pipeline, offline.OfflineBestEffortMixin):
3736
def __init__(self, model_id: str):
3837
use_auth_token = os.getenv("HF_API_TOKEN")
3938
self.use_auth_token = use_auth_token
4039
# This should allow us to make the image work with private models when no token is provided, if the said model
4140
# is already in local cache
4241
self.offline_preferred = validation.str_to_bool(os.getenv("OFFLINE_PREFERRED"))
43-
fetched = False
44-
if self.offline_preferred:
45-
cache_root = os.getenv(
46-
"DIFFUSERS_CACHE", os.getenv("HUGGINGFACE_HUB_CACHE", "")
47-
)
48-
folder_name = file_download.repo_folder_name(
49-
repo_id=model_id, repo_type="model"
50-
)
51-
folder_path = os.path.join(cache_root, folder_name)
52-
logger.debug("Cache folder path %s", folder_path)
53-
filename = os.path.join(folder_path, "hub_model_info.json")
54-
try:
55-
with open(filename, "r") as f:
56-
model_data = json.load(f)
57-
except OSError:
58-
logger.info(
59-
"No cached model info found in file %s found for model %s. Fetching on the hub",
60-
filename,
61-
model_id,
62-
)
63-
else:
64-
model_data = hf_api.ModelInfo(**model_data)
65-
fetched = True
66-
67-
if not fetched:
68-
model_data = model_info(model_id, token=self.use_auth_token)
69-
42+
model_data = self._hub_model_info(model_id)
7043
kwargs = (
7144
{"safety_checker": None}
7245
if model_id.startswith("hf-internal-testing/")
@@ -84,25 +57,7 @@ def __init__(self, model_id: str):
8457
config_file_name = file_name
8558
break
8659
if config_file_name:
87-
fetched = False
88-
if self.offline_preferred:
89-
try:
90-
config_file = hf_hub_download(
91-
model_id,
92-
config_file_name,
93-
token=self.use_auth_token,
94-
local_files_only=True,
95-
)
96-
except utils.LocalEntryNotFoundError:
97-
logger.info("Unable to fetch model index in local cache")
98-
else:
99-
fetched = True
100-
if not fetched:
101-
config_file = hf_hub_download(
102-
model_id,
103-
config_file_name,
104-
token=self.use_auth_token,
105-
)
60+
config_file = self._hub_repo_file(model_id, config_file_name)
10661

10762
with open(config_file, "r") as f:
10863
config_dict = json.load(f)

docker_images/diffusers/app/pipelines/text_to_image.py

Lines changed: 8 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import TYPE_CHECKING
66

77
import torch
8-
from app import idle, lora, timing, validation
8+
from app import idle, lora, offline, timing, validation
99
from app.pipelines import Pipeline
1010
from diffusers import (
1111
AutoencoderKL,
@@ -14,7 +14,6 @@
1414
EulerAncestralDiscreteScheduler,
1515
)
1616
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
17-
from huggingface_hub import file_download, hf_api, hf_hub_download, model_info, utils
1817

1918

2019
logger = logging.getLogger(__name__)
@@ -23,7 +22,9 @@
2322
from PIL import Image
2423

2524

26-
class TextToImagePipeline(Pipeline, lora.LoRAPipelineMixin):
25+
class TextToImagePipeline(
26+
Pipeline, lora.LoRAPipelineMixin, offline.OfflineBestEffortMixin
27+
):
2728
def __init__(self, model_id: str):
2829
self.current_lora_adapter = None
2930
self.model_id = None
@@ -32,32 +33,7 @@ def __init__(self, model_id: str):
3233
# This should allow us to make the image work with private models when no token is provided, if the said model
3334
# is already in local cache
3435
self.offline_preferred = validation.str_to_bool(os.getenv("OFFLINE_PREFERRED"))
35-
fetched = False
36-
if self.offline_preferred:
37-
cache_root = os.getenv(
38-
"DIFFUSERS_CACHE", os.getenv("HUGGINGFACE_HUB_CACHE", "")
39-
)
40-
folder_name = file_download.repo_folder_name(
41-
repo_id=model_id, repo_type="model"
42-
)
43-
folder_path = os.path.join(cache_root, folder_name)
44-
logger.debug("Cache folder path %s", folder_path)
45-
filename = os.path.join(folder_path, "hub_model_info.json")
46-
try:
47-
with open(filename, "r") as f:
48-
model_data = json.load(f)
49-
except OSError:
50-
logger.info(
51-
"No cached model info found in file %s found for model %s. Fetching on the hub",
52-
filename,
53-
model_id,
54-
)
55-
else:
56-
model_data = hf_api.ModelInfo(**model_data)
57-
fetched = True
58-
59-
if not fetched:
60-
model_data = model_info(model_id, token=self.use_auth_token)
36+
model_data = self._hub_model_info(model_id)
6137

6238
kwargs = (
6339
{"safety_checker": None}
@@ -74,26 +50,7 @@ def __init__(self, model_id: str):
7450
if self._is_lora(model_data):
7551
model_type = "LoraModel"
7652
elif has_model_index:
77-
fetched = False
78-
if self.offline_preferred:
79-
try:
80-
config_file = hf_hub_download(
81-
model_id,
82-
"model_index.json",
83-
token=self.use_auth_token,
84-
local_files_only=True,
85-
)
86-
except utils.LocalEntryNotFoundError:
87-
logger.info("Unable to fetch model index in local cache")
88-
else:
89-
fetched = True
90-
91-
if not fetched:
92-
config_file = hf_hub_download(
93-
model_id,
94-
"model_index.json",
95-
token=self.use_auth_token,
96-
)
53+
config_file = self._hub_repo_file(model_id, "model_index.json")
9754
with open(config_file, "r") as f:
9855
config_dict = json.load(f)
9956
model_type = config_dict.get("_class_name", None)
@@ -107,15 +64,9 @@ def __init__(self, model_id: str):
10764
raise ValueError(
10865
"No `base_model` found. Please include a `base_model` on your README.md tags"
10966
)
110-
111-
weight_name = self._get_lora_weight_name(model_data)
11267
self._load_sd_with_sdxl_fix(model_to_load, **kwargs)
113-
self.ldm.load_lora_weights(
114-
model_id, weight_name=weight_name, use_auth_token=self.use_auth_token
115-
)
116-
self.current_lora_adapter = model_id
117-
self._fuse_or_raise()
118-
logger.info("LoRA adapter %s loaded", model_id)
68+
# The lora will actually be lazily loaded on the fly per request
69+
self.current_lora_adapter = None
11970
else:
12071
if model_id == "stabilityai/stable-diffusion-xl-base-1.0":
12172
self._load_sd_with_sdxl_fix(model_id, **kwargs)

0 commit comments

Comments
 (0)