From da9a6188652110346b50902f252d471724810ffe Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 8 Nov 2022 14:56:59 +0100 Subject: [PATCH 1/2] initial attempt --- .DS_Store | Bin 0 -> 8196 bytes Project.toml | 1 + docs/.DS_Store | Bin 0 -> 6148 bytes src/.DS_Store | Bin 0 -> 6148 bytes src/GNNGraphs/GNNGraphs.jl | 1 + src/GNNGraphs/convert.jl | 56 +++++++++++++++++++++++++++++++++---- src/GNNGraphs/gnngraph.jl | 4 ++- src/GNNGraphs/query.jl | 4 +-- src/GNNGraphs/utils.jl | 9 +++++- src/layers/conv.jl | 1 + test/.DS_Store | Bin 0 -> 6148 bytes test/runtests.jl | 2 +- test_graphblas.jl | 13 +++++++++ 13 files changed, 80 insertions(+), 11 deletions(-) create mode 100644 .DS_Store create mode 100644 docs/.DS_Store create mode 100644 src/.DS_Store create mode 100644 test/.DS_Store create mode 100644 test_graphblas.jl diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..1537fbafbe076b644e8692c1b63f2a91133b8f36 GIT binary patch literal 8196 zcmeHMzl#$=6n>LLvoR!tD{DbftQ4`+!Z$y%naPhUMZ6i9`6lz;+wa>iVfJPZ07#)*%>yg| zfQ6l5CXK_4#^S6`wU}Hvib${?yn4R0Nn62Grdco!7zd04#sTAiao}HZ0H4{cS#zHI zB5FqCfN@|~I>7b^8#_Zs;!q&pI&k1e0K^0?^M>;bn;_Z}9f?DM7{NwFDWWJP@`yo1 zIj-9>&yhG3D9VAz!v~SfM4nIxXUFT?JROK5(2T|b7R3h)X?6hN>rg@Z!m%OOYNeYpiE}mad2IVhL11 zQBCOf!a#ce(ha<3*vaMIlUO{lXK&IvLKU!}bQ?fTV_*Oo-UN?ECL=5W8iyp*+1pBZHBetvOj zF>9aA=LUn6b?Dfsb8F?hUdyKsnSZ$J%H0|Ei_g)xjkD->G-KE84{A)c^!xqat5<*Kv*Mw4es}paLxjz{7c{=_2QU&B{^roygX#s{#MY zr>>>y+$m4fO9G@sy#cCe!yQczJiV;OR2^@ZBOOt#%0DZv@xnXC8^@o18#Xmo^6mII zqH;dJ5pl{om2o<%@x_LWF999q`oy+F4t!N(D(AEG$L(m_k+mx8thmNY_7zd04|EB{aQ7jb;cyFV>-X?D|h6^_K+t^uSc|(D`1P7MF h*7)i=P#jw3`gGZSgM59Cm1pWwJ9^4Hg zpBFaIBR>_fjx-o8qcv=ICX=gMUGwslY1f?eH@7ys=CyurI&B$eFJ8KSYxp=ENBpVe zkl}_SL@lSCnp_HdLcJ_Rw-oI?g2hN#qwsjb!7dH4M&-)hzFm-C^j`p|ce`Fsw)<6oOm z=Npy+4M-t@-CC94f1F#YH66rHW=yKMY6XkdT0Z$*TgZ!KDZwA6i=3tP+^+)6y~L|p z$AzzNUVoBbbgAc@m7_hE)JL2gdr?csSU>hN(G$=NXa@c?1GGPwsDzfnQli*8un|`P z#3l|4!9Lw3D8^K1DJ&(T1%=5_L>bD|6@$re95Npl1I*QMs3PGDm9YjlEDG^6d^dA91gKjhff0cnBInCI^ literal 0 HcmV?d00001 diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..2f4b7e5948d9d5199914a5652f771fdd1fb9faeb GIT binary patch literal 6148 zcmeHKy-veG4EB`@rDEyGc(2em2vJZsmWsWU7CIzK1VV`I78V|gi5YkW5bpt>?Ndoe zBQYUV*_H3R_~*{~62(0ta@oytLNq3#1j;xVVb~(s4sH)%gB1&*xAjo z9PLwq_v(&Ts1I*e!%vU*KmE;`-k#qdK5nI^)*5B?T}@w1s}6eQI!ufKW55{L6$a3= zS&{=q8;t>Dz!=ytz~2WCWlSRmMfr4KkRt$a2(t+0e3sxGA2E#>6ybq5O$BPIi$@Hn z>9G44mqrYVnodr3#yW9l7f&coc8A?dIJs2OMq|JjXfv>(E7Zf)sm`E!|EVtrAs0i$yC%`mfP=p0ye*}aE8;pTJ GW#9{9KT_QQ literal 0 HcmV?d00001 diff --git a/src/GNNGraphs/GNNGraphs.jl b/src/GNNGraphs/GNNGraphs.jl index a0baa83ee..14229b8e8 100644 --- a/src/GNNGraphs/GNNGraphs.jl +++ b/src/GNNGraphs/GNNGraphs.jl @@ -15,6 +15,7 @@ using ChainRulesCore using LinearAlgebra, Random, Statistics import MLUtils using MLUtils: getobs, numobs +using SuiteSparseGraphBLAS include("gnngraph.jl") export GNNGraph, diff --git a/src/GNNGraphs/convert.jl b/src/GNNGraphs/convert.jl index 8da7345a4..b460168ca 100644 --- a/src/GNNGraphs/convert.jl +++ b/src/GNNGraphs/convert.jl @@ -89,12 +89,12 @@ function to_dense(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing, weighted= if dir == :in A = A' end + if !weighted + A = binarize(A, T) + end if T != eltype(A) A = T.(A) end - if !weighted - A = map(x -> ifelse(x > 0, T(1), T(0)), A) - end return A, num_nodes, num_edges end @@ -154,15 +154,15 @@ function to_sparse(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing, weighted if dir == :in A = A' end + if !weighted + A = binarize(A, T) + end if T != eltype(A) A = T.(A) end if !(A isa AbstractSparseMatrix) A = sparse(A) end - if !weighted - A = map(x -> ifelse(x > 0, T(1), T(0)), A) - end return A, num_nodes, num_edges end @@ -187,3 +187,47 @@ function to_sparse(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted= end return A, num_nodes, num_edges end + +# GBMatrix + +function to_graphblas(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing, weighted=true) + @assert dir ∈ [:out, :in] + num_nodes = size(A, 1) + @assert num_nodes == size(A, 2) + T = T === nothing ? eltype(A) : T + num_edges = A isa AbstractSparseMatrix ? nnz(A) : count(!=(0), A) + if dir == :in + A = A' + end + A = GBMatrix(A, fill=T(0)) + if !weighted + A = binarize(A, T) + end + if T != eltype(A) + A = T.(A) + end + + return A, num_nodes, num_edges +end + +function to_graphblas(adj_list::ADJLIST_T, T=nothing; dir=:out, num_nodes=nothing, weighted=true) + coo, num_nodes, num_edges = to_coo(adj_list; dir, num_nodes) + return to_graphblas(coo; num_nodes) +end + +function to_graphblas(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=true) + s, t, eweight = coo + T = T === nothing ? (eweight === nothing ? eltype(s) : eltype(eweight)) : T + + if eweight === nothing || !weighted + eweight = fill!(similar(s, T), 1) + end + + num_nodes::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes + A = GBMatrix(s, t, eweight, num_nodes, num_nodes, fill=T(0)) + num_edges::Int = nnz(A) + if eltype(A) != T + A = T.(A) + end + return A, num_nodes, num_edges +end diff --git a/src/GNNGraphs/gnngraph.jl b/src/GNNGraphs/gnngraph.jl index ab4f3a789..efc3b61a5 100644 --- a/src/GNNGraphs/gnngraph.jl +++ b/src/GNNGraphs/gnngraph.jl @@ -127,7 +127,7 @@ function GNNGraph(data::D; gdata = (;), ) where D <: Union{COO_T, ADJMAT_T, ADJLIST_T} - @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" + @assert graph_type ∈ [:coo, :dense, :sparse, :graphblas] "Invalid graph_type $graph_type requested" @assert dir ∈ [:in, :out] if graph_type == :coo @@ -136,6 +136,8 @@ function GNNGraph(data::D; graph, num_nodes, num_edges = to_dense(data; num_nodes, dir) elseif graph_type == :sparse graph, num_nodes, num_edges = to_sparse(data; num_nodes, dir) + elseif graph_type == :graphblas + graph, num_nodes, num_edges = to_graphblas(data; num_nodes, dir) end num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1 diff --git a/src/GNNGraphs/query.jl b/src/GNNGraphs/query.jl index a537a90c4..a6c101c69 100644 --- a/src/GNNGraphs/query.jl +++ b/src/GNNGraphs/query.jl @@ -157,7 +157,7 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=eltype(g); @assert dir ∈ [:in, :out] A = g.graph if !weighted - A = binarize(A) + A = binarize(A, T) end A = T != eltype(A) ? T.(A) : A return dir == :out ? A : A' @@ -232,7 +232,7 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT=nothing; dir=:out, edge_we end A = adjacency_matrix(g) if edge_weight === false - A = binarize(A) + A = binarize(A, T) end A = eltype(A) != T ? T.(A) : A return dir == :out ? vec(sum(A, dims=2)) : diff --git a/src/GNNGraphs/utils.jl b/src/GNNGraphs/utils.jl index 0e5c498c0..bae152ba3 100644 --- a/src/GNNGraphs/utils.jl +++ b/src/GNNGraphs/utils.jl @@ -172,7 +172,14 @@ function edge_decoding(idx, n; directed=true) return s, t end -binarize(x) = map(>(0), x) +# binarize(x) = map(>(0), x) +binarize(x) = binarize(x, Bool) +binarize(x, T::Type{Bool}) = x .> 0 +binarize(x, T) = T.(x .> 0) + +binarize(x::GBMatrix, T::Type{Bool}) = x .> 0 +binarize(x::GBMatrix, T) = T.(binarize(x, Bool)) + @non_differentiable binarize(x...) @non_differentiable edge_encoding(x...) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c1409c71f..75285e807 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -107,6 +107,7 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}, edge_weight::EW=nothing else x = propagate(copy_xj, g, +, xj=x) end + @show x c x = x .* c' if Dout >= Din x = l.weight * x diff --git a/test/.DS_Store b/test/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..91c07da1854603d93ee95abd4820bc7e35761ce0 GIT binary patch literal 6148 zcmeHKy-LJD5S~2`$)U)Vmg%kX075(>*NG4;1#>@mMMA=fsEwxcoh+=Ztb7O`z(??# zoh3haInhQ$X2R^3%+BmizFjg|A~K`vSwJ)-A`i;g8=&bB?q^+-p0n%%jUA()l;(7m zUIy`+w;KK;1H5(~*6fxN%BZn^(>yCA;l2AVpDd4}Nt#c%?;x3bXT{4=x8Lkre#Ih) z>p6Kp(tyTvf@@4mWLMBac2@26@Y?Ow%=jBgc^6mFcJ>C@BBiCrJ#h9u`^>F zH?wgk6k}(H-M8*!LP2en0cD`gz?M7?x&9w-KL59a^hp^|2L2TTrXR)82(RR7>)Ok4 ut@WW7P!{$p1*;I;*isB%F2(y$E3o@~0}LIOg0MjJN5ImcjWY17415BXw{Ua- literal 0 HcmV?d00001 diff --git a/test/runtests.jl b/test/runtests.jl index 6b08fb50f..decc16d96 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -42,7 +42,7 @@ tests = [ !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") -@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :dense, :sparse) +@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:graphblas, :coo, :dense, :sparse) global GRAPH_T = graph_type global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse) diff --git a/test_graphblas.jl b/test_graphblas.jl new file mode 100644 index 000000000..db3edfaf9 --- /dev/null +++ b/test_graphblas.jl @@ -0,0 +1,13 @@ +using GraphNeuralNetworks +using SuiteSparseGraphBLAS +using LinearAlgebra, SparseArrays + +g = rand_graph(10, 20, graph_type=:graphblas) +x = rand(2, 10) +m = GCNConv(2 => 3) +A = adjacency_matrix(g) +@assert A isa GBMatrix +@assert A + I isa GBMatrix +@assert Float32.(A) isa GBMatrix + +m(g, x) \ No newline at end of file From 8d2ded08cc4132d5b3056de5c5b13bb6657e7cf9 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 8 Nov 2022 14:58:16 +0100 Subject: [PATCH 2/2] cleanup --- .DS_Store | Bin 8196 -> 0 bytes src/.DS_Store | Bin 6148 -> 0 bytes test/.DS_Store | Bin 6148 -> 0 bytes 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store delete mode 100644 src/.DS_Store delete mode 100644 test/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 1537fbafbe076b644e8692c1b63f2a91133b8f36..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHMzl#$=6n>LLvoR!tD{DbftQ4`+!Z$y%naPhUMZ6i9`6lz;+wa>iVfJPZ07#)*%>yg| zfQ6l5CXK_4#^S6`wU}Hvib${?yn4R0Nn62Grdco!7zd04#sTAiao}HZ0H4{cS#zHI zB5FqCfN@|~I>7b^8#_Zs;!q&pI&k1e0K^0?^M>;bn;_Z}9f?DM7{NwFDWWJP@`yo1 zIj-9>&yhG3D9VAz!v~SfM4nIxXUFT?JROK5(2T|b7R3h)X?6hN>rg@Z!m%OOYNeYpiE}mad2IVhL11 zQBCOf!a#ce(ha<3*vaMIlUO{lXK&IvLKU!}bQ?fTV_*Oo-UN?ECL=5W8iyp*+1pBZHBetvOj zF>9aA=LUn6b?Dfsb8F?hUdyKsnSZ$J%H0|Ei_g)xjkD->G-KE84{A)c^!xqat5<*Kv*Mw4es}paLxjz{7c{=_2QU&B{^roygX#s{#MY zr>>>y+$m4fO9G@sy#cCe!yQczJiV;OR2^@ZBOOt#%0DZv@xnXC8^@o18#Xmo^6mII zqH;dJ5pl{om2o<%@x_LWF999q`oy+F4t!N(D(AEG$L(m_k+mx8thmNY_7zd04|EB{aQ7jb;cyFV>-X?D|h6^_K+t^uSc|(D`1P7MF h*7?Ndoe zBQYUV*_H3R_~*{~62(0ta@oytLNq3#1j;xVVb~(s4sH)%gB1&*xAjo z9PLwq_v(&Ts1I*e!%vU*KmE;`-k#qdK5nI^)*5B?T}@w1s}6eQI!ufKW55{L6$a3= zS&{=q8;t>Dz!=ytz~2WCWlSRmMfr4KkRt$a2(t+0e3sxGA2E#>6ybq5O$BPIi$@Hn z>9G44mqrYVnodr3#yW9l7f&coc8A?dIJs2OMq|JjXfv>(E7Zf)sm`E!|EVtrAs0i$yC%`mfP=p0ye*}aE8;pTJ GW#9{9KT_QQ diff --git a/test/.DS_Store b/test/.DS_Store deleted file mode 100644 index 91c07da1854603d93ee95abd4820bc7e35761ce0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKy-LJD5S~2`$)U)Vmg%kX075(>*NG4;1#>@mMMA=fsEwxcoh+=Ztb7O`z(??# zoh3haInhQ$X2R^3%+BmizFjg|A~K`vSwJ)-A`i;g8=&bB?q^+-p0n%%jUA()l;(7m zUIy`+w;KK;1H5(~*6fxN%BZn^(>yCA;l2AVpDd4}Nt#c%?;x3bXT{4=x8Lkre#Ih) z>p6Kp(tyTvf@@4mWLMBac2@26@Y?Ow%=jBgc^6mFcJ>C@BBiCrJ#h9u`^>F zH?wgk6k}(H-M8*!LP2en0cD`gz?M7?x&9w-KL59a^hp^|2L2TTrXR)82(RR7>)Ok4 ut@WW7P!{$p1*;I;*isB%F2(y$E3o@~0}LIOg0MjJN5ImcjWY17415BXw{Ua-