Skip to content

Commit

Permalink
make detailed comment in cuda_shift
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Aug 3, 2020
1 parent 1104da6 commit a5bbb26
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions mmaction/models/backbones/cuda_shift/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@


def _find_cuda_home():
# guess rule 3 of torch.utils.cpp_extension
# set directory path to CUDA which supports `nvcc` command as CUDA_HOME
# other than 3 Guess in PyTorch
nvcc = subprocess.check_output(['which', 'nvcc']).decode().rstrip('\r\n')
cuda_home = os.path.dirname(os.path.dirname(nvcc))
print(f'find cuda home:{cuda_home}')
return cuda_home


# remember to overwrite PyTorch auto-detected cuda home which
# may not be our expected
# overwrite PyTorch auto-detected CUDA_HOME which may not be our expected,
# because PyTorch determines the CUDA_HOME using this priority:
# os.environ['CUDA_HOME'] > path to CUDA supporting `nvcc` > '/usr/local/cuda',
# although The first two guess should be the same in most cases.
torch.utils.cpp_extension.CUDA_HOME = _find_cuda_home()
CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME

Expand All @@ -34,11 +37,10 @@ def _find_cuda_home():
else:
raise ValueError('CUDA is not available')

extra_compile_args = dict(cxx=[]) # ['-fopenmp', '-std=c99']
extra_compile_args = dict(cxx=[])
extra_compile_args['nvcc'] = []

this_file = os.path.dirname(os.path.realpath(__file__))
print(this_file)
sources = [os.path.join(this_file, fname) for fname in sources]
headers = [os.path.join(this_file, fname) for fname in headers]
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
Expand Down

0 comments on commit a5bbb26

Please sign in to comment.