Skip to content

Commit df9f47b

Browse files
support loras with replicate (#3054)
1 parent 1e40c62 commit df9f47b

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask
2424
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
2525
from .openai import OpenAIConversationalTask
26-
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
26+
from .replicate import ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
2727
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
2828
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
2929

@@ -115,7 +115,7 @@
115115
"conversational": OpenAIConversationalTask(),
116116
},
117117
"replicate": {
118-
"text-to-image": ReplicateTask("text-to-image"),
118+
"text-to-image": ReplicateTextToImageTask(),
119119
"text-to-speech": ReplicateTextToSpeechTask(),
120120
"text-to-video": ReplicateTask("text-to-video"),
121121
},

src/huggingface_hub/inference/_providers/replicate.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ def get_response(self, response: Union[bytes, Dict], request_params: Optional[Re
4747
return get_session().get(output_url).content
4848

4949

50+
class ReplicateTextToImageTask(ReplicateTask):
51+
def __init__(self):
52+
super().__init__("text-to-image")
53+
54+
def _prepare_payload_as_dict(
55+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
56+
) -> Optional[Dict]:
57+
payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment]
58+
if provider_mapping_info.adapter_weights_path is not None:
59+
payload["input"]["lora_weights"] = f"https://huggingface.co/{provider_mapping_info.hf_model_id}"
60+
return payload
61+
62+
5063
class ReplicateTextToSpeechTask(ReplicateTask):
5164
def __init__(self):
5265
super().__init__("text-to-speech")

0 commit comments

Comments
 (0)