Skip to content

Commit 389957f

Browse files
authoredMay 4, 2024··
create GNNlib.jl (#432)
1 parent 936da13 commit 389957f

31 files changed

+5497
-0
lines changed
 

‎GNNlib/Project.toml

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
name = "GNNlib"
2+
uuid = "a6a84749-d869-43f8-aacc-be26a1996e48"
3+
authors = ["Carlo Lucibello and contributors"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
10+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
11+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
12+
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
13+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
15+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
16+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
17+
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
18+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
19+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
20+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
21+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
22+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
23+
24+
[weakdeps]
25+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
26+
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
27+
28+
[extensions]
29+
GNNlibCUDAExt = "CUDA"
30+
GNNlibSimpleWeightedGraphsExt = "SimpleWeightedGraphs"
31+
32+
[compat]
33+
Adapt = "3, 4"
34+
CUDA = "4, 5"
35+
ChainRulesCore = "1"
36+
DataStructures = "0.18"
37+
Functors = "0.4.1"
38+
Graphs = "1.4"
39+
KrylovKit = "0.6, 0.7"
40+
LinearAlgebra = "1"
41+
MLDatasets = "0.7"
42+
MLUtils = "0.4"
43+
MacroTools = "0.5"
44+
NNlib = "0.9"
45+
NearestNeighbors = "0.4"
46+
Random = "1"
47+
Reexport = "1"
48+
SimpleWeightedGraphs = "1.4.0"
49+
SparseArrays = "1"
50+
Statistics = "1"
51+
StatsBase = "0.34"
52+
cuDNN = "1"
53+
julia = "1.9"
54+
55+
[extras]
56+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
57+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
58+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
59+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
60+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
61+
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
62+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
63+
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
64+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
65+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
66+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
67+
68+
[targets]
69+
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]

‎GNNlib/README.md

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# GNNlib.jl
2+
3+
This package contains a collection deep-learning framework agnostic
4+
building blocks for graph neural networks such as graph convolutional layers and the implementation
5+
of GraphGNN.
6+
7+
In the future it will serve as the foundation of GraphNeuralNetworks.jl (based on Flux,jl).
8+
GNNlib.jl will be to GraphNeuralNetworks.jl what NNlib.jl is to Flux.jl and Lux.jl.
9+
10+
This package is currently under development and may break frequentely.
11+
It is not meant for final users but for GNN libraries developers.
12+
Final user should use GraphNeuralNetworks.jl instead.
13+
14+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
GNNGraphs.iscuarray(x::AnyCuArray) = true
3+
4+
5+
function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
6+
#TODO proper cuda friendly implementation
7+
sort_edge_index(u |> Flux.cpu, v |> Flux.cpu) |> Flux.gpu
8+
end
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module GNNlibCUDAExt
2+
3+
using CUDA
4+
using Random, Statistics, LinearAlgebra
5+
using GNNlib
6+
using GNNlib.GNNGraphs
7+
using GNNlib.GNNGraphs: COO_T, ADJMAT_T, SPARSE_T
8+
import GNNlib: propagate
9+
10+
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
11+
12+
include("GNNGraphs/query.jl")
13+
include("GNNGraphs/transform.jl")
14+
include("GNNGraphs/utils.jl")
15+
include("msgpass.jl")
16+
17+
end #module

‎GNNlib/ext/GNNlibCUDAExt/msgpass.jl

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
2+
###### PROPAGATE SPECIALIZATIONS ####################
3+
4+
## COPY_XJ
5+
6+
## avoid the fast path on gpu until we have better cuda support
7+
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
8+
xi, xj::AnyCuMatrix, e)
9+
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
10+
end
11+
12+
## E_MUL_XJ
13+
14+
## avoid the fast path on gpu until we have better cuda support
15+
function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
16+
xi, xj::AnyCuMatrix, e::AbstractVector)
17+
propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e)
18+
end
19+
20+
## W_MUL_XJ
21+
22+
## avoid the fast path on gpu until we have better cuda support
23+
function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
24+
xi, xj::AnyCuMatrix, e::Nothing)
25+
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e)
26+
end
27+
28+
# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
29+
# A = adjacency_matrix(g, weighted=false)
30+
# D = compute_degree(A)
31+
# return xj * A * D
32+
# end
33+
34+
# # Zygote bug. Error with sparse matrix without nograd
35+
# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2)))
36+
37+
# Flux.Zygote.@nograd compute_degree
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module GNNlibSimpleWeightedGraphsExt
2+
3+
using GNNlib
4+
using Graphs
5+
using SimpleWeightedGraphs
6+
7+
function GNNlib.GNNGraph(g::T; kws...) where
8+
{T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}}
9+
return GNNGraph(g.weights, kws...)
10+
end
11+
12+
end #module

‎GNNlib/src/GNNGraphs/GNNGraphs.jl

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
module GNNGraphs
2+
3+
using SparseArrays
4+
using Functors: @functor
5+
import Graphs
6+
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
7+
has_self_loops, is_directed
8+
import MLUtils
9+
using MLUtils: getobs, numobs, ones_like, zeros_like, batch
10+
import NearestNeighbors
11+
import NNlib
12+
import StatsBase
13+
import KrylovKit
14+
using ChainRulesCore
15+
using LinearAlgebra, Random, Statistics
16+
import MLUtils
17+
import Functors
18+
19+
include("chainrules.jl") # hacks for differentiability
20+
21+
include("datastore.jl")
22+
export DataStore
23+
24+
include("abstracttypes.jl")
25+
export AbstractGNNGraph
26+
27+
include("gnngraph.jl")
28+
export GNNGraph,
29+
node_features,
30+
edge_features,
31+
graph_features
32+
33+
include("gnnheterograph.jl")
34+
export GNNHeteroGraph,
35+
num_edge_types,
36+
num_node_types,
37+
edge_type_subgraph
38+
39+
include("temporalsnapshotsgnngraph.jl")
40+
export TemporalSnapshotsGNNGraph,
41+
add_snapshot,
42+
# add_snapshot!,
43+
remove_snapshot
44+
# remove_snapshot!
45+
46+
include("query.jl")
47+
export adjacency_list,
48+
edge_index,
49+
get_edge_weight,
50+
graph_indicator,
51+
has_multi_edges,
52+
is_directed,
53+
is_bidirected,
54+
normalized_laplacian,
55+
scaled_laplacian,
56+
laplacian_lambda_max,
57+
# from Graphs
58+
adjacency_matrix,
59+
degree,
60+
has_self_loops,
61+
has_isolated_nodes,
62+
inneighbors,
63+
outneighbors,
64+
khop_adj
65+
66+
include("transform.jl")
67+
export add_nodes,
68+
add_edges,
69+
add_self_loops,
70+
getgraph,
71+
negative_sample,
72+
rand_edge_split,
73+
remove_self_loops,
74+
remove_edges,
75+
remove_multi_edges,
76+
set_edge_weight,
77+
to_bidirected,
78+
to_unidirected,
79+
random_walk_pe,
80+
remove_nodes,
81+
# from Flux
82+
batch,
83+
unbatch,
84+
# from SparseArrays
85+
blockdiag
86+
87+
include("generate.jl")
88+
export rand_graph,
89+
rand_heterograph,
90+
rand_bipartite_heterograph,
91+
knn_graph,
92+
radius_graph,
93+
rand_temporal_radius_graph,
94+
rand_temporal_hyperbolic_graph
95+
96+
include("sampling.jl")
97+
export sample_neighbors
98+
99+
include("operators.jl")
100+
# Base.intersect
101+
102+
include("convert.jl")
103+
include("utils.jl")
104+
105+
include("gatherscatter.jl")
106+
# _gather, _scatter
107+
108+
end #module

‎GNNlib/src/GNNGraphs/abstracttypes.jl

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
3+
const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}}
4+
const ADJMAT_T = AbstractMatrix
5+
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
6+
7+
const AVecI = AbstractVector{<:Integer}
8+
9+
# All concrete graph types should be subtypes of AbstractGNNGraph{T}.
10+
# GNNGraph and GNNHeteroGraph are the two concrete types.
11+
abstract type AbstractGNNGraph{T} <: AbstractGraph{Int} end

‎GNNlib/src/GNNGraphs/chainrules.jl

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648
2+
# Remove when merged
3+
4+
function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
5+
ks = map(first, ps)
6+
project_ks, project_vs = map(ProjectTo, ks), map(ProjectTolast, ps)
7+
function Dict_pullback(ȳ)
8+
dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v
9+
dk, dv = proj_k(getkey(ȳ, k, NoTangent())), proj_v(get(ȳ, k, NoTangent()))
10+
Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
11+
end
12+
return (NoTangent(), dps...)
13+
end
14+
return T(ps...), Dict_pullback
15+
end

0 commit comments

Comments
 (0)
Please sign in to comment.