Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion nvdiffrast/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import numpy as np
import os
import platform
import torch
import torch.utils.cpp_extension

Expand Down Expand Up @@ -120,9 +121,31 @@ def get_sort_key(x):
except:
pass

# ---- BEGIN MODIFICATION: Detect conda CUDA include path ----
extra_include_paths = []
conda_prefix = os.environ.get('CONDA_PREFIX')
if conda_prefix and platform.system() == 'Linux': # Only apply this logic on Linux for now
# Default target for nvidia conda channel packages
conda_target_arch = 'x86_64-linux'
# Construct the potential non-standard include path
potential_include_path = os.path.join(conda_prefix, 'targets', conda_target_arch, 'include')
if os.path.isdir(potential_include_path):
logging.getLogger('nvdiffrast').info(f"Detected conda environment and adding potential CUDA include path: {potential_include_path}")
extra_include_paths.append(potential_include_path)
# ---- END MODIFICATION ----

# Compile and load.
source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files]
torch.utils.cpp_extension.load(name=plugin_name, sources=source_paths, extra_cflags=common_opts+cc_opts, extra_cuda_cflags=common_opts+['-lineinfo'], extra_ldflags=ldflags, with_cuda=True, verbose=False)
torch.utils.cpp_extension.load(
name=plugin_name,
sources=source_paths,
extra_cflags=common_opts+cc_opts,
extra_cuda_cflags=common_opts+['-lineinfo'],
extra_ldflags=ldflags,
extra_include_paths=extra_include_paths,
with_cuda=True,
verbose=False
)

# Import, cache, and return the compiled module.
_cached_plugin[gl] = importlib.import_module(plugin_name)
Expand Down