Skip to content

Commit

Permalink
use GNNlib in GNN.jl (#464)
Browse files Browse the repository at this point in the history
* use GNNlib in GNN.jl

* cleanup

* ported all graph convs

* workflow

* fix

* fix gcn_con

* fix gcn_con

* add comments
  • Loading branch information
CarloLucibello authored Jul 28, 2024
1 parent a9700f9 commit 9338ed7
Show file tree
Hide file tree
Showing 12 changed files with 265 additions and 1,212 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_GraphNeuralNetworks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
# dev mono repo versions
pkg"registry up"
Pkg.update()
pkg"dev ./GNNGraphs ."
pkg"dev ./GNNGraphs ./GNNlib ."
Pkg.test("GraphNeuralNetworks"; coverage=true)
- uses: julia-actions/julia-processcoverage@v1
with:
Expand Down
107 changes: 51 additions & 56 deletions GNNlib/src/GNNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,70 +12,65 @@ using .GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
check_num_nodes, check_num_edges,
EType, NType # for heteroconvs

export
# utils
reduce_nodes,
reduce_edges,
softmax_nodes,
softmax_edges,
broadcast_nodes,
broadcast_edges,
softmax_edge_neighbors,
# msgpass
apply_edges,
aggregate_neighbors,
propagate,
copy_xj,
copy_xi,
xi_dot_xj,
xi_sub_xj,
xj_sub_xi,
e_mul_xj,
w_mul_xj
include("utils.jl")
export reduce_nodes,
reduce_edges,
softmax_nodes,
softmax_edges,
broadcast_nodes,
broadcast_edges,
softmax_edge_neighbors

include("msgpass.jl")
export apply_edges,
aggregate_neighbors,
propagate,
copy_xj,
copy_xi,
xi_dot_xj,
xi_sub_xj,
xj_sub_xi,
e_mul_xj,
w_mul_xj

## The following methods are defined but not exported

# # layers/basic
# dot_decoder,

# # layers/conv
# agnn_conv,
# cg_conv,
# cheb_conv,
# edge_conv,
# egnn_conv,
# gat_conv,
# gatv2_conv,
# gated_graph_conv,
# gcn_conv,
# gin_conv,
# gmm_conv,
# graph_conv,
# megnet_conv,
# nn_conv,
# res_gated_graph_conv,
# sage_conv,
# sg_conv,
# transformer_conv,
include("layers/basic.jl")
export dot_decoder

# # layers/temporalconv
# a3tgcn_conv,
include("layers/conv.jl")
export agnn_conv,
cg_conv,
cheb_conv,
d_conv,
edge_conv,
egnn_conv,
gat_conv,
gatv2_conv,
gated_graph_conv,
gcn_conv,
gin_conv,
gmm_conv,
graph_conv,
megnet_conv,
nn_conv,
res_gated_graph_conv,
sage_conv,
sg_conv,
tag_conv,
transformer_conv

# # layers/pool
# global_pool,
# global_attention_pool,
# set2set_pool,
# topk_pool,
# topk_index,
include("layers/temporalconv.jl")
export a3tgcn_conv

include("layers/pool.jl")
export global_pool,
global_attention_pool,
set2set_pool,
topk_pool,
topk_index

include("utils.jl")
include("layers/basic.jl")
include("layers/conv.jl")
# include("layers/heteroconv.jl") # no functional part at the moment
include("layers/temporalconv.jl")
include("layers/pool.jl")
include("msgpass.jl")

end #module

Loading

0 comments on commit 9338ed7

Please sign in to comment.