Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated _encode_prompt_with_clip and encode_prompt in train_dreamboth_sd3 #9800

Merged
merged 11 commits into from
Nov 6, 2024
26 changes: 17 additions & 9 deletions examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,20 +902,26 @@ def _encode_prompt_with_clip(
tokenizer,
prompt: str,
device=None,
text_input_ids=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)

text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)

text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")

text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

pooled_prompt_embeds = prompt_embeds[0]
Expand All @@ -937,6 +943,7 @@ def encode_prompt(
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt

Expand All @@ -945,13 +952,14 @@ def encode_prompt(

clip_prompt_embeds_list = []
clip_pooled_prompt_embeds_list = []
for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
Expand Down
Loading