diff --git a/CMakeLists.txt b/CMakeLists.txt
index bb427a565053..2926baac9cdd 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -285,6 +285,9 @@ if(USE_ROCM)
endif()
message(STATUS "CMAKE_HIP_FLAGS: ${CMAKE_HIP_FLAGS}")
+ # Building for ROCm almost always means USE_CUDA.
+ # Exceptions to this will be guarded by USE_ROCM.
+ add_definitions(-DUSE_CUDA)
add_definitions(-DUSE_ROCM)
endif()
@@ -471,10 +474,21 @@ set(
src/cuda/cuda_algorithms.cu
)
-if(USE_CUDA)
+if(USE_CUDA OR USE_ROCM)
list(APPEND LGBM_SOURCES ${LGBM_CUDA_SOURCES})
endif()
+if(USE_ROCM)
+ set(CU_FILES "")
+ foreach(file IN LISTS LGBM_CUDA_SOURCES)
+ string(REGEX MATCH "\\.cu$" is_cu_file "${file}")
+ if(is_cu_file)
+ list(APPEND CU_FILES "${file}")
+ endif()
+ endforeach()
+ set_source_files_properties(${CU_FILES} PROPERTIES LANGUAGE HIP)
+endif()
+
add_library(lightgbm_objs OBJECT ${LGBM_SOURCES})
if(BUILD_CLI)
@@ -623,6 +637,10 @@ if(USE_CUDA)
endif()
endif()
+if(USE_ROCM)
+ target_link_libraries(lightgbm_objs PUBLIC hip::host)
+endif()
+
if(WIN32)
if(MINGW OR CYGWIN)
target_link_libraries(lightgbm_objs PUBLIC ws2_32 iphlpapi)
diff --git a/build-python.sh b/build-python.sh
index a7fce2b9ce3f..9b93110db0dd 100755
--- a/build-python.sh
+++ b/build-python.sh
@@ -54,6 +54,8 @@
# --precompile
# Use precompiled library.
# Only used with 'install' command.
+# --rocm
+# Compile ROCm version.
# --time-costs
# Compile version that outputs time costs for different internal routines.
# --user
@@ -142,6 +144,9 @@ while [ $# -gt 0 ]; do
--cuda)
BUILD_ARGS="${BUILD_ARGS} --config-setting=cmake.define.USE_CUDA=ON"
;;
+ --rocm)
+ BUILD_ARGS="${BUILD_ARGS} --config-setting=cmake.define.USE_ROCM=ON"
+ ;;
--gpu)
BUILD_ARGS="${BUILD_ARGS} --config-setting=cmake.define.USE_GPU=ON"
;;
diff --git a/docs/Installation-Guide.rst b/docs/Installation-Guide.rst
index 072789d91356..0b726bc97417 100644
--- a/docs/Installation-Guide.rst
+++ b/docs/Installation-Guide.rst
@@ -749,6 +749,65 @@ macOS
The CUDA version is not supported on macOS.
+Build ROCm Version
+~~~~~~~~~~~~~~~~~~
+
+The `original GPU version <#build-gpu-version>`__ of LightGBM (``device_type=gpu``) is based on OpenCL.
+
+The ROCm-based version (``device_type=cuda``) is a separate implementation. Yes, the ROCm version reuses the ``device_type=cuda`` as a convenience for users. Use this version in Linux environments with an AMD GPU.
+
+Windows
+^^^^^^^
+
+The ROCm version is not supported on Windows.
+Use the `GPU version <#build-gpu-version>`__ (``device_type=gpu``) for GPU acceleration on Windows.
+
+Linux
+^^^^^
+
+On Linux, a ROCm version of LightGBM can be built using
+
+- **CMake**, **gcc** and **ROCm**;
+- **CMake**, **Clang** and **ROCm**.
+
+Please refer to `the ROCm docs`_ for **ROCm** libraries installation.
+
+After compilation the executable and ``.so`` files will be in ``LightGBM/`` folder.
+
+gcc
+***
+
+1. Install `CMake`_, **gcc** and **ROCm**.
+
+2. Run the following commands:
+
+ .. code:: sh
+
+ git clone --recursive https://github.com/microsoft/LightGBM
+ cd LightGBM
+ cmake -B build -S . -DUSE_ROCM=ON
+ cmake --build build -j4
+
+Clang
+*****
+
+1. Install `CMake`_, **Clang**, **OpenMP** and **ROCm**.
+
+2. Run the following commands:
+
+ .. code:: sh
+
+ git clone --recursive https://github.com/microsoft/LightGBM
+ cd LightGBM
+ export CXX=clang++-14 CC=clang-14 # replace "14" with version of Clang installed on your machine
+ cmake -B build -S . -DUSE_ROCM=ON
+ cmake --build build -j4
+
+macOS
+^^^^^
+
+The ROCm version is not supported on macOS.
+
Build Java Wrapper
~~~~~~~~~~~~~~~~~~
@@ -1054,6 +1113,8 @@ gcc
.. _this detailed guide: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html
+.. _the ROCm docs: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/
+
.. _following docs: https://github.com/google/sanitizers/wiki
.. _Ninja: https://ninja-build.org
diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 3a4d880ef2e0..00b2dade4383 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -264,7 +264,7 @@ Core Parameters
- ``cpu`` supports all LightGBM functionality and is portable across the widest range of operating systems and hardware
- - ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA
+ - ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA or ROCm
- ``gpu`` can be faster than ``cpu`` and works on a wider range of GPUs than CUDA
@@ -272,7 +272,7 @@ Core Parameters
- **Note**: for the faster speed, GPU uses 32-bit float point to sum up by default, so this may affect the accuracy for some tasks. You can set ``gpu_use_dp=true`` to enable 64-bit float point, but it will slow down the training
- - **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU or CUDA support
+ - **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU, CUDA, or ROCm support
- ``seed`` :raw-html:`🔗︎`, default = ``None``, type = int, aliases: ``random_seed``, ``random_state``
diff --git a/docs/_static/js/script.js b/docs/_static/js/script.js
index bcc11349b61a..b62a3c539f68 100644
--- a/docs/_static/js/script.js
+++ b/docs/_static/js/script.js
@@ -22,6 +22,7 @@ $(() => {
"#build-mpi-version",
"#build-gpu-version",
"#build-cuda-version",
+ "#build-rocm-version",
"#build-java-wrapper",
"#build-python-package",
"#build-r-package",
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index a8070a2dd8da..500dd8a1c7fe 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -246,11 +246,11 @@ struct Config {
// alias = device
// desc = device for the tree learning
// desc = ``cpu`` supports all LightGBM functionality and is portable across the widest range of operating systems and hardware
- // desc = ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA
+ // desc = ``cuda`` offers faster training than ``gpu`` or ``cpu``, but only works on GPUs supporting CUDA or ROCm
// desc = ``gpu`` can be faster than ``cpu`` and works on a wider range of GPUs than CUDA
// desc = **Note**: it is recommended to use the smaller ``max_bin`` (e.g. 63) to get the better speed up
// desc = **Note**: for the faster speed, GPU uses 32-bit float point to sum up by default, so this may affect the accuracy for some tasks. You can set ``gpu_use_dp=true`` to enable 64-bit float point, but it will slow down the training
- // desc = **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU or CUDA support
+ // desc = **Note**: refer to `Installation Guide <./Installation-Guide.rst>`__ to build LightGBM with GPU, CUDA, or ROCm support
std::string device_type = "cpu";
// [no-automatically-extract]
diff --git a/include/LightGBM/cuda/cuda_algorithms.hpp b/include/LightGBM/cuda/cuda_algorithms.hpp
index 9a5d208b98dc..45a1d5073577 100644
--- a/include/LightGBM/cuda/cuda_algorithms.hpp
+++ b/include/LightGBM/cuda/cuda_algorithms.hpp
@@ -9,8 +9,10 @@
#ifdef USE_CUDA
+#ifndef USE_ROCM
#include
#include
+#endif
#include
#include
diff --git a/include/LightGBM/cuda/cuda_random.hpp b/include/LightGBM/cuda/cuda_random.hpp
index efe0d4fc83dd..f1bcef080c30 100644
--- a/include/LightGBM/cuda/cuda_random.hpp
+++ b/include/LightGBM/cuda/cuda_random.hpp
@@ -7,8 +7,10 @@
#ifdef USE_CUDA
+#ifndef USE_ROCM
#include
#include
+#endif
namespace LightGBM {
diff --git a/include/LightGBM/cuda/cuda_rocm_interop.h b/include/LightGBM/cuda/cuda_rocm_interop.h
index 9629ac57347e..1f789ccc6ae8 100644
--- a/include/LightGBM/cuda/cuda_rocm_interop.h
+++ b/include/LightGBM/cuda/cuda_rocm_interop.h
@@ -7,16 +7,59 @@
#ifdef USE_CUDA
-#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP__)
-// ROCm doesn't have __shfl_down_sync, only __shfl_down without mask.
+#if defined(__HIP_PLATFORM_AMD__)
+
+// ROCm doesn't have atomicAdd_block, but it should be semantically the same as atomicAdd
+#define atomicAdd_block atomicAdd
+
+// hipify
+#include
+#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceSynchronize hipDeviceSynchronize
+#define cudaError_t hipError_t
+#define cudaFree hipFree
+#define cudaFreeHost hipFreeHost
+#define cudaGetDevice hipGetDevice
+#define cudaGetDeviceProperties hipGetDeviceProperties
+#define cudaGetErrorName hipGetErrorName
+#define cudaGetErrorString hipGetErrorString
+#define cudaGetLastError hipGetLastError
+#define cudaHostAlloc hipHostAlloc
+#define cudaHostAllocPortable hipHostAllocPortable
+#define cudaMalloc hipMalloc
+#define cudaMemcpy hipMemcpy
+#define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
+#define cudaMemoryTypeHost hipMemoryTypeHost
+#define cudaMemset hipMemset
+#define cudaPointerAttributes hipPointerAttribute_t
+#define cudaPointerGetAttributes hipPointerGetAttributes
+#define cudaSetDevice hipSetDevice
+#define cudaStreamCreate hipStreamCreate
+#define cudaStreamDestroy hipStreamDestroy
+#define cudaStream_t hipStream_t
+#define cudaSuccess hipSuccess
+
+// ROCm 7.0 did add __shfl_down_sync et al, but the following hack still works.
// Since mask is full 0xffffffff, we can use __shfl_down instead.
#define __shfl_down_sync(mask, val, offset) __shfl_down(val, offset)
#define __shfl_up_sync(mask, val, offset) __shfl_up(val, offset)
-// ROCm warpSize is constexpr and is either 32 or 64 depending on gfx arch.
-#define WARPSIZE warpSize
-// ROCm doesn't have atomicAdd_block, but it should be semantically the same as atomicAdd
-#define atomicAdd_block atomicAdd
-#else
+
+// warpSize is only allowed for device code.
+// HIP header used to define warpSize as a constexpr that was either 32 or 64
+// depending on the target device, and then always set it to 64 for host code.
+static inline constexpr int WARP_SIZE_INTERNAL() {
+#if defined(__GFX9__)
+ return 64;
+#else // __GFX9__
+ return 32;
+#endif // __GFX9__
+}
+#define WARPSIZE (WARP_SIZE_INTERNAL())
+
+#else // __HIP_PLATFORM_AMD__
// CUDA warpSize is not a constexpr, but always 32
#define WARPSIZE 32
#endif // defined(__HIP_PLATFORM_AMD__) || defined(__HIP__)
diff --git a/include/LightGBM/cuda/cuda_utils.hu b/include/LightGBM/cuda/cuda_utils.hu
index 4bd84aeb264d..c6a61733ca84 100644
--- a/include/LightGBM/cuda/cuda_utils.hu
+++ b/include/LightGBM/cuda/cuda_utils.hu
@@ -8,8 +8,12 @@
#ifdef USE_CUDA
+#if defined(USE_ROCM)
+#include
+#else
#include
#include
+#endif
#include
#include
diff --git a/include/LightGBM/cuda/vector_cudahost.h b/include/LightGBM/cuda/vector_cudahost.h
index 9a0d69225b7c..3f8ea4d00621 100644
--- a/include/LightGBM/cuda/vector_cudahost.h
+++ b/include/LightGBM/cuda/vector_cudahost.h
@@ -9,9 +9,12 @@
#include
#ifdef USE_CUDA
+#ifndef USE_ROCM
#include
#include
-#endif
+#endif // USE_ROCM
+#include
+#endif // USE_CUDA
#include
enum LGBM_Device {
@@ -66,14 +69,14 @@ struct CHAllocator {
#ifdef USE_CUDA
if (LGBM_config_::current_device == lgbm_device_cuda) {
cudaPointerAttributes attributes;
- cudaPointerGetAttributes(&attributes, p);
- #if CUDA_VERSION >= 10000
+ CUDASUCCESS_OR_FATAL(cudaPointerGetAttributes(&attributes, p));
+ #if CUDA_VERSION >= 10000 || defined(USE_ROCM)
if ((attributes.type == cudaMemoryTypeHost) && (attributes.devicePointer != NULL)) {
- cudaFreeHost(p);
+ CUDASUCCESS_OR_FATAL(cudaFreeHost(p));
}
#else
if ((attributes.memoryType == cudaMemoryTypeHost) && (attributes.devicePointer != NULL)) {
- cudaFreeHost(p);
+ CUDASUCCESS_OR_FATAL(cudaFreeHost(p));
}
#endif
} else {
diff --git a/src/cuda/cuda_utils.cpp b/src/cuda/cuda_utils.cpp
index b601f9395268..968b4c1e8799 100644
--- a/src/cuda/cuda_utils.cpp
+++ b/src/cuda/cuda_utils.cpp
@@ -5,6 +5,7 @@
#ifdef USE_CUDA
+#include
#include
namespace LightGBM {
diff --git a/src/treelearner/cuda/cuda_best_split_finder.cu b/src/treelearner/cuda/cuda_best_split_finder.cu
index 6b1d16748868..26c740c38bd6 100644
--- a/src/treelearner/cuda/cuda_best_split_finder.cu
+++ b/src/treelearner/cuda/cuda_best_split_finder.cu
@@ -934,7 +934,11 @@ __global__ void FindBestSplitsDiscretizedForLeafKernel(
if (is_feature_used_bytree[inner_feature_index]) {
if (task->is_categorical) {
__threadfence(); // ensure store issued before trap
+#if defined(USE_ROCM)
+ __builtin_trap();
+#else
asm("trap;");
+#endif
} else {
if (!task->reverse) {
if (use_16bit_bin) {
diff --git a/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp b/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp
index 51970bf1b210..e53871fc060c 100644
--- a/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp
+++ b/src/treelearner/cuda/cuda_single_gpu_tree_learner.hpp
@@ -155,7 +155,7 @@ class CUDASingleGPUTreeLearner: public SerialTreeLearner {
#pragma warning(disable : 4702)
explicit CUDASingleGPUTreeLearner(const Config* tree_config, const bool /*boosting_on_cuda*/) : SerialTreeLearner(tree_config) {
Log::Fatal("CUDA Tree Learner was not enabled in this build.\n"
- "Please recompile with CMake option -DUSE_CUDA=1");
+ "Please recompile with CMake option -DUSE_CUDA=1 (NVIDIA GPUs) or -DUSE_ROCM=1 (AMD GPUs)");
}
};