-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
936da13
commit 389957f
Showing
31 changed files
with
5,497 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
name = "GNNlib" | ||
uuid = "a6a84749-d869-43f8-aacc-be26a1996e48" | ||
authors = ["Carlo Lucibello and contributors"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" | ||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | ||
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" | ||
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | ||
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" | ||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" | ||
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69" | ||
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
|
||
[weakdeps] | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" | ||
|
||
[extensions] | ||
GNNlibCUDAExt = "CUDA" | ||
GNNlibSimpleWeightedGraphsExt = "SimpleWeightedGraphs" | ||
|
||
[compat] | ||
Adapt = "3, 4" | ||
CUDA = "4, 5" | ||
ChainRulesCore = "1" | ||
DataStructures = "0.18" | ||
Functors = "0.4.1" | ||
Graphs = "1.4" | ||
KrylovKit = "0.6, 0.7" | ||
LinearAlgebra = "1" | ||
MLDatasets = "0.7" | ||
MLUtils = "0.4" | ||
MacroTools = "0.5" | ||
NNlib = "0.9" | ||
NearestNeighbors = "0.4" | ||
Random = "1" | ||
Reexport = "1" | ||
SimpleWeightedGraphs = "1.4.0" | ||
SparseArrays = "1" | ||
Statistics = "1" | ||
StatsBase = "0.34" | ||
cuDNN = "1" | ||
julia = "1.9" | ||
|
||
[extras] | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" | ||
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" | ||
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" | ||
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" | ||
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" | ||
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" | ||
|
||
[targets] | ||
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# GNNlib.jl | ||
|
||
This package contains a collection deep-learning framework agnostic | ||
building blocks for graph neural networks such as graph convolutional layers and the implementation | ||
of GraphGNN. | ||
|
||
In the future it will serve as the foundation of GraphNeuralNetworks.jl (based on Flux,jl). | ||
GNNlib.jl will be to GraphNeuralNetworks.jl what NNlib.jl is to Flux.jl and Lux.jl. | ||
|
||
This package is currently under development and may break frequentely. | ||
It is not meant for final users but for GNN libraries developers. | ||
Final user should use GraphNeuralNetworks.jl instead. | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
|
||
GNNGraphs.iscuarray(x::AnyCuArray) = true | ||
|
||
|
||
function sort_edge_index(u::AnyCuArray, v::AnyCuArray) | ||
#TODO proper cuda friendly implementation | ||
sort_edge_index(u |> Flux.cpu, v |> Flux.cpu) |> Flux.gpu | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
module GNNlibCUDAExt | ||
|
||
using CUDA | ||
using Random, Statistics, LinearAlgebra | ||
using GNNlib | ||
using GNNlib.GNNGraphs | ||
using GNNlib.GNNGraphs: COO_T, ADJMAT_T, SPARSE_T | ||
import GNNlib: propagate | ||
|
||
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix} | ||
|
||
include("GNNGraphs/query.jl") | ||
include("GNNGraphs/transform.jl") | ||
include("GNNGraphs/utils.jl") | ||
include("msgpass.jl") | ||
|
||
end #module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
|
||
###### PROPAGATE SPECIALIZATIONS #################### | ||
|
||
## COPY_XJ | ||
|
||
## avoid the fast path on gpu until we have better cuda support | ||
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), | ||
xi, xj::AnyCuMatrix, e) | ||
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e) | ||
end | ||
|
||
## E_MUL_XJ | ||
|
||
## avoid the fast path on gpu until we have better cuda support | ||
function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), | ||
xi, xj::AnyCuMatrix, e::AbstractVector) | ||
propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e) | ||
end | ||
|
||
## W_MUL_XJ | ||
|
||
## avoid the fast path on gpu until we have better cuda support | ||
function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), | ||
xi, xj::AnyCuMatrix, e::Nothing) | ||
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e) | ||
end | ||
|
||
# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) | ||
# A = adjacency_matrix(g, weighted=false) | ||
# D = compute_degree(A) | ||
# return xj * A * D | ||
# end | ||
|
||
# # Zygote bug. Error with sparse matrix without nograd | ||
# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2))) | ||
|
||
# Flux.Zygote.@nograd compute_degree |
12 changes: 12 additions & 0 deletions
12
GNNlib/ext/GNNlibSimpleWeightedGraphsExt/GNNlibSimpleWeightedGraphsExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
module GNNlibSimpleWeightedGraphsExt | ||
|
||
using GNNlib | ||
using Graphs | ||
using SimpleWeightedGraphs | ||
|
||
function GNNlib.GNNGraph(g::T; kws...) where | ||
{T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}} | ||
return GNNGraph(g.weights, kws...) | ||
end | ||
|
||
end #module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
module GNNGraphs | ||
|
||
using SparseArrays | ||
using Functors: @functor | ||
import Graphs | ||
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, | ||
has_self_loops, is_directed | ||
import MLUtils | ||
using MLUtils: getobs, numobs, ones_like, zeros_like, batch | ||
import NearestNeighbors | ||
import NNlib | ||
import StatsBase | ||
import KrylovKit | ||
using ChainRulesCore | ||
using LinearAlgebra, Random, Statistics | ||
import MLUtils | ||
import Functors | ||
|
||
include("chainrules.jl") # hacks for differentiability | ||
|
||
include("datastore.jl") | ||
export DataStore | ||
|
||
include("abstracttypes.jl") | ||
export AbstractGNNGraph | ||
|
||
include("gnngraph.jl") | ||
export GNNGraph, | ||
node_features, | ||
edge_features, | ||
graph_features | ||
|
||
include("gnnheterograph.jl") | ||
export GNNHeteroGraph, | ||
num_edge_types, | ||
num_node_types, | ||
edge_type_subgraph | ||
|
||
include("temporalsnapshotsgnngraph.jl") | ||
export TemporalSnapshotsGNNGraph, | ||
add_snapshot, | ||
# add_snapshot!, | ||
remove_snapshot | ||
# remove_snapshot! | ||
|
||
include("query.jl") | ||
export adjacency_list, | ||
edge_index, | ||
get_edge_weight, | ||
graph_indicator, | ||
has_multi_edges, | ||
is_directed, | ||
is_bidirected, | ||
normalized_laplacian, | ||
scaled_laplacian, | ||
laplacian_lambda_max, | ||
# from Graphs | ||
adjacency_matrix, | ||
degree, | ||
has_self_loops, | ||
has_isolated_nodes, | ||
inneighbors, | ||
outneighbors, | ||
khop_adj | ||
|
||
include("transform.jl") | ||
export add_nodes, | ||
add_edges, | ||
add_self_loops, | ||
getgraph, | ||
negative_sample, | ||
rand_edge_split, | ||
remove_self_loops, | ||
remove_edges, | ||
remove_multi_edges, | ||
set_edge_weight, | ||
to_bidirected, | ||
to_unidirected, | ||
random_walk_pe, | ||
remove_nodes, | ||
# from Flux | ||
batch, | ||
unbatch, | ||
# from SparseArrays | ||
blockdiag | ||
|
||
include("generate.jl") | ||
export rand_graph, | ||
rand_heterograph, | ||
rand_bipartite_heterograph, | ||
knn_graph, | ||
radius_graph, | ||
rand_temporal_radius_graph, | ||
rand_temporal_hyperbolic_graph | ||
|
||
include("sampling.jl") | ||
export sample_neighbors | ||
|
||
include("operators.jl") | ||
# Base.intersect | ||
|
||
include("convert.jl") | ||
include("utils.jl") | ||
|
||
include("gatherscatter.jl") | ||
# _gather, _scatter | ||
|
||
end #module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
|
||
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V} | ||
const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}} | ||
const ADJMAT_T = AbstractMatrix | ||
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T | ||
|
||
const AVecI = AbstractVector{<:Integer} | ||
|
||
# All concrete graph types should be subtypes of AbstractGNNGraph{T}. | ||
# GNNGraph and GNNHeteroGraph are the two concrete types. | ||
abstract type AbstractGNNGraph{T} <: AbstractGraph{Int} end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648 | ||
# Remove when merged | ||
|
||
function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict} | ||
ks = map(first, ps) | ||
project_ks, project_vs = map(ProjectTo, ks), map(ProjectTo∘last, ps) | ||
function Dict_pullback(ȳ) | ||
dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v | ||
dk, dv = proj_k(getkey(ȳ, k, NoTangent())), proj_v(get(ȳ, k, NoTangent())) | ||
Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv) | ||
end | ||
return (NoTangent(), dps...) | ||
end | ||
return T(ps...), Dict_pullback | ||
end |
Oops, something went wrong.