Skip to content

Commit 56de8a3

Browse files
atalmanmalfet
andauthored
Add manual cuda deps search logic (#90411) (#90426)
If PyTorch is package into a wheel with [nvidia-cublas-cu11](https://pypi.org/project/nvidia-cublas-cu11/), which is designated as PureLib, but `torch` wheel is not, can cause a torch_globals loading problem. Fix that by searching for `nvidia/cublas/lib/libcublas.so.11` an `nvidia/cudnn/lib/libcudnn.so.8` across all `sys.path` folders. Test plan: ``` docker pull amazonlinux:2 docker run --rm -t amazonlinux:2 bash -c 'yum install -y python3 python3-devel python3-distutils patch;python3 -m pip install torch==1.13.0;curl -OL https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/90411.diff; pushd /usr/local/lib64/python3.7/site-packages; patch -p1 </90411.diff; popd; python3 -c "import torch;print(torch.__version__, torch.cuda.is_available())"' ``` Fixes #88869 Pull Request resolved: #90411 Approved by: https://github.com/atalman Co-authored-by: Nikita Shulga <nshulga@meta.com>
1 parent a4d16e0 commit 56de8a3

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

torch/__init__.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,24 @@
141141
kernel32.SetErrorMode(prev_error_mode)
142142

143143

144+
def _preload_cuda_deps():
145+
""" Preloads cudnn/cublas deps if they could not be found otherwise """
146+
# Should only be called on Linux if default path resolution have failed
147+
assert platform.system() == 'Linux', 'Should only be called on Linux'
148+
for path in sys.path:
149+
nvidia_path = os.path.join(path, 'nvidia')
150+
if not os.path.exists(nvidia_path):
151+
continue
152+
cublas_path = os.path.join(nvidia_path, 'cublas', 'lib', 'libcublas.so.11')
153+
cudnn_path = os.path.join(nvidia_path, 'cudnn', 'lib', 'libcudnn.so.8')
154+
if not os.path.exists(cublas_path) or not os.path.exists(cudnn_path):
155+
continue
156+
break
157+
158+
ctypes.CDLL(cublas_path)
159+
ctypes.CDLL(cudnn_path)
160+
161+
144162
# See Note [Global dependencies]
145163
def _load_global_deps():
146164
if platform.system() == 'Windows' or sys.executable == 'torch_deploy':
@@ -150,7 +168,15 @@ def _load_global_deps():
150168
here = os.path.abspath(__file__)
151169
lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name)
152170

153-
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
171+
try:
172+
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
173+
except OSError as err:
174+
# Can only happen of wheel with cublas as PYPI deps
175+
# As PyTorch is not purelib, but nvidia-cublas-cu11 is
176+
if 'libcublas.so.11' not in err.args[0]:
177+
raise err
178+
_preload_cuda_deps()
179+
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
154180

155181

156182
if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \

0 commit comments

Comments
 (0)