Skip to content

Commit 43e745e

Browse files
committed
unify device_batch for CUDA/ROCm
1 parent f643a0d commit 43e745e

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

src/backbone_cuda.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
using CUDA
22
using CUDA.CUSOLVER, CUDA.CUBLAS
33

4+
@inline function device_batch(batch::Vector{<:CuArray{T}}) where {T}
5+
ptrs = pointer.(batch)
6+
return CuArray(ptrs)
7+
end
8+
49
for (Xpotrf_buffer, Xpotrf, Xtrsm, Xgemm, T) in (
510
(:cusolverDnSpotrf_bufferSize, :cusolverDnSpotrf, :cublasStrsm_v2, :cublasSgemm_v2, :Float32),
611
(:cusolverDnDpotrf_bufferSize, :cusolverDnDpotrf, :cublasDtrsm_v2, :cublasDgemm_v2, :Float64),

src/backbone_rocm.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
using AMDGPU
22
using AMDGPU.rocSOLVER, AMDGPU.rocBLAS
33

4+
function device_batch(batch::Array{T}) where T <: ROCArray
5+
E = eltype(T)
6+
ROCArray([convert(Ptr{E}, arr.buf[]) for arr in batch])
7+
end
8+
49
for (Xpotrf, Xtrsm, Xgemm, T) in (
510
(:rocsolver_spotrf, :rocblas_strsm, :rocblas_sgemm, :Float32),
611
(:rocsolver_dpotrf, :rocblas_dtrsm, :rocblas_dgemm, :Float64),
@@ -151,7 +156,7 @@ for (XpotrfBatched, XtrsmBatched, XgemmBatched, T) in (
151156
function update_boundary!(M_ptrs_1::ROCVector{<:Ptr{$T}}, M_ptrs_2::ROCVector{<:Ptr{$T}}, d_ptrs::ROCVector{<:Ptr{$T}}, P, n, m)
152157

153158
dh = rocBLAS.handle()
154-
159+
155160
rocBLAS.$XgemmBatched(
156161
dh, rocBLAS.rocblas_operation_transpose, rocBLAS.rocblas_operation_none,
157162
n, 1, n, -one($T),

src/gpu_seq.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@ end
2828

2929
function create_matrix_list(N::Int, n1::Int, n2::Int, ::Type{T}, ::Type{M}) where {T, M}
3030

31-
M_vec = M{T, 2}(zeros(N*n1*n2, 1)) #TODO
32-
M_tensor = unsafe_wrap(M{T, 3}, pointer(M_vec), (n1, n2, N); own=false) #TODO
31+
M_vec = M{T, 2}(zeros(N*n1*n2, 1))
32+
M_tensor = unsafe_wrap(M{T, 3}, pointer(M_vec), (n1, n2, N); own=false)
3333
M_list = Vector{M{T, 2}}(undef, N);
3434
ptr = pointer(M_tensor)
3535

3636
for i in 1:N
3737
M_list[i] = unsafe_wrap(M{T, 2}, ptr + n1*n2*(i-1)*sizeof(T), (n1, n2); own=false)
3838
end
3939

40-
M_ptrs = CUBLAS.unsafe_batch(M_list) #TODO
40+
M_ptrs = device_batch(M_list)
4141

4242
return M_vec, M_tensor, M_list, M_ptrs
4343
end
@@ -66,7 +66,7 @@ function factorize!(data::BlockTriDiagData_seq)
6666
A_ptrs = data.A_ptrs
6767
B_ptrs = data.B_ptrs
6868

69-
@allowscalar cholesky_factorize!(A_ptrs, B_ptrs, N, n) #TODO
69+
@allowscalar cholesky_factorize!(A_ptrs, B_ptrs, N, n) #TODO check if works for both CUDA and ROCm
7070

7171
end
7272

@@ -79,6 +79,6 @@ function solve!(data::BlockTriDiagData_seq)
7979
B_ptrs = data.B_ptrs
8080
d_ptrs = data.d_ptrs
8181

82-
@allowscalar cholesky_solve!(A_ptrs, B_ptrs, d_ptrs, N, n, 1) #TODO
82+
@allowscalar cholesky_solve!(A_ptrs, B_ptrs, d_ptrs, N, n, 1) #TODO check if works for both CUDA and ROCm
8383

8484
end

0 commit comments

Comments
 (0)