Skip to content

Commit dce4e4b

Browse files
fix cuda ext
1 parent 80c672a commit dce4e4b

File tree

3 files changed

+16
-19
lines changed

3 files changed

+16
-19
lines changed

GNNlib/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GNNlib"
22
uuid = "a6a84749-d869-43f8-aacc-be26a1996e48"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.2.0-DEV"
4+
version = "0.2.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,37 @@
1+
module GNNlibCUDAExt
2+
3+
using CUDA
4+
using Random, Statistics, LinearAlgebra
5+
using GNNlib: GNNlib, propagate, copy_xj, e_mul_xj, w_mul_xj
6+
using GNNGraphs: GNNGraph, COO_T, SPARSE_T
17

28
###### PROPAGATE SPECIALIZATIONS ####################
39

410
## COPY_XJ
511

612
## avoid the fast path on gpu until we have better cuda support
7-
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
8-
xi, xj::AnyCuMatrix, e)
13+
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
14+
xi, xj::AnyCuMatrix, e)
915
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
1016
end
1117

1218
## E_MUL_XJ
1319

1420
## avoid the fast path on gpu until we have better cuda support
15-
function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
16-
xi, xj::AnyCuMatrix, e::AbstractVector)
21+
function GNNlib.propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
22+
xi, xj::AnyCuMatrix, e::AbstractVector)
1723
propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e)
1824
end
1925

2026
## W_MUL_XJ
2127

2228
## avoid the fast path on gpu until we have better cuda support
23-
function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
24-
xi, xj::AnyCuMatrix, e::Nothing)
29+
function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
30+
xi, xj::AnyCuMatrix, e::Nothing)
2531
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e)
2632
end
2733

28-
# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
34+
# function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
2935
# A = adjacency_matrix(g, weighted=false)
3036
# D = compute_degree(A)
3137
# return xj * A * D
@@ -35,3 +41,5 @@ end
3541
# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2)))
3642

3743
# Flux.Zygote.@nograd compute_degree
44+
45+
end #module

GNNlib/ext/GNNlibCUDAExt/GNNlibCUDAExt.jl

-11
This file was deleted.

0 commit comments

Comments
 (0)