Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create GNNlib.jl #432

Merged
merged 1 commit into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions GNNlib/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name = "GNNlib"
uuid = "a6a84749-d869-43f8-aacc-be26a1996e48"
authors = ["Carlo Lucibello and contributors"]
version = "0.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"

[extensions]
GNNlibCUDAExt = "CUDA"
GNNlibSimpleWeightedGraphsExt = "SimpleWeightedGraphs"

[compat]
Adapt = "3, 4"
CUDA = "4, 5"
ChainRulesCore = "1"
DataStructures = "0.18"
Functors = "0.4.1"
Graphs = "1.4"
KrylovKit = "0.6, 0.7"
LinearAlgebra = "1"
MLDatasets = "0.7"
MLUtils = "0.4"
MacroTools = "0.5"
NNlib = "0.9"
NearestNeighbors = "0.4"
Random = "1"
Reexport = "1"
SimpleWeightedGraphs = "1.4.0"
SparseArrays = "1"
Statistics = "1"
StatsBase = "0.34"
cuDNN = "1"
julia = "1.9"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]
14 changes: 14 additions & 0 deletions GNNlib/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# GNNlib.jl

This package contains a collection deep-learning framework agnostic
building blocks for graph neural networks such as graph convolutional layers and the implementation
of GraphGNN.

In the future it will serve as the foundation of GraphNeuralNetworks.jl (based on Flux,jl).
GNNlib.jl will be to GraphNeuralNetworks.jl what NNlib.jl is to Flux.jl and Lux.jl.

This package is currently under development and may break frequentely.
It is not meant for final users but for GNN libraries developers.
Final user should use GraphNeuralNetworks.jl instead.


2 changes: 2 additions & 0 deletions GNNlib/ext/GNNlibCUDAExt/GNNGraphs/query.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1))
2 changes: 2 additions & 0 deletions GNNlib/ext/GNNlibCUDAExt/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz)
8 changes: 8 additions & 0 deletions GNNlib/ext/GNNlibCUDAExt/GNNGraphs/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

GNNGraphs.iscuarray(x::AnyCuArray) = true


function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
#TODO proper cuda friendly implementation
sort_edge_index(u |> Flux.cpu, v |> Flux.cpu) |> Flux.gpu
end
17 changes: 17 additions & 0 deletions GNNlib/ext/GNNlibCUDAExt/GNNlibCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module GNNlibCUDAExt

using CUDA
using Random, Statistics, LinearAlgebra
using GNNlib
using GNNlib.GNNGraphs
using GNNlib.GNNGraphs: COO_T, ADJMAT_T, SPARSE_T
import GNNlib: propagate

const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}

include("GNNGraphs/query.jl")
include("GNNGraphs/transform.jl")
include("GNNGraphs/utils.jl")
include("msgpass.jl")

end #module
37 changes: 37 additions & 0 deletions GNNlib/ext/GNNlibCUDAExt/msgpass.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

###### PROPAGATE SPECIALIZATIONS ####################

## COPY_XJ

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e)
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
end

## E_MUL_XJ

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e::AbstractVector)
propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e)
end

## W_MUL_XJ

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
xi, xj::AnyCuMatrix, e::Nothing)
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e)
end

# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
# A = adjacency_matrix(g, weighted=false)
# D = compute_degree(A)
# return xj * A * D
# end

# # Zygote bug. Error with sparse matrix without nograd
# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2)))

# Flux.Zygote.@nograd compute_degree
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module GNNlibSimpleWeightedGraphsExt

using GNNlib
using Graphs
using SimpleWeightedGraphs

function GNNlib.GNNGraph(g::T; kws...) where
{T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}}
return GNNGraph(g.weights, kws...)
end

end #module
108 changes: 108 additions & 0 deletions GNNlib/src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
module GNNGraphs

using SparseArrays
using Functors: @functor
import Graphs
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
has_self_loops, is_directed
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, batch
import NearestNeighbors
import NNlib
import StatsBase
import KrylovKit
using ChainRulesCore
using LinearAlgebra, Random, Statistics
import MLUtils
import Functors

include("chainrules.jl") # hacks for differentiability

include("datastore.jl")
export DataStore

include("abstracttypes.jl")
export AbstractGNNGraph

include("gnngraph.jl")
export GNNGraph,
node_features,
edge_features,
graph_features

include("gnnheterograph.jl")
export GNNHeteroGraph,
num_edge_types,
num_node_types,
edge_type_subgraph

include("temporalsnapshotsgnngraph.jl")
export TemporalSnapshotsGNNGraph,
add_snapshot,
# add_snapshot!,
remove_snapshot
# remove_snapshot!

include("query.jl")
export adjacency_list,
edge_index,
get_edge_weight,
graph_indicator,
has_multi_edges,
is_directed,
is_bidirected,
normalized_laplacian,
scaled_laplacian,
laplacian_lambda_max,
# from Graphs
adjacency_matrix,
degree,
has_self_loops,
has_isolated_nodes,
inneighbors,
outneighbors,
khop_adj

include("transform.jl")
export add_nodes,
add_edges,
add_self_loops,
getgraph,
negative_sample,
rand_edge_split,
remove_self_loops,
remove_edges,
remove_multi_edges,
set_edge_weight,
to_bidirected,
to_unidirected,
random_walk_pe,
remove_nodes,
# from Flux
batch,
unbatch,
# from SparseArrays
blockdiag

include("generate.jl")
export rand_graph,
rand_heterograph,
rand_bipartite_heterograph,
knn_graph,
radius_graph,
rand_temporal_radius_graph,
rand_temporal_hyperbolic_graph

include("sampling.jl")
export sample_neighbors

include("operators.jl")
# Base.intersect

include("convert.jl")
include("utils.jl")

include("gatherscatter.jl")
# _gather, _scatter

end #module
11 changes: 11 additions & 0 deletions GNNlib/src/GNNGraphs/abstracttypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}}
const ADJMAT_T = AbstractMatrix
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T

const AVecI = AbstractVector{<:Integer}

# All concrete graph types should be subtypes of AbstractGNNGraph{T}.
# GNNGraph and GNNHeteroGraph are the two concrete types.
abstract type AbstractGNNGraph{T} <: AbstractGraph{Int} end
15 changes: 15 additions & 0 deletions GNNlib/src/GNNGraphs/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648
# Remove when merged

function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
ks = map(first, ps)
project_ks, project_vs = map(ProjectTo, ks), map(ProjectTo∘last, ps)
function Dict_pullback(ȳ)
dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v
dk, dv = proj_k(getkey(ȳ, k, NoTangent())), proj_v(get(ȳ, k, NoTangent()))
Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
end
return (NoTangent(), dps...)
end
return T(ps...), Dict_pullback
end
Loading
Loading