-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support changing runtime_state_dict's device with `runtime_state_dict_to` ```python def warmup_with_load(self, file_path, device=None): state_dict = flow.load(file_path) if device is not None: state_dict = flow.nn.Graph.runtime_state_dict_to(state_dict, device) self.load_runtime_state_dict(state_dict) ``` Depends on PR in oneflow: Oneflow-Inc/oneflow#10335 # Performance check ## save - speed: 5.99~6.08 it/s - mem: - before compile: 7.4G - after compile: 8.6 G - run: 15G ## load - speed: 5.94~6.11 it/s - mem: - before compile: 8.5G - after compile: 8.5G - run: 15G ## load from cuda 0 to cuda 1 - speed: 6.17~6.22 it/s - mem: - before compile: 7.3G - after compile: 8.5G - run: 15G ## load from cuda 0 to cuda 0 and cuda 1 - speed: 6.05~6.12 it/s - mem: - before compile: 7.3G - after compile: 8.5G - run: 15G --------- Co-authored-by: Shenghang Tsai <jackalcooper@gmail.com>
- Loading branch information
1 parent
db15295
commit 7944a85
Showing
3 changed files
with
102 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Compile and save to oneflow graph example: python examples/text_to_image_sdxl_mp_load.py --save | ||
# Compile and load to new device example: python examples/text_to_image_sdxl_mp_load.py --load | ||
|
||
import os | ||
import argparse | ||
|
||
# cv2 must be imported before diffusers and oneflow to avlid error: AttributeError: module 'cv2.gapi' has no attribute 'wip' | ||
# Maybe bacause oneflow use a lower version of cv2 | ||
import cv2 | ||
import oneflow as flow | ||
import torch | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--base", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" | ||
) | ||
parser.add_argument("--variant", type=str, default="fp16") | ||
parser.add_argument( | ||
"--prompt", | ||
type=str, | ||
default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", | ||
) | ||
parser.add_argument("--n_steps", type=int, default=30) | ||
parser.add_argument("--saved_image", type=str, required=False, default="sdxl-out.png") | ||
parser.add_argument("--seed", type=int, default=1) | ||
parser.add_argument("--save", action=argparse.BooleanOptionalAction) | ||
parser.add_argument("--load", action=argparse.BooleanOptionalAction) | ||
parser.add_argument("--file", type=str, required=False, default="unet_compiled") | ||
cmd_args = parser.parse_args() | ||
|
||
|
||
def run_sd(cmd_args, device): | ||
# oneflow_compile should be imported before importing any diffusers | ||
from onediff.infer_compiler import oneflow_compile | ||
from diffusers import DiffusionPipeline | ||
|
||
# Normal SDXL pipeline init. | ||
seed = torch.Generator(device).manual_seed(cmd_args.seed) | ||
output_type = "pil" | ||
# SDXL base: StableDiffusionXLPipeline | ||
base = DiffusionPipeline.from_pretrained( | ||
cmd_args.base, | ||
torch_dtype=torch.float16, | ||
variant=cmd_args.variant, | ||
use_safetensors=True, | ||
) | ||
base.to(device) | ||
|
||
# Compile unet with oneflow | ||
print("unet is compiled to oneflow.") | ||
base.unet = oneflow_compile(base.unet) | ||
|
||
# Load compiled unet with oneflow | ||
if cmd_args.load: | ||
print("loading graphs...") | ||
base.unet.warmup_with_load("base_" + cmd_args.file, device) | ||
|
||
# Normal SDXL run | ||
# sizes = [1024, 896, 768] | ||
sizes = [1024] | ||
for h in sizes: | ||
for w in sizes: | ||
for i in range(3): | ||
image = base( | ||
prompt=cmd_args.prompt, | ||
height=h, | ||
width=w, | ||
generator=seed, | ||
num_inference_steps=cmd_args.n_steps, | ||
output_type=output_type, | ||
).images | ||
image[0].save(f"h{h}-w{w}-i{i}-{cmd_args.saved_image}") | ||
|
||
# Save compiled unet with oneflow | ||
if cmd_args.save: | ||
print("saving graphs...") | ||
base.unet.save_graph("base_" + cmd_args.file) | ||
|
||
if __name__ == '__main__': | ||
if cmd_args.save: | ||
run_sd(cmd_args, "cuda:0") | ||
|
||
if cmd_args.load: | ||
import torch.multiprocessing as mp | ||
# multi device/process run | ||
devices = ("cuda:0", "cuda:1") | ||
procs = [] | ||
for device in devices: | ||
p = mp.get_context("spawn").Process(target=run_sd, args=(cmd_args, device)) | ||
p.start() | ||
procs.append(p) | ||
|
||
for p in procs: | ||
p.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters