|
| 1 | +module GNNlibCUDAExt |
| 2 | + |
| 3 | +using CUDA |
| 4 | +using Random, Statistics, LinearAlgebra |
| 5 | +using GNNlib: GNNlib, propagate, copy_xj, e_mul_xj, w_mul_xj |
| 6 | +using GNNGraphs: GNNGraph, COO_T, SPARSE_T |
1 | 7 |
|
2 | 8 | ###### PROPAGATE SPECIALIZATIONS ####################
|
3 | 9 |
|
4 | 10 | ## COPY_XJ
|
5 | 11 |
|
6 | 12 | ## avoid the fast path on gpu until we have better cuda support
|
7 |
| -function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), |
8 |
| - xi, xj::AnyCuMatrix, e) |
| 13 | +function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), |
| 14 | + xi, xj::AnyCuMatrix, e) |
9 | 15 | propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
|
10 | 16 | end
|
11 | 17 |
|
12 | 18 | ## E_MUL_XJ
|
13 | 19 |
|
14 | 20 | ## avoid the fast path on gpu until we have better cuda support
|
15 |
| -function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), |
16 |
| - xi, xj::AnyCuMatrix, e::AbstractVector) |
| 21 | +function GNNlib.propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), |
| 22 | + xi, xj::AnyCuMatrix, e::AbstractVector) |
17 | 23 | propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e)
|
18 | 24 | end
|
19 | 25 |
|
20 | 26 | ## W_MUL_XJ
|
21 | 27 |
|
22 | 28 | ## avoid the fast path on gpu until we have better cuda support
|
23 |
| -function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), |
24 |
| - xi, xj::AnyCuMatrix, e::Nothing) |
| 29 | +function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), |
| 30 | + xi, xj::AnyCuMatrix, e::Nothing) |
25 | 31 | propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e)
|
26 | 32 | end
|
27 | 33 |
|
28 |
| -# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) |
| 34 | +# function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) |
29 | 35 | # A = adjacency_matrix(g, weighted=false)
|
30 | 36 | # D = compute_degree(A)
|
31 | 37 | # return xj * A * D
|
|
35 | 41 | # compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2)))
|
36 | 42 |
|
37 | 43 | # Flux.Zygote.@nograd compute_degree
|
| 44 | + |
| 45 | +end #module |
0 commit comments