Skip to content

Commit

Permalink
multi device demo (#258)
Browse files Browse the repository at this point in the history
Support changing runtime_state_dict's device with
`runtime_state_dict_to`
```python
    def warmup_with_load(self, file_path, device=None):
        state_dict = flow.load(file_path)
        if device is not None:
            state_dict = flow.nn.Graph.runtime_state_dict_to(state_dict, device)
        self.load_runtime_state_dict(state_dict)
```


Depends on PR in oneflow:
Oneflow-Inc/oneflow#10335

# Performance check
## save
- speed: 5.99~6.08 it/s
- mem:
  - before compile:  7.4G
  - after compile: 8.6 G
  - run: 15G

## load
- speed: 5.94~6.11 it/s
- mem:
  - before compile: 8.5G
  - after compile: 8.5G
  - run: 15G

## load from cuda 0 to cuda 1
- speed: 6.17~6.22 it/s
- mem:
  - before compile: 7.3G
  - after compile: 8.5G
  - run: 15G

## load from cuda 0 to cuda 0 and cuda 1
- speed: 6.05~6.12 it/s
- mem:
  - before compile: 7.3G
  - after compile: 8.5G
  - run: 15G

---------

Co-authored-by: Shenghang Tsai <jackalcooper@gmail.com>
  • Loading branch information
strint and jackalcooper authored Oct 9, 2023
1 parent db15295 commit 7944a85
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 26 deletions.
3 changes: 0 additions & 3 deletions examples/text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
import cv2
import oneflow as flow
import torch
import logging

logger = logging.getLogger(__name__)

# oneflow_compile should be imported before importing any diffusers
from onediff.infer_compiler import oneflow_compile
Expand Down
94 changes: 94 additions & 0 deletions examples/text_to_image_sdxl_mp_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Compile and save to oneflow graph example: python examples/text_to_image_sdxl_mp_load.py --save
# Compile and load to new device example: python examples/text_to_image_sdxl_mp_load.py --load

import os
import argparse

# cv2 must be imported before diffusers and oneflow to avlid error: AttributeError: module 'cv2.gapi' has no attribute 'wip'
# Maybe bacause oneflow use a lower version of cv2
import cv2
import oneflow as flow
import torch

parser = argparse.ArgumentParser()
parser.add_argument(
"--base", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
)
parser.add_argument("--variant", type=str, default="fp16")
parser.add_argument(
"--prompt",
type=str,
default="street style, detailed, raw photo, woman, face, shot on CineStill 800T",
)
parser.add_argument("--n_steps", type=int, default=30)
parser.add_argument("--saved_image", type=str, required=False, default="sdxl-out.png")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--save", action=argparse.BooleanOptionalAction)
parser.add_argument("--load", action=argparse.BooleanOptionalAction)
parser.add_argument("--file", type=str, required=False, default="unet_compiled")
cmd_args = parser.parse_args()


def run_sd(cmd_args, device):
# oneflow_compile should be imported before importing any diffusers
from onediff.infer_compiler import oneflow_compile
from diffusers import DiffusionPipeline

# Normal SDXL pipeline init.
seed = torch.Generator(device).manual_seed(cmd_args.seed)
output_type = "pil"
# SDXL base: StableDiffusionXLPipeline
base = DiffusionPipeline.from_pretrained(
cmd_args.base,
torch_dtype=torch.float16,
variant=cmd_args.variant,
use_safetensors=True,
)
base.to(device)

# Compile unet with oneflow
print("unet is compiled to oneflow.")
base.unet = oneflow_compile(base.unet)

# Load compiled unet with oneflow
if cmd_args.load:
print("loading graphs...")
base.unet.warmup_with_load("base_" + cmd_args.file, device)

# Normal SDXL run
# sizes = [1024, 896, 768]
sizes = [1024]
for h in sizes:
for w in sizes:
for i in range(3):
image = base(
prompt=cmd_args.prompt,
height=h,
width=w,
generator=seed,
num_inference_steps=cmd_args.n_steps,
output_type=output_type,
).images
image[0].save(f"h{h}-w{w}-i{i}-{cmd_args.saved_image}")

# Save compiled unet with oneflow
if cmd_args.save:
print("saving graphs...")
base.unet.save_graph("base_" + cmd_args.file)

if __name__ == '__main__':
if cmd_args.save:
run_sd(cmd_args, "cuda:0")

if cmd_args.load:
import torch.multiprocessing as mp
# multi device/process run
devices = ("cuda:0", "cuda:1")
procs = []
for device in devices:
p = mp.get_context("spawn").Process(target=run_sd, args=(cmd_args, device))
p.start()
procs.append(p)

for p in procs:
p.join()
31 changes: 8 additions & 23 deletions src/onediff/infer_compiler/with_oneflow_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,13 @@ def __init__(self, unet):
# os.environ["ONEFLOW_MLIR_FUSE_KERNEL_LAUNCH"] = "1"
# os.environ["ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH"] = "1"

def build(
self,
latent_model_input,
t,
encoder_hidden_states,
cross_attention_kwargs=None,
added_cond_kwargs=None,
return_dict=False,
):
encoder_hidden_states = flow._C.amp_white_identity(encoder_hidden_states)
pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=return_dict,
)
return pred

def warmup_with_load(self, file_path):
def build(self, *args, **kwargs):
return self.unet(*args, **kwargs)

def warmup_with_load(self, file_path, device=None):
state_dict = flow.load(file_path)
if device is not None:
state_dict = flow.nn.Graph.runtime_state_dict_to(state_dict, device)
self.load_runtime_state_dict(state_dict)

def save_graph(self, file_path):
Expand Down Expand Up @@ -103,8 +88,8 @@ def __call__(self, *args, **kwargs):
out = out_tree.map_leaf(output_fn)
return out[0]

def warmup_with_load(self, file_path):
self._dpl_graph.warmup_with_load(file_path)
def warmup_with_load(self, file_path, device=None):
self._dpl_graph.warmup_with_load(file_path, device)

def save_graph(self, file_path):
self._dpl_graph.save_graph(file_path)
Expand Down

0 comments on commit 7944a85

Please sign in to comment.