Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inplace version of batched adjoint/transpose #502

Open
chengchingwen opened this issue Jun 24, 2023 · 7 comments
Open

Inplace version of batched adjoint/transpose #502

chengchingwen opened this issue Jun 24, 2023 · 7 comments

Comments

@chengchingwen
Copy link
Member

We are missing an the inplace version of batched adjoint/transpose. They are required to avoid gpu scalar indexing with Base.copy like copy(batched_adjoint(CUDA.randn(3,5,2))). They can be implemented as:

# Inplace                                                                                                                 
function batched_transpose_f!(f, B::AbstractArray{T, 3}, A::AbstractArray{T, 3}) where T
    axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f)))
    Threads.@threads for i in axes(A,3)
        Bi = @view B[:, :, i]
        Ai = @view A[:, :, i]
        LinearAlgebra.transpose_f!(f, Bi, Ai)
    end
    return B
end

using GPUArrays
function batched_transpose_f!(f, B::AnyGPUArray{T, 3}, A::AnyGPUArray{T, 3}) where T
    axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f)))
    GPUArrays.gpu_call(B, A) do ctx, B, A
        idx = GPUArrays.@cartesianidx A
        @inbounds B[idx[2], idx[1], idx[3]] = f(A[idx[1], idx[2], idx[3]])
        return
    end
    return B
end

batched_adjoint!(B, A) = batched_transpose_f!(adjoint, B, A)
batched_transpose!(B, A) = batched_transpose_f!(transpose, B, A)

# copy                                                                                                                    
function Base.copy(x::BatchedAdjoint)
    p = parent(x)
    a1, a2, a3 = axes(p)
    return batched_adjoint!(similar(p, (a2, a1, a3)), p)
end
function Base.copy(x::BatchedTranspose)
    p = parent(x)
    a1, a2, a3 = axes(p)
    return batched_transpose!(similar(p, (a2, a1, a3)), p)
end

which require an extra dependency of GPUArrays. I have no idea where should we put these code under the ext.

@CarloLucibello
Copy link
Member

Do you want to file a PR? You can create a NNlibGPUArraysExt extension.

@chengchingwen
Copy link
Member Author

Would NNlibGPUArraysExt be loaded along with NNlibCUDAExt or NNlibAMDGPUExt?

It is also possible to rewrite the batched_transpose_f! with KernelAbstractions so the extra dependency is avoided.

@CarloLucibello
Copy link
Member

Since both CUDA and AMDGPU will load GPUArrays, the GPUArraysExt will be loaded as well when doing using NNlib, CUDA. That said, using KernelAbstractions seems the best way to do it.

@chengchingwen
Copy link
Member Author

I'm not familiar with KernelAbstractions. If anyone want to try, feel free to take over.

@ToucheSir
Copy link
Member

Here's an untested translation to KernelAbstractions. I made some minor style changes to better match existing KA-using functions like gather! and scatter!. If anyone has a test suite to run this on, feel free to steal the below code for a PR.

function batched_transpose_f!(f, dst::AnyGPUArray{<:Any, 3}, src::AnyGPUArray{<:Any, 3})
    axes(dst, 1) == axes(src, 2) && axes(dst, 2) == axes(src, 1) && axes(src, 3) == axes(dst, 3) || throw(DimensionMismatch(string(f)))
    backend = KernelAbstractions.get_backend(src)
    _batched_transpose_f!(backend)(f, dst, src; ndrange=size(src))
    return B
end

@kernel function _batched_transpose_f!(f::F, dst, @Const(src)) where F
    i, j, k = @index(Global, NTuple)
    @inbounds dst[j, i, k] = f(src[i, j, k])
end

@mcabbott
Copy link
Member

Isn't this permutedims!, at least for real numbers?

@chengchingwen
Copy link
Member Author

Yes, but we still need to overload Base.copy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants