2828
2929function create_matrix_list(N:: Int , n1:: Int , n2:: Int , :: Type{T} , :: Type{M} ) where {T, M}
3030
31- M_vec = M{T, 2 }(zeros(N* n1* n2, 1 )) # TODO
32- M_tensor = unsafe_wrap(M{T, 3 }, pointer(M_vec), (n1, n2, N); own= false ) # TODO
31+ M_vec = M{T, 2 }(zeros(N* n1* n2, 1 ))
32+ M_tensor = unsafe_wrap(M{T, 3 }, pointer(M_vec), (n1, n2, N); own= false )
3333 M_list = Vector{M{T, 2 }}(undef, N);
3434 ptr = pointer(M_tensor)
3535
3636 for i in 1 : N
3737 M_list[i] = unsafe_wrap(M{T, 2 }, ptr + n1* n2* (i- 1 )* sizeof(T), (n1, n2); own= false )
3838 end
3939
40- M_ptrs = CUBLAS . unsafe_batch (M_list) # TODO
40+ M_ptrs = device_batch (M_list)
4141
4242 return M_vec, M_tensor, M_list, M_ptrs
4343end
@@ -66,7 +66,7 @@ function factorize!(data::BlockTriDiagData_seq)
6666 A_ptrs = data. A_ptrs
6767 B_ptrs = data. B_ptrs
6868
69- @allowscalar cholesky_factorize!(A_ptrs, B_ptrs, N, n) # TODO
69+ @allowscalar cholesky_factorize!(A_ptrs, B_ptrs, N, n) # TODO check if works for both CUDA and ROCm
7070
7171end
7272
@@ -79,6 +79,6 @@ function solve!(data::BlockTriDiagData_seq)
7979 B_ptrs = data. B_ptrs
8080 d_ptrs = data. d_ptrs
8181
82- @allowscalar cholesky_solve!(A_ptrs, B_ptrs, d_ptrs, N, n, 1 ) # TODO
82+ @allowscalar cholesky_solve!(A_ptrs, B_ptrs, d_ptrs, N, n, 1 ) # TODO check if works for both CUDA and ROCm
8383
8484end
0 commit comments