-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Community Pipeline] UnCLIP Text Interpolation Pipeline #2257
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks super cool, thanks for adding it @Abhinay1997
Would be interesting to try it out on some of the examples of DALLE-2, e.g.
"a photo of an adult lion → a photo of lion cub"
"a photo of a landscape in winter → a photo of a landscape in fall"
"a photo of a victorian house → a photo of a modern house"
Also think this is a cool pipeline to build a spaces with :-)
@williamberman can you also have a look here?
Thanks @patrickvonplaten. 😄 Need your input on the attention mask to be used for the interpolated text embeddings because the results are not great when the difference in prompt length is large. |
cc @williamberman maybe? |
@patrickvonplaten , see our discussion here #1869 willberman suggested we use the larger of the two for now |
for interp_val in np.linspace(0, 1, steps): | ||
# Use the start and end prompts for 0 and 1 values as slerp results are subjectively worse than slerp results for the same. | ||
if interp_val == 0: | ||
text_embeds = start_text_embeds | ||
last_hidden_state = start_last_hidden_state | ||
elif interp_val == 1: | ||
text_embeds = end_text_embeds | ||
last_hidden_state = end_last_hidden_state | ||
else: | ||
text_embeds = UnCLIPTextInterpolationPipeline.slerp(interp_val, start_text_embeds, end_text_embeds) | ||
last_hidden_state = UnCLIPTextInterpolationPipeline.slerp( | ||
interp_val, start_last_hidden_state, end_last_hidden_state | ||
) | ||
|
||
text_model_output.text_embeds = text_embeds.unsqueeze(0).to(device) | ||
text_model_output.last_hidden_state = last_hidden_state.unsqueeze(0).to(device) | ||
|
||
res = self._generate( | ||
text_model_output=text_model_output, | ||
text_attention_mask=attention_mask, | ||
generator=generator, | ||
prior_num_inference_steps=prior_num_inference_steps, | ||
decoder_num_inference_steps=decoder_num_inference_steps, | ||
super_res_num_inference_steps=super_res_num_inference_steps, | ||
prior_guidance_scale=prior_guidance_scale, | ||
decoder_guidance_scale=decoder_guidance_scale, | ||
output_type=output_type, | ||
return_dict=return_dict, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think ideally we should batch the embeddings instead of effectively running the pipeline in a loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Will batch the pipeline run.
text_model_output.text_embeds = text_embeds.unsqueeze(0).to(device) | ||
text_model_output.last_hidden_state = last_hidden_state.unsqueeze(0).to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ideally we use the interpolated results directly instead of mutating text_model_output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Will make the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@williamberman , made changes based on your feedback. Could you review them when you can ?
P.S Ran the code through black and isort multiple times but it's still failing the code quality test
We recently updated the versions of our linters etc.. could you try making sure they're up to date and running make style
locally before pushing?
return ImagePipelineOutput(images=image) | ||
|
||
@staticmethod | ||
def slerp(val, low, high): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! slerp doesn't have to be a static or regular method on the class. Let's just move it to a regular function at the top of the file :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh. Yeah that makes sense.
|
||
@torch.no_grad() | ||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__call__ | ||
def _generate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We try to keep the __call__
function pretty self contained so lets move _generate
back to inside __call__
. This should work well with the other comment on batching the interpolated text embeddings :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean like this ?
def __call__(.....):
def _generate(.....):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
almost! could we just remove the _generate
function and have all of the logic directly in the __call__
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure !
Great start @Abhinay1997 ! |
@williamberman , made changes based on your feedback. Could you review them when you can ? P.S Ran the code through black and isort multiple times but it's still failing the code quality test |
for interp_val in torch.linspace(0, 1, steps): | ||
text_embeds = slerp(interp_val, text_model_output.text_embeds[0], text_model_output.text_embeds[1]) | ||
last_hidden_state = slerp( | ||
interp_val, text_model_output.last_hidden_state[0], text_model_output.last_hidden_state[1] | ||
) | ||
batch_text_embeds.append(text_embeds.unsqueeze(0)) | ||
batch_last_hidden_state.append(last_hidden_state.unsqueeze(0)) | ||
|
||
batch_text_embeds = torch.cat(batch_text_embeds) | ||
batch_last_hidden_state = torch.cat(batch_last_hidden_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
Love the progress @Abhinay1997 ! Could you also add some example code for running the pipeline along with the outputs it gives :) ? |
### UnCLIP Text Interpolation Pipeline | ||
|
||
This Diffusion Pipeline takes two prompts and interpolates between the two input prompts using spherical interpolation ( slerp ). The input prompts are converted to text embeddings by the pipeline's text_encoder and the interpolation is done on the resulting text_embeddings over the number of steps specified. Defaults to 5 steps. | ||
|
||
```python | ||
import torch | ||
from diffusers import DiffusionPipeline | ||
|
||
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") | ||
|
||
pipe = DiffusionPipeline.from_pretrained( | ||
"kakaobrain/karlo-v1-alpha", | ||
torch_dtype=torch.float16, | ||
custom_pipeline="unclip_text_interpolation" | ||
) | ||
pipe.to(device) | ||
|
||
start_prompt = "A photograph of an adult lion" | ||
end_prompt = "A photograph of a lion cub" | ||
#For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths. | ||
generator = torch.Generator(device=device).manual_seed(42) | ||
|
||
output = pipe(start_prompt, end_prompt, steps = 6, generator = generator, enable_sequential_cpu_offload=False) | ||
|
||
for i,image in enumerate(output.images): | ||
img.save('result%s.jpg' % i) | ||
``` | ||
|
||
The resulting images in order:- | ||
|
||
 | ||
 | ||
 | ||
 | ||
 | ||
 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@williamberman Code example for the pipeline.
|
Awesome, looks basically good to do @Abhinay1997! I needed to merge in master to get the updated linter versions :) |
Refactor to linter formatting Co-authored-by: Will Berman <wlbberman@gmail.com>
Thanks for the help Will ! Hope we are good for the merge now. |
Awesome, this is great @Abhinay1997! Would you be interested in making a spaces to showcase the pipeline? https://huggingface.co/spaces |
Sure @williamberman ! I was thinking of doing it once the PR is merged :) |
…2257) * UnCLIP Text Interpolation Pipeline * Formatter fixes * Changes based on feedback * Formatting fix * Formatting fix * isort formatting fix(?) * Remove duplicate code * Formatting fix * Refactor __call__ and change example in readme. * Update examples/community/unclip_text_interpolation.py Refactor to linter formatting Co-authored-by: Will Berman <wlbberman@gmail.com> --------- Co-authored-by: Will Berman <wlbberman@gmail.com>
No description provided.