Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi device demo #258

Merged
merged 7 commits into from
Oct 9, 2023
Merged

multi device demo #258

merged 7 commits into from
Oct 9, 2023

Conversation

strint
Copy link
Collaborator

@strint strint commented Sep 22, 2023

Support changing runtime_state_dict's device with runtime_state_dict_to

    def warmup_with_load(self, file_path, device=None):
        state_dict = flow.load(file_path)
        if device is not None:
            state_dict = flow.nn.Graph.runtime_state_dict_to(state_dict, device)
        self.load_runtime_state_dict(state_dict)

Depends on PR in oneflow: Oneflow-Inc/oneflow#10335

Performance check

save

  • speed: 5.99~6.08 it/s
  • mem:
    • before compile: 7.4G
    • after compile: 8.6 G
    • run: 15G

load

  • speed: 5.94~6.11 it/s
  • mem:
    • before compile: 8.5G
    • after compile: 8.5G
    • run: 15G

load from cuda 0 to cuda 1

  • speed: 6.17~6.22 it/s
  • mem:
    • before compile: 7.3G
    • after compile: 8.5G
    • run: 15G

load from cuda 0 to cuda 0 and cuda 1

  • speed: 6.05~6.12 it/s
  • mem:
    • before compile: 7.3G
    • after compile: 8.5G
    • run: 15G


if __name__ == '__main__':
if cmd_args.save:
run_sd(cmd_args, "cuda:0")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda 0 compile and save

devices = ("cuda:0", "cuda:1")
procs = []
for device in devices:
p = mp.get_context("spawn").Process(target=run_sd, args=(cmd_args, device))
Copy link
Collaborator Author

@strint strint Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda 0 and 1 load and run

state_dict = flow.load(file_path)
if device is not None:
state_dict = flow.nn.Graph.runtime_state_dict_to(state_dict, device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change runtime_state_dict device

@strint strint merged commit 7944a85 into main Oct 9, 2023
@strint strint deleted the new_device branch October 9, 2023 02:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants