Skip to content

Commit

Permalink
create GNNlib.jl (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored May 4, 2024
1 parent 936da13 commit 389957f
Show file tree
Hide file tree
Showing 31 changed files with 5,497 additions and 0 deletions.
69 changes: 69 additions & 0 deletions GNNlib/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name = "GNNlib"
uuid = "a6a84749-d869-43f8-aacc-be26a1996e48"
authors = ["Carlo Lucibello and contributors"]
version = "0.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"

[extensions]
GNNlibCUDAExt = "CUDA"
GNNlibSimpleWeightedGraphsExt = "SimpleWeightedGraphs"

[compat]
Adapt = "3, 4"
CUDA = "4, 5"
ChainRulesCore = "1"
DataStructures = "0.18"
Functors = "0.4.1"
Graphs = "1.4"
KrylovKit = "0.6, 0.7"
LinearAlgebra = "1"
MLDatasets = "0.7"
MLUtils = "0.4"
MacroTools = "0.5"
NNlib = "0.9"
NearestNeighbors = "0.4"
Random = "1"
Reexport = "1"
SimpleWeightedGraphs = "1.4.0"
SparseArrays = "1"
Statistics = "1"
StatsBase = "0.34"
cuDNN = "1"
julia = "1.9"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]
14 changes: 14 additions & 0 deletions GNNlib/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# GNNlib.jl

This package contains a collection deep-learning framework agnostic
building blocks for graph neural networks such as graph convolutional layers and the implementation
of GraphGNN.

In the future it will serve as the foundation of GraphNeuralNetworks.jl (based on Flux,jl).
GNNlib.jl will be to GraphNeuralNetworks.jl what NNlib.jl is to Flux.jl and Lux.jl.

This package is currently under development and may break frequentely.
It is not meant for final users but for GNN libraries developers.
Final user should use GraphNeuralNetworks.jl instead.


2 changes: 2 additions & 0 deletions GNNlib/ext/GNNlibCUDAExt/GNNGraphs/query.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1))
2 changes: 2 additions & 0 deletions GNNlib/ext/GNNlibCUDAExt/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz)
8 changes: 8 additions & 0 deletions GNNlib/ext/GNNlibCUDAExt/GNNGraphs/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

GNNGraphs.iscuarray(x::AnyCuArray) = true


function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
#TODO proper cuda friendly implementation
sort_edge_index(u |> Flux.cpu, v |> Flux.cpu) |> Flux.gpu
end
17 changes: 17 additions & 0 deletions GNNlib/ext/GNNlibCUDAExt/GNNlibCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module GNNlibCUDAExt

using CUDA
using Random, Statistics, LinearAlgebra
using GNNlib
using GNNlib.GNNGraphs
using GNNlib.GNNGraphs: COO_T, ADJMAT_T, SPARSE_T
import GNNlib: propagate

const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}

include("GNNGraphs/query.jl")
include("GNNGraphs/transform.jl")
include("GNNGraphs/utils.jl")
include("msgpass.jl")

end #module
37 changes: 37 additions & 0 deletions GNNlib/ext/GNNlibCUDAExt/msgpass.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

###### PROPAGATE SPECIALIZATIONS ####################

## COPY_XJ

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e)
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
end

## E_MUL_XJ

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e::AbstractVector)
propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e)
end

## W_MUL_XJ

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e::Nothing)
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e)
end

# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
# A = adjacency_matrix(g, weighted=false)
# D = compute_degree(A)
# return xj * A * D
# end

# # Zygote bug. Error with sparse matrix without nograd
# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2)))

# Flux.Zygote.@nograd compute_degree
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module GNNlibSimpleWeightedGraphsExt

using GNNlib
using Graphs
using SimpleWeightedGraphs

function GNNlib.GNNGraph(g::T; kws...) where
{T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}}
return GNNGraph(g.weights, kws...)
end

end #module
108 changes: 108 additions & 0 deletions GNNlib/src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
module GNNGraphs

using SparseArrays
using Functors: @functor
import Graphs
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
has_self_loops, is_directed
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, batch
import NearestNeighbors
import NNlib
import StatsBase
import KrylovKit
using ChainRulesCore
using LinearAlgebra, Random, Statistics
import MLUtils
import Functors

include("chainrules.jl") # hacks for differentiability

include("datastore.jl")
export DataStore

include("abstracttypes.jl")
export AbstractGNNGraph

include("gnngraph.jl")
export GNNGraph,
node_features,
edge_features,
graph_features

include("gnnheterograph.jl")
export GNNHeteroGraph,
num_edge_types,
num_node_types,
edge_type_subgraph

include("temporalsnapshotsgnngraph.jl")
export TemporalSnapshotsGNNGraph,
add_snapshot,
# add_snapshot!,
remove_snapshot
# remove_snapshot!

include("query.jl")
export adjacency_list,
edge_index,
get_edge_weight,
graph_indicator,
has_multi_edges,
is_directed,
is_bidirected,
normalized_laplacian,
scaled_laplacian,
laplacian_lambda_max,
# from Graphs
adjacency_matrix,
degree,
has_self_loops,
has_isolated_nodes,
inneighbors,
outneighbors,
khop_adj

include("transform.jl")
export add_nodes,
add_edges,
add_self_loops,
getgraph,
negative_sample,
rand_edge_split,
remove_self_loops,
remove_edges,
remove_multi_edges,
set_edge_weight,
to_bidirected,
to_unidirected,
random_walk_pe,
remove_nodes,
# from Flux
batch,
unbatch,
# from SparseArrays
blockdiag

include("generate.jl")
export rand_graph,
rand_heterograph,
rand_bipartite_heterograph,
knn_graph,
radius_graph,
rand_temporal_radius_graph,
rand_temporal_hyperbolic_graph

include("sampling.jl")
export sample_neighbors

include("operators.jl")
# Base.intersect

include("convert.jl")
include("utils.jl")

include("gatherscatter.jl")
# _gather, _scatter

end #module
11 changes: 11 additions & 0 deletions GNNlib/src/GNNGraphs/abstracttypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}}
const ADJMAT_T = AbstractMatrix
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T

const AVecI = AbstractVector{<:Integer}

# All concrete graph types should be subtypes of AbstractGNNGraph{T}.
# GNNGraph and GNNHeteroGraph are the two concrete types.
abstract type AbstractGNNGraph{T} <: AbstractGraph{Int} end
15 changes: 15 additions & 0 deletions GNNlib/src/GNNGraphs/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648
# Remove when merged

function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
ks = map(first, ps)
project_ks, project_vs = map(ProjectTo, ks), map(ProjectTolast, ps)
function Dict_pullback(ȳ)
dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v
dk, dv = proj_k(getkey(ȳ, k, NoTangent())), proj_v(get(ȳ, k, NoTangent()))
Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
end
return (NoTangent(), dps...)
end
return T(ps...), Dict_pullback
end
Loading

0 comments on commit 389957f

Please sign in to comment.