Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 28, 2024
1 parent ffdb7e2 commit 1ab21cd
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
####################### GCNConv ######################################

check_gcnconv_input(g::AbstractGNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) =
throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs"))
Expand Down Expand Up @@ -77,6 +78,8 @@ function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, norm_fn::F, co
return gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight)
end

####################### ChebConv ######################################

function cheb_conv(l, g::GNNGraph, X::AbstractMatrix{T}) where {T}
check_num_nodes(g, X)
@assert size(X, 1) == size(l.weight, 2) "Input feature size must match input channel size."
Expand All @@ -94,6 +97,8 @@ function cheb_conv(l, g::GNNGraph, X::AbstractMatrix{T}) where {T}
return Y .+ l.bias
end

####################### GraphConv ######################################

function graph_conv(l, g::AbstractGNNGraph, x)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)
Expand All @@ -102,6 +107,8 @@ function graph_conv(l, g::AbstractGNNGraph, x)
return l.σ.(x .+ l.bias)
end

####################### GATConv ######################################

function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing)
check_num_nodes(g, x)
@assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer"
Expand Down Expand Up @@ -158,6 +165,8 @@ function gat_message(l, Wxi, Wxj, e)
return (; logα, Wxj)
end

####################### GATv2Conv ######################################

function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing)
check_num_nodes(g, x)
@assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer"
Expand Down Expand Up @@ -202,6 +211,7 @@ function gatv2_message(l, Wxi, Wxj, e)
return (; logα, Wxj)
end

####################### GatedGraphConv ######################################

# TODO PIRACY! remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
@non_differentiable fill!(x...)
Expand All @@ -222,6 +232,8 @@ function gated_graph_conv(l, g::GNNGraph, H::AbstractMatrix{S}) where {S <: Real
return H
end

####################### EdgeConv ######################################

function edge_conv(l, g::AbstractGNNGraph, x)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)
Expand All @@ -233,6 +245,7 @@ end

edge_conv_message(l, xi, xj, e) = l.nn(vcat(xi, xj .- xi))

####################### GINConv ######################################

function gin_conv(l, g::AbstractGNNGraph, x)
check_num_nodes(g, x)
Expand All @@ -243,6 +256,8 @@ function gin_conv(l, g::AbstractGNNGraph, x)
return l.nn((1 .+ ofeltype(xi, l.ϵ)) .* xi .+ m)
end

####################### NNConv ######################################

function nn_conv(l, g::GNNGraph, x::AbstractMatrix, e)
check_num_nodes(g, x)
message = Fix1(nn_conv_message, l)
Expand All @@ -258,6 +273,8 @@ function nn_conv_message(l, xi, xj, e)
return reshape(m, :, nedges)
end

####################### SAGEConv ######################################

function sage_conv(l, g::AbstractGNNGraph, x)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)
Expand All @@ -266,6 +283,8 @@ function sage_conv(l, g::AbstractGNNGraph, x)
return x
end

####################### ResGatedConv ######################################

function res_gated_graph_conv(l, g::AbstractGNNGraph, x)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)
Expand All @@ -281,6 +300,8 @@ function res_gated_graph_conv(l, g::AbstractGNNGraph, x)
return l.σ.(l.U * xi .+ m .+ l.bias)
end

####################### CGConv ######################################

function cg_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)
Expand Down Expand Up @@ -312,6 +333,7 @@ function cg_message(l, xi, xj, e)
return l.dense_f(z) .* l.dense_s(z)
end

####################### AGNNConv ######################################

function agnn_conv(l, g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
Expand All @@ -330,6 +352,8 @@ function agnn_conv(l, g::GNNGraph, x::AbstractMatrix)
return x
end

####################### MegNetConv ######################################

function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
check_num_nodes(g, x)

Expand All @@ -344,6 +368,8 @@ function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
return x̄, ē
end

####################### GMMConv ######################################

function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
(nin, ein), out = l.ch #Notational Simplicity

Expand Down Expand Up @@ -375,6 +401,8 @@ function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
return m
end

####################### SGCConv ######################################

# this layer is not stable enough to be supported by GNNHeteroGraph type
# due to it's looping mechanism
function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T},
Expand Down Expand Up @@ -426,6 +454,8 @@ function sgc_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
return sgc_conv(l, g, x, edge_weight)
end

####################### EGNNGConv ######################################

function egnn_conv(l, g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e = nothing)
if l.num_features.edge > 0
@assert e!==nothing "Edge features must be provided."
Expand Down Expand Up @@ -464,6 +494,8 @@ function egnn_message(l, xi, xj, e)
return (; x = msg_x, h = msg_h)
end

######################## SGConv ######################################

# this layer is not stable enough to be supported by GNNHeteroGraph type
# due to it's looping mechanism
function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T},
Expand Down Expand Up @@ -515,6 +547,8 @@ function sg_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
return sg_conv(l, g, x, edge_weight)
end

######################## TransformerConv ######################################

function transformer_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, Nothing} = nothing)
check_num_nodes(g, x)

Expand Down Expand Up @@ -594,6 +628,7 @@ function transformer_message_main(xi, xj, e)
end


######################## TAGConv ######################################

function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T},
edge_weight::EW = nothing) where
Expand Down Expand Up @@ -654,6 +689,7 @@ function tag_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
return l(g, x, edge_weight)
end

######################## DConv ######################################

function d_conv(l, g::GNNGraph, x::AbstractMatrix)
#A = adjacency_matrix(g, weighted = true)
Expand Down

0 comments on commit 1ab21cd

Please sign in to comment.