Skip to content

Commit

Permalink
Update CUDA compute capability to support Blackwell (#7047)
Browse files Browse the repository at this point in the history
Update CUDA compute capability for cross compile according to wiki page.
https://en.wikipedia.org/wiki/CUDA#GPUs_supported

---------

Signed-off-by: Hongwei <hongweichen@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
  • Loading branch information
hwchen2017 authored and tohtana committed Feb 28, 2025
1 parent e946615 commit 38e9bf3
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,20 @@ def installed_cuda_version(name=""):

def get_default_compute_capabilities():
compute_caps = DEFAULT_COMPUTE_CAPABILITIES
# Update compute capability according to: https://en.wikipedia.org/wiki/CUDA#GPUs_supported
import torch.utils.cpp_extension
if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version()[0] >= 11:
if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0:
# Special treatment of CUDA 11.0 because compute_86 is not supported.
compute_caps += ";8.0"
else:
if torch.utils.cpp_extension.CUDA_HOME is not None:
if installed_cuda_version()[0] == 11:
if installed_cuda_version()[1] >= 0:
compute_caps += ";8.0"
if installed_cuda_version()[1] >= 1:
compute_caps += ";8.6"
if installed_cuda_version()[1] >= 8:
compute_caps += ";9.0"
elif installed_cuda_version()[0] == 12:
compute_caps += ";8.0;8.6;9.0"
if installed_cuda_version()[1] >= 8:
compute_caps += ";10.0;12.0"
return compute_caps


Expand Down

0 comments on commit 38e9bf3

Please sign in to comment.